TensorFlow集成Android工程的框架
欢迎Follow我的GitHub,关注我的
在Android工程中,集成TensorFlow模型。运行TensorFlow的默认Android工程,请参考。
Android源码:https://github.com/SpikeKing/TFAndroid/tree/master
库及模型的大小
libtensorflow_inference.so 10.2 Mlibandroid_tensorflow_inference_java.jar 27 KBoptimized_tfdroid.pb 291 B
如果将so转换为jar库,参考,则TF的so由10.2M缩小至4.1M。
TF AndroidTensorFlow
TF模型源码:
https://github.com/SpikeKing/MachineLearningTutorial/blob/master/tests/android_test.py
创建TensorFlow模型,简单的y=WX+b
,存储图信息write_graph
,存储参数信息saver.save
。输入数据placeholder是I
,输出数据是O
。
import tensorflow as tfI = tf.placeholder(tf.float32, shape=[None, 3], name='I') # inputW = tf.Variable(tf.zeros(shape=[3, 2]), dtype=tf.float32, name='W') # weightsb = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biasesO = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / outputsaver = tf.train.Saver()init_op = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init_op) tf.train.write_graph(sess.graph_def, './data/android/', 'tfdroid.pbtxt') # 存储TensorFlow的图 # 训练数据,本例直接赋值 sess.run(tf.assign(W, [[1, 2], [4, 5], [7, 8]])) sess.run(tf.assign(b, [1, 1])) # 存储checkpoint文件,即参数信息 saver.save(sess, './data/android/tfdroid.ckpt')
创建Freeze的图,将图结构与参数组合在一起,生成模型,参考。
def gnr_freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, output_graph, clear_devices): """ 将输入图与参数结合在一起 :param input_graph: 输入图 :param input_saver: Saver解析器 :param input_binary: 输入图的格式,false是文本,true是二进制 :param input_checkpoint: checkpoint,检查点文件 :param output_node_names: 输出节点名称 :param output_graph: 保存输出文件 :param clear_devices: 清除训练设备 :return: NULL """ restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" freeze_graph.freeze_graph( input_graph=input_graph, # 输入图 input_saver=input_saver, # Saver解析器 input_binary=input_binary, # 输入图的格式,false是文本,true是二进制 input_checkpoint=input_checkpoint, # checkpoint,检查点文件 output_node_names=output_node_names, # 输出节点名称 restore_op_name=restore_op_name, # 从模型恢复节点的名字 filename_tensor_name=filename_tensor_name, # tensor名称 output_graph=output_graph, # 保存输出文件 clear_devices=clear_devices, # 清除训练设备 initializer_nodes="") # 初始化节点
优化模型,剪切节点,模型只保留输入输出的参数。
def gnr_optimize_graph(graph_path, optimized_graph_path): """ 优化图 :param graph_path: 原始图 :param optimized_graph_path: 优化的图 :return: NULL """ input_graph_def = tf.GraphDef() # 读取原始图 with tf.gfile.Open(graph_path, "r") as f: data = f.read() input_graph_def.ParseFromString(data) # 设置输入输出节点,剪切分支,大约节省1/4 output_graph_def = optimize_for_inference_lib.optimize_for_inference( input_graph_def, ["I"], # an array of the input node(s) ["O"], # an array of output nodes tf.float32.as_datatype_enum) # 存储优化的图 f = tf.gfile.FastGFile(optimized_graph_path, "w") f.write(output_graph_def.SerializeToString())
执行函数,生成模型,frozen_tfdroid.pb
和optimized_tfdroid.pb
。
if __name__ == "__main__": input_graph_path = MODEL_FOLDER + MODEL_NAME + '.pbtxt' # 输入图 checkpoint_path = MODEL_FOLDER + MODEL_NAME + '.ckpt' # 输入参数 output_path = MODEL_FOLDER + 'frozen_' + MODEL_NAME + '.pb' # Freeze模型 gnr_freeze_graph(input_graph=input_graph_path, input_saver="", input_binary=False, input_checkpoint=checkpoint_path, output_node_names="O", output_graph=output_path, clear_devices=True) optimized_output_graph = MODEL_FOLDER + 'optimized_' + MODEL_NAME + '.pb' gnr_optimize_graph(output_path, optimized_output_graph)
Android
编译Android的库,参考,或者,直接在Nightly中下载,参考,archive.zip,大约158M。
创建Android工程,添加app/libs/
中添加库文件。
armeabi-v7a/libtensorflow_inference.solibandroid_tensorflow_inference_java.jar
在build.gradle中,添加
android { sourceSets { main { jniLibs.srcDirs = ['libs'] } }}
在app/src/main/assets中,添加模型optimized_tfdroid.pb
文件。
在MainActivity中,添加so库。
static { System.loadLibrary("tensorflow_inference");}
模型文件在assets中,TF的核心接口类TensorFlowInferenceInterface。
private static final String MODEL_FILE = "file:///android_asset/optimized_tfdroid.pb";private TensorFlowInferenceInterface mInferenceInterface;
初始模型文件
mInferenceInterface = new TensorFlowInferenceInterface();mInferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
模型Feed数据,输入点名称是INPUT_NODE
,输入结构INPUT_SIZE
,输入数据inputFloats。
float[] inputFloats = {num1, num2, num3};mInferenceInterface.fillNodeFloat(INPUT_NODE, INPUT_SIZE, inputFloats);
模型执行文件,输出点名称是OUTPUT_NODE
,即"O"
mInferenceInterface.runInference(new String[]{OUTPUT_NODE});
输出数据结构
float[] resu = {0, 0};mInferenceInterface.readNodeFloat(OUTPUT_NODE, resu);
最后,在layout中创建GUI布局。
效果
DemoTensorFlow集成至春雨医生
CY-TFThat's all! Enjoy it!
更多相关文章
- NPM 和webpack 的基础使用
- 【阿里云镜像】使用阿里巴巴DNS镜像源——DNS配置教程
- android 超简单的下载功能,进度条 异步下载
- vim+ctags+cscope 打造Android源码阅读工具
- Android(安卓)读取 assets目录下的文件
- android 五种 布局文件
- Android安装apk文件,适配Android(安卓)7.0
- Android(安卓)Studio 新建编辑条 点击按钮显示控件中的内容
- linux 下使用ndk-build编译android使用的c++静态库