Understanding callbacks in fastai
A large part of lesson 2 of fast.ai’s deep learning part 2 course focuses on understanding the Callback system in fast.ai.
If you are not participating in the course, here is an interesting talk by Sylvain Gugger to get a high-level overview of the callback system.
In one sentence, the callback system is a design pattern in the fast.ai library to hook into different parts of the training loop of a neural network to easily tweak the specifics of that training loop.
By using callbacks, you can add things like learning rate annealing, keep track of metrics or even implement a GAN without having to write an entirely new training loop. The flexibility of this callback system makes it convenient to experiment with custom neural networks.
How does this callback system work? In this article, I will highlight a few lines of the callback system that might not seem very obvious. If you follow along, you will deepen your understanding of the fast.ai callback system and learn a thing or two about Python if you did not use it extensively.
Note: In this post, I look at the new Callback system as discussed in lesson two part two of the fast.ai course. This new system is not included in the fastai library yet, and thus a bit different from the current (fastai v1.0.50) callback system. If you see things in this post that you cannot find in the fastai source code, this is the reason.
General overview of the Callback system
Before diving into the specifics, I explain the current flow of the Callback system as briefly as possible. Having high-level knowledge of the callback system is necessary to understand its details.
The new training loop is part of a Runner class. A simplified training loop in fastai looks like this - I removed and modified this class to make it easier to follow:
The main thing to look at there is the
fit function. There is
pred to get the predictions,
loss to see how good these predictions are and
loss.backward() to calculate the gradient of the loss with respect to the parameters. Finally,
opt.step() update the weights by calling the PyTorch Optimizer class, and
opt.zero_grad() zeroes out the gradients again.
With callbacks, a simplified version of the
fit() method inside the Runner class looks like this:
It is quite similar! The only extra parts are all these if statements that call the runner class with an argument.
What the hell does
__call__ do in
Wait. There seems to be something wrong. If I run something like
self('begin_batch') right now, the code does not work.
The init method does not allow something like this (no,
listify() does not do that either). More precisely,
the Runner class is not even callable right now - something like
r = Runner(); r() does not work, although
to suggest it should.
It is clear that we need to add something. This is where the special Python dunder method
__call__ in the Runner class comes into play, a method that I left out of the example until now.
__call__ makes the Runner class callable; it lets you treat an object as a function. This means that when we implement this method, we can call the Runner class with
self('begin_batch') and the
__call__ method will run when we do that. Here is the implementation of that method:
When I first looked at this method, I thought: what the hell does this do? There is a lot to unpack here. Let us go through it step by step.
The method receives a callback name,
cb_name, which allows us to do something like
self('begin_batch'). This should be straightforward by now.
If it is not, have a look at the previous paragraph again, and the included links.
This line loops through all callbacks that are initiated when the Runner class was defined. It loops through them in the order that was defined in each callback class:
So each Callback may have an order, to allow for maximum flexibility of the callbacks, i.e., it adds the possibility to define the order of callbacks so that you can specify which one to execute first. This can become quite handy when implementing more difficult things where the order of execution matters.
listify() in the init method of the Runner class ensures that
self.cbs is a list, so you can iterate over the callbacks, even if you do not instantiate the Runner class with a list of callbacks.
This line looks inside each callback with
getattr and checks whether the
cb_name is implemented. If it cannot find it, it sets
If the callback name can be found (
if f), i.e.,
f is not
None, call it with
f() and if that returns
True, stop calling other callbacks and exit the loop (
This last line ensures that
False, so the training loop continues (have a look at the
fit() function if it is not clear to you why this is true).
The line seems to be added to be explicit here and is not strictly necessary, since a function that does not return anything defaults to returning
False in Python.
This is an important thing to remember, because it is also the reason why
False is used to signal that the training loop should continue, although
True is (at least in my experience) usually used to express that the execution of something should continue. However, returning
True to continue the training loop has a definite downside in this case:
Most implemented callbacks do not stop the training loop - they only add a tweak to it. By defaulting to returning
False to continue fitting, the majority of the callbacks do not have to return anything - because returning nothing defaults to returning
None which means
That is one less thing to remember when implementing most callbacks, and one less line of code per callback.
I hope this explanation of the
__call__ function helps you to understand what is going on in inside it. If it did not, remember this: the
__call__ function runs multiple times during the training loop and receives a different callback name each time it is run. It checks if that callback name is implemented, and if that is the case, it executes that specific callback.
Using Runner attributes in a callback and vice versa
There is one more thing I would like to explain in this blog post.
When you implement your callback, you often use attributes from the Runner class, because these are the things that you want to tweak (e.g.,
optim ). You have access to these attributes because
for cb in self.cbs: cb.set_runner(self) is implemented in the Runner class:
However, wouldn’t it be more convenient if you did not have to call each of these attributes via
That’s why the
__getattr__ is implemented in the Callback class:
This method allows you to do something like
self.model inside a Callback class, although you did not explicitly define a
self.model in that class. How does that work? The special dunder
__getattr__ method runs each time you try to get an attribute that is not implemented. Because the body of that method is
return getattr(self.run, k), it checks if that attribute is implemented in the runner class. It is a little but a very nice addition to the Callback class that makes a big difference in the end, given the number of callbacks in the fastai library.
At the same time, it can be beneficial to access callbacks attributes from within the Runner class. For example, the ability to plot losses in fastai is implemented in a callback. Unfortunately, when you are working in your Jupyter notebook, you do not have direct access to that Callback class each time you are training. Or do you? Have a look at this piece of code:
What is going on here? Each callback that is passed onto the Runner class is being set as an attribute via
setattr(self, cb.name, cb), which calls the Callback
name() method. This method takes the name of the callback class, removes ‘Callback’ from that name with a regular expression and returns a ‘camel2snake’ string of the name of the callback class. What does this mean in practice?
Let us say you have implemented a callback named LearningRateCallback that keeps track of the learning rate during training and allows you to plot the learning rate via a method called
plot() (In fact, fastai has something very similar via a more general Recorder callback and a
The code in
def name() removes
LearningRateCallback and transforms it to snake case:
setattr(self, cb.name, cb) inside Runner then sets this
learning_rate as an attribute inside the Runner class.
This allows you to do something like this:
That is quite neat and useful - you do not have to think about how to access each callback when you are working with Runner anymore.
Let me know if you have any questions, if there is anything I could have explained better or if you have any feedback in general on the fastai forum.