欢迎Follow我的GitHub,关注我的

在Android工程中,集成TensorFlow模型。运行TensorFlow的默认Android工程,请参考。

Android源码:https://github.com/SpikeKing/TFAndroid/tree/master

库及模型的大小

libtensorflow_inference.so  10.2 Mlibandroid_tensorflow_inference_java.jar  27 KBoptimized_tfdroid.pb  291 B

如果将so转换为jar库,参考,则TF的so由10.2M缩小至4.1M。

TensorFlow集成Android工程的框架_第1张图片 TF Android

TensorFlow

TF模型源码:
https://github.com/SpikeKing/MachineLearningTutorial/blob/master/tests/android_test.py

创建TensorFlow模型,简单的y=WX+b,存储图信息write_graph,存储参数信息saver.save。输入数据placeholder是I,输出数据是O

import tensorflow as tfI = tf.placeholder(tf.float32, shape=[None, 3], name='I')  # inputW = tf.Variable(tf.zeros(shape=[3, 2]), dtype=tf.float32, name='W')  # weightsb = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b')  # biasesO = tf.nn.relu(tf.matmul(I, W) + b, name='O')  # activation / outputsaver = tf.train.Saver()init_op = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init_op)    tf.train.write_graph(sess.graph_def, './data/android/', 'tfdroid.pbtxt')  # 存储TensorFlow的图    # 训练数据,本例直接赋值    sess.run(tf.assign(W, [[1, 2], [4, 5], [7, 8]]))    sess.run(tf.assign(b, [1, 1]))    # 存储checkpoint文件,即参数信息    saver.save(sess, './data/android/tfdroid.ckpt')

创建Freeze的图,将图结构与参数组合在一起,生成模型,参考。

def gnr_freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,                     output_node_names, output_graph, clear_devices):    """    将输入图与参数结合在一起        :param input_graph: 输入图    :param input_saver: Saver解析器    :param input_binary: 输入图的格式,false是文本,true是二进制    :param input_checkpoint: checkpoint,检查点文件        :param output_node_names: 输出节点名称    :param output_graph: 保存输出文件    :param clear_devices: 清除训练设备    :return: NULL    """    restore_op_name = "save/restore_all"    filename_tensor_name = "save/Const:0"    freeze_graph.freeze_graph(        input_graph=input_graph,  # 输入图        input_saver=input_saver,  # Saver解析器        input_binary=input_binary,  # 输入图的格式,false是文本,true是二进制        input_checkpoint=input_checkpoint,  # checkpoint,检查点文件        output_node_names=output_node_names,  # 输出节点名称        restore_op_name=restore_op_name,  # 从模型恢复节点的名字        filename_tensor_name=filename_tensor_name,  # tensor名称        output_graph=output_graph,  # 保存输出文件        clear_devices=clear_devices,  # 清除训练设备        initializer_nodes="")  # 初始化节点

优化模型,剪切节点,模型只保留输入输出的参数。

def gnr_optimize_graph(graph_path, optimized_graph_path):    """    优化图    :param graph_path: 原始图    :param optimized_graph_path: 优化的图    :return: NULL    """    input_graph_def = tf.GraphDef()  # 读取原始图    with tf.gfile.Open(graph_path, "r") as f:        data = f.read()        input_graph_def.ParseFromString(data)    # 设置输入输出节点,剪切分支,大约节省1/4    output_graph_def = optimize_for_inference_lib.optimize_for_inference(        input_graph_def,        ["I"],  # an array of the input node(s)        ["O"],  # an array of output nodes        tf.float32.as_datatype_enum)    # 存储优化的图    f = tf.gfile.FastGFile(optimized_graph_path, "w")    f.write(output_graph_def.SerializeToString())

执行函数,生成模型,frozen_tfdroid.pboptimized_tfdroid.pb

if __name__ == "__main__":    input_graph_path = MODEL_FOLDER + MODEL_NAME + '.pbtxt'  # 输入图    checkpoint_path = MODEL_FOLDER + MODEL_NAME + '.ckpt'  # 输入参数    output_path = MODEL_FOLDER + 'frozen_' + MODEL_NAME + '.pb'  # Freeze模型    gnr_freeze_graph(input_graph=input_graph_path, input_saver="",                     input_binary=False, input_checkpoint=checkpoint_path,                     output_node_names="O", output_graph=output_path, clear_devices=True)    optimized_output_graph = MODEL_FOLDER + 'optimized_' + MODEL_NAME + '.pb'    gnr_optimize_graph(output_path, optimized_output_graph)

Android

编译Android的库,参考,或者,直接在Nightly中下载,参考,archive.zip,大约158M。

创建Android工程,添加app/libs/中添加库文件。

armeabi-v7a/libtensorflow_inference.solibandroid_tensorflow_inference_java.jar

在build.gradle中,添加

android {    sourceSets {        main {            jniLibs.srcDirs = ['libs']        }    }}

在app/src/main/assets中,添加模型optimized_tfdroid.pb文件。

在MainActivity中,添加so库。

static {    System.loadLibrary("tensorflow_inference");}

模型文件在assets中,TF的核心接口类TensorFlowInferenceInterface。

private static final String MODEL_FILE = "file:///android_asset/optimized_tfdroid.pb";private TensorFlowInferenceInterface mInferenceInterface;

初始模型文件

mInferenceInterface = new TensorFlowInferenceInterface();mInferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);

模型Feed数据,输入点名称是INPUT_NODE,输入结构INPUT_SIZE,输入数据inputFloats。

float[] inputFloats = {num1, num2, num3};mInferenceInterface.fillNodeFloat(INPUT_NODE, INPUT_SIZE, inputFloats);

模型执行文件,输出点名称是OUTPUT_NODE,即"O"

mInferenceInterface.runInference(new String[]{OUTPUT_NODE});

输出数据结构

float[] resu = {0, 0};mInferenceInterface.readNodeFloat(OUTPUT_NODE, resu);

最后,在layout中创建GUI布局。

效果

TensorFlow集成Android工程的框架_第2张图片 Demo

TensorFlow集成至春雨医生

TensorFlow集成Android工程的框架_第3张图片 CY-TF

That's all! Enjoy it!

更多相关文章

  1. Android 读取 assets目录下的文件
  2. android 五种 布局文件
  3. Android安装apk文件,适配Android 7.0
  4. IDA调试Android so文件
  5. 《Android面试宝典》学习笔记(第五章:文件存储)
  6. Android 文件读写操作 总结
  7. Android基础知识:Day02 常见布局、logcat相关和文件读写
  8. android解析xml文件的方式

随机推荐

  1. MySql反向模糊查询
  2. 如何创建a '。sql的文件
  3. MYSQL5.5和5.6参数的差异
  4. 在MySQL数据库中存储无法访问的用户
  5. MySQL很有用的命令
  6. MySQL查询中的变量会导致错误
  7. mysql字符集浅谈
  8. 反驳"MySQL InnoDB (不行)的性能问题",千
  9. 在同一列上选择多个条件
  10. MYSQL必知必会-SQL语句查询