您好,登錄后才能下訂單哦!
這篇文章主要介紹“python的tf.train.batch函數怎么用”的相關知識,小編通過實際案例向大家展示操作過程,操作方法簡單快捷,實用性強,希望這篇“python的tf.train.batch函數怎么用”文章能幫助大家解決問題。
tf.train.batch( tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None )
其中:
1、tensors:利用slice_input_producer獲得的數據組合。
2、batch_size:設置每次從隊列中獲取出隊數據的數量。
3、num_threads:用來控制線程的數量,如果其值不唯一,由于線程執行的特性,數據獲取可能變成亂序。
4、capacity:一個整數,用來設置隊列中元素的最大數量
5、allow_samller_final_batch:當其為True時,如果隊列中的樣本數量小于batch_size,出隊的數量會以最終遺留下來的樣本進行出隊;當其為False時,小于batch_size的樣本不會做出隊處理。
6、name:名字
import pandas as pd import numpy as np import tensorflow as tf # 生成數據 def generate_data(): num = 18 label = np.arange(num) return label # 獲取數據 def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True) return label_batch # 數據組 label = get_batch_data() sess = tf.Session() # 初始化變量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch訓練的參數 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): # 自動獲取下一組數據 l = sess.run(label) print(l) except tf.errors.OutOfRangeError: print('Done training') finally: coord.request_stop() coord.join(threads) sess.close()
運行結果為:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
[17]
Done training
相比allow_samller_final_batch=True,輸出結果少了[17]
import pandas as pd import numpy as np import tensorflow as tf # 生成數據 def generate_data(): num = 18 label = np.arange(num) return label # 獲取數據 def get_batch_data(): label = generate_data() input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2) label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return label_batch # 數據組 label = get_batch_data() sess = tf.Session() # 初始化變量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch訓練的參數 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): # 自動獲取下一組數據 l = sess.run(label) print(l) except tf.errors.OutOfRangeError: print('Done training') finally: coord.request_stop() coord.join(threads) sess.close()
運行結果為:
[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 0 1]
[2 3 4 5 6]
[ 7 8 9 10 11]
[12 13 14 15 16]
Done training
關于“python的tf.train.batch函數怎么用”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識,可以關注億速云行業資訊頻道,小編每天都會為大家更新不同的知識點。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。