Operator overloading

Martin McBride, 2019-01-05
Tags operator overload
Categories magic methods
In section Python language

One of our example classes is Matrix, a 2 by 2 matrix.

You can perform operations such as add or multiply on matrices. The most obvious way to do this would be to define an add function:

p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p.add(q)

That is ok, but it would be nice if we could write:

p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p + q

Well we can! In fact we can override all the arithmetic operators if we wish. In this section we will look at + and *, but the technique can apply to any operator.

Matrix addition

Just to recap the basics of matrix algebra, the sum of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} + \begin{pmatrix}e & f\\g & h\end{pmatrix}$$


$$\begin{pmatrix}a + e & b + f\\c + g & d + h\end{pmatrix}$$

Overriding the addition operator

You can override the addition operator for a class by providing an __add__ method:

class Matrix:

    def __init__(self, a, b, c, d):
        self.data = [a, b, c, d]

    def __add__(self, other):
        if isinstance(other, Matrix):
            return Matrix(self.data[0] + other.data[0],
                          self.data[1] + other.data[1],
                          self.data[2] + other.data[2],
                          self.data[3] + other.data[3])
            return NotImplemented

The __add__ method accepts a parameter other. We first check other is a Matrix. If it is, the method creates a brand new Matrix object whose elements formed by adding the elements of other to self.

If other is not a Matrix, our code doesn't know how to handle it. In that case we must return NotImplemented. Python can then decide what to do (this is covered in more detail below).

Here is how this is used:

p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p + q

We create two matrices, p and q. We then perform the calculation p + q. This calls the __add__ method on the first object p, passing the second object r as the other parameter.

The __add__ function returns a Matrix that is the result of the addition, and this gets assigned to r. The result is printed:

[6, 8][10, 12]

Matrix multiplication

Matrix multiplication is a more interesting case, because you can multiply a matrix by another matrix, or alternatively you can multiply it by a scalar (ie an ordinary number).

Multiplying a matrix by a matrix

The product of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . \begin{pmatrix}e & f\\g & h\end{pmatrix}$$


$$\begin{pmatrix}a.e + b.g & a.f + b.h\\c.e + d.g & c.f + d.h\end{pmatrix}$$

Multiplying a matrix by a scalar

You can also multiply a matrix by a scalar (an ordinary number n):

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . n$$


$$\begin{pmatrix}a.n & b.n\\c.n & d.n\end{pmatrix}$$

Overriding the multiply operator

Here is a version of the Matrix class with an implementation of __mul__

class Matrix:

    def __init__(self, a, b, c, d):
        self.a = a
        self.b = b
        self.c = c
        self.d = d

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            return Matrix(self.data[0] * other,
                          self.data[1] * other,
                          self.data[2] * other,
                          self.data[3] * other)
        elif isinstance(other, Matrix):
            return Matrix(self.data[0] * other.data[0] + self.data[1] * other.data[1],
                          self.data[0] * other.data[1] + self.data[1] * other.data[3],
                          self.data[2] * other.data[0] + self.data[3] * other.data[1],
                          self.data[2] * other.data[1] + self.data[3] * other.data[3])
            return NotImplemented

If you look at __mul__, you will see that the first thing we do is to check if other is a scalar. We do this by checking if it is an instance of int or float (you could also check complex if you wanted the Matrix class to support complex number, but we won't bother in this example).

If the value is a number, we execute the code for the scalar multiplication equation above.

If the value is not a scalar, we check if it is a Matrix, and execute the code for the matrix multiplication equation above.

If the value is neither a number nor a Matrix we return NotImplemented.

Here is an example:

p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)

This prints

[2, 4][6, 8]
[17, 22][39, 50]

as expected.

Reversing the arguments

What if we try to do this:


Unfortunately our existing code doesn't quite cope with this situation. We get an error:

TypeError: unsupported operand type(s) for *: 'int' and 'Matrix'

So what has happened here? We are trying to multiply 2*p:

  • Python looks at the first value, 2, which is an int.
  • It calls int.__mul__ passing in the second value p which is a Matrix.
  • Since int is a built-in type, its __mul__ function knows nothing of out Matrix type, so it returns NotImplemented.

You might think that Python would give an error at this point, but actually it tries one last thing:

  • Python checks if the second argument p has a __rmul__ method. If not it gives an error.
  • It calls p.__rmul__ passing in the first value 2.
  • If p.__rmul__ can handle an integer type, the value will be calculated.
  • If not, p.__rmul__ returns NotImplemented and Python gives an error.

So, we can handle this extra case by implementing __rmul__ for our Matrix class:

    def __rmul__(self, other):
        if isinstance(other, (int, float)):
            return Matrix(self.data[0] * other,
                          self.data[1] * other,
                          self.data[2] * other,
                          self.data[3] * other)
            return NotImplemented

In this case, self is the second operand p, and other is the first operand 2. This is because __rmul__ reverses the arguments.

Since other is an int, our code executes and creates the correct result. In this case, the code for handing numbers is identical in __mul__ and __rmul__ because for matrices p*2 and 2*p are the same. That won't be true for all data types and all operators of course.

Notice that if both operands are of type Matrix, the case will always be handled by __mul__, so there is no need to handle that case in __rmul__. This is generally true for all data types and operators.

Error checking

What if we try something crazy like:


Our __mul__ code checks the type of the other value. It isn't a number, it isn't a Matrix, so we return NotImplemented.

Python will then check if str has an __rmul__ method. It does, but it can't handle our Matrix type so again it returns NotImplemented.

Python gives an error.

In place operators

There is an additional case to consider, the in place operators such as += and *=. They are covered in the next section

Summary of operators

Here is a summary of the available numerical operators:

Method Symbol
__add__ +
__sub__ -
__mul__ *
__matmul__ @
__truediv__ /
__floordiv__ //
__mod__ %
__pow__ **
__lshift__ <<
__rshift__ >>
__and__ &
__xor__ ^
__or__ |

If you found this article useful, you might be interested in the book Functional Programming in Python, or other books, by the same author.


Tag cloud

2d arrays abstract data type alignment and array arrays bezier curve built-in function close closure colour comparison operator comprehension context conversion data types design pattern device space dictionary duck typing efficiency encryption enumerate filter font font style for loop function function composition function plot functools generator gif gradient greyscale higher order function html image processing imagesurface immutable object index inner function input installing iter iterator itertools lambda function len linspace list list comprehension logical operator lru_cache mandelbrot map monad mutability named parameter numeric python numpy object open operator optional parameter or partial application path positional parameter print pure function radial gradient range recipes recursion reduce rgb rotation scaling sequence slice slicing sound spirograph str stream string subpath symmetric encryption template text text metrics transform translation transparency tuple unpacking user space vectorisation webserver website while loop zip

Copyright (c) Axlesoft Ltd 2020