Torch中可以通過使用一些可視化工具來對模型進行可視化,例如使用TensorBoardX庫。以下是一個簡單示例:
pip install tensorflow
pip install tensorboardX
from tensorboardX import SummaryWriter
# 創建一個SummaryWriter對象,指定log目錄
writer = SummaryWriter('logs')
# 在訓練過程中,可以使用add_scalar方法記錄損失值
for i in range(num_epochs):
loss = train_model()
writer.add_scalar('Loss/train', loss, i)
# 在訓練過程中,也可以使用add_graph方法記錄模型結構
model = Model()
data = torch.rand(1, 3, 224, 224)
writer.add_graph(model, data)
# 訓練完成后,關閉SummaryWriter對象
writer.close()
tensorboard --logdir logs