在PyTorch中,torch.nn.Linear
是一個用于定義線性變換的類。它將輸入的特征向量進行線性變換,并輸出一個新的特征向量。
在使用torch.nn.Linear
時,你需要指定輸入特征的維度和輸出特征的維度。這兩個參數分別是in_features
和out_features
。例如,如果你有一個輸入特征是100維,輸出特征是50維的線性變換,可以使用以下方式創建一個Linear
對象:
import torch
import torch.nn as nn
linear = nn.Linear(100, 50)
然后,你可以將輸入特征向量傳遞給線性層,使用forward
方法進行線性變換。例如,假設你有一個大小為[batch_size, 100]
的輸入特征張量x
,你可以通過以下方式對其進行線性變換:
output = linear(x)
最后,output
將是一個大小為[batch_size, 50]
的特征張量,它是輸入特征經過線性變換得到的結果。
此外,torch.nn.Linear
類還包含了參數權重weight
和偏置bias
,它們可以通過linear.weight
和linear.bias
來訪問。這些參數會在模型訓練過程中自動更新,以最小化定義的損失函數。