在PyTorch中保存最佳模型通常是通過保存模型的參數和優化器狀態來實現的。以下是一個示例代碼,演示了如何保存最佳模型:
import torch
import torch.nn as nn
import torch.optim as optim
# 定義模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = MyModel()
# 定義損失函數和優化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 訓練模型
best_loss = float('inf')
for epoch in range(num_epochs):
# 訓練過程
train_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# 保存最佳模型
if train_loss < best_loss:
best_loss = train_loss
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss
}, 'best_model.pth')
在上面的示例中,我們首先定義了一個模型、損失函數和優化器。然后在訓練過程中,我們通過比較當前訓練損失和最佳損失來保存最佳模型。當訓練損失小于最佳損失時,我們保存模型的狀態字典和優化器的狀態字典,并將最佳損失更新為當前訓練損失。
最后,我們可以通過加載best_model.pth
文件來恢復最佳模型的狀態,并繼續使用該模型進行推理或進一步的訓練。