Monday, December 28, 2015

On calling Ceres from Scala - 2.0: Swiggin' it

This is the second installment of a series documenting my progress on porting the Ceres solver API to Scala. All my sources quoted below are available on github.

Swiggin' callbacks


Using SWIG to build most (hopefully all) the JNI wrapping code used to call Ceres from the JVM is a choice that hardly needs justification. To wit, its advantages are:
  • Decently well-thought and documented API that tries hard to help the coder avoid common pitfalls.
  • Ogles of helper methods to address common use cases.
  • Minimal configuration required, reduces maintenance work for keeping wrappers in synch with the wrapped library.
Disadvantages include the possible generation of underperforming code. However I believe that, with due care, the wrap interfaces can be kept minimalistic enough to be essentially optimal by design (i.e. near the limitations of the JNI/JVM interface and the non-native user application code running in the JVM).

SWIG is well documented and fairly easy to use in practice. Our case, however, needs a couple of its "advanced" features, due to the need cross the JNI boundary in both directions. This need arises because we want to write our models (and cost functions) in Scala, calling Ceres in a two-step workflow:
  • First, problem specification. We want to give Ceres some Scala objects - the residual terms - that know how to compute some cost terms (and perhaps their derivatives) for some putative values of the unknowns.
  • Second, we run Ceres's C++ optimization code. This in turn will call the Scala residual term computations iteratively, until some optimum set of parameters is (hopefully) found.
So we need to cross the JNI boundary in both directions:
  • In the first step we call Scala  C++ to configure the native optimizer so it can use the Scala residual terms. 
  • Then we call Scala → C++ to start the optimizer.
  • Then the optimizer calls back C++ → Scala to compute the residuals.
  • The residual computation themselves may or may not call Scala → C++ to access some Ceres utilities.
Going from Scala to C++ with SWIG is very easy. Going the other way is easy modulo some wading trough the docs. The SWIG facility to use for this use case is directors.  Ultra-simplifying, directors make Java classes behave like subclasses of C++ classes. The subclass behavior includes, crucially, resolution of virtual methods. 

In our case, we can declare an abstract C++ class that implements the residual term virtual interface as expected by Ceres. We then run SWIG to create an appropriate wrapper/director for it, and extend it in Scala to actually implement the computation. 

In practice


To make it concrete, let's code a toy residual term and evaluator library, consisting of:

  • A ResidualTerm abstract base class that assumes residuals are evaluated by its call operator, with the result returned into its by-reference second argument.
  • A ResidualEvaluator class that holds a collection of concrete ResidualTerms, with a method to register new ones, and an "eval" method to compute the sum of all residual terms at an input point x.
Let's wrap it using the following SWIG script (see comments herein for explanation):



Running the swig(1) command on it generates both a residuals_wrap.cxx C++ interface, and Java sources to call it. The former are compiled into a DLL along with the actual sources, the latter are java-compiled into class files, ready to be used by our Scala implementation to follow:


Here I added a Test executable to exercise the whole thing. Note how the wrapped classes behave quite naturally - e.g. the evaluator's AddResidualTerm is called in a foreach callback. The only really quirky item is the use of the pointer-wrapping class "SWIGTYPE_p_double" to wrap the by-reference output argument of the cost Residual. This too could be finessed using a SWIG typemap, or even just a rename.

Running yields the expected output:

$ scala -cp classes org.somelightprojections.skeres.Test
Creating/adding residual terms
Evaluating at 10.0
Computed residuals[0]=7, total=7
Computed residuals[1]=5, total=12
Total residual = 12.0

That's it for today. All code is available in the sandbox section of the skeres github repo.

No comments:

Post a Comment