将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/79672257

项目Github下载地址:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo麻烦给个“star”哈

       本博客将以最简单的方式,利用TensorFlow实现了MNIST手写数字识别,并将Python TensoFlow训练好的模型移植到Android手机上运行。网上也有很多移植教程,大部分是在Ubuntu(Linux)系统,一般先利用Bazel工具把TensoFlow编译成.so库文件和jar包,再进行Android配置,实现模型移植。不会使用Bazel也没关系,实质上TensoFlow已经为开发者提供了最新的.so库文件和对应的jar包了(如libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar),我们只需要下载文件,并在本地Android Studio导入jar包和.so库文件,即可以在Android加载TensoFlow的模型了。 

      当然了,本博客的项目代码都上传到Github:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo

      先说一下,本人的开发环境:

  • Windows 7
  • Python3.5
  • TensoFlow 1.6.0(2018年3月23日—当前最新版)
  • Android Studio 3.0.1(2018年3月23日—当前最新版)

一、利用Python训练模型

   以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现单隐含层的SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件。首先,需要定义模型的输入层和输出层节点的名字(通过形参 'name'指定,名字可以随意,后面加载模型时,都是通过该name来传递数据的):

x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点:x_input...pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点:output

PS:说一下鄙人遇到坑:起初,我参照网上相关教程训练了一个模型,在Windows下测试没错,但把模型移植到Android后就出错了,但用别人的模型又正常运行;后来折腾了半天才发现,是类型转换出错啦!!!!
TensorFlow默认类型是float32,但我们希望返回的是一个int型,因此需要指定output_type='int32';但注意了,在Windows下测试使用int64和float64都是可以的,但在Android平台上只能使用int32和float32,并且对应Java的int和float类型。

 将训练好的模型保存为.pb文件,这就需要用到tf.graph_util.convert_variables_to_constants函数了。

# 保存训练好的模型#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。    f.write(output_graph_def.SerializeToString())

   关于tensorflow保存模型和加载模型的方法,请参考本人另一篇博客:https://blog.csdn.net/guyuealian/article/details/79693741

   这里给出Python训练模型完整的代码如下:

#coding=utf-8# 单隐层SoftMax Regression分类器:训练和保存模型模块from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tffrom tensorflow.python.framework import graph_utilprint('tensortflow:{0}'.format(tf.__version__))mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)#create modelwith tf.name_scope('input'):    x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_input    y_ = tf.placeholder(tf.float32,[None,10],name='y_input')with tf.name_scope('layer'):    with tf.name_scope('W'):        #tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]        W = tf.Variable(tf.zeros([784,10]),name='Weights')    with tf.name_scope('b'):        b = tf.Variable(tf.zeros([10]),name='biases')    with tf.name_scope('W_p_b'):        Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')    y = tf.nn.softmax(Wx_plus_b, name='final_result')# 定义损失函数和优化方法with tf.name_scope('loss'):    loss = -tf.reduce_sum(y_ * tf.log(y))with tf.name_scope('train_step'):    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)    print(train_step)# 初始化sess = tf.InteractiveSession()init = tf.global_variables_initializer()sess.run(init)# 训练for step in range(100):    batch_xs,batch_ys =mnist.train.next_batch(100)    train_step.run({x:batch_xs,y_:batch_ys})    # variables = tf.all_variables()    # print(len(variables))    # print(sess.run(b))# 测试模型准确率pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:outputcorrect_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})print('测试正确率:{0}'.format(a))# 保存训练好的模型#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。    f.write(output_graph_def.SerializeToString())sess.close()

上面的代码已经将训练模型保存在model/mnist.pb,当然我们可以先在Python中使用该模型进行简单的预测,测试方法如下:

