91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch中torch.topk()函數怎么用

發布時間:2022-02-25 11:35:27 來源:億速云 閱讀:228 作者:iii 欄目:開發技術

這篇文章主要介紹“pytorch中torch.topk()函數怎么用”,在日常操作中,相信很多人在pytorch中torch.topk()函數怎么用問題上存在疑惑,小編查閱了各式資料,整理出簡單好用的操作方法,希望對大家解答”pytorch中torch.topk()函數怎么用”的疑惑有所幫助!接下來,請跟著小編一起來學習吧!

函數作用:

pytorch中torch.topk()函數怎么用

pytorch中torch.topk()函數怎么用

該函數的作用即按字面意思理解,topk:取數組的前k個元素進行排序。

通常該函數返回2個值,第一個值為排序的數組,第二個值為該數組中獲取到的元素在原數組中的位置標號。

舉個栗子:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader

####################準備一個數組#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)

####################打印這個原數組#########################
print('tensor1:')
print(tensor1)

#################使用torch.topk()這個函數##################
print('使用torch.topk()這個函數得到:')

'''k=3代表從原數組中取得3個元素,dim=1表示從原數組中的第一維獲取元素
(在本例中是分別從[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
  [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]這四個數組中獲取3個元素)
其中largest=True表示從大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))


#################打印這個函數第一個返回值####################
print('函數第一個返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])

#################打印這個函數第二個返回值####################
print('函數第二個返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''

#######################運行結果##########################
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])

使用torch.topk()這個函數得到:

'得到的values是原數組dim=1的四組從大到小的三個元素值;
得到的indices是獲取到的元素值在原數組dim=1中的位置。'


torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))

函數第一個返回值topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])
        
函數第二個返回值topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

該函數功能經常用來獲取張量或者數組中最大或者最小的元素以及索引位置,是一個經常用到的基本函數。

實例演示

任務一:

取top1(最大值):

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的結果,設置keepdim為True,避免降維。因為topk函數返回的index不降維,shape和輸入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
輸出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])

任務二:

按行取出topk,將小于topk的置為inf:

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
top_k = 2  # 按行求出每一行的最大的前兩個值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value  # 對于topk之外的其他元素的logits值設為負無窮
print(pred)
 
輸出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[4],
        [4],
        [4],
        [3]])
tensor([[0.4053],
        [1.8823],
        [1.7255],
        [0.3849]])
tensor([[ True, False,  True,  True, False],
        [ True, False,  True,  True, False],
        [ True,  True, False,  True, False],
        [ True, False,  True, False,  True]])
tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],
        [   -inf,  1.4164,    -inf,    -inf,  1.8823],
        [   -inf,    -inf,  1.2590,    -inf,  1.7255],
        [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

任務三:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
#                       [6,5,4],
#                       [1,4,7],
#                       [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接輸出topk,會得到兩個東西,我們需要的是第二個indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])
直接輸出topk,會得到兩個東西,我們需要的是第二個indices
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))
topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])
topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

到此,關于“pytorch中torch.topk()函數怎么用”的學習就結束了,希望能夠解決大家的疑惑。理論與實踐的搭配能更好的幫助大家學習,快去試試吧!若想繼續學習更多相關知識,請繼續關注億速云網站,小編會繼續努力為大家帶來更多實用的文章!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

广饶县| 内丘县| 芜湖县| 宁德市| 徐州市| 南昌县| 哈巴河县| 镇赉县| 丰城市| 朝阳县| 图木舒克市| 遂平县| 昔阳县| 新和县| 桂林市| 观塘区| 壶关县| 东阿县| 长顺县| 鸡泽县| 武夷山市| 墨玉县| 文安县| 驻马店市| 张掖市| 敖汉旗| 涿州市| 咸丰县| 邯郸市| 麻城市| 新兴县| 资中县| 兴安县| 东辽县| 巍山| 友谊县| 芦溪县| 张掖市| 岗巴县| 昌乐县| 赞皇县|