Home Generalist Neural Algorithmic Learner
Post
Cancel
Generalist Neural Algorithmic Learner

Generalist Neural Algorithmic Learner

A single graph neural network processor capable of learning to execute a wide range of algorithms, such as sorting, searching, dynamic programming, path-finding, and geometry.

\[\renewcommand{\V}[1]{\mathbf{#1}}\]

Another generalist learner in the neural algorithm domain (Ibarz et al., 2022)

Abstract

The cornerstone of neural algorithmic reasoning is the ability to solve algorithmic tasks, especially in a way that generalizes out of distribution. While recent years have seen a surge in methodological improvements in this area, they mostly focused on building specialist models. Specialist models are capable of learning to neurally execute either only one algorithm or a collection of algorithms with identical control-flow backbone. Here, instead, we focus on constructing a generalist neural algorithmic learner – a single graph neural network processor capable of learning to execute a wide range of algorithms, such as sorting, searching, dynamic programming, path-finding, and geometry. We leverage the CLRS benchmark to empirically show that, much like recent successes in the domain of perception, generalist algorithmic learners can be built by “incorporating” knowledge. That is, it is possible to effectively learn algorithms in a multi-task manner, so long as we can learn to execute them well in a single-task regime. Motivated by this, we present a series of improvements to the input representation, training regime, and processor architecture over CLRS, improving average single-task performance by over 20% from the prior art. We then conduct a thorough ablation of multi-task learners leveraging these improvements. Our results demonstrate a generalist learner that effectively incorporates knowledge captured by specialist models.

Single-task Experiments

Each algorithm in the CLRS benchmark (Veličković et al., 2022) is specified by a number of inputs, hints, and outputs. In a given sample, the inputs and outputs are fixed, while hints are time series of intermediate states of the algorithm. Each sample for a particular task has a size, \(n\), corresponding to the number of nodes in the GNN that will execute the algorithm.

A sample of every algorithm is represented as a graph, with each input, output, and hint located in either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension, and, for hints, time dimension) \(n \times f\), \(n \times n \times f\), or \(f\), respectively, \(f\) being the dimensionality of the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar, categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and loss functions — e.g. a scalar type will be encoded and decoded directly by a single linear layer, and optimized using mean squared error. We defer to the CLRS benchmark paper (Veličković et al., 2022) for further details.

Base Model

Encoder.

We adopt the same encode-process-decode paradigm (Hamrick et al., 2018) presented with the CLRS benchmark (Veličković et al., 2022) . At each time step, \(t\), of a particular task \(\tau\) (e.b. insertion sort), the task-based encoder \(f_\tau\), consisting of a linear encoder for each input and hit, embeds inputs and the current hints as high-dimensional vectors. These embeddings of inputs and hints located in the nodes have the same dimension and are added together; the same happens with hints and inputs located on edges and in the graph. In our experiments, we use the same dimension, \(h=128\), for node, edge, and graph embeddings. Thus, at the end of the encoding step for a time-step \(t\) of the algorithm, we have a single set of embeddings \(\Big \{ \V x_i^{(i)}, \V e_{ij}^{(t)}, \V g^{(t)} \Big \}\) , shapes \(n \times h\), \(n \times n \times h\), and \(h\), in the nodes, edges and graph, respectively. Note that this is independent of the number and type of the inputs and hints of the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS. Further, note that at each step, the input encoding is fed directly to these embeddings — this recall mechanism significantly improves the model’s robustness over long trajectories (Bansal et al., 2022) .

Processor.

The embeddings are fed into a processor \(P\), a GNN that performs one step of the computation. The processor transforms the input node, edge, and graph embeddings into processed node embeddings \(\V h_i^{(t)}\), as inputs. Importantly, the same processor model can operate on graphs of any size. We leverage the message-passing neural network [ (Gilmer et al., 2017) , MPNN], using the \(\max\) aggregation and passing messages over a fully-connected graph, as our base model. The MPNN computes processed embeddings as follows

\[\V z^{(t)} = \V x_i^{(t)} \| \V h_i^{(t-1)} \qquad \V m_i^{(t)} = \max_{1\le j\le n} f_m \left ( \V z_i^{(t)}, \V z_j^{(t)}, \V e_{ij}^{(t)}, \V g^{(t)} \right) \qquad \V h_i ^{(t)} = f_r \left ( \V z_i^{(t)}, \V m_i^{(t)} \right)\]

starting from \(\V h^{(0)} = \V 0\). Here \(\|\) denotes concatenation, \(f_m : \R^{2h} \times \R^{2h} \times \R^h \times \R^h \rightarrow \R^h\) is the message function (for which we use a three-layer MLP with ReLU activations), and \(f_r: \R^{2h} \times \R^h \rightarrow \R^h\) is the readout function (for which we use a linear layer with ReLU activation). The use of the \(\max\) aggregator is well-motivated by prior work (Veličković et al., 2022) (Veličković et al., 2019) , and we use the fully connected graph — letting the neighbors \(j\) range over all nodes \((1\le j \le n)\) — in order to allow the model to overcome situations where the input graph structure may be suboptimal. Layer normalization (Ba et al., 2016) is applied to \(\V h_i^{(t)}\) before using them further. Further details on the MPNN processor may be found in Veličković et al. (Veličković et al., 2022) .

Decoder.

The processed embeddings are finally decoded with a task-based decoder \(g_\tau\), to predict the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder relies mainly on a linear decoder for each hint and output, along with a mechanism to compute pairwise node similarities when appropriate. Specifically, the pointer type decoder computes a score, \(s_{ij}\), for each pair of nodes, and then chooses the pointer of node \(i\) by taking either the \(\arg\max_j s_{ij}\) or \(\mathrm{softmax}_j s_{ij}\) (depending on whether a hard or soft prediction is used).

Loss.

The decoded hints and outputs are used to compute the loss during training, according to their type (Veličković et al., 2022) . For each sample in a batch, the hint prediction losses are averaged across hints and time, and the output loss is averaged across outputs (most algorithms have a single output, though some have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing is used. (See Section “Dataset and Training”)

We train the model on samples with sizes \(n\le 16\), and periodically evaluate them on in-distribution samples of size \(n=16\). Also, periodically, we evaluate the model with the best in-distribution evaluation score so far on OOD samples of size \(n = 64\). In what follows, we will be reporting only these OOD evaluation scores. Full details of the model, training, and evaluation hyperparameters can be found in Appendix A.

Model improvements

As previously discussed, single-task improvements, especially in terms of learning stability, will empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner, all the changes made to the model, which have led to an absolute improvement of over 20% on average across all 30 tasks in CLRS.

Dataset and training

Removing teacher forcing.

At evaluation time, the model has no access to the step-by-step hints in the dataset, and has to rely on its own hint predictions.

Augmenting the training data.

To prevent our model from over-fitting to the statistics of the fixed CLRS training dataset (Veličković et al., 2022) , we augmented the training data in three key ways, without breaking the intended examples on the fly, rather than using a fixed dataset which is easier to overfit to.

Soft hint propagation.

When predicted hints are fed back as inputs during training, gradients may or may not be allowed to flow through them.

Static hint elimination.

Eleven algorithms in CLRS1 specify a fixed ordering of the nodes, common to every sample, via a node pointer hint that does not ever change along the trajectories.

Improving training stability with encoder initialization and gradient clipping. The scalar hints have unbounded values, in principle, and are optimized using mean-squared error, hence their gradients can quickly grow with increasing prediction error.

Encoders and decoders

Randomized position scalar.

Across all algorithms in the dataset, there exists a position scalar input that uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node index. To avoid overfitting to these linearly spaced values during training, we replaced them with random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to these positions, such as string matching. Namely, the model could learn to base all of its computations on the assumption that it will always be finding a \(m\)-character pattern inside an \(n\)-character string, even though at test time, \(m\) and \(n\) will increase fourfold.

Permutation decoders and the Sinkhorn operator.
Sorting algorithms (Insertion Sort, Bubble Sort, Heapsort, and Quicksort) always output a permutation of the input nodes. In the CLRS benchmark, this permutation is encoded as a pointer where each node points to its predecessor in the sorted order (the first node points to itself); this is represented as a \(n \times n\) matrix \(\V P\) where each row is a one-hot vector, such that element \((i, j)\) is 1 if node \(i\) points to node \(j\). As with all types of pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained decoder outputs (logits), trained with cross entropy (As in (Veličković et al., 2022) ). However, this does not explicitly take advantage of the fact that the pointers encode a permutation, which the model has to learn instead. Our early experiments showed that the model was often failing to predict valid permutations OOD.

Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as follows. First, we modify the output representation by rewiring the first node to point to the last one, turning \(\V P\) into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We also augment the representation with a one-hot vector of size \(n\) that specifies the first node, so we do not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the permutation matrix \(\V P\) from unconstrained decoder outputs \(\V Y\) by replacing the usual row-wise softmax with the Sinkhorn operator \(\mathcal{S}\) [ (Sinkhorn, 1964) , (Sinkhorn & Knopp, 1967) , (Santa Cruz et al., 2017) , (Mena et al., 2017) , (Mena et al., 2018) ]. \(\mathcal{S}\) projects an arbitrary square matrix \(\V Y\) into a doubly stochastic matrix \(\mathcal{S}(\V Y)\) (a non-negative matrix whose rows and columns sum to 1), by exponentiating and repeatedly normalizing rows and columns so they sum to 1. Specifically, \(\mathcal{S}\) is defined by:

\[\mathcal{S}^0 (\V Y) = \exp (\V Y) \qquad \mathcal{S}^l (\V Y) = \mathcal{T}_c (\mathcal{T}_R(\mathcal{S}^{l-1}(\V Y))) \qquad \mathcal{S}(\V Y) = \lim_{l\to\infty}\mathcal{S}^l(\V Y),\]

where \(\exp\) acts element-wise, and \(\mathcal{T}_r\) and \(\mathcal{T}_c\) denote row and column normalisation respectively. Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix we can obtain a permutation matrix by introducing a temperature parameter, \(\tau > 0\) , and taking \(\V P = \lim_{\tau\to0^+}\mathcal{S}(\V Y / \tau)\) ; as long as there are no ties in the elements of \(\V Y, \V P\) is guaranteed to be a permutation matrix [ (Mena et al., 2017) , Theorem 1].

In practice, we compute the Sinkhorn operator using a fixed number of iterations \(l_{\max}\) . We use a smaller number of iterations \(l_{\max} = 10\) for training, to limit vanishing and exploding gradients, and \(l_{\max}= 60\) for evaluation. A fixed temperature \(\tau=0.1\) was experimentally found to give a good balance between speed of convergence and tie-breaking. We also encode the fact that no node points to itself, that is, that all diagonal elements of \(\V P\) should be 0, by setting the diagonal elements of \(\V Y\) to \(-\infty\). To avoid ties, we follow Mena et al. (Mena et al., 2018) , injecting Gumbel noise to the elements of \(\V Y\) prior to applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix \(\V P\), and mask_one pointing to the first element, into the original pointer representation used by CLRS.

Processor networks

Gating mechanisms.
\[\V g_i^{(t)} = f_g \left ( \V z_i^{(t)}, \V m_i^{(t)} \right )\]

where \(f_g : \R^{2h} \times \R^h \to \R_h\) is the gating function for which we use a two-layer MLP, with ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the final layer bias of \(f_g\) is initialized to a value of -3, which biases the network for not updating its representations, unless necessary. The processed gated embeddings, \({\widehat {\V h}_i^{(t)}}\), are computed as follows:

\[\widehat {\V h}_i^{(t)} = \V g_i^{(t)} \odot \V h_i^{(t)} + (1 - \V g_i^{(t)}) \odot \V h_i^{(t-1)}\]

and are used instead of \(\V h_i^{(t)}\) in the subsequent steps, replacing \(\V z^{(t)}\) in Eq. 1 by \(\V z^(t) = \V x_i ^{(t)} \| \widehat{\V h}_i^{t-1}\).

Triplet reasoning.

Several algorithms within CLRS-30 explicitly require edge-based reasoning — where edges store values and update them based on other edges’ values. An example of this is the Floyd-Warshall algorithm (Floyd, 1962) , which computes all-pairs shortest paths in a weighted graph. The update rule for \(d_{ij}\), its estimate for the best distance from node \(i\) to \(j\), is \(d_{ij} = \min_k d_{ik} + d_{kj}\), which roughly says “the best way to get from \(i\) to \(j\) is to find the optimal mid-point \(k\), travel from \(i\) to \(k\), then from \(k\) to \(j\)“. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic programming. Even though there are no node representations in the above update, all our processors are centered on passing messages between node representations \(\V h_i\).

To rectify this situation, we augment our processor to perform message passing towards edges. Referring again to the update for \(d_{ij}\), we note that the edge representations are updated by choosing an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković (Dudzik & Veličković, 2022) , we introduce triplet reasoning: first, computing representations over triplets of nodes, then reducing over one node to obtain edge latents:

\[\V t_{ijk} = \psi_t (\V h_i, \V h_j, \V h_k, \V e_{ij}, \V e_{ik}, \V e_{kj}, \V g) \qquad \V h_{ij} = \phi_t(\max_k \V t_{ijk})\]

Here, \(\psi_t\) is a triplet message function, mapping all relevant representations to a single vector for each triplet of nodes, and \(\phi_t\) is an edge readout function, which transforms the aggregated triplets for each edge for later use. According to prior findings on the CLRS benchmark (Veličković et al., 2022) , we use the max aggregation to obtain edge representations. The computed \(\V h_{ij}\) vectors can then be used in any edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree algorithm (Kruskal, 1956) , where we presume that access to triplet reasoning allowed the model to more easily sort the edges by weight, as it selects how to augment the spanning forest at each step.

In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only 8-dimensional features in \(\psi_t\). \(\phi_t\) then upscales the aggregated edge features back to 128 dimensions, to make them compatible with the rest of the architecture. Our initial experiments demonstrated that the output dimensionality of \(\psi_t\) did not significantly affect downstream performance. Note that computing triplet representations have been a useful approach in general GNN design (Morris et al., 2018) — however, it has predominantly been studied in the context of GNNs over constant input features. Our study is among the first to verify their utility over reasoning tasks with well-specified initial features.

Extra Reading Materials

NeuralExecuter++ (Xhonneux et al., 2021)

References

  1. Ibarz, B., Kurin, V., Papamakarios, G., Nikiforou, K., Bennani, M., Csordás, R., Dudzik, A., Bošnjak, M., Vitvitskyi, A., Rubanova, Y., Deac, A., Bevilacqua, B., Ganin, Y., Blundell, C., & Veličković, P. (2022). A Generalist Neural Algorithmic Learner. arXiv. doi: 10.48550/ARXIV.2209.11142 https://arxiv.org/abs/2209.11142
  2. Veličković, P., Badia, A. P., Budden, D., Pascanu, R., Banino, A., Dashevskiy, M., Hadsell, R., & Blundell, C. (2022). The CLRS Algorithmic Reasoning Benchmark. arXiv. doi: 10.48550/ARXIV.2205.15659 https://arxiv.org/abs/2205.15659
  3. Hamrick, J. B., Allen, K. R., Bapst, V., Zhu, T., McKee, K. R., Tenenbaum, J. B., & Battaglia, P. W. (2018). Relational inductive bias for physical construction in humans and machines. arXiv. doi: 10.48550/ARXIV.1806.01203 https://arxiv.org/abs/1806.01203
  4. Bansal, A., Schwarzschild, A., Borgnia, E., Emam, Z., Huang, F., Goldblum, M., & Goldstein, T. (2022). End-to-end Algorithm Synthesis with Recurrent Networks: Logical Extrapolation Without Overthinking. arXiv. doi: 10.48550/ARXIV.2202.05826 https://arxiv.org/abs/2202.05826
  5. Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017). Neural Message Passing for Quantum Chemistry. In D. Precup & Y. W. Teh (Eds.), Proceedings of the 34th International Conference on Machine Learning (Vol. 70, pp. 1263–1272). PMLR. https://proceedings.mlr.press/v70/gilmer17a.html
  6. Veličković, P., Ying, R., Padovano, M., Hadsell, R., & Blundell, C. (2019). Neural Execution of Graph Algorithms. arXiv. doi: 10.48550/ARXIV.1910.10593 https://arxiv.org/abs/1910.10593
  7. Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer Normalization. arXiv. doi: 10.48550/ARXIV.1607.06450 https://arxiv.org/abs/1607.06450
  8. Sinkhorn, R. (1964). A Relationship Between Arbitrary Positive Matrices and Doubly Stochastic Matrices. The Annals of Mathematical Statistics, 35(2), 876–879. doi: 10.1214/aoms/1177703591 https://doi.org/10.1214/aoms/1177703591
  9. Sinkhorn, R., & Knopp, P. (1967). Concerning nonnegative matrices and doubly stochastic matrices. Pacific Journal of Mathematics, 21, 343–348.
  10. Santa Cruz, R., Fernando, B., Cherian, A., & Gould, S. (2017, July). DeepPermNet: Visual Permutation Learning. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
  11. Mena, G., Belanger, D., Muñoz, G., & Snoek, J. (2017). Sinkhorn Networks: Using Optimal Transport Techniques to Learn Permutations.
  12. Mena, G., Belanger, D., Linderman, S., & Snoek, J. (2018). Learning Latent Permutations with Gumbel-Sinkhorn Networks. doi: 10.48550/ARXIV.1802.08665 https://arxiv.org/abs/1802.08665
  13. Floyd, R. W. (1962). Algorithm 97: Shortest path. Communications of the ACM, 5, 345.
  14. Dudzik, A., & Veličković, P. (2022). Graph Neural Networks are Dynamic Programmers. arXiv. doi: 10.48550/ARXIV.2203.15544 https://arxiv.org/abs/2203.15544
  15. Kruskal, J. B. (1956). On the shortest spanning subtree of a graph and the traveling salesman problem. Proceedings of the American Mathematical Society, 7(1), 48–50. doi: 10.1090/s0002-9939-1956-0078686-7 https://app.dimensions.ai/details/publication/pub.1018477579
  16. Morris, C., Ritzert, M., Fey, M., Hamilton, W. L., Lenssen, J. E., Rattan, G., & Grohe, M. (2018). Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks. arXiv. doi: 10.48550/ARXIV.1810.02244 https://arxiv.org/abs/1810.02244
  17. Xhonneux, L.-P. A. C., Deac, A., Veličković, P., & Tang, J. (2021). How to transfer algorithmic reasoning knowledge to learn new algorithms? In A. Beygelzimer, Y. Dauphin, P. Liang, & J. W. Vaughan (Eds.), Advances in Neural Information Processing Systems. https://openreview.net/forum?id=q2JWz371le
  18. Bentley, J. (1986). Programming Pearls. Association for Computing Machinery.
  19. Aho, A. V., Hopcroft, J. E., & Ullman, J. D. (1974). The design and analysis of computer algorithms. In The design and analysis of computer algorithms. Addison-Wesley.
  20. Gavril, F. (1972). Algorithms for Minimum Coloring, Maximum Clique, Minimum Covering by Cliques, and Maximum Independent Set of a Chordal Graph. SIAM J. Comput., 1, 180–187.
  21. Lawler, E. L. (1985). The traveling salesman problem: a guided tour of combinatorial optimization. In The traveling salesman problem: a guided tour of combinatorial optimization.
  22. Knuth, D. E., Morris, J. H., Jr., & Pratt, V. R. (1977). Fast Pattern Matching in Strings. SIAM Journal on Computing, 6(2), 323–350. doi: 10.1137/0206024 https://doi.org/10.1137/0206024
  23. Jarvis, R. A. (1973). On the Identification of the Convex Hull of a Finite Set of Points in the Plane. Inf. Process. Lett., 2(1), 18–21. http://dblp.uni-trier.de/db/journals/ipl/ipl2.html#Jarvis73

Footnotes

  1. Binary Search, Minimum, Max Subarray (Bentley, 1986) , Matrix Chain Order, LCS Length, Optimal BST (Aho et al., 1974) , Activity Selector (Gavril, 1972) , Task Scheduling (Lawler, 1985) , Naïve String Matcher, Knuth-Morris-Pratt (Knuth et al., 1977) and Jarvis’ March (Jarvis, 1973) 

This post is licensed under CC BY 4.0 by the author.

Generalist Agents

Is Conditional Generative Modelling All You Need for Decision-Making?