在PyTorch中自定義數據集需要繼承torch.utils.data.Dataset
類,并實現以下方法:
__init__(self, ...)
:初始化方法,可以在這里加載數據或設置數據路徑等。__len__(self)
:返回數據集的大小。__getitem__(self, index)
:根據索引返回數據樣本。以下是一個例子,假設我們有一個包含圖像和標簽的數據集:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = {
'image': self.data[index],
'label': self.labels[index]
}
return sample
# 使用自定義數據集
data = [...] # 圖像數據
labels = [...] # 圖像標簽
custom_dataset = CustomDataset(data, labels)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
在上面的例子中,CustomDataset
類繼承了torch.utils.data.Dataset
,并實現了__init__
、__len__
和__getitem__
方法。然后我們可以通過創建一個DataLoader
對象來加載自定義數據集,以便于后續的訓練或測試。