Computing the Fibonacci Sequence Seven Ways
Lessons learned from optimizing beyond reason
- Naive Recursive
- Cached Recursion
- Iterative
- Matrix form
- Matrix form, with the power of two trick
- Exponential
- Exponential, but now more accurate
- Optimal
- Credits
Naive Recursive
def fib_recursive(n):
return float(n) if n < 2 else fib_recursive(n-1) + fib_recursive(n-2)
It’s brief, and fairly clear, I’ll give it that. Also, the runtime starts to look like O(fib(n)), which I think is neat.
Cached Recursion
Since the problem with the naive recursion is that we recompute the same input value many times, the obvious solution is just to memoize the function. Specifically, if we just cache the most recent three function inputs (fib(n), fib(n-1), and fib(n-2)), then we will get rid of all the redundant computations, and we can keep the nice recursive expression:
from functools import lru_cache
@lru_cache(maxsize=3)
def fib_recursive_lru(n):
return float(n) if n < 2 else fib_recursive_lru(n-1) + fib_recursive_lru(n-2)
As usual, be careful with caching a recusive function, since the cache makes reasoning about performance and correctness hard. For example, if you test this code by checking fib(1), then fib(2), then fib(3), all the way to fib(1000), the cached implementation will look great! All the values come out correct, and so fast! But then a user comes, and the first Fibonnacci number they need is fib(100)- they see the function as sadly linear time. Then worse! They try fib(1000), which you’ve tried and know works fine. But the user reports a RecursionError: maximum recursion depth exceeded. It turns out that the order of your testing meant you never actually tested the full depth recursion- you always reached a cached value immediately.
Iterative
def fib_iterative(n):
if n < 2:
return float(n)
current, prev = 1., 0.
for _ in range(n-1):
current, prev = current + prev, current
return current
This is the solution it seems like your algorithms 101 professor wanted to see on the exam. It’s fine- linearly performant and nearly as clear as the recursive version, with absolutely no gotchas. So let’s quickly proceed past it into the fun algorithms.
Matrix form
def fib_matrix(n):
if n < 2:
return float(n)
def mat_mul(m,n):
return [[sum(m[i][k] * n[k][j] for k in range(len(m[0]))) for j in range(len(n[0]))] for i in range(len(m))]
result = [[1.], [0.]]
m = [[1.,1.],[1.,0.]]
for _ in range(n-1):
result = mat_mul(m, result)
return result[0][0]
What’s happening here? Basically the same thing as the iterative form. result[0]
is the iterative current
, and result[1]
is the iterative prev
. On each step, result[0]
becomes the sum of the old result[0]
and result[1]
, and result[1]
is assigned 1
times the old result[0]
.
Why is this useful? Honestly, it’s not. This does exactly the same operations as the iterative form, except with a lot of complexity and some extra computation. But, when we think about algorithms as matrix operations, we can start bringing in the tools of linear algebra. In particular, we can see that we’re multiplying by the same matrix many times- we’re computing result = result * m ^ (n-1). And because matrix multiplication is associative, we can restructure the exponentiation to re-use some computation, and be more efficient, leading to:
Matrix form, with the power of two trick
def fib_matrix_power_2(n):
if n < 2:
return float(n)
def mat_mul(m,n):
return [[sum(m[i][k] * n[k][j] for k in range(len(m[0]))) for j in range(len(n[0]))] for i in range(len(m))]
result = [[1.], [0.]]
m = [[1.,1.],[1.,0.]]
i = 1
while i <= (n-1):
if i & (n-1):
result=mat_mul(m, result)
m = mat_mul(m,m)
i *= 2
return result[0][0]
For some insight into this method, let’s start with an example. To multiply r * M^15, you can multiply r by M 15 times. Or, you could multiply r by M, then M^2, then M^4, then M^8, for 4 opertaions. Each power of 2 of M comes from just multiplying the last by itself, for 4 operations, so 8 total for this method. That’s half the operations as the naive way, and it grows logaraithmicaly in the matrix exponent. In general, to multiply by M^x, you just need to multiply by the powers of 2 corresponding to which bits are set in x- for X=17, that’s bits 1 and 5.
Converting an algorithm to matrix form to solve it with matrix exponentiation is a general trick- maybe not common, but if you see everything in matrices, the technique can come up in surpring places. The next technique, however, is totally Fibonacci specific.
Exponential
Domain expertise > programming prowess. According to the Fibonacci Wikipedia page, there’s a closed form solution to the Fibonacci sequence, and it’s pretty easy to implement as a power of the golden ratio:
import math
def fib_exponential(n):
return float(round((((1+math.sqrt(5))/2)**(n)) / math.sqrt(5)))
Tada! One operation, one line of code, an excess of parenthesis, and we have achieved Fibonacci. Clearly, this is the holy grail of fast Fibonacci code. Except…
The lie, and the neat pattern in the lie
Domain expertise should probably be paired with understanding the limits of code. In this case, the problem shows up when we test large-ish inputs on the exponential version of the function against the reference iterative version.
for i in range(500):
assert fib_iterative(i) == fib_cached(i), f"Missmatch on cached version for {i}, expected {fib_iterative(i)}, got {fib_cached(i)}"
assert fib_iterative(i) == fib_recursive_lru(i), f"Missmatch on recursive version for {i}, expected {fib_iterative(i)}, got {fib_recursive_lru(i)}"
assert fib_iterative(i) == fib_matrix(i), f"Missmatch on matrix version for {i}, expected {fib_iterative(i)}, got {fib_matrix(i)}"
assert fib_iterative(i) == fib_matrix_power_2(i), f"Missmatch on matrix power 2 version for {i}, expected {fib_iterative(i)}, got {fib_matrix_power_2(i)}"
assert fib_iterative(i) == fib_exponential(i), f"Missmatch on exponential version for {i}, expected {fib_iterative(i)}, got {fib_exponential(i)}"
returns AssertionError: Missmatch on exponential version for 71, expected 308061521170129.0, got 308061521170130.0
That’s odd- it’s weird that it would work fine for the first 71 numbers, and then fail by being just slightly off. Maybe the closed form approximation starts to show it’s limits? But Wikipedia didn’t describe it as an approximation- that’s just the real formula.
Turns out it’s not our formula that’s an approximation, it’s our computer. Floating point numbers only have so much precision- 64 bits total for a Python float. When we try to represent the golden ratio, then raise it to larger and larger exponents, those precision limits eventually end up with errors that are large enough to show up after rounding to an integer.
Interestingly, when we log those errors, we see a familiar pattern:
exp_errors = [fib_exponential(i) - fib_iterative(i) for i in range(500)]
print(exp_errors)
[0.0, {0.0 69 times} ... 0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 9.0, 14.0, 24.0, 40.0, 60.0, 104.0, 168.0, 288.0, ...]
Notice the fun pattern? At a certain point, the error starts to itself look pretty similar to the Fibonacci sequence. This turns out to be less of a coincidence on deeper inspection- the relative error stays about the same for all n, but since Fibonacci values increase exponentially, absolute error increases at the same proportional rate.
Exponential, but now more accurate
The exponential error can be reduced by taking advantage of Python’s arbitrary precision libraries. But higher precision means a higher runtime- there’s no free lunch.
from decimal import *
def fib_exponential_precise(n, precision=100):
getcontext().prec = precision
sqrt_5 = Decimal(5) ** Decimal(.5)
phi = (1+sqrt_5)/2
return float(round((pow(phi,n)-pow(1-phi,n))/sqrt_5))
Are are precision problems finally squashed? Or at least, would they be if we set the precision high enough? Absolutely not- the unit of least precision in floating point is well above 1 for high floating point values. No amount of computational precision in the world will give perfect accuracy when the output value can’t even represent the computed value.
Digression- what does big-O notation mean?
When we say a function is O(ln(x)), we mean that with sufficiently large input x, it will take a number of operations proportional to ln(x) for compute the function. What counts as an operation? That depends on your machine model. Usually this doesn’t matter too much1- for most programs, for most machine models, a line of code is basically constant time, and what really matters is the number of iterations of an inner loop or something.
For arbitrary precision math though, a line of Python code is not constant time at all. In fact, it hides a complex inner loop that handles the many memory elements that collectively represent a high precision number. How slow is that? In practice, depends on your Python version. In theory, nobody knows for sure right now.
Optimal
What’s better than log(n) ish runtime? Constant runtime. And my personal favorite way to develop a constant runtime algorithm is to skip the algorithm part- let’s just know the answers. In this case, there aren’t actually that many possible inputs to the Fibonnacci function- sure, you can write x = fib(1_000_000)
, but it’s just going to return ‘inf’ anyway, since an FP64 number can only go so high, and fib() is an exponential function. So let’s compute inputs until we run out of output, embed the results in a lookup table, and use that to find the answers at runtime.
cached_fib = [0., 1.]
while cached_fib[-1] != float('inf'):
cached_fib.append(cached_fib[-1] + cached_fib[-2])
def fib_cached(n):
return cached_fib[n] if n < len(cached_fib) else float('inf')
It appears that our cached_fib list is a totally reasonable 1478 elements.
Once again, we see that when we realize problem domains are actually quite small, exhaustive methods become tractable, clear, and efficient.
Credits
Thanks to Matthew Caseres for suggesting Fibonacci as part of the Leetcode practice for the day, and especially Adam Kelly for pointing out the pattern in the accuracy errors in the exponential solution and the subsequent deep discussion on big-O notiation.
Footnotes
-
But there’s lots of performance to be gained by developing an algorithm for a more accurate machine model- for instance, a machine model that treats computation as free but cache misses as expensive will be a better match for real world servers in most domains. ↩