Diamond Inheritance in Python
When I was trying Python multiple inheritance mechanism in my project, I encountered pitfalls commonly discussed online. I implemented a Backbone
class for neural network backbones, and HATMaskBackbone
class for those with HAT mask mechanism, which inherits from Backbone
. Now, I needed a HATMaskMLP
that should logically inherit from both HATMaskBackbone
and MLP
, where MLP
is also a subclass of Backbone
. This created numerous problems that proved quite frustrating.
Let me illustrate this with a simplified toy example:
class A:
def __init__(self, a):
self.a = a
class B(A):
def __init__(self, a, b):
super().__init__(a)
self.b = b
class C(A):
def __init__(self, a, c):
super().__init__(self, a)
self.c = c
class D(B, C):
def __init__(self, a, b, c, d):
# How to call its parent classes' __init__ methods?
self.d = d
This is commonly referred to as diamond inheritance. Our question is shown in the comment of D
class:
- How should we call its parent classes’
__init__()
methods? - How can we ensure that
self.a = a
,self.b = b
,self.c = c
andself.d = d
are all properly executed exactly once?
Misunderstanding of super()
The natural approach is to call super().__init__(...)
in class D
. However, this raises the question: which parent class’s __init__
method does super()
represent—B
, C
, or even A
? According to MRO (Method Resolution Order), the answer is B
(please refer to my article about MRO).
At this point, I believed that calling super().__init__(...)
would call B.__init__(...)
first, which would call A.__init__(...)
next because A is the direct parent of B, as explicitly defined in B
’s class definition. This seemed very reasonable, but this was my pitfall.
If we carefully look at how super()
works, we’ll discover something against this intuition. super()
calls follow the MRO list order of the initially calling class, not the direct parent class definition order. That is, in this case, when we call super().__init__(...)
in class D
, Python will call B.__init__(...)
, then C.__init__(...)
, and finally A.__init__(...)
.
Why does this happen? This requires understanding super()
. In Python, super
is a special built-in class that actually requires two parameters for initialization: the first parameter type
is a class, and the second parameter obj
is an instance. What super(type, obj)
does is search through obj
’s MRO list, find the type
class, and return the next class after type
in the MRO list. Python’s super
has another mechanism: when super()
is called without arguments, it automatically uses the current class and its self
for initialization (suppose it shows in an __init__()
method of a class).
Therefore, the process of calling super().__init__(...)
in class D
is as follows:
super().__init__()
inD
is equivalent tosuper(D, self).__init__(...)
.self
is bound toD
, whose MRO is[D, B, C, A]
. The first parameter isD
. Following this MRO, we findD
’s first parent classB
, sosuper(D, self).__init__(...)
becomesB.__init__(...)
.- When calling
B.__init__(...)
, anothersuper().__init__()
appears inB.__init__(...)
, which is equivalent tosuper(B, self).__init__(...)
. - At this point,
self
is stillD
(this is crucial!), so the MRO is still[D, B, C, A]
. The next parent class afterB
isC
, sosuper(B, self).__init__(...)
becomesC.__init__(...)
.
Step 4 explains why super() calls C.__init__(...)
instead of directly calling A.__init__(...)
. Therefore, the final execution order is: B.__init__()
→ C.__init__()
→ A.__init__()
. This all works perfectly and elegantly solving the problem we proposed at the end of the previous section.
This behaviour is indeed counter-intuitive and becomes the key of the pitfall—I had misunderstood how super()
works. But when we think about it from another perspective, the meaning of super is to find its parents. Since both B
and C
are parents of D
, both of their __init__()
methods should be called. From this angle, super
becomes quite understandable.
The Wrong Alternative: Manually Call Parent Classes’ __init__()
Before I discovered this pitfall, I attempted to explicitly call the parent classes’ __init__()
methods instead of using super()
. Since I initially thought using super()
wouldn’t reach C
, I tried calling them directly with B.__init__(self, a, b)
and C.__init__(self, a, c)
:
class A:
def __init__(self, a):
self.a = a
class B(A):
def __init__(self, a, b):
__init__(self, a)
A.self.b = b
class C(A):
def __init__(self, a, c):
__init__(self, a)
A.self.c = c
class D(B, C):
def __init__(self, a, b, c, d):
__init__(self, a, b)
B.__init__(self, a, c)
C.self.d = d
While this appears to solve the problem, there are also significant issues. First, this manual approach leads to maintenance difficulties, but this isn’t the thing that can leads to bugs. The more serious problem is that, A.__init__(...)
is actually called twice, creating redundancy (In cases without diamond inheritance, there’s no redundancy, but this is in diamond inheritance). This redundancy might seem minor as well, but it’s a potential pitfall that can cause hard-to-detect bugs in the future.
In my case above, the diamond inheritance structure of HATMaskMLP
, MLP
, HATMaskBackbone
, and Backbone
caused Backbone
’s initialization called twice. The Backbone
class is a subclass of nn.Module
, so the nn.Module.__init__()
was called twice, which is not allowed in PyTorch and leads to serious problems. Please refer to my another code pitfall post for details.