Thursday, 22 April 2010

Plotting Lambert's W Function

For this post I wanted to come up with something a little easier on the eye than my previous offerings (Loading DTDs using DOM in Python, Colorised svn diffs and How do you like your Fortran 95 interfaces?).

Here you go:




I've been developing code for computing complex values of Lambert's W function, and I thought I'd share the Riemann-surface plots above, which are, from top to bottom, Re(W), Im(W) and abs(W) for the branches -1 (Blues), 0 (Greens) and 1 (Oranges). The route I've taken to get the plots has been a good learning experience for me.

First approach

My initial idea was to write a mex wrapper to my algorithmic Fortran code, along the lines of what happens for the NAG Toolbox for MATLAB. I also wanted to see if I could provide a vectorized interface, as native MATLAB functions have.


Whew, it's not much fun debugging mex files, is it?


The main thing I learned from this experience was to use mxCreateNumericMatrix to create my function's return values: in this case the W value, the residual error and the fail flag.
/* prhs[2] is the array of input zs */
dims = mxGetDimensions(prhs[2]);
plhs[0] = mxCreateNumericMatrix(dims[0], dims[1], 
                                mxDOUBLE_CLASS,
                                mxCOMPLEX);

if (nlhs >= 2)
{
  plhs[1] = mxCreateNumericMatrix(dims[0], dims[1],
                                  mxDOUBLE_CLASS,
                                  mxREAL);
  
  if (nlhs >= 3)
    {
      plhs[2] = mxCreateNumericMatrix(dims[0], dims[1],
                                      mxINT_CLASS,
                                      mxREAL);
    }

}
Other approaches I had tried, such as populating local mxArray *s which were only assigned back to the plhs arrays at the end (and if required) caused all sorts of upset to MATLAB when it exited my mex routine.

With the working code, I managed to get some good plots in a MATLABy way. I felt a little discontent though, and then I remembered the advice of my old (and sadly departed) A-level statistics teacher, about how it's better to be able to solve a problem two ways than one.

Second approach

I thought back to Mike Croucher's PyNAG tutorials on calling the NAG C Library from Python, and decided it might be fun to call my algorithmic Fortran code from Python and use matplotlib to do the plotting. While I was at it I decided I'd see if I could get Fortran 2003's C Interoperability to help me make the call to Fortran from Python more straightforward; in the 'unassisted' case, calling Fortran from Python is very similar to calling Fortran from C, with the usual headaches that entails.

The primary hoop I had to jump through to allow C interoperability was to map Fortran character strings to arrays:
...
   USE, INTRINSIC :: iso_c_binding, ONLY : c_char
   ...
   CHARACTER (kind=c_char), DIMENSION (6), &
      INTENT (IN) :: in_str
   CHARACTER (kind=c_char), DIMENSION (10), &
      INTENT (OUT) :: out_str
   CHARACTER (len=SIZE(in_str)) :: local_in_str
   CHARACTER (len=SIZE(out_str)) :: local_out_str
   INTEGER :: i
   ...
   DO i = 1, SIZE(in_str)
      local_in_str(i:i) = in_str(i)
   END DO
   ...
!  Here's where all the real stuff happens. Operate on
!  local_in_str and local_out_str.
   ...
   DO i = 1, SIZE(out_str)
      out_str(i) = local_out_str(i:i)
   END DO

END
Then with a similarly-vectorized and Pythonic wrapper around my Fortran call I generated the plots above using
from mpl_toolkits.mplot3d import \
     Axes3D
from matplotlib import \
     cm
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = Axes3D(fig)
n = 100
lower = -1
upper = -lower
threshold = 0.001
X1, Y1 = np.meshgrid(np.linspace(lower, upper, n),
                     np.linspace(lower, -threshold, n))
X2, Y2 = np.meshgrid(np.linspace(lower, upper, n),
                     np.linspace(threshold, upper, n))
cmaps = {-1: cm.Blues,
         0: cm.Greens,
         1: cm.Oranges}

# Loop over some branches
for b in [-1, 0, 1]:
    w = nag_LambertW(X1 + 1j*Y1, b=b, fail=1)[0]
    ax.plot_surface(X1, Y1, w.imag,
                    rstride=1, cstride=1, cmap=cmaps[b])
    w = nag_LambertW(X2 + 1j*Y2, b=b, fail=1)[0]
    ax.plot_surface(X2, Y2, w.imag,
                    rstride=1, cstride=1, cmap=cmaps[b])

plt.show()
I had to do two separate subplots for each branch to avoid matplotlib interpolating across the discontinuity at the branch cuts. Being a bit of a matplotlib newb, I'm now wondering
  • is there a better way of telling matplotlib to avoid discontinuities?
  • is it possible to apply a colour map continuously across two subplots?
  • doesn't plt.title('Title') work with surface plots?

2 comments:

  1. Well, turns out what I wanted was

    fig.suptitle('Imag(W)')

    before the plt.show()...

    ReplyDelete
  2. matplotlib's 3d support is a bit basic and I believe that support for it is uncertain at the moment. I think it was pulled altogether a while ago but has recently been added back in.

    Have you tried Mayavi?

    http://mayavi.sourceforge.net/docs/guide/ch04.html

    ReplyDelete

NAG moderates all replies and reserves the right to not publish posts that are deemed inappropriate.