Compositional Generalization and Parsimony

July 19, 2022

We recently presented our work, Compositional Generalization in Grounded Language Learning via Induced Model Sparsity at the Student Research Workshop in NAACL 2022. NAACL was a real blast, I have another post talking about what I learned from the conference as my first-time conference, but in this post I want to talk a little bit more about our work and the field of compositional generalization, which I think is quite exciting (even if its been around for some time now).

Compositional Generalization and out-of-distribution

Deep learning continues to make lots of strides in different areas. As someone who entered the software engineering space about ten years ago, it really feels like magic. Nobody could have imagined writing a program to do really underspecified things like detecting what kind of bird is in a photograph or creating extremely realistic chatbots which can outperform humans on reasoning and question answering tasks, or wiping the floor against humans on games like go and starcraft, or generating hyper-realistic artwork of things that seem almost impossible to imagine, like an avocado shaped chair.

But all this deep learning stuff has a lot of really silly failure cases too. Performance on tasks like vision-language navigation is still nowhere near what humans can do and deep learning fails in pretty silly ways when you start using it in contexts that it didn’t see in training. Usually the answer to this is more data but its really hard to collect data on some things like robotics, or at least might be really frustrating and tedious to do so.

One really simple and silly failure case is failure to compositionally generalize and in a related way, failure to disentangle correlated factors present in the training data. Here’s a really simple example. Say I had an agent in a very simple environment where the task is to go to objects of particular colors and shapes. In this case, the colors are red, blue, green, yellow, purple and grey and the objects are ball, box and key. If I give the agent a reasonable number of examples of each combination with an instruction to go to the described object and a demonstration of how that is done, neural networks can learn to solve this task quite easily with a very high success rate. However, if I withhold some combinations, for example red ball, blue key, green box and keep other examples in the training set such that it is still not ambiguous as to what each color means or what each object word means, quite many neural architectures will fail to work on those withheld examples and the ones that do work only work on some initializations and not all of the time.

Why this matters

To regular people, these failure cases are quite perplexing because the task really seems quite simple. How come GPT-3 can write me a convincing sounding research article and yet the same architecture only works sometimes on such a toy task?

To make matters worse, this isn’t just a toy problem. Datasets can have internal correlations which can cause models learnt from them to exhibit unexpected behaviour. Take for example a question answering model learned on correlated descriptions of food. Lets say that Japanese food is usually umami, sweet and salty but never spicy. Chilli sauce can be sweet or spicy depending on whether or not it is hot. If you asked this model what a seaweed salad with hot chilli sauce tasted like, the model would likely tell you either umami or spicy, but not both. This problem gets even worse in language goal-conditioned settings, where you really want to be able your robot helper to learn individual properties as they relate to the words you speak, as opposed to having to teach it every single combination of tasks.

Why this happens

Its hard to fault models for exhibiting this behaviour, since we sort of ask for it. Consider the cross-entropy loss function for next word prediction:

