您好,登錄后才能下訂單哦!
pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應參數建立映射關系.(如model的每一層的weights及偏置等等)
(注意,只有那些參數可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)
優化器對象Optimizer也有一個state_dict,它包含了優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)
備注:
1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數,可以直接調用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態. 而在"訓練(training)模態"與"評估(evalution)模態"下,這兩層有不同的表現形式.
模態字典(state_dict)的保存(model是一個網絡結構類的對象)
1.1)僅保存學習到的參數,用以下命令
torch.save(model.state_dict(), PATH)
1.2)加載model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名
2.1)保存整個model的狀態,用以下命令
torch.save(model,PATH)
2.2)加載整個model的狀態,用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項
如何僅加載某一層的訓練的到的參數(某一層的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加載模型參數后,如何設置某層某參數的"是否需要訓練"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False
注意: requires_grad的操作對象是tensor.
疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:經測試,不可以.model.conv1 沒有requires_grad屬性.
全部測試代碼:
#-*-coding:utf-8-*- import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial model model = TheModelClass() #initialize the optimizer optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dict print("model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param") print('\n',model.conv1.weight.size()) print('\n',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('\n',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## 僅僅加載某一層的參數 conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)
以上這篇pytorch 狀態字典:state_dict使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。