Prestaties inschatten met cross-validatie

Modelleren met tidymodels in R

David Svancer

Data Scientist

Trainings- en testdatasets

 

Trainings- en testdatasets maken is de eerste stap in het modelleerproces

  • Beschermt tegen overfitting
    • Trainingsdata voor modelleren
    • Testdata voor evaluatie

 

Nadeel

  • Slechts één schatting van modelprestatie

 

Partitioneringsdiagram voor het maken van trainings- en testdatasets

Modelleren met tidymodels in R

K-voudige cross-validatie

Resamplingtechniek om modelprestatie te verkennen

  • Levert K schattingen van modelprestatie tijdens het fitten

 

Verdeling van data in trainings- en testdatasets

Modelleren met tidymodels in R

K-voudige cross-validatie

Resamplingtechniek om modelprestatie te verkennen

  • Levert K schattingen van modelprestatie tijdens het fitten
  • Trainingsdata willekeurig verdeeld in K sets van ongeveer gelijke grootte
  • Folds worden gebruikt voor K iteraties van fitten en evalueren

 

Verdeling van trainingsdata in cross-validatiefolds

Modelleren met tidymodels in R

Machine learning met cross-validatie

5-voudige cross-validatie uitvoeren

  • Vijf iteraties van modeltraining en -evaluatie

 

Iteratie één van vijfvoudige cross-validatie

Modelleren met tidymodels in R

Machine learning met cross-validatie

5-voudige cross-validatie uitvoeren

  • Vijf iteraties van modeltraining en -evaluatie
  • Iteratie 1
    • Fold 1 gereserveerd voor modelevaluatie en folds 2 t/m 5 voor training

 

Iteratie één van vijfvoudige cross-validatie

Modelleren met tidymodels in R

Machine learning met cross-validatie

5-voudige cross-validatie uitvoeren

  • Vijf iteraties van modeltraining en -evaluatie
  • Iteratie 1
    • Fold 1 gereserveerd voor modelevaluatie en folds 2 t/m 5 voor training
  • Iteratie 2
    • Fold 2 gereserveerd voor modelevaluatie

 

Iteratie twee van vijfvoudige cross-validatie

Modelleren met tidymodels in R

Machine learning met cross-validatie

5-voudige cross-validatie uitvoeren

  • Vijf iteraties van modeltraining en -evaluatie
  • Iteratie 1
    • Fold 1 gereserveerd voor modelevaluatie en folds 2 t/m 5 voor training
  • Iteratie 2
    • Fold 2 gereserveerd voor modelevaluatie

 

In totaal vijf schattingen van modelprestatie

 

Iteratie vijf van vijfvoudige cross-validatie

Modelleren met tidymodels in R

Cross-validatiefolds maken

De functie vfold_cv()

  • Trainingsdata
  • Aantal folds, v
  • Stratificatievariabele, strata
  • Voer set.seed() uit vóór vfold_cv() voor reproduceerbaarheid
  • splits
    • List-kolom met datasplit-objecten voor het maken van folds
set.seed(214)
leads_folds <- vfold_cv(leads_training,

v = 10,
strata = purchased)
leads_folds
#  10-fold cross-validation using stratification 
# A tibble: 10 x 2
   splits            id    
   <list>            <chr> 
 1 <split [896/100]> Fold01
 2 <split [896/100]> Fold02
 3 <split [896/100]> Fold03
 . ................  ......
 9 <split [897/99]>  Fold09
10 <split [897/99]>  Fold10
Modelleren met tidymodels in R

Modeltraining met cross-validatie

De functie fit_resamples()

  • Train een parsnip-model of workflow-object
  • Geef cross-validatiefolds op, resamples
  • Optionele aangepaste metriekfunctie, metrics
    • Standaard: accuracy en ROC AUC

 

Elke metriek wordt 10 keer geschat

  • Eén schatting per fold
  • Gemiddelde in kolom mean
leads_rs_fit <- leads_wkfl %>%

fit_resamples(resamples = leads_folds,
metrics = leads_metrics)
leads_rs_fit %>% collect_metrics()
# A tibble: 3 x 5
  .metric .estimator  mean     n std_err
  <chr>   <chr>      <dbl> <int>   <dbl>
1 roc_auc binary     0.823    10  0.0147
2 sens    binary     0.786    10  0.0203
3 spec    binary     0.855    10  0.0159
Modelleren met tidymodels in R

Gedetailleerde cross-validatieresultaten

De functie collect_metrics()

  • Met summarize = FALSE krijg je alle metriek-schattingen per cross-validatiefold
  • 30 combinaties totaal (3 metrieken x 10 folds)
    • Kolom .metric geeft de metriek aan
    • Kolom .estimate geeft de schatting per fold
rs_metrics <- leads_rs_fit %>% 
  collect_metrics(summarize = FALSE)

rs_metrics
# A tibble: 30 x 4
   id     .metric .estimator .estimate
   <chr>  <chr>   <chr>          <dbl>
 1 Fold01 sens    binary         0.861
 2 Fold01 spec    binary         0.891
 3 Fold01 roc_auc binary         0.885
 4 Fold02 sens    binary         0.778
 5 Fold02 spec    binary         0.969
 6 Fold02 roc_auc binary         0.885
# ... with 24 more rows
Modelleren met tidymodels in R

Cross-validatieresultaten samenvatten

collect_metrics() geeft een tibble terug

  • Resultaten kun je samenvatten met dplyr
    • Begin met rs_metrics
    • Groepeer op .metric
    • Bereken samenvattende statistieken met summarize()
rs_metrics %>%

group_by(.metric) %>%
summarize(min = min(.estimate), median = median(.estimate), max = max(.estimate), mean = mean(.estimate), sd = sd(.estimate))
# A tibble: 3 x 6
 .metric   min  median   max   mean     sd
  <chr>   <dbl>  <dbl>  <dbl>  <dbl>   <dbl>
1 roc_auc 0.758  0.806  0.885  0.823   0.0466
2 sens    0.667  0.792  0.861  0.786   0.0642
3 spec    0.810  0.843  0.969  0.855   0.0502
Modelleren met tidymodels in R

Cross-validatiemethodologie

Modellen getraind met fit_resamples() kunnen geen voorspellingen doen op nieuwe data

  • predict() accepteert geen resample-objecten

Doel van fit_resample()

  • De prestatieprofielen van verschillende modeltypen verkennen en vergelijken
  • Beste modeltype kiezen en daarop focussen bij het fitten
predict(leads_rs_fit,
        new_data = leads_test)

Error in UseMethod("predict") : 
  no applicable method for 'predict' applied to 
  an object of class 
  "c('resample_results', 
      'tune_results',  
      'tbl_df', 
      'tbl', 'data.frame')"
Modelleren met tidymodels in R

Laten we cross-validaten!

Modelleren met tidymodels in R

Preparing Video For Download...