要在PyTorch中制作自己的數據集,你需要創建一個繼承自torch.utils.data.Dataset
的自定義數據集類。這個類需要實現__len__
和__getitem__
方法。
下面是一個簡單的例子,展示了如何創建一個自定義數據集類:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
在這個例子中,CustomDataset
類接受兩個參數data
和targets
,分別代表數據和對應的標簽。__len__
方法返回數據集的長度,__getitem__
方法根據給定的索引返回對應的數據和標簽。
接下來,你可以實例化這個自定義數據集類并將其用于創建一個DataLoader
對象,從而可以方便地迭代數據集進行訓練或測試:
data = [...] # your data
targets = [...] # your targets
custom_dataset = CustomDataset(data, targets)
dataloader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
現在你可以使用dataloader
來迭代自定義數據集進行訓練。