As I’ve used PyTorch’s nn.Module
for more complex deep learning systems, I’ve gradually discovered that nn.Module
has many mechanisms I wasn’t previously aware of, which can become pitfalls when coding in PyTorch.
The pitfall I’m discussing here concerns the module registration mechanism. The way PyTorch identifies submodules within a module is through assigning instance properties (which they call “registration”) in the __init__()
method of nn.Module
.
nn.Module
in PyTorch Docs.Don’t Initialize nn.Module
Twice
The pitfall in short: never call nn.Module.__init__()
twice!
I encountered a strange situation in my project involving diamond inheritance with nn.Module
classes (see my previous article). The following toy example illustrates the problem:
from torch import nn
class A(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.a = nn.Linear(1, 1)
def forward(self):
return
class B(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.b = nn.Linear(2, 2)
def forward(self):
return
class C(A, B):
def __init__(self):
__init__(self)
A.__init__(self)
B.self.c = nn.Linear(3, 3)
def forward(self):
return
C()
Note that I implemented manual __init__()
by referring to class names, instead of using super()
, because I didn’t fully understand the mechanism of super()
at the time. Please refer to the previous article for details.
Although this code results in calling nn.Module.__init__()
twice, it is able to guarantee that all nn.Module
components a
, b
, and c
are properly initialized. Therefore, we expect class C
to have three properties a
, b
, and c
after instantiation. However, the result is, only b
and c
are registered in the C
instance, while a
is not. We added some debug prints:
class C(A, B):
def __init__(self):
__init__(self)
A.print(self.a)
__init__(self)
B.print(self.b)
print(self.a)
self.c = nn.Linear(3, 3)
and found that the first print(self.a)
and print(self.b)
execute smoothly, but the second print(self.a)
raises an AttributeError
stating that a
is not defined. However, if self.a
were any other type of variable except an nn.Module
, it would print successfully!
This is the pitfall. The only possible explanation is that something happens in B.__init__(self)
that causes a
to be lost, and this only occurs with nn.Module
variables.
We can find the reason simply by examining the source code. The answer lies in the internal mechanism within nn.Module
related to module registration: numerous __setattr__
calls are made in the __init__()
method of nn.Module
to register submodules, most of which perform reset operations through direct manipulation of class attributes. There’s a super().__setattr__("_modules", {})
call that likely causes the attribute self.a
to disappear.
class Module:
def __init__(self, *args, **kwargs) -> None:
...
super().__setattr__("training", True)
super().__setattr__("_parameters", {})
super().__setattr__("_buffers", {})
super().__setattr__("_non_persistent_buffers_set", set())
super().__setattr__("_backward_pre_hooks", OrderedDict())
super().__setattr__("_backward_hooks", OrderedDict())
super().__setattr__("_is_full_backward_hook", None)
super().__setattr__("_forward_hooks", OrderedDict())
super().__setattr__("_forward_hooks_with_kwargs", OrderedDict())
super().__setattr__("_forward_hooks_always_called", OrderedDict())
super().__setattr__("_forward_pre_hooks", OrderedDict())
super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict())
super().__setattr__("_state_dict_hooks", OrderedDict())
super().__setattr__("_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
super().__setattr__("_modules", {})
if self.call_super_init:
super().__init__(*args, **kwargs)
An Inelegant Solution
Now let’s return to our original intention. The code above is meant to combine two modules into a new one, but we now see it doesn’t work with the approach shown. How can we avoid this issue?
I figured out a workaround: perform the module initialization of subclasses A
and B
separately and call it explicitly in C
’s __init__()
method:
class A(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.register_modules()
def register_modules(self):
self.a = nn.Linear(1, 1)
def forward(self):
return
class B(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.b = nn.Linear(2, 2)
def forward(self):
return
class C(A, B):
def __init__(self):
__init__(self)
A.__init__(self)
B.self)
A.register_modules(self.c = nn.Linear(3, 3)
def forward(self):
return
This works, although it’s not an elegant solution because self.a
is registered twice and it also makes class A
look somewhat unconventional. (Note that there’s an existing register_module()
method in nn.Module
—don’t confuse it with our custom method. )
The Correct Way: Let super()
Handle It
In fact, if I had understood the problem discussed in my previous article, I would have preferred to use super()
to automatically handle the inheritance and initialization calls:
from torch import nn
class A(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1)
def forward(self):
return
class B(nn.Module):
def __init__(self):
super().__init__()
self.b = nn.Linear(2, 2)
def forward(self):
return
class C(A, B):
def __init__(self):
super().__init__()
self.c = nn.Linear(3, 3)
def forward(self):
return
C()
With this approach, nn.Module.__init__()
is called only once, avoiding the registration conflicts entirely! This is perfectly elegant, and I am so amazed with the Python designers.