E(wt+1,wt,...,w0)D[1N(logpM(wt+1wt,...,w0)+wW,wwt+1logpM(wwt,...,w0))]E_{(w_{t + 1}, w_{t}, ..., w_{0}) \sim\mathcal{D}}[\frac{1}{N}(-\log p_M(w_{t + 1}|w_{t}, ..., w_0) + \sum_{w' \in W, w' \ne w_{t + 1}} \log p_M(w'|w_{t}, ..., w_0))]

The first term of the cross-entropy loss tries to minimize the negative log-likelihood (eg, maximize the likelihood) in which we predict that word wt+1w_{t + 1} comes after the sequence wt,...,w0w_{t}, ..., w_0. The second term tries to minimize the likelihood according to our model that anything else (wW,wwt+1w' \in W, w' \ne w_{t + 1}) comes after that sequence.

Of course this loss is an expected value, so sometimes you can have different words coming after the same sequence wt,...,w0w_{t}, ..., w_0 and the model will minimize the expected loss by predicting a likelihood proportionate to how much each example appears in the data.

Going back to the example, in the cases where the food is Japanese, the words “umami”, “salty” and “sweet” might appear after a description of the food and our question “what does this food taste like?“. They appear sometimes in the first term, so the log-likelihood will be close to zero in those cases.

However, the word “spicy” never appears in the first term and always in the second. So when the model sees Japanese food, the expected thing to do is to predict a very small likelihood.

The same goes for “chilli sauce”. Food that includes this might always be spicy or salty, but never umami.

What happens if we describe to the model Japanese food with a side of hot sauce? In that case we’re in a bit of a pickle (no pun intended). Because on one hand, the presense of Japanese food tell us that the food is probably either umami, salty or sweet and definitely not spicy. On the other hand, the presence of “chilli sauce” tell us that the food is almost certainly going to be spicy or sweet, but definitely not umami. At this point we’re out of distribution. The log-probabilities in this case might just cancel out meaning that the model is just completely unsure about what Japanese food with chilli sauce is, or it might predict that the food is just sweet, or just spicy, or just umami, but probably won’t predict that it is both. That just doesn’t exist according to the likelihood distribution that the model has been optimized to produce.

What we did

We took a look at this from the perspective of a different problem, a simple vision-language navigation agent called BabyAI. At first we weren’t really looking at compositional generalization at all. Instead we were studying sample efficiency. But we ended up falling into this research direction by taking a more critical look at BabyAI’s validation set - in particular, the validation tasks are drawn from the same distribution as the training ones, meaning that you’re likely to see the same tasks in validation (up to a permutation of the environment layout) as you see in training.1 This means that you could just be learning to memorize what each training task means instead of actually learning to interpret the language, which is sort of the whole point. We decided to hold out some description combinations (for example, red ball, blue key, green box) from the training tasks and test on those ones - and surprise surprise, this completely confuses the model and it gets a very low success rate.2

BabyAI environment

BabyAI environment

This sort of happens for the same reasons as described above, though a little more indirectly, since we’re modelling actions from the environment as opposed to descriptions (a description in this case might be a prediction of where the goal object actually is). However, if you make a pared down version of this experiment where you just learn to predict the goal object and not a policy to get there, you’ll start to see the same behaviour - the model predicts that the actual red box or red key corresponds to “red box” and “red key” perfectly fine, but fails completely when asked to identify the “red ball”. Again, this is because the model has learned that “red” means “not ball” and “ball” means “not red”, so even though the object in the environment might match the description, the factors that you see are actually working against you when it comes to predicting where the object is.

You can make an even simpler version of this problem by casting it as a simple form of attention between words and cells in the grid, where the “values” in the attention are given by the identity matrix.

Lets call our sentence QQ with length MM and the linearized image KK with length N=W×HN = W \times H. If we compute QKTRM×NQK^{T} \in \mathbb{R}^{M \times N}, we have MM rows of correlation scores for each word to each of the NN cells in the grid. If have lots of examples of grids and annotations about where the goal state is for a corresponding sentence, we could sort of hope that QQKT\sum_{Q} QK^T tells us which of the NN cells are likely to be the goal state (because they matched many of the words). For the same reasons above, this still fails to generalize.

Structural prior - factored attention

Structural prior - factored attention

One possible solution is to enforce with a structural prior that the words are independent. A cell matches a sentence if both its color aspect matches the sentence and also its objectness aspect matches the sentence. This assumes that we have a disentangled representation, which we have in BabyAI, and we also know which components of a vector representation of a cell correspond to different aspects3. Then we could compute the attention separately for each aspect and multiply them together, which is the equivalent of imposing a logical AND.

QQKcolorsTQKobjectsT\sum_Q QK_{\text{colors}}^T \odot QK_{\text{objects}}^T

This still fails to generalize.

Even though we’d hope that, for example, seeing “red” in a sentence along with red color indicates a very high probability of a match, the presence of “ball” in the sentence is a negative confounder. You’ll see this if you visualize the correlation between attribute embeddings and word embeddings. A simple way to fix this is to apply L1 regularization on the correlations during the optimization process. This penalizes weaker predictors in favor of stronger ones.

Correlation heatmaps

Correlation heatmaps

These correlation heatmaps also have a nice interpretation from a language perspective as well. Because language is grounded in shared experience, we expect that individual units of language refer to specific and independent parts of the world, regardless of whether or not correlations exist. For example, the word “sky” means the actual sky, and not blue, even though the sky is mostly blue, grey or black (or sometimes pink-orange). Similarly, “blue” doesn’t mean “sky” - lots of things can be blue. Ensuring that each word refers to only at most a few distinct concepts in the world is a kind of parisimonious prior and is more aligned with our intuitions on how language actually works. It also happens to have the property that it can enable significantly stronger compositional generalization as we demonstrate on the paper.

What’s next?

Goal identification is an interesting problem, but as pointed in both Qiu and our paper, it seems that in the limit of data, the Transformer architecture is able to figure this out. An interesting follow-up research question is exactly why this happens.

However there are still many problems are have not yet been solved via a general architecture. These are problems that exist particularly on the generative side, similar to the language generation problem given above. It is possible to solve them if we start making some assumptions about the data, or giving the network a few extra hints about the data and how tokens in unfamiliar contexts should be treated. Or perhaps the answer is a newer network architecture which is more robust to these sorts of challenges.

1 This isn’t a particularly unreasonable assumption. It basically just assumes that the data is i.i.d, just like it is in most cases where machine learning is applied.

2 That said, some models seem to do OK at this task! For example the Transformer model once you get enough data. This is similar to a result in Qiu et al.

3 We don’t have to assume that we know which aspect, just that the sub-slice of the vector corresponds to a different aspect.


Profile picture

Written by Sam Spilsbury an Australian PhD student living in Helsinki.