The amount of epochs to train is a hyperparameter determined via early stopping.
Goals are:
- Prevent overfitting (and therefore improve model performance)
- Reduce training time
- Determine an (close to) optimal value for the amount of epochs
(dirty) Manual interruption
Do this instead of patience unless your model is in its final version.
try:
for epoch in range(num_epochs):
for batch in data_loader:
# Training code goes here
print(f"Epoch {epoch + 1}/{num_epochs} completed.")
except KeyboardInterrupt:
print("Training interrupted by the user.")
# Save the model or perform any necessary cleanup here
(clean) Patience
Basically we save the losses for each epoch, and compare them. If for patience epochs, it doesn't improve, we take the model with the best result, in this case current epoch number - patience.
The important bits:
for epoch in range(epochs):
...
if test_loss < best_loss:
best_loss = test_loss
best_model_weights = copy.deepcopy(model.state_dict())
patience = patience_base_value
else:
patience -= 1
if patience == 0:
break
Full code for reference
stolen from Model Implementation (Pytorch)
#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience_base_value = 10
patience = patience_base_value
def train_loop(dataloader, model, loss_fn, optimizer):
'''
exact same training loop, not modified
'''
...
def test_loop(dataloader, model, loss_fn):
model.eval()
avg_test_loss, avg_accuracy = 0, 0
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
avg_test_loss += loss_fn(pred, y).item()
avg_accuracy += (pred.argmax(1) == y).sum().item().type(torch.float) # converting to float for an accurate calculation of the division by the dataset size below.
avg_test_loss /= len(dataloader)
accuracy /= len(dataloader.dataset)
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# we need to return the test loss (and potentially the accuracy) to compare it across epochs
return avg_test_loss, accuracy
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
current_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
# using a validation loss would be even better
avg_test_loss, accuracy = test_loop(test_dataloader, model, loss_fn)
if avg_test_loss < best_loss:
best_loss = avg_test_loss
best_model_weights = copy.deepcopy(model.state_dict())
patience = patience_base_value
else:
patience -= 1
if patience == 0:
break
model.load_state_dict(best_model_weights)
print("Done!")
you can also combine both