博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用图神经网络(GNN)寻找最短路径
阅读量:6240 次
发布时间:2019-06-22

本文共 5816 字,大约阅读时间需要 19 分钟。

使用图神经网络(GNN)寻找最短路径

在本文中,我们将展示具有关注读写功能的图形网络如何执行最短路径计算。经过最少的培训后,该网络可以100%的准确率执行此任务。

1

引言

在Octavian,我们相信图是表示复杂知识的强大媒介(例如BenevolentAI使用它们来代表药物研究和知识)。

神经网络是一种创造人类无法表达的函数的方法。人们使用大型数据集对网络进行训练。对于神经网络可以处理的模型,我门使用示例的方法训练神经网络拟合输入和输出的关系.

我门需要能够从图结构中进行学习的神经网络,这些神经网络学习有效的归纳偏置用以可靠的学习处理图中的函数.在这个基础上,我门建立了一个强大的神经图系统.

在这里我门提出了一个’面向读写的图网络’,一个可以有效处理最短路径的简单的网络.它是如何组合不同神经网络组件以使系统易于学习经典图算法的一个例子。

这个网络本身是一个新的运算系统,但更重要的是,网络是进一步研究神经图计算的基础。

代码在这里

问题陈述

考虑一个问题“站A和B之间最短路径的长度是多少?

如果考虑一个图中任意两点的最短路径呢?

我们想要通过训练神经网络返回整数的答案.

相关工作

机器学习图表是一个年轻但不断发展的领域。

详细的请参见综述文章Graph Neural Networks: A Review of Methods and Applications or our introduction

用于计算最短路径的经典算法是A-star,Dijkstra和Bellman Ford算法。

这些方法有效并且可以广泛的被应用。 Dijkstra与我们的用例最相似,即在没有路径成本启发的情况下找到两个特定节点之间的最短路径。

最早的基于神经的最短路径解决方案的工作是通过通信和分组路由来驱动的,这样的近似算法比经典算法速度更快。 这些操作与当今的神经网络完全不同,它们使用迭代反向传播来解决特定图形上的最短路径。 该领域的工作示例包括 Neural networks for routing communication traffic (1988), A Neural Network for Shortest Path Computation (2000) 和 Neural Network for Optimization of Routing in Communication Networks (2006).

本工作建立了一个可以在未知结构图中工作的模型,与前文提到的只能解决某个图中问题的方法形成鲜明的对比.另外,我门寻求为从输入输出对中寻找解决复杂图问题运算提供基础.

最近一个突破性的解决方法在 Differentiable neural computers, 它通过将图形作为连接元组序列并使用读写存储器学习逐步算法来实现这一点。训练的过程以一个学习计划的形式提供,逐步的提高图和问题的规模.

相比之下,我们的解决方案在更大的路径(长度9对4)上表现更好(100%对55.3%),不需要计划学习,不需要训练LSTM控制器,参数更少,网络更简单组件更少。虽然我们还没有找到任何其他公布的解决方案来解决这个问题,但是有很多类似技术被用于不同的问题。几个相关的例子:

• Commonsense Knowledge Aware Conversation Generation with Graph Attention 使用注意力来读出知识图

• Deeply learning molecular structure-property relationships using attention- and gate-augmented graph convolutional network 在每一个图节点使用GRU模型,在节点处使用注意力机制
• DeepPath: A Reinforcement Learning Method for Knowledge Graph Reasoning 使用策略网络在图中进行导航据我们所知,我们的第一个将注意力读写与图形网络相结合的例子

我门将要解决的问题

提出问题,状态1和状态15之间有多少个状态,我们需要的正确答案可能是6.

更具体地说,我们将使用Graph-Question-Answer元组训练网络。每个元组包含一个独特的随机生成的图形,一个英语语言问题和预期的答案。

比如:

3

这些数据被分为不重叠的训练集,验证集和测试集。

这些数描述的网络将用于以前从未见过的新图形。也就是说,它将学习图算法。

我门将使用 CLEVR-Graph 数据集来描述这些问题.

CLEVR-Graph介绍

在构建机器学习解决方案而方案不要求高精度时,很难知道模型是否存在缺陷,或者数据是否具有固有的噪声和模糊性。

为了消除这种不确定性,我们使用了手工数据集。 也就是说我们根据自己的规则生了数据集。 由于数据结构明确,一个好的模型可以获得100%的准确率。 在比较不同的架构时,这确实很有用。

CLEVR图包含一组有关程序生成的传输网络图的问题和答案。 以下是其中一个传输网络的样子(以伦敦地铁为模型)以及一些示例问题和答案:

4

CLEVR-Graph中的每个问题都带有一个答案和一个产生的图。

CLEVR-Graph可以生成许多不同类型的问题。

