在 PyTorch 中,有幾種常見的方法可以導入數據集:
torchvision.datasets
模塊導入常見的計算機視覺數據集,例如 CIFAR10、MNIST 等。可以使用 torchvision.datasets.CIFAR10
、torchvision.datasets.MNIST
等類來實例化數據集對象。示例代碼:
import torchvision.datasets as datasets
# 導入 CIFAR10 數據集
train_dataset = datasets.CIFAR10(root='path/to/dataset', train=True, transform=None, download=True)
# 導入 MNIST 數據集
test_dataset = datasets.MNIST(root='path/to/dataset', train=False, transform=None, download=True)
torch.utils.data.Dataset
,并實現 __len__
和 __getitem__
方法。這允許您以自定義方式加載和處理數據集。示例代碼:
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.labels[index]
if self.transform:
x = self.transform(x)
return x, y
# 使用自定義數據集
my_dataset = MyDataset(data, labels, transform=None)
torch.utils.data.DataLoader
類將數據集包裝成可迭代的數據加載器。數據加載器可以用于批量加載數據、多線程加載數據等。示例代碼:
from torch.utils.data import DataLoader
# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
以上是 PyTorch 中導入數據集的幾種常見方法。具體的選擇取決于數據集的類型和需求。