您好,登錄后才能下訂單哦!
下面代碼的功能是先訓練一個簡單的模型,然后保存模型,同時保存到一個pb文件當中,后續可以從pd文件里讀取權重值。
import tensorflow as tf import numpy as np import os import h6py import pickle from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile #設置使用指定GPU os.environ['CUDA_VISIBLE_DEVICES'] = '1' #下面這段代碼是在訓練好之后將所有的權重名字和權重值羅列出來,訓練的時候需要注釋掉 reader = tf.train.NewCheckpointReader('./model.ckpt-100') variables = reader.get_variable_to_shape_map() for ele in variables: print(ele) print(reader.get_tensor(ele)) x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False#設成True去訓練模型 train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) graph_def = tf.get_default_graph().as_graph_def() #通過修改下面的函數,個人覺得理論上能夠實現修改權重,但是很復雜,如果哪位有好辦法,歡迎指教 output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable']) with tf.gfile.FastGFile('./test.pb', 'wb') as f: f.write(output_graph_def.SerializeToString()) with tf.Session() as sess: #對應最后一部分的寫,這里能夠將對應的變量取出來 with gfile.FastGFile('./test.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) res = tf.import_graph_def(graph_def, return_elements=['Variable:0']) print(sess.run(res)) print(sess.run(graph_def))
以上這篇tensorflow 保存模型和取出中間權重例子就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。