在本文中,我们将生成与最短路径相关的那些。 通过模版(“A和B之间有多少个站?”),它与每个随机生成的图形中随机选择的一对站点组合在一起,给出一个( 图形 - 问题 - 答案)三元组。

图形-问题-答案三元组生成为YAML文件,然后我们将其编译为TFRecords。

由于只有一个问题模板,培训数据缺乏多样性。 你会得到一个更自然(人类)的来源。 这使数据集更容易解决。 我们将语言多样性作为未来的延伸挑战(并希望看到读者的解决方案!)。

5

解决方案

我们在tensorflow建立了神经网络来解决这个问题,代码在这里. The code for this system is available in our repository.

我们将构建的系统需要一个问题,执行多次迭代处理,然后最终生成一个输出:

6

我们将使用的结构是循环神经网络(RNN) - 在RNN中,相同的单元被顺序执行多次,将其内部状态向前传递到下一次执行。

RNN单元将问题和图形作为输入,以及来自单元的早期执行的任何输出。对它们进行变换,并由单元生成输出向量和更新的节点状态。

RNN单元内部有两个主要组件:图形网络和输出单元。他们的细节是理解这个网络如何运作的关键。我们将在下一节详细介绍这些内容。

RNN单元向前传递隐藏状态,即“节点状态”。这是节点状态表,图中每个节点一个向量。网络使用它来跟踪每个节点的正在进行的计算。

RNN单元执行固定次数(通过实验确定,通常比两个节点之间的最长路径长),然后将最终小区的输出用作系统的总输出。

这样就完成了对整体结构的简要介绍。接下来的部分将概述网络的输入,以及RNN单元的工作原理。

数据输入

T建立系统的第一步是建立输入数据的管道,这提供了3件事.

• 输入的问题是“两个节点间的距离是”
• 输入的图结构是
• 期望的输出是

所有这些都被预处理成TFRecords,因此可以有效地加载它们并传递给模型。 此过程的代码位于随附的GitHub存储库中的build.py中。 您也可以下载预编译的TFRecords。

问题文本的输入

将英文问题转化成信息使用如下三个步骤

• Split the text into a series of ‘tokens’ (e.g. common words and special characters like ? and spaces)将问题区分为一系列的tokens
• 为每个唯一标记分配一个整数ID,并将该标记的每个实例表示为该整数ID
• 将每个标记(例如,单词,特殊字符)嵌入为矢量。此步骤在模型运行时完成,对于这个简单的示例,我们使用单热矢量来编码整数。

图的输入

该图由TFRecord示例中的三个数据结构表示:

1,具有id,名称和属性的节点列表
2,边列表及其源节点ID和目标节点ID以及边属性
3,邻接矩阵,映射节点之间的连接。 如果两个节点直接连接则为1.0,否则为0.0。
使用多维张量进行描述

7

期望输出

期望输出(对于该数据集,始终是从0到9的整数)表示为单个文本标记(即,作为整数),使用与问题文本和节点/边缘属性相同的编码方案。

在训练模式期间使用预期答案进行损耗计算和反向传播,在验证和测试期间,它用于测量模型精度并确定用于调试的失败数据示例。

RNN 如何工作

网络的核心是RNN。 它由一个RNN单元组成,该单元被重复执行,并将其结果向前传递。

8

在我们的实验中,我们使用了10次RNN迭代(通常,迭代次数需要大于或等于要测试的最长路径)。

这个RNN单元每次迭代都会做四件事:

1,将数据写入选定的节点状态
2,沿图中的边沿传播节点状态
3,从选定节点状态读取数据
4,获取读到的数据、所有先前RNN单元的输出,组合它们,并为此RNN迭代生成输出
只需这四个步骤,网络就能够轻松学习如何计算最短路径。

图网络

图形网络是该模型功能的关键。 它使它能够计算图形结构的功能。

在图形网络中,每个节点n在时间t具有状态向量S(n,t)。 我们使用宽度为4的状态向量。每次迭代,节点状态都传播到节点的邻居adj(n):

9

上标和下标用于公式效果图中,以便于理解而不是函数符号S(n,t)

初始状态S(n,0)是零向量。
这种简单的状态传播需要两个以上的部分才能进行最短路径计算:节点状态写入和节点状态读取。

节点状态写入和节点状态读取

节点状态写入是模型将信号向量添加到图中特定节点的状态的机制:

10

该机制首先从问题中提取单词,以形成写入查询q_write。此查询将用于选择节点状态以将写入信号p添加.

使用索引关注生成写查询,其计算问题词Q中应该关注哪些索引(作为RNN迭代id r,单热矢量的函数),然后将它们提取为加权和:

11

通过获取RNN迭代id并应用具有S形激活的密集层来计算写信号。

12

