图神经网络注意力机制解析:原理、可视化与NumPy实现

深入解析图神经网络自注意力层,通过NumPy实现与可视化,揭示其内部运作机制与数学原理。

原文标题:深入解析图神经网络注意力机制:数学原理与可视化实现

原文作者:数据派THU

冷月清谈:

本文深入探讨了图神经网络(GNNs)中自注意力机制的工作原理,通过可视化方法和数学推导,揭示了注意力权重在图结构数据中的生成和应用过程。文章采用“位置-转移图”的概念框架,并使用NumPy实现了GNN自注意力层的关键计算步骤,包括节点特征矩阵和权重矩阵的初始化、邻接矩阵的处理、非线性激活函数和归一化操作的应用等。通过将复杂的数学公式转化为易于理解的代码和可视化图形,本文旨在帮助读者直观地理解GNN自注意力机制,并为GNN的可解释性研究提供新的思路。

怜星夜思:

1、文章中提到了使用NumPy实现GNN自注意力层而非PyTorch Geometric的原因,你觉得除了文中提到的原因外,还有哪些其他可能的考虑因素?
2、文章中提到了“位置-转移图”的概念,这个概念在理解GNN的结构和信息流动方面有什么作用?你觉得它还能应用在哪些其他类型的神经网络中?
3、文章通过NumPy实现了GNN自注意力层,并将其与PyG中的实现进行了对比。假设让你来设计一个GNN库,你会如何权衡易用性和性能之间的关系,又会如何保证代码的可维护性?

原文内容

来源:DeepHub IMBA‍‍‍‍‍‍‍
本文约5000字,建议阅读9分钟
本文旨在通过可视化方法和数学推导,揭示图神经网络自注意力层的内部运作机制。


在图神经网络(Graph Neural Networks, GNNs)的发展历程中,注意力机制扮演着至关重要的角色。通过赋予模型关注图中最相关节点和连接的能力,注意力机制显著提升了GNN在节点分类、链接预测和图分类等任务上的性能。尽管这一机制的重要性不言而喻,但其内部工作原理对许多研究者和工程师而言仍是一个"黑盒"。
本文旨在通过可视化方法和数学推导,揭示图神经网络自注意力层的内部运作机制。我们将采用"位置-转移图"的概念框架,结合NumPy编程实现,一步步拆解自注意力层的计算过程,使读者能够直观理解注意力权重是如何生成并应用于图结构数据的。
通过将复杂的数学表达式转化为易于理解的代码块和可视化图形,本文不仅适合已经熟悉图神经网络的研究人员,也为刚开始接触这一领域的学习者提供了一个清晰的学习路径。
本文详细解析了图神经网络自注意力层的可视化方法及其数学原理,通过代码实现展示其内部工作机制。

图神经网络自注意力层的数学表示

在采用自注意力机制的图神经网络中,一个典型层的计算可以通过以下张量乘法表示:
其中各元素定义如下:
Image
包含自循环的邻接矩阵的转置
注意力张量
Image
节点特征矩阵
常规(非注意力)权重张量的转置
"自注意力"机制的核心在于注意力张量实际上是由方程中其他元素通过线性函数与非线性函数组合生成的。这一概念可能较为抽象,但我们可以通过编程实现来展示这种组合关系,并从代码中推导出直观的图形表示。

选择NumPy实现而非解析PyTorch Geometric

我们选择使用NumPy的原因在于:
PyG的实际代码包含大量计算细节,且设计目标是扩展基础MessagePassing模块,这使得理解张量元素间的关系变得复杂。例如,GATv2Conv模块处理了以下复杂性:
  • 参数重置

  • forward()方法的多种变体

  • SparseTensors的特殊处理
而基本的MessagePassing模块则考虑了更多复杂因素,包括钩子、Jinja文本渲染、可解释性、推理分解、张量大小不匹配异常、"提升"和"收集"的子任务以及分解层等。
因此使用NumPy构建一个简洁明了的例子能够更有效地帮助我们理解注意力张量是如何从方程的其他元素构建而来的。

