您好,登錄后才能下訂單哦!
數據增強
卷積神經網絡非常容易出現過擬合的問題,而數據增強的方法是對抗過擬合問題的一個重要方法。
2012 年 AlexNet 在 ImageNet 上大獲全勝,圖片增強方法功不可沒,因為有了圖片增強,使得訓練的數據集比實際數據集多了很多'新'樣本,減少了過擬合的問題,下面我們來具體解釋一下。
常用的數據增強方法
常用的數據增強方法如下:
1.對圖片進行一定比例縮放
2.對圖片進行隨機位置的截取
3.對圖片進行隨機的水平和豎直翻轉
4.對圖片進行隨機角度的旋轉
5.對圖片進行亮度、對比度和顏色的隨機變化
這些方法 pytorch 都已經為我們內置在了 torchvision 里面,我們在安裝 pytorch 的時候也安裝了 torchvision,下面我們來依次展示一下這些數據增強方法。
import sys sys.path.append('..') from PIL import Image from torchvision import transforms as tfs # 讀入一張圖片 im = Image.open('./cat.png') im
隨機比例放縮
隨機比例縮放主要使用的是 torchvision.transforms.Resize()
這個函數,第一個參數可以是一個整數,那么圖片會保存現在的寬和高的比例,并將更短的邊縮放到這個整數的大小,第一個參數也可以是一個 tuple,那么圖片會直接把寬和高縮放到這個大小;第二個參數表示放縮圖片使用的方法,比如最鄰近法,或者雙線性差值等,一般雙線性差值能夠保留圖片更多的信息,所以 pytorch 默認使用的是雙線性差值,你可以手動去改這個參數,更多的信息可以看看文檔
# 比例縮放 print('before scale, shape: {}'.format(im.size)) new_im = tfs.Resize((100, 200))(im) print('after scale, shape: {}'.format(new_im.size)) new_im
隨機位置截取
隨機位置截取能夠提取出圖片中局部的信息,使得網絡接受的輸入具有多尺度的特征,所以能夠有較好的效果。在 torchvision 中主要有下面兩種方式,一個是 torchvision.transforms.RandomCrop()
,傳入的參數就是截取出的圖片的長和寬,對圖片在隨機位置進行截取;第二個是 torchvision.transforms.CenterCrop()
,同樣傳入介曲初的圖片的大小作為參數,會在圖片的中心進行截取
# 隨機裁剪出 100 x 100 的區域 random_im1 = tfs.RandomCrop(100)(im) random_im1
# 中心裁剪出 100 x 100 的區域 center_im = tfs.CenterCrop(100)(im) center_im
隨機的水平和豎直方向翻轉
對于上面這一張貓的圖片,如果我們將它翻轉一下,它仍然是一張貓,但是圖片就有了更多的多樣性,所以隨機翻轉也是一種非常有效的手段。在 torchvision 中,隨機翻轉使用的是 torchvision.transforms.RandomHorizontalFlip()
和 torchvision.transforms.RandomVerticalFlip()
# 隨機水平翻轉 h_filp = tfs.RandomHorizontalFlip()(im) h_filp
# 隨機豎直翻轉 v_flip = tfs.RandomVerticalFlip()(im) v_flip
隨機角度旋轉
一些角度的旋轉仍然是非常有用的數據增強方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation()
來實現,其中第一個參數就是隨機旋轉的角度,比如填入 10,那么每次圖片就會在 -10 ~ 10 度之間隨機旋轉
rot_im = tfs.RandomRotation(45)(im) rot_im
亮度、對比度和顏色的變化
除了形狀變化外,顏色變化又是另外一種增強方式,其中可以設置亮度變化,對比度變化和顏色變化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 來實現的,第一個參數就是亮度的比例,第二個是對比度,第三個是飽和度,第四個是顏色
# 亮度 bright_im = tfs.ColorJitter(brightness=1)(im) # 隨機從 0 ~ 2 之間亮度變化,1 表示原圖 bright_im
# 對比度 contrast_im = tfs.ColorJitter(contrast=1)(im) # 隨機從 0 ~ 2 之間對比度變化,1 表示原圖 contrast_im
# 顏色 color_im = tfs.ColorJitter(hue=0.5)(im) # 隨機從 -0.5 ~ 0.5 之間對顏色變化 color_im
上面我們講了這么圖片增強的方法,其實這些方法都不是孤立起來用的,可以聯合起來用,比如先做隨機翻轉,然后隨機截取,再做對比度增強等等,torchvision 里面有個非常方便的函數能夠將這些變化合起來,就是 torchvision.transforms.Compose(),下面我們舉個例子
im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5) ])
import matplotlib.pyplot as plt %matplotlib inline nrows = 3 ncols = 3 figsize = (8, 8) _, figs = plt.subplots(nrows, ncols, figsize=figsize) for i in range(nrows): for j in range(ncols): figs[i][j].imshow(im_aug(im)) figs[i][j].axes.get_xaxis().set_visible(False) figs[i][j].axes.get_yaxis().set_visible(False) plt.show()
可以看到每次做完增強之后的圖片都有一些變化,所以這就是我們前面講的,增加了一些'新'數據
下面我們使用圖像增強進行訓練網絡,看看具體的提升究竟在什么地方,使用 ResNet 進行訓練
使用數據增強
import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.autograd import Variable from torchvision.datasets import CIFAR10 from utils import train, resnet from torchvision import transforms as tfs # 使用數據增強 def train_tf(x): im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x def test_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=train_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=test_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
不使用數據增強
# 不使用數據增強 def data_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=data_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=data_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
從上面可以看出,對于訓練集,不做數據增強跑 10 次,準確率已經到了 95%,而使用了數據增強,跑 10 次準確率只有 75%,說明數據增強之后變得更難了。
而對于測試集,使用數據增強進行訓練的時候,準確率會比不使用更高,因為數據增強提高了模型應對于更多的不同數據集的泛化能力,所以有更好的效果。
以上就是深度學習入門之Pytorch 數據增強的實現的詳細內容,更多關于Pytorch 數據增強的資料請關注億速云其它相關文章!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。