Recursion and the lru_cache in Python
Martin McBride, 2020-02-12
Tags factorial recursion recursion limit tail call optimisation fibonacci series functools lru_cache
Categories functional programming
Recursion is a common technique that is often associated with functional programming. The basic idea is this – given a difficult problem, try to find procedure that turns the original problem into a simpler version of the same problem. Apply the same procedure repeatedly to make the problem simpler and simpler, until you have a problem that is so simple you can just solve it in one go.
As a Python programmer you may well look at some examples of recursion and think that it would obviously be easier to write a loop instead. Some other languages don’t have loops, so you have to use recursion, but in those cases the interpreter often creates a loop behind the scenes.
But there are plenty of problems that are inherently recursive in nature and would be very difficult to solve in any other way, so recursion is definitely something to have in your toolbox.
This example is a slight cliché, but it is still a good illustration of both the beauty and pitfalls of recursion.
The factorial of an integer n is the product of all the integers between 1 and n. For example, 6 factorial (usually written 6!) is:
6*5*4*3*2*1 = 720
Now as we said in the introduction, the obvious way to do this is with a loop. But there is an alternative, "cleverer" way, using recursion.
We can make the simple observation that 6! is actually 65!. And 5! is 54!, and so on. So, we could calculate n! without ever explicitly calculating a factorial at all. We just keep relying on smaller and smaller factorials, without ever calculating them.
Of course, you must stop somewhere – we know that 1! is 1.
Here is the Python code for calculating the factorial of n. Like we said, we just return n times the factorial of n – 1, unless n is 1 when we just return 1:
def factorial(n): if n>1: x = n*factorial(n-1) else: x = 1 return x print(factorial(6))
Amazingly enough, this works. We can investigate this further by adding some debug print statements:
def factorial(n): print('Enter', n) if n>1: x = n*factorial(n-1) else: x = 1 print('Exit', n) return x
Here is what it prints
Enter 6 Enter 5 Enter 4 Enter 3 Enter 2 Enter 1 Exit 1 Exit 2 Exit 3 Exit 4 Exit 5 Exit 6
As you can see, we have called a function within a function within a function ... that’s recursion, of course.
Recursion is relatively inefficient compared to looping. This is because each step in a recursion results in a function call, whereas each step in a loop merely requires a "jump" to a different place in the code.
Calling a function involves considerably more work than a simple loop, and in any system it is going to take more time and use extra memory (memory is required to store the current state on the function – the values of its local variables – each time the function calls itself recursively).
However, Python has a rather more immediate problem. Recursive calls are limited to a depth of 1000. The code above cannot be used to calculate the factorial of any number greater than 1000.
This doesn’t mean that recursion isn’t a useful tool in Python. If you are processing a binary tree, for example, a depth of 1000 allows you to process a tree containing around 2^1000 elements, which is a vast number. But if the problem can be solved with a simple loop, that is probably the best solution.
The form of recursion exhibited by factorial is called tail recursion. Tail recursion is when the recursive call is right at the end of the function (usually with a condition beforehand to terminate the function before making the recursive call).
When a function is tail recursive, you can generally replace the recursive call with a loop. In Python, you usually should do that!
Some languages automatically spot tail recursion and replace it with a looping operation. This is often called TCO (Tail Call Optimisation). Python does not do this. It tends to happen in pure functional languages, where in some cases loops don’t even exist. Such languages are often far more declarative than Python, which makes it easier to detect tail recursion.
There are some hacks that allow you to implement tail recursion in Python, but they are not covered here.
Inefficient recursion – Fibonacci numbers.
Here is another classic example of recursion – calculating the nth Fibonacci number. It turns out that this is hopelessly inefficient using pure recursion, but we will also look at a useful technique to alleviate the problem.
If you are not familiar with the Fibonacci series, it is an infinite series of numbers defined as follows:
F0 = 0 F1 = 1 F2 = F1 + F0 = 1 F3 = F2 + F1 = 2 ... F(n) = F(n-1) + F(n-2)
In other words, each element is the sum of the two previous elements. Here are the first few values of the series:
0, 1, 1, 2, 3, 5, 8, 13, 21...
This can obviously be calculated recursively, like this:
def fibonacci(n): if n==0: x = 0 elif n==1: x = 1 else: x = fibonacci(n-1) + fibonacci(n-2) return x print(fibonacci(8)) # 21
Notice that we need to supply two initial cases. You can’t calculate F0 or F1, they must defined. The series is numbered from 0, so element 8 is 21.
If we now look at how this function actually works, by analysing adding Enter and Exit print statements as before. It turns out to be a bit of a nightmare!
Calculating F8 requires us to calculate F7 and F6. That is where the inefficiencies start, because of course calculating F7 also requires us to calculate F6. Since these calculations are done in separate branches of the recursion, F6 will be calculated twice.
Calculating F6 twice then requires us to calculate F5 twice, but we also need to calculate it again as part of the F7 calculation, so we end up calculating F5 three times.
Calculating F6 twice and F5 three times means we end up calculating F4 five times. You might be noticing a pattern here – the number of times we have to calculate each successively lower level of recursion increases according to the Fibonacci series!
In short, this is a terribly inefficient method.
The basic problem here is that we are calling
fibonacci multiple times, with the same argument, but each time we are calculating the value all over again.
Now we know that
fibonacci is a pure function. It has no side effects, and every time you call it with a particular value, you will always get the same result.
What we need is some way to remember all the times it has been called before, store the result, and only calculate it if it is called with a value that has never been seen before. We can do this using a dictionary told all the previous calls. The dictionary key is the argument, the dictionary value is the result. Here is the code:
cache = dict() def fibonacci(n): if n in cache: return cache[n] if n==0: x = 0 elif n==1: x = 1 else: x = fibonacci(n-1) + fibonacci(n-2) cache[n] = x return x print(fibonacci(8))
Here we define an empty dictionary called
cache. Every time we enter the
fibonacci function, we check if the value if
n already exists in the dictionary. If it does, we simply return the previous stored value for the function result, which is found in
If the value doesn’t already exist, we calculate it in the normal way. Then before
fibonacci returns we store the result in cache, so we never have to calculate it again.
This is all very well, but it is adding extra code to the
fibonacci function. Extra code which in fact, has little to do with what the function is really doing, it has more to do with an efficiency improvement that you might wish to use with other function, not just
These so-called cross cutting concerns are exactly what decorators where invented for.
Another problem is that our cache implementation is quite crude and simplistic. It relies on having a global variable, cache, kicking around in the file, and hoping that nobody else uses it. It only works for functions that take exactly one argument. It also allows the cache to grow to any size, when it might sometimes be more sensible to set a maximum size.
Fortunately, there is an existing decorator,
lru_cache, solves all those problems. It is in the
functools module, and it only takes one line of code to set it up:
from functools import lru_cache @lru_cache() def fibonacci(n): print('Enter', n) if n==0: x = 0 elif n==1: x = 1 else: x = fibonacci(n-1) + fibonacci(n-2) print('Exit', n) return x print(fibonacci(8))
That is it. Just import the decorator and add
@lru_cache before the function definition, and it will only ever call
fibonacci once for every value of