图注意力层的NumPy实现

为了绘制方程的位置-转移图,我们将Labonne的代码重构为四个类,这四个类对应于本文顶部图中的四个胶囊(GAL1到GAL4)。
采用面向对象的方法使得我们可以通过构造函数(init方法)区分中间结果和在整个位置-转移图中四个类/胶囊间共享的结果。共享结果通过self.x = y赋值保存为实例数据成员。
为便于理解,下面是一个四节点图的示例:
我们假设每个节点都与自身连接。图中展示了入站和出站弧而非无向边,因为入站-出站关系在代码中被显式表示。
为简化起见,我们假设特征和权重初始化均在(-1, 1)范围内。
以下是GAL1的代码实现:
import numpy as np

np.random.seed(0)

class GAL1:

num_nodes = 4
num_features = 4
num_hidden_dimensions = 2 # We just choose this arbitrarily // 我们任意选择这个值

X = np.random.uniform(-1, 1, (num_nodes, num_features))
print(‘X\n’, X, ‘\n’)

def init(self):

W = np.random.uniform(-1, 1, (GAL1.num_hidden_dimensions, GAL1.num_nodes))
print(‘W\n’, W, ‘\n’)

self.XatWT = GAL1.X @ W.T
print(‘XatWT\n’, self.XatWT, ‘\n’)

执行该代码会产生以下输出:
X
[[ 0.09762701  0.43037873  0.20552675  0.08976637]
[-0.1526904   0.29178823 -0.12482558  0.783546  ]
[ 0.92732552 -0.23311696  0.58345008  0.05778984]
[ 0.13608912  0.85119328 -0.85792788 -0.8257414 ]]

W
[[-0.95956321 0.66523969 0.5563135 0.7400243 ]
[ 0.95723668 0.59831713 -0.07704128 0.56105835]]

XatWT
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]


在这一阶段,我们初始化了节点特征矩阵X和标准权重矩阵W。在实际训练场景中,X来自图结构,而W则源自初始化或前一轮训练。这在位置-转移图上表示为标记为"Graph"和"PyTorch Geo"的"云"位置。
GAL1的主要保留数据成员是self.XatWT,即我们方程的右侧部分("at"表示矩阵乘法的"@"中缀符号)。在后续代码中,这个中间结果将与邻接矩阵结合,形成注意力张量。
GAL2的代码实现如下:
class GAL2:

A = np.array([
[1, 1, 1, 1],
[1, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]
])

def init(self, gal1: GAL1):

print(‘A\n’, GAL2.A, ‘\n’)

u = np.asarray(GAL2.A > 0)
print(‘u\n’, u, ‘\n’)

self.connections = u.nonzero()
print(‘connections\n’, self.connections, ‘\n’)

XatWTc0 = gal1.XatWT[self.connections[0]]
print(‘XatWTc0\n’, XatWTc0, ‘\n’)

XatWTc1 = gal1.XatWT[self.connections[1]]
print(‘XatWTc1\n’, XatWTc1, ‘\n’)

self.XatWT_concat = np.concatenate([XatWTc0, XatWTc1], axis=1)
print(‘XatWT_concat\n’, self.XatWT_concat, ‘\n’)

def reshape(self, e: np.ndarray) -> np.ndarray:
E = np.zeros(GAL2.A.shape)
E[self.connections[0], self.connections[1]] = e[0]
return E


邻接矩阵A由图的结构固定。connections计算的结果如下:
A
[[1 1 1 1]
[1 1 0 0]
[1 0 1 1]
[1 0 1 1]]

u
[[ True True True True]
[ True True False False]
[ True False True True]
[ True False True True]]

connections
(array([0, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3]), array([0, 1, 2, 3, 0, 1, 0, 2, 3, 0, 2, 3]))


