91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

Pytorch如何實現數據加載與數據預處理

發布時間:2021-08-18 15:31:46 來源:億速云 閱讀:576 作者:小新 欄目:開發技術

這篇文章給大家分享的是有關Pytorch如何實現數據加載與數據預處理的內容。小編覺得挺實用的,因此分享給大家做個參考,一起跟隨小編過來看看吧。

數據加載分為加載torchvision.datasets中的數據集以及加載自己使用的數據集兩種情況。

torchvision.datasets中的數據集

torchvision.datasets中自帶MNIST,Imagenet-12,CIFAR等數據集,所有的數據集都是torch.utils.data.Dataset的子類,都包含 _ _ len _ (獲取數據集長度)和 _ getItem _ _ (獲取數據集中每一項)兩個子方法。

Pytorch如何實現數據加載與數據預處理

Dataset源碼如上,可以看到其中包含了兩個沒有實現的子方法,之后所有的Dataet類都繼承該類,并根據數據情況定制這兩個子方法的具體實現。

因此當我們需要加載自己的數據集的時候也可以借鑒這種方法,只需要繼承torch.utils.data.Dataset類并重寫 init ,len,以及getitem這三個方法即可。這樣組著的類可以直接作為參數傳入到torch.util.data.DataLoader中去。

以CIFAR10為例 源碼:

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

加載自己的數據集

對于torchvision.datasets中有兩個不同的類,分別為DatasetFolder和ImageFolder,ImageFolder是繼承自DatasetFolder。

下面我們通過源碼來看一看folder文件中DatasetFolder和ImageFolder分別做了些什么

import torch.utils.data as data
from PIL import Image
import os
import os.path


def has_file_allowed_extension(filename, extensions): //檢查輸入是否是規定的擴展名
  """Checks if a file is an allowed extension.

  Args:
    filename (string): path to a file

  Returns:
    bool: True if the filename ends with a known image extension
  """
  filename_lower = filename.lower()
  return any(filename_lower.endswith(ext) for ext in extensions)


def find_classes(dir):
  classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //獲取root目錄下所有的文件夾名稱

  classes.sort()
  class_to_idx = {classes[i]: i for i in range(len(classes))} //生成類別名稱與類別id的對應Dictionary
  return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
  images = []
  dir = os.path.expanduser(dir)// 將~和~user轉化為用戶目錄,對參數中出現~進行處理
  for target in sorted(os.listdir(dir)):
    d = os.path.join(dir, target)
    if not os.path.isdir(d):
      continue

    for root, _, fnames in sorted(os.walk(d)): //os.work包含三個部分,root代表該目錄路徑 _代表該路徑下的文件夾名稱集合,fnames代表該路徑下的文件名稱集合
      for fname in sorted(fnames):
        if has_file_allowed_extension(fname, extensions):
          path = os.path.join(root, fname)
          item = (path, class_to_idx[target])
          images.append(item)  //生成(訓練樣本圖像目錄,訓練樣本所屬類別)的元組

  return images  //返回上述元組的列表


class DatasetFolder(data.Dataset):
  """A generic data loader where the samples are arranged in this way: ::

    root/class_x/xxx.ext
    root/class_x/xxy.ext
    root/class_x/xxz.ext

    root/class_y/123.ext
    root/class_y/nsdf3.ext
    root/class_y/asd932_.ext

  Args:
    root (string): Root directory path.
    loader (callable): A function to load a sample given its path.
    extensions (list[string]): A list of allowed extensions.
    transform (callable, optional): A function/transform that takes in
      a sample and returns a transformed version.
      E.g, ``transforms.RandomCrop`` for images.
    target_transform (callable, optional): A function/transform that takes
      in the target and transforms it.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    samples (list): List of (sample path, class_index) tuples
  """

  def __init__(self, root, loader, extensions, transform=None, target_transform=None):
    classes, class_to_idx = find_classes(root)
    samples = make_dataset(root, class_to_idx, extensions)
    if len(samples) == 0:
      raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                "Supported extensions are: " + ",".join(extensions)))

    self.root = root
    self.loader = loader
    self.extensions = extensions

    self.classes = classes
    self.class_to_idx = class_to_idx
    self.samples = samples

    self.transform = transform
    self.target_transform = target_transform

  def __getitem__(self, index):
    """
    根據index獲取sample 返回值為(sample,target)元組,同時如果該類輸入參數中有transform和target_transform,torchvision.transforms類型的參數時,將獲取的元組分別執行transform和target_transform中的數據轉換方法。
       Args:
      index (int): Index

    Returns:
      tuple: (sample, target) where target is class_index of the target class.
    """
    path, target = self.samples[index]
    sample = self.loader(path)
    if self.transform is not None:
      sample = self.transform(sample)
    if self.target_transform is not None:
      target = self.target_transform(target)

    return sample, target


  def __len__(self):
    return len(self.samples)

  def __repr__(self): //定義輸出對象格式 其中和__str__的區別是__repr__無論是print輸出還是直接輸出對象自身 都是以定義的格式進行輸出,而__str__ 只有在print輸出的時候會是以定義的格式進行輸出
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '  Number of datapoints: {}\n'.format(self.__len__())
    fmt_str += '  Root Location: {}\n'.format(self.root)
    tmp = '  Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    tmp = '  Target Transforms (if any): '
    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str



IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']


def pil_loader(path):
  # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  with open(path, 'rb') as f:
    img = Image.open(f)
    return img.convert('RGB')


def accimage_loader(path):
  import accimage
  try:
    return accimage.Image(path)
  except IOError:
    # Potentially a decoding problem, fall back to PIL.Image
    return pil_loader(path)


def default_loader(path):
  from torchvision import get_image_backend
  if get_image_backend() == 'accimage':
    return accimage_loader(path)
  else:
    return pil_loader(path)


