Batchtraining

Introductie tot TensorFlow in Python

Isaiah Hull

Visiting Associate Professor of Finance, BI Norwegian Business School

Wat is batchtraining?

Deze afbeelding toont prijs, grootte en aantal slaapkamers voor huizen in King County.

Deze afbeelding toont prijs, grootte en aantal slaapkamers voor huizen in King County, verdeeld in batches.

Introductie tot TensorFlow in Python

De parameter chunksize

  • Met pd.read_csv() kun je data in batches laden
    • Vermijd laden van hele dataset
    • Parameter chunksize bepaalt batchgrootte
# Import pandas and numpy
import pandas as pd
import numpy as np

# Load data in batches
for batch in pd.read_csv('kc_housing.csv', chunksize=100):
    # Extract price column
    price = np.array(batch['price'], np.float32)

    # Extract size column
    size = np.array(batch['size'], np.float32)
Introductie tot TensorFlow in Python

Een lineair model trainen in batches

# Import tensorflow, pandas, and numpy
import tensorflow as tf
import pandas as pd
import numpy as np
# Define trainable variables
intercept = tf.Variable(0.1, tf.float32)
slope = tf.Variable(0.1, tf.float32)
# Define the model
def linear_regression(intercept, slope, features):
    return intercept + features*slope
Introductie tot TensorFlow in Python

Een lineair model trainen in batches

# Compute predicted values and return loss function
def loss_function(intercept, slope, targets, features):
    predictions = linear_regression(intercept, slope, features)
    return tf.keras.losses.mse(targets, predictions)
# Define optimization operation
opt = tf.keras.optimizers.Adam()
Introductie tot TensorFlow in Python

Een lineair model trainen in batches

# Load the data in batches from pandas
for batch in pd.read_csv('kc_housing.csv', chunksize=100):
    # Extract the target and feature columns
    price_batch = np.array(batch['price'], np.float32)
    size_batch = np.array(batch['lot_size'], np.float32)

    # Minimize the loss function
    opt.minimize(lambda: loss_function(intercept, slope, price_batch, size_batch), 
                 var_list=[intercept, slope])
# Print parameter values
print(intercept.numpy(), slope.numpy())
Introductie tot TensorFlow in Python

Volledige steekproef vs. batchtraining

  • Volledige steekproef
    1. Eén update per epoch
    2. Neemt dataset zonder aanpassing
    3. Beperkt door geheugen
  • Batchtraining
    1. Meerdere updates per epoch
    2. Vereist opdelen van dataset
    3. Geen limiet op datasetgrootte
Introductie tot TensorFlow in Python

Laten we oefenen!

Introductie tot TensorFlow in Python

Preparing Video For Download...