91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

tensorflow從.ckpt文件中讀取任意變量的實現方法

發布時間:2020-07-21 16:51:30 來源:億速云 閱讀:424 作者:小豬 欄目:開發技術

這篇文章主要講解了tensorflow從.ckpt文件中讀取任意變量的實現方法,內容清晰明了,對此有興趣的小伙伴可以學習一下,相信大家閱讀完之后會有幫助。

看了faster rcnn的tensorflow代碼,關于fix_variables的作用我不是很明白,所以寫了以下代碼,讀取了預訓練模型vgg16得fc6和fc7的參數,以及faster rcnn中heat_to_tail中的fc6和fc7,將它們做了對比,發現結果不一樣,說明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被訓練。

具體讀取任意變量的代碼如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路徑
name_variable_to_restore = 'vgg_16/fc7/weights' #要讀取權重的變量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #輸出這個變量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定義接收權重的變量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定義恢復變量的對象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必須初始化
restorer_fc.restore(sess, file_name) #恢復變量
print(sess.run(fc7_conv)) #輸出結果

用以上的代碼分別讀取兩個網絡的fc6 和 fc7 ,對應參數尺寸和權值都不同,但參數量相同。

再看lib/nets/vgg16.py中的:

(注意注釋)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定義接收權重的變量,不可被訓練
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定義恢復變量的對象
   restorer_fc.restore(sess, pretrained_model) #恢復這些變量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #將vgg16中的fc6、fc7中的權重reshape賦給faster-rcnn中的fc6、fc7

我的理解:faster rcnn的網絡繼承了分類網絡的特征提取權重和分類器的權重,讓網絡從一個比較好的起點開始被訓練,有利于訓練結果的快速收斂。

補充知識:TensorFlow:加載部分ckpt文件變量&不同命名空間中加載模型

TensorFlow中,在加載和保存模型時,一般會直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,當需要選擇性加載模型參數時,則需要利用pywrap_tensorflow讀取模型,分析模型內的變量關系。

例子:Faster-RCNN中,模型加載vgg16.ckpt,需要利用pywrap_tensorflow讀取ckpt文件中的參數

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此處構建vgg16模型
variables = tf.global_variables()#獲取模型中所有變量
 
file_name='vgg16.ckpt'#vgg16網絡模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#獲取ckpt模型中的變量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空間名
variables_to_restore={}#構建字典:需要的變量和對應的模型變量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#將需要加載的變量作為參數輸入
restorer.restore(sess, file_name)

實際中,Faster RCNN中所構建的vgg16網絡的fc6和fc7權重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7權重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

看完上述內容,是不是對tensorflow從.ckpt文件中讀取任意變量的實現方法有進一步的了解,如果還想學習更多內容,歡迎關注億速云行業資訊頻道。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

永川市| 哈巴河县| 建湖县| 苏州市| 博兴县| 铜山县| 清镇市| 界首市| 环江| 霍邱县| 高雄县| 乾安县| 水城县| 长子县| 凤阳县| 洪湖市| 大姚县| 陵川县| 开化县| 遂川县| 石台县| 来安县| 绥芬河市| 广南县| 县级市| 夏河县| 嘉黎县| 青神县| 凤山县| 常熟市| 枣强县| 金昌市| 新宁县| 南投县| 名山县| 兴安盟| 特克斯县| 梓潼县| 宁夏| 海盐县| 丽江市|