 Hi there and welcome to lesson 16 where we are working on building our first flexible training framework the learner and I've got some very good news, which is that I have Thought of a way of doing it a little bit more gradually and simply actually than last time So that should that should make things a bit easier. So we're going to take it a bit more step by step So we're working in the 09 learner notebook today And we've seen already this this basic callbacks learner and so the idea is that We've seen so far this Learner which wasn't flexible at all, but it had all the basic pieces, which is we've got a fit method We are hard coding that we can only calculate accuracy and average loss We're hard coding. We're putting things on a default device Hard coding a single learning rate, but the basic idea is here we go through each epoch and call one epoch to to train or evaluate depending on this flag And then we loop through each batch in the data loader And one batch is going to grab the x and y parts of the batch Call the model Call the loss function and if we're training do the backward pass And then print out Well calculate the statistics for our accuracy And then at the end of the epoch print that out. So it wasn't very flexible Um, but it did do something so that's good So what we're going to do now is we're going to do is an intermediate step. We're going to look at a But I'm calling a basic callbacks learner and it actually has nearly all the functionality of the full thing Then when we're going to after we look at this basic callbacks learner, we're then going to um After creating some callbacks and metrics, we're going to look at something called the flexible learner Um, so let's go step by step So the basic callbacks learner looks very similar to the previous learner. Um, it It's got a fit function Which is going to Go through each epoch Uh calling one epoch with training on and then training off And then one epoch will go through each batch And call one batch and one batch will call the model the loss function and if we're training Um, it will do the backward step So that's all pretty similar, but there's a few more things going on here. For example, if we have a look at fit You'll see that um after creating the optimizer. So we call self dot opt funk. So opt funk here Um Defaults to sgd. So we instantiate an sgd object passing in our models parameters and the requested learning rate And then before we start looping through One epoch at a time now we've set epochs here Um, we first of all call self dot callback and passing in before fit now. What does that do? um self dot callback is here And it takes a method names in this case. It's before fit And it calls a function called run callbacks It passes in a list of our callbacks and the method name in this case before fit So run callbacks Is something that's going to go for each callback and it's going to sort them in order of their order attribute And so there's a base class through our callbacks, which has an order of zero So our callbacks all are going to have the same order of zero and we'll ask otherwise Um, so here's an example of a callback So before we look at how callbacks work, let's just let's just run a callback so we can create a ridiculously simple callback called completion callback Which before we start fitting a new model it will set its count attribute to zero After each batch it will increment that And after completing the fitting process it will print out how many batches we've done So before we even train a model we could just run manually before fit after batch And after fit using this run cbs And you can see it's ended up saying Completed one batches. So what did that do? So it went through Each of the cbs in this list. There's only one. So it's going to look at the one cb And it's going to try to use getattra To find an attribute with this name, which is before fit So if we try that manually, so this is the kind of thing I want you to do if you find anything difficult to understand is do it all manually So create a callback Set it to cbs zero just like you're doing in a loop Right and then find out what happens if we call this And pass in this And you'll see it's returned A method And then what happens to that method? It gets called So let's try calling it There we are. So that's what happened when we call the before fit Which doesn't do anything very interesting, but if we then call after batch And then we call after fit There it is, right? So Yeah, make sure you don't just run code really nearly but Understand it by experimenting with it And I don't always experiment with it myself in these classes often I'm leaving that to you but sometimes I'm trying to give you a sense of How I would experiment with code if I was learning it So then having done that I would then go ahead and delete those cells But you can see I'm using this interactive notebook environment to To explore and learn and understand and so Now we've got and and if I haven't created a simple example or something to make it really easy to understand You should do that, right? Don't just use what I've already created or what somebody else has already created So we've now got something that works totally independently. We can see how it works. This is what a callback does So a callback is something which we'll look at a class A callback is a class where you can define One or more of before after fit before after batch and before after epoch So it's going to go through and run all the callbacks That have a before fit method before we start fitting Then it'll go through each epoch And call one epoch with training and one epoch with evaluation and then when that's all done It will call after fit callbacks and one epoch will before it Starts on enumerating through the batches It will call before epoch and when it's done it will call after epoch The other thing you'll notice is that there's a try except immediately before every before method and immediately after every after method. There's a try And there's an accept and each one has a different thing to look for cancel fit exception cancel epoch exception and Cancel batch exception. So here's the bit which goes through each batch calls before batch Processes the batch calls after batch and if there's an exception That's of type cancel batch exception It gets ignored. So what's that for? So the reason we have this is that any of our um callbacks could call Uh could raise any one of these three exceptions to say I don't want to do this batch, please Um, so maybe you're looking example of that in a moment So we can now train with this Um, so let's call create a little get model function that creates a sequential model with just some linear layers And then we'll call fit And it's not telling us anything interesting because the only callback we added Was the completion callback That's fine. It's it's training. It's doing something and we now have a trained model Just didn't print out any metrics or anything because we don't have any callbacks for that That's the basic idea So we could create a um Maybe we could call it a single batch callback which after batch after a single batch it raises a Cancel Cancel fit exception So that's a pretty I mean, I suppose that could be kind of useful actually if you want to just run one batch to your model to make sure it works So we could try that So now we're going to add to our list of callbacks the single batch callback Let's try it And in fact, you know, we probably want this Let's just have a think here um Oh, that's fine. Let's run it There we go. So it ran and nothing happened and the reason nothing happened is because This cancelled before this ran So we could make this run second by setting its order to be higher And we could say just order equals one because the default order is zero and we thought in order of the order attribute actually, let's use um Cancel epoch exception There we go That way it'll run the final fit There we are So it did one batch for the It did one batch for the Training and one batch for the evaluation. So that's a total of two batches So remember callbacks are not a special magic part of like the python language or anything. It's just a a name We used to refer to these functions or classes or callables more accurately that we That we pass into something that will then call back to that callable at particular times And I think these are kind of interesting kinds of callbacks because these callbacks have multiple Methods in them So is each method a callback is each class with all those methods of callback? I don't know. I tend to think of the class with all the methods in as a single callback I'm not sure if we have great nomenclature for this Um, all right, so let's actually try to get this doing something more interesting by not modifying the learner at all But just by adding callbacks because that's the great hope of callbacks, right? so It would be very nice if it told us The accuracy and the loss So to do that, it would be great to have a class that can keep track of a metric So I've created here a metric class and maybe before we Look at it. We'll see how it works You could create for example an accuracy metric by defining the calculation necessary to calculate the accuracy metric Which is the mean Of how often do the inputs equal to targets? And the idea is you could then create an accuracy metric object You could add a batch of inputs and targets and add another batch of inputs and targets And get the value and there you would get the 0.45 accuracy Um, or another way you could do it would be just to create a metric Which simply takes gets the weighted average for example of your loss So you could add 0.6 as the loss with a batch size of 32 0.9 as a loss and a batch size of two And then that's going to give us A weighted average loss of 0.62 which is equal to this weighted average calculation So that's like one way we could kind of make it easy to calculate metrics So here's the class basically We're going to keep track of all of the actual values that we're averaging and the number in each mini batch And so when you add a mini batch We call calculate Which for example for accuracy remember this is going to Override the parent classes calculate. So it does the calculation here And then we'll add that to our list of values We will add to our list of batch sizes the current batch size Um, and then when you calculate the value We will Calculate the weighted sum. That's right the weighted mean weighted average Um now notice that here value. I didn't have to put parentheses after it And that's because it's a property. I think we've seen this before so just to remind you Property just means you don't have to put parentheses after it to get it's to get the calculation to happen All right, so just let me know if anybody's got any questions up to here, of course So we now need some way to use this metric In a callback to actually print out Um, the first thing I'm going to do though is I'm going to create one more one useful metric first a very simple one Just two lines of code called the device callback And that is something which is going to allow us to use kuda or For the apple gpu or whatever Without the complications we had before of You know, how do we have multiple processes and our data loader and also use our device and not have everything fall over So the way we could do it is we could say before fit Put the model onto the default device And Before each batch is run Put that batch onto the device Because look what happened in in the this is really really important in the learner Absolutely, everything is put inside self dot Which means it's all modifiable. So we go for self dot iteration number comma self dot the batch itself enumerating the data loader And then we call one batch But before it we call the callback so we can modify this Now how does the callback get access to the learner? Well, what actually happens is we go through each of our callbacks And put set an attribute called learn equal to the learner And so that means in the callback itself we can say self dot learn dot model And actually we could make this a bit better. I think so make it like maybe you don't want to use a default device So this is where I would be inclined to add a constructor And set device And we could default it to the default device of course And then we could use that instead and that would give us a bit more flexibility So if you wanted to train on some different device then you could I think that might be a slight improvement Okay, so there's a callback we can use to put things on CUDA and we could check that it works by just quickly Going back to our old learner here Remove the single batch cb And replace it with Device cb Yep, still works. So that's a good sign Okay, so now let's do our metrics now um Of course, we couldn't use metrics until we Built them by hand The good news is we don't have to write every single metric now by hand because they already exist In a fairly new project called torch eval, which is an official py torch project And so torch eval Is something that gives us actually I came across it after I had created my own metric class But it actually looks pretty similar to the one that I built earlier So you can install it With pep. I'm not sure if it's on konda yet, but it probably will be soon by the time you see the video Um, I think it's pure python anyway, so it doesn't matter how you install it um And yeah, it has a pretty similar Um Pretty similar approach where you call Dot update and you call dot compute so there's slightly different names, but they're basically super similar to the thing that we just built But there's a nice good list of Metrics to pick from So because we've already built our own now that means we're allowed to use theirs so we can import The multi-class accuracy metric and the main metric And just to show you they look very very similar If we call multi-class accuracy and we can pass in a mini batch of inputs and targets and compute And that all works nicely Um, now these in fact, it's exactly the same as what I wrote. We both added this thing called reset which basically well resets it and um So obviously we're going to be wanting to do that probably at the start of each epoch um And so if you reset it And then try to compute You'll get nan because you can't get accuracy accuracy is meaningless when you don't have any data yet Okay So let's create a metrics callback so we can print out our metrics I've got some ideas to improve this which maybe I'll do this week, but here's a basic working version Slightly hacky, but it's not too bad um So generally speaking One thing I noticed actually is I don't know if this is considered a bug But a lot of the metrics didn't seem to work correctly in torch eval when I had Tenses that were on the gpu and um had requires grad So I created a little two cpu function, which I think is very useful um, and that's just going to detach the so detach Takes the tensor and removes all the gradient history the computation history used to calculate a gradient and puts it on the cpu That'll do the same for dictionaries Of tensors lists of tensors and tuples of tensors So our metrics callback Basically, here's how we're going to use it. Um, so let's run it So here we're creating a metrics callback object And saying we want to create a metric called accuracy That's what's going to print out And this is the metrics object. We're going to use to calculate accuracy And so then we just pass that in as one of our callbacks And so you can see what it's going to do Is it's going to print out the epoch number Whether it's training or evaluating so training set or validation set and it'll print out our metrics and our current status actually we can simplify that We don't need to print those bits because it's all in the dictionary now. Let's do that There we go um, so Let's take a look at how this works. So we are going to be creating with for the callback we're going to be passing in the names and object metric objects for the metrics to track and print So here it is here star star metric. So we've seen star star before Um, and as a little shortcut I decided that it might be nice if you didn't want to write accuracy equals You could just remove that and run it And if you do that, then it will give it a name and it'll just use the same name as the class And so that's why you could either pass in So star ms will be a tuple Well, I mean it's got to be pulled out. So it's just passing a list of positional arguments Which will be turned into a tuple or you can pass in named arguments that'll be turned into a dictionary If you pass in positional arguments, then I'm going to turn them into Named arguments in the dictionary by just grabbing the name From their type. So that's where this comes from. So that's all that's going on here Just a little shortcut a bit of convenience. So we'll store that away um And this is yeah, this is a bit. I think I can simplify a little bit, but I'm just adding manually an additional metric Which is I'm going to call the loss and that's just going to be the weighted the weighted average of the losses um, so before we start fitting We um, we're going to actually tell the learner That we are the metrics callback. And so you'll see later where we're going to actually use this Um before each epoch we will reset all of our metrics After each epoch we will create a dictionary of the Keys and values which are the actual strings that we want to print out And we will call log Which for now we'll just print them And then after each batch, this is the key thing. We're going to actually Grab the input and target We're going to put them on the cpu And then we're going to Go through each of our metrics And call that update. So remember the update in the metric Is the thing that actually says here's a batch of data, right? So we're passing in the batch of data Which is the predictions And the targets And then we'll do the same thing for our special loss metric Passing in the actual loss And the size of our mini batch And so that's how We're able to get this Yeah, this actual Running on the Nvidia GPU And showing our metrics and obviously there's a lot of room to improve how this is displayed But all the informations we needed here and it's just a case of changing that function Okay, so that's our kind of like intermediate complexity learner um We can make it More sophisticated But it's still exactly it's still going to fit in a single screen of codes This is kind of my goal here was to keep everything in a single screen of code This first bit is exactly the same as before But you'll see that the One epoch and fit and batch has gone from Let's see, but what it was before It's gone from quite a lot of code all this To much less code And the trick to doing that is I decided to use a context manager We're going to learn more about context managers in the next notebook But basically, um, originally last week I was saying I was going to do this as a decorator, but I realized a context manager is better Um, basically what we're going to do is we're going to call our before and after callbacks in a try except block And to say that we want to use the callbacks in the try and accept block We're going to use a with statement. So in python a with statement says everything in that block Call our context manager Before and after it now there's a few ways to do that, but one really easy one is using this context manager decorator And everything up into the up to the yield statement is called before your code Where it says yield it then calls your code And then everything after the yield is called after your code So in this case, it's going to be try self.callback before name where name is fit And then it will call for self dot epoch, etc Because that's where the yield is and then it'll call self.callback after fit Except Okay, and now we need to grab the cancel fit exception so all of the variables that you have in python all live inside a special dictionary called globals So this dictionary contains all of your variables. So I can just look up in that dictionary the variable called Cancel fit with a capital F exception. So this is accept cancel fit exception So this is exactly the same then as This code except the nice thing is now I only have to write it once Rather than at least three times and I'm probably going to want more of them. So you know, I tend to think It's worth. Yeah, I I tend to think it's worth refactoring a code when you have duplicate code particularly here We had the same code three times So that's going to be more of a maintenance headache. We're probably going to want to add callbacks to more things later So by putting it into a context manager just once Um, I think we're going to reduce our maintenance burden Well, I know we do because I've had a similar thing in fast aio for some years now and it's been quite convenient So that's what this context manager's about Um Yeah, other than that the code's exactly the same so we create our optimizer And then with our callback context manager for fit Go through each epoch call one epoch Set it to training or non training mode based on the argument we pass in Grab the training or validation set based on the argument we pass in And then using the context manager for epoch Go through each batch in the data loader and then for each batch in the data loader using the batch context Now this is where something gets quite interesting We call predict get loss and if we're training backward step and zero grad but Previously we actually called self dot model, etc self dot loss function, etc So we go through each batch and Call before batch Do the batch Oh this way that's our That's our slow version. Wait, what are we doing? Oh, yes. We're going to be over here Um, okay. I'm back where we are. Yes So previously, uh, we were calling. Yeah calling calling the model Calling the loss function calling loss dot backward opt dot step opt dot zero grad but now We are Calling instead Self dot predict self dot get loss self dot backward and how on earth is that working because they're not defined here at all Um, and so the reason I've decided to do this is it gives us a lot of flexibility. We can now actually Create our own way of doing predict get loss backward step and zero grad In different situations and we're going to see some of those situations So what happens if we call self dot predict and it doesn't exist? Um, well, it doesn't necessarily cause an error. What actually happens is it calls a special magic method in python called done to get attra. That's what we've seen before And what i'm doing here is i'm saying, okay. Well, if it's one of these special five things Don't raise an attribute error, which is this is the default thing it does but instead Create a callback Or actually I should say call self dot callback Passing in That name so it's actually going to call self dot callback quote predict And self dot callback is exactly the same as before and so what that means now is to make this work exactly the same as it did before I need a callback which does these five things And here it is. I'm going to call it train callback So here are the five things predict get loss backward step and zero grad So they are here predict Get loss backward step and zero grad um, okay, so They're almost exactly the same as what they looked like in our intermediate learner except now I just need to have self dot learn in front of everything because we remember this is the callback It's not the learner And so for a callback the callback can access the learner using self dot learn so self dot learn dot preds There's self dot learn dot model passing in self dot learn dot batch And just the independent variables Diddo for the loss calls the loss function backward step zero grad So that's um At this point this isn't doing anything that it wasn't doing before but the nice thing is now if you want to use Hugging face accelerate or you want something that works on hugging face data styles dictionary things or whatever You can actually change And exactly how it behaves by just call passing by creating a callback for training Um, and if you want everything except one thing to be the same you can inherit from train cb So this is i've i've not tried this before i haven't seen this done anywhere else. Um, so it's a bit of an experiment So i would sit here how you go with it And then finally i thought it'd be nice to have a progress bar. So let's create a progress callback And the progress bar is going to show on it our current loss and going to put create a plot of it um, so I'm going to use a project that we created called fast progress um, mainly created by the wonderful sylvain And um, basically fast progress is Um, yeah, very nice way to create a very flexible progress bars So let me show you what it looks like first. So let's get the model And train and as you can see it actually in real time updates the graph and everything Um, there you go. That's pretty cool. So that's the um, that's the progress bar The metrics callback the device callback and the training callback all in action So before we fit, um, we actually have to set Self.learn.epox now that might look a little bit weird, but self.learn.epox Is the thing that we loop through for self.epox in so we can change that so it's not just a normal range but instead it is a progress bar around a range Um, we can then check remember I told you that the learner is going to have the metrics attribute applied We can then say oh if the learner has a metrics attribute Then let's replace the underscore log method there with ours and our one Instead we'll write to the progress bar. Now. This is pretty simple. It looks very similar to before But we could easily replace this example with something that creates an html table Which is another thing fast progress does or other stuff like that So you can see we can modify the nice thing is we can modify How our metrics are displayed Um, so that's a very powerful thing that python lets us do is actually replace one piece of code with another And that's the whole purpose of why the Metrics callback Had this underscore log separately. So why didn't I just say print here? Oh, that's because this way classes can replace How the metrics are displayed so we could change that to like send them over to weights and biases for example or You know create visualizations or so forth Um So before epoch we do a very similar thing the self dot learn dot dl iterator We change it to have a progress bar wrapped around it and then after each bar we um Set the progress bars comment to be the to be the loss It's going to print just going to show the loss on the progress bar as it goes And if we've asked for a plot Then we will append the losses to a list of losses um And we will update the graph with the losses and the batch numbers So there we have it we have a um Yeah, nice working learner Which is I think the most flexible learner that training loop probably that's I hope has ever been written Because I think the fast ai 2 one was the most flexible that had ever been written before and this is more flexible And the nice thing is You can make this your own, you know You can you know fully understand this training loop So it's kind of like you can use a framework But it's a framework in which you're totally in control of it and you can make it work exactly how you want to Ideally not by changing the learn changing the learner itself. Ideally by creating callbacks But if you want to you could certainly like look at that the whole learner fits on a single screen So you could certainly change that Um, we haven't added inference yet, although that shouldn't be too much to add. I guess we have to do that at some point um Okay, now interestingly I love this about python. It's so flexible when when we said, um Self.predict self.getloss. I said if they don't exist, then it's going to use get atcha And it's going to try to find those in the callbacks Um, and in fact, you could have multiple callbacks that define these things and then they would chain them together Which would be kind of interesting um But there's another way we could make these exist which is um Which is that we could subclass this So let's not use train cb just to just to show us how this would work and instead we're going to use a subclass So here i'm going to subclass learner and i'm going to override the five Well, it's not exactly overriding. I didn't have any definition of them before so i'm going to define the five directly In the learner subclass so that way it's never going to end up going to get atcha because get get atcha is only called if something doesn't exist um So here it's basically all these five are exactly the same As in our train callback except we don't need self dot learn anymore. We can just use self because we're now in the learner Um, but i've changed zero grad to do something a bit crazy Um, i'm not sure if this has been done before i haven't seen it But maybe it's an old trick that i just haven't come across But it occurred to me zero grad Which remember is the thing that we call after we take the optimizer step Doesn't actually have to zero the gradients at all What if instead of zeroing the gradients? We multiplied them by some number like say 0.85 um Well What would that do? Well, what it would do is it would mean that your previous gradients Would still be there But they would be reduced a bit and remember what happens In pi torch is pi torch always adds The gradients to the existing gradients and that's why we normally have to call zero grad But if instead we multiply the gradients by some number, I mean we should really make this a parameter Let's do that. Shall we so let's create a parameter Um, so probably there's a few ways we could do this Um Well, let's do it properly We've got a little bit of time um So we could say well Maybe it was copy and paste all those over here And we'll add um Momentum Momentum equals zero point 85 self dot momentum equals momentum and then super So make sure you call the super classes passing in all the stuff We could use delegates for this and quags. That would be possibly another great way of doing it, but Let's just do this for now Okay, and then so there we wouldn't make it 0.85. We would make it self dot momentum So you'll see now Still trains, but there's no train cb Callback anymore in my list I don't need one because I have to find the five methods in the subclass Now this uh training at the same learning rate for the same time the accuracy Um, let me get improved by more. Let's run them all Yeah, this is a lot like gradient accumulation callback. They're kind of cooler. I think Okay, so uh the Let's see the loss has gone from 0.8 to 0.55 and the accuracy has gone from about 0.7 to about 0.8 So they've improved Um, why is that? Well, we're going to be learning a lot more about this Um pretty shortly But basically what's happening here But basically what's happening here is we have just implemented in a very interesting way, which I haven't seen done before something called momentum And basically what momentum does is it say like Imagine you are, you know, you're trying you've got some kind of complex contour Um Lost surface right and you know, so imagine these are hills with a marble very similar right and your marbles up here What would normally happen with gradient descent is it would go You know in the direction downhill, which is this way. So we go over here and then over here, right very slow What momentum does is it's is the first steps the same And then the second step says oh, I wanted to go this way But I'm going to add together The previous direction plus the new direction but reduce the previous direction a bit So that would actually make me end up about here And then the second one does the same thing And so momentum basically Makes you much more quickly go to your destination Um, so normally momentum is done, uh, the reason I did it this way partly to show you It's just a bit of fun a bit of interest But it's very it's very useful because normally momentum You have to store a complete copy basically of all the gradients The momentum version of the gradients so that you can kind of keep track of that that that running exponentially weighted moving average But using this trick You're actually using the dot grad themselves to store the exponentially weighted moving average So anyway, there's a little bit of fun, which hopefully particularly those of you who are interested in Accelerated optimizers and memory saving might find a bit inspiring All right, there's one more Call back i'm going to show you before the break, which is the wonderful learning rate finder I'm assuming that anybody who's watching this already is familiar with the learning rate finder from fast ai If you're not there's lots of videos and tutorials around about it It's an idea that comes from a paper by leslie smith from a few years ago and the basic idea is that we Will increase the learning rate. I should have put titles on this the the x-axis here is learning rate the y-axis here is loss we increase the learning rate Gradually over time and we plot the loss against the learning rate and we find how high can we bring the learning rate up Before the loss starts getting worse And you kind of want roughly where about the steepest slope is so probably here it would be about 0.1 So it'd be nice to create a learning rate finder So here's a learning rate finder callback So what a learning rate finder needs to do? Well, you have to tell it how much to multiply the learning rate by each batch So let's say we add 30 percent of the learning rate each batch And so we'll store that So before we fit We obviously need to keep track of the learning rates and we need to keep track Of the losses because those are the things that we put on a plot The other thing we have to do is decide when do we stop training So when is it clearly gone off the rails? And I decided that if the loss is three times higher than the minimum loss we've seen Then we should stop So we're going to keep track of the minimum loss And so let's just initially set that to infinity. It's a nice big number Well, not quite a number but a number-ish like thing So then after every batch First of all, let's check that we're training. Okay, if we're not training Then we don't want to do anything. We don't use the learning rate finder during validation So here's a really handy thing just raise cancel epoch exception And that stops it from doing that epoch entirely. So just to see how that works You can see here One epoch does with the callback context manager epoch And that will say oh It's got cancelled. So it goes straight to the accept Which is going to go all the way to the end of that code And it's going to skip it So it's you can see that we're using exceptions as control structures Which is actually a really Powerful programming technique that is really underutilized in my opinion Like a lot of things I do it's actually somewhat controversial Some people think it's a bad idea But I find it actually Makes my code more concise and more maintainable and more powerful. So I like it So Let's see. Yeah, so that's we've got our cancel epoch exception. So then we're just going to keep track of our learning rates The learning rates we're going to learn a lot more about optimizers shortly So I won't worry too much about this But basically the learning rates are stored by PyTorch inside the optimizer And they're actually stored in things called param groups parameter groups So don't worry too much about the details, but we can grab the learning rate from that dictionary And we'll learn more about that shortly We've got to keep track of the loss Append it to our list of losses And if it's less than the minimum we've seen then recorded as the minimum And if it's greater than a three times the minimum then look at this, this is really cool cancel fit exception So this will stop everything In a very nice clean way No need for lots of returns and conditionals or and stuff like that just raise the cancel fit exception And yeah, and then finally we've got to actually update our learning rate to 1.3 times the previous one And so basically the way you do it in PyTorch is you have to go through each parameter group And grab the learning rate in the dictionary and multiply it by lr molt So yeah, you've already seen it run And we can at the end of running You will find that there is now a The callback will now contain an lr's and a losses So for this callback, I can't just add it directly to the callback list I need to instantiate it first And the reason I need to instantiate it first is because I need to be able to grab its learning rates and its losses And in fact, you know, we could grab that whole thing And move it in here. There's no reason callbacks only have to have the callback things right So we could do this Um, and now that's just going to become self There we go And so then we can train it again and we could just call lrfind plot So callbacks can really be, you know, quite self contained nice things as you can see So there's a more sophisticated callback and I think it's doing a lot of really nice stuff here um You might have come across something in PyTorch called learning rate schedulers And in fact, we could implement this whole thing with a learning rate scheduler It won't actually save that much time But I just want to show you When you use stuff in PyTorch like learning rate schedulers, you're actually using things that are extremely simple The learning rate scheduler basically does this one line of code for us So I'm going to now create a new lrfindercd And this time I'm going to use the pyTorch's exponential lr scheduler Which is here So this is now it's interesting that actually The documentation of this is kind of actually wrong It claims that it decays the learning rate of each parameter group by gamma So gamma is just some number you pass in I don't know why this has to be a Greek letter, but it sounds more fancy Then multiplying by an lr multiplier It says every epoch But it's not actually done every epoch at all. What actually happens is In pyTorch the schedulers have a step method And the decay happens each time you call step And if you set gamma, which is actually lrmult To a number bigger than one, it's not a decay. It's an increase So the difference now, I guess I'll copy and paste the previous version Okay, so the previous version is on the top So the main difference here is that before fit We're going to create something called a self.shed equal to The scheduler and the scheduler because it's going to be adjusting the learning rates It actually needs access to the optimizer. So we pass in the optimizer and the learning rate model player And so then in after batch rather than having this line of code We replace it with this line of code self.shed.step So that's the only difference and you know, I mean we're not Gaining much as I said by using the pyTorch exponential lr scheduler But I mainly wanted to do it so you can see that these things like pyTorch schedulers are not Doing anything magic. They're they're just doing that one line of code for us And so I run it again using this new version opsie-daisy Um Oh, I forgot to run this line of code There we go. And I guess I should also add the nice little plot method. Maybe we'll just move it to the bottom there lr find up plot There we go And put that one back to how it was All right Perfect timing. So, um, we added a few very important things in here. So make sure we export And we'll be able to use them shortly All right, let's have an eight minute break Um, let's just have a 10 minute break. So I'll see you back here at um Eight past All right, welcome back. Um one suggestion, which I literally like is we could rename plot To after fit Which I really like because that means we should be able to then just call learn dot fit and delete the next one and Let's see that didn't work. Why not? Oh, no, that doesn't work. Does it because the Hmm, you know what I think the um callback here Could go into a finally block Actually That would actually allow us to always call the callback even if we've cancelled I think that's reasonable That may have its own confusions. Anyway, we could try it for now because that would let us put this after fit in There we go. So that, um Automatically runs that So, um, that's an interesting idea. I think I quite like it Um, cool So let's now look at notebook, um 10 so I feel like this is the the next Big piece we need. So we've got a pretty good system now for training models um what I Think we're really missing though is a way to identify How our models are training? And so to I identify how our models are training We need to be able to look inside them and see what's going on while they train We don't currently have any way to do that and therefore it's very hard for us To diagnose and fix problems Most people have no way of looking inside their models And so most people have no way to properly diagnose and fix models And that's why most people when they have a problem with training their model Randomly try things until something starts. Hopefully working We're not going to do that. Um, we're going to do it properly So we can import the stuff that we just created in the learner Um, and the first thing I'm going to do, uh, introduce now is a set seed function Um, we've been using torch dot manual seed before we know all about rngs random number generators. Um We've actually got three of them Pi torches num pies and py pythons. Let's seed all of them And also in python py torch You can use a flag to ask it to use deterministic algorithms So things should be reproducible as we've discussed before you shouldn't always just make things reproducible But for lessons, I think this is useful. So here's a function that lets you set a reproducible seed All right, let's use the same data set as before a fashion eminus data set. We'll load it up in the same way And let's create a model That looks very similar to our previous models This one might be a bit bigger might not I didn't actually check Okay so let's Use multi-class accuracy again Same callbacks that we used before We'll use the train cb version for no particular reason And generally speaking We want to train as fast as possible Not just because we don't like wasting time but actually more importantly because The the higher the learning rate you train at The more the more you're able to find a Often a more generalizable Set of weights and also Oh training quickly also means that we can look at each batch let each Item in the data less often. So we're going to have less issues with overfitting And generally speaking If we can train at a high learning rate Then that means that we're learning to train in a stable way and stable training is is very good So let's try setting up a high learning rate of 0.6 and see what happens So here's a function that's just going to create our learner with our callbacks And fit it and return the learner in case we want to use it And it's training. Oh, and then it suddenly fell apart. So it's going well for a while and then it stopped training nicely So one nice thing about this graph is that we can immediately see when it stops training well, which is very useful Um So what happened there? Why did it go badly? I mean we can guess that it might have been because of our high learning rate But what's really going on? So let's try to look inside it So one way to look inside it would be we could create our own sequential model We're just like the sequential model. We've built before Do you remember we created one using nn.module list in a previous lesson if you've forgotten go back and check that out And when we call that model, we go through each layer and just call the layer And what we could do is so we could add something in addition, which is at each layer we could also Get the mean of that layer And the standard deviation of that layer and append them to a couple of different Lists and activation means and activation standard deviations This is going to contain after we call this model It's going to contain the means and standard deviations For each layer And then we could define dunder iter which makes this into an iterator as being let's say just oh just when you iterate through this model You can iterate through the layers So we can then train this model in the usual way And this is going to give us exactly the same outcome as before because i'm using the same seed so you can see it looks identical But the difference is instead of using nn.sequential We've now used something that's actually saved the means and standard deviations of each layer And so therefore we can plot them Okay, so here we've plotted the activation means And notice that we've done it for every batch So that's why along the x-axis here we have batch number And on the y-axis we have the activation means and then we have it for each layer So rather than starting at one because we python we're starting at zero So this is the first layer his blue second layer is orange third layer green fourth layer red And fifth layer watch whether like movie kind of color And look what's happened the activations have started pretty small close to zero and have increased at an exponentially increasing rate And then have crashed And then have increased again an exponentially rate and crashed again It increased again crashed again And each time they've gone up they've gone up even higher and they've crashed In this case even lower And what happens well when what's happening here when our activations are really close to zero Well, when your activations are really close to zero that means that the inputs to each layer are numbers very close to zero As a result of which of course the outputs are very close to zero Because we're doing just matrix multiplies and so This is a disaster when activations are very close to zero Your they're they're um They're dead units. They're not able to do anything and you can see for ages here. It's not training at all And this is um, so this is the this is the activation means the standard deviations tell an even stronger story so you want um Generally speaking you want the means of the activations to be about zero And the standard deviations to be about one Mean that zero is fine as long as they're spread around zero But a standard deviation of close to zero is terrible because that means all of the activations are about the same So here after batch 30 All all of the activations are close to zero and all of their standard deviations are close to zero So all the numbers are about the same and they're about zero So nothing's going on Um And you can see the same things happening with standard deviations We start with not very much variety in the weights It exponentially increases how much variety there is and then it crashes again exponentially increases crashes again This is a classic Shape of bad behavior And with these two plots you can really understand what's going on in your model And if you train a model and at the end of it you kind of think well, I wonder if this is any good If you haven't looked at this plot, you don't know because you haven't checked to see whether it's training nicely Maybe it could it could be a lot better if you can get something we'll see some nicer training Pictures later, but generally speaking you want something where your mean is always about zero And your variance is always about one standard deviation. It's always about one And if you see that then it's a pretty good chance to your training properly If you don't see that you're almost certainly not training properly Okay, so what i'm going to do in the rest of this part of the lesson is explain How to do this in a more elegant way because as I say being able to look inside your models is such a critically important thing To building and debugging models We don't have to do it manually. We don't have to create our own sequential model We can actually use a pie torch thing called hooks So As it says here a hook is called When a layer that it's registered to is executed during the forward pass That's called a forward hook or the backward pass and that's called a backward hook And so the key thing about hooks is we don't have to rewrite the model. We can add them to any existing model So we can just use standard and n dot sequential passing in our Layers which were These ones here And so we're still going to have something to keep track of the activation means and standard deviations So just create an empty list for now For each layer in the model And let's create a little function It's going to be called because a hook is going to call a function when when during the forward pass For a forward hook or the backward pass through a backward hook So it could have function called a pen stats. It's going to be passed the hook number. I'm sorry the layer number the module And the input and the output so we're going to be grabbing the Outputs mean and putting in in activation means And the output standard deviation and putting it in activation standard deviations So here's how you do it. We've got a model You go through each layer of the model and you call on a register forward hook. That's part of pi torch And we don't need to write it ourselves because we already did right. It's just doing the same thing as this basically And what function is always going to be called The function that's going to be called is the append stats function Passing in remember partial is the equivalent of saying append stats passing in i As the first element the first argument So if we now fit that model it trains in the usual way But after each After each layer it's going to call this And so you can see we get exactly the same thing as before So one question we get here is what's the difference between a hook and a callback? nothing at all Hawks and callbacks are the same thing. It's just that pi torch defines Hawks and they call them hawks instead of callbacks They are less flexible than the callbacks that we used in In the learner because you don't have access to all the available states. You can't change things But there are yeah, they're a particular kind of callback. It's just setting a piece of code that's going to be run for us when we When something happens and in this case there's something that happens is that either a layer in the forward pass is called or a layer in the backward pass is called I guess you could describe the function that's being called back as the callback And the thing that's doing the callback has the hook I'm not sure if that level of distinction is important, but maybe that's you could true that Okay, so anyway, this is a little bit fussy of kind of like creating globals and depending to them and stuff like that So let's try to simplify this a little bit So what I did here was I created a class called hook So this class when we create it, we're going to pass in the module That we're hooking so we call m register forward hook And we call the function we pass the function that we want to be Given and so here's we pass the function and we're also going to pass in the hook class to the function Let's also define a remove because this is actually the thing that This is actually the thing that removes the hook. We don't want it sitting around forever This is called del is called by python when an object is freed. So when that happens, we should also make sure that we remove this Okay, so a pen stats now we're going to replace it's going to instead get past the hook instead Because that's what we asked to be passed And if there's no dot stats attribute in there yet, then let's create one And then we're going to be past the activation so put that on the cpu And append the mean and at the standard deviation And now the nice thing is that the stats are actually inside this object, which is convenient So now we can do exactly the same thing as before But we don't have to set any of that global stuff or whatever We can just say okay, our hooks is a hook with that layer and that function for All those models layers And so we're just calling it has called register forward hook for us So now when we fit that It's going to run with the hooks There we go it trains Actually, they did do it too Okay, so then it trains and we get exactly the same shape as usual and we get back the same results as usual But as we can see we're gradually making this more convenient, which is nice So We can make it nicer still because generally speaking we're going to be adding multiple hooks and this stuff of You know this list comprehension, whatever it's a bit inconvenient. So let's create a hooks class So first of all, we'll see how the hooks class works in practice So in the hooks class the way we're going to use it is we're going to call with hooks pass in the model pass in the function to use as their hook And then we'll fit the model And that's it. It's going to be literally just one extra line of code to set up the whole thing And then when we then we can then go through each hook and plot The mean and standard deviation of each layer So that's how that's the hooks class is going to make things much easier So the hooks class as you can see we're using a making it a context manager And We want to be able to loop through it We want to be able to index into it. So it's kind of a lot of behavior we want Believe it or not, all that behavior is in this tiny little thing And we're going to use the most flexible general way of creating context managers now context managers are things that we can say with The general way of creating a context manager is to create a class And to find two special things dunder enter and dunder exit Dunder enter is a function that's going to be called When it hits the with statement And if you add an as Blar after it then the contents of this variable will be whatever is returned from dunder enter And as you can see we just return the object itself So the the hooks object is going to be stored in hooks Um Now interestingly the hooks class inherits from list you can do this you can actually inherit from Stuff like list in python. So a hooks the hooks object is a list And therefore we need to call the super classes constructor And we're going to pass in a that list comprehension. We saw that list of hooks Where it's going to hook into each module in the list of modules we asked to hook into Um Now we're passing in a model here But because the model is an nn.sequential you can actually loop through an nn.sequential and it returns each of the layers So this is actually very very Nice and concise and convenient So that's the constructor dunder enter just returns it Dunder exit is what's called automatically at the end of the whole block So when this whole thing's finished, it's going to remove the hooks And removing the hooks is just going to go through each hook and remove it And the reason we can do for h itself is because remember this is a list um And then uh, finally we've got a dunder dell like before And I also added a dunder dell item. This is the thing that lets you delete a single hook from the list Which will remove that one hook and call The lists dell item So, um, there's our whole thing. So this is going to this this this one's optional This is the one that lets us remove a single hook rather than all of them um so Let's just understand some of what's going on there. So here's a dummy context manager As you can see here, it's got a dunder enter Which is going to return itself And it's going to print something So you can see here I call with dummy context manager And so therefore it prints. Let's go first The second thing it's going to do is call This code inside the context manager So we've got as dcm. So that's itself And so it's going to actually call hello, which prints hello So here it is And then finally it's going to automatically call exit Thunder exit, which is all done So here's all done. So again, if you haven't used context managers before You want to be creating little samples like this yourself and getting them to work So this is your key homework for this week is Anything in the lesson Where we're using a part of python. You're not 100 familiar with Is for you to from scratch to create some simple like kind of dummy version that fully explores What it's doing If you're familiar with all the python pieces, then it's to create Your own, you know, that is to explore do the same with the pytorch pieces like with with hooks and so forth And so I just wanted to show you also what it's like to Inherit from list. So here I'm here inheriting from a list and I could redefine how dunder del item works So now I can create a dummy list And it looks exactly the same as usual, but now if I delete an item from the list It's going to call my Overridden version And then it will call the original version And so the list is now removed that item and did this at the same time. So you can see you can actually Yeah, modify how python works Or create your own things that get all the behavior or the convenience of python classes Like this one and add stuff to them So that's what's happening there Okay, so that's our hooks class Um, so the next bit, um Was developed, uh Largely developed the last time I think it was that we did a part two course In san francisco with stafano. So many thanks to him for helping get this next bit looking great We're going to create my favorite Um single image Explanations of what's going on inside a model Um, we call them the colorful dimension, which they're histograms Um, we're going to take our same append stats. These are all the same as before We're going to add an extra line of code Which is to get a histogram Of the absolute values of the activations So a histogram a histogram to remind you is something that takes A collection of i of numbers and tells you how frequent each group of numbers are And we're going to create 50 bins for our histogram So We will use our hooks That we just created and we're going to use this new version of append stats So it's going to train us before but now we're going to in addition have this extra Extra thing in stats. We're just going to contain a histogram And so with that we're now going to create this amazing plot Now what this plot is showing is for the first second third and fourth layers What does the training look like and you can immediately see the basic idea is that we're seeing this same pattern Um, but what is this pattern Showing what exactly is going on in these pictures? So I think it might be best if we Try and draw a picture of this So let's take a normal histogram Okay, so let's take a normal histogram Where what we'll be where we basically have like Have grouped all the data into bins And then we have counts Of how much is in each bin So for example This will be like the value of the activations and it might be say um From zero to 10 And then from 10 to 20 and from 20 to 30 and these are generally equally spaced bins Okay And then um Here is the count So that's the number of items with that range of values So this is called a histogram Okay So um What Stefano and I did was we um actually Turn that histogram that whole histogram into a single column of pixels So if I take one column of pixels Um, we're that's actually one histogram and the way we do it Is we take these numbers. So let's say Let's say it's like 14 that one's like two Seven nine 11 Three two four two say and so then what we do is we turn it into a single column And so in this case, we've got one two three four five six seven eight nine groups, right? So we would create our nine groups Sorry, they were meant to be evenly spaced, but they were not a very good job Um got our nine groups and so we take the first group. It's 14 And what we do Is we color it With a gradient and a color according to how big that number is so 14 is a real big number So depending on, you know, what gradient we use maybe red's really really big and the next one's really small Which might be like green And then the next one's quite big in the middle, which is like blue The next one's getting quite quite big as still so maybe it's just a little bit Sorry, we should go back to red go back to more red Next one's bigger stills. It's even more red And so forth. So basically we're taking the histogram And taking it into a color coded single column plot if that makes sense And so what that means is that at the very so let's take layer number two here Um layer number two we can take the very first column And so uh in the color scheme that actually map plot lives picked here yellow is the most common And then light green is less common and then light blue is less common and then dark blue is zero So you can see the vast majority is zero And there's a few with slightly bigger numbers Which is exactly the same that we saw for Index one layer here it is Right, um the average the average Is pretty close to zero the standard deviation is pretty small This is giving us more information. However So as we train at this point here the um At this point here there is Quite a few activations that are a lot larger as you can see and still the vast majority of them are very small There's a few big ones. So it's still got a bright yellow bar at the bottom The other thing to notice here is what's happened is we've taken those Those stats those histograms. We've stacked them all up into a single tensor And then we've taken their log now log one p is just log of the number plus one That's because we've got zeros here And so just taking the log is going to kind of um Let us see The full range more clearly. So that's what the log's for So basically what we're really ideally like to see here is that this whole thing Should be a kind of More like a rectangle, you know, the maximum should be should be not changing very much There shouldn't be a thick yellow bar at the bottom, but instead it should be a nice even gradient Matching a normal distribution each single column of pixels wants to be kind of like a normal distribution. So You know gradually decreasing the number of activations That's what we're aiming for There's a Another really important and actually easier to read version of this Which is what if we just took those first two bottom pixels So the the least common five percent And counted up how many were in what's not the foot sorry least common five percent the least the Not least common either let's try again In the bottom two pixels We've got the um smallest two equally sized groups of activations um We don't want there to be too many of them because those are basically dead or nearly dead They're much much much smaller than the big ones. And so taking the ratio between those bottom two groups and the total basically tells us what percentage have Zero or near zero or extremely small magnitudes And Remember that these are with absolute values um So if we plot those you can see how bad this is um, and in particular for example at the final layer From the you know nearly from the very start really nearly all of the activations are Are entirely just about entirely disabled Um So this is this is bad news Um, and if you've got a model where most of your model Is close to zero then most of your model is doing no work And so it's it's really It's really not working So it may look like at the very end things were improving But as you can see from this chart, that's not true, right? There's still the vast majority is still inactive Generally speaking I found that if early in training you see this rising crash rising crash at all You should stop and restart training Because this Your model will probably never recover Too many of the activations have gone off the rails So we want it to look Kind of like this The whole time But with less of this very thick yellow bar, which is showing us most are inactive Okay, so that's our activations Um So we've got really now Um all of the Kind of key pieces I think we need to be able to Flexibly change how we train models And to understand what's going on inside our models Um, and so from this point we've kind of like drilled down as deep As we need to go and we can now start to come back up again Um and and put together The pieces building up. What are all of the things that are going to help us train? Models reliably and quickly and then um, hopefully we're going to be able to yeah successfully create from scratch some really high quality Generative models and other models along the way Okay, I think that's everything for this class Next class we're going to start looking at things like initialization. It's a really important topic If you want to do some revision before then just make sure that you're very comfortable with things like standard deviations and other Stuff like that because we'll be using that quite a lot for next time And uh, yeah, thanks for joining me. Look forward to the next lesson. See you again