在Torch中進行模型評估通常需要使用驗證集或測試集來評估模型的性能。下面是一個基本的示例來展示如何在Torch中進行模型評估:
import torch
import torch.nn as nn
import torch.optim as optim
# 定義模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 創建模型實例
model = SimpleModel()
# 加載訓練好的模型參數
model.load_state_dict(torch.load('model.pth'))
# 定義評估函數
def evaluate(model, dataloader, criterion):
model.eval()
total_loss = 0.0
total_samples = 0
with torch.no_grad():
for inputs, targets in dataloader:
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
avg_loss = total_loss / total_samples
return avg_loss
# 創建驗證集的數據加載器
val_dataloader = ...
# 定義損失函數
criterion = nn.MSELoss()
# 計算模型在驗證集上的平均損失
avg_val_loss = evaluate(model, val_dataloader, criterion)
print('Average validation loss:', avg_val_loss)
在上面的示例中,首先定義了一個簡單的模型SimpleModel
,然后加載了預訓練好的模型參數。接著定義了評估函數evaluate
來計算模型在驗證集上的平均損失。最后,通過調用evaluate
函數來評估模型在驗證集上的性能,并輸出平均損失值。