# Operator overloading

Categories: magic methods

One of our example classes is `Matrix`

, a 2 by 2 matrix.

You can perform operations such as add or multiply on matrices. The most obvious way to do this would be to define an `add`

function:

```
p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p.add(q)
```

That is ok, but it would be nice if we could write:

```
p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p + q
```

Well we can! In fact we can override all the arithmetic operators if we wish. In this section we will look at + and *, but the technique can apply to any operator.

## Matrix addition

Just to recap the basics of matrix algebra, the sum of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} + \begin{pmatrix}e & f\\g & h\end{pmatrix}$$

is

$$\begin{pmatrix}a + e & b + f\\c + g & d + h\end{pmatrix}$$

## Overloading the addition operator

You can override the addition operator for a class by providing an `__add__`

method:

```
class Matrix:
def __init__(self, a, b, c, d):
self.data = [a, b, c, d]
def __add__(self, other):
if isinstance(other, Matrix):
return Matrix(self.data[0] + other.data[0],
self.data[1] + other.data[1],
self.data[2] + other.data[2],
self.data[3] + other.data[3])
else:
return NotImplemented
```

The `__add__`

method accepts a parameter `other`

. We first check `other`

is a `Matrix`

. If it is, the method creates a brand new `Matrix`

object whose elements formed by adding the elements of `other`

to `self`

.

If `other`

is not a `Matrix`

, our code doesn't know how to handle it. In that case we must return `NotImplemented`

. Python can then decide what to do (this is covered in more detail below).

Here is how this is used:

```
p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
r = p + q
print(r)
```

We create two matrices, `p`

and `q`

. We then perform the calculation `p + q`

. This calls the `__add__`

method on the first object `p`

, passing the second object `r`

as the `other`

parameter.

The `__add__`

function returns a `Matrix`

that is the result of the addition, and this gets assigned to `r`

. The result is printed:

```
[6, 8][10, 12]
```

## Matrix multiplication

Matrix multiplication is a more interesting case, because you can multiply a matrix by another matrix, or alternatively you can multiply it by a scalar (ie an ordinary number).

### Multiplying a matrix by a matrix

The product of two matrices:

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . \begin{pmatrix}e & f\\g & h\end{pmatrix}$$

is

$$\begin{pmatrix}a.e + b.g & a.f + b.h\\c.e + d.g & c.f + d.h\end{pmatrix}$$

### Multiplying a matrix by a scalar

You can also multiply a matrix by a scalar (an ordinary number `n`

):

$$\begin{pmatrix}a & b\\c & d\end{pmatrix} . n$$

giving:

$$\begin{pmatrix}a.n & b.n\\c.n & d.n\end{pmatrix}$$

## Overloading the multiply operator

Here is a version of the `Matrix`

class with an implementation of `__mul__`

```
class Matrix:
def __init__(self, a, b, c, d):
self.a = a
self.b = b
self.c = c
self.d = d
def __mul__(self, other):
if isinstance(other, (int, float)):
return Matrix(self.data[0] * other,
self.data[1] * other,
self.data[2] * other,
self.data[3] * other)
elif isinstance(other, Matrix):
return Matrix(self.data[0] * other.data[0] + self.data[1] * other.data[1],
self.data[0] * other.data[1] + self.data[1] * other.data[3],
self.data[2] * other.data[0] + self.data[3] * other.data[1],
self.data[2] * other.data[1] + self.data[3] * other.data[3])
else:
return NotImplemented
```

If you look at `__mul__`

, you will see that the first thing we do is to check if `other`

is a scalar. We do this by checking if it is an instance of `int`

or `float`

