使用pytorch
庫中的model.parameters()
可以獲得模型的所有參數,然后使用len()
函數可以統計參數的數量。下面是一個示例代碼:
import torch
import torch.nn as nn
# 創建模型
model = nn.Linear(10, 5)
# 統計參數數量
num_parameters = sum(p.numel() for p in model.parameters())
print(f"模型參數數量: {num_parameters}")
輸出結果會顯示模型的參數數量。