Explainability in tree-based models

Explainable AI in Python

Fouad Trad

Machine Learning Engineer

Decision tree

  • Fundamental block of tree-based models
  • Used for regression and classification
  • Tree-like structure for predictions
    • Several decisions
    • Each decision is based on one feature
  • Inherently explainable

Representation of a decision tree showing how predictions are made based on multiple conditions

Explainable AI in Python

Random forest

  • Consists of many decision trees
  • Used for regression and classification
  • Complicates direct interpretability
  • Feature importance
    • Measures reduction of uncertainty in predictions
    • Different than coefficients in linear models

Image showing the random forest as a collection of trees that receive an instance, and aggregate the predictions to give a final prediction.

Explainable AI in Python

Admissions dataset

GRE Score TOEFL Score University Rating SOP LOR CGPA Accept
337 118 4 4.5 4.5 9.65 1
324 107 4 4 4.5 8.87 1
316 104 3 3 3.5 8.00 1
322 110 3 3.5 2.5 8.67 1
314 103 2 2 3 8.21 0
X_train = data.drop(['Accept'], axis=1)
y_train = data['Accept']
Explainable AI in Python

Model training

from sklearn.tree import DecisionTreeClassifier

tree_model = DecisionTreeClassifier() tree_model.fit(X_train, y_train)
print(tree_model.feature_importances_)
[0.17936982 0.08878744 0.04388924 
 0.0532897  0.07130751 0.56335628]
from sklearn.ensemble import RandomForestClassifier

forest_model = RandomForestClassifier() forest_model.fit(X_train, y_train)
print(forest_model.feature_importances_)
[0.25347149 0.17518662 0.06551317 
 0.06758647 0.07866478 0.35957747]
Explainable AI in Python

Feature importance

import matplotlib.pyplot as plt

plt.barh(X_train.columns, 
         tree_model.feature_importances_)
plt.title('Feature Importance - Decision Tree')
plt.show()

A horizontal bar plot of the importances showing that CGPA and GRE score are the most important features.

import matplotlib.pyplot as plt

plt.barh(X_train.columns, 
         forest_model.feature_importances_)
plt.title('Feature Importance - Random Forest')
plt.show()

A horizontal bar plot of the importances showing that CGPA and GRE score are the most important features.

Explainable AI in Python

Let's practice!

Explainable AI in Python

Preparing Video For Download...