您好,登錄后才能下訂單哦!
這篇文章主要介紹“PyTorch怎么安裝和使用”的相關知識,小編通過實際案例向大家展示操作過程,操作方法簡單快捷,實用性強,希望這篇“PyTorch怎么安裝和使用”文章能幫助大家解決問題。
安裝PyTorch geometric
首先確保安裝了PyTorch 1.2.0及以上版本
$ python -c "import torch; print(torch.__version__)"
>>> 1.2.0
安裝依賴包
$ pip install --verbose --no-cache-dir torch-scatter
$ pip install --verbose --no-cache-dir torch-sparse
$ pip install --verbose --no-cache-dir torch-cluster
$ pip install --verbose --no-cache-dir torch-spline-conv (optional)
$ pip install torch-geometric
注意:
def spawn(self, cmd):
spawn(cmd, dry_run=self.dry_run)
改為
def spawn(self, cmd):
spawn(cmd, dry_run=self.dry_run)
import torch
from torch_geometric.data import Data
#邊,shape = [2,num_edge]
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
#點,shape = [num_nodes, num_node_features]
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
數據集
PyTorch Geometric已經包含有很多常見的基準數據集,包括:
Cora:一個根據科學論文之間相互引用關系而構建的Graph數據集合,論文分為7類:Genetic_Algorithms,Neural_Networks,Probabilistic_Methods,Reinforcement_Learning,Rule_Learning,Theory,共2708篇;
Citeseer:一個論文之間引用信息數據集,論文分為6類:Agents、AI、DB、IR、ML和HCI,共包含3312篇論文;
Pubmed:生物醫學方面的論文搜尋以及摘要數據集。
以及網址中的數據集等等。
初始化這樣的一個數據集也很簡單,會自動下載對應的數據集然后處理成需要的格式,例如ENZYMES dataset (覆蓋6大類的600個圖,可用于graph-level的分類任務):
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
>>> ENZYMES(600)
len(dataset)
>>> 600
dataset.num_classes
>>> 6
dataset.num_node_features
>>> 3
對于其中的第一個圖,可以這樣取得:
data = dataset[0]
>>> Data(edge_index=[2, 168], x=[37, 3], y=[1])
#可以看出這個圖包含邊168/2=84條,節點37個,每個節點包含三個特征
data.is_undirected()
>>> True
再看一個node-level的數據集
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()
#可以看到這個數據集只有一個圖
len(dataset)
>>> 1
dataset.num_classes
>>> 7
dataset.num_node_features
>>> 1433
#train_mask
data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
#用來訓練的數據量
data.train_mask.sum().item()
>>> 140
#用來驗證的數據量
data.val_mask.sum().item()
>>> 500
#用來測試的數據量
data.test_mask.sum().item()
>>> 1000
下面再來看一個完整的例子:
import torch
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
#數據集加載
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
#網絡定義
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#網絡訓練
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
#測試
model.eval()
_, pred = model(data).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
GCNConv層具體的實現代碼為:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: 增加自連接到鄰接矩陣
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 對節點的特征矩陣進行線性變換
x = self.lin(x)
# Step 3-5: Start propagating messages.
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# Step 3: Normalize node features.
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
# aggr_out has shape [N, out_channels]
# Step 5: Return new node embeddings.
return aggr_out
關于“PyTorch怎么安裝和使用”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識,可以關注億速云行業資訊頻道,小編每天都會為大家更新不同的知識點。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。