But I think the paper fails to answer the most important question. It alleges that this isn't a statistical model:
"it is not a statistical model that predicts the most likely next state based on all the examples it has been trained on.
We observe that it learns to use its attention mechanism to compute 3x3 convolutions — 3x3 convolutions are a common way to implement the Game of Life, since it can be used to count the neighbours of a cell, which is used to decide whether the cell lives or dies."
But it is never actually shown that this is the case. It later on isn't even alleged that this is true, rather the metric they use is that it gives the correct answers often enough, as a test for convergence and not that the net has converged to values which give the correct algorithm.
But there is no guarantee that it actually has learned the game. There are still learned parameters and the paper doesn't investigate if these parameters actually have converged to something where the Net is actually just a computation of the algorithm. The most interesting question is left unanswered.
The diagonal-looking attention matrix shown in the post is mathematically equivalent to 3 by 3 convolution. The model learns how to do that via its attention mechanism - it's not obvious that it would be able to do that via attention.
(This can be shown by comparing that attention matrix to a "manually computed Neighbour Attention matrix", which is known to be equivalent to 3 by 3 conv.)
It would be more convincing if they did an exhaustive enumeration and verified that for every possible 3x3 Life the learned NN was correct. How do I know looking at a speckled screenshot that it is exactly correct and there's not a little floating point error somewhere or something like that which results in 1 edge-case being slightly off? If the only testing is '100 Life games for 100 steps', that isn't water-tight. (While if you do exhaustive enumeration, well, it has to be correct, because the NN is deterministic and fixed and there's no way for it to go wrong then.)
Edit: increased the validation to 10,000 life grids for 100 steps, (taking 16 minutes to check), which is hopefully somewhat more convincing. That's 1,000,000 life steps computed without errors in total. Plus 32,000 steps computed without error during training.
When the attention grid is manually computed (to be equivalent to 3 by 3 conv), the model can be trained to be 100% perfect, verified by checking all 3 by 3 grid states. (And this manually computed attention matrix means that once the tokens reach the classifier layer, each token contains only the information of the relevant 3 by 3 grid, and the whole thing is deterministic as you say.)
However, when the model is computing the attention grid itself, just checking all 3 by 3 sub-grid states crop up is not enough, because the position of the sub-grids can impact the attention matrix, and also the state of other cells can impact the attention matrix. So as shown in the post, it does approximate 3 by 3 conv, but if it doesn't get the approximation quite right, there could be errors. But I would say that it's still computing the Game of Life algorithm in an interpretable way, it's just that maybe it has struggled to create a perfect 3 by 3 convolution via attention in that particular case. (To exhaustively check this, would require checking all 2 * (16x16) grids.)
I think it would have also been very interesting to manually construct a NN, which represented the rules exactly. Maybe there is some nice mathematical way to describe them or some constraints need to be fulfilled.
Then afterwards you can check the neutral network against the exact algorithms.
Yes, I also quoted that part from the article. This does not address that the attention Matrix does not represent all learned parameters. Even supposing that the form of the attention matrix guarantees the correct functioning of the algorithm why was that not used as the metric to decide convergence?
"We detected that the model had converged by looking for 1024 training batches with perfect predictions, and that it could perfectly run 100 Life games for 100 steps." This would be superfluous (and even a pretty bizarre methodology) if the shape of the attention matrix was proof that the Network performed the actual game of life algorithm.
Just to be clear, I am not saying that the NN isn't converging to performing some computation that would also be seen in other algorithms. I am saying that the paper does not investigate whether the resulting NN actually performs the game of life algorithm. The convolution part is certainly evidence, but I think it would have been worthwhile to look at the actual resulting Net and figure out if the trained weights together actually formed an algorithm. This is also the only way to determine the truth of the initial claim, that this isn't just a statistical model, but rather an actual algorithm.
> the paper does not investigate whether the resulting NN actually performs the game of life algorithm
How could it not be computing the game of life algorithm? Given that it gets 100% accuracy over multiple steps on a bunch random game boards it's never seen before.
And then based on the structure of the net, and by examining the attention layers and finding that it's doing 3 by 3 average pooling, we can see that the attention layer produces a set of tokens, where each token contains the information of the number of neighbours it had, and its previous state. This then goes through a classifier layer, which decides it's next state, given that information.
Further evidence for that: it was possible to use linear probes to confirm that the tokens that had been through the attention layer contained the information about the number of neighbours and the previous state.
From all of this, it's clear that the model is running the Game of Life properly.
Do you not understand the difference between empirical evidence and mathematical proof? Surely every person talking about NN research should be aware of that distinction.
> How could it not be computing the game of life algorithm? Given that it gets 100% accuracy over multiple steps on a bunch random game boards it's never seen before.
Reminds me of this great story about a programmer-turned-businessman who tried to learn a game from examples and ended up with an almost-correct brute force solution:
But I think the paper fails to answer the most important question. It alleges that this isn't a statistical model: "it is not a statistical model that predicts the most likely next state based on all the examples it has been trained on.
We observe that it learns to use its attention mechanism to compute 3x3 convolutions — 3x3 convolutions are a common way to implement the Game of Life, since it can be used to count the neighbours of a cell, which is used to decide whether the cell lives or dies."
But it is never actually shown that this is the case. It later on isn't even alleged that this is true, rather the metric they use is that it gives the correct answers often enough, as a test for convergence and not that the net has converged to values which give the correct algorithm.
But there is no guarantee that it actually has learned the game. There are still learned parameters and the paper doesn't investigate if these parameters actually have converged to something where the Net is actually just a computation of the algorithm. The most interesting question is left unanswered.