# Operator overloading

Martin McBride, 2019-01-05

Tags operator overload

Categories magic methods

One of our example classes is `Matrix`

, a 2 by 2 matrix.

You can perform operations such as add or multiply on matrices. The most obvious way to do this would be to define an `add`

function:

p = Matrix(1, 2, 3, 4) q = Matrix(5, 6, 7, 8) r = p.add(q)

That is ok, but it would be nice if we could write:

p = Matrix(1, 2, 3, 4) q = Matrix(5, 6, 7, 8) r = p + q

Well we can! In fact we can override all the arithmetic operators if we wish. In this section we will look at + and *, but the technique can apply to any operator.

## Matrix addition

Just to recap the basics of matrix algebra, the sum of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} + \begin{pmatrix}e & f\\g & h\end{pmatrix}$$

is

$$\begin{pmatrix}a + e & b + f\\c + g & d + h\end{pmatrix}$$

## Overriding the addition operator

You can override the addition operator for a class by providing an `__add__`

method:

class Matrix: def __init__(self, a, b, c, d): self.data = [a, b, c, d] def __add__(self, other): if isinstance(other, Matrix): return Matrix(self.data[0] + other.data[0], self.data[1] + other.data[1], self.data[2] + other.data[2], self.data[3] + other.data[3]) else: return NotImplemented

The `__add__`

method accepts a parameter `other`

. We first check `other`

is a `Matrix`

. If it is, the method creates a brand new `Matrix`

object whose elements formed by adding the elements of `other`

to `self`

.

If `other`

is not a `Matrix`

, our code doesn't know how to handle it. In that case we must return `NotImplemented`

. Python can then decide what to do (this is covered in more detail below).

Here is how this is used:

p = Matrix(1, 2, 3, 4) q = Matrix(5, 6, 7, 8) r = p + q print(r)

We create two matrices, `p`

and `q`

. We then perform the calculation `p + q`

. This calls the `__add__`

method on the first object `p`

, passing the second object `r`

as the `other`

parameter.

The `__add__`

function returns a `Matrix`

that is the result of the addition, and this gets assigned to `r`

. The result is printed:

[6, 8][10, 12]

## Matrix multiplication

Matrix multiplication is a more interesting case, because you can multiply a matrix by another matrix, or alternatively you can multiply it by a scalar (ie an ordinary number).

### Multiplying a matrix by a matrix

The product of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . \begin{pmatrix}e & f\\g & h\end{pmatrix}$$

is

$$\begin{pmatrix}a.e + b.g & a.f + b.h\\c.e + d.g & c.f + d.h\end{pmatrix}$$

### Multiplying a matrix by a scalar

You can also multiply a matrix by a scalar (an ordinary number `n`

):

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . n$$

giving:

$$\begin{pmatrix}a.n & b.n\\c.n & d.n\end{pmatrix}$$

## Overriding the multiply operator

Here is a version of the `Matrix`

class with an implementation of `__mul__`

class Matrix: def __init__(self, a, b, c, d): self.a = a self.b = b self.c = c self.d = d def __mul__(self, other): if isinstance(other, (int, float)): return Matrix(self.data[0] * other, self.data[1] * other, self.data[2] * other, self.data[3] * other) elif isinstance(other, Matrix): return Matrix(self.data[0] * other.data[0] + self.data[1] * other.data[1], self.data[0] * other.data[1] + self.data[1] * other.data[3], self.data[2] * other.data[0] + self.data[3] * other.data[1], self.data[2] * other.data[1] + self.data[3] * other.data[3]) else: return NotImplemented

If you look at `__mul__`

, you will see that the first thing we do is to check if `other`

is a scalar. We do this by checking if it is an instance of `int`

or `float`

(you could also check `complex`

if you wanted the Matrix class to support complex number, but we won't bother in this example).

If the value is a number, we execute the code for the scalar multiplication equation above.

If the value is not a scalar, we check if it is a `Matrix`

, and execute the code for the matrix multiplication equation above.

If the value is neither a number nor a `Matrix`

we return `NotImplemented`

.

Here is an example:

p = Matrix(1, 2, 3, 4) q = Matrix(5, 6, 7, 8) print(p*2) print(p*q)

This prints

[2, 4][6, 8] [17, 22][39, 50]

as expected.

## Reversing the arguments

What if we try to do this:

print(2*p)

Unfortunately our existing code doesn't quite cope with this situation. We get an error:

TypeError: unsupported operand type(s) for *: 'int' and 'Matrix'

So what has happened here? We are trying to multiply `2*p`

:

- Python looks at the first value, 2, which is an
`int`

. - It calls
`int.__mul__`

passing in the second value`p`

which is a`Matrix`

. - Since
`int`

is a built-in type, its`__mul__`

function knows nothing of out`Matrix`

type, so it returns`NotImplemented`

.

You might think that Python would give an error at this point, but actually it tries one last thing:

- Python checks if the second argument
`p`

has a`__rmul__`

method. If not it gives an error. - It calls
`p.__rmul__`

passing in the first value 2. - If
`p.__rmul__`

can handle an integer type, the value will be calculated. - If not,
`p.__rmul__`

returns`NotImplemented`

and Python gives an error.

So, we can handle this extra case by implementing `__rmul__`

for our `Matrix`

class:

def __rmul__(self, other): if isinstance(other, (int, float)): return Matrix(self.data[0] * other, self.data[1] * other, self.data[2] * other, self.data[3] * other) else: return NotImplemented

In this case, `self`

is the second operand `p`

, and `other`

is the first operand 2. This is because `__rmul__`

reverses the arguments.

Since `other`

is an `int`

, our code executes and creates the correct result. In this case, the code for handing numbers is identical in `__mul__`

and `__rmul__`

because for matrices `p*2`

and `2*p`

are the same. That won't be true for all data types and all operators of course.

Notice that if both operands are of type Matrix, the case will *always* be handled by `__mul__`

, so there is no need to handle that case in `__rmul__`

. This is generally true for all data types and operators.

## Error checking

What if we try something crazy like:

print(p*'abc')

Our `__mul__`

code checks the type of the `other`

value. It isn't a number, it isn't a `Matrix`

, so we return `NotImplemented`

.

Python will then check if `str`

has an `__rmul__`

method. It does, but it can't handle our `Matrix`

type so again it returns `NotImplemented`

.

Python gives an error.

## In place operators

There is an additional case to consider, the in place operators such as `+=`

and `*=`

. They are covered in the next section

## Summary of operators

Here is a summary of the available numerical operators:

Method | Symbol |
---|---|

__add__ | + |

__sub__ | - |

__mul__ | * |

__matmul__ | @ |

__truediv__ | / |

__floordiv__ | // |

__mod__ | % |

__divmod__ | |

__pow__ | ** |

__lshift__ | << |

__rshift__ | >> |

__and__ | & |

__xor__ | ^ |

__or__ | | |