Vanishing and exploding gradients

Intermediate Deep Learning with PyTorch

Michal Oleszak

Machine Learning Engineer

Vanishing gradients

  • Gradients get smaller and smaller during backward pass
  • Earlier layers get small parameter updates
  • Model doesn't learn

Graphs showing gradient size versus layer index: for earlier layers, the gradients are smaller

Intermediate Deep Learning with PyTorch

Exploding gradients

  • Gradients get bigger and bigger
  • Parameter updates are too large
  • Training diverges

Graphs showing gradient size versus layer index: for earlier layers, the gradients are larger

Intermediate Deep Learning with PyTorch

Solution to unstable gradients

  1. Proper weights initialization
  2. Good activations
  3. Batch normalization

 

 

Three steps

Intermediate Deep Learning with PyTorch

Weights initialization

layer = nn.Linear(8, 1)
print(layer.weight)
Parameter containing:
tensor([[-0.0195,  0.0992,  0.0391,  0.0212,
         -0.3386, -0.1892, -0.3170,  0.2148]])
Intermediate Deep Learning with PyTorch

Weights initialization

Good initialization ensures:

  • Variance of layer inputs = variance of layer outputs
  • Variance of gradients the same before and after a layer

 

How to achieve this depends on the activation:

  • For ReLU and similar, we can use He/Kaiming initialization
Intermediate Deep Learning with PyTorch

Weights initialization

import torch.nn.init as init

init.kaiming_uniform_(layer.weight)
print(layer.weight)
Parameter containing:
tensor([[-0.3063, -0.2410,  0.0588,  0.2664,
          0.0502, -0.0136,  0.2274,  0.0901]])
Intermediate Deep Learning with PyTorch

He / Kaiming initialization

init.kaiming_uniform_(self.fc1.weight)
init.kaiming_uniform_(self.fc2.weight)
init.kaiming_uniform_(
  self.fc3.weight,
  nonlinearity="sigmoid",
)
Intermediate Deep Learning with PyTorch

He / Kaiming initialization

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 1)


init.kaiming_uniform_(self.fc1.weight) init.kaiming_uniform_(self.fc2.weight) init.kaiming_uniform_( self.fc3.weight, nonlinearity="sigmoid", )




    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = nn.functional.sigmoid(self.fc3(x))
        return x







Intermediate Deep Learning with PyTorch

Activation functions

A plot showing the ReLU function. For values below zero, the line is horizontal at zero; for values above zero, the curve the line takes a positive slope.

  • Often used as the default activation
  • nn.functional.relu()
  • Zero for negative inputs - dying neurons

A plot showing the ELU function. It is similar to the ReLU function but with a curved and smooth transition from negative values to the positive region.

  • nn.functional.elu()
  • Non-zero gradients for negative values - helps against dying neurons
  • Average output around zero - helps against vanishing gradients
Intermediate Deep Learning with PyTorch

Batch normalization

After a layer:

  1. Normalize the layer's outputs by:

    • Subtracting the mean
    • Dividing by the standard deviation
  2. Scale and shift normalized outputs using learned parameters

Model learns optimal inputs distribution for each layer:

  • Faster loss decrease
  • Helps against unstable gradients
Intermediate Deep Learning with PyTorch

Batch normalization

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9, 16)
        self.bn1 = nn.BatchNorm1d(16)

        ...


def forward(self, x): x = self.fc1(x) x = self.bn1(x) x = nn.functional.elu(x) ...
Intermediate Deep Learning with PyTorch

Let's practice!

Intermediate Deep Learning with PyTorch

Preparing Video For Download...