Understanding callbacks in fastai

A large part of lesson 2 of fast.ai’s deep learning part 2 course focuses on understanding the Callback system in fast.ai. If you are not participating in the course, here is an interesting talk by Sylvain Gugger to get a high-level overview of the callback system. In one sentence, the callback system is a design pattern in the fast.ai library to hook into different parts of the training loop of a neural network to easily tweak the specifics of that training loop. By using callbacks, you can add things like learning rate annealing, keep track of metrics or even implement a GAN without having to write an entirely new training loop. The flexibility of this callback system makes it convenient to experiment with custom neural networks.

How does this callback system work? In this article, I will highlight a few lines of the callback system that might not seem very obvious. If you follow along, you will deepen your understanding of the fast.ai callback system and learn a thing or two about Python if you did not use it extensively.

Note: In this post, I look at the new Callback system as discussed in lesson two part two of the fast.ai course. This new system is not included in the fastai library yet, and thus a bit different from the current (fastai v1.0.50) callback system. If you see things in this post that you cannot find in the fastai source code, this is the reason.

General overview of the Callback system

Before diving into the specifics, I explain the current flow of the Callback system as briefly as possible. Having high-level knowledge of the callback system is necessary to understand its details.

The new training loop is part of a Runner class. A simplified training loop in fastai looks like this - I removed and modified this class to make it easier to follow:

class Runner()
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
    
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def fit(self, epochs, learn):
        for cb in self.cbs: cb.set_runner(self)
        self.epochs,self.learn = epochs,learn
        for epoch in range(epochs):
            for xb,yb in train_dl:
                self.pred = self.model(self.xb)
                self.loss = self.loss_func(self.pred, self.yb)
                self.loss.backward()
                self.opt.step()
                self.opt.zero_grad()

The main thing to look at there is the fit function. There is pred to get the predictions, loss to see how good these predictions are and loss.backward() to calculate the gradient of the loss with respect to the parameters. Finally, opt.step() update the weights by calling the PyTorch Optimizer class, and opt.zero_grad() zeroes out the gradients again.
With callbacks, a simplified version of the fit() method inside the Runner class looks like this:

def fit(self, epochs, learn):
    self.epochs,self.learn = epochs,learn
    if self('begin_fit'): return
    for epoch in range(epochs):
        for xb,yb in train_dl:
            if self('begin_batch'): return
            self.pred = self.model(self.xb)
            if self('after_pred'): return
            self.loss = self.loss_func(self.pred, self.yb)
            if self('after_loss') or not self.in_train: return
            self.loss.backward()
            if self('after_backward'): return
            self.opt.step()
            if self('after_step'): return
            self.opt.zero_grad()

It is quite similar! The only extra parts are all these if statements that call the runner class with an argument.

What the hell does __call__ do in Runner?

