Swiggin' callbacks
Using SWIG to build most (hopefully all) the JNI wrapping code used to call Ceres from the JVM is a choice that hardly needs justification. To wit, its advantages are:
- Decently well-thought and documented API that tries hard to help the coder avoid common pitfalls.
- Ogles of helper methods to address common use cases.
- Minimal configuration required, reduces maintenance work for keeping wrappers in synch with the wrapped library.
Disadvantages include the possible generation of underperforming code. However I believe that, with due care, the wrap interfaces can be kept minimalistic enough to be essentially optimal by design (i.e. near the limitations of the JNI/JVM interface and the non-native user application code running in the JVM).
- First, problem specification. We want to give Ceres some Scala objects - the residual terms - that know how to compute some cost terms (and perhaps their derivatives) for some putative values of the unknowns.
- Second, we run Ceres's C++ optimization code. This in turn will call the Scala residual term computations iteratively, until some optimum set of parameters is (hopefully) found.
- In the first step we call Scala → C++ to configure the native optimizer so it can use the Scala residual terms.
- Then we call Scala → C++ to start the optimizer.
- Then the optimizer calls back C++ → Scala to compute the residuals.
- The residual computation themselves may or may not call Scala → C++ to access some Ceres utilities.
In our case, we can declare an abstract C++ class that implements the residual term virtual interface as expected by Ceres. We then run SWIG to create an appropriate wrapper/director for it, and extend it in Scala to actually implement the computation.
In practice
To make it concrete, let's code a toy residual term and evaluator library, consisting of:
- A ResidualTerm abstract base class that assumes residuals are evaluated by its call operator, with the result returned into its by-reference second argument.
- A ResidualEvaluator class that holds a collection of concrete ResidualTerms, with a method to register new ones, and an "eval" method to compute the sum of all residual terms at an input point x.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// residuals.h | |
// | |
#include <vector> | |
#include <cstdio> | |
class ResidualTerm { | |
public: | |
ResidualTerm(); | |
virtual ~ResidualTerm(); | |
// Evaluate the residual at point x, write the result into y | |
// Returns true iff successful. | |
virtual bool operator()(double x, double* y) const = 0; | |
}; | |
class ResidualEvaluator { | |
private: | |
std::vector<const ResidualTerm*> residuals; | |
public: | |
// Register the given residual term (not owned). | |
// Returns the number of registered terms. | |
int AddResidualTerm(ResidualTerm const* c); | |
// Compute the sum of all residual terms at x. | |
double Eval(double x) const; | |
}; | |
// | |
// residuals.cc | |
// | |
#include "residuals.h" | |
ResidualTerm::ResidualTerm() {} | |
ResidualTerm::~ResidualTerm() {} | |
int ResidualEvaluator::AddResidualTerm(ResidualTerm const* c) { | |
residuals.push_back(c); | |
return residuals.size(); | |
} | |
double ResidualEvaluator::Eval(double x) const { | |
double total = 0.0; | |
double y; | |
for (int i = 0; i < residuals.size(); ++i) { | |
const ResidualTerm& cost = *(residuals[i]); | |
cost(x, &y); | |
total += y; | |
fprintf(stderr, "Computed residuals[%d]=%g, total=%g\n", i, y, total); | |
} | |
return total; | |
} |
Let's wrap it using the following SWIG script (see comments herein for explanation):
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* Name the new SWIG module and enable creation of directors */ | |
/* (they are disabled by default) */ | |
%module(directors="1") residuals | |
/* Declarations to wrap. */ | |
%{ | |
#include "residuals.h" | |
%} | |
/* Specify director feature for the base class ResidualTerm */ | |
/* The "assumeoverride=1" indicates that subclasses are expected */ | |
/* to override its methods */ | |
%feature("director", assumeoverride=1) ResidualTerm; | |
/* "operator()" of ResidualTerm does not swig into a valid Java/Scala */ | |
/* method name, so we rename its wrap using the standard name for the */ | |
/* Scala call operator: apply. */ | |
%rename(apply) ResidualTerm::operator(); | |
/* Ue the SWIG carrays library to treat C++ pointers as arrays. */ | |
%include "carrays.i" | |
%array_class(double, doubleArray); | |
%include "residuals.h" | |
Running the swig(1) command on it generates both a residuals_wrap.cxx C++ interface, and Java sources to call it. The former are compiled into a DLL along with the actual sources, the latter are java-compiled into class files, ready to be used by our Scala implementation to follow:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.somelightprojections.skeres | |
import residuals.{ResidualTerm, ResidualEvaluator, SWIGTYPE_p_double, doubleArray} | |
// Concrete residual term: cost(x) = x - y0 | |
case class Residual(y0: Double) extends ResidualTerm { | |
override def apply(x: Double, y: SWIGTYPE_p_double): Boolean = { | |
val out = doubleArray.frompointer(y) | |
out.setitem(0, x - y0) | |
true | |
} | |
} | |
object Test { | |
// Load the swig-wrapped C++ DLL. | |
System.loadLibrary("residuals"); | |
def main(args: Array[String]): Unit = { | |
val re = new ResidualEvaluator | |
println("Creating/adding residual terms") | |
Vector(3.0, 5.0) | |
.map(Residual) | |
.foreach(re.AddResidualTerm) | |
println("Evaluating at 10.0") | |
val cost = re.Eval(10.0) | |
println(s"Total residual = $cost") | |
} | |
} |
Here I added a Test executable to exercise the whole thing. Note how the wrapped classes behave quite naturally - e.g. the evaluator's AddResidualTerm is called in a foreach callback. The only really quirky item is the use of the pointer-wrapping class "SWIGTYPE_p_double" to wrap the by-reference output argument of the cost Residual. This too could be finessed using a SWIG typemap, or even just a rename.
Running yields the expected output:
$ scala -cp classes org.somelightprojections.skeres.Test
Creating/adding residual terms
Evaluating at 10.0
Computed residuals[0]=7, total=7
Computed residuals[1]=5, total=12
Total residual = 12.0
That's it for today. All code is available in the sandbox section of the skeres github repo.
No comments:
Post a Comment