import tensorflow as tfimport numpy as npfrom PIL import Imageimport matplotlib.pyplot as plt#模型路径model_path = 'model/mnist.pb'#测试图片testImage = Image.open("data/test_image.jpg");with tf.Graph().as_default():    output_graph_def = tf.GraphDef()    with open(model_path, "rb") as f:        output_graph_def.ParseFromString(f.read())        tf.import_graph_def(output_graph_def, name="")    with tf.Session() as sess:        tf.global_variables_initializer().run()        # x_test = x_test.reshape(1, 28 * 28)        input_x = sess.graph.get_tensor_by_name("input/x_input:0")        output = sess.graph.get_tensor_by_name("output:0")        #对图片进行测试        testImage=testImage.convert('L')        testImage = testImage.resize((28, 28))        test_input=np.array(testImage)        test_input = test_input.reshape(1, 28 * 28)        pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果        print('模型预测结果为:',pre_num)        #显示测试的图片        # testImage = test_x.reshape(28, 28)        fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片        plt.title("prediction result:"+str(pre_num))        plt.show()

二、移植到Android

    相信大家看到很多大神的博客,都是要自己编译TensoFlow的so库和jar包,说实在的,这个过程真TM麻烦,反正我弄了半天都没成功过,然后放弃了……。本博客的移植方法不需要安装Bazel,也不需要构建TensoFlow的so库和jar包,因为Google在TensoFlow github中给我们提供了,为什么不用了!!!

1、下载TensoFlow的jar包和so库

    TensoFlow在Github已经存放了很多开发文件:https://github.com/PanJinquan/tensorflow

   我们需要做的是,下载Android: native libs ,打包下载全部文件,其中有我们需要的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,有了这两个文件,剩下的就是在Android Studio配置的问题了

2、Android Studio配置

(1)新建一个Android项目