(you could also check `complex`

if you wanted the Matrix class to support complex number, but we won't bother in this example).

If the value is a number, we execute the code for the scalar multiplication equation above.

If the value is not a scalar, we check if it is a `Matrix`

, and execute the code for the matrix multiplication equation above.

If the value is neither a number nor a `Matrix`

we return `NotImplemented`

.

Here is an example:

```
p = Matrix(1, 2, 3, 4)
q = Matrix(5, 6, 7, 8)
print(p*2)
print(p*q)
```

This prints

```
[2, 4][6, 8]
[17, 22][39, 50]
```

as expected.

## Reversing the arguments

What if we try to do this:

```
print(2*p)
```

Unfortunately our existing code doesn't quite cope with this situation. We get an error:

```
TypeError: unsupported operand type(s) for *: 'int' and 'Matrix'
```

So what has happened here? We are trying to multiply `2*p`

:

- Python looks at the first value, 2, which is an
`int`

. - It calls
`int.__mul__`

passing in the second value`p`

which is a`Matrix`

. - Since
`int`

is a built-in type, its`__mul__`

function knows nothing of out`Matrix`

type, so it returns`NotImplemented`

.

You might think that Python would give an error at this point, but actually it tries one last thing:

- Python checks if the second argument
`p`

has a`__rmul__`

method. If not it gives an error. - It calls
`p.__rmul__`

passing in the first value 2. - If
`p.__rmul__`

can handle an integer type, the value will be calculated. - If not,
`p.__rmul__`

returns`NotImplemented`

and Python gives an error.

So, we can handle this extra case by implementing `__rmul__`

for our `Matrix`

class:

```
def __rmul__(self, other):
if isinstance(other, (int, float)):
return Matrix(self.data[0] * other,
self.data[1] * other,
self.data[2] * other,
self.data[3] * other)
else:
return NotImplemented
```

In this case, `self`

is the second operand `p`

, and `other`

is the first operand 2. This is because `__rmul__`

reverses the arguments.

Since `other`

is an `int`

, our code executes and creates the correct result. In this case, the code for handing numbers is identical in `__mul__`

and `__rmul__`

because for matrices `p*2`

and `2*p`

are the same. That won't be true for all data types and all operators of course.

Notice that if both operands are of type Matrix, the case will *always* be handled by `__mul__`

, so there is no need to handle that case in `__rmul__`

. This is generally true for all data types and operators.

## Error checking

What if we try something crazy like:

```
print(p*'abc')
```

Our `__mul__`

code checks the type of the `other`

value. It isn't a number, it isn't a `Matrix`

, so we return `NotImplemented`

.

Python will then check if `str`

has an `__rmul__`

method. It does, but it can't handle our `Matrix`

type so again it returns `NotImplemented`

.

Python gives an error.

## In place operators

There is an additional case to consider, the in place operators such as `+=`

and `*=`

. They are covered in the next section

## Summary of operators

Here is a summary of the available numerical operators:

Method | Symbol |
---|---|

__add__ | + |

__sub__ | - |

__mul__ | * |

__matmul__ | @ |

__truediv__ | / |

__floordiv__ | // |

__mod__ | % |

__divmod__ | |

__pow__ | ** |

__lshift__ | << |

__rshift__ | >> |

__and__ | & |

__xor__ | ^ |

__or__ | | |

## See also

## Join the PythonInformer Newsletter

Sign up using this form to receive an email when new content is added:

## Popular tags

2d arrays abstract data type alignment and angle animation arc array arrays bar chart bar style behavioural pattern bezier curve built-in function callable object chain circle classes clipping close closure cmyk colour combinations comparison operator comprehension context context manager conversion count creational pattern data science data types decorator design pattern device space dictionary drawing duck typing efficiency ellipse else encryption enumerate fill filter font font style for loop formula function function composition function plot functools game development generativepy tutorial generator geometry gif global variable gradient greyscale higher order function hsl html image image processing imagesurface immutable object in operator index inner function input installing iter iterable iterator itertools join l system lambda function latex len lerp line line plot line style linear gradient linspace list list comprehension logical operator lru_cache magic method mandelbrot mandelbrot set map marker style matplotlib monad mutability named parameter numeric python numpy object open operator optimisation optional parameter or pandas partial application path pattern permutations pie chart pil pillow polygon pong positional parameter print product programming paradigms programming techniques pure function python standard library radial gradient range recipes rectangle recursion reduce regular polygon repeat rgb rotation roundrect scaling scatter plot scipy sector segment sequence setup shape singleton slice slicing sound spirograph sprite square str stream string stroke structural pattern subpath symmetric encryption template tex text text metrics tinkerbell fractal transform translation transparency triangle truthy value tuple turtle unpacking user space vectorisation webserver website while loop zip zip_longest