您好,登錄后才能下訂單哦!
小編給大家分享一下關于圖片標準化函數per_image_standardization的用法簡介,希望大家閱讀完這篇文章后大所收獲,下面讓我們一起去探討方法吧!
實驗環境:windows 7,anaconda 3(Python 3.5),tensorflow(gpu/cpu)
函數介紹:標準化處理可以使得不同的特征具有相同的尺度(Scale)。
這樣,在使用梯度下降法學習參數的時候,不同特征對參數的影響程度就一樣了。
tf.image.per_image_standardization(image),此函數的運算過程是將整幅圖片標準化(不是歸一化),加速神經網絡的訓練。
主要有如下操作,(x - mean) / adjusted_stddev,其中x為圖片的RGB三通道像素值,mean分別為三通道像素的均值,adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))。
stddev為三通道像素的標準差,image.NumElements()計算的是三通道各自的像素個數。
實驗代碼:
import tensorflow as tf import matplotlib.image as img import matplotlib.pyplot as plt import numpy as np sess = tf.InteractiveSession() image = img.imread('D:/Documents/Pictures/logo7.jpg') shape = tf.shape(image).eval() h,w = shape[0],shape[1] standardization_image = tf.image.per_image_standardization(image)#標準化 fig = plt.figure() fig1 = plt.figure() ax = fig.add_subplot(111) ax.set_title('orginal image') ax.imshow(image) ax1 = fig1.add_subplot(311) ax1.set_title('original hist') ax1.hist(sess.run(tf.reshape(image,[h*w,-1]))) ax1 = fig1.add_subplot(313) ax1.set_title('standardization hist') ax1.hist(sess.run(tf.reshape(standardization_image,[h*w,-1]))) plt.ion() plt.show()
實驗結果:
兩幅hist圖分別是原圖和標準化后的RGB的像素值分布圖,可以看到只是將圖片的像素值大小限定到一個范圍,但是像素值的分布為改變。
補充知識:tensorflow運行單張圖像與加載模型時注意的問題
關于模型的保存加載:
在做實驗的情況下,一般使用save函數與restore函數就足夠用,該剛發只加載模型的參數而不加載模型,這意味著
當前的程序要能找到模型的結構
saver = tf.train.Saver()#聲明saver用來保存模型 with tf.Session() as sess: for i in range(train_step): #.....訓練操作 if i%100 == 0 && i!= 0:#每間隔訓練100次存儲一個模型,默認最多能存5個,如果超過5個先將序號小的覆蓋掉 saver.save(sess,str(i)+"_"+'model.ckpt',global_step=i)
得到的文件如下:
在一個文件夾中,會有一個checkpoint文件,以及一系列不同訓練階段的模型文件,如下圖
ckeckpoint文件可以放在編輯器里面打開看,里面記錄的是每個階段保存模型的信息,同時也是記錄最近訓練的檢查點
ckpt文件是模型參數,index文件一般用不到(我也查到是啥-_-|||)
在讀取模型時,聲明一個saver調用restore函數即可,我看很多博客里面寫的都是添加最近檢查點的模型,這樣添加的模型都是最后一次訓練的結果,想要加載固定的模型,直接把模型參數名稱的字符串寫到參數里就行了,如下段程序
saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "step_1497batch_64model.ckpt-1497")#加載對應的參數
這樣就把參數加載到Session當中,如果有數據,就可以直接塞進來進行計算了
運行單張圖片:
運行單張圖像的方法的流程大致如下,首先使用opencv或者Image或者使用numpy將圖像讀進來,保存成numpy的array的格式
接下來可以對圖像使用opencv進行預處理。然后將處理后的array使用feed_dict的方式輸入到tensorflow的placeholder中,這里注意兩點,不要單獨的使用下面的方法將tensor轉換成numpy再進行處理,除非是想查看一下圖像輸出,否則在驗證階段,強烈不要求這樣做,盡量使用feed_dict,原因后面說明
numpy_img = sess.run(tensor_img)#將tensor轉換成numpy
這里注意一點,如果你的圖像是1通道的圖像,即灰度圖,那么你得到的numpy是一個二維矩陣,將使用opencv讀入的圖像輸出shape會得到如(424,512)這樣的形狀,分別表示行和列,但是在模型當中通常要要有batch和通道數,所以需要將圖像使用python opencv庫中的reshape函數轉換成四維的矩陣,如
cv_img = cv_img.reshape(1,cv_img.shape[0],cv_img.shape[1],1)#cv_img是使用Opencv讀進來的圖片
用來輸入到網絡中的placeholder設置為如下,即可進行輸入了
img_raw = tf.placeholder(dtype=tf.float32, shape=[1,512, 424, 1], name='input')
測試:
如果使用的是自己的數據集,通常是制作成tfrecords,在訓練和測試的過程中,需要讀取tfrecords文件,這里注意,千萬不要把讀取tfrecords文件的函數放到循環當中,而是把這個文件放到外面,否則你訓練或者測試的數據都是同一批,Loss會固定在一個值!
這是因為tfrecords在讀取的過程中是將圖像信息加入到一個隊列中進行讀取,不要當成普通的函數調用,要按照tensorflow的思路,將它看成一個節點!
def read_data(tfrecords_file, batch_size, image_size):#讀取tfrecords文件 filename_queue = tf.train.string_input_producer([tfrecords_file]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(img_features['image_raw'], tf.float32) min_after_dequeue = 1000 image = tf.reshape(image, [image_size, image_size,1]) image = tf.image.resize_images(image, (32,32),method=3)#縮放成32×32 image = tf.image.per_image_standardization(image)#圖像標準化 label = tf.cast(img_features['label'], tf.int32) capacity = min_after_dequeue + 3 * batch_size image_batch, label_batch = tf.train.shuffle_batch([image, label], min_after_dequeue = min_after_dequeue) return image_batch, tf.one_hot(label_batch,6)#返回的標簽經過one_hot編碼 #將得到的圖像數據與標簽都是tensor哦,不能輸出的! read_image_batch,read_label_batch = read_data('train_data\\tfrecord\\TrainC6_95972.tfrecords',batch_size,120)
回到在運行單張圖片的那個問題,直接對某個tensor進行sess.run()會得到圖計算后的類型,也就是咱們python中常見的類型。
使用sess.run(feed_dict={…})得到的計算結果和直接使用sess.run有什么不同呢?
可以使用一個循環實驗一下,在循環中不停的調用sess.run()相當于每次都向圖中添加節點,而使用sess.run(feed_dict={})是向圖中開始的位置添加數據!
結果會發現,直接使用sess.run()的運行會越來越慢,使用sess.run(feed_dict={})會運行的飛快!
為什么要提這個呢?
在上面的read_data中有這么三行函數
image = tf.reshape(image, [image_size, image_size,1])#與opencv的reshape結果一樣 image = tf.image.resize_images(image, (32,32),method=3)#縮放成32×32,與opencv的resize結果一樣,插值方法要選擇三次立方插值 image = tf.image.per_image_standardization(image)#圖像標準化
如果想要在將訓練好的模型作為網絡節點添加到系統中,得到的數據必須是經過與訓練數據經過相同處理的圖像,也就是必須要對原始圖像經過上面的處理。如果使用其他的庫容易造成結果對不上,最好使用與訓練數據處理時相同的函數。
如果使用將上面的函數當成普通的函數使用,得到的是一個tensor,沒有辦法進行其他的圖像預處理,需要先將tensor變成numpy類型,問題來了,想要變成numpy類型,就得調用sess.run(),如果模型作為接口死循環,那么就會一直使用sess.run,效率會越來越慢,最后卡死!
原因在于你沒有將tensorflow中的函數當成節點調用,而是將其當成普通的函數調用了!
解決辦法就是按部就班的來,將得到的numpy數據先提前處理好,然后使用sess.run(feed_dict)輸入到placeholder中,按照圖的順序一步一步運行即可!
如下面程序
with tf.name_scope('inputs'): img_raw = tf.placeholder(dtype=tf.float32, shape=[1,120, 120, 1], name='input')#輸入數據 keep_prob = tf.placeholder(tf.float32,name='keep_prob') with tf.name_scope('preprocess'):#圖中的預處理函數,當成節點順序調用 img_120 = tf.reshape(img_raw, [120, 120,1]) img_norm = tf.cast(img_120, "float32") / 256 img_32 = tf.image.resize_images(img_norm, (32,32),method=3) img_std = tf.image.per_image_standardization(img_32) img = tf.reshape(img_std, [1,32, 32,1]) with tf.name_scope('output'):#圖像塞到網絡中 output = MyNet(img,keep_prob,n_cls) ans = tf.argmax(tf.nn.softmax(output),1)#計算模型得到的結果 init = tf.global_variables_initializer() saver = tf.train.Saver() if __name__ == '__main__': with tf.Session() as sess: sess.run(init) saver.restore(sess, "step_1497batch_64model.ckpt-1497")#效果更好 index = 0 path = "buffer\\" while True: f = path + str(index)+'.jpg'#從0.jpg、1.jpg、2.jpg.....一直讀 if os.path.exists(f): cv_img = cv.imread(f,0) cv_img = OneImgPrepro(cv_img) cv_img = cv_img.reshape(1,cv_img.shape[0],cv_img.shape[1],1)#需要reshape成placeholder可接收型 clas = ans.eval(feed_dict={img_raw:cv_img,keep_prob:1})#feed的速度快! print(clas)#輸出分類 index += 1
看完了這篇文章,相信你對關于圖片標準化函數per_image_standardization的用法簡介有了一定的了解,想了解更多相關知識,歡迎關注億速云行業資訊頻道,感謝各位的閱讀!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。