(2)把训练好的pb文件(mnist.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。

(3)将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下

 

(4)app\build.gradle配置

    在defaultConfig中添加

   multiDexEnabled true        ndk {            abiFilters "armeabi-v7a"        }

    增加sourceSets

    sourceSets {        main {            jniLibs.srcDirs = ['libs']        }    }

    在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:

    compile files('libs/libandroid_tensorflow_inference_java.jar')

   OK了,build.gradle配置完成了,剩下的就是java编程的问题了。

3、模型调用

  在需要调用TensoFlow的地方,加载so库“System.loadLibrary("tensorflow_inference");并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了

     注意,旧版的TensoFlow,是如下方式进行,该方法可参考大神的博客:https://www.jianshu.com/p/1168384edc1e

TensorFlowInferenceInterface.fillNodeFloat(); //送入输入数据TensorFlowInferenceInterface.runInference();  //进行模型的推理TensorFlowInferenceInterface.readNodeFloat(); //获取输出数据

     但在最新的libandroid_tensorflow_inference_java.jar中,已经没有这些方法了,换为

TensorFlowInferenceInterface.feed()TensorFlowInferenceInterface.run()TensorFlowInferenceInterface.fetch()

     下面是以MNIST手写数字识别为例,其实现方法如下:

package com.example.jinquan.pan.mnist_ensorflow_androiddemo;import android.content.res.AssetManager;import android.graphics.Bitmap;import android.graphics.Color;import android.graphics.Matrix;import android.util.Log;import org.tensorflow.contrib.android.TensorFlowInferenceInterface;public class PredictionTF {    private static final String TAG = "PredictionTF";    //设置模型输入/输出节点的数据维度    private static final int IN_COL = 1;    private static final int IN_ROW = 28*28;    private static final int OUT_COL = 1;    private static final int OUT_ROW = 1;    //模型中输入变量的名称    private static final String inputName = "input/x_input";    //模型中输出变量的名称    private static final String outputName = "output";    TensorFlowInferenceInterface inferenceInterface;    static {        //加载libtensorflow_inference.so库文件        System.loadLibrary("tensorflow_inference");        Log.e(TAG,"libtensorflow_inference.so库加载成功");    }    PredictionTF(AssetManager assetManager, String modePath) {        //初始化TensorFlowInferenceInterface对象        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);        Log.e(TAG,"TensoFlow模型文件加载成功");    }    /**     *  利用训练好的TensoFlow模型预测结果     * @param bitmap 输入被测试的bitmap图     * @return 返回预测结果,int数组     */    public int[] getPredict(Bitmap bitmap) {        float[] inputdata = bitmapToFloatArray(bitmap,28, 28);//需要将图片缩放带28*28        //将数据feed给tensorflow的输入节点        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);        //运行tensorflow        String[] outputNames = new String[] {outputName};        inferenceInterface.run(outputNames);        ///获取输出节点的输出信息        int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据        inferenceInterface.fetch(outputName, outputs);        return outputs;    }    /**     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。     * @param bitmap 输入被测试的bitmap图片     * @param rx 将图片缩放到指定的大小(列)->28     * @param ry 将图片缩放到指定的大小(行)->28     * @return   返回归一化后的一维float数组 ->28*28     */    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){        int height = bitmap.getHeight();        int width = bitmap.getWidth();        // 计算缩放比例        float scaleWidth = ((float) rx) / width;        float scaleHeight = ((float) ry) / height;        Matrix matrix = new Matrix();        matrix.postScale(scaleWidth, scaleHeight);        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);        Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());        Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());        height = bitmap.getHeight();        width = bitmap.getWidth();        float[] result = new float[height*width];        int k = 0;        //行优先        for(int j = 0;j < height;j++){            for (int i = 0;i < width;i++){                int argb = bitmap.getPixel(i,j);                int r = Color.red(argb);                int g = Color.green(argb);                int b = Color.blue(argb);                int a = Color.alpha(argb);                //由于是灰度图,所以r,g,b分量是相等的。                assert(r==g && g==b);//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);                result[k++] = r / 255.0f;            }        }        return result;    }}
  • 简单说明一下:项目新建了一个PredictionTF类,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径;
  •  从资源文件中获取BitMap图片,并传入 getPredict(Bitmap bitmap)方法,该方法首先将BitMap图像缩放到28*28的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将28*28的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)方法的作用;
  •  然后将数据feed给tensorflow的输入节点,并运行(run)tensorflow,最后获取(fetch)输出节点的输出信息。

   MainActivity很简单,一个单击事件获取预测结果:

package com.example.jinquan.pan.mnist_ensorflow_androiddemo;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.support.v7.app.AppCompatActivity;import android.os.Bundle;import android.util.Log;import android.view.View;import android.widget.ImageView;import android.widget.TextView;public class MainActivity extends AppCompatActivity {    // Used to load the 'native-lib' library on application startup.    static {        System.loadLibrary("native-lib");//可以去掉    }    private static final String TAG = "MainActivity";    private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路径    TextView txt;    TextView tv;    ImageView imageView;    Bitmap bitmap;    PredictionTF preTF;    @Override    protected void onCreate(Bundle savedInstanceState) {        super.onCreate(savedInstanceState);        setContentView(R.layout.activity_main);        // Example of a call to a native method        tv = (TextView) findViewById(R.id.sample_text);        txt=(TextView)findViewById(R.id.txt_id);        imageView =(ImageView)findViewById(R.id.imageView1);        bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test_image);        imageView.setImageBitmap(bitmap);        preTF =new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型    }    public void click01(View v){        String res="预测结果为:";        int[] result= preTF.getPredict(bitmap);        for (int i=0;i

   activity_main布局文件:

<?xml version="1.0" encoding="utf-8"?>        

最后一步,就是run,run,run,效果如下, 

本博客的项目代码都上传到Github:下载地址:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo

 

相关参考资料:https://blog.csdn.net/gzhermit/article/details/73924515

如果你觉得该帖子帮到你,还望贵人多多支持,鄙人会再接再厉,继续努力的~

 

更多相关文章

  1. Android(安卓)Studio添加so文件并打包到APK的lib文件夹中
  2. android studio 设置模板
  3. Android偶遇问题处理
  4. build/envsetup.sh简记
  5. 导入android工程没有R文件的解决办法
  6. Mac 配置Android(安卓)studio 环境变量
  7. android线程模型文章集合
  8. Android(安卓)开发技巧之Log写入文件
  9. Android(安卓)数据持久化方式

随机推荐

  1. Android实现拖拽GridView到目标View
  2. Android(安卓)AlertDialog工具类 实现对
  3. android studio 修改包名
  4. Attacking memory problems on Android
  5. Android(安卓)IjkPlayer 视频播放器
  6. How to enable Android(安卓)ConfigFS ga
  7. Android之调用js常见错误
  8. Android实时获得经纬度,发送给c++服务端
  9. android位置服务
  10. Android(安卓)图片的缓存机制分析