Kreuzvalidierung

Überwachtes Lernen mit scikit-learn

George Boorman

Core Curriculum Manager, DataCamp

Notwendigkeit der Kreuzvalidierung

  • Leistung eines Modells hängt von der Aufteilung der Daten ab

  • Eventuell nicht repräsentativ für die Fähigkeit des Modells, Regeln für neue Daten zu verallgemeinern

  • Lösung: Kreuzvalidierung!

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Tabellenüberschriften: Zerlegung 1, Teilmenge 1, Teilmenge 2, Teilmenge 3, Teilmenge 4 und Teilmenge 5

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Teilmenge 1 als Testdaten

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Teilmengen 2 bis 5 als Trainingsdaten

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Berechnung der Kennzahl für diese Teilmengen

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Teilmenge 2 als Testdaten

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Teilmengen 1, 3, 4 und 5 als Trainingsdaten

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Erneute Berechnung der Kennzahl

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Wiederholung mit der dritten Teilmenge

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Wiederholung mit der vierten Teilmenge

Überwachtes Lernen mit scikit-learn

Grundlagen der Kreuzvalidierung

Wiederholung mit der fünften Teilmenge

Überwachtes Lernen mit scikit-learn

Kreuzvalidierung und Modellleistung

  • 5 Teilmengen = 5-fache KV

  • 10 Teilmengen = 10-fache KV

  • k Teilmengen = k-fache KV

  • Mehr Teilmengen = höherer Rechenaufwand

Überwachtes Lernen mit scikit-learn

Kreuzvalidierung in scikit-learn

from sklearn.model_selection import cross_val_score, KFold

kf = KFold(n_splits=6, shuffle=True, random_state=42)
reg = LinearRegression()
cv_results = cross_val_score(reg, X, y, cv=kf)
Überwachtes Lernen mit scikit-learn

Auswertung der Ergebnisse der Kreuzvalidierung

print(cv_results)
[0.70262578, 0.7659624, 0.75188205, 0.76914482, 0.72551151, 0.73608277]
print(np.mean(cv_results), np.std(cv_results))
0.7418682216666667 0.023330243960652888
print(np.quantile(cv_results, [0.025, 0.975]))
array([0.7054865, 0.76874702])
Überwachtes Lernen mit scikit-learn

Lass uns üben!

Überwachtes Lernen mit scikit-learn

Preparing Video For Download...