在PyTorch中讀取CSV數據集通常有以下幾種方法:
import pandas as pd
import torch
# 讀取CSV文件
data = pd.read_csv('data.csv')
# 將數據轉換為PyTorch張量
tensor_data = torch.tensor(data.values)
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.tensor(self.data.iloc[idx].values)
dataset = MyDataset('data.csv')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, csv_file):
data = pd.read_csv(csv_file)
self.X = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32)
self.y = torch.tensor(data.iloc[:, -1].values, dtype=torch.long)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
dataset = CustomDataset('data.csv')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
這些是一些常用的方法,你可以根據自己的需求選擇適合的方法來讀取CSV數據集。