Monday, December 28, 2015

On calling Ceres from Scala - 2.0: Swiggin' it

This is the second installment of a series documenting my progress on porting the Ceres solver API to Scala. All my sources quoted below are available on github.

Swiggin' callbacks


Using SWIG to build most (hopefully all) the JNI wrapping code used to call Ceres from the JVM is a choice that hardly needs justification. To wit, its advantages are:
  • Decently well-thought and documented API that tries hard to help the coder avoid common pitfalls.
  • Ogles of helper methods to address common use cases.
  • Minimal configuration required, reduces maintenance work for keeping wrappers in synch with the wrapped library.
Disadvantages include the possible generation of underperforming code. However I believe that, with due care, the wrap interfaces can be kept minimalistic enough to be essentially optimal by design (i.e. near the limitations of the JNI/JVM interface and the non-native user application code running in the JVM).

SWIG is well documented and fairly easy to use in practice. Our case, however, needs a couple of its "advanced" features, due to the need cross the JNI boundary in both directions. This need arises because we want to write our models (and cost functions) in Scala, calling Ceres in a two-step workflow:
  • First, problem specification. We want to give Ceres some Scala objects - the residual terms - that know how to compute some cost terms (and perhaps their derivatives) for some putative values of the unknowns.
  • Second, we run Ceres's C++ optimization code. This in turn will call the Scala residual term computations iteratively, until some optimum set of parameters is (hopefully) found.
So we need to cross the JNI boundary in both directions:
  • In the first step we call Scala  C++ to configure the native optimizer so it can use the Scala residual terms. 
  • Then we call Scala → C++ to start the optimizer.
  • Then the optimizer calls back C++ → Scala to compute the residuals.
  • The residual computation themselves may or may not call Scala → C++ to access some Ceres utilities.
Going from Scala to C++ with SWIG is very easy. Going the other way is easy modulo some wading trough the docs. The SWIG facility to use for this use case is directors.  Ultra-simplifying, directors make Java classes behave like subclasses of C++ classes. The subclass behavior includes, crucially, resolution of virtual methods. 

In our case, we can declare an abstract C++ class that implements the residual term virtual interface as expected by Ceres. We then run SWIG to create an appropriate wrapper/director for it, and extend it in Scala to actually implement the computation. 

In practice


To make it concrete, let's code a toy residual term and evaluator library, consisting of:

  • A ResidualTerm abstract base class that assumes residuals are evaluated by its call operator, with the result returned into its by-reference second argument.
  • A ResidualEvaluator class that holds a collection of concrete ResidualTerms, with a method to register new ones, and an "eval" method to compute the sum of all residual terms at an input point x.
//
// residuals.h
//
#include <vector>
#include <cstdio>
class ResidualTerm {
public:
ResidualTerm();
virtual ~ResidualTerm();
// Evaluate the residual at point x, write the result into y
// Returns true iff successful.
virtual bool operator()(double x, double* y) const = 0;
};
class ResidualEvaluator {
private:
std::vector<const ResidualTerm*> residuals;
public:
// Register the given residual term (not owned).
// Returns the number of registered terms.
int AddResidualTerm(ResidualTerm const* c);
// Compute the sum of all residual terms at x.
double Eval(double x) const;
};
//
// residuals.cc
//
#include "residuals.h"
ResidualTerm::ResidualTerm() {}
ResidualTerm::~ResidualTerm() {}
int ResidualEvaluator::AddResidualTerm(ResidualTerm const* c) {
residuals.push_back(c);
return residuals.size();
}
double ResidualEvaluator::Eval(double x) const {
double total = 0.0;
double y;
for (int i = 0; i < residuals.size(); ++i) {
const ResidualTerm& cost = *(residuals[i]);
cost(x, &y);
total += y;
fprintf(stderr, "Computed residuals[%d]=%g, total=%g\n", i, y, total);
}
return total;
}
view raw residuals.h hosted with ❤ by GitHub
Let's wrap it using the following SWIG script (see comments herein for explanation):

/* Name the new SWIG module and enable creation of directors */
/* (they are disabled by default) */
%module(directors="1") residuals
/* Declarations to wrap. */
%{
#include "residuals.h"
%}
/* Specify director feature for the base class ResidualTerm */
/* The "assumeoverride=1" indicates that subclasses are expected */
/* to override its methods */
%feature("director", assumeoverride=1) ResidualTerm;
/* "operator()" of ResidualTerm does not swig into a valid Java/Scala */
/* method name, so we rename its wrap using the standard name for the */
/* Scala call operator: apply. */
%rename(apply) ResidualTerm::operator();
/* Ue the SWIG carrays library to treat C++ pointers as arrays. */
%include "carrays.i"
%array_class(double, doubleArray);
%include "residuals.h"
view raw residuals.i hosted with ❤ by GitHub


Running the swig(1) command on it generates both a residuals_wrap.cxx C++ interface, and Java sources to call it. The former are compiled into a DLL along with the actual sources, the latter are java-compiled into class files, ready to be used by our Scala implementation to follow:

package org.somelightprojections.skeres
import residuals.{ResidualTerm, ResidualEvaluator, SWIGTYPE_p_double, doubleArray}
// Concrete residual term: cost(x) = x - y0
case class Residual(y0: Double) extends ResidualTerm {
override def apply(x: Double, y: SWIGTYPE_p_double): Boolean = {
val out = doubleArray.frompointer(y)
out.setitem(0, x - y0)
true
}
}
object Test {
// Load the swig-wrapped C++ DLL.
System.loadLibrary("residuals");
def main(args: Array[String]): Unit = {
val re = new ResidualEvaluator
println("Creating/adding residual terms")
Vector(3.0, 5.0)
.map(Residual)
.foreach(re.AddResidualTerm)
println("Evaluating at 10.0")
val cost = re.Eval(10.0)
println(s"Total residual = $cost")
}
}
view raw Residuals.scala hosted with ❤ by GitHub

Here I added a Test executable to exercise the whole thing. Note how the wrapped classes behave quite naturally - e.g. the evaluator's AddResidualTerm is called in a foreach callback. The only really quirky item is the use of the pointer-wrapping class "SWIGTYPE_p_double" to wrap the by-reference output argument of the cost Residual. This too could be finessed using a SWIG typemap, or even just a rename.

Running yields the expected output:

$ scala -cp classes org.somelightprojections.skeres.Test
Creating/adding residual terms
Evaluating at 10.0
Computed residuals[0]=7, total=7
Computed residuals[1]=5, total=12
Total residual = 12.0

That's it for today. All code is available in the sandbox section of the skeres github repo.

No comments:

Post a Comment