您好,登錄后才能下訂單哦!
PyTorch中還單獨提供了一個sampler模塊,用來對數據進行采樣。常用的有隨機采樣器:RandomSampler,當dataloader的shuffle參數為True時,系統會自動調用這個采樣器,實現打亂數據。默認的是采用SequentialSampler,它會按順序一個一個進行采樣。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler,它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重采樣。
構建WeightedRandomSampler時需提供兩個參數:每個樣本的權重weights、共選取的樣本總數num_samples,以及一個可選參數replacement。權重越大的樣本被選中的概率越大,待選取的樣本數目一般小于全部的樣本數目。replacement用于指定是否可以重復選取某一個樣本,默認為True,即允許在一個epoch中重復采樣某一個數據。如果設為False,則當某一類的樣本被全部選取完,但其樣本數目仍未達到num_samples時,sampler將不會再從該類中選擇數據,此時可能導致weights參數失效。
下面舉例說明。
from dataSet import * dataset = DogCat('data/dogcat/', transform=transform) from torch.utils.data import DataLoader # 狗的圖片被取出的概率是貓的概率的兩倍 # 兩類圖片被取出的概率與weights的絕對大小無關,只和比值有關 weights = [2 if label == 1 else 1 for data, label in dataset] print(weights) from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader = DataLoader(dataset, batch_size=3, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())
輸出:
[2, 2, 1, 1, 2, 1, 1, 2] [1, 1, 0] [1, 0, 0] [0, 0, 1]
github 地址:
https://github.com/WebLearning17/CommonTool
以上這篇pytorch sampler對數據進行采樣的實現就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。