class ImageFolder(DatasetFolder): 
  """A generic data loader where the images are arranged in this way: ::

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

  Args:
    root (string): Root directory path.
    transform (callable, optional): A function/transform that takes in an PIL image
      and returns a transformed version. E.g, ``transforms.RandomCrop``
    target_transform (callable, optional): A function/transform that takes in the
      target and transforms it.
    loader (callable, optional): A function to load an image given its path.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    imgs (list): List of (image path, class_index) tuples
  """
  def __init__(self, root, transform=None, target_transform=None,
         loader=default_loader):
    super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                     transform=transform,
                     target_transform=target_transform)
    self.imgs = self.samples

如果自己所要加載的數據組織形式如下

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

即不同類別的訓練數據分別存儲在不同的文件夾中,這些文件夾都在root(即形如 D:/animals 或者 /usr/animals )路徑下

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)

參數如下:

root (string) – Root directory path.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
loader – A function to load an image given its path. 就是上述源碼中


__getitem__(index)
Parameters: index (int) – Index
Returns:  (sample, target) where target is class_index of the target class.
Return type:  tuple

可以通過torchvision.datasets.ImageFolder進行加載

img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                      transform=transforms.Compose([
                        transforms.Scale(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor()])
                      )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))

對于所有的訓練樣本都在一個文件夾中 同時有一個對應的txt文件每一行分別是對應圖像的路徑以及其所屬的類別,可以參照上述class寫出對應的加載類

def default_loader(path):
  return Image.open(path).convert('RGB')


class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))

DataLoader解析

位于torch.util.data.DataLoader中 源代碼

該接口的主要目的是將pytorch中已有的數據接口如torchvision.datasets.ImageFolder,或者自定義的數據讀取接口轉化按照

batch_size的大小封裝為Tensor,即相當于在內置數據接口或者自定義數據接口的基礎上增加一維,大小為batch_size的大小,

得到的數據在之后可以通過封裝為Variable,作為模型的輸出

_ _ init _ _中所需的參數如下

1. dataset torch.utils.data.Dataset類的子類,可以是torchvision.datasets.ImageFolder等內置類,也可是繼承了torch.utils.data.Dataset的自定義類
2. batch_size 每一個batch中包含的樣本個數,默認是1 
3. shuffle 一般在訓練集中采用,默認是false,設置為true則每一個epoch都會將訓練樣本打亂
4. sampler 訓練樣本選取策略,和shuffle是互斥的 如果 shuffle為true,該參數一定要為None
5. batch_sampler BatchSampler 一次產生一個 batch 的 indices,和sampler以及shuffle互斥,一般使用默認的即可
  上述Sampler的源代碼地址如下[源代碼](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py)
6. num_workers 用于數據加載的線程數量 默認為0 即只有主線程用來加載數據
7. collate_fn 用來聚合數據生成mini_batch

使用的時候一般為如下使用方法:

train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
...

循環取DataLoader中的數據會觸發類中_ _ iter __方法,查看源代碼可知 其中調用的方法為 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 這一內部類

class DataLoaderIter(object):
  "Iterates once over the DataLoader's dataset, as specified by the sampler"

  def __init__(self, loader):
    self.dataset = loader.dataset
    self.collate_fn = loader.collate_fn
    self.batch_sampler = loader.batch_sampler
    self.num_workers = loader.num_workers
    self.pin_memory = loader.pin_memory and torch.cuda.is_available()
    self.timeout = loader.timeout
    self.done_event = threading.Event()

    self.sample_iter = iter(self.batch_sampler)

    if self.num_workers > 0:
      self.worker_init_fn = loader.worker_init_fn
      self.index_queue = multiprocessing.SimpleQueue()
      self.worker_result_queue = multiprocessing.SimpleQueue()
      self.batches_outstanding = 0
      self.worker_pids_set = False
      self.shutdown = False
      self.send_idx = 0
      self.rcvd_idx = 0
      self.reorder_dict = {}

      base_seed = torch.LongTensor(1).random_()[0]
      self.workers = [
        multiprocessing.Process(
          target=_worker_loop,
          args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
             base_seed + i, self.worker_init_fn, i))
        for i in range(self.num_workers)]

      if self.pin_memory or self.timeout > 0:
        self.data_queue = queue.Queue()
        self.worker_manager_thread = threading.Thread(
          target=_worker_manager_loop,
          args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
             torch.cuda.current_device()))
        self.worker_manager_thread.daemon = True
        self.worker_manager_thread.start()
      else:
        self.data_queue = self.worker_result_queue

      for w in self.workers:
        w.daemon = True # ensure that the worker exits on process exit
        w.start()

      _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
      _set_SIGCHLD_handler()
      self.worker_pids_set = True

      # prime the prefetch loop
      for _ in range(2 * self.num_workers):
        self._put_indices()

感謝各位的閱讀!關于“Pytorch如何實現數據加載與數據預處理”這篇文章就分享到這里了,希望以上內容可以對大家有一定的幫助,讓大家可以學到更多知識,如果覺得文章不錯,可以把它分享出去讓更多的人看到吧!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

潼南县| 奈曼旗| 万州区| 石河子市| 湘乡市| 福安市| 永福县| 辉南县| 永兴县| 鹤峰县| 天津市| 益阳市| 辉南县| 富平县| 东台市| 永济市| 阳西县| 历史| 日喀则市| 平利县| 丰城市| 大方县| 阳谷县| 晋中市| 八宿县| 侯马市| 松原市| 济源市| 元氏县| 屯门区| 滦平县| 阿巴嘎旗| 徐闻县| 江川县| 巴彦淖尔市| 长沙市| 凯里市| 古田县| 汨罗市| 朝阳县| 清河县|