Pelatihan batch

Pendahuluan TensorFlow di Python

Isaiah Hull

Visiting Associate Professor of Finance, BI Norwegian Business School

Apa itu pelatihan batch?

Gambar ini menampilkan data harga, luas, dan jumlah kamar tidur untuk rumah di King County.

Gambar ini menampilkan data harga, luas, dan jumlah kamar tidur untuk rumah di King County, dibagi menjadi batch.

Pendahuluan TensorFlow di Python

Parameter chunksize

  • pd.read_csv() memungkinkan memuat data per batch
    • Hindari memuat seluruh dataset
    • Parameter chunksize menentukan ukuran batch
# Import pandas and numpy
import pandas as pd
import numpy as np

# Load data in batches
for batch in pd.read_csv('kc_housing.csv', chunksize=100):
    # Extract price column
    price = np.array(batch['price'], np.float32)

    # Extract size column
    size = np.array(batch['size'], np.float32)
Pendahuluan TensorFlow di Python

Melatih model linear per batch

# Import tensorflow, pandas, and numpy
import tensorflow as tf
import pandas as pd
import numpy as np
# Define trainable variables
intercept = tf.Variable(0.1, tf.float32)
slope = tf.Variable(0.1, tf.float32)
# Define the model
def linear_regression(intercept, slope, features):
    return intercept + features*slope
Pendahuluan TensorFlow di Python

Melatih model linear per batch

# Compute predicted values and return loss function
def loss_function(intercept, slope, targets, features):
    predictions = linear_regression(intercept, slope, features)
    return tf.keras.losses.mse(targets, predictions)
# Define optimization operation
opt = tf.keras.optimizers.Adam()
Pendahuluan TensorFlow di Python

Melatih model linear per batch

# Load the data in batches from pandas
for batch in pd.read_csv('kc_housing.csv', chunksize=100):
    # Extract the target and feature columns
    price_batch = np.array(batch['price'], np.float32)
    size_batch = np.array(batch['lot_size'], np.float32)

    # Minimize the loss function
    opt.minimize(lambda: loss_function(intercept, slope, price_batch, size_batch), 
                 var_list=[intercept, slope])
# Print parameter values
print(intercept.numpy(), slope.numpy())
Pendahuluan TensorFlow di Python

Sampel penuh vs. pelatihan batch

  • Sampel penuh
    1. Satu pembaruan per epoch
    2. Menerima dataset tanpa diubah
    3. Terbatas oleh memori
  • Pelatihan batch
    1. Beberapa pembaruan per epoch
    2. Perlu pembagian dataset
    3. Tidak dibatasi ukuran dataset
Pendahuluan TensorFlow di Python

Ayo berlatih!

Pendahuluan TensorFlow di Python

Preparing Video For Download...