要實現自定義數據集類,需要繼承PyTorch中的Dataset類,并重寫其中的兩個方法:len__和__getitem。下面是一個簡單的例子,演示如何實現一個自定義數據集類:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_point = self.data[index]
target = self.targets[index]
return data_point, target
在上面的例子中,CustomDataset類接收兩個參數data和targets作為初始化參數,分別表示數據和標簽。然后重寫了__len__方法,返回數據集的長度,重寫了__getitem__方法,根據索引index返回對應的數據點和標簽。
使用這個自定義數據集類的方法如下:
data = [...] # your data
targets = [...] # your targets
custom_dataset = CustomDataset(data, targets)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
for data, target in data_loader:
# do something with data and target
這樣就可以通過自定義數據集類來加載自己的數據集,并使用DataLoader來批量加載數據。