Slide 49
Slide 49 text
Creating python bindings for C++ with pybind11
#include
#include
#include
#include
#include
#include
#include
namespace odeint = boost::numeric::odeint;
namespace py = pybind11;
typedef vex::multivector state_type;
typedef odeint::runge_kutta4_classic<
state_type, double, state_type, double,
odeint::vector_space_algebra, odeint::default_operations
> Stepper;
vex::Context& ctx(std::string name = "") {
static vex::Context c(vex::Filter::Env && vex::Filter::Name(name));
return c;
}
struct lorenz_system {
int n;
double sigma, b;
vex::vector R;
state_type X;
lorenz_system(double sigma, double b, py::array_t R)
: n(R.size()), sigma(sigma), b(b), R(ctx(), n, R.data()), X(ctx(), n) {}
void operator()(const state_type &x, state_type &dxdt, double t) const {
dxdt = std::make_tuple(
sigma * (x(1) - x(0)),
R * x(0) - x(1) - x(0) * x(2),
x(0) * x(1) - b * x(2));
}
py::array_t advance(py::array_t x_in, int steps, double dt) {
for(int i=0, b=0, e=n; i<3; ++i, b+=n, e+=n)
vex::copy(x_in.data()+b, x_in.data()+e, X(i).begin());
odeint::integrate_n_steps(Stepper(), std::ref(*this), X, 0.0, dt, steps);
py::array_t x_out(std::array{x_in.shape(0), x_in.shape(1)});
for(int i=0, b=0; i<3; ++i, b+=n)
vex::copy(X(i).begin(), X(i).end(), x_out.mutable_data()+b);
return x_out;
}
};
PYBIND11_PLUGIN(pylorenz) {
py::module m("pylorenz");
m.def("context", [](std::string name) {
std::ostringstream s; s << ctx(name); py::print(s.str());
}, py::arg("name") = std::string(""));
py::class_(m, "Stepper")
.def(py::init>())
.def("advance", &lorenz_system::advance)
;
return m.ptr();
}
3. System function
24 struct lorenz_system {
25 int n;
26 double sigma, b;
27 vex::vector R;
28 state_type X;
29
30 lorenz_system(double sigma, double b, py::array_t R)
31 : n(R.size()), sigma(sigma), b(b), R(ctx(), n, R.data()), X(ctx(), n) {}
32
33 void operator()(const state_type &x, state_type &dxdt, double t) const {
34 dxdt = std::make_tuple(
35 sigma * (x(1) - x(0)),
36 R * x(0) - x(1) - x(0) * x(2),
37 x(0) * x(1) - b * x(2));
38 }
39
40 py::array_t advance(py::array_t x_in, int steps, double dt) {
41 for(int i=0, b=0, e=n; i<3; ++i, b+=n, e+=n)
42 vex::copy(x_in.data()+b, x_in.data()+e, X(i).begin());
43
44 odeint::integrate_n_steps(Stepper(), std::ref(*this), X, 0.0, dt, steps);
45
46 py::array_t x_out(std::array{x_in.shape(0), x_in.shape(1)});
47 for(int i=0, b=0; i<3; ++i, b+=n)
48 vex::copy(X(i).begin(), X(i).end(), x_out.mutable_data()+b);
49
50 return x_out;
51 }
52 };