您好,登錄后才能下訂單哦!
pytorch 官網給出的例子中都是使用了已經定義好的特殊數據集接口來加載數據,而且其使用的數據都是官方給出的數據。如果我們有自己收集的數據集,如何用來訓練網絡呢?此時需要我們自己定義好數據處理接口。幸運的是pytroch給出了一個數據集接口類(torch.utils.data.Dataset),可以方便我們繼承并實現自己的數據集接口。
torch.utils.data
torch的這個文件包含了一些關于數據集處理的類。
class torch.utils.data.Dataset: 一個抽象類, 所有其他類的數據集類都應該是它的子類。而且其子類必須重載兩個重要的函數:len(提供數據集的大小)、getitem(支持整數索引)。
class torch.utils.data.TensorDataset: 封裝成tensor的數據集,每一個樣本都通過索引張量來獲得。
class torch.utils.data.ConcatDataset: 連接不同的數據集以構成更大的新數據集。
class torch.utils.data.Subset(dataset, indices): 獲取指定一個索引序列對應的子數據集。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 數據加載器。組合了一個數據集和采樣器,并提供關于數據的迭代器。
torch.utils.data.random_split(dataset, lengths): 按照給定的長度將數據集劃分成沒有重疊的新數據集組合。
class torch.utils.data.Sampler(data_source):所有采樣的器的基類。每個采樣器子類都需要提供 __iter__ 方法以方便迭代器進行索引 和一個 len方法 以方便返回迭代器的長度。
class torch.utils.data.SequentialSampler(data_source):順序采樣樣本,始終按照同一個順序。
class torch.utils.data.RandomSampler(data_source):無放回地隨機采樣樣本元素。
class torch.utils.data.SubsetRandomSampler(indices):無放回地按照給定的索引列表采樣樣本元素。
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照給定的概率來采樣樣本。
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一個batch中封裝一個其他的采樣器。
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采樣器可以約束數據加載進數據集的子集。
自定義數據集
自己定義的數據集需要繼承抽象類class torch.utils.data.Dataset,并且需要重載兩個重要的函數:__len__ 和__getitem__。
整個代碼僅供參考。在__init__中是初始化了該類的一些基本參數;__getitem__中是真正讀取數據的地方,迭代器通過索引來讀取數據集中數據,因此只需要這一個方法中加入讀取數據的相關功能即可;__len__給出了整個數據集的尺寸大小,迭代器的索引范圍是根據這個函數得來的。
import torch class myDataset(torch.nn.data.Dataset): def __init__(self, dataSource) self.dataSource = dataSource def __getitem__(self, index): element = self.dataSource[index] return element def __len__(self): return len(self.dataSource) train_data = myDataset(dataSource)
自定義數據集加載器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 數據加載器。組合了一個數據集和采樣器,并提供關于數據的迭代器。
dataset (Dataset) – 需要加載的數據集(可以是自定義或者自帶的數據集)。
batch_size – batch的大小(可選項,默認值為1)。
shuffle – 是否在每個epoch中shuffle整個數據集, 默認值為False。
sampler – 定義從數據中抽取樣本的策略. 如果指定了, shuffle參數必須為False。
num_workers – 表示讀取樣本的線程數, 0表示只有主線程。
collate_fn – 合并一個樣本列表稱為一個batch。
pin_memory – 是否在返回數據之前將張量拷貝到CUDA。
drop_last (bool, optional) – 設置是否丟棄最后一個不完整的batch,默認為False。
timeout – 用來設置數據讀取的超時時間的,但超過這個時間還沒讀取到數據的話就會報錯。應該為非負整數。
train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
以上這篇pytorch 自定義數據集加載方法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。