The Numba "stencil" directive
After a bit of a delay we're getting the blog posts going again with a mention of a slightly odd bit of the Python Numba compiler - the stencil directive. The purpose of Numba is to produce compiled code from Python source that should run at a reasonable fraction of the speed of classical C/Fortran etc. codes. In general it produces codes that is about 20-40% as fast as C or Fortran code, so typically only about 1-2% as fast as the computer can theoretically operate. In general, the more that you can tell Numba in advance about how you are going to use your data the more options it has to optimise the code as it compiles it. The "stencil" directive is used to indicate to Numba that you are going to operate on an array by moving a stencil across it and updating each point using data from neighbouring points.
This is a fairly common thing to want to do and crops up in algorithms from image smoothing to numerical solutions to differential equations so this is a useful bit of the library. As a simple example consider the simplest possible image smoothing algorithm. For each pixel in the image P(i,j) replace the value with the average of the surrounding pixels, so
P'(i,j) = 1/4 * (P(i+1,j) + P(i-1,j) + P(i,j+1) + P(i,j-1))
Note that the left hand side of this equation is P' not P, so when we calculate the average for each pixel using the original values surrounding it not the values that have already been through the averaging process. After you have finished you copy P' to P. There is then one last thing that you have to worry about: what do you do when you reach the edge of the array? Since you are using adjacent cells you have to do something or you will read outside your array. Numba's stencil operator at present only has two options: set the outer cells to be zero or some other constant value. In general, this will mean that you want to have an outer strip of cells added around your image otherwise your image will get smaller as it smooths. These outer cells are often called ghost or guard cells and are also common in numerical solution of differential equations for representing boundary conditions. The code for doing all this in Python is quite simple
def blur(A): s = A.shape B = np.zeros((s[0],s[1])) for i in range(1,s[0]-1): #Range only over the inner cells for j in range(1,s[1]-1): #Range only over the inner cells B[i, j] = 0.25 * (A[i-1,j] + A[i+1,j] + A[i,j-1] + A[i,j+1]) return B
This code takes a numpy array and iterates over all but the outer strip of cells in every direction, averages and returns the value. Each call to this function smooths your image a bit (by a radius of about 1 pixel) so in general you'll want to call it several times to smooth your image as much as you want. Running this algorithm on a stock image of 1529 x 2250 pixels on a 3.4GHz processor takes about 3.3 seconds per iteration using the pure Python implementation and 0.006 seconds by using the Numba @njit decorator. For testing we ran 100 iterations of the pure Python code and 1000 iterations of the Numba code. If run to the same number of iterations, the results in both cases is the same
The simple equivalent using stencil is
@numba.stencil def blur(A): return 0.25 * (A[-1,0] + A[1,0] + A[0,-1] + A[0,1])
You can see how i-1 becomes just -1 and similarly for other parameters in the stencil, and you can also see how this operation now becomes a one-liner so it's much easier to write. Unfortunately the performance is much worse than the @njit version too taking about 0.22 seconds per iteration, although that is still some 10x faster than the native python performance. Fortunately performance can be improved by calling the stencil from a Numba jit-ed function, so for example
@stencil def inner_blur(A): return 0.25 * (A[-1,0] + A[1,0] + A[0,-1] + A[0,1]) @njit(parallel=False) def blur(A): return inner_blur(A)
The result from this method gives performance that is indistinguishable from that of the direct Numba jit version and is still rather shorter. Since these stencils lay out the data dependency you can also set parallel=True in the @njit call and this can be quite succesful but tends to work better for more complex stencils. In this particular case despite showing the Python interpreter apparently using 6 cores solidly the execution speed slows down by a factor of 3.
No comments
Add a comment
You are not allowed to comment on this entry as it has restricted commenting permissions.