91超碰碰碰碰久久久久久综合_超碰av人澡人澡人澡人澡人掠_国产黄大片在线观看画质优化_txt小说免费全本

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

返回最大值的index?pytorch方式是什么

發布時間:2022-07-18 10:04:30 來源:億速云 閱讀:112 作者:iii 欄目:開發技術

這篇文章主要講解了“返回最大值的index pytorch方式是什么”,文中的講解內容簡單清晰,易于學習與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學習“返回最大值的index pytorch方式是什么”吧!

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
                [1.1,1.2,1.3],
                [2.1,2.2,2.3],
                [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

輸出:

tensor([ 2,  2,  2,  2])
tensor(11)

pytorch 找最大值

題意:使用神經網絡實現,從數組中找出最大值。

提供數據:兩個 csv 文件,一個存訓練集:n 個 m 維特征自然數數據,另一個存每條數據對應的 label ,就是每條數據中的最大值。

這里將隨機構建訓練集:

#%%
import numpy as np 
import pandas as pd 
import torch 
import random 
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
  
def GetData(m, n):
    dataset = []
    for j in range(m):
        max_v = random.randint(0, 9)
        data = [random.randint(0, 9) for i in range(n)]
        dataset.append(data)
    label = [max(dataset[i]) for i in  range(len(dataset))]
    data_list = np.column_stack((dataset, label))
    data_list = data_list.astype(np.float32)
    return data_list
 
#%%
# 數據集封裝 重載函數len, getitem
class GetMaxEle(Data.Dataset):
    def __init__(self, trainset):
        self.data = trainset 
 
    def __getitem__(self, index):
        item = self.data[index]
        x = item[:-1]
        y = item[-1]
        return x, y
    
    def __len__(self):
        return len(self.data)
 
# %% 定義網絡模型
class SingleNN(nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(SingleNN, self).__init__()
        
        self.hidden = nn.Linear(n_feature, n_hidden)
        self.relu = nn.ReLU()
        self.predict = nn.Linear(n_hidden, n_output)
 
    def forward(self, x):
        x = self.hidden(x)
        x = self.relu(x)
        x = self.predict(x)
        return x
  
def train(m, n, batch_size, PATH):
    # 隨機生成 m 個 n 個維度的訓練樣本
    data_list =GetData(m, n)
    dataset = GetMaxEle(data_list)
    trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                      shuffle=True)
 
    net = SingleNN(n_feature=10, n_hidden=100,
                   n_output=10)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    #
    total_epoch = 100
    for epoch in range(total_epoch):
        for index, data in enumerate(trainset):
            input_x, labels = data
            labels = labels.long()
            optimizer.zero_grad()
 
            output = net(input_x)
            # print(output)
            # print(labels)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
 
        # scheduled_optimizer.step()
        print(f"Epoch {epoch}, loss:{loss.item()}")
 
    # %% 保存參數
    torch.save(net.state_dict(), PATH)
    #測試 
  
def test(m, n, batch_size, PATH):
    data_list = GetData(m, n)
    dataset = GetMaxEle(data_list)
    testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    dataiter = iter(testloader)
    input_x, labels = dataiter.next()
    net = SingleNN(n_feature=10, n_hidden=100,
                   n_output=10)
    net.load_state_dict(torch.load(PATH))
    outputs = net(input_x)
 
    _, predicted = torch.max(outputs, 1)
    print("Ground_truth:",labels.numpy())
    print("predicted:",predicted.numpy())
  
if __name__ == "__main__":
    m = 1000
    n = 10
    batch_size = 64
    PATH = './max_list.pth'
    train(m, n, batch_size, PATH)
    test(m, n, batch_size, PATH)

初始的想法是使用全連接網絡+分類來實現, 但是結果不盡人意,主要原因:不同類別之間的樣本量差太大,幾乎90%都是最大值。

比如代碼中隨機構建 10 個 0~9 的數字構成一個樣本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 該樣本標簽是9。

感謝各位的閱讀,以上就是“返回最大值的index pytorch方式是什么”的內容了,經過本文的學習后,相信大家對返回最大值的index pytorch方式是什么這一問題有了更深刻的體會,具體使用情況還需要大家實踐驗證。這里是億速云,小編將為大家推送更多相關知識點的文章,歡迎關注!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

清徐县| 宁远县| 义马市| 沭阳县| 象山县| 增城市| 三门峡市| 澄迈县| 汽车| 贞丰县| 繁昌县| 丹凤县| 海伦市| 五家渠市| 平安县| 温州市| 邵阳县| 河北省| 新郑市| 油尖旺区| 天长市| 柳江县| 五常市| 余姚市| 农安县| 池州市| 浦城县| 东辽县| 崇文区| 时尚| 马边| 六枝特区| 庆阳市| 长白| 汶川县| 鄢陵县| 桦甸市| 贺兰县| 沾益县| 突泉县| 陆良县|