TensorFlow在Android平台上的安装和应用
当我们有一个已经训练好的TF模型的时候,我们怎么去调用这个模型并且让他顺利在Android平台上运行起来呢?
大概包括这几个方面:
1、 保存训练完毕的TF模型
2、 在Android项目中导入TF模型、导入Android平台调用TF模型需要的jar包和so文件 (它们负责TF模型的解析和运算)
3、定义变量、存储数据,通过jar包提供的接口进行模型的调用
移植过程
我们以mnist数据集上自己训练的一个图像识别模型为例,进行讲解
一、 在使用python代码编写的TF模型定义中为模型的输入层和输出层Tensor Variable分别指定名字(通过形参 ‘name’)
X = tf.placeholder(tf.float32, shape = […], name=‘input’) //网络的输入Y = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases, name=’output’) //网络的输出
- 1
- 2
名字可以随便起,以方便好记为主,后面还会反复用到。我起的是input和output。
二、 将使用TensorFlow训练好的模型保存为.pb文件
在模型训练结束后的代码位置,添加下述两句代码,可将模型保存为.pb文件
output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’])//形参output_node_names用于指定输出的节点名称
- 1
- 2
贴一个说明文档,帮助大家进一步了解这个函数
with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f: f.write(output_graph_def.SerializeToString())
- 1
- 2
第一个参数用于指定输出的文件存放路径、文件名及格式。我把它放在与代码同级目录的model文件下,取名为mnist.pb
第二个参数 mode用于指定文件操作的模式,’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
如果不指明‘b’,则默认会以文本txt方式写入文件。现在TF还不支持对文本格式.pb文件的解析,在调用时会出现报错。
注:
1)、不能使用 tf.train.write_graph()保存模型,因为它只是保存了模型的结构,并不保存训练完毕的参数值
2)、不能使用 tf.train.saver()保存模型,因为它只是保存了网络中的参数值,并不保存模型的结构。
很显然,我们需要的是既保存模型的结构,又保存模型中每个参数的值。以上两者皆不符合。
三、生成在Android平台上调用tensorflow 模型需要的jar包和so文件
1) 从github下载TensorFlow的项目源码
2) 安装Bazel
Bazel的安装过程,我在另一篇文章中有介绍,欢迎参阅
Ubuntu14.04 源代码安装 TensorFlow r0.12 详细教程
3) 参考如下图的官方教程,生成Android上调用TF模型需要的so文件和jar包
四、安装Android Studio,创建Android 项目
Android Studio安装完毕后,还需要搭建环境。搭建过程可参考我的另一篇文章:
Ubuntu 使用 Android Studio 编译 TensorFlow android demo
五、添加资源到项目
1) 将(二)步生成的.pb文件放入项目中
打开 Project view ,app/src/main/assets。
若不存在assets目录,右键main->new->folder->Assets Folder
2) 添加(三)步生成的jar包
打开Project view,将jar包拷贝到app->libs下
选中jar文件,右键 add as library
3) 添加(三)生成的so文件
打开 Project view,将.so文件拷贝到 app/src/main/jniLibs下(jniLibs文件夹若没有则新建)
如果我讲的不太明白的话,可自行谷歌搜索“如何在 Android studio中添加引用 jar文件和so文件”
六、创建接口,实现调用
1) 导入jar包和so文件
在需要调用模型的.java文件中,导入jar包:
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
- 1
在该java类定义的首行,导入so文件:
{ System.loadLibrary(“tensorflow_inference”)}
- 1
- 2
- 3
2)定义变量及对象
private static final String MODEL_FILE = “file:///android_asset/mnist.pb” //模型存放路径private static final String INPUT_NODE = “input”; //模型中输入变量的名称private static final String INPUT_NODE = “output”; //模型中输出变量的名称private static final int NUM_CLASSES = 10; //样本集的类别数量,mnist数据集对应10private static final int HEIGHT = 24; //输入图片的像素高private static final int WIDTH = 24; //输入图片的像素宽private static final int CHANNEL = 3; //输入图片的通道数:RGBprivate floats inputs = new float[HEIGHT*WIDTH*CHANNEL]; //用于存储的模型输入数据private floats outputs = new float[NUM_CLASSES]; //用于存储模型的输出数据
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
2)Tensorflow 接口初始化
private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(); //接口定义inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE); //接口初始化
- 1
- 2
在完成上述两步之后,就可以反复调用模型。
在每次调用前,先将待输入的数据按顺序存放进 inputs 变量中,然后执行下述三个语句。
3)TF模型的调用
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs); //送入输入数据inferenceInterface.runInference(new String[]{OUTPUT_NODE}); //进行模型的推理inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //获取输出数据
- 1
- 2
- 3
然后接下来的主要工作就是安卓项目的编译以及将编译完的apk文件安装到手机,这部分内容与一般的安卓项目并无区别。这些内容在另一篇文章中也有所提及:
Ubuntu 使用 Android Studio 编译 TensorFlow android demo
更多相关文章
- State 状态模式在 Android(安卓)多弹窗的应用
- Android如何在字符串资源文件strings.xml中通过引用的方式在一个
- Android多媒体学习六:访问网络上的Audio对应的M3U文件,实现网络音
- Android(安卓)APN的设置问题
- 【移动开发】Android中将我们平时积累的工具类打包
- Android(安卓)调用百度在线语音识别功能
- (6) Android中Binder调用流程 --- Binder驱动总结
- 关于Android调用单目摄像头以及双目摄像头的方法(智能平板)
- 关于Android的Sensor驱动,不支持内核模块模式的驱动