91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

如何對pytorch中不定長序列補齊

發布時間:2021-05-31 09:45:49 來源:億速云 閱讀:893 作者:小新 欄目:開發技術

小編給大家分享一下如何對pytorch中不定長序列補齊,希望大家閱讀完這篇文章之后都有所收獲,下面讓我們一起去探討吧!

第二種方法通常是在load一個batch數據時, 在collate_fn中進行補齊的.

以下給出兩種思路:

第一種思路是比較容易想到的, 就是對一個batch的樣本進行遍歷, 然后使用np.pad對每一個樣本進行補齊.

for unit in data:
        mask = np.zeros(max_length)
        s_len = len(unit[0])    # calculate the length of sequence in each unit
        mask[: s_len] = 1
        unit[0] = np.pad(unit[0], (0, max_length - s_len), 'constant', constant_values=(0, 0))
        mask_batch.append(mask)

但是這種方法在batch size很大的情況下會很慢, 因為使用for循環進行了遍歷. 我在實際用的時候, 當batch_size=128時, 一個batch的加載時間甚至是一個batch訓練時間的幾倍!

因此, 我想到如何并行地對序列進行補齊. 第二種方法的思路就是使用torch中自帶的pad_sequence來并行補齊.

batch_sequence = list(map(lambda x: torch.tensor(x[findex]), x_data))
batch_data[feat] = torch.nn.utils.rnn.pad_sequence(batch_sequence).T

可以看到這里使用pad_sequence一次性對整個batch進行補齊. 下面對這個函數進行詳細說明.

pad_sequence詳解

from torch.utils.rnn import pad_sequence
a = torch.ones(10)
b = torch.ones(6)
c = torch.ones(20)
abc = pad_sequence([a,b,c])  # shape(20, 3)

注意這個函數接收的是一個元素為tensor的列表, 而不是tensor.

最終, 這個函數會將所有tensor轉換為tensor矩陣#shape(max_length, batch_size). 因此, 在使用完后通常還需要轉置一下.

補充:PyTorch中用于RNN變長序列填充函數的簡單使用

1、PyTorch中RNN變長序列的問題 ??

RNN在處理變長序列時有它的優勢。在分批處理變長序列問題時,每個序列的長度往往不會完全相等,因此針對一個batch中序列長度不一的情況,需要對某些序列進行PAD(填充)操作,使得一個batch內的序列長度相等。 ??

PyTorch中的pack_padded_sequence和pad_packed_sequence可處理上述問題,以下用一個示例演示這兩個函數的簡單使用方法。

2、填充函數簡介

“壓縮”函數:用于將填充后的序列tensor進行壓縮,方便RNN處理

pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)

(1)input->被“壓縮”的tensor,維度一般為[batch_size,_max_seq_len[,embedding_size]]或者[max_seq_len,batch_size[,embedding_size]]

若input維度為:[batch_size,_max_seq_len[,embedding_size]]

要將batch_first設置為True,這表示input的第一個維度為batch的數量

若input維度為:[max_seq_len,batch_size[,embedding_size]]

要將batch_first設置為False(默認值),這表示input的第一個維度不是batch的數量

(2)lengths->lengths參數表示一個batch中序列真實長度,類型為列表,在例子中詳細說明

(3)batch_first->表示batch的數量是否在input的第一維度,默認值為False

(4)enforce_sorted->input中的會自動按照lengths的情況進行排序,默認值為

“解壓”函數:該函數與"壓縮函數"相對應,經“壓縮函數”處理的輸入經過RNN得到的最終結果可以利用該函數進行“解壓”

pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):

(1)sequence->壓縮函數處理過的input經RNN后得到的結果

(2)batch_first->與“壓縮”函數中的batch_first一致

(3)padding_value->序列進行填充時使用的索引,默認為0

(4)total_length->暫略

3、PyTorch代碼示例

代碼如下(示例):

# Create by leslie_miao on 2020/11/1
import torch
import torch.nn as nn
d_model = 10 # 詞嵌入的維度
hidden_size = 20 # lstm隱藏層單元數量
layer_num = 1 # lstm層數
# 輸入inputs,維度為[batch_size,max_seq_len]=[3,4],其中0代表填充
# 該input包含3個序列,每個序列的真實長度分別為: 4 3 2
inputs = torch.tensor([[1,2,3,4],[1,2,3,0],[1,2,0,0]])
embedding = nn.Embedding(5,d_model)
# 獲取詞嵌入后的inputs 當前inputs的維度為[batch_size,max_seq_len,d_model]=[3,4,10]
inputs = embedding(inputs)
# 查看inputs的維度
print(inputs.size())
# print: torch.Size([3, 4, 10])
# 利用“壓縮”函數對inputs進行壓縮處理,[4,3,2]分別為inputs中序列的真實長度,batch_first=True表示inputs的第一維是batch_size
inputs = nn.utils.rnn.pack_padded_sequence(inputs,lengths=[4,3,2],batch_first=True)
# 查看經“壓縮”函數處理過的inputs的維度
print(inputs[0].size())
# print: torch.Size([9, 10])
# 定義RNN網絡
network = nn.LSTM(input_size=d_model,hidden_size=hidden_size,batch_first=True,num_layers=layer_num)
# 初始化RNN相關門參數
c_0 = torch.zeros((layer_num,3,hidden_size))
h_0 = torch.zeros((layer_num,3,hidden_size)) # [rnn層數,batch_size,hidden_size]
# inputs經過RNN網絡后得到的結果outputs
output,(h_n,c_n) = network(inputs,(h_0,c_0))
#查看未經“解壓函數”處理的outputs維度
print(output[0].size())
# print: torch.Size([9, 20])
# 利用“解壓函數”對outputs進行解壓操作,其中batch_first設置與“壓縮函數相同”,padding_value為0
output = nn.utils.rnn.pad_packed_sequence(output,batch_first=True,padding_value=0)
# 查看經“解壓函數”處理的outputs維度
print(output[0].size())
# print:torch.Size([3, 4, 20])

看完了這篇文章,相信你對“如何對pytorch中不定長序列補齊”有了一定的了解,如果想了解更多相關知識,歡迎關注億速云行業資訊頻道,感謝各位的閱讀!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

烟台市| 临夏县| 昌都县| 且末县| 虹口区| 萍乡市| 万年县| 龙井市| 绍兴县| 墨玉县| 栖霞市| 榕江县| 绥棱县| 龙山县| 绥芬河市| 镇坪县| 常州市| 革吉县| 噶尔县| 读书| 平阳县| 淄博市| 茂名市| 皋兰县| 墨竹工卡县| 崇礼县| 元氏县| 乌兰浩特市| 新化县| 兴城市| 柳林县| 建水县| 邵阳市| 葫芦岛市| 体育| 永善县| 霞浦县| 鹿泉市| 永福县| 崇礼县| 凤凰县|