https://github.com/dfm/AstroFlow/blob/master/astroflow/ops/kepler_op.cc
template
class KeplerOp : public OpKernel {
public:
explicit KeplerOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("maxiter", &maxiter_));
OP_REQUIRES(context, maxiter_ >= 0,
errors::InvalidArgument("Need maxiter >= 0, got ", maxiter_));
OP_REQUIRES_OK(context, context->GetAttr("tol", &tol_));
}
void Compute(OpKernelContext* context) override {
const Tensor& M_tensor = context->input(0);
const Tensor& e_tensor = context->input(1);
const int64 N = M_tensor.NumElements();
Tensor* E_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, M_tensor.shape(), &E_tensor));
const auto M = M_tensor.template flat();
const auto e = e_tensor.template scalar()(0);
auto E = E_tensor->template flat();
for (int64 n = 0; n < N; ++n) {
E(n) = kepler(M(n), e, maxiter_, tol_);
}
}
private:
int maxiter_;
float tol_;
};