您好,登錄后才能下訂單哦!
最近在使用tensorflow進行網絡訓練的時候,需要提取出別人訓練好的卷積核的部分層的數據。由于tensorflow中的tensor和python中的list不同,無法直接使用加法進行拼接,后來發現一個函數可以完成tensor的拼接。
函數形式如下:
tf.concat(concat_dim,values,name='concat')
其中,第一個參數表示需要拼接的多維tensor,并且可以將多個tensor同事拼接,第二個表示按照哪一個維度拼接(從數字0開始)。
例子:創建一個三維的tensor,然后分別取出最后一個維度(注意:tensor支持與python中list相似的切片操作,可以使用這種方式進行拆分),然后在拼接在一起。
import tensorflow as tf weights=tf.Variable(tf.truncated_normal([2,3,4],dtype=tf.float32,stddev=1e-1),name='weights') weight1=weights[0:2,0:3,1:2] weight2=weights[0:2,0:3,2:3] weight3=weights[0:2,0:3,1:2] weight4=tf.concat([weight1,weight2,weight3],2) #2表示最后一個維度 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(weights)) print("****************") print(sess.run(weight4))
以上這篇Tensorflow進行多維矩陣的拆分與拼接實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。