您好,登錄后才能下訂單哦!
這篇文章給大家介紹怎么用Tensorflow完成手寫數字識別,內容非常詳細,感興趣的小伙伴們可以參考借鑒,希望對大家能有所幫助。
深度學習最經典的任務問題就是分類。通過分類,我們可以將照片中的數字,人臉,動植物等等分到它屬于的那一類當中,完成識別。接下來,我就帶著大家一起完成一個簡單的程序,來實現經典問題手寫數字識別。
數據集
我們第一步需要收集一堆手寫數據,并且將每個手寫數字都標號類別,用來做成數據集。對于深度學習而言,一般的數據集大小至少上萬起。所以收集數據這個工作還是比較繁瑣的。不過呢,有人已經幫我們弄好了數據集,這就是鼎鼎有名的MNIST數據集。
MNIST數據集是一個標準的手寫數據集,如上圖所示,數據集里面有六萬個手寫數字且都標記完全。其中有五萬個手寫數字作為訓練集,另外一萬作為測試集。
這里有一份傳送門:
http://yann.lecun.com/exdb/mnist/
我們并不需要事先下載MNIST數據集,Tensorflow幾行代碼就可以搞定:
搭建網絡
準備好了數據集之后,我們開始用Tensorflow搭建神經網絡模型:
1.輸入輸出
tf.placeholder是占位符的意思,先把坑填好,之后會有數據填充進去。其中y_是輸入對應的正確的數字標簽,x就是手寫數字照片。
2.網絡主體
我們建立了一個四層全連接網絡,每一層的網絡寬度都是400。因為MNIST數據集的數字照片都是28*28的,所以第一層網絡的權重的形狀是[784,400],注意到我們使用了Dropout技術,所以代碼中有tf.nn.dropout。對于最后一層我們用softmax技術,將對0-9數字的預測歸一化,變成一個概率。
3.損失函數和優化器
對于損失函數,我們選擇了平方差函數,其實就是線性規劃。而優化器我們選擇了Adam,是目前主流的優化器。
訓練網絡
1.初始化
我們在這里做了兩件事情,一個是初始化網絡中變量,第二個建立一個存儲器,用來存儲訓練過程的一些變量。
2.訓練
第一行的循環是控制循環的次數,我們使用了隨機梯度訓練,就是每次更新參數的時候并不是一次性把五萬張照片一起塞進去,而是從中隨機選出來作為一個batch來訓練,這樣的做的好處是可以大大減輕計算量。我們需要在每一步都在訓練集上面訓練來更新網絡的參數,接著我們一定步驟后在測試集上面看看我們的訓練效果。
3.執行程序
才開始訓練集和測試集上的準確率是在10%附近,這是因為在網絡的參數沒有更新的時候,所有參數都是隨機的,相當于我們在瞎猜。一共有十個數字,所以猜對的概率是十分之一。之后,隨著訓練的進行,訓練集和測試集上的準確率都在增加。我們同時觀察訓練集和測試集上的準確率,是防止網絡過擬合把我們欺騙了。
訓練到一定步時,我們發現訓練集的準確率已經接近百分之百了,測試集上的準確率也達到了百分之九十七以上。簡簡單單的四層就能做到如此之高的準確率,可見神經網絡之神奇!
關于怎么用Tensorflow完成手寫數字識別就分享到這里了,希望以上內容可以對大家有一定的幫助,可以學到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。