 Hi, everyone. Can you hear me in the back? OK. Thank you so much for this invitation to the beautiful place. The workshop has been really great. My talk will be slightly different from the talks that we've seen so far. So today, I'm going to talk about Transformers. And I'm going to talk about this recent work that we had at ICLR, talking about how Transformers do reasoning tasks. And you might notice some things. If you've done a CS degree, then these automata figures you might notice from a long time ago when you studied these. We will see more of this in the talk. OK, so let me begin. So we've seen that there's been a lot of emergent behavior that has been observed for these large language models. And the way these large language models are trained, are trained basically on next-word prediction tasks. So basically, you get, for example, here, you have the sentence. And you want to predict what this word is. And so far, there's another sentence. You want to predict what this word is. And what these models are trying to do is find the probability distribution of what the next word is given the context. And usually, this is trained on a huge corpus of data. I'm not even sure this is updated. This is 45 TB of internet text. This was true for GPT-3. I'm not even sure what GPT-4 is trained on. And these are trained usually by doing local methods, local optimization methods. And even though this task seems very simple, that you're just trying to predict next-word, it actually leads to, when you evaluate these models, you're not evaluating just on the next-word prediction, but on what they're able to do. And in fact, you can see that these models actually end up doing much more than what this task would suggest. So for example, the first example here is my interaction with chat GPT, where I gave it one of my midterm questions that I set up for the midterm I had in my semester. And it actually solves it and produces a very good latex version of this solution. And so it's a very simple problem trying to compute an MLE, but it still gets it all logically correct. Here, these are some other examples of safe translation, bracket code generation, and so forth. So even despite this simple task of training, these methods get good at doing these open-ended reasoning. And I'm putting reasoning in quotes because it's not very clear how to define reasoning. And for my talk, I will focus on a very specific definition of reasoning. And I'll get to this in a second. So one could ask, is this model already doing really well? And is it always correct? Of course not. It's not always correct. I'm sure if you all have played around with it, you figured out that it does some crazy things sometimes. For example, addition, very simple thing to do. I mean, this is a trick question because I added an example which has a lot of carriers. And it's going to produce some output. And you won't be able to see that this is wrong, but this is wrong. And here, for instance, if you interact with it and tell it to do binary addition. And it knows that in binary addition, there should be only two digits, 0 and 1. And it'll produce some final result, which has a 3. So it does hallucinate logic and knowledge. Even though it knows what it's supposed to do, it might produce some output that you're not expecting. So the question that we tried to address in this work is that how do we try to understand how these transformer models are doing reasoning internally? And why are they failing when they are failing in these ways? And the hope is if you want to improve these models and make them more robust to these problems and the issues that I raised, then we need to understand what their internal mechanisms are. And here, internal mechanisms, I mean, how do they represent these reasoning tasks? And how do they generalize? How do they learn from samples and perform well in distribution? And of course, the hard question to answer is how is these local optimization methods actually getting you these solutions? So what the training dynamics are? So in this talk, I'm actually gonna only focus on representation and not talk about the training dynamics. I will show empirical experiments on the training dynamics. So there are a lot of open questions to try to understand how optimization works in these spaces. Okay, so today's talk in the theory will talk about how can transformers represent these algorithmic reasoning tasks and how an empirics will see how do they actually learn to do these? And here, algorithmic reasoning will be defined as executing transitions of a finite state machine. I will define more concretely what this means and why we think of this as algorithmic reasoning. And this is based on this work called Transformers Learn Shortcuts to Automata. And these are my amazing collaborators with which we did this work. Okay, and please feel free to stop me at any point if you don't follow anything along. Okay, so let's begin by talking, addressing the elephant in the room. So quantifying what reasoning means. Okay, so one way to think about reasoning is that we're doing this next token prediction in a natural language, right? And this is, you know, to predict the next word, you need to solve a complex, many different mixture of subtasks. So for instance, you know, if you want to do code completion, you need to parse syntactic structure, or if you want to produce a grammatically correct sentence, you need to understand syntactic structure. You also want to perform factual inference, right? So if I want to, the sentence that I had before, the ICTP is located where I need to understand that the factual reference that ICTP is located here. Yes, so I need to put that factual reference. And then you also need to retrieve some maybe memorized associations, search over and eliminate answers. And these can be thought of as symbolic computations. So for instance, these are different kinds of problems that have been studied in this transformer literature, more empirically and some theoretically. So for example, regular languages. So you can model like, say you want to figure, you want to generate an email or a sentence, you need some regular expression that generates it. So this is a subtask that you need to solve if you want to do next-word prediction. If you're generating code, like I mentioned that you need to do bracket matching, these are hierarchical languages and there are many hierarchical languages that exist. You would want to actually solve maybe arithmetic. So arithmetic, the model needs to understand how to do arithmetic or add numbers. And this is an example of a problem called parity, which we'll see a lot in the talk. And it's basically a task where you're being given a sequence of things that are happening here. For instance, there is a switch and somebody is turning it on, turning it off, somebody just goes and does nothing. And you want to predict at the end what the situation of the room is, is the light on or not. And so forth. So this has been also, is also captured in the symbolic computation and a thing that's been looked at a lot in trying to understand theory for transformers. So all of these can be thought of if you bracket them down as executing some transitions of a finite state machine. And we can just think of algorithmic reasoning as implementing this finite state machine and trying to figure out what the state of this machine is. And that's what we are gonna call and let me define what a finite state machine here I mean. So a finite state machine, a semi-automated one, is a discrete time dynamical system, which has the following structure. So it has a state space that I'll denote by q and it has an input space which I'll denote by sigma and it has a bunch of transitions. So how it works is that every time you are at some state you get your input sigma t and you move to the next state qt and there's a delta transition function that governs how you move from one state to the other. Okay so and more classically if you might have seen an automata which actually takes the state that you have and actually gives you some output that you might desire. So in this we're just gonna worry about semi-automata so we'll only care about how the state transitions are happening. Okay so there are many examples from easy to hard here that you can imagine putting it into this model. So one thing you can think about is the parity counter that I mentioned. So there are only two states, either you're in the even state or the odd state. You decide whether you get an input zero or you get an input one. So the sigma is zero and one and if you get an input zero you stay in the same state. If you get an input one you move to the other state. So this can be represented at a finite state machine. Memory unit is another very interesting thing that we would care about like I mentioned that we need to recall and memorize things. So this is just here I'm denoting these two states which is the club and diamond. And what the task is that you have a read operation and a write operation. So once you've written in memory you need to recall later what is in the memory. So this is just doing like a one bit memory operation. Then you can think of examples from RL where you're operating in a two dimensional world or a three dimensional world and you want to figure out which position you are and after you take a lot of actions. And more complicated even a Rubik's cube. You can view all these stickers that are on the Rubik's cube as the state space and the actions as you do the rotations. And you can ask like questions about this Rubik's cube and these are very hard questions like can you solve it from starting from a state or so forth. So this mechanism can capture all of these operations. Yeah. Oh yeah, great. So semi-autometer is just does not have this output function. So there's an output function size. So what you can imagine is that you have your hidden state but you're actually outputting something else as a function of the hidden state. So you might not actually see the hidden state. So we're making the problem even simpler. We're saying that we just want to predict the hidden state. We don't care what the output is going to be. So for example, in regular expressions if you have like the email I could imagine that the only output is whether till now the value is a correct email address or not. So it's going to be like 0, 0, 0. Till at the end you get a dot comm when it's finished then you get like a one. So that's going to be an automator. But the hidden state might be different. Sorry, second? I recall you had to. Yes, it's a perfect way to characterize an automator. Yeah, if you just wanted to predict the hidden state and not any other function of the, yeah. So RNNs are perfect for this task because they can actually just learn this delta dynamic because the way they're fed in is that you apply the same function again and again and that's basically like learning this transition. This will show up again when I talk about transformers because transformers do not work like this. Okay, so what is this reasoning task? I'll just make this task more clear. So you have an initial state q0 and I just give you the input sequence sigma one to sigma t. So you just see what are the inputs that you're given and what your task is to output the sequence of states that you would have gone through when you saw this input token sequence. So you can either view it as like generating the entire state sequence or just the final state. And you can think of if you want to find the state you need to apply this delta function t times. So this can be viewed as like unrolling it t times. So this t steps of reasoning is what you can think of. So for the parity counter for example here, if I start with an even state and I see this like one, one, zero, one, one, one, zero, then I want to predict odd, even, even, odd, odd. This is what I'm trying to do. And like this was already mentioned as a simple iterative solution, you just learn this function delta and compute qt using the previous state, qt minus one. And this needs t iterations to compute the final state. So you'll have to recurse this t times, which an RNN would be greater doing. Okay, so this is the task you're gonna look at. And we're gonna assume you have like a lot of inputs of different, different of these sequence tokens and the output states. And you wanna understand like what you can learn with transformers. Can you learn the underlying automata? Semi-automata. Okay, so the main question is, this is very good with RNNs, but transformers are what we are using in these GPT models and we've thrown away all these recurrent connections. And transformers are not similar to RNNs. So what they do, they are actually shallow parallel circuits and have no recurrent structure. Okay, just a poll, how many of you know the transformers architecture really well? Okay, cool. So I will explain a little bit so that we are on the same page, but for the rest of the talk, I'll do a TLDR and like just call it something as a circuit and we'll just deal with that. But let me just explain a little bit. So to make it clear how they're very different from recurrent architectures. So here what happens is that instead of getting the input one at a time and operating on it one by one that RNNs would do, you just get an input sequence X1, X2, X6 here. And these are just like imagined words in a sentence or like any kind of other structure that you might want. And what the architecture does is that it computes some score, a pairwise score between every token and every other token to try to understand how this token is related to the other token. And then it normalizes the score, averages according to that score and produces a new layer of embedding. And then it applies a fully connected net and this constitutes like one layer of the transformer. And then you can stack many, many layers. The things to notice here are that these are only pairwise scores. So they can be parallely computed and the fully connected networks are also applied. The same networks are applied to everything. So that's also a parallel operation. So these are very parallel, these are parallel across inputs. They are sequential only across layers. Okay, so every layer that you're implementing is a parallel layer and then as many layers as you put is sequential. So if you think about the models that we have right now, Distil word, GBD3, GBD4, I don't have an idea what the size is, but the number of layers are much less than the sequence length. Okay, so here the sequence length goes up to 20, 2000 and the number of layers is only 96. So if you wanted to do this recurrent computation, you would need length to be as large as the sequence length, but we don't have that. So it cannot execute a chain of these reasoning steps if I don't give it depth being as large as T. So it's not that obvious that this can actually even solve these reasoning tasks given that it doesn't have the depth to implement this. Okay, so that's why there's a mismatch between, even though they're doing well at these tasks, they're not actually, their architecture is not designed in a way that would be perfectly fitting into this task. And the main part of the talk, we'll see how do they actually solve this, even despite the fact that they don't have that higher depth. Okay, so if you wanted black box these transformers, I know this architecture is very complicated. The way you can think about this is that it does two kinds of operation. When it computes this parallel circuit, the first operation is a uniform attention, which means it just uniformly looks at everything in the back. So imagine all the weights are like, all the weights that you could have seen in the past equally weighted. And this can be very easily achieved. So this is just like one over T. If you're at the Tth position, it's just giving one over T weight for everything in the past. And the second kind of operation that it is really good at implementing is a sparse attention. So which means it can pick out one thing in the past that it might be of interest to this position. And the good part about this, maybe I'm not gonna talk too much about it, that you can do this in a very parameter efficient way. So if you want to do sparsity with an MLP, you would still need the length of the sequence number of parameters, but you can only use polylog parameters to get actually sparse attention in these transformers. And this is a functionality that you would expect from the word attention, that it can sparsely attend to something that it needs to attend. So what we can think of transformers as shallow parallel circuits, where they have sparse and uniform gates. These two kind of gate operations that they have. And that's all we'll have to worry about as transformers for the rest of the term. Any questions at this point? Okay, cool. So we're gonna look at transformers just as shallow parallel circuits. And now we'll ask the question of how they do reasoning. And the main result that we have is gonna be the following. That transformers can represent shortcuts, which I already hinted to, which require less than T-sequential steps to solve the semi-automated problem. Okay, so they require a little O of T. And in particular, all semi-automated, you can do with order log T-depth. There's a class of solvable semi-automated that you can do with only depth that depends on the size of the state space. So it doesn't even depend on the length of the sequence. And for some special automated, you can even do it with like constant depth. Okay, so you don't need this depth as much as the reasoning steps of how RNNs do. Instead of actually computing it one step at a time, these transformers can do some kind of parallel computation to speed up this process and find shortcut solutions to the problem. And this is what we show that you can show this in the representation sense. We can't guarantee that they would actually find these, but they can at least represent these shortcut solutions. And if you are interested in circuit complexity, you can think of this as a class of NC1 circuits and ACC0 circuits. And this class is not clear where this fits in. This is not a complexity class in circuits. Okay, so this is the main result. And now I'll go through how we could think of these shortcuts and why would we even expect shortcuts to exist? Okay, so let's think about why shortcuts should always exist, right? Okay, so let's see what we are trying to do. So for each time t, we are trying to compute the state qt. What does qt mean? You start with q0. You apply this transition function with sigma one, which is the first symbol as the input, then apply sigma delta with sigma two and so on. So you're just basically doing this function composition, right? So if you think about the parity function, this is the parity function, you can write this delta operation as a matrix. So given if it has the current state zero, what is the, if it takes, gets zero, what is the transition that it gets? So you can write this transition as a matrix. And if you get any sequence input, you can view finding the final state as basically doing this matrix multiplication of all the operations that you've seen so far and selecting the vector of the original position, okay? So this is another way that you could actually implement the final state that you have at time t by viewing this as a function composition as a thing together. Instead of computing the first state and then applying the transition, you can just view this as an operation that you wanna do. Okay, so how, how can we do this faster than just doing it with t depth? So for instance, if, if you were to look at the iterative solution, what we would do is that we would apply this bracketing as the following. We would start with our original state. We would apply the first transition function, then the next one and then the next one and so forth. And this would give us this kind of a structure which would give us t depth. But there's something smarter that you can do. You can just partition this in like a binary decomposition. You can group two of them together then four and so forth and make this tree. And now this tree, instead of computing, at every position we don't know what the state is, we just know what the transition combined together is. So here we know that if we have sigma one and sigma two, if you multiply these matrices, we know if we take the operation, sigma one, sigma two together, this is the state we're gonna get and so forth. So we can operate in that space and then eventually at the end apply the initial state and now we only get log t depth. So we basically did this function composition using different scales and this gets us a simple log t solution and this is possible. I didn't assume anything about the automator just that you're doing function compositions, right? So you can always do this. So what this suggests is that instead of depth t, you can always get a depth log t solution if you just care about getting this final state. And if you want for all positions, you can always take parts of this to compute for any middle position of any part of the sequence. So not just q of final t, but all of the other sequences. So how can you implement this with attention? I said that the attention has two things that it can do very well, it can do sparse attention. So sparse attention can decide which position to pick out. So it can pick this attention and then we have the MLP that can learn to multiply these. So you can actually just do this construction and you can figure out that transformers actually fit really well because they can do this sparse attention so they just need to attend carefully at the correct position that we've defined and then they can use MLPs to multiply them. So we can always show that there are these shortcuts that exist. Okay, so this is general, but we're not using anything about the automata, particularly apart from the fact that this is a function composition. So can we find shorter shortcuts? Okay, so let's look at this parity question where we have to just figure out the parity at the t. Let's just think about predicting the parity at the end. Does anybody want to take a guess of how you could do this in even a shorter way than what I showed? If you just got a sequence of zero ones and you wanted to figure out whether at the final position you were in the even state or the odd state, how could you do this? Yes, great. So this is just a mod two counter. So what you can do is you can just sum everything from the start and take modulo two and this is a completely parallel operation and this in the architecture would look like just taking the sum and using the mod two at the end. And what I said is transformers are very good at doing uniform attention. So they can uniformly just attend to everything and sum them up and the MLP can implement the mod two gate. So this local dynamic, even though this is a local dynamic, this induced a global structure that allowed us to compute what the solution is gonna be even without actually implementing every step of the reasoning. So if there is function composition that is commutative, so what the property we use was that it doesn't matter in which order the parity is happening. We can exchange it as long as we only care about the sum and as long as that is true, then counting suffices. So we can always count in this form and get a shorter solution. Okay, so here we went from a log T solution to a depth one solution for this particular parity problem. So the hope is that there is some algebraic structure that we can exploit which can even get a shorter, shorter shortcuts. The challenge is of course the structure that I just showed uses this property that there's commutativity, which means that the ordering doesn't matter but for most tasks your ordering matters. So for example, if you're writing an email address, it really matters in which order you've put it. If you cannot just not assume an ordering and of course there's also noninvertibility because history, what you've seen so far can change what the new operation is gonna be. So like for instance, if I've seen at the rate already then Gmail has a different meaning than if it was before that. So these are the two challenges that this doesn't solve and let me show you this challenge with this problem. It's a really simple math toy puzzle which gives this intuition really clearly. So what you're doing in this task is that you're on a roundabout, like maybe a roundabout and you have four positions and you're driving this car and you only have two operations that you can either drive forward or you can take a U turn, which is not feasible but you just flip the cars. Let's just assume you can do that and you want to basically get, you get a sequence of operations. You start at the top position facing right and you want to predict where you are at the end. Okay, I just give you a sequence of either you're driving or you're taking a U turn, driving U turn. And now this, if you look at this here, if I drive first and then take a U turn, I'm in this state. But if I take a U turn first and then drive I'm in this state. So this is not commutative at all, right? So how would you do this? I'll give you like a couple of seconds to think about it and then I'll just tell you the solution. There is a two layer solution that you can get to solve this problem. Okay, so you get a sequence of D, U, D, D, say like a sequence of operations and you want to predict where you are at the final state. Okay, so here what's happening is that depending on the car's direction, so if there was no direction, if you only had to drive forward, then this is just basically like a parity thing that you're counting a modulo. And the only thing you need now is this direction. How do you encounter this direction? And how can you find the car's direction? Okay, so at any time, if you want to find the car's direction, the way it's just a parity task on the U turn. So if you just look at the U turn, we can figure out at this current position what is the direction that we're facing. And now if you want to find the position of the car, we can just take the direction and multiply it with the, you know, which if you're going, you know, you've took, going forward and your direction is say negative, you just add a negative one. And if your direction is positive, it's a positive one. And then you can take mod four. Here I'm just doing a four circle. So this is again a parity task on this four circle. So for instance, if you have drive U turn, drive U turn, this example, you can get the parity that tells you the final direction of the car is gonna be this. And then you can use the sign counts to finally get the position of your current state. And here what you did was you kind of broke this task into two different components and compose them together. So if we can decompose any task into these commutative groups, then we can do this discounting thing that we had and we can glue it together. Okay, so in general, you can imagine that there could be even shorter shortcuts with like more structure on the automata that we have. And generally you can get these really nice theorems from a semi-group theory about these decompositions. And here I'll just give a high level. I'm not gonna go into the details of it, but you can take any automata and write a transformation semi-group, which is just the semi-group defined by this function composition. And if you have some structure, for instance, the parity one turns into the cyclic group C2. And this one that I described turns into a semi-direct product of C2 and C4, like how we decided like the signs were based on parity or modulo two, and then the position was based on parity modulo four. So you get some product of these. And in general, if you can decompose, you can factorize this semi-group, then you can get these constructions in much shorter sense. And the intuition for decomposition, you can think of for natural numbers, we have prime numbers as a natural factorization. For groups, there are these Jordan and Holder decomposition, semi-groups, there's this seminal result by Crohn and Rhodes that shows that for different kinds of semi-groups, you have different kinds of decomposition. And basically what you can do is the same that you can use this theorem. The nice thing that turns out in these decomposition, Crohn-Rhoads theorem, are that the decompositions only consist of two kinds of automata. And these are exactly the two kind of automata that the transformers are really good at. So this is the parity automata and the memory automata, where it needs to just memorize the last position. And the decomposition basically decomposes into implementing these blocks and gluing them together. So we can implement all of that, those constructions using transformers, yeah. This one or the, okay, yeah. So what the Crohn-Rhoads theorem states is that if you have a semi-group, then the semi-group is contained in this decomposition, where the decomposition, I won't be able to exactly define the read product and all, but if you can think of it as like a factorization, which is factorized into simple groups and these flip-flop operations, which are basically, if you're at the current position, you can, you're just predicting what is the value at the last thing you wrote to memory. So if you can decompose these into many consecutive of these operations, and this exists for any solvable automata. So, cyclic groups actually. So when we define solvable, this Crohn-Rhoad decomposition in general, but if you want solvable, the simple group is just gonna be abelian, and we can think of it as a cyclic group. So we can implement with just this operation, and then these ones we can implement with this sparse attention, and then we can compose them together and get the transform architecture. And the number of elements in this decomposition is always gonna be bounded by the size of the state space. So you will never get anything larger than the size of the state space. Yeah, exactly, that's why you need the depth. So what happens is that you can get, when you partition this, you need to still maintain how many different possible histories can generate at every state. This construction is actually very beautiful, the way they do it in an efficient manner. So you have to maintain the history only up to that level, and you only have a queue of these levels, so you only need about two to the queue different histories that you need to maintain, and does not depend on the size of the sequence anymore, because the histories collapse because of this decomposition. And so for transformers, we need to do a little bit more work to do this construction, but the fact is that because these factors are so perfectly suited for transformers, we can just easily construct these constructions. And the gluing procedure is a little bit more complicated, but it's done with very standard MLPs. Yeah, so basically what happens is that there are at most queue factors, but for each factor, we kind of need queue depths to make sure, in the past, how many different states it could have, the history of what different positions it could have been. Sorry, say again? What are the two factors? Oh, these factors. Oh, so the two types of factors, so this decomposition. The m's and the h's. Yeah, the m's and the h's. So m, don't think about, m is just a gluing thing, so h1 to hn is the thing that you should think about. So n is the number of states here. So those are the ones that you want to implement, but because there's no invertibility, you kind of have to maintain the different histories that could have led to the same state, and for that you need an extra queue layers. It's not dependent. You know, n is a queue here. It's not dependent on t, the sequence, because this is just a group. The semi-group is just on the composition. It doesn't care about how many times you compose it. So it has no structure of the knowledge of how long you're running this model. Yeah, so it's a very beautiful construction if you want to check it out. I mean, it's the original paper, like not ours, ours is very ugly. Yeah, okay. So this gets you this order queue squared up. Something very interesting that happened was that we thought maybe, okay, we can actually achieve this, you know, make it... We initially didn't think that transformers would work for all automata. So we tried this particular automata, which is a very simple problem, but you would kind of understand why we thought it would be challenging. So basically it's like you're walking in a room and you have walls, and you're going left and right. So when you hit a wall, you cannot really go ahead. So basically you just stay in that position. So we thought that if the transformer has to understand, it really needs to know how many times it's hit the wall, and this would be a challenging problem for it. And I mean, the recurrent solution is simple. A crew on roads would give an order queue, actually queue square, that's my typo, a layer solution. But what we realized was when we trained a transformer, it actually solved it with like even two layers. Four layer solution was more robust. And we tried with different values of this, like, you know, the state space being much larger, and it always solved it. And it did it really well. It actually, if you look at the attention maps, it actually found a boundary detector, and it found the closest boundary that you would have ever hit. And once you figure out the closest boundary, then it's just a parity problem. And it actually learned this, and then we, after this, we came up with a construction that actually can do this in two layers. If you, there's a smart solution in how you can maintain your history and solve this faster, and the transformer did this before us. So it was kind of fascinating to see that the solutions that the transformers learn are actually quite interesting. Yeah, and this is like the attention plots that they had. And they actually performed in the same modality that I said, that they either had uniform attention or this release parse attention. So from several problems, you can get these faster solutions. Also, we don't really know what class of automata actually satisfy this, but this is one problem, yeah. Okay, yeah, let me try to explain that. So imagine, take this as, the attention matrix is basically, so every position, let me try to write it here. So you can think of the sequence as x one to x t. And so every, when you're taking the output of this position, say x small t, and this is the new embedding that the transformer's getting, it's putting some weight on everything, right? So you can think of this alpha one to alpha t, alpha capital T. So then this matrix is basically the matrix of, take the x one, x t, x one to x t, what is the weight that x one has, x i has for x j in the attention network. So that's the weight it's representing. So it's basically alpha of i comma j. And then, so here you can look at this. Now it's gonna be upper triangular because you cannot look at anything that you haven't seen in the future. So everything will be zero. And before you're looking at, so you look at, if you read every row, that's the weight that where it's looking at, what the attention is. So here, so for all these positions, if this is the line, then all of these positions are looking, only putting weight on this coordinate in the whole of history, not looking at anything else. So a boundary detector would just be like, when is the last time I hit a boundary? And then everything after that should just look at that boundary. So that's what this attention map is looking at. It's a very clean attention map. That's how we discovered the algorithm. No, it's not always clean. It's not always clean, but this one just turned out to be clean. Yeah, now sometimes if you take a very simple problem, you realize these attention maps don't mean anything because a lot of them cancel out because you're over parameterized, this network is huge. A lot of the new attention heads are not very useful. So they don't really have signal, but you will often find some attention head that seems to, but it's also like, I think I know what I'm looking for and I see what I'm looking at. So I think there's always a problem with just passing this. This is very non-scientific to do. I think there's better ways to do this, but at least in this case it was very clear that it is doing something meaningful. So we were able to, this was just motivation. We couldn't actually, so there's a whole literature that's trying to do mechanistic interpretability to try to understand how transformers are behaving and they do very rigorously, like understand each attention head, they try to remove it, change the inputs to see how much, whether this attention map is actually useful. We didn't do all of that. It requires a lot of engineering effort. But yeah, these can be misleading for sure. Okay, so I'm gonna summarize and then talk about the experiments, because the experiments are what actually started this whole thing because we thought they were really fascinating. So the summary for the theory is that you can get, instead of the iterative state simulation that is very RNN style, you can actually transformers can do some kind of, you know, faster computation and get shallower solutions to these problem. And this unifies a lot of results that existed for just addition, parity, dich languages and regular expressions in the literature because this is a construction that works for all of them. And it shows that, you know, given these long sequences, you can represent some emerging hierarchical representations and maybe in the future, these representations might be useful if this model is actually learning these representations in the hierarchical level. For experiments, we did synthetic experiments though, but we tried about 19 different semi-autometer motivated by all these experiments that were done before. And we asked the following questions. The standard training on transformers actually find these shortcuts because it's not obvious like just that representation doesn't guarantee that optimization will actually find it. And, you know, questions like can it work with, you know, limited supervision? We don't have all the states. Are these actually good solutions? Like, do they work out of distribution? And can we, you know, make them learn more recurrent patterns if you want them to? And the answers are the following. Yes, they are always able to find them. Transverse has standard training actually finds, always finds these shortcuts. It works with some limited supervision, not always. But the most interesting thing we observed was that these shortcut solutions are very good in distribution, but the moment you get out of distribution, these shortcut solutions break, where a recurrent solution would not break because recurrent solution has actually learned the dynamic. This is doing some shortcuts. So here's like a plot of the different automata we tried. Lighter is a better performance, darker is worse performance. And what you can see is that the complexity, actually going from top to down, the complexity of the automata is increasing and it's taking larger depths. And this A5 and S5 are actually unsolvable. So they cannot have shorter solutions than log of the sequence length. And it kind of takes that much depth to actually learn the solution. So we kind of have this deeper factorization takes more layers. So maybe it's doing something of this flavor actually in the construction. It's very hard to interpret whether it's actually doing that or not, but this is what we observe. How do we understand optimization? This is a good question. We are thinking about how do we understand optimization and some people in the group, in the room are probably working on these questions. So there are many ways we could think about, maybe there's like some curriculum learning that's happening because it gets to see past history of sequences like different lengths so it can learn something more interesting. Maybe we can like understand mechanistically what it has learned and try to understand this. But yeah, this is a very interesting question of why they're actually able to find these solutions. The auto distribution question is very interesting. So if we take parity for instance, and we change the number of ones that we see in the sequence. So right now we're training on like randomly generated sequence so the number of ones are balanced, right? Like number of ones and zeros are balanced. But if we change the number of ones, the performance starts dipping on both ends. And the way you can think about this is that shortcuts just need to count. And if you're giving random sequences, the sequences will concentrate about t by t over two plus minus square root t, right? So when it sees a lot of these sequences and it's learning only from these sequences, it can actually learn to just count in this range very well and not care about counting outside this range. And so if we change the sequence length, you know, we're observing these things that the number of ones if we change it starts performing poorly. And actually you can see that we can increase the sequence length if we keep the number of ones constant. The performance stays really well. But if we make the number of ones be like probability 0.5, then the performance goes down really fast. So actually it's like learning counting only up to some value. So these shortcut solutions are great, but because they're not actually learning the recurrent dynamic and they need to learn more complicated like this modular operation, it might not learn it over the entire distribution. So this is one example. There are other examples in which it also fails out of distribution. And the thing is you can only detect this brittleness if you go out of distribution because in distribution it's perfect. So how many samples you can try? It's perfect. It's only out of distribution when you make the sequences have this kind of behavior that it starts failing. And this could be a potential reason for why these additions and other things don't work because the sequences in which these don't work are might be more out of distribution sequences than the in distribution sequences it's seen. And the question, I guess I'm almost over time, but I'll just say that you can make these solutions. You can hope that maybe we can make these solutions behave more recurrently. And in fact, if you have seen this before, there are ways to make this recurrent by using the scratchpad idea, which you kind of like make the model predict predict a hidden state and recurs at many times. So try to unroll it like an RNN. So here what you can imagine is that instead of the input, I generate the output, I feed the output back as an input and then keep doing this. And in that fashion, it'll have to learn the recurrent structure because that solution exists. And at test time, I don't have these hidden states, so I can just take the output of the model and feed that in as a proxy of the hidden state. And this actually has been studied as to improve reasoning by giving the scratchpad access where it can actually work out the steps of computation. And what we see is that we lose the parallelism because now we have to implement this, we have to unroll this T times, but we get performance that's perfect out of distribution by taking the transform and training it in this fashion. Yeah. So we don't change the rule, we change the input distribution. So usually you're gonna train with some distribution. So generally you can choose an uniform distribution over the symbols, but maybe you wanna test it on a distribution that's skewed, maybe it has different properties. So parity, you can imagine you only have zero and one and you're taking equal probability of seeing zero and one and then you're like, okay, now I want to see more ones in my sequence and less zeros. Then it should still be robust because the underlying algorithm is simple, like it doesn't change, the algorithm doesn't change, but it doesn't learn the algorithm properly. Yeah, so I'm just gonna end with some discussion. This is what we showed. I guess the question you can ask is these shortcuts a bug or a feature? Of course, the problems are there brittle, they're not maybe learning the intended recurrence, but they're shallower, they're faster, there's some hierarchical representation, so they're actually maybe efficient, more efficient than the recurrent solutions. And I mean, if you can get the best of both worlds, that would be great. There is some new architecture that's being proposed that kind of tries to do this, get a recurrent parallelism. I'm not sure how these problems would perform on that, but that's something we're gonna try. But the hope is if we can get the best of both worlds in some form, then that would be great. And I'll just summarize. There's of course a lot of future questions that are open. Maybe you can suggest potential architecture fixes that can help with these problems. Like we know if the error is with the modular operator, then we can actually fix that in the architecture. We can try to interpret these solutions. We can try to find if these kind of shortcuts are actually happening in the world. We have a paper coming up that tries to do that. And we can understand what this crashpad-like reasoning is doing. In terms of theory, you can get these better constructions, understand more fine-grained, like how this attention network behaves, and quantify these auto-distribution failures. It's very hard, like you ask this question, what does auto-distribution mean? It is not very clear what auto-distribution means. Like, do you want worst case, like works on every input, or what do you exactly want? So quantifying that, and here we looked at only deterministic settings in real world, these are all stochastic. So like, I mean, all of these RL settings are stochastic. So understanding what happens in stochastic environments is another interesting question. So I'll end with that, and thank you so much for your attention. Now it does well on the distribution that we trained it on, but if we change the number of funds, the probability of one, it starts failing. If we make it less or more, it starts failing. Less, you would think it would still do well because it sees these subsequences, but then it learns something that depends on the position encoding. So it just kind of knows that in this position, I should see this number, so I just learned that. Yeah, exactly. Yeah, so log T solutions, it's not even clear whether it's learning those log T solutions because we can't really control what it actually learns. But yeah, they are sometimes brittle for training because it has to learn. So for example, we have to multiply these complex functions, right? Multiplication is not a very easy computation for MLPs to do. So right now, I mean, they're able to do it because we're looking at very discrete states, like we have only like five to 10 states. So it's not that complicated a multiplication operation. But yeah, in general, I would think that they'll start failing because multiplication is hard for it to do. But then you could hope that you can add something about the architecture, like the MLPs. You don't need to put MLPs on top. You can add some multiplication elements and it could do better. So we've tried some of these ideas to see if it does better. Some other activations actually do a little better. But yeah, these challenges are obviously there. Like for example, even the mod solution, you need to do mod T, right? T that depends on the length of the sequence. But in the other solution, I always have to do only mod two. So that's a much more robust solution. So of course, yeah, these are problems. Like if you want to actually operate on the, not on the state space exactly, but the function compositions, you're doing something more complicated. Or non-solvable groups? Yeah, so we tried in the experiment, we tried S5 and A5 that are the smallest, and they actually acquire a lot of depth. But it still finds a solution which was kind of cool to see. And then I think one of my collaborators actually ran this in a large scale and you can see the attention patterns are kind of, okay, then again, I'm reading into them, but they look like they're looking at this decomposition, how structured that we have. So maybe it's actually doing that, but it does do it in much less depth. So I think we tried on length 100 sequences and it doesn't in like 10 or eight layers. It doesn't require like 100 layers. Yeah.