您好,登錄后才能下訂單哦!
這篇文章主要為大家展示了“pytorch標簽轉onehot形式的示例分析”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“pytorch標簽轉onehot形式的示例分析”這篇文章吧。
代碼:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size, 1).random_() % class_num print(label.size()) one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1) print(one_hot)
輸出:
torch.Size([4, 1]) tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
注意:
label的形狀必須是[n,1]的,也就是必須是二維的,且第二個維度長度為1,如果是一維度的,則需要升維度,代碼如下:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size).random_() % class_num print(label.size()) label = torch.unsqueeze(label,dim=1) print(label.size())
以上是“pytorch標簽轉onehot形式的示例分析”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。