您好,登錄后才能下訂單哦!
這篇文章將為大家詳細講解有關pytorch怎么打印網絡回傳梯度,小編覺得挺實用的,因此分享給大家做個參考,希望大家閱讀完這篇文章后可以有所收獲。
打印梯度,檢查網絡學習情況
net = your_network().cuda()
def train():
...
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
for name, parms in net.named_parameters():
print('-->name:', name, '-->grad_requirs:',parms.requires_grad,
' -->grad_value:',parms.grad)
...
name表示網絡參數的名字; parms.requires_grad 表示該參數是否可學習,是不是frozen的; parm.grad 打印該參數的梯度值。
補充:pytorch的梯度計算
import torch
from torch.autograd import Variable
x = torch.Tensor([[1.,2.,3.],[4.,5.,6.]]) #grad_fn是None
x = Variable(x, requires_grad=True)
y = x + 2
z = y*y*3
out = z.mean()
#x->y->z->out
print(x)
print(y)
print(z)
print(out)
#結果:
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[3., 4., 5.],
[6., 7., 8.]], grad_fn=<AddBackward>)
tensor([[ 27., 48., 75.],
[108., 147., 192.]], grad_fn=<MulBackward>)
tensor(99.5000, grad_fn=<MeanBackward1>)
若是關于graph leaves求導的結果變量是一個標量,那么gradient默認為None,或者指定為“torch.Tensor([1.0])”
若是關于graph leaves求導的結果變量是一個向量,那么gradient是不能缺省的,要是和該向量同緯度的tensor
out.backward()
print(x.grad)
#結果:
tensor([[3., 4., 5.],
[6., 7., 8.]])
#如果是z關于x求導就必須指定gradient參數:
gradients = torch.Tensor([[2.,1.,1.],[1.,1.,1.]])
z.backward(gradient=gradients)
#若z不是一個標量,那么就先構造一個標量的值:L = torch.sum(z*gradient),再關于L對各個leaf Variable計算梯度
#對x關于L求梯度
x.grad
#結果:
tensor([[36., 24., 30.],
[36., 42., 48.]])
z.backward()
print(x.grad)
#報錯:RuntimeError: grad can be implicitly created only for scalar outputs只能為標量創建隱式變量
x1 = Variable(torch.Tensor([[1.,2.,3.],[4.,5.,6.]]))
x2 = Variable(torch.arange(4).view(2,2).type(torch.float), requires_grad=True)
c = x2.mm(x1)
c.backward(torch.ones_like(c))
# c.backward()
#RuntimeError: grad can be implicitly created only for scalar outputs
print(x2.grad)
從上面的例子中,out是常量,可以默認創建隱變量,如果反向傳播的不是常量,要知道該矩陣的具體值,在網絡中就是loss矩陣,方向傳播的過程中就是拿該歸一化的損失乘梯度來更新各神經元的參數。
看到一個博客這樣說:loss = criterion(outputs, labels)對應loss += (label[k] - h) * (label[k] - h) / 2
就是求loss(其實我覺得這一步不用也可以,反向傳播時用不到loss值,只是為了讓我們知道當前的loss是多少)
我認為一定是要求loss的具體值,才能對比閾值進行分類,通過非線性激活函數,判斷是否激活。
關于“pytorch怎么打印網絡回傳梯度”這篇文章就分享到這里了,希望以上內容可以對大家有一定的幫助,使各位可以學到更多知識,如果覺得文章不錯,請把它分享出去讓更多的人看到。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。