Upgrade to Pro — share decks privately, control downloads, hide ads and more …

TensorFlowで作ったAIをAndroidアプリで実行する

 TensorFlowで作ったAIをAndroidアプリで実行する

TensorFlowで機械学習をしたグラフをAndroid端末上でで実行する方法についてご紹介します。

このスライドは 2017/03/23 「TensorFlow Android ワークショップ」での発表に使用したスライドです。
https://tfug-tokyo.connpass.com/event/52648/

Arata Furukawa

March 23, 2017
Tweet

More Decks by Arata Furukawa

Other Decks in Technology

Transcript

  1. TensorFlowの 公式サンプルでは ・左 リアルタイムに人物を認識し、トラッキン グするモデルのサンプル。1秒未満のラ グで複数の人物の動きを追跡できる。 ・右 Google Inceptionモデルによるリアルタイ ム物体識別。写真ではカメラに写った

    iPhoneを「iPod」か「携帯電話」であると 判断している。予測にかかる時間は 100-300msと、体感では一瞬で識別でき ている。 ※写真はNexus5での実行結果です。  端末や使用状況によって実行速度は変化します。
  2. • 電池の減りが早くなる(と思われる) - 明らかに計算量が増えるので電池の減りが早くなることが予想されます。 - どのくらいかは、モデル次第で大きく変わってくるので、定量的に示せません …。 開発ごとに実験して、チューニングが必要です。 • 通常に比べるとアプリのデータサイズが増える

    - モデルデータは10MB〜100MBくらいは必要。チューニングに左右される点でもある。 - 精度を保ったままモデルを縮小するといった先行研究はあります。 • チューニング不可欠・専門の技術者が必要 - モバイル端末のスペックに限界がある以上、サーバでの計算以上に高度なチューニングが要 求される。モデルの精度やサイズをコントロールできるだけの技術と知識を持った人間が必 要。 モバイル上で実行するデメリット
  3. • インターネット接続が不要である - 回線状況がスループットに影響しない、オフラインでも実行可能 • サーバ不要 - サーバを運用するコストが不要、または軽減できる - 実行速度は訓練より速いと言っても、ユーザからのリクエストが集中すればかなりの演

    算コストになるため、 AI用のサーバ運用はいろいろ大変 • プライバシーの尊重 - ユーザのプライバシーに関わる情報をインターネット上に流出させなくても利用できる - サーバ型の場合は、ユーザの情報がインターネットを通過しないと利用できない - モバイル端末の扱うデータは個人のプライバシーに関わることが多い (顔写真、位置情報など) - 「万が一流出」というリスクを完全に除去することが出来る モバイル上で実行するメリット
  4. その前に…TensorFlowの仕組み TensorFlowにおいて、「あらゆる演算」は、グラフという抽象化されたデータ構造で表現されます。 グラフはProtocol Buffers形式でシリアライズ可能な、 TensorFlowに依存しない*独立したデータです。 TensorFlow グラフ グラフを作る グラフを実行する *

    TensorFlowには依存しないのでTensorFlowを使わなくても読み取ったり出来ますが、TensorFlowを使わずに読み取る意味はあまりないかもしれません。 足し算でも行列計算でも 計算機上で実行可能な あらゆる演算 シリアライズできるので デバイス間を 自由に持ち運び可能 チェックポイ ント
  5. 実行に必要なファイルは2つ グラフ チェック ポイント 定数と、演算が保存されています。 例)3 や 8.5 などの定数 例)3a

    + b のような演算の情報 変数の値が保存されています。 例)a = 1.2 のような変数の値の情報
  6. Androidで 実行するまでの 全体の流れ 1. TensorFlowで機械学習をする 2. 機械学習で出来た人工知能モデルを Protocol Buffers 形式でシリアライズする

    3. シリアライズされたモデルファイルをAndroid上 にダウンロードする 4. Android用にビルドしたTensorFlowコア ライブラリでモデルを実行する ※Android JavaアプリからJNIを介して実行します 1 2 3 4
  7. 1 機械学習する 今回の手書き数字の認識モデルは、 TensorFlow公式チュートリアルに ソースコードと解説があります。 TensorFlowのpython APIでモデルを 構築します。 ソースコード全体: https://goo.gl/uq62wo

    チュートリアル翻訳: https://goo.gl/Vi9nox #!/usr/bin/env python import os import time import argparse import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.python.tools.freeze_graph import freeze_graph def build_inference(x, keep_prob=None): def weight(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial, name='weight') def bias(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial, name='bias') def convolution(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', name='convolutional') def pooling(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling') x_image = tf.reshape(x, [-1,28,28,1]) with tf.name_scope('hidden_1'): W1 = weight([5, 5, 1, 32]) b1 = bias([32]) C1 = tf.nn.relu(convolution(x_image, W1) + b1) h1 = pooling(C1) with tf.name_scope('hidden_2'): W2 = weight([5, 5, 32, 64]) b2 = bias([64]) C2 = tf.nn.relu(convolution(h1, W2) + b2) h2 = pooling(C2) with tf.name_scope('fully_connect'):
  8. 2 シリアライズ 作った人工知能モデルをファイルに シリアライズします。 グラフを書き出したら、今回は簡単 のために凍結処理を行います。 本来、TensorFlowのモデルデータは、 グラフとチェックポイントの2つの ファイルから構成されます。 凍結は、両者を合体してグラフに

    固めてしまう処理です。 ソースコード全体: https://goo.gl/uq62wo learning_rate: 1e-4}) writer.add_summary(summary, global_step=step) elapsed_time = time.time() - start print('done!') print('Total time: %f [sec]' % elapsed_time) ckpt = saver.save(session, os.path.join(FLAGS.ckpt_dir, 'ckpt'), global_step=step) print('Save checkpoint: %s' % ckpt) # Write graph. tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb', as_text=False) tf.train.write_graph(session.graph.as_graph_def(), FLAGS.model_dir, FLAGS.model_name + '.pb.txt', as_text=True) # End default graph scope. # Freeze graph. freeze_graph( input_graph=os.path.join(FLAGS.model_dir, FLAGS.model_name + '.pb.txt'), input_saver='', input_binary=False, input_checkpoint=ckpt, output_node_names='readout/y', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=os.path.join(FLAGS.model_dir, '%s.frozen.pb' % FLAGS.model_name), clear_devices=False, initializer_nodes='')
  9. 3 ダウンロード 適当にアップロードしたモデルをAndroid アプリからダウンロードします。もちろん、 アセットなどに組み込んでも構いません。 今回のサンプルでは、HTTP通信でダウ ンロードしたモデルを、ローカルストレー ジに通常のファイルと同様に保存します。 public class

    Download extends AsyncTask<URL, Integer, Boolean> { private Context context; private String pathTo; Download(Context context, String pathTo) { this.context = context; this.pathTo = pathTo; } @Override protected Boolean doInBackground(URL... urls) { final byte[] buffer = new byte[4096]; HttpURLConnection connection = null; InputStream input = null; OutputStream output = null; try { connection = (HttpURLConnection) urls[0].openConnection(); connection.connect(); int length = connection.getContentLength(); input = connection.getInputStream(); output = this.context.openFileOutput(pathTo, Context.MODE_PRIVATE); int totalBytes = 0; int bytes = 0; while ((bytes = input.read(buffer)) != -1) { output.write(buffer, 0, bytes); totalBytes += bytes; publishProgress((int)(totalBytes * 100.f / length)); } } catch (IOException e) { e.printStackTrace(); } finally { try { if(input != null){ input.close(); }
  10. 4 実行 「識別」ボタンが押されたタイミングでモデ ルに入力を与えて実行します。 今回の場合は、Viewのピクセル情報を配 列で渡しています。 @Override protected void onCreate(Bundle

    savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); canvas = (CanvasView) findViewById(R.id.canvas); recognize = (Button) findViewById(R.id.recognize); result = (TextView) findViewById(R.id.result); recognize.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { int pixels[] = canvas.getPixels(); float data[] = new float[pixels.length]; for(int i = 0; i < pixels.length; ++i){ data[i] = Color.alpha(pixels[i]) / 255.f; } float answers[] = MNIST.inference(data); result.setText(""); int bestIndex = 0; float bestScore = answers[0]; for(int i = 1; i < answers.length; ++i) { if(bestScore < answers[i]){ bestScore = answers[i]; bestIndex = i; } } print("予測: " + bestIndex); for(int i = 0; i < answers.length; ++i) { print(i + ": " + answers[i]); } } });
  11. MNISTクラス MNISTクラスは、JNIによって実装されて います。 JNIとはJava Native Interfaceのことで、 ネイティブコードをJavaから実行するため の仕組みです。 このクラスのメソッドが実行された場合、 実際には読み込んでいる共有ライブラリ

    が実行されます。 この共有ライブラリ内部でTensorFlow C++ APIが実行されています。 public class MNIST { static { System.loadLibrary("mnist"); } public static native void initialize(String model_path); public static native float[] inference(float x[]); } モデルのファイルパスを引数に 受け取りTensorFlowを初期化する メソッド ピクセル情報を配列で受け取り 予測した数字の確率分布を返す メソッド 共有ライブラリ「libmnist.so」を読み込む (Android Studioの場合、AndroidManifest.xmlのあるディレクトリに jniLibs/armeabi-v7aというディレクトリを作成し、その中に共有ライブラリを入れるこ とで読み込めます)
  12. C++コード ソースコード全体: https://goo.gl/2mu7bR 共有ライブラリのビルド方法: (リンク先はv0.11になっていますがv1.0でも同じ方法です) https://goo.gl/N4VLAC // Tensors Tensor x(DT_FLOAT,

    TensorShape({1, mnist_input_size})); Tensor keep_prob(DT_FLOAT, TensorShape()); // Initialize Tensors auto _x = x.flat<float>(); for(int i = 0; i < mnist_input_size; ++i){ _x(i) = inputs[i]; } keep_prob.scalar<float>()() = 1.0f; // Input and Output vector<pair<string, Tensor>> feed_dict({ {"x", x}, {"keep_prob", keep_prob} }); vector<Tensor> outputs; // Run session status = global_session->Run(feed_dict, {"readout/y:0"}, {}, &outputs); if(!status.ok()){ Log::i(status.error_message().c_str()); env->FatalError("Failed to execute runnning to inference operation."); }