您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“Pytorch轉ONNX中tracing機制有什么用”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“Pytorch轉ONNX中tracing機制有什么用”這篇文章吧。
(1)tracing的機制
上文提到過,Pytorch轉ONNX的方式是基于tracing(追蹤),通俗來說,就是ONNX的相關代碼在一旁看著Pytorch跑一遍,運行了什么內容就把什么記錄下來。但是在這里并不是所有Python的運行內容都會被記錄。舉個例子,下面的代碼中,
c = torch.matmul(a, b)
print("Blabla")
e = torch.matmul(c, d)
其中只有第1,3行相關的內容會被記錄,因為只有他們是和Pytorch相關的,而第二行只是普通的python語句。
具體來說,只有ATen操作會被記錄下來。ATen可以被理解為一個Pytorch的基本操作庫,一切的Pytorch函數都是基于這些零部件構造出來的(比如ATen就是加減乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根據加減乘除構造出來)
*之前說的ONNX無法記錄if語句的問題也是因為if并不是Aten中的操作
雖然ONNX可以記錄所有Pytorch的執行(即記錄所有ATen操作),但是在輸出的時候會做一個剪枝,把沒用的操作剪掉
舉個例子,下面的程序,顯而易見第一句話是沒有用的。
t1 = torch.matmul(a, b)
t2 = torch.matmul(c, d)
return t2
ONNX會在得到全部的操作以及他們之間的輸入輸出關系后(以DAG作為表示),根據DAG的輸出往前推,做遍歷,所有可以被遍歷到的節點被保留,其他節點直接扔掉。
在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代碼:
if bboxes.numel() == 0:
bboxes = multibboxes.newzeros((0, 5))
labels = multibboxes.newzeros((0, ), dtype=torch.long)
if torch.onnx.isinonnxexport():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels
dets, keep = batchednms(bboxes, scores, labels, nmscfg)
代碼邏輯很簡單,如果之前的網絡根本沒有輸出任何合法的bbox(第一行的分支判斷),那么顯然nms的結果就是一堆0,所以沒必要運行nms直接返回0就可以。
如果我們想將這段代碼轉換到ONNX,之前我們提到過ONNX不能處理分支邏輯,因此只能選擇一條路去走,記錄那條路轉換得到的模型。很顯然,正常情況下我們自然期待會有較多的bbox,并且將這些bbox作為參數調用nms。
所以如果我們發現模型執行的路徑觸發了if分支,我們必須要進行一個判斷,看看是不是在轉ONNX,如果是的話我們就需要直接報錯,因為顯然轉出來的ONNX不是我們想要的。
假設什么都不做,在這種情況下我們轉出來的模型是什么樣呢?思考一下不難發現,假設函數的返回值就是網絡的最終輸出,那么我們只會得到一個2個節點的DAG,即第2,3行的兩個操作。之前說過ONNX拿到所有的DAG之后會做剪枝,在這里ONNX拿到返回值(bboxes, labels)做回溯,發現最頭上就是第2,3行的兩個操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都會被扔掉。
因此,在進行MMDet模型的轉換的時候,必須用真實的數據和訓練好的參數來做轉換,否則基本不會得到有效的bbox,于是就會觸發第6行的error
(2)利用tracing機制做優化
在MMSeg中有一個很巧妙的利用tracing機制做優化的例子。
在slide inference時,我們需要計算一個count mat矩陣,這個矩陣在h, w以及對應的stride都固定的情況下會是一個常量。
不過在訓練時,往往這些都是我們要調的參數,所有MMSeg沒有選擇把這些常數保存下來,而是每次都算一遍
countmat = img.newzeros((batchsize, 1, himg, wimg))
for hidx in range(hgrids):
for widx in range(wgrids):
y1 = hidx * hstride
x1 = widx * wstride
y2 = min(y1 + hcrop, himg)
x2 = min(x1 + wcrop, wimg)
y1 = max(y2 - hcrop, 0)
x1 = max(x2 - wcrop, 0)
cropimg = img[:, :, y1:y2, x1:x2]
cropseglogit = self.encodedecode(cropimg, imgmeta)
preds += F.pad(cropseglogit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
countmat[:, :, y1:y2, x1:x2] += 1
assert (countmat == 0).sum() == 0
if torch.onnx.isinonnxexport():
# cast countmat to constant while exporting to ONNX
countmat = torch.fromnumpy(
countmat.cpu().detach().numpy()).to(device=img.device)
不過在部署時,這些參數往往是固定的,因此我們沒必要把它算一遍。因此在倒數第4行的if分支里,我們做了一件看似很沒用的事
countmat = torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)
即我們把算出來的countmat從tensor轉換成numpy,再轉回tensor。
其實我們的目的是切斷tracing。
之前提到過,ONNX只能記錄ATen相關的操作,但是很顯然,tensor和numpy的互轉肯定不是ATen操作。因此在回溯的時候,當訪問到count mat,ONNX并不能發現它是被誰運算出來的,所以countmat就會被看作一個常數被保存下來,之前計算countmat的部分都會被扔掉
以上是“Pytorch轉ONNX中tracing機制有什么用”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。