Have you ever wondered ๐ค how PyTorch nn.Module
works?
I was always curious to understand how the internals work too. Recently I was reading Fast.ai's Deep learning for coders book's 19th chapter, where we learn how to build minimal versions of PyTorch
and FastAI
modules like
- Dataset, Dataloaders
- Modules
- FastAI Learner
This intrigued ๐ค me to take a look at the PyTorch
source code for nn.Module
. The code for nn.Module
is 1000+
lines ๐ฎ. After a few cups of coffee โโ, I was able to make sense of what is happening inside. Hopefully, by end of this post, you would have an understanding of what goes insidenn.Module
without those cups of coffee ๐.
A simple Pytorch modelโ
All the models in the PyTorch
subclasses the nn.Module
class which implements various methods. Some of these methods we would usually use are -
- to
- cuda
- zero_grad
- train
- eval
- load_state_dict
A simple example looks like this.
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Let's look at the __dict__
method of our model which contains the attributes of our SimpleModel
class.
model = SimpleModel()
model.__dict__
{'training': True,
'_parameters': OrderedDict(),
'_buffers': OrderedDict(),
'_non_persistent_buffers_set': set(),
'_backward_hooks': OrderedDict(),
'_is_full_backward_hook': None,
'_forward_hooks': OrderedDict(),
'_forward_pre_hooks': OrderedDict(),
'_state_dict_hooks': OrderedDict(),
'_load_state_dict_pre_hooks': OrderedDict(),
'_modules': OrderedDict([('conv1',
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))),
('conv2', Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1)))])}
If we observe closely, we can see that our modules/layers like conv1
end up in _modules
dict internally. After observing the above output closely, I was wondering about
- How these modules ended up in
_modules
and not as a direct attribute. - How are we able to access the
conv1
like thismodel.conv1
. - How are we able to pass input to a model as if it is a function
model(input)
. - How the tab completion was able to recommend us the keys of a
_modules
when we didmodel.
Let's put on our detective ๐ต๏ธโโ๏ธ hat and search for answers.
How did the conv1
attribute end up in _modules
?โ
__init__
and __setattr__
are responsible for this behaviour. Let's look at the minimal versions of these functions.
def __init__(self):
self._parameters = OrderedDict()
self._modules = OrderedDict()
def __setattr__(self,name,value):
if isinstance(value,nn.Parameter):
self._parameters[name]=value
if isinstance(value,nn.Module):
self._modules[name] = value
else:
object.__setattr__(self,name,value)
__setattr__
is called by python when we initialize something like self.conv1 = ****
. Based on the type of the value, PyTorch
decides whether to store it in _modules
, _parameters
or assign it to the object itself (which is the default behavior).
How are we able to access conv1
like model.conv1
โ
__getattr__
is called by python when it cannot find an attribute directly in the __dict__
. So to access conv1
we need to check if it is
present in either _modules
or _parameters
and return the value if present. The actual code contains a number of validations, a minimal version would look like this.
def __getattr__(self,name):
if name in self._modules.keys():
return self._modules[name]
if name in self._parameters.keys():
return self._parameters[name]
How tab completion works?โ
I was happy ๐ค with the progress till here, but I got disappointed when I tried building a model with our newly built module.
If we try to use tab completion on the model. + tab
, nothing happens, but it works with the PyTorch
module.
After some digging, I figured out that the __dir__
is responsible for it. Let's look at the minimal version.
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
modules = list(self._modules.keys())
parameters = list(self._parameters.keys())
keys = module_attrs + attrs + modules + parameters
return keys
How the forward method worksโ
Ever wondered how it works, when we pass inputs to the model object like model(x)
, PyTorch somehow ๐ค
calls the forward
method. __call__
is called when we use the object as a function.
Yes, our forward method gets called inside the __call__
. A very minimal version will look like this.
def __call__(self,*input, **kwargs):
return self.forward(*input, **kwargs)
In the actual code, PyTorch does more than just call the forward
method.
Other interesting methodsโ
If reading the blog has sparked your interest to understand more, you can try understanding the below functions.
- get_submodule
- add_module
- apply
- float/double
I hope you enjoyed reading the blog.
You can find the source code for nn.Module
here.