Predict and evaluate

Machine Learning with Tree-Based Models in R

Sandro Raabe

Data Scientist

Predicting on new data

General call:
predict(model, new_data, type)
Arguments:
  1. Trained model
  2. Dataset to predict on
  3. Prediction type: labels or probabilities
Machine Learning with Tree-Based Models in R

Predicting on new data

predict(model, new_data = test_data, 
               type = "class")
  .pred_class
  <fct>      
1 no         
2 no         
3 yes        
4 no
predict(model, new_data = test_data, 
               type = "prob")
     .pred_no  .pred_yes
     <dbl>     <dbl>
1    0.866     0.134
2    0.956     0.044
3    0.672     0.328
4    0.877     0.123
Machine Learning with Tree-Based Models in R

Confusion matrix

confusion_matrix_1

  • Reveals how confused a model is
Machine Learning with Tree-Based Models in R

Confusion matrix

confusion_matrix_2

Machine Learning with Tree-Based Models in R

Confusion matrix

confusion_matrix_3

Machine Learning with Tree-Based Models in R

Confusion matrix

confusion_matrix_4

  • Diagonal: correct predictions
  • Off-diagonal: incorrect predictions
Machine Learning with Tree-Based Models in R

Confusion matrix

 

  • TP: prediction is yes, truth is yes
  • TN: prediction is no, truth is no
  • FP: prediction is yes, truth is no
  • FN: prediction is no, truth is yes

confusion_matrix_4

Machine Learning with Tree-Based Models in R

Create the confusion matrix

# Combine predictions and truth values
pred_combined <- predictions %>% 
   mutate(true_class = test_data$outcome)

pred_combined
  .pred_class  true_class
  <fct>        <fct>     
1 no           no        
2 no           yes       
3 no           no        
4 yes          yes
# Calculate the confusion matrix
conf_mat(data = pred_combined,

estimate = .pred_class,
truth = true_class)
             Truth
Prediction    no   yes
        no   116    31
       yes    12    40
Machine Learning with Tree-Based Models in R

Accuracy

  $$\text{accuracy} = \frac{\text{n of correct predictions}}{\text{n of total predictions}}$$

  • Function name: accuracy()
  • Same arguments as conf_mat()
    • data, estimate and truth
    • Common structure in yardstick
accuracy(pred_combined,
         estimate = .pred_class,
         truth = true_class)
  .metric     .estimate
  <chr>           <dbl>
1 accuracy        0.708
Machine Learning with Tree-Based Models in R

Let's evaluate!

Machine Learning with Tree-Based Models in R

Preparing Video For Download...