您好,登錄后才能下訂單哦!
這篇文章主要介紹了tensorflow中沒有output結點如何存儲成pb文件,具有一定借鑒價值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。
Tensorflow中保存成pb file 需要 使用函數
graph_util.convert_variables_to_constants(sess, sess.graph_def,
output_node_names=[]) []中需要填寫你需要保存的結點。如果保存的結點在神經網絡中沒有被顯示定義該怎么辦?
例如我使用了tf.contrib.slim或者keras,在tf的高層很多情況下都會這樣。
在寫神經網絡時,只需要簡單的一層層傳導,一個slim.conv2d層就包含了kernal,bias,activation function,非常的方便,好處是網絡結構一目了然,壞處是什么呢?
在嘗試保存pb的 output node names時,需要將最后的輸出結點保存下來,與這個結點相關的,從輸入開始,經過層層傳遞的嵌套函數或者操作的相關結點,都會被保存,但無效的例如 計算準確率,計算loss等,就可以省略了,因為保存的pb主要是用來做預測的。
在準備查看所有的結點名稱并選取保存時,發現scope "local3"里面僅有相關的weights 和biases,這兩個是單獨存在的,即保存這兩個參數并沒有任何意義。
那么這時候有兩種解決辦法:
方法一:
graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=[var.name[:-2] for var in tf.global_variables()])
那么這個的意思是所有的variable的都被保存下來 但函數中要求的是 node name 我們通過 global_variables獲得的是 變量名 并不是 節點名
(例如 output:0 就是變量名,又叫tensor name)
output就是 node name了。
在tensorboard中可以一窺究竟
通過這樣 也可以將 所有的變量全部保存下來(但是你并不能使用,是因為你的output并沒有名字,所以你不可以通過常用的sess.graph.get_tensor_by_name來使用)
方法二:
那就是直接改寫神經網絡了....當然了還是比較簡單的,只要改寫最后一個,改寫成output即可,tensorflow中無論是 變量、操作op、函數、都可以命名,那么這個地方是一個簡單的全連接,僅需要將weights*net(上一層的輸出) +bias 即可,我們只要將bias相加的結果命名為 ouput即可:
with tf.name_scope('local3'): local3_weights = tf.Variable(tf.truncated_normal([4096, self.output_size], stddev=0.1)) local3_bias = tf.Variable(tf.constant(0.1, shape=[self.output_size])) result = tf.add(tf.matmul(net, local3_weights), local3_bias, name="output")
這樣將上述的convert_variables_to_constants中的output_node_names只需要填寫一個['output']即可,因為這一個output結點,需要從input開始,將所有的神經網絡前向傳播的操作和參數全部保存下來,因此保存的結點數量 和 方法一保存的結點數量是一樣的(console顯示都是 convert 24)。
完整的pb保存為:(我是將ckpt讀入進來,然后存成pb的)
from tensorflow.python.platform import gfile load_ckpt(): path = './data/output/loss1.0/' print("read from ckpt") ckpt = tf.train.get_checkpoint_state(path) saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) def write2pb_file(): constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["output"]) with tf.gfile.GFile(path+'loss1.0.pb', mode='wb') as f: f.write(constant_graph.SerializeToString()) print("Model is saved as " + path+'loss1.0.pb') def main(): load_ckpt() write2pb_file()
如果是簡單的直接保存,那就更簡單了。
pb文件的read,很多人會將一個net寫成一個類,在引入的時候會將新建這個類,然后讀入ckpt文件,這完全沒有問題,但是在讀取pb時,就會發生問題,因為pb中已經包含了圖與參數,引入時會創建一個默認的圖,但是net類中自己也會創建一個圖,那么這時候你運行程序,參數其實并沒有使用.pb的文件。
所以我們不能創建net類,然后直接讀入.pb文件,對.pb文件,通過如下代碼,獲取.pb的graph中的輸入和輸出。
self.output = self.sess.graph.get_tensor_by_name("output:0") self.input = self.sess.graph.get_tensor_by_name("images:0")
注意此時要加:0 因為你獲取的不再是結點了,而是一個真實的變量,我的理解是,結點相當于一個類,:0是對象,默認初始化值就是對象的初始化。
然后就可以通過self.sess.run(self.output(feed_dict={self.input: your_input})))運行你的網絡了!
感謝你能夠認真閱讀完這篇文章,希望小編分享的“tensorflow中沒有output結點如何存儲成pb文件”這篇文章對大家有幫助,同時也希望大家多多支持億速云,關注億速云行業資訊頻道,更多相關知識等著你來學習!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。