As I used PyTorch nn.Module
for more complicated deep learning systems, I gradually found that nn.Module
has many mechansims that I didn’t know before, which also becomes the pitfalls of coding in PyTorch.
The pitfall I am talking about here is concerned with the module registration mechanism. The way PyTorch identifies the submodules in a module is through assigning instance property (they call it register) in the __init__()
method of nn.Module
.
nn.Module
in PyTorch Docs.I came across a weird situation in my project that has diamond inheritance on nn.Module
classes (see my previous post). The following toy example codes illustrate the problem:
from torch import nn
class A(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.a = nn.Linear(1, 1)
def forward(self):
return
class B(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.b = nn.Linear(2, 2)
def forward(self):
return
class C(A, B):
def __init__(self):
__init__(self)
A.__init__(self)
B.self.c = nn.Linear(3, 3)
def forward(self):
return
C()
We expect C has three submodules a
, b
and c
after the instantiation. However, in reality, only b
and c
are registered in C
instance, while a
is not. If we debug with some prints:
class C(A, B):
def __init__(self):
__init__(self)
A.print(self.a)
__init__(self)
B.print(self.b)
print(self.a)
self.c = nn.Linear(3, 3)
We found it prints the first print(self.a)
and print(self.b)
smoothly, but the second print(self.a)
raises an AttributeError
that a
is not defined. However, if self.a
is another type of variables except nn.Module
, it managed to be printed!
This is the pitfall. The only possible reason is something happened in B.__init__(self)
that makes a
lost, and this only happened to nn.Module
variable.
We can find the reason simply from the source code. The answer is the internal mechanism inside nn.Module
that related to the module registration: lots of __setattr__
are called in the __init__()
method of nn.Module
to register the submodules, most of which are sort of resetting actions through direct manipulation of class attributes. There is a super().__setattr__("_modules", {})
that probably make the attribute self.a
disappear.
class Module:
def __init__(self, *args, **kwargs) -> None:
...
super().__setattr__("training", True)
super().__setattr__("_parameters", {})
super().__setattr__("_buffers", {})
super().__setattr__("_non_persistent_buffers_set", set())
super().__setattr__("_backward_pre_hooks", OrderedDict())
super().__setattr__("_backward_hooks", OrderedDict())
super().__setattr__("_is_full_backward_hook", None)
super().__setattr__("_forward_hooks", OrderedDict())
super().__setattr__("_forward_hooks_with_kwargs", OrderedDict())
super().__setattr__("_forward_hooks_always_called", OrderedDict())
super().__setattr__("_forward_pre_hooks", OrderedDict())
super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict())
super().__setattr__("_state_dict_hooks", OrderedDict())
super().__setattr__("_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
super().__setattr__("_modules", {})
if self.call_super_init:
super().__init__(*args, **kwargs)
Now we go back to our original intention. The code above is to combine two modules into a new one, and we now see it doesn’t work if we do this through multiple inheritance. How to do that?
I figured out a way: do the module initialisation of subclasses A
or B
in a seperate method and call it in C
’s __init__()
method explicitly:
class A(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.register_module_explicitly(...)
def register_modules(self, ...):
self.a = nn.Linear(1, 1)
def forward(self):
return
class B(nn.Module):
def __init__(self):
__init__(self)
nn.Module.self.b = nn.Linear(2, 2)
def forward(self):
return
class C(A, B):
def __init__(self):
__init__(self)
A.__init__(self)
B.self)
A.register_module_explicitly(self.c = nn.Linear(3, 3)
def forward(self):
return
It works although this is not an elegant way because self.a
is registered twice and it makes A
look slightly unconventional. Note that there is a existingregister_module()
method in nn.Module
, don’t take that time.
π
Back to top