Upgrade to Pro — share decks privately, control downloads, hide ads and more …

TensorFlowで作ったAIをAndroidアプリで実行する

 TensorFlowで作ったAIをAndroidアプリで実行する

TensorFlowで機械学習をしたグラフをAndroid端末上でで実行する方法についてご紹介します。

このスライドは 2017/03/23 「TensorFlow Android ワークショップ」での発表に使用したスライドです。
https://tfug-tokyo.connpass.com/event/52648/

D745b61e4ca7584109de26b112442e56?s=128

Arata Furukawa

March 23, 2017
Tweet

More Decks by Arata Furukawa

Other Decks in Technology

Transcript

  1. TensorFlowで作ったAIを Androidアプリで実行する 2017-03-23 APPDOJO TensorFlow User Group / 東海大学 理学部1年生

    古川新
  2. 今日の流れ ・モバイルとAIについて ・サンプルアプリ配布 ・プログラム解説

  3. △モバイルアプリでAIを利用する ◎モバイルアプリでAIを実行する

  4. モバイルアプリでAIを利用する モバイル上で 実行する サーバで実行して モバイルで利用する

  5. でも、(計算コスト)お高いんでしょう? 確かに、やれディープラーニングだとか、やれ人工知能だとか、 よくわからないけどGPUが何百枚だ何だと話をしているのを聞きます。 本当にモバイルで動くんでしょうか? まあ、動くんですが…

  6. 人工知能は 筋肉と同じ

  7. 筋肉を鍛えるには、たくさんの時間とお金が必要です 何回も何時間もトレーニングして、 少しずつ筋肉は強くなります

  8. 一方で、筋肉を動かすこと自体は簡単 時間もお金もかかりませんよね?

  9. 機械学習における学習と実行の非対称性 機械学習によって「AIを作る」場合の演算コストと 「作ったAIを動かす」場合の演算コストは別の話。 実行するだけなら モバイル端末でも十分。

  10. TensorFlowの 公式サンプルでは ・左 リアルタイムに人物を認識し、トラッキン グするモデルのサンプル。1秒未満のラ グで複数の人物の動きを追跡できる。 ・右 Google Inceptionモデルによるリアルタイ ム物体識別。写真ではカメラに写った

    iPhoneを「iPod」か「携帯電話」であると 判断している。予測にかかる時間は 100-300msと、体感では一瞬で識別でき ている。 ※写真はNexus5での実行結果です。  端末や使用状況によって実行速度は変化します。
  11. • 電池の減りが早くなる(と思われる) - 明らかに計算量が増えるので電池の減りが早くなることが予想されます。 - どのくらいかは、モデル次第で大きく変わってくるので、定量的に示せません …。 開発ごとに実験して、チューニングが必要です。 • 通常に比べるとアプリのデータサイズが増える

    - モデルデータは10MB〜100MBくらいは必要。チューニングに左右される点でもある。 - 精度を保ったままモデルを縮小するといった先行研究はあります。 • チューニング不可欠・専門の技術者が必要 - モバイル端末のスペックに限界がある以上、サーバでの計算以上に高度なチューニングが要 求される。モデルの精度やサイズをコントロールできるだけの技術と知識を持った人間が必 要。 モバイル上で実行するデメリット
  12. • インターネット接続が不要である - 回線状況がスループットに影響しない、オフラインでも実行可能 • サーバ不要 - サーバを運用するコストが不要、または軽減できる - 実行速度は訓練より速いと言っても、ユーザからのリクエストが集中すればかなりの演

    算コストになるため、 AI用のサーバ運用はいろいろ大変 • プライバシーの尊重 - ユーザのプライバシーに関わる情報をインターネット上に流出させなくても利用できる - サーバ型の場合は、ユーザの情報がインターネットを通過しないと利用できない - モバイル端末の扱うデータは個人のプライバシーに関わることが多い (顔写真、位置情報など) - 「万が一流出」というリスクを完全に除去することが出来る モバイル上で実行するメリット
  13. TensorFlowを Androidで動かす

  14. その前に…TensorFlowの仕組み TensorFlowにおいて、「あらゆる演算」は、グラフという抽象化されたデータ構造で表現されます。 グラフはProtocol Buffers形式でシリアライズ可能な、 TensorFlowに依存しない*独立したデータです。 TensorFlow グラフ グラフを作る グラフを実行する *

    TensorFlowには依存しないのでTensorFlowを使わなくても読み取ったり出来ますが、TensorFlowを使わずに読み取る意味はあまりないかもしれません。 足し算でも行列計算でも 計算機上で実行可能な あらゆる演算 シリアライズできるので デバイス間を 自由に持ち運び可能 チェックポイ ント
  15. シリアライズ可能ということは、こういうこともできる TensorFlow グラフ サーバー:グラフを作る 作成 TensorFlow グラフ クライアント:グラフを実行する 実行 何らかの方法で

    ダウンロード チェックポイ ント チェックポイ ント
  16. 実行に必要なファイルは2つ グラフ チェック ポイント 定数と、演算が保存されています。 例)3 や 8.5 などの定数 例)3a

    + b のような演算の情報 変数の値が保存されています。 例)a = 1.2 のような変数の値の情報
  17. 実行に必要なファイルは2つ1つ グラフ チェック ポイント 凍結グラフ 凍結 本来、これらは機械学習のために分離されていますが、 実行にあたっては面倒くさいだけなので、凍結処理 (変数を定数に変換する処理)をすることで合体させます。 ※もちろん、分離したままでも実行可能ですが、

    2つファイルが必要になりますし、実行に必要な いメタ情報が多く含まれている分チェックポイントはファイルサイズが大きいので、実行だけであ れば凍結するのがおすすめです。
  18. Androidでグラフを実行するまでの流れ TensorFlow グラフ サーバー:グラフを作る 作成 TensorFlow 凍結グラフ Android:グラフを実行する 実行 ダウンロード

    チェックポイ ント 凍結グラフ 凍結
  19. 今日配布する サンプルアプリ 手書きした数字を予測するAIをAndroid 上で実行するアプリです。

  20. 手書き数字の認識モデル 機械学習のハローワールドとも呼ばれているテー マで、手書きした数字を認識するモデルです。 入力として784次元(28x28ピクセル)のピクセル強 度の情報を受け取り、予測した数字の確率強度を 出力として吐き出します。 TensorFlowにおいて、機械学習のモデルもまた グラフで表現されます。 ※通常、正規化した分布を出力する事が多いですが、本アプリでは出力された強度をそのまま出力しています。ご注意 ください。

    28ピクセル 28ピクセル モデル 3
  21. モデルの作成・訓練 こちらにソースコードを用意しました。 https://github.com/ornew/mnist-android/blob/master/model/mnist.py ダウンロードして頂き、 TensorFlowがインストールされた python環境で実行していただくと、モデルが生成さ れ出力されます。 モデルの構造はTensorFlowチュートリアルと同様になっています。 (チュートリアル翻訳: http://ornew.net/tensorflow-mnist-for-ml-beginners)

    皆さんのノートパソコンに実行させるのはあまりに酷なので、 今日は出来合いのものをご用意しております。 https://github.com/ornew/mnist-android/releases/download/v1.0.1/mnist.frozen.pb
  22. https://goo.gl/tgIOK9 今日のサンプルプログラム Android Studioでビルドできるようにしてあります 少し時間を取りますので ぜひお手元でお試しください。

  23. プログラム解説

  24. Androidで 実行するまでの 全体の流れ 1. TensorFlowで機械学習をする 2. 機械学習で出来た人工知能モデルを Protocol Buffers 形式でシリアライズする

    3. シリアライズされたモデルファイルをAndroid上 にダウンロードする 4. Android用にビルドしたTensorFlowコア ライブラリでモデルを実行する ※Android JavaアプリからJNIを介して実行します 1 2 3 4
  25. 1 機械学習する 今回の手書き数字の認識モデルは、 TensorFlow公式チュートリアルに ソースコードと解説があります。 TensorFlowのpython APIでモデルを 構築します。 ソースコード全体: https://goo.gl/uq62wo

    チュートリアル翻訳: https://goo.gl/Vi9nox #!/usr/bin/env python import os import time import argparse import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.python.tools.freeze_graph import freeze_graph def build_inference(x, keep_prob=None): def weight(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial, name='weight') def bias(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial, name='bias') def convolution(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', name='convolutional') def pooling(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling') x_image = tf.reshape(x, [-1,28,28,1]) with tf.name_scope('hidden_1'): W1 = weight([5, 5, 1, 32]) b1 = bias([32]) C1 = tf.nn.relu(convolution(x_image, W1) + b1) h1 = pooling(C1) with tf.name_scope('hidden_2'): W2 = weight([5, 5, 32, 64]) b2 = bias([64]) C2 = tf.nn.relu(convolution(h1, W2) + b2) h2 = pooling(C2) with tf.name_scope('fully_connect'):
  26. 2 シリアライズ 作った人工知能モデルをファイルに シリアライズします。 グラフを書き出したら、今回は簡単 のために凍結処理を行います。 本来、TensorFlowのモデルデータは、 グラフとチェックポイントの2つの ファイルから構成されます。 凍結は、両者を合体してグラフに

    固めてしまう処理です。 ソースコード全体: https://goo.gl/uq62wo learning_rate: 1e-4}) writer.add_summary(summary, global_step=step) elapsed_time = time.time() - start print('done!') print('Total time: %f [sec]' % elapsed_time) ckpt = saver.save(session, os.path.join(FLAGS.ckpt_dir, 'ckpt'), global_step=step) print('Save checkpoint: %s' % ckpt) # Write graph. tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb', as_text=False) tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb.txt', as_text=True) # End default graph scope. # Freeze graph. freeze_graph( input_graph=os.path.join(FLAGS.model_dir, FLAGS.model_name + '.pb.txt'), input_saver='', input_binary=False, input_checkpoint=ckpt, output_node_names='readout/y', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=os.path.join(FLAGS.model_dir, '%s.frozen.pb' % FLAGS.model_name), clear_devices=False, initializer_nodes='')
  27. 3 ダウンロード 適当にアップロードしたモデルをAndroid アプリからダウンロードします。もちろん、 アセットなどに組み込んでも構いません。 今回のサンプルでは、HTTP通信でダウ ンロードしたモデルを、ローカルストレー ジに通常のファイルと同様に保存します。 public class

    Download extends AsyncTask<URL, Integer, Boolean> { private Context context; private String pathTo; Download(Context context, String pathTo) { this.context = context; this.pathTo = pathTo; } @Override protected Boolean doInBackground(URL... urls) { final byte[] buffer = new byte[4096]; HttpURLConnection connection = null; InputStream input = null; OutputStream output = null; try { connection = (HttpURLConnection) urls[0].openConnection(); connection.connect(); int length = connection.getContentLength(); input = connection.getInputStream(); output = this.context.openFileOutput(pathTo, Context.MODE_PRIVATE); int totalBytes = 0; int bytes = 0; while ((bytes = input.read(buffer)) != -1) { output.write(buffer, 0, bytes); totalBytes += bytes; publishProgress((int)(totalBytes * 100.f / length)); } } catch (IOException e) { e.printStackTrace(); } finally { try { if(input != null){ input.close(); }
  28. 4 実行 「識別」ボタンが押されたタイミングでモデ ルに入力を与えて実行します。 今回の場合は、Viewのピクセル情報を配 列で渡しています。 @Override protected void onCreate(Bundle

    savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); canvas = (CanvasView) findViewById(R.id.canvas); recognize = (Button) findViewById(R.id.recognize); result = (TextView) findViewById(R.id.result); recognize.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { int pixels[] = canvas.getPixels(); float data[] = new float[pixels.length]; for(int i = 0; i < pixels.length; ++i){ data[i] = Color.alpha(pixels[i]) / 255.f; } float answers[] = MNIST.inference(data); result.setText(""); int bestIndex = 0; float bestScore = answers[0]; for(int i = 1; i < answers.length; ++i) { if(bestScore < answers[i]){ bestScore = answers[i]; bestIndex = i; } } print("予測: " + bestIndex); for(int i = 0; i < answers.length; ++i) { print(i + ": " + answers[i]); } } });
  29. MNISTクラス MNISTクラスは、JNIによって実装されて います。 JNIとはJava Native Interfaceのことで、 ネイティブコードをJavaから実行するため の仕組みです。 このクラスのメソッドが実行された場合、 実際には読み込んでいる共有ライブラリ

    が実行されます。 この共有ライブラリ内部でTensorFlow C++ APIが実行されています。 public class MNIST { static { System.loadLibrary("mnist"); } public static native void initialize(String model_path); public static native float[] inference(float x[]); } モデルのファイルパスを引数に 受け取りTensorFlowを初期化する メソッド ピクセル情報を配列で受け取り 予測した数字の確率分布を返す メソッド 共有ライブラリ「libmnist.so」を読み込む (Android Studioの場合、AndroidManifest.xmlのあるディレクトリに jniLibs/armeabi-v7aというディレクトリを作成し、その中に共有ライブラリを入れるこ とで読み込めます)
  30. C++? Javaで出来ないの? Javaは正式サポートされていません。 C++ APIをJNIで呼び出す試験的なJava APIが有志の手によって提供されて いますが、基本的にやっていることは同じです。 まだまだ不安定で、機能が少ないので使いづらいですが、いずれ安定してJavaから 簡単に使えるようになると思います。

  31. C++コード ソースコード全体: https://goo.gl/2mu7bR 共有ライブラリのビルド方法: (リンク先はv0.11になっていますがv1.0でも同じ方法です) https://goo.gl/N4VLAC // Tensors Tensor x(DT_FLOAT,

    TensorShape({1, mnist_input_size})); Tensor keep_prob(DT_FLOAT, TensorShape()); // Initialize Tensors auto _x = x.flat<float>(); for(int i = 0; i < mnist_input_size; ++i){ _x(i) = inputs[i]; } keep_prob.scalar<float>()() = 1.0f; // Input and Output vector<pair<string, Tensor>> feed_dict({ {"x", x}, {"keep_prob", keep_prob} }); vector<Tensor> outputs; // Run session status = global_session->Run(feed_dict, {"readout/y:0"}, {}, &outputs); if(!status.ok()){ Log::i(status.error_message().c_str()); env->FatalError("Failed to execute runnning to inference operation."); }