91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch 狀態字典:state_dict使用詳解

發布時間:2020-08-19 21:38:20 來源:腳本之家 閱讀:984 作者:wzg2016 欄目:開發技術

pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應參數建立映射關系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)

優化器對象Optimizer也有一個state_dict,它包含了優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數,可以直接調用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態. 而在"訓練(training)模態"與"評估(evalution)模態"下,這兩層有不同的表現形式.

模態字典(state_dict)的保存(model是一個網絡結構類的對象)

1.1)僅保存學習到的參數,用以下命令

 torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

 model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名

2.1)保存整個model的狀態,用以下命令

torch.save(model,PATH)

2.2)加載整個model的狀態,用以下命令:

   # Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項

如何僅加載某一層的訓練的到的參數(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加載模型參數后,如何設置某層某參數的"是否需要訓練"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作對象是tensor.

疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經測試,不可以.model.conv1 沒有requires_grad屬性.

全部測試代碼:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上這篇pytorch 狀態字典:state_dict使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

文昌市| 观塘区| 车险| 漳州市| 昌平区| 邵武市| 旬阳县| 通海县| 大渡口区| 阿尔山市| 安西县| 隆德县| 抚远县| 澜沧| 鹤岗市| 清涧县| 延津县| 稻城县| 达拉特旗| 邵阳县| 杭锦旗| 丰原市| 罗城| 诸城市| 渭南市| 肥东县| 布拖县| 庆阳市| 安丘市| 英山县| 县级市| 河曲县| 布拖县| 扶风县| 泗阳县| 靖远县| 岱山县| 清河县| 宿松县| 内江市| 封丘县|