91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

TensorFlow saver指定變量的存取

發布時間:2020-08-20 16:41:46 來源:腳本之家 閱讀:132 作者:main_h_ 欄目:開發技術

今天和大家分享一下用TensorFlow的saver存取訓練好的模型那點事。

1. 用saver存取變量;
2. 用saver存取指定變量。

用saver存取變量。

話不多說,先上代碼

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集沒有裝,加這個不顯示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路徑可以自己定
 print("save to path:",save_path)

這里我隨便定義了幾個變量然后進行存操作,運行后,變量w,b,s會被保存下來。保存會生成如下幾個文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下來是讀取的代碼

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在寫讀取代碼時要注意變量定義的類型、大小和變量的數量以及順序等要與存的時候一致,不然會報錯。你存的時候順序是w,b,s,取的時候同樣這個順序。存的時候w定義了dtype沒有 定義name,取的時候同樣要這樣,因為TensorFlow存取是按照鍵值對來存取的,所以必須一致。這里變量名,也就是w,s之類可以不同。

如下是我成功讀取的效果

TensorFlow saver指定變量的存取

用saver存取指定變量。

在我們做訓練時候,有些變量是沒有必要保存的,但是如果直接用tf.train.Saver()。程序會將所有的變量保存下來,這時候我們可以指定保存,只保存我們需要的變量,其他的統統丟掉。
其實很簡單,只需要在上面代碼基礎上稍加修改,只需把tf.train.Saver()替換成如下代碼

program = []
program += [w,b]
tf.train.Saver(program)

這樣,程序就只會存w和b了。同樣,讀取程序里面的tf.train.Saver()也要做如上修改。dtype,name之類依舊必須一致。

最后附上最終代碼:

# coding=utf-8
# saver保存變量測試
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集沒有裝,加這個不顯示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路徑可以自己定
 print("save to path:",save_path)


#saver提取變量測試
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

南木林县| 南漳县| 大新县| 雷州市| 搜索| 濉溪县| 沂源县| 崇义县| 霞浦县| 宝山区| 萝北县| 奎屯市| 桃源县| 昌邑市| 于都县| 黑山县| 平阴县| 车险| 新田县| 鸡泽县| 崇明县| 隆子县| 安康市| 文水县| 墨江| 鞍山市| 通化市| 齐河县| 庐江县| 舟山市| 茌平县| 邹城市| 嘉鱼县| 宁城县| 璧山县| 德昌县| 武汉市| 屯留县| 南木林县| 射阳县| 左贡县|