91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch 利用lstm做mnist手寫數字識別分類的實例

發布時間:2020-08-29 04:15:00 來源:腳本之家 閱讀:496 作者:xckkcxxck 欄目:開發技術

代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數據的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定義數據
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定義模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果
     
   def forward(self, x):
     #x的大小為(batch,1,28,28),所以我們需要將其轉化為rnn的輸入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
     x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
     out, _ = self.rnn(x) #使用默認的隱藏狀態,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分類結果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定義訓練過程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上這篇pytorch 利用lstm做mnist手寫數字識別分類的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

高陵县| 丹凤县| 日照市| 讷河市| 海南省| 河西区| 古交市| 漳州市| 平潭县| 凤台县| 泗阳县| 克什克腾旗| 社会| 申扎县| 江华| 施秉县| 三门县| 屏东县| 武汉市| 永宁县| 阜南县| 东兰县| 来安县| 海丰县| 崇信县| 宝清县| 赣榆县| 九龙坡区| 千阳县| 肇源县| 建阳市| 汕头市| 纳雍县| 贵南县| 阿克| 双江| 祁连县| 砚山县| 綦江县| 台前县| 武平县|