继续探索Graph OOD的相关问题,与以往工作不同的是,这篇工作避开了复杂的数学推导和琐碎的数据生成过程,直接从简单有效的判别模型入手研究图上的OOD检测问题。
论文题目:Energy-based Out-of-Distribution Detection for Graph Neural Networks
论文链接:
代码链接(含实验细节说明):
尽管针对图数据的学习方法目前已取得了空前的进展,绝大部分现有的方法都假设训练数据与测试数据来自同一分布。目前有大量研究表明,现有的图深度学习模型(如图神经网络)通常在分布外数据上表现差强人意,这也使得针对图数据分布外泛化(Out-of-Distribution Generalization,简称OOD泛化)问题的研究[1,2]逐渐流行起来。
OOD泛化问题旨在解决训练和测试分布不一致的问题,其学习目标是为了提升模型在新的未知分布的测试数据上的性能。而另一类比较常见的实际问题,是分布外检测(Out-of-Distribution Detection)[3,4],问题定义为:
当分类器在有限观测的训练数据上完成训练后,需要具备识别测试集中不同于训练主体分布的数据的能力。
OOD检测在一些对安全性要求较高的领域(如自动驾驶、医疗诊断、风险投资)具有重要的实际价值。尽管OOD检测在图像领域已被广泛研究,但针对图数据的OOD检测还是一个几乎未被探索的领域[5]。
首先,从整体上看,与图片不同的是,图结构数据中的每个样本通常是图上的节点。由于节点互联的特性,节点样本之间存在着依赖关系,导致了样本的非独立性。因此,在对OOD样本进行判定时,需要考虑到这种数据依赖关系(data inter-dependence)。
下面我们对图上的OOD检测问题给出定义。假设输入数据样本构成了一个图G=(V, E),V是节点集合,E是连边集合,使用A表示邻接矩阵。图中每个节点i都是一个样本,包含输入特征\mathbf x_i和标签\mathbf y_i。图中的节点集合\mathcal I分成了训练集\mathcal I_s和测试集\mathcal I_u。定义X = [\mathbf x_i]_{i\in \mathcal I}和Y = [\mathbf y_i]_{i\in \mathcal I},我们需要训练一个节点分类器f,它能预测节点的标签\hat Y = f(X, A)。此外,更重要的是,这一分类器具备识别分布外样本的能力。具体的,考虑一个由f产生的决策函数G(\mathbf x, \mathcal G_{\mathbf x}; f),使得对于任意输入\mathbf x有
G(\mathbf x, \mathcal G_{\mathbf x}; f) = \begin{cases} 1, \quad &\mathbf x ~~\mbox{is an in-distribution instance}, \\ 0, \quad &\mathbf x ~~\mbox{is an out-of-distribution instance}, \end{cases} \\
其中\mathcal G_{\mathbf x}表示节点\mathbf x在图中对应的邻居子图。
本文提出的方法主要基于简单有效的设计原则。首先,对于输入图首先考虑一个图神经网络h_\theta来得到节点的表征。具体的,如果采用图卷积网络(GCN),其节点表征的更新公式如下:
Z^{(l)} = \sigma \left ( D^{-1/2} \tilde A D^{-1/2} Z^{(l-1)} W^{(l)} \right ), \quad Z^{(l-1)} = [\mathbf z_i^{(l-1)}]_{i\in \mathcal I}, \quad Z^{(0)} = X, \\
在上式中节点表征的计算依赖于图中相邻的节点,从而将样本间的依赖关系建模了出来。通过L层图卷积之后,将最后一层的输出结果\mathbf z_i^{(L)} = h_\theta(\mathbf{x}, \mathcal G_{\mathbf x})作为logits用于对节点标签的预测,即模型给出的预测分布可以写为
p(y \mid \mathbf{x}, \mathcal G_{\mathbf x})=\frac{e^{h_\theta(\mathbf{x}, \mathcal G_{\mathbf x})_{[y]} }}{\sum_{c=1}^{C} e^{h_\theta(\mathbf{x}, \mathcal G_{\mathbf x})_{[c]}}}. \\
用于OOD检测的能量函数 已有的研究[6]表明,当假设E(\mathbf{x}, \mathcal G_{\mathbf{x}}, y; h_\theta)=-h_\theta(\mathbf{x}, \mathcal G_{\mathbf x})_{[y]}时,上式可以看作一个玻尔兹曼分布(Boltzmann distribution)
p(y | \mathbf x, \mathcal G_{\mathbf x}) = \frac{e^{-E(\mathbf x, \mathcal G_{\mathbf x}, y)}}{\sum_{y'} e^{-E(\mathbf x, \mathcal G_{\mathbf x}, y')}} = \frac{e^{-E(\mathbf x, \mathcal G_{\mathbf x}, y)}}{e^{-E(\mathbf x, \mathcal G_{\mathbf x})}}, \\
这里的E(\mathbf{x}, \mathcal G_{\mathbf{x}}, y; h_\theta)称为分类器h_\theta对应的给定标签y下的能量函数,而通过对y进行marginalization可以得到对于输入(\mathbf{x}, \mathcal G_{\mathbf{x}})的自由能量函数:
E(\mathbf{x}, \mathcal G_{\mathbf{x}} ; h_\theta) = - \log \sum_{c=1}^{C} e^{h_\theta(\mathbf{x}, \mathcal G_{\mathbf{x}})_{[c]}}. \\
这一能量函数对每个输入节点都能返回一个能量值,它可以衡量分类器对图中节点的置信度,即作为判别是否是OOD样本的依据。
基于能量的信任传播 为了进一步的利用图结构产生的样本依赖性,我们提出了基于能量的信任传播,具体实现为将每个节点的能量值沿着输入图进行信息传递:
\mathbf E^{(k)} = \alpha \mathbf E^{(k-1)} + (1 - \alpha) D^{-1} A \mathbf E^{(k-1)}, \quad \mathbf E^{(k)} = [E^{(k)}_i]_{i\in \mathcal I}, \quad \mathbf E^{(0)} = [E(\mathbf{x}_i, \mathcal G_{\mathbf{x}_i} ; h_\theta)]_{i\in \mathcal I}. \\
这样做的好处是,可以使得分类器产生的置信度沿着图结构加强,即利用全局的信息来强化个体节点的OOD估计。具体的,上述的能量传播可以使得更新后的能量值朝着邻居节点中的“大多数的声音”靠近,即如果邻居中多数节点都是OOD样本则该节点属于OOD的倾向性更大。我们在论文的3.2节对这一性质给出了理论证明,并且在实验中通过大量的消融实验验证了这一简单方法的有效性。
损失函数 在模型训练方面,我们考虑两种可能的情形,以分别适用于两种被广泛研究的OOD检测问题。第一种情形是训练集中仅包含主体分布数据(即分布内训练数据\mathcal I_s),此时可以使用标准的分类损失函数训练图神经网络分类器(我们称提出的方法叫GNNSafe):
\mathcal L_{sup} = \mathbb E_{(\mathbf x, \mathcal G_{\mathbf x}, y) \sim \mathcal D_{in}} \left ( -\log p(y \mid \mathbf{x}, \mathcal G_{\mathbf x}) \right ) = \sum_{i\in \mathcal I_{s}} \left (- h_\theta(\mathbf x_i, \mathcal G_{\mathbf x_i})_{[y_i]} + \log \sum_{c=1}^C e^{h_\theta(\mathbf x_i, \mathcal G_{\mathbf x_i})_{[c]}} \right ). \\
另一种情形是训练数据中还额外包括已知的分布外数据(表示为\mathcal I_o),此时常见方法是引入一个额外的正则项,例如可以对模型输出的能量值进行上下界约束[5](我们称提出的方法叫GNNSafe++):
\mathcal L_{ref} = \frac{1}{|\mathcal I_{s}|}\sum_{i\in \mathcal I_{s}}\left(\mbox{ReLU} \left(\tilde E\left(\mathbf{x}_i, \mathcal G_{\mathbf{x}_i}; h_\theta\right)-t_{in}\right)\right)^{2} +\frac{1}{|\mathcal I_{o}|}\sum_{j\in \mathcal I_{o}} \left(\mbox{ReLU} \left(t_{out}-\tilde E\left(\mathbf{x}_j, \mathcal G_{\mathbf{x}_j}; h_\theta\right)\right)\right)^{2}. \\
对于第二种情形最终的损失函数可以写为加权和\mathcal L_{sup} + \lambda \mathcal L_{reg}。
由于图数据的分布外检测问题目前还有待探索,本文也对这一问题背景下如何有效和全面的评测模型的能力给出了系统的探讨,包括1)如何选择数据集,2)如何划分数据集,3)如何评估OOD检测的能力。
评估准则 首先,我们需要明确的是,与传统监督学习不同的是,OOD检测问题需要额外考虑分布外的测试数据(以及可能用到的训练数据)。下图展示了监督学习与OOD检测(包含两类问题)问题对数据集划分的要求。
数据集和划分 数据划分是非常重要的环节,需要考虑的是如何在不破坏原数据内在特性的情况下,引入分布差异。整体原则包含两点:
基于上述两个原则,我们进一步考虑两类常见的图数据集,对数据的划分方式描述如下图。
具体的,我们在实验里考虑了五个不同的数据集,根据它们不同的特性,采用不同的划分方式:
实验结果 下面的表格展示了在5个数据集上的OOD检测结果,这里采用常规的评测指标AUROC/AUPR/FPR95来衡量模型对IND-Te和OOD-Te样本估计值排序的正确性。这里我们统一使用GCN作为分类器主干,并在两种情形下进行各自的对比,即使用或不使用OOD exposure。可以看到,本文提出的方法GNNSafe显著好于其他同类的不使用OOD exposure的方法,而GNNSafe++取得了最好的性能。特别的,相比SOTA方法,在Twitch和Cora-Structure数据集,GNNSafe++对AUROC指标分别提升了12.8%和17.0%,而对FPR95指标分别降低了44.8%和21.0%。
为了进一步验证提出方法的有效性,我们也对两个关键模块能量信任传播和能量正则项进行了消融实验。下图分别绘制了三种方法在Twitch和Arxiv上对IND-Te和OOD-Te所估计的能量值分布。可以看到,相比于GNNSafe w/o energy propagation(不考虑能量信任传播和能量正则项)和GNNSafe(仅不考虑能量正则项),GNNSafe++所给出的能量分布能够更明显的把分布内和分布外的样本区分开。
此外,我们也探索了使用GNN backbone对模型性能的影响,下图分别考虑MLP,GCN,GAT,JKNet和MixHop作为主干,可以看到几种方法的相对优劣保持一致,这也进一步验证了GNNSafe在使用不同GNN主干时的优越性。
这一工作主要对图结构数据节点分布外检测的问题进行了初步探索,并提出了一种简单有效的方法叫作GNNSafe,可以作为这一(尚未被充分探索的)研究领域的强有力的基线方法。此外,还对如何针对不同数据集在数据划分中引入相应的分布偏移作了讨论,为图数据OOD检测提供了benchmarks。
当然,本文的方法以及提供的代码也可以很方便的进行拓展和延伸,包括但不限于:
[1] Qitian Wu, et al., Handling Distribution Shifts on Graphs: An Invariance Perspective, ICLR 2022.
[2] Jiaqi Ma, et al., Subgroup Generalization and Fairness of Graph Neural Networks, NeurIPS 2021.
[3] Dario Amodei et al., Concrete problems in ai safety, Arxiv 2016.
[4] Shiyu Liang et al., Enhancing the reliability of out-of-distribution image detection in neural networks, ICLR 2018.
[5] Zenan Li et al., Graphde: A generative framework for debiased learning and out-of-distribution detection on graphs, NeurIPS 2022.
[6] Will Grathwohl et al., Your classifier is secretly an energy based model and you should treat it like one, ICLR 2020.
[7] Weitang Liu et al., Energy-based out-of-distribution detection, NeurIPS 2020.
[8] Qitian Wu et al., NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification, NeurIPS 2022.
下一篇:黑洞是暗能量的来源