Python+Android进行TensorFlow开发_第1张图片

Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。

Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。
将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下:

Python+Android进行TensorFlow开发_第2张图片

Android目录结构

同时在app的build.gradle中的dependencies模块下添加如下配置:

    

dependencies {
    ...
    compile files('libs/libandroid_tensorflow_inference_java.jar')
    ...
}

使用tensorflow框架进行机器学习分为四个步骤:

  • 构造神经网络

  • 训练神经网络模型

  • 将训练好的模型输出为pb文件

  • 在Android上加载pb模型进行计算

前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

    

# -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""
训练模型
"""

def train():
    # 定义训练数据集batch大小为8
    batch_size = 8

    # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层
    w1 = tf.Variable(tf.random_normal([23], stddev=1, seed=1), name="w1_val")
    w2 = tf.Variable(tf.random_normal([31], stddev=1, seed=1), name="w2_val")

    # 定义输入输出格式
    x = tf.placeholder(tf.float32, shape=(None2), name='x_input')
    y_ = tf.placeholder(tf.float32, shape=(None1))

    # 定义神经网络前向传播过程
    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2, name="cal_node")

    # 定义交叉熵和反向传播算法
    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-101.0)))
    train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)

    # 生成随机训练集
    rdm = RandomState(1)
    dataset_size = 128

    # 定义映射关系
    X = rdm.rand(dataset_size, 2)
    Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

    with tf.Session() as sess:
        # 初始化所有参数
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # print sess.run(w1)
        # print sess.run(w2)

        STEPS = 500
        for i in range(STEPS):
            start = (i * batch_size) % dataset_size
            end = min(start + batch_size, dataset_size)

            # 训练神经网络,更新神经网络参数
            sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})

            if i % 100 == 0:
                total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
                print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))

            print(sess.run(w1))
            print(sess.run(w2))

        # 保存check point
        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(sess, './model/checpt')

上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:

Python+Android进行TensorFlow开发_第3张图片

checkpoint相关文件

由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型:

    

"""
存储pb模型
"""

def dump_graph_to_pb(pb_path):
    with tf.Session() as sess:
        check_point = tf.train.get_checkpoint_state("./model/")
        if check_point:
            saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
            saver.restore(sess, check_point.model_checkpoint_path)
        else:
            raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))

        graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))

        with tf.gfile.GFile(pb_path, "wb"as f:
            f.write(graph_def.SerializeToString())

拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下:

Python+Android进行TensorFlow开发_第4张图片

接下来就可以载入pb,进行计算了:

    

public class MainActivity extends AppCompatActivity {
    private Graph graph_;
    private Session session_;
    private AssetManager assetManager;

    private static ExecutorService executorService;
    private static Handler handler;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        executorService = Executors.newFixedThreadPool(5);

        // 初始化tensorflow
        initTensorFlow("outmodel.pb");

        // 使用tensorflow进行计算
        runTensorFlow();
    }
    ...
}

通过如下方式载入pb模型,初始化tensorflow:

    

private boolean initTensorFlow(String modelFile) {
        assetManager = getAssets();
        // 新建Graph
        graph_ = new Graph();

        InputStream is = null;
        try {
            // 读取Assets pb文件
            is = assetManager.open(modelFile);
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }

        try {
            // 加载pb到Graph
            TensorUtil.loadGraph(is, graph_);
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }
        // 初始化session
        session_ = new Session(graph_);
        if (session_ == null) {
            return false;
        }

        return true;
    }

然后就可以使用tensorflow API进行运算了:

    

private void runTensorFlow() {
        executorService.execute(generatePredictRunnable(handler));
    }

    private Runnable generatePredictRunnable(Handler handler) {
        return new Runnable() {
            @Override
            public void run() {
                float[][] input = new float[1][2];

                input[0][0] = 1;
                input[0][1] = 2;

                // 定义输入tensor
                Tensor inputTensor = Tensor.create(input);

                // 指定输入,输出节点,运行并得到结果
                Tensor resultTensor = session_.runner()
                        .feed("x_input", inputTensor)
                        .fetch("cal_node")
                        .run()
                        .get(0);

                float[][] dst = new float[1][1];
                resultTensor.copyTo(dst);

                // 处理结果
                ArrayList<Float> resultList = new ArrayList<>();
                for (float val : dst[0]) {
                    if (val != 0) {
                        resultList.add(val);
                    } else {
                        break;
                    }
                }
            }
        };
    }

上面就是通过python训练机器学习模型,并在android平台进行调用的完整流程。

原创作者:JackMeGo,原文链接:https://www.jianshu.com/p/eef4ab014a12

PS: 我之前写的书《Android音视频开发》,断货厉害,最近都没有去宣传,目前当当网还有200多册,感兴趣可以去当当网入手。源码和勘误地址:https://github.com/hejunlin2013/AVBookCode


640?wx_fmt=png

突围单一技术孤岛,与近300位一见如故的球友同行,来鱼哥的知识星球,如果你今天加入了,请在微信评论区告诉我你加入的原因,你对知识星球的期待,你自己通过加入知识星球实现的愿望,这些反馈会帮助我更好理解你,帮助你。即刻加入,请点击【阅读原文】。

Python+Android进行TensorFlow开发_第5张图片

更多相关文章

  1. Android中彩信文件的读取
  2. Android 上传图片到服务器(多文件上传)
  3. Android 系统文件简介
  4. Android 保存数据到文件
  5. Android 使用FTP上传文件
  6. Android将需要的日志文件LOG记录到本地文件夹下指定的文件

随机推荐

  1. Android模块化编译
  2. Android(安卓)zygote的分裂总结
  3. Android开发之使用MediaRecorder录制声音
  4. MT6573 android 2.3系统默认语言处理流程
  5. Android开发者指南-摄像头-Camera
  6. App应用唯一标示码
  7. Android开发工具——ADB(Android(安卓)De
  8. Android(安卓)SDK下, 如何在程序中输出日
  9. Android(安卓)Matrix 超级详解
  10. android获取已安装应用中的系统应用程序