我们选择的节点标签与邻接矩阵中的索引对应。第一个connections数组表示具有到节点j的出站连接的节点索引i。
例如:
  • 节点0出现四次(出站连接到所有节点包括自身)。

  • 节点1仅出现两次(出站连接到节点0和自身)。

  • 节点2和节点3各出现三次(出站连接到节点0、彼此和自身)。
第二个connections数组包含相同的值,但按入站顺序排列,这是因为该图实际上是非定向的。
使用connections数组作为gal1.XatWT的索引,产生以下输出:
XatWTc0
[[ 0.37339233  0.38548525]
[ 0.37339233  0.38548525]
[ 0.37339233  0.38548525]
[ 0.37339233  0.38548525]
[ 0.85102612  0.47765279]
[ 0.85102612  0.47765279]
[-0.67755906  0.73566587]
[-0.67755906  0.73566587]
[-0.67755906  0.73566587]
[-0.65268413  0.24235977]
[-0.65268413  0.24235977]
[-0.65268413  0.24235977]]

XatWTc1
[[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[ 0.85102612 0.47765279]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]
[ 0.37339233 0.38548525]
[-0.67755906 0.73566587]
[-0.65268413 0.24235977]]


此处,我们的十二元素入站和出站connections索引数组分别被转换为gal1.XatWT元素的十二元素数组。
将入站和出站数组连接,得到结果:
XatWT_concat
[[ 0.37339233  0.38548525  0.37339233  0.38548525]
[ 0.37339233  0.38548525  0.85102612  0.47765279]
[ 0.37339233  0.38548525 -0.67755906  0.73566587]
[ 0.37339233  0.38548525 -0.65268413  0.24235977]
[ 0.85102612  0.47765279  0.37339233  0.38548525]
[ 0.85102612  0.47765279  0.85102612  0.47765279]
[-0.67755906  0.73566587  0.37339233  0.38548525]
[-0.67755906  0.73566587 -0.67755906  0.73566587]
[-0.67755906  0.73566587 -0.65268413  0.24235977]
[-0.65268413  0.24235977  0.37339233  0.38548525]
[-0.65268413  0.24235977 -0.67755906  0.73566587]
[-0.65268413  0.24235977 -0.65268413  0.24235977]]

ndarray connections被赋值给self,但并非为了在GAL2外部使用(因此在图中用虚线椭圆表示)。相反,我们在reshape方法中使用connections。reshape方法通过创建一个与A形状相同的零矩阵来生成ndarray E,然后使用connections[0]作为E的行索引,connections[1]作为E的列索引,从输入ndarray e[0]分配值。此方法将被GAL3调用。
显然,E按connections和e的排序应具有相同数量的元素。E的某些元素将保持未分配状态(零值),即那些对应于图中缺少入站或出站弧的节点对的元素。
除了GAL2.A之外,数组XatWT_concat也将在后续计算中使用,因此被赋值给self。
GAL3的代码实现如下:
class GAL3:

@staticmethod
def leaky_relu(x, alpha=0.2) -> np.ndarray:
return np.maximum(alpha * x, x)

@staticmethod
def softmax2D(x, axis) -> np.ndarray:
e = np.exp(x - np.expand_dims(np.max(x, axis=axis), axis))
sum_ = np.expand_dims(np.sum(e, axis=axis), axis)
return e / sum_

def init(self, gal2: GAL2):

W_att = np.random.uniform(-1, 1, (1, GAL1.num_nodes))
print(‘W_att\n’, W_att, ‘\n’)

a = W_att @ gal2.XatWT_concat.T
print(‘a\n’, a, ‘\n’)

e = GAL3.leaky_relu(a)
print(‘e\n’, e, ‘\n’)

E = gal2.reshape(e)
print(‘E\n’, E, ‘\n’)

W_alpha = GAL3.softmax2D(E, 1)
print(‘W_alpha\n’, W_alpha, ‘\n’)

self.left = gal2.A.T @ W_alpha
print(‘left\n’, self.left, ‘\n’)


