您好,登錄后才能下訂單哦!
小編這次要給大家分享的是如何實現TensorFlow固化模型,文章內容豐富,感興趣的小伙伴可以來了解一下,希望大家閱讀完這篇文章之后能夠有所收獲。
前言
TensorFlow目前在移動端是無法training的,只能跑已經訓練好的模型,但一般的保存方式只有單一保存參數或者graph的,如何將參數、graph同時保存呢?
生成模型
主要有兩種方法生成模型,一種是通過freeze_graph把tf.train.write_graph()生成的pb文件與tf.train.saver()生成的chkp文件固化之后重新生成一個pb文件,這一種現在不太建議使用。另一種是把變量轉成常量之后寫入PB文件中。我們簡單的介紹下freeze_graph方法。
freeze_graph
這種方法我們需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代碼如下:
with tf.Session() as sess: saver = tf.train.Saver() saver.save(session, "model.ckpt") tf.train.write_graph(session.graph_def, '', 'graph.pb')
然后使用TensorFlow源碼中的freeze_graph工具進行固化操作:
首先需要build freeze_graph 工具( 需要 bazel ):
bazel build tensorflow/python/tools:freeze_graph
然后使用這個工具進行固化(/path/to/表示文件路徑):
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants
其實在TensorFlow中傳統的保存模型方式是保存常量以及graph的,而我們的權重主要是變量,如果我們把訓練好的權重變成常量之后再保存成PB文件,這樣確實可以保存權重,就是方法有點繁瑣,需要一個一個調用eval方法獲取值之后賦值,再構建一個graph,把W和b賦值給新的graph。
牛逼的Google為了方便大家使用,編寫了一個方法供我們快速的轉換并保存。
首先我們需要引入這個方法
from tensorflow.python.framework.graph_util import convert_variables_to_constants
在想要保存的地方加入如下代碼,把變量轉換成常量
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])
這里參數第一個是當前的session,第二個為graph,第三個是輸出節點名(如我的輸出層代碼是這樣的:)
with tf.name_scope('output'): w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN])) tf.summary.histogram('output/weight', w_out) b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN])) tf.summary.histogram('output/biases', b_out) out = tf.add(tf.matmul(dense2, w_out), b_out) out = tf.nn.softmax(out) predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')
由于我們采用了name_scope所以我們在predict之前需要加上output/
生成文件
with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
第一個參數是文件路徑,第二個是指文件操作的模式,這里指的是以二進制的方式寫入文件。
運行代碼,系統會生成一個PB文件,接下來我們要測試下這個模型是否能夠正常的讀取、運行。
測試模型
在Python環境下,我們首先需要加載這個模型,代碼如下:
with open('./model/rounded_graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'inputs/X:0': newInput_X}, return_elements=['output/predict:0'])
由于我們原本的網絡輸入值是一個placeholder,這里為了方便輸入我們也先定義一個新的placeholder:
newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")
在input_map的參數填入新的placeholder。
在調用我們的網絡的時候直接用這個新的placeholder接收數據,如:
text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})
然后就是運行我們的網絡,看是否可以運行吧。
看完這篇關于如何實現TensorFlow固化模型的文章,如果覺得文章內容寫得不錯的話,可以把它分享出去給更多人看到。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。