Gevorderde Deep Learning met PyTorch
Michal Oleszak
Machine Learning Engineer
Fout:
Kwadratische fout:
Gemiddelde kwadratische fout (MSE):
Het kwadrateren van de fout:
criterion = nn.MSELoss()
(batch_size, seq_length, num_features)(batch_size, seq_length)for seqs, labels in dataloader_train:
print(seqs.shape)
torch.Size([32, 96])
seqs = seqs.view(32, 96, 1)
print(seqs.shape)
torch.Size([32, 96, 1])
Labels hebben vorm (batch_size)
for seqs, labels in test_loader:
print(labels.shape)
torch.Size([32])
Modeloutputs zijn (batch_size, 1)
out = net(seqs)
torch.Size([32, 1])
We kunnen de laatste dimensie van de outputs droppen
out = net(seqs).squeeze()
torch.Size([32])
net = Net() criterion = nn.MSELoss() optimizer = optim.Adam( net.parameters(), lr=0.001 )for epoch in range(num_epochs): for seqs, labels in dataloader_train:seqs = seqs.view(32, 96, 1)outputs = net(seqs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step()
mse = torchmetrics.MeanSquaredError()net.eval() with torch.no_grad(): for seqs, labels in test_loader:seqs = seqs.view(32, 96, 1)outputs = net(seqs).squeeze()mse(outputs, labels)print(f"Test MSE: {mse.compute()}")
Test MSE: 0.13292162120342255
Test MSE: 0.13292162120342255
Test MSE: 0.12187089771032333
Gevorderde Deep Learning met PyTorch