您好,登錄后才能下訂單哦!
這期內容當中小編將會給大家帶來有關如何在Pytorch 中使用retain_graph,文章內容豐富且以專業的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。
用法分析
在查看SRGAN源碼時有如下損失函數,其中設置了retain_graph=True,其作用是什么?
############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) ##### optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img) running_results['g_loss'] += g_loss.data[0] * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data[0] * batch_size running_results['d_score'] += real_out.data[0] * batch_size running_results['g_score'] += fake_out.data[0] * batch_size
在更新D網絡時的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計算的梯度,后續G網絡更新時使用;
其實retain_graph這個參數在平常中我們是用不到的,但是在特殊的情況下我們會用到它,
如下代碼:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
輸出如下錯誤信息:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-19-8ad6b0658906> in <module>() ----> 1 output1.backward() 2 output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph) 91 products. Defaults to ``False``. 92 """ ---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph) 94 95 def register_hook(self, hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 88 Variable._execution_engine.run_backward( 89 tensors, grad_tensors, retain_graph, create_graph, ---> 90 allow_unreachable=True) # allow_unreachable flag 91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正確:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
# 假如你有兩個Loss,先執行第一個的backward,再執行第二個backward loss1.backward(retain_graph=True) loss2.backward() # 執行完這個后,所有中間變量都會被釋放,以便下一次的循環 optimizer.step() # 更新參數
Variable 類源代碼
class Variable(_C._VariableBase): """ Attributes: data: 任意類型的封裝好的張量。 grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。 requires_grad: 標記變量是否已經由一個需要調用到此變量的子圖創建的bool值。只能在葉子變量上進行修改。 volatile: 標記變量是否能在推理模式下應用(如不保存歷史記錄)的bool值。只能在葉變量上更改。 is_leaf: 標記變量是否是圖葉子(如由用戶創建的變量)的bool值. grad_fn: Gradient function graph trace. Parameters: data (any tensor class): 要包裝的張量. requires_grad (bool): bool型的標記值. **Keyword only.** volatile (bool): bool型的標記值. **Keyword only.** """ def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None): """計算關于當前圖葉子變量的梯度,圖使用鏈式法則導致分化 如果Variable是一個標量(例如它包含一個單元素數據),你無需對backward()指定任何參數 如果變量不是標量(包含多個元素數據的矢量)且需要梯度,函數需要額外的梯度; 需要指定一個和tensor的形狀匹配的grad_output參數(y在指定方向投影對x的導數); 可以是一個類型和位置相匹配且包含與自身相關的不同函數梯度的張量。 函數在葉子上累積梯度,調用前需要對該葉子進行清零。 Arguments: grad_variables (Tensor, Variable or None): 變量的梯度,如果是一個張量,除非“create_graph”是True,否則會自動轉換成volatile型的變量。 可以為標量變量或不需要grad的值指定None值。如果None值可接受,則此參數可選。 retain_graph (bool, optional): 如果為False,用來計算梯度的圖將被釋放。 在幾乎所有情況下,將此選項設置為True不是必需的,通常可以以更有效的方式解決。 默認值為create_graph的值。 create_graph (bool, optional): 為True時,會構造一個導數的圖,用來計算出更高階導數結果。 默認為False,除非``gradient``是一個volatile變量。 """ torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables) def register_hook(self, hook): """Registers a backward hook. 每當與variable相關的梯度被計算時調用hook,hook的申明:hook(grad)->Variable or None 不能對hook的參數進行修改,但可以選擇性地返回一個新的梯度以用在`grad`的相應位置。 函數返回一個handle,其``handle.remove()``方法用于將hook從模塊中移除。 Example: >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.Tensor([1, 1, 1])) >>> v.grad.data 2 2 2 [torch.FloatTensor of size 3] >>> h.remove() # removes the hook """ if self.volatile: raise RuntimeError("cannot register a hook on a volatile variable") if not self.requires_grad: raise RuntimeError("cannot register a hook on a variable that " "doesn't require gradient") if self._backward_hooks is None: self._backward_hooks = OrderedDict() if self.grad_fn is not None: self.grad_fn._register_hook_dict(self) handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def reinforce(self, reward): """Registers a reward obtained as a result of a stochastic process. 區分隨機節點需要為他們提供reward值。如果圖表中包含任何的隨機操作,都應該在其輸出上調用此函數,否則會出現錯誤。 Parameters: reward(Tensor): 帶有每個元素獎賞的張量,必須與Variable數據的設備位置和形狀相匹配。 """ if not isinstance(self.grad_fn, StochasticFunction): raise RuntimeError("reinforce() can be only called on outputs " "of stochastic functions") self.grad_fn._reinforce(reward) def detach(self): """返回一個從當前圖分離出來的心變量。 結果不需要梯度,如果輸入是volatile,則輸出也是volatile。 .. 注意:: 返回變量使用與原始變量相同的數據張量,并且可以看到其中任何一個的就地修改,并且可能會觸發正確性檢查中的錯誤。 """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result def detach_(self): """從創建它的圖中分離出變量并作為該圖的一個葉子""" self._grad_fn = None self.requires_grad = False def retain_grad(self): """Enables .grad attribute for non-leaf Variables.""" if self.grad_fn is None: # no-op for leaves return if not self.requires_grad: raise RuntimeError("can't retain_grad on Variable that has requires_grad=False") if hasattr(self, 'retains_grad'): return weak_self = weakref.ref(self) def retain_grad_hook(grad): var = weak_self() if var is None: return if var._grad is None: var._grad = grad.clone() else: var._grad = var._grad + grad self.register_hook(retain_grad_hook) self.retains_grad = True
上述就是小編為大家分享的如何在Pytorch 中使用retain_graph了,如果剛好有類似的疑惑,不妨參照上述分析進行理解。如果想知道更多相關知識,歡迎關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。