您好,登錄后才能下訂單哦!
如何進行移動端SOTA模型MixNet的分析,很多新手對此不是很清楚,為了幫助大家解決這個難題,下面小編將為大家詳細講解,有這方面需求的人可以來學習下,希望你能有所收獲。
Depthwise卷積在設計更輕量高效的網絡中經常被使用,但人們通常都忽略了Depthwise卷積中的卷積核大小(通常都是使用3x3)。我們研究了不同大小卷積核對網絡性能的影響,并觀察到不同大小卷積核相互組合,能得到更高的準確性。基于這個思想,我們得到了一個以不同大小卷積核組合成Depthwise卷積模塊,再AutoML的搜索下,提出了一個更高效的網絡Mixnet,超越大部分移動端網絡如Mobilenetv1, v2, shufflenet等等。
由于Depthwise卷積是分離各個通道,單獨做一個卷積操作。因此在設計網絡中,為了減少計算量,研究人員通常把注意力放在如何控制通道數,使得網絡計算量不會增長過大。然后網絡中通常只采用了3x3大小卷積核的卷積,而在其他工作中表明大卷積核在一定程度上能提高模型性能。我們問題轉為使用大卷積核是否就一定提高模型準確率?
通過對比兩種網絡結構,我們可以得知不同網絡最好的性能對應著不同的卷積核大小
基于觀察的結果,我們設置了一個不同大小卷積核構成的MixConv模塊
MixConv模塊還有很多參數沒有實際確定
進行MixConv需要對通道做分組,分配給不同大小的卷積核。實驗中,研究人員發現Groups = 4時候,是對MobileNets結構最穩定的。借助于NAS搜索,研究人員分別從1-5的分組數進行結構搜索。
卷積核大小雖然能隨意設計,但還是有一定前提的。比如當兩個組的卷積核大小相同,其實可以等價于這兩個組融合進一個卷積組里(比如2組都是3x3卷積核,輸出通道為X,相當于1組由3x3卷積核,輸出通道X)
因此我們設定,卷積核起始大小為3,組與組之間卷積核增長為2
比如分4組的話,卷積核為3x3 5x5 7x7 9x9
我們采取了兩種策略
空洞卷積往往能得到更大的感受野,相較于同等感受野的大卷積核,它能一定程度上減少參數量,然而根據我們的實驗,空洞卷積的性能通常要比大卷積核的差
上圖是基于Mobilenet結構上,對Mixconv各種策略的進一步驗證
介紹完前面的設計理念后,這篇論文也就差不多了,后續的工作都是AutoML進行搜索得到的,Mixnet有三種大小的模型(MixNet-S, MixNet-M, MixNet-L)
下面兩圖分別是Mixnet-S和Mixnet-M的結構
這里采用的是https://github.com/romulus0914/MixNet-PyTorch 這版代碼,講解的是研究人員提出的不同kernel_size的DepthwiseConv模塊
class MDConv(nn.Module):
"""
實現分離depthwise卷積
"""
def __init__(self, channels, kernel_size, stride):
super(MDConv, self).__init__()
self.num_groups = len(kernel_size)
self.split_channels = _SplitChannels(channels, self.num_groups)
self.mixed_depthwise_conv = nn.ModuleList()
for i in range(self.num_groups):
self.mixed_depthwise_conv.append(nn.Conv2d(
self.split_channels[i], self.split_channels[i],
kernel_size[i], stride=stride, padding=kernel_size[i] // 2,
groups=self.split_channels[i],
bias=False
))
def forward(self, x):
if self.num_groups == 1:
return self.mixed_depthwise_conv[0](x)
x_split = torch.split(x, self.split_channels, dim=1)
x = [conv(t) for conv, t in zip(self.mixed_depthwise_conv, x_split)]
x = torch.cat(x, dim=1)
return x
首先通過splitchannels這個方法,得到每個kernel size對應的通道數。
再用一個for循環,把每個不同kernel size的卷積模塊,添加到ModuleList容器中
在前向傳播里面,先是調用torch.split方法對輸入在通道維度上做分離,通過一個列表,保存所有卷積得到的張量。最后調用torch.cat在通道維上進行連結。
看完上述內容是否對您有幫助呢?如果還想對相關知識有進一步的了解或閱讀更多相關文章,請關注億速云行業資訊頻道,感謝您對億速云的支持。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。