您好,登錄后才能下訂單哦!
本篇內容主要講解“Pytorch的使用技巧有哪些”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學習“Pytorch的使用技巧有哪些”吧!
訓練模型,最常看的指標就是 Loss。我們可以根據 Loss 的收斂情況,初步判斷模型訓練的好壞。
如果,Loss 值突然上升了,那說明訓練有問題,需要檢查數據和代碼。
如果,Loss 值趨于穩定,那說明訓練完畢了。
觀察 Loss 情況,最直觀的方法,就是繪制 Loss 曲線圖。
通過繪圖,我們可以很清晰的看到,左圖還有收斂空間,而右圖已經完全收斂。
通過 Loss 曲線,我們可以分析模型訓練的好壞,模型是否訓練完成,起到一個很好的“監控”作用。
繪制 Loss 曲線圖,第一步就是需要保存訓練過程中的 Loss 值。
一個最簡單的方法是使用,sys.stdout 標準輸出重定向,簡單好用,實乃“煉丹”必備“良寶”。
import os import sys class Logger(): def __init__(self, filename="log.txt"): self.terminal = sys.stdout self.log = open(filename, "w") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass sys.stdout = Logger() print("Jack Cui") print("https://cuijiahua.com") print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")
代碼很簡單,創建一個 log.py 文件,自己寫一個 Logger 類,并采用 sys.stdout 重定向輸出。
在 Terminal 中,不僅可以使用 print 打印結果,同時也會將結果保存到 log.txt 文件中。
運行 log.py,打印 print 內容的同時,也將內容寫入了 log.txt 文件中。
使用這個代碼,就可以在打印 Loss 的同時,將結果保存到指定的 txt 中,比如保存上篇文章訓練 UNet 的 Loss。
Matplotlib 是一個 Python 的繪圖庫,簡單好用。
簡單幾行命令,就可以繪制曲線圖、散點圖、條形圖、直方圖、餅圖等等。
在深度學習中,一般就是繪制曲線圖,比如 Loss 曲線、Acc 曲線。
舉一個,簡單的例子。
使用 sys.stdout 保存的 train_loss.txt,繪制 Loss 曲線。
train_loss.txt 下載地址:點擊查看
思路非常簡單,讀取 txt 內容,解析 txt 內容,使用 Matplotlib 繪制曲線。
import matplotlib.pyplot as plt # Jupyter notebook 中開啟 # %matplotlib inline with open('train_loss.txt', 'r') as f: train_loss = f.readlines() train_loss = list(map(lambda x:float(x.strip()), train_loss)) x = range(len(train_loss)) y = train_loss plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5) plt.xlabel('Epoch') plt.ylabel('Loss Value') plt.legend() plt.show()
指定 x 和 y 對應的值,就可以繪制。
是不是很簡單?
說到保存日志,那不得不提 Python 的內置標準模塊 Logging,它主要用于輸出運行日志,可以設置輸出日志的等級、日志保存路徑、日志文件回滾等,同時,我們也可以設置日志的輸出格式。
import logging def get_logger(LEVEL, log_file = None): head = '[%(asctime)-15s] [%(levelname)s] %(message)s' if LEVEL == 'info': logging.basicConfig(level=logging.INFO, format=head) elif LEVEL == 'debug': logging.basicConfig(level=logging.DEBUG, format=head) logger = logging.getLogger() if log_file != None: fh = logging.FileHandler(log_file) logger.addHandler(fh) return logger logger = get_logger('info') logger.info('Jack Cui') logger.info('https://cuijiahua.com') logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')
只需要幾行代碼,進行一個簡單的封裝使用。使用函數 get_logger 創建一個級別為 info 的 logger,如果指定 log_file,則會對日志進行保存。
logging 默認支持的日志一共有 5 個等級:
日志級別等級 CRITICAL > ERROR > WARNING > INFO > DEBUG。
默認的日志級別設置為 WARNING,也就是說如果不指定日志級別,只會顯示大于等于 WARNING 級別的日志。
例如:
import logging logging.debug("debug_msg") logging.info("info_msg") logging.warning("warning_msg") logging.error("error_msg") logging.critical("critical_msg")
運行結果:
WARNING:root:warning_msg ERROR:root:error_msg CRITICAL:root:critical_msg
可以看到 info 和 debug 級別的日志不會輸出,默認的日志格式也比較簡單。
默認的日志格式為日志級別:Logger名稱:用戶輸出消息
當然,我們可以通過,logging.basicConfig 的 format 參數,設置日志格式。
字段有很多,可謂應有盡有,足以滿足我們定制化的需求。
上文介紹的“法寶”,并非針對深度學習“煉丹”使用的工具。
而 TensorboardX 則不同,它是專門用于深度學習“煉丹”的高級“法寶”。
早些時候,很多人更喜歡用 Tensorflow 的原因之一,就是 Tensorflow 框架有個一個很好的可視化工具 Tensorboard。
Pytorch 要想使用 Tensorboard 配置起來費勁兒不說,還有很多 Bug。
Pytorch 1.1.0 版本發布后,打破了這個局面,TensorBoard 成為了 Pytorch 的正式可用組件。
在 Pytorch 中,這個可視化工具叫做 TensorBoardX,其實就是針對 Tensorboard 的一個封裝,使得 PyTorch 用戶也能夠調用 Tensorboard。
TensorboardX 安裝也非常簡單,使用 pip 即可安裝。
pip install tensorboardX
tensorboardX 使用也很簡單,編寫如下代碼。
from tensorboardX import SummaryWriter # 創建 writer1 對象 # log 會保存到 runs/exp 文件夾中 writer1 = SummaryWriter('runs/exp') # 使用默認參數創建 writer2 對象 # log 會保存到 runs/日期_用戶名 格式的文件夾中 writer2 = SummaryWriter() # 使用 commet 參數,創建 writer3 對象 # log 會保存到 runs/日期_用戶名_resnet 格式的文件中 writer3 = SummaryWriter(comment='_resnet')
使用的時候,創建一個 SummaryWriter 對象即可,以上展示了三種初始化 SummaryWriter 的方法:
提供一個路徑,將使用該路徑來保存日志
無參數,默認將使用 runs/日期_用戶名 路徑來保存日志
提供一個 comment 參數,將使用 runs/日期_用戶名+comment 路徑來保存日志
運行結果:
有了 writer 我們就可以往日志里寫入數字、圖片、甚至聲音等數據。
這個是最簡單的,使用 add_scalar 方法來記錄數字常量。
add_scalar(tag, scalar_value, global_step=None, walltime=None)
總共 4 個參數。
tag (string): 數據名稱,不同名稱的數據使用不同曲線展示
scalar_value (float): 數字常量值
global_step (int, optional): 訓練的 step
walltime (float, optional): 記錄發生的時間,默認為 time.time()
需要注意,這里的 scalar_value 一定是 float 類型,如果是 PyTorch scalar tensor,則需要調用 .item() 方法獲取其數值。我們一般會使用 add_scalar 方法來記錄訓練過程的 loss、accuracy、learning rate 等數值的變化,直觀地監控訓練過程。
運行如下代碼:
from tensorboardX import SummaryWriter writer = SummaryWriter('runs/scalar_example') for i in range(10): writer.add_scalar('quadratic', i**2, global_step=i) writer.add_scalar('exponential', 2**i, global_step=i) writer.close()
通過 add_scalar 往日志里寫入數字,日志保存到 runs/scalar_example中,writer 用完要記得 close,否則無法保存數據。
在 cmd 中使用如下命令:
tensorboard --logdir=runs/scalar_example --port=8088
指定日志地址,使用端口號,在瀏覽器中,就可以使用如下地址,打開 Tensorboad。
http://localhost:8088/
省去了我們自己寫代碼可視化的麻煩。
使用 add_image 方法來記錄單個圖像數據。注意,該方法需要 pillow 庫的支持。
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')
參數:
tag (string):數據名稱
img_tensor (torch.Tensor / numpy.array):圖像數據
global_step (int, optional):訓練的 step
walltime (float, optional):記錄發生的時間,默認為 time.time()
dataformats (string, optional):圖像數據的格式,默認為 'CHW',即 Channel x Height x Width,還可以是 'CHW'、'HWC' 或 'HW' 等
我們一般會使用 add_image 來實時觀察生成式模型的生成效果,或者可視化分割、目標檢測的結果,幫助調試模型。
from tensorboardX import SummaryWriter from urllib.request import urlretrieve import cv2 urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg') urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg') urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg') writer = SummaryWriter('runs/image_example') for i in range(1, 4): writer.add_image('UNet_Seg', cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB), global_step=i, dataformats='HWC') writer.close()
代碼就是下載上篇文章數據集里的三張圖片,然后使用 Tensorboard 可視化處理來,使用 8088 端口開打 Tensorboard:
tensorboard --logdir=runs/image_example --port=8088
運行結果:
試想一下,一邊訓練,一邊輸出圖片結果,是不是很酸爽呢?
Tensorboard 中常用的 Scalar 和 Image,直方圖、運行圖、嵌入向量等,可以查看官方手冊進行學習,方法都是類似的,簡單好用。
到此,相信大家對“Pytorch的使用技巧有哪些”有了更深的了解,不妨來實際操作一番吧!這里是億速云網站,更多相關內容可以進入相關頻道進行查詢,關注我們,繼續學習!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。