GAL3是我们引入非线性(leaky_relu)和归一化(softmax2D)操作的类。GAL3最终将生成原始方程的整个左侧,仅剩右侧gal1.XatWT未处理。GAL3的唯一"输出"是self.left。
以下是GAL3中的前四个计算步骤:
  • W_att:初始化或来自前一轮训练

  • a:W_att与gal2.XatWT_concat的矩阵乘法

  • e:对a应用leaky_relu函数

  • E:调用gal2.reshape方法,传入e作为输入
这四个计算的结果如下:
W_att
[[-0.76345115  0.27984204 -0.71329343  0.88933783]]

a
[[-0.1007035 -0.35942847 0.96036209 0.50390318 -0.43956122 -0.69828618
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]

e
[[-0.0201407 -0.07188569 0.96036209 0.50390318 -0.08791224 -0.13965724
0.79964181 1.8607074 1.40424849 0.64260322 1.70366881 1.2472099 ]]
E
[[-0.0201407 -0.07188569 0.96036209 0.50390318]
[-0.08791224 -0.13965724 0. 0. ]
[ 0.79964181 0. 1.8607074 1.40424849]
[ 0.64260322 0. 1.70366881 1.2472099 ]]


GAL3中的最后两个计算步骤:
  • W_alpha:对E应用softmax函数。

  • self.left:gal2.A.T与W_alpha的矩阵乘法。
结果如下:
W_alpha
[[0.15862414 0.15062488 0.42285965 0.26789133]
[0.24193418 0.22973368 0.26416607 0.26416607]
[0.16208847 0.07285714 0.46834625 0.29670814]
[0.16010498 0.08420266 0.46261506 0.2930773 ]]

left
[[0.72275177 0.53741836 1.61798703 1.12184284]
[0.40055832 0.38035856 0.68702572 0.5320574 ]
[0.48081759 0.30768468 1.35382096 0.85767677]
[0.48081759 0.30768468 1.35382096 0.85767677]]


GAL3的唯一"输出"是left,因此它被赋值给self。
至此,我们已经计算出原始方程的左侧和右侧(gal1.XatWT)。
GAL4的代码实现及主函数如下:
class GAL4:

def init(self, gal1: GAL1, gal3: GAL3):

self.H = gal3.left @ gal1.XatWT
print(‘H\n’, self.H, ‘\n’)

if name == ‘main’:

gal_1 = GAL1()
gal_2 = GAL2(gal_1)
gal_3 = GAL3(gal_2)
gal_4 = GAL4(gal_1, gal_3)


最终结果H为:
H
[[-1.10126376  1.99749693]
[-0.33950544  0.97045933]
[-1.03570438  1.53614075]
[-1.03570438  1.53614075]]


在这里,我们将原始方程的左侧和右侧进行矩阵乘法运算,得到最终结果。

图注意力层的结构分析

从文章开头的图和上面"main"中的代码可以看出,每个GALx仅依赖于前一个GAL(x-1),除了GAL4,它同时依赖于GAL1和GAL3。通过对代码进行分类和封装,我们使其结构更加清晰,从而更易于理解。
该图由位置(椭圆)和转移(矩形)组成,因此被称为位置-转移图。在本文中,我们仅针对GAL特定实现的位置-转移图进行直观分析。有关位置-转移图的更详细信息,请参考我之前的文章(参考文献[PT-GNN-TD])中的"位置-转移图基础"部分。
下面我们将详细分析GAL位置-转移图的各个组成部分。
GAL1结构相对简单,仅执行一次矩阵乘法运算。但其结果是原始方程的整个右侧,也是GAL2和GAL4的主要非邻接相关输入。
将这两个组件合并分析是因为它们之间的连接较为紧密。GAL3利用了GAL2的值A和XatWT_concat,以及GAL2的方法reshape。我们通过标记来自输入引用gal2的弧线来突出每个值或方法的使用位置。
同样,GAL2的connections使用虚线表示,因为它仅在公开方法reshape中使用。
GAL2专注于矩阵操作,是邻接矩阵A"注入"到原始方程的关键点。因此,GAL2是以图结构为中心的组件。
GAL3同样执行矩阵操作,但其核心功能是应用非线性函数(leaky_relu)和归一化操作(softmax)。注意力权重矩阵W_att的引入对GAL3的功能也至关重要。GAL3是以注意力机制为中心的组件。
与GAL1类似,GAL4的结构也相对简单,仅执行一次矩阵乘法。它将方程的左侧gal3.left与右侧gal1.XatWT结合。GAL4是唯一一个接收来自多个组件输入的类,因此它扮演着"混合器"的角色,在"串联"和"并联"模式下连接节点特征、邻接关系和注意力机制。

核心代码

以下是实际PyG库中GATv2Conv的核心代码,涵盖了我们使用NumPy模拟的大部分功能:
def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
dim_size: Optional[int]) -> Tensor:
x = x_i + x_j

