Ajuste de hiperparámetros en R
Dr. Shirin Elsinghorst
Senior Data Scientist
?h2o.gbm
ntrees: Número de árboles. Predeterminado: 50.
max_depth: Profundidad máxima. Predeterminado: 5.
min_rows: Mín. observaciones (ponderadas) en una hoja. Predeterminado: 10.
learn_rate: Tasa de aprendizaje (0.0 a 1.0). Predeterminado: 0.1.
learn_rate_annealing: Escala la tasa tras cada árbol (p. ej., 0.99 o 0.999). Predeterminado: 1.seeds_data_hf <- as.h2o(seeds_data)
y <- "seed_type"
x <- setdiff(colnames(seeds_data_hf), y)
sframe <- h2o.splitFrame(data = seeds_data_hf, ratios = c(0.7, 0.15), seed = 42)
train <- sframe[[1]]
valid <- sframe[[2]]
test <- sframe[[3]]
gbm_params <- list(ntrees = c(100, 150, 200), max_depth = c(3, 5, 7), learn_rate = c(0.001, 0.01, 0.1))
h2o.gridgbm_grid <- h2o.grid("gbm",
grid_id = "gbm_grid",
x = x,
y = y,
training_frame = train,
validation_frame = valid,
seed = 42,
hyper_params = gbm_params)
h2o.getGridExamina resultados de gbm_grid con h2o.getGrid.
Obtén la cuadrícula ordenada por accuracy de validación
gbm_gridperf <- h2o.getGrid(grid_id = "gbm_grid", sort_by = "accuracy", decreasing = TRUE)
ID de la cuadrícula: gbm_grid
Hiperparámetros usados:
- learn_rate
- max_depth
- ntrees
Número de modelos: 27
Modelos con error: 0
Resumen de búsqueda de hiperparámetros: ordenado por accuracy descendente
best_gbm <- h2o.getModel(gbm_gridperf@model_ids[[1]])
print(best_gbm@model[["model_summary"]])
Resumen del modelo:
number_of_trees number_of_internal_trees model_size_in_bytes min_depth
200 600 100961 2
max_depth mean_depth min_leaves max_leaves mean_leaves
7 5.22667 3 10 8.38833
best_gbm es un modelo H2O normal y se trata como tal.h2o.performance(best_gbm, test)
MSE: (Extraer con `h2o.mse`) 0.04761904
RMSE: (Extraer con `h2o.rmse`) 0.2182179
Logloss: (Extraer con `h2o.loglos
gbm_params <- list(ntrees = c(100, 150, 200), max_depth = c(3, 5, 7), learn_rate = c(0.001, 0.01, 0.1))search_criteria <- list(strategy = "RandomDiscrete", max_runtime_secs = 60, seed = 42)gbm_grid <- h2o.grid("gbm", grid_id = "gbm_grid", x = x, y = y, training_frame = train, validation_frame = valid, seed = 42, hyper_params = gbm_params, search_criteria = search_criteria)
search_criteria <- list(strategy = "RandomDiscrete", stopping_metric = "mean_per_class_error", stopping_tolerance = 0.0001, stopping_rounds = 6)gbm_grid <- h2o.grid("gbm", x = x, y = y, training_frame = train, validation_frame = valid, seed = 42, hyper_params = gbm_params, search_criteria = search_criteria)
Detalles de la cuadrícula H2O
=============================
ID de la cuadrícula: gbm_grid
Hiperparámetros usados:
- learn_rate
- max_depth
- ntrees
Número de modelos: 30
Modelos con error: 0
Ajuste de hiperparámetros en R