Wait. There seems to be something wrong. If I run something like self('begin_batch') right now, the code does not work. The init method does not allow something like this (no, listify() does not do that either). More precisely, the Runner class is not even callable right now - something like r = Runner(); r() does not work, although self('begin_batch) seems to suggest it should. It is clear that we need to add something. This is where the special Python dunder method __call__ in the Runner class comes into play, a method that I left out of the example until now. __call__ makes the Runner class callable; it lets you treat an object as a function. This means that when we implement this method, we can call the Runner class with self('begin_batch') and the __call__ method will run when we do that. Here is the implementation of that method:

def __call__(self, cb_name):
    for cb in sorted(self.cbs, key=lambda x: x._order):
        f = getattr(cb, cb_name, None)
        if f and f(): return True
    return False

When I first looked at this method, I thought: what the hell does this do? There is a lot to unpack here. Let us go through it step by step.

def __call__(self, cb_name):

The method receives a callback name, cb_name, which allows us to do something like self('begin_batch'). This should be straightforward by now. If it is not, have a look at the previous paragraph again, and the included links.

for cb in sorted(self.cbs, key=lambda x: x._order):

This line loops through all callbacks that are initiated when the Runner class was defined. It loops through them in the order that was defined in each callback class:

class Callback():
    _order=0

So each Callback may have an order, to allow for maximum flexibility of the callbacks, i.e., it adds the possibility to define the order of callbacks so that you can specify which one to execute first. This can become quite handy when implementing more difficult things where the order of execution matters.
listify() in the init method of the Runner class ensures that self.cbs is a list, so you can iterate over the callbacks, even if you do not instantiate the Runner class with a list of callbacks.

f = getattr(cb, cb_name, None)

This line looks inside each callback with getattr and checks whether the cb_name is implemented. If it cannot find it, it sets f to None.

if f and f(): return True

If the callback name can be found (if f), i.e., f is not None, call it with f() and if that returns True, stop calling other callbacks and exit the loop (Return True).

Return False

This last line ensures that __call__ returns False, so the training loop continues (have a look at the fit() function if it is not clear to you why this is true).
The line seems to be added to be explicit here and is not strictly necessary, since a function that does not return anything defaults to returning None, and None is False in Python.
This is an important thing to remember, because it is also the reason why False is used to signal that the training loop should continue, although True is (at least in my experience) usually used to express that the execution of something should continue. However, returning True to continue the training loop has a definite downside in this case:
Most implemented callbacks do not stop the training loop - they only add a tweak to it. By defaulting to returning False to continue fitting, the majority of the callbacks do not have to return anything - because returning nothing defaults to returning None which means False.
That is one less thing to remember when implementing most callbacks, and one less line of code per callback.

class ExampleCallback(Callback):
    def begin_batch(self):
        # ..., body of callback
        return False # not necessary, returning nothing = None = False
        # On the other hand, if True would signal to continue,
        # adding `return True` would have been necessary in most cases

I hope this explanation of the __call__ function helps you to understand what is going on in inside it. If it did not, remember this: the __call__ function runs multiple times during the training loop and receives a different callback name each time it is run. It checks if that callback name is implemented, and if that is the case, it executes that specific callback.

Using Runner attributes in a callback and vice versa

There is one more thing I would like to explain in this blog post.
When you implement your callback, you often use attributes from the Runner class, because these are the things that you want to tweak (e.g., model, optim ). You have access to these attributes because for cb in self.cbs: cb.set_runner(self) is implemented in the Runner class:

class Runner():
    # ...
    def fit(self, epoch, learn):
        for cb in self.cbs: cb.set_runner(self)

class Callback():
    # ...
    def set_runner(self, run): self.run=run

However, wouldn’t it be more convenient if you did not have to call each of these attributes via self.run? Yes!
That’s why the __getattr__ is implemented in the Callback class:

class Callback():
    # ...
    def __getattr__(self, k): return getattr(self.run, k)

This method allows you to do something like self.model inside a Callback class, although you did not explicitly define a self.model in that class. How does that work? The special dunder __getattr__ method runs each time you try to get an attribute that is not implemented. Because the body of that method is return getattr(self.run, k), it checks if that attribute is implemented in the runner class. It is a little but a very nice addition to the Callback class that makes a big difference in the end, given the number of callbacks in the fastai library.

At the same time, it can be beneficial to access callbacks attributes from within the Runner class. For example, the ability to plot losses in fastai is implemented in a callback. Unfortunately, when you are working in your Jupyter notebook, you do not have direct access to that Callback class each time you are training. Or do you? Have a look at this piece of code:

class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)

class Callback():
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

What is going on here? Each callback that is passed onto the Runner class is being set as an attribute via setattr(self, cb.name, cb), which calls the Callback name() method. This method takes the name of the callback class, removes ‘Callback’ from that name with a regular expression and returns a ‘camel2snake’ string of the name of the callback class. What does this mean in practice?
Let us say you have implemented a callback named LearningRateCallback that keeps track of the learning rate during training and allows you to plot the learning rate via a method called plot() (In fact, fastai has something very similar via a more general Recorder callback and a plot_lr method).
The code in def name() removes Callback from LearningRateCallback and transforms it to snake case: learning_rate.
setattr(self, cb.name, cb) inside Runner then sets this learning_rate as an attribute inside the Runner class. This allows you to do something like this:

run = Runner(cb=LearningRateCallback)
# ... run.fit()
run.learning_rate.plot()

That is quite neat and useful - you do not have to think about how to access each callback when you are working with Runner anymore.

Let me know if you have any questions, if there is anything I could have explained better or if you have any feedback in general on the fastai forum.