91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

如何在pytorch中解決state_dict()的拷貝問題

發布時間:2021-03-03 15:32:46 來源:億速云 閱讀:333 作者:Leah 欄目:開發技術

如何在pytorch中解決state_dict()的拷貝問題?很多新手對此不是很清楚,為了幫助大家解決這個難題,下面小編將為大家詳細講解,有這方面需求的人可以來學習下,希望你能有所收獲。

model.state_dict()是淺拷貝,返回的參數仍然會隨著網絡的訓練而變化。

應該使用deepcopy(model.state_dict()),或將參數及時序列化到硬盤。

再講故事,前幾天在做一個模型的交叉驗證訓練時,通過model.state_dict()保存了每一組交叉驗證模型的參數,后根據效果選擇準確率最佳的模型load回去,結果每一次都是最后一個模型,從地址來看,每一個保存的state_dict()都具有不同的地址,但進一步發現state_dict()下的各個模型參數的地址是共享的,而我又使用了in-place的方式重置模型參數,進而導致了上述問題。

補充:pytorch中state_dict的理解

在PyTorch中,state_dict是一個Python字典對象(在這個有序字典中,key是各層參數名,value是各層參數),包含模型的可學習參數(即權重和偏差,以及bn層的的參數) 優化器對象(torch.optim)也具有state_dict,其中包含有關優化器狀態以及所用超參數的信息。

其實看了如下代碼的輸出應該就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

輸出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是剛接觸深度學西的小白一個,希望大佬可以為我指出我的不足,此博客僅為自己的筆記!!!!

補充:pytorch保存模型時報錯***object has no attribute 'state_dict'

定義了一個類BaseNet并實例化該類:

net=BaseNet()

保存net時報錯 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定義類的時候不是繼承nn.Module類,比如:

class BaseNet(object):
  def __init__(self):

把類定義改為

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

看完上述內容是否對您有幫助呢?如果還想對相關知識有進一步的了解或閱讀更多相關文章,請關注億速云行業資訊頻道,感謝您對億速云的支持。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

漯河市| 安福县| 临安市| 正镶白旗| 金平| 河北区| 河津市| 南宁市| 丽江市| 松滋市| 奉新县| 新平| 乡城县| 安宁市| 夹江县| 镇远县| 淮北市| 台东县| 新乡县| 鹿泉市| 武鸣县| 宝兴县| 扎赉特旗| 合江县| 苏尼特右旗| 手机| 胶南市| 建水县| 龙江县| 克山县| 高青县| 修水县| 抚远县| 余姚市| 平遥县| 石柱| 陆川县| 团风县| 衡阳市| 潜山县| 卢龙县|