您好,登錄后才能下訂單哦!
本篇文章給大家分享的是有關怎么在Tensorflow中使用tfrecord輸入數據格式,小編覺得挺實用的,因此分享給大家學習,希望大家閱讀完這篇文章后可以有所收獲,話不多說,跟著小編一起來看看吧。
1. TFRecord格式介紹
TFRecord文件中的數據是通過tf.train.Example Protocol Buffer的格式存儲的,下面是tf.train.Example的定義
message Example { Features features = 1; }; message Features{ map<string,Feature> featrue = 1; }; message Feature{ oneof kind{ BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
從上述代碼可以看到,ft.train.Example 的數據結構相對簡潔。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字符串,屬性的取值可以為字符串(BytesList ),實數列表(FloatList )或整數列表(Int64List )。例如我們可以將解碼前的圖片作為字符串,圖像對應的類別標號作為整數列表。
2. 將自己的數據轉化為TFRecord格式
準備數據
在上一篇中,我們為了像偉大的MNIST致敬,所以選擇圖像的前綴來進行不同類別的分類依據,但是大多數的情況下,在進行分類任務的過程中,不同的類別都會放在不同的文件夾下,而且類別的個數往往浮動性又很大,所以針對這樣的情況,我們現在利用不同類別在不同文件夾中的圖像來生成TFRecord.
我們在Iris&Contact這個文件夾下有兩個文件夾,分別為iris,contact。對于每個文件夾中存放的是對應的圖片
轉換數據
數據準備好以后,就開始準備生成TFRecord,具體代碼如下:
import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt cwd='/home/ruyiwei/Documents/Iris&Contact/' classes={'iris','contact'} writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords") for index,name in enumerate(classes): class_path=cwd+name+'/' for img_name in os.listdir(class_path): img_path=class_path+img_name img=Image.open(img_path) img= img.resize((512,80)) img_raw=img.tobytes() #plt.imshow(img) # if you want to check you image,please delete '#' #plt.show() example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) writer.close()
3. Tensorflow從TFRecord中讀取數據
def read_and_decode(filename): # read iris_contact.tfrecords filename_queue = tf.train.string_input_producer([filename])# create a queue reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue)#return file_name and file features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), })#return image and label img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3 img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor label = tf.cast(features['label'], tf.int32) #throw label tensor return img, label
4. 將TFRecord中的數據保存為圖片
filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #return file and file_name features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(features['img_raw'], tf.uint8) image = tf.reshape(image, [512, 80, 3]) label = tf.cast(features['label'], tf.int32) with tf.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord=tf.train.Coordinator() threads= tf.train.start_queue_runners(coord=coord) for i in range(20): example, l = sess.run([image,label])#take out image and label img=Image.fromarray(example, 'RGB') img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image print(example, l) coord.request_stop() coord.join(threads)
以上就是怎么在Tensorflow中使用tfrecord輸入數據格式,小編相信有部分知識點可能是我們日常工作會見到或用到的。希望你能通過這篇文章學到更多知識。更多詳情敬請關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。