Validação cruzada

Aprendizado Supervisionado com o scikit-learn

George Boorman

Core Curriculum Manager, DataCamp

Motivação para a validação cruzada

  • O desempenho do modelo depende da maneira como dividimos os dados

  • Não representa a capacidade do modelo de generalizar para dados não vistos

  • Solução: validação cruzada

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

Títulos da tabela: divisão 1, grupo 1, grupo 2, grupo 3, grupo 4 e grupo 5

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

divisão 1 reservada como conjunto de teste

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

grupos 2–5 usados como dados de treinamento

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

calcular a métrica com esses grupos

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

Grupo 2 como dados de teste

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

Grupos 1, 3, 4 e 5 como dados de treinamento

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

calcular a métrica novamente

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

repetir com o terceiro grupo

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

repetir com o quarto grupo

Aprendizado Supervisionado com o scikit-learn

Noções básicas de validação cruzada

repetir com o quinto grupo

Aprendizado Supervisionado com o scikit-learn

Validação cruzada e desempenho do modelo

  • 5 grupos = validação cruzada com 5 grupos

  • 10 grupos = validação cruzada com 10 grupos

  • k grupos = validação cruzada com k grupos (k-fold CV)

  • Mais grupos = mais caro do ponto de vista computacional

Aprendizado Supervisionado com o scikit-learn

Validação cruzada no scikit-learn

from sklearn.model_selection import cross_val_score, KFold

kf = KFold(n_splits=6, shuffle=True, random_state=42)
reg = LinearRegression()
cv_results = cross_val_score(reg, X, y, cv=kf)
Aprendizado Supervisionado com o scikit-learn

Avaliação do desempenho da validação cruzada

print(cv_results)
[0.70262578, 0.7659624, 0.75188205, 0.76914482, 0.72551151, 0.73608277]
print(np.mean(cv_results), np.std(cv_results))
0.7418682216666667 0.023330243960652888
print(np.quantile(cv_results, [0.025, 0.975]))
array([0.7054865, 0.76874702])
Aprendizado Supervisionado com o scikit-learn

Vamos praticar!

Aprendizado Supervisionado com o scikit-learn

Preparing Video For Download...