图神经网络 torch_geometric 库的 MessagePassing 运行机制学习
torch_geometric.nn.conv. MessagePassing( )
继承这个类,可以自定义节点信息传播机制
例子
import torch
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.conv import MessagePassingclass GCNConv(MessagePassing):def __init__(self):#选择相加的方式进行邻居节点信息聚合super().__init__(aggr='add')def forward(self, x, edge_index):#给图添加自环edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))print(edge_index)out = self.propagate(edge_index, x=x)print(out)def message(self, x_j):print(x_j)return x_jedge_index = torch.tensor([[0, 1],[1, 0],[1, 2],[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [6], [1]], dtype=torch.float)
edge_index = edge_index.permute(1, 0)
model = GCNConv( )
out = model(x, edge_index)
运行截图

MessagePassing的运行机制就是用行坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,然后用列坐标[1,0,2,1,0,1,2]进行 add 聚合信息,x_j其实就是根据行坐标得来的,行坐标里面的每一个元素其实就是一个节点标号,它告诉我们当前聚合信息时,每一个节点的信息应该是怎么样,在这里我没有转换节点feature,直接就是初始feature进行聚合,然后列坐标的元素进行聚合,如:列坐标中0节点与行节点对应的元素为1,0,所以在x_j对应位置找到元素6,-1然后相加得5,同理,1节点为-1+1+6 =6,2节点为6+1=7;
需要注意的是def message(self, x_j)中x_j的参数名字不能随便改变,不然会出错;其实x_j可以变为x_i,x_i代表以列坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,但仍以列坐标[0,1,1,2,0,1,2]汇聚信息;最后得到的结果如下:

也可以def message(self, x_j,x_i) ,其中x_j,x_i同时返回,可以根据具体应用进行灵活操作;
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
