Upgrade to Pro — share decks privately, control downloads, hide ads and more …

TensorFlowで作ったAIをAndroidアプリで実行する

 TensorFlowで作ったAIをAndroidアプリで実行する

TensorFlowで機械学習をしたグラフをAndroid端末上でで実行する方法についてご紹介します。

このスライドは 2017/03/23 「TensorFlow Android ワークショップ」での発表に使用したスライドです。
https://tfug-tokyo.connpass.com/event/52648/

Arata Furukawa

March 23, 2017
Tweet

More Decks by Arata Furukawa

Other Decks in Technology

Transcript

  1. TensorFlowで作ったAIを
    Androidアプリで実行する
    2017-03-23 APPDOJO
    TensorFlow User Group / 東海大学 理学部1年生
    古川新

    View full-size slide

  2. 今日の流れ
    ・モバイルとAIについて
    ・サンプルアプリ配布
    ・プログラム解説

    View full-size slide

  3. △モバイルアプリでAIを利用する
    ◎モバイルアプリでAIを実行する

    View full-size slide

  4. モバイルアプリでAIを利用する
    モバイル上で
    実行する
    サーバで実行して
    モバイルで利用する

    View full-size slide

  5. でも、(計算コスト)お高いんでしょう?
    確かに、やれディープラーニングだとか、やれ人工知能だとか、
    よくわからないけどGPUが何百枚だ何だと話をしているのを聞きます。
    本当にモバイルで動くんでしょうか?
    まあ、動くんですが…

    View full-size slide

  6. 人工知能は
    筋肉と同じ

    View full-size slide

  7. 筋肉を鍛えるには、たくさんの時間とお金が必要です
    何回も何時間もトレーニングして、
    少しずつ筋肉は強くなります

    View full-size slide

  8. 一方で、筋肉を動かすこと自体は簡単
    時間もお金もかかりませんよね?

    View full-size slide

  9. 機械学習における学習と実行の非対称性
    機械学習によって「AIを作る」場合の演算コストと
    「作ったAIを動かす」場合の演算コストは別の話。
    実行するだけなら
    モバイル端末でも十分。

    View full-size slide

  10. TensorFlowの
    公式サンプルでは
    ・左
    リアルタイムに人物を認識し、トラッキン
    グするモデルのサンプル。1秒未満のラ
    グで複数の人物の動きを追跡できる。
    ・右
    Google Inceptionモデルによるリアルタイ
    ム物体識別。写真ではカメラに写った
    iPhoneを「iPod」か「携帯電話」であると
    判断している。予測にかかる時間は
    100-300msと、体感では一瞬で識別でき
    ている。
    ※写真はNexus5での実行結果です。
     端末や使用状況によって実行速度は変化します。

    View full-size slide

  11. ● 電池の減りが早くなる(と思われる)
    - 明らかに計算量が増えるので電池の減りが早くなることが予想されます。
    - どのくらいかは、モデル次第で大きく変わってくるので、定量的に示せません …。
    開発ごとに実験して、チューニングが必要です。
    ● 通常に比べるとアプリのデータサイズが増える
    - モデルデータは10MB〜100MBくらいは必要。チューニングに左右される点でもある。
    - 精度を保ったままモデルを縮小するといった先行研究はあります。
    ● チューニング不可欠・専門の技術者が必要
    - モバイル端末のスペックに限界がある以上、サーバでの計算以上に高度なチューニングが要
    求される。モデルの精度やサイズをコントロールできるだけの技術と知識を持った人間が必
    要。
    モバイル上で実行するデメリット

    View full-size slide

  12. ● インターネット接続が不要である
    - 回線状況がスループットに影響しない、オフラインでも実行可能
    ● サーバ不要
    - サーバを運用するコストが不要、または軽減できる
    - 実行速度は訓練より速いと言っても、ユーザからのリクエストが集中すればかなりの演
    算コストになるため、 AI用のサーバ運用はいろいろ大変
    ● プライバシーの尊重
    - ユーザのプライバシーに関わる情報をインターネット上に流出させなくても利用できる
    - サーバ型の場合は、ユーザの情報がインターネットを通過しないと利用できない
    - モバイル端末の扱うデータは個人のプライバシーに関わることが多い
    (顔写真、位置情報など)
    - 「万が一流出」というリスクを完全に除去することが出来る
    モバイル上で実行するメリット

    View full-size slide

  13. TensorFlowを
    Androidで動かす

    View full-size slide

  14. その前に…TensorFlowの仕組み
    TensorFlowにおいて、「あらゆる演算」は、グラフという抽象化されたデータ構造で表現されます。
    グラフはProtocol Buffers形式でシリアライズ可能な、 TensorFlowに依存しない*独立したデータです。
    TensorFlow グラフ
    グラフを作る
    グラフを実行する
    * TensorFlowには依存しないのでTensorFlowを使わなくても読み取ったり出来ますが、TensorFlowを使わずに読み取る意味はあまりないかもしれません。
    足し算でも行列計算でも
    計算機上で実行可能な
    あらゆる演算
    シリアライズできるので
    デバイス間を
    自由に持ち運び可能
    チェックポイ
    ント

    View full-size slide

  15. シリアライズ可能ということは、こういうこともできる
    TensorFlow グラフ
    サーバー:グラフを作る
    作成
    TensorFlow グラフ
    クライアント:グラフを実行する
    実行
    何らかの方法で
    ダウンロード
    チェックポイ
    ント
    チェックポイ
    ント

    View full-size slide

  16. 実行に必要なファイルは2つ
    グラフ
    チェック
    ポイント
    定数と、演算が保存されています。
    例)3 や 8.5 などの定数
    例)3a + b のような演算の情報
    変数の値が保存されています。
    例)a = 1.2 のような変数の値の情報

    View full-size slide

  17. 実行に必要なファイルは2つ1つ
    グラフ
    チェック
    ポイント
    凍結グラフ
    凍結
    本来、これらは機械学習のために分離されていますが、
    実行にあたっては面倒くさいだけなので、凍結処理
    (変数を定数に変換する処理)をすることで合体させます。
    ※もちろん、分離したままでも実行可能ですが、 2つファイルが必要になりますし、実行に必要な
    いメタ情報が多く含まれている分チェックポイントはファイルサイズが大きいので、実行だけであ
    れば凍結するのがおすすめです。

    View full-size slide

  18. Androidでグラフを実行するまでの流れ
    TensorFlow グラフ
    サーバー:グラフを作る
    作成
    TensorFlow 凍結グラフ
    Android:グラフを実行する
    実行
    ダウンロード
    チェックポイ
    ント
    凍結グラフ
    凍結

    View full-size slide

  19. 今日配布する
    サンプルアプリ
    手書きした数字を予測するAIをAndroid
    上で実行するアプリです。

    View full-size slide

  20. 手書き数字の認識モデル
    機械学習のハローワールドとも呼ばれているテー
    マで、手書きした数字を認識するモデルです。
    入力として784次元(28x28ピクセル)のピクセル強
    度の情報を受け取り、予測した数字の確率強度を
    出力として吐き出します。
    TensorFlowにおいて、機械学習のモデルもまた
    グラフで表現されます。
    ※通常、正規化した分布を出力する事が多いですが、本アプリでは出力された強度をそのまま出力しています。ご注意
    ください。
    28ピクセル
    28ピクセル
    モデル

    View full-size slide

  21. モデルの作成・訓練
    こちらにソースコードを用意しました。
    https://github.com/ornew/mnist-android/blob/master/model/mnist.py
    ダウンロードして頂き、 TensorFlowがインストールされた python環境で実行していただくと、モデルが生成さ
    れ出力されます。
    モデルの構造はTensorFlowチュートリアルと同様になっています。
    (チュートリアル翻訳: http://ornew.net/tensorflow-mnist-for-ml-beginners)
    皆さんのノートパソコンに実行させるのはあまりに酷なので、
    今日は出来合いのものをご用意しております。
    https://github.com/ornew/mnist-android/releases/download/v1.0.1/mnist.frozen.pb

    View full-size slide

  22. https://goo.gl/tgIOK9
    今日のサンプルプログラム
    Android Studioでビルドできるようにしてあります
    少し時間を取りますので
    ぜひお手元でお試しください。

    View full-size slide

  23. プログラム解説

    View full-size slide

  24. Androidで
    実行するまでの
    全体の流れ
    1. TensorFlowで機械学習をする
    2. 機械学習で出来た人工知能モデルを Protocol
    Buffers 形式でシリアライズする
    3. シリアライズされたモデルファイルをAndroid上
    にダウンロードする
    4. Android用にビルドしたTensorFlowコア
    ライブラリでモデルを実行する
    ※Android JavaアプリからJNIを介して実行します




    View full-size slide


  25. 機械学習する
    今回の手書き数字の認識モデルは、
    TensorFlow公式チュートリアルに
    ソースコードと解説があります。
    TensorFlowのpython APIでモデルを
    構築します。
    ソースコード全体:
    https://goo.gl/uq62wo
    チュートリアル翻訳:
    https://goo.gl/Vi9nox
    #!/usr/bin/env python
    import os
    import time
    import argparse
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from tensorflow.python.tools.freeze_graph import freeze_graph
    def build_inference(x, keep_prob=None):
    def weight(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name='weight')
    def bias(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name='bias')
    def convolution(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', name='convolutional')
    def pooling(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling')
    x_image = tf.reshape(x, [-1,28,28,1])
    with tf.name_scope('hidden_1'):
    W1 = weight([5, 5, 1, 32])
    b1 = bias([32])
    C1 = tf.nn.relu(convolution(x_image, W1) + b1)
    h1 = pooling(C1)
    with tf.name_scope('hidden_2'):
    W2 = weight([5, 5, 32, 64])
    b2 = bias([64])
    C2 = tf.nn.relu(convolution(h1, W2) + b2)
    h2 = pooling(C2)
    with tf.name_scope('fully_connect'):

    View full-size slide


  26. シリアライズ
    作った人工知能モデルをファイルに
    シリアライズします。
    グラフを書き出したら、今回は簡単
    のために凍結処理を行います。
    本来、TensorFlowのモデルデータは、
    グラフとチェックポイントの2つの
    ファイルから構成されます。
    凍結は、両者を合体してグラフに
    固めてしまう処理です。
    ソースコード全体:
    https://goo.gl/uq62wo
    learning_rate: 1e-4})
    writer.add_summary(summary, global_step=step)
    elapsed_time = time.time() - start
    print('done!')
    print('Total time: %f [sec]' % elapsed_time)
    ckpt = saver.save(session, os.path.join(FLAGS.ckpt_dir, 'ckpt'), global_step=step)
    print('Save checkpoint: %s' % ckpt)
    # Write graph.
    tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb', as_text=False)
    tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb.txt', as_text=True)
    # End default graph scope.
    # Freeze graph.
    freeze_graph(
    input_graph=os.path.join(FLAGS.model_dir, FLAGS.model_name + '.pb.txt'),
    input_saver='',
    input_binary=False,
    input_checkpoint=ckpt,
    output_node_names='readout/y',
    restore_op_name='save/restore_all',
    filename_tensor_name='save/Const:0',
    output_graph=os.path.join(FLAGS.model_dir, '%s.frozen.pb' % FLAGS.model_name),
    clear_devices=False,
    initializer_nodes='')

    View full-size slide


  27. ダウンロード
    適当にアップロードしたモデルをAndroid
    アプリからダウンロードします。もちろん、
    アセットなどに組み込んでも構いません。
    今回のサンプルでは、HTTP通信でダウ
    ンロードしたモデルを、ローカルストレー
    ジに通常のファイルと同様に保存します。
    public class Download extends AsyncTask {
    private Context context;
    private String pathTo;
    Download(Context context, String pathTo) {
    this.context = context;
    this.pathTo = pathTo;
    }
    @Override
    protected Boolean doInBackground(URL... urls) {
    final byte[] buffer = new byte[4096];
    HttpURLConnection connection = null;
    InputStream input = null;
    OutputStream output = null;
    try {
    connection = (HttpURLConnection) urls[0].openConnection();
    connection.connect();
    int length = connection.getContentLength();
    input = connection.getInputStream();
    output = this.context.openFileOutput(pathTo, Context.MODE_PRIVATE);
    int totalBytes = 0;
    int bytes = 0;
    while ((bytes = input.read(buffer)) != -1) {
    output.write(buffer, 0, bytes);
    totalBytes += bytes;
    publishProgress((int)(totalBytes * 100.f / length));
    }
    } catch (IOException e) {
    e.printStackTrace();
    } finally {
    try {
    if(input != null){
    input.close();
    }

    View full-size slide


  28. 実行
    「識別」ボタンが押されたタイミングでモデ
    ルに入力を与えて実行します。
    今回の場合は、Viewのピクセル情報を配
    列で渡しています。
    @Override
    protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);
    canvas = (CanvasView) findViewById(R.id.canvas);
    recognize = (Button) findViewById(R.id.recognize);
    result = (TextView) findViewById(R.id.result);
    recognize.setOnClickListener(new View.OnClickListener() {
    @Override
    public void onClick(View v) {
    int pixels[] = canvas.getPixels();
    float data[] = new float[pixels.length];
    for(int i = 0; i < pixels.length; ++i){
    data[i] = Color.alpha(pixels[i]) / 255.f;
    }
    float answers[] = MNIST.inference(data);
    result.setText("");
    int bestIndex = 0;
    float bestScore = answers[0];
    for(int i = 1; i < answers.length; ++i) {
    if(bestScore < answers[i]){
    bestScore = answers[i];
    bestIndex = i;
    }
    }
    print("予測: " + bestIndex);
    for(int i = 0; i < answers.length; ++i) {
    print(i + ": " + answers[i]);
    }
    }
    });

    View full-size slide

  29. MNISTクラス
    MNISTクラスは、JNIによって実装されて
    います。
    JNIとはJava Native Interfaceのことで、
    ネイティブコードをJavaから実行するため
    の仕組みです。
    このクラスのメソッドが実行された場合、
    実際には読み込んでいる共有ライブラリ
    が実行されます。
    この共有ライブラリ内部でTensorFlow
    C++ APIが実行されています。
    public class MNIST {
    static {
    System.loadLibrary("mnist");
    }
    public static native void initialize(String model_path);
    public static native float[] inference(float x[]);
    }
    モデルのファイルパスを引数に
    受け取りTensorFlowを初期化する
    メソッド
    ピクセル情報を配列で受け取り
    予測した数字の確率分布を返す
    メソッド
    共有ライブラリ「libmnist.so」を読み込む
    (Android Studioの場合、AndroidManifest.xmlのあるディレクトリに
    jniLibs/armeabi-v7aというディレクトリを作成し、その中に共有ライブラリを入れるこ
    とで読み込めます)

    View full-size slide

  30. C++? Javaで出来ないの?
    Javaは正式サポートされていません。
    C++ APIをJNIで呼び出す試験的なJava APIが有志の手によって提供されて
    いますが、基本的にやっていることは同じです。
    まだまだ不安定で、機能が少ないので使いづらいですが、いずれ安定してJavaから
    簡単に使えるようになると思います。

    View full-size slide

  31. C++コード
    ソースコード全体:
    https://goo.gl/2mu7bR
    共有ライブラリのビルド方法:
    (リンク先はv0.11になっていますがv1.0でも同じ方法です)
    https://goo.gl/N4VLAC
    // Tensors
    Tensor x(DT_FLOAT, TensorShape({1, mnist_input_size}));
    Tensor keep_prob(DT_FLOAT, TensorShape());
    // Initialize Tensors
    auto _x = x.flat();
    for(int i = 0; i < mnist_input_size; ++i){
    _x(i) = inputs[i];
    }
    keep_prob.scalar()() = 1.0f;
    // Input and Output
    vector> feed_dict({
    {"x", x},
    {"keep_prob", keep_prob}
    });
    vector outputs;
    // Run session
    status = global_session->Run(feed_dict, {"readout/y:0"}, {}, &outputs);
    if(!status.ok()){
    Log::i(status.error_message().c_str());
    env->FatalError("Failed to execute runnning to inference operation.");
    }

    View full-size slide