您好,登錄后才能下訂單哦!
本篇內容主要講解“pytorch的hook函數怎么使用”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學習“pytorch的hook函數怎么使用”吧!
""" @brief : pytorch的hook函數 """ import torch import torch.nn as nn from tools.common_tools2 import set_seed set_seed(1) # ----------------------------------- 1 tensor hook 1 flag = 0 # flag = 1 if flag: w = torch.tensor([1.], requires_grad=True) x = torch.tensor([2.], requires_grad=True) a = torch.add(w, x) b = torch.add(w, 1) y = torch.mul(a, b) a_grad = list() def grad_hook(grad): a_grad.append(grad) handle = a.register_hook(grad_hook) y.backward() # 查看梯度 print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad) print("a_grad[0]:", a_grad[0]) handle.remove() # ----------------------------------- 2 tensor hook 2 flag = 0 # flag = 1 if flag: w = torch.tensor([1.], requires_grad=True) x = torch.tensor([2.], requires_grad=True) a = torch.add(w, x) b = torch.add(w, 1) y = torch.mul(a, b) a_grad = list() def grad_hook(grad): grad *= 2 return grad * 3 handle = w.register_hook(grad_hook) y.backward() print("w.grad:", w.grad) handle.remove() # --------------------------- 3 Module.register_forward_hook and pre hook # flag = 0 flag = 1 if flag: class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) def forward_pre_hook(module, data_input): print("forward_pre_hook input:{}".format(data_input)) def backward_hook(module, grad_input, grad_output): print("backward hook input:{}".format(grad_input)) print("backward hook output:{}".format(grad_output)) # 初始化網絡 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 注冊hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) net.conv1.register_forward_pre_hook(forward_pre_hook) net.conv1.register_backward_hook(backward_hook) # inference fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W output = net(fake_img) # 前向傳播 loss_fnc = nn.L1Loss() target = torch.randn_like(output) loss = loss_fnc(target, output) loss.backward() # 觀察 print("output shape: {}\noutput value: {}\n".format(output.shape, output)) print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0])) print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
1. 采用torch.nn.Module.register_forward_hook機制實現AlexNet第一個卷積層輸出特征圖的可視化,并將/torchvision/models/alexnet.py中第28行改為:nn.ReLU(inplace=False),觀察
inplace=True與inplace=False的差異
# -*- coding:utf-8 -*- """ @brief : 采用hook函數可視化特征圖 """ import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as transforms import torchvision.utils as vutils from torch.utils.tensorboard import SummaryWriter from tools.common_tools2 import set_seed import torchvision.models as models set_seed(1) # 設置隨機種子 # ----------------------------------- feature map visualization # flag = 0 flag = 1 if flag: writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix") # 數據 path_img = "./lena.png" # your path to image normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] norm_transform = transforms.Normalize(normMean, normStd) img_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), norm_transform ]) img_pil = Image.open(path_img).convert('RGB') if img_transforms is not None: img_tensor = img_transforms(img_pil) img_tensor.unsqueeze_(0) # chw --> bchw # 模型 alexnet = models.alexnet(pretrained=True) # 注冊hook fmap_dict = dict() for name, sub_module in alexnet.named_modules(): if isinstance(sub_module, nn.Conv2d): key_name = str(sub_module.weight.shape) fmap_dict.setdefault(key_name, list()) n1, n2 = name.split(".") def hook_func(m, i, o): key_name = str(m.weight.shape) fmap_dict[key_name].append(o) alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) # forward output = alexnet(img_tensor) # add image for layer_name, fmap_list in fmap_dict.items(): fmap = fmap_list[0] fmap.transpose_(0, 1) nrow = int(np.sqrt(fmap.shape[0])) fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow) writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
到此,相信大家對“pytorch的hook函數怎么使用”有了更深的了解,不妨來實際操作一番吧!這里是億速云網站,更多相關內容可以進入相關頻道進行查詢,關注我們,繼續學習!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。