您好,登錄后才能下訂單哦!
本篇文章給大家分享的是有關TensorFlow中怎么利用saver保存和提取參數,小編覺得挺實用的,因此分享給大家學習,希望大家閱讀完這篇文章后可以有所收獲,話不多說,跟著小編一起來看看吧。
在訓練循環中,定期調用 saver.save() 方法,向文件夾中寫入包含了當前模型中所有可訓練變量的 checkpoint 文件。
saver.save(sess, FLAGS.train_dir, global_step=step)
global_step是訓練的第幾步
保存參數:
import tensorflow as tf W = tf.Variable([[1, 2, 3]], dtype=tf.float32) b = tf.Variable([[1]], dtype=tf.float32) saver = tf.train.Saver() sess = tf.InteractiveSession() tf.global_variables_initializer().run() # 必須要指定文件夾,保存到ckpt文件 save_path = saver.save(sess, "winycg/1.ckpt") print(save_path)
一次 saver.save() 后可以在文件夾中看到新增的四個文件,實際上每調用一次保存操作會創建后3個數據文件并創建一個檢查點(checkpoint)文件,簡單理解就是權重等參數被保存到 .chkp.data 文件中,以字典的形式;圖和元數據被保存到 .chkp.meta 文件中,可以被 tf.train.import_meta_graph 加載到當前默認的圖。
讀取參數:
import tensorflow as tf import numpy as np W = tf.Variable(np.arange(3).reshape(1, 3), dtype=tf.float32) b = tf.Variable(np.arange(1).reshape(1, 1), dtype=tf.float32) saver = tf.train.Saver() sess = tf.InteractiveSession() # 讀取參數時不需要global_variables_initializer() save_path = saver.restore(sess, "parameter/1.ckpt") print("weights:", sess.run(W)) print("bias:", sess.run(b))
weights: [[ 1. 2. 3.]]
bias: [[ 1.]]
以上就是TensorFlow中怎么利用saver保存和提取參數,小編相信有部分知識點可能是我們日常工作會見到或用到的。希望你能通過這篇文章學到更多知識。更多詳情敬請關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。