您好,登錄后才能下訂單哦!
如何使用pytorch打印網絡回傳梯度?相信很多沒有經驗的人對此束手無策,為此本文總結了問題出現的原因和解決方法,通過這篇文章希望你能解決這個問題。
打印梯度,檢查網絡學習情況
net = your_network().cuda() def train(): ... outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() for name, parms in net.named_parameters(): print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \ ' -->grad_value:',parms.grad) ...
name表示網絡參數的名字; parms.requires_grad 表示該參數是否可學習,是不是frozen的; parm.grad 打印該參數的梯度值。
補充:pytorch的梯度計算
import torch from torch.autograd import Variable x = torch.Tensor([[1.,2.,3.],[4.,5.,6.]]) #grad_fn是None x = Variable(x, requires_grad=True) y = x + 2 z = y*y*3 out = z.mean() #x->y->z->out print(x) print(y) print(z) print(out) #結果: tensor([[1., 2., 3.], [4., 5., 6.]], requires_grad=True) tensor([[3., 4., 5.], [6., 7., 8.]], grad_fn=<AddBackward>) tensor([[ 27., 48., 75.], [108., 147., 192.]], grad_fn=<MulBackward>) tensor(99.5000, grad_fn=<MeanBackward1>)
若是關于graph leaves求導的結果變量是一個標量,那么gradient默認為None,或者指定為“torch.Tensor([1.0])”
若是關于graph leaves求導的結果變量是一個向量,那么gradient是不能缺省的,要是和該向量同緯度的tensor
out.backward() print(x.grad) #結果: tensor([[3., 4., 5.], [6., 7., 8.]]) #如果是z關于x求導就必須指定gradient參數: gradients = torch.Tensor([[2.,1.,1.],[1.,1.,1.]]) z.backward(gradient=gradients) #若z不是一個標量,那么就先構造一個標量的值:L = torch.sum(z*gradient),再關于L對各個leaf Variable計算梯度 #對x關于L求梯度 x.grad #結果: tensor([[36., 24., 30.], [36., 42., 48.]])
z.backward() print(x.grad) #報錯:RuntimeError: grad can be implicitly created only for scalar outputs只能為標量創建隱式變量 x1 = Variable(torch.Tensor([[1.,2.,3.],[4.,5.,6.]])) x2 = Variable(torch.arange(4).view(2,2).type(torch.float), requires_grad=True) c = x2.mm(x1) c.backward(torch.ones_like(c)) # c.backward() #RuntimeError: grad can be implicitly created only for scalar outputs print(x2.grad)
1.PyTorch是相當簡潔且高效快速的框架;2.設計追求最少的封裝;3.設計符合人類思維,它讓用戶盡可能地專注于實現自己的想法;4.與google的Tensorflow類似,FAIR的支持足以確保PyTorch獲得持續的開發更新;5.PyTorch作者親自維護的論壇 供用戶交流和求教問題6.入門簡單
看完上述內容,你們掌握如何使用pytorch打印網絡回傳梯度的方法了嗎?如果還想學到更多技能或想了解更多相關內容,歡迎關注億速云行業資訊頻道,感謝各位的閱讀!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。