在PyTorch中,nn.Parameter
是一個特殊的Tensor,它是nn.Module
中可訓練參數的一種特殊類型。nn.Parameter
對象由nn.Module
的構造函數自動識別并將其注冊為模型的可訓練參數。
要使用nn.Parameter
,首先需要創建一個nn.Parameter
對象,并將其作為模型的屬性。下面是一個簡單的示例:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.rand(3, 4)) # 創建一個參數
def forward(self, x):
out = torch.matmul(x, self.weight)
return out
model = MyModel()
print(model.weight) # 打印參數
在上面的示例中,我們定義了一個MyModel
類,它繼承自nn.Module
。在構造函數__init__
中,我們創建了一個nn.Parameter
對象self.weight
,它是一個形狀為(3, 4)
的隨機初始化的Tensor。
在forward
方法中,我們可以使用self.weight
參數進行計算。在模型創建完畢后,我們可以通過model.weight
來訪問這個參數。
需要注意的是,nn.Parameter
對象會自動被注冊為模型的可訓練參數,并且在模型的parameters()
方法中可以訪問到。此外,nn.Parameter
對象還會自動具有梯度計算的功能,可以通過backward()
方法自動計算梯度。