PyTorch的PyG庫可以支持自定義層。在PyTorch中,可以通過繼承torch.nn.Module
類來創建自定義層。例如,定義一個簡單的全連接層,可以這樣做:
import torch
import torch.nn as nn
class MyLayer(nn.Module):
def __init__(self, input_dim, output_dim):
super(MyLayer, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
在這個例子中,MyLayer
類繼承自nn.Module
,并定義了一個全連接層self.linear
。在forward
方法中,我們將輸入x
傳遞給這個全連接層,并返回其輸出。
然后,在使用PyG庫時,可以將這個自定義層添加到圖結構中。例如,定義一個包含自定義層和PyTorch nn.Linear
層的圖結構:
from torch_geometric.nn import MessagePassing
import torch
class MyModel(MessagePassing):
def __init__(self, in_channels, out_channels):
super(MyModel, self).__init__(aggr='add')
self.lin = nn.Linear(in_channels, out_channels)
self.my_layer = MyLayer(in_channels, 64)
def forward(self, x, edge_index):
row, col = edge_index
x = self.my_layer(x)
x = self.lin(x)
row, col = row.view(-1, 1), col.view(-1, 1)
deg = self.degree(row, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def degree(self, row, num_nodes, dtype):
row, col = row.to(dtype), col.to(dtype)
deg = torch.bincount(row, minlength=num_nodes, dtype=dtype)
deg = deg[row] + deg[col]
return deg.view(-1, 1)
在這個例子中,MyModel
類繼承自MessagePassing
,并定義了一個包含自定義層self.my_layer
和PyTorch nn.Linear
層的圖結構。在forward
方法中,我們首先對輸入x
應用自定義層,然后應用線性層,最后根據邊的權重計算消息和更新節點特征。