接下来,写入信号和写入查询被内容层馈送到注意,以确定如何将写入信号添加到节点状态。 内容注意力只是标准的点积注意机制,其中每个项目与点积的查询进行比较,以产生一组分数。 然后通过softmax将得分转换成一个总和为1的分布:

13

在这种情况下,分数被计算为每个节点状态的相关节点id与写查询的点积。 最后,写入信号与分数成比例地添加到节点状态:

14

节点状态读取

15

通过获取RNN迭代id并应用具有S形激活的密集层来计算写信号.

16

接下来,以与信号写入方式类似的方式从图中读取状态。从输入的问题单词计算读取查询,再次使用索引注意:

17

然后使用节点状态的加权和计算最终读出值:

18

输出单元

RNN的最后一个重要部分是输出单元。这对网络的成功至关重要(删除之前的输出回调会将精度降低到95%)。

输出单元的总体如.

20

输出单元有两部分:

通过索引注意先前的输出和最近的图形网络读取(这与上一节中的读取和写入查询相同)

一个基本的前馈网络,用于将注意力输出转换为单元的输出

输出单元可以将早期迭代的输出与当前图形网络输出组合。 这允许单元重复组合先前的输出,提供简单递归的形式。 这也有助于网络轻松回顾早期迭代的输出,而不管RNN迭代的总数。

结果

网络的训练

网络的超参数是通过实验确定的。 使用 Learning Rate Finder protocol识别学习速率,并且通过网格搜索确定其他参数,例如节点状态大小,图形读/写头的数量和RNN迭代的数量。

在9k训练周期(MacBook Pro CPU上2分钟)后,网络可实现100%的测试准确度。 这种快速收敛表明网络对解决这个问题有很强的可归纳性。

21

22

观察模型是如何工作的

本文对测模式注意力进行可视化,让您了解网络正在做什么。 它显示了读取,写入和输出注意力部分的工作情况:

23

注意力主要用于显而易见的方式:

每一步都从第一个提到的站的节点状态读取

每一步都从第二个提到的站的节点状态写入

输出单元主要使用来自网络的读取值,但通常将其(至少部分地)与其他步骤的输出组合

方法的效率

对于问题的任何解决方案,值得将其与其他方法进行比较。 在这里,我们将此模型与可微分神经计算机和标准经典方法进行比较。与经典方法Dijkstra相比,这种方法(实际上大多数神经方法)效率较低:这种模式需要适量的初始训练,而Dijkstra则不需要在预测模式期间,该模型比Dijkstra执行更多操作,尽管由于它们是并行矩阵操作,由于专用硬件(例如GPU),它们可能具有类似的运行时间这两种方法有相似的运行时间:Dijksta的标度O(| E | + | N | log | N |)其中E是边和N个节点,这个方法缩放O(|最长路径|)~O(| E |)

但是,我们的方法有一个主要的好处,它有可能根据训练样例学习不同的功能。

与Differentiable Neural Computers方法相比我门的方法有更好的效果.

1,与DNC相比,该方法实现了更高的精度和可扩展性 - 100%(长度为9的路径)与55.3%(长度为4的路径)相比

2,该方法不需要构建和管理学习计划

3,我们怀疑这种方法比DNC需要更少的培训资源(我们在笔记本电脑CPU上2分钟后获得100%的准确度),尽管DNC论文中没有公布数据

4,这种方法是一个更简单的网络,具有更少的读头(Read head)(1对5),更小的内存状态(64个元素对128)和没有LSTM单元

5,我们怀疑这种架构更容易扩展到更大的图形,因为我们并行化图形探索(例如DNC需要更多内存,读取头和运行时来处理更大的图形或更高的边缘密度)

学习其他函数

作为这项工作的一部分,我们探索了在每个节点使用门控循环单元(GRU)作为节点状态更新功能。 这有效,但由于参数增加而额外的培训工作没有带来任何好处,因此最终GRU被禁用。 我们将使用所提出架构的扩展来留下未来的工作,以学习不同的图形函数。

转载地址:http://qgdia.baihongyu.com/

你可能感兴趣的文章
为域用户创建漫游用户配置文件
查看>>
设置域用户只能登陆到特定的计算机
查看>>
将博客搬至CSDN
查看>>
逼自己一把,你就优秀了
查看>>
我的友情链接
查看>>
sql server 第二讲
查看>>
我的友情链接
查看>>
我的友情链接
查看>>
mysql导出表结构
查看>>
Log4j使用总结
查看>>
Mysql主主复制原理及配置
查看>>
nginx编译安装
查看>>
我的友情链接
查看>>
我的友情链接
查看>>
模拟video播放器
查看>>
防杀病毒的12项纪律
查看>>
拦截器的执行顺序
查看>>
Nginx+Tomcat实现动静分离
查看>>
Linux网络配置
查看>>
python之九九乘法表
查看>>