Cross-Validation

Machine Learning with PySpark

Andrew Collier

Data Scientist, Fathom Data

Full dataset

Machine Learning with PySpark

Dataset split into training and testing sets

Machine Learning with PySpark

Training data split into multiple folds

Machine Learning with PySpark

Fold upon fold - first fold

First fold

Machine Learning with PySpark

Fold upon fold - second fold

Second fold

Machine Learning with PySpark

Fold upon fold - other folds

Remaining folds

Machine Learning with PySpark

Cars revisited

cars.select('mass', 'cyl', 'consumption').show(5)
+------+---+-----------+
|  mass|cyl|consumption|
+------+---+-----------+
|1451.0|  6|       9.05|
|1129.0|  4|       6.53|
|1399.0|  4|       7.84|
|1147.0|  4|       7.84|
|1111.0|  4|       9.05|
+------+---+-----------+
Machine Learning with PySpark

Estimator and evaluator

An object to build the model. This can be a pipeline.

regression = LinearRegression(labelCol='consumption')

An object to evaluate model performance.

evaluator = RegressionEvaluator(labelCol='consumption')
Machine Learning with PySpark

Grid and cross-validator

from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

A grid of parameter values (empty for the moment).

params = ParamGridBuilder().build()

The cross-validation object.

cv = CrossValidator(estimator=regression,
                    estimatorParamMaps=params,
                    evaluator=evaluator,
                    numFolds=10, seed=13)
Machine Learning with PySpark

Cross-validators need training too

Apply cross-validation to the training data.

cv = cv.fit(cars_train)

What's the average RMSE across the folds?

cv.avgMetrics
[0.800663722151572]
Machine Learning with PySpark

Cross-validators act like models

Make predictions on the original testing data.

evaluator.evaluate(cv.transform(cars_test))
# RMSE on testing data
0.745974203928479

Much smaller than the cross-validated RMSE.

# RMSE from cross-validation
0.800663722151572

A simple train-test split would have given an overly optimistic view on model performance.

Machine Learning with PySpark

Cross-validate all the models!

Machine Learning with PySpark

Preparing Video For Download...