深入解析图神经网络自注意力层,通过NumPy实现与可视化,揭示其内部运作机制与数学原理。
原文标题:深入解析图神经网络注意力机制:数学原理与可视化实现
原文作者:数据派THU
冷月清谈:
怜星夜思:
2、文章中提到了“位置-转移图”的概念,这个概念在理解GNN的结构和信息流动方面有什么作用?你觉得它还能应用在哪些其他类型的神经网络中?
3、文章通过NumPy实现了GNN自注意力层,并将其与PyG中的实现进行了对比。假设让你来设计一个GNN库,你会如何权衡易用性和性能之间的关系,又会如何保证代码的可维护性?
原文内容
图神经网络自注意力层的数学表示


选择NumPy实现而非解析PyTorch Geometric
-
参数重置
-
forward()方法的多种变体
-
SparseTensors的特殊处理
图注意力层的NumPy实现
import numpy as np
np.random.seed(0)
class GAL1:
num_nodes = 4
num_features = 4
num_hidden_dimensions = 2 # We just choose this arbitrarily // 我们任意选择这个值X = np.random.uniform(-1, 1, (num_nodes, num_features))
print(‘X\n’, X, ‘\n’)def init(self):
W = np.random.uniform(-1, 1, (GAL1.num_hidden_dimensions, GAL1.num_nodes))
print(‘W\n’, W, ‘\n’)
self.XatWT = GAL1.X @ W.T
print(‘XatWT\n’, self.XatWT, ‘\n’)
X [[ 0.09762701 0.43037873 0.20552675 0.08976637] [-0.1526904 0.29178823 -0.12482558 0.783546 ] [ 0.92732552 -0.23311696 0.58345008 0.05778984] [ 0.13608912 0.85119328 -0.85792788 -0.8257414 ]]
W
[[-0.95956321 0.66523969 0.5563135 0.7400243 ]
[ 0.95723668 0.59831713 -0.07704128 0.56105835]]
XatWT
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]
class GAL2:
A = np.array([
[1, 1, 1, 1],
[1, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]
])def init(self, gal1: GAL1):
print(‘A\n’, GAL2.A, ‘\n’)
u = np.asarray(GAL2.A > 0)
print(‘u\n’, u, ‘\n’)self.connections = u.nonzero()
print(‘connections\n’, self.connections, ‘\n’)XatWTc0 = gal1.XatWT[self.connections[0]]
print(‘XatWTc0\n’, XatWTc0, ‘\n’)XatWTc1 = gal1.XatWT[self.connections[1]]
print(‘XatWTc1\n’, XatWTc1, ‘\n’)self.XatWT_concat = np.concatenate([XatWTc0, XatWTc1], axis=1)
print(‘XatWT_concat\n’, self.XatWT_concat, ‘\n’)
def reshape(self, e: np.ndarray) -> np.ndarray:
E = np.zeros(GAL2.A.shape)
E[self.connections[0], self.connections[1]] = e[0]
return E
A [[1 1 1 1] [1 1 0 0] [1 0 1 1] [1 0 1 1]]
u
[[ True True True True]
[ True True False False]
[ True False True True]
[ True False True True]]
connections
(array([0, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3]), array([0, 1, 2, 3, 0, 1, 0, 2, 3, 0, 2, 3]))
-
节点0出现四次(出站连接到所有节点包括自身)。
-
节点1仅出现两次(出站连接到节点0和自身)。
-
节点2和节点3各出现三次(出站连接到节点0、彼此和自身)。
XatWTc0 [[ 0.37339233 0.38548525] [ 0.37339233 0.38548525] [ 0.37339233 0.38548525] [ 0.37339233 0.38548525] [ 0.85102612 0.47765279] [ 0.85102612 0.47765279] [-0.67755906 0.73566587] [-0.67755906 0.73566587] [-0.67755906 0.73566587] [-0.65268413 0.24235977] [-0.65268413 0.24235977] [-0.65268413 0.24235977]]
XatWTc1
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]
XatWT_concat [[ 0.37339233 0.38548525 0.37339233 0.38548525] [ 0.37339233 0.38548525 0.85102612 0.47765279] [ 0.37339233 0.38548525 -0.67755906 0.73566587] [ 0.37339233 0.38548525 -0.65268413 0.24235977] [ 0.85102612 0.47765279 0.37339233 0.38548525] [ 0.85102612 0.47765279 0.85102612 0.47765279] [-0.67755906 0.73566587 0.37339233 0.38548525] [-0.67755906 0.73566587 -0.67755906 0.73566587] [-0.67755906 0.73566587 -0.65268413 0.24235977] [-0.65268413 0.24235977 0.37339233 0.38548525] [-0.65268413 0.24235977 -0.67755906 0.73566587] [-0.65268413 0.24235977 -0.65268413 0.24235977]]
class GAL3:
@staticmethod
def leaky_relu(x, alpha=0.2) -> np.ndarray:
return np.maximum(alpha * x, x)@staticmethod
def softmax2D(x, axis) -> np.ndarray:
e = np.exp(x - np.expand_dims(np.max(x, axis=axis), axis))
sum_ = np.expand_dims(np.sum(e, axis=axis), axis)
return e / sum_def init(self, gal2: GAL2):
W_att = np.random.uniform(-1, 1, (1, GAL1.num_nodes))
print(‘W_att\n’, W_att, ‘\n’)a = W_att @ gal2.XatWT_concat.T
print(‘a\n’, a, ‘\n’)e = GAL3.leaky_relu(a)
print(‘e\n’, e, ‘\n’)E = gal2.reshape(e)
print(‘E\n’, E, ‘\n’)W_alpha = GAL3.softmax2D(E, 1)
print(‘W_alpha\n’, W_alpha, ‘\n’)
self.left = gal2.A.T @ W_alpha
print(‘left\n’, self.left, ‘\n’)
-
W_att:初始化或来自前一轮训练
-
a:W_att与gal2.XatWT_concat的矩阵乘法
-
e:对a应用leaky_relu函数
-
E:调用gal2.reshape方法,传入e作为输入
W_att [[-0.76345115 0.27984204 -0.71329343 0.88933783]]
a
[[-0.1007035 -0.35942847 0.96036209 0.50390318 -0.43956122 -0.69828618
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]
e
[[-0.0201407 -0.07188569 0.96036209 0.50390318 -0.08791224 -0.13965724
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]
E
[[-0.0201407 -0.07188569 0.96036209 0.50390318]
[-0.08791224 -0.13965724 0. 0. ]
[ 0.79964181 0. 1.8607074 1.40424849]
[ 0.64260322 0. 1.70366881 1.2472099 ]]
-
W_alpha:对E应用softmax函数。
-
self.left:gal2.A.T与W_alpha的矩阵乘法。
W_alpha [[0.15862414 0.15062488 0.42285965 0.26789133] [0.24193418 0.22973368 0.26416607 0.26416607] [0.16208847 0.07285714 0.46834625 0.29670814] [0.16010498 0.08420266 0.46261506 0.2930773 ]]
left
[[0.72275177 0.53741836 1.61798703 1.12184284]
[0.40055832 0.38035856 0.68702572 0.5320574 ]
[0.48081759 0.30768468 1.35382096 0.85767677]
[0.48081759 0.30768468 1.35382096 0.85767677]]
class GAL4:
def init(self, gal1: GAL1, gal3: GAL3):
self.H = gal3.left @ gal1.XatWT
print(‘H\n’, self.H, ‘\n’)if name == ‘main’:
gal_1 = GAL1()
gal_2 = GAL2(gal_1)
gal_3 = GAL3(gal_2)
gal_4 = GAL4(gal_1, gal_3)
H [[-1.10126376 1.99749693] [-0.33950544 0.97045933] [-1.03570438 1.53614075] [-1.03570438 1.53614075]]
图注意力层的结构分析
核心代码
def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: x = x_i + x_j
some conditional edge code removed… // 删除了一些条件边缘代码…
x = F.leaky_relu(x, self.negative_slope)
alpha = (x * self.att).sum(dim=-1)
alpha = softmax(alpha, index, ptr, dim_size)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return x_j * alpha.unsqueeze(-1)
总结