Slide 22
Slide 22 text
演習用グラフコンパイラの説明 : C++ コード生成編、生成物
22
// 入力とパラメータを読み込む
const Matrix input = load<256, 784>(input_ptr);
const std::array target = load<256, int>(target_ptr);
const Matrix fc1_weight = load<16, 784>(fc1_weight_ptr);
const Vector fc1_bias = load<16>(fc1_bias_ptr);
const Matrix fc3_weight = load<10, 16>(fc3_weight_ptr);
const Vector fc3_bias = load<10>(fc3_bias_ptr);
// 計算
const Matrix fc1_pre_matmul = matmul<256, 784, 16>(input, trans<16, 784>(fc1_weight));
const Matrix fc1_pre = add_colvec<256, 16>(fc1_pre_matmul, fc1_bias);
const Matrix fc1 = relu<256, 16>(fc1_pre);
const Matrix fc3_pre_matmul = matmul<256, 16, 10>(fc1, trans<10, 16>(fc3_weight));
const Matrix fc3_pre = add_colvec<256, 10>(fc3_pre_matmul, fc3_bias);
const Matrix probs = softmax<256, 10>(fc3_pre);
以下のような C++ コードが生成される(load, matmul などの定義はテンプレートに)