Çapraz doğrulama

R ile Ağaç Tabanlı Modellerle Machine Learning

Sandro Raabe

Data Scientist

k-katlı çapraz doğrulama

veri k katmana bölündü

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

test kümesi olarak kat 1

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

test kümesi olarak kat 2

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

test kümesi olarak kat 3

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

test kümesi olarak kat 4

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

test kümesi olarak kat 5

R ile Ağaç Tabanlı Modellerle Machine Learning

k-katlı çapraz doğrulama

toplam çapraz doğrulanmış MAE

R ile Ağaç Tabanlı Modellerle Machine Learning

Son modeli tüm veri kümesinde eğitme

son modeli tam veri kümesinde eğit

R ile Ağaç Tabanlı Modellerle Machine Learning

Kodlama - Veriyi 10 kez bölme

# Random seed for reproducibility
set.seed(100)

# Create 10 folds of the dataset
chocolate_folds <- vfold_cv(chocolate_train, v = 10)
#  10-fold cross-validation
# A tibble: 10 x 2
  splits             id
1 <split [1293/144]> Fold1
2 <split [1293/144]> Fold2
3 <split [1293/144]> Fold3
4 ...
R ile Ağaç Tabanlı Modellerle Machine Learning

Kodlama - Katları eğitme

# Fit a model for every fold and calculate MAE and RMSE
fits_cv <- fit_resamples(tree_spec,

final_grade ~ .,
resamples = chocolate_folds,
metrics = metric_set(mae, rmse))
# Resampling results
# 10-fold cross-validation
# A tibble: 10 x 4
  splits             id    .metrics
  <list>             <chr> <list>
1 <split [1293/144]> Fold1 <tibble [2 x 4]>
2 <split [1293/144]> Fold2 <tibble [2 x 4]>
3 <split [1293/144]> Fold3 <tibble [2 x 4]>
4 ...
R ile Ağaç Tabanlı Modellerle Machine Learning

Kodlama - Tüm hataları topla

# Collect raw errors of all model runs
all_errors <- collect_metrics(fits_cv, 
                              summarize = FALSE)

print(all_errors)
# A tibble: 20 x 3
   id      .metric  .estimate
   <chr>     <chr>      <dbl>
 1 Fold01      mae      0.362
 2 Fold01     rmse      0.442
 3 Fold02      mae      0.385
 4 Fold02     rmse      0.504
 5 ...
library(ggplot2)
ggplot(all_errors, aes(x = .estimate, 
                       fill = .metric)) +
   geom_histogram()

hata histogramı

R ile Ağaç Tabanlı Modellerle Machine Learning

Kodlama - Eğitimleri özetle

# Collect and summarize errors of all model runs
collect_metrics(fits_cv)
# A tibble: 2 x 3
  .metric   mean      n
  <chr>    <dbl>  <int>
1 mae      0.383     10
2 rmse     0.477     10
R ile Ağaç Tabanlı Modellerle Machine Learning

Hadi çapraz doğrulayalım!

R ile Ağaç Tabanlı Modellerle Machine Learning

Preparing Video For Download...