some conditional edge code removed… // 删除了一些条件边缘代码…

x = F.leaky_relu(x, self.negative_slope)
alpha = (x * self.att).sum(dim=-1)
alpha = softmax(alpha, index, ptr, dim_size)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha

def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return x_j * alpha.unsqueeze(-1)


忽略MessagePassing的部分复杂性,我们可以看到实际的PyG代码与我们的NumPy实现在核心逻辑上非常相似。

总结

通过本文的分析,我们已经深入剖析了图神经网络自注意力机制的内部工作原理。从数学表达式到代码实现再到可视化图形,我们提供了一个全方位的视角来理解注意力权重如何在图结构数据中生成和应用。
通过位置-转移图的概念框架,我们不仅展示了计算流程,还揭示了各组件之间的依赖关系,为图神经网络的可解释性研究提供了新的思路。
作者:John Baumgarten
编辑:黄继彦



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

我认为易用性和性能可以通过分层设计来兼顾。底层使用高性能的C++或CUDA实现核心计算,上层提供Python API,方便用户使用。同时,要注重代码的注释和文档编写,方便后续的维护和扩展。

其实这和数据流图有点类似,本质上都是一种计算图。只要网络结构能用计算图表示,就可以用类似位置-转移图的方法进行分析。比如,可以用它来分析卷积神经网络中不同卷积层之间的信息传递。

还有一点,NumPy的debug更方便。PyTorch Geometric封装程度高,出错时定位问题比较困难。NumPy代码相对简单,可以逐行调试,更容易找到bug。

从研究的角度看,用NumPy实现可以更自由地定制和修改模型结构,方便进行创新性实验。PyTorch Geometric的模块化设计虽然方便,但也限制了一些实验的灵活性。

我会优先考虑易用性,提供简洁的API和丰富的文档,降低用户的上手门槛。性能方面,可以利用CUDA等技术进行加速,并提供性能优化的选项。可维护性方面,要注重代码的模块化和接口设计,并编写充分的单元测试。

我觉得位置-转移图的核心在于对计算流程的解耦和可视化。不仅可以用于理解GNN,还可以用于优化模型结构。例如,通过分析位置-转移图,可以发现计算瓶颈,从而有针对性地进行优化。

我觉得可能还因为NumPy更适合做原型验证和教学演示。PyTorch Geometric虽然功能强大,但学习曲线较陡峭,上手成本高。用NumPy可以快速实现核心逻辑,方便理解和修改,特别是在教学场景下,更易于学生掌握。

代码可维护性真的很重要!我会强制执行code review,统一代码风格,并定期进行代码重构。另外,要建立完善的issue tracking系统,及时修复bug和响应用户反馈。

“位置-转移图”有点像电路图,把GNN的各个计算模块和数据依赖关系清晰地展现出来,有助于理解信息的传递路径和各个模块的功能。我觉得可以应用到Transformer这种结构复杂的网络中,可视化attention的计算过程,方便理解。