您好,登錄后才能下訂單哦!
代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數據的格式。
# -*- coding: utf-8 -*- """ Created on Tue Oct 9 08:53:25 2018 @author: www """ import sys sys.path.append('..') import torch import datetime from torch.autograd import Variable from torch import nn from torch.utils.data import DataLoader from torchvision import transforms as tfs from torchvision.datasets import MNIST #定義數據 data_tf = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5], [0.5]) ]) train_set = MNIST('E:/data', train=True, transform=data_tf, download=True) test_set = MNIST('E:/data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, 64, True, num_workers=4) test_data = DataLoader(test_set, 128, False, num_workers=4) #定義模型 class rnn_classify(nn.Module): def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2): super(rnn_classify, self).__init__() self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果 def forward(self, x): #x的大小為(batch,1,28,28),所以我們需要將其轉化為rnn的輸入格式(28,batch,28) x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28) x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28) out, _ = self.rnn(x) #使用默認的隱藏狀態,得到的out是(28, batch, hidden_feature) out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature) out = self.classifier(out) #得到分類結果 return out net = rnn_classify() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adadelta(net.parameters(), 1e-1) #定義訓練過程 def get_acc(output, label): total = output.shape[0] _, pred_label = output.max(1) num_correct = (pred_label == label).sum().item() return num_correct / total def train(net, train_data, valid_data, num_epochs, optimizer, criterion): if torch.cuda.is_available(): net = net.cuda() prev_time = datetime.datetime.now() for epoch in range(num_epochs): train_loss = 0 train_acc = 0 net = net.train() for im, label in train_data: if torch.cuda.is_available(): im = Variable(im.cuda()) # (bs, 3, h, w) label = Variable(label.cuda()) # (bs, h, w) else: im = Variable(im) label = Variable(label) # forward output = net(im) loss = criterion(output, label) # backward optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_acc += get_acc(output, label) cur_time = datetime.datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) if valid_data is not None: valid_loss = 0 valid_acc = 0 net = net.eval() for im, label in valid_data: if torch.cuda.is_available(): im = Variable(im.cuda()) label = Variable(label.cuda()) else: im = Variable(im) label = Variable(label) output = net(im) loss = criterion(output, label) valid_loss += loss.item() valid_acc += get_acc(output, label) epoch_str = ( "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data), valid_loss / len(valid_data), valid_acc / len(valid_data))) else: epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % (epoch, train_loss / len(train_data), train_acc / len(train_data))) prev_time = cur_time print(epoch_str + time_str) train(net, train_data, test_data, 10, optimizer, criterion)
以上這篇pytorch 利用lstm做mnist手寫數字識別分類的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。