一. tensorflow和android的桥接库

源代码地址:

https://github.com/tensorflow/tensorflow

使用git下载下来:

git clone https://github.com/tensorflow/tensorflow.git

配置下编译环境:

vi WORKSPACE

找到下面两个配置,取消掉原有的注释,并且将本地路径填上去:

android_sdk_repository(

name = "androidsdk",

path = "---> here is your android sdk path <---",

)

android_ndk_repository(

name="androidndk",

path="---> here is your android ndk path <---",

api_level=14)

更改好后保存。 

从上面看,你要至少有android sdk,ndk的环境,没有的话自己百度下。

接着编译android对应的jar和so文件。

编译jar包:

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java

产出物路径:

bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

编译so文件:

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so--crosstool_top=//external:android/crosstool--host_crosstool_top=@bazel_tools//tools/cpp:toolchain--cpu=armeabi-v7a

产出物路径:

bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

其中没有bazel工具的,可以通过brew安装下:

brew install bazel

最终的产出物,就是android和机器学习模型的一个桥梁,通过它们可以在android上加载机器学习的模型并且针对特定输入给予推荐的输出。

那么机器学习模型怎么获取呢?


二. 训练模型

先配置一下训练的环境。

python,可以下载2.7.13

https://www.python.org/downloads/release/python-2713/

直接包安装,记得刷新下环境变量。

source ~/.base_profile

tensorflow的python包,选择Mac CPU-only Python 2,下载好的文件假定是download_file

https://github.com/tensorflow/tensorflow

sudo pip install download_file

另外安装一些有用的包:

numpy,quandl,matplotlib

这里有numpy的介绍,它是python的一个计算库,提供丰富的数据结构:

https://docs.scipy.org/doc/numpy-dev/user/quickstart.html

环境都配置好了,我们以一个实际的例子,开始机器学习。


真实数据


上面的图形,展示了一组真实的数据,[x, y],我们希望通过已知的这些数据,预测当x=35时y的值。 (上面的数据,实际上是 y = 3 * x + 2 )

那么这个问题,怎么用机器学习的思路解决呢?

1. 需要建立一个模型,用于推测这些数据的规律。

如上面的数据,看起来是很简单的线性模型,我们建立模型如下:

y = a * x + b

2. 模型结构图

tensorflow提供了一个建立模型关系的框架,上述的数学模型可以分解为多个被称为tensor的节点,比如a, x,甚至乘法 *,这些形成一个单向图,如下:


模型图谱


上图把一个数学模型,分解成了由多个节点组成的单向图。

其中a和b,属于不确定的,我们希望通过数据训练得到的参数,在tensorflow中称为Variable;

而x和y,属于用于训练的原始数据,在tensorflow中称为placeholder,也就是占位符;

训练过程,就是持续不断的给这些占位符输入正确的数据,使得模型可以不断改进参数来尽量接近这些正确数据,最终训练的结果就是改进好的参数a和b; 这个过程,就像流水一样(flow),一遍遍的冲刷这个由tensor组成的模型图谱, 

3. 编程实现和训练

运行下面的py文件

#!/usr/bin/python

# -*- coding: utf-8 -*-

# 导入tensorflow和numpy

import tensorflow as tf

import numpy as np

# 准备训练数据,随机模拟[x, y]

tx = np.random.ranf(20)

ty = tx * 3 + 2

# 按照模型图谱,建立相应的节点

x = tf.placeholder("float", name='input')

y = tf.placeholder("float", name='output')

a = tf.Variable(0., name='a')

b = tf.Variable(0., name='b')

y_pred = tf.add(a * x,b, name='y_pred')

# 定义对训练效果的度量

error = tf.reduce_sum(tf.square(y - y_pred))

train = tf.train.GradientDescentOptimizer(0.03).minimize(error)

# 定义训练的最小偏差

min_loss = tf.constant(1e-3)

# 启动session

session = tf.Session()

session.run(tf.global_variables_initializer())

# 模拟训练

for i in range(500):

session.run(train, feed_dict={x : tx, y : ty})

a_value, b_value, error_value, is_ok = session.run([a, b, error, error < min_loss], feed_dict={x : tx, y : ty})

print a_value, b_value, error_value

if is_ok:

break

# 打印训练好的参数

result = session.run([a, b])

print result

# 保存模型和参数,文件为pb

output_graph_def = tf.graph_util.convert_variables_to_constants(session, session.graph_def,output_node_names=['input', 'output', 'a', 'b','y_pred'])

f = tf.gfile.FastGFile('fuck.pb', mode='wb')

f.write(output_graph_def.SerializeToString())

f.close()

# 终止会话

session.close()

训练结果展示:

2.61073 4.31385 89.048

1.30782 1.78556 29.8708

2.19754 3.12242 10.8184

1.83855 2.28745 4.55673

2.17195 2.68346 2.39239

2.10437 2.39156 1.55744

2.25308 2.49305 1.1679

2.27088 2.37789 0.93927

2.35422 2.38957 0.778111

2.39303 2.33406 0.652181

2.44999 2.3204 0.549087

2.49031 2.2868 0.463075

2.53413 2.2678 0.390786

2.57058 2.24364 0.329862

2.60621 2.22522 0.278461

2.63775 2.20617 0.235078

2.6674 2.18987 0.198456

2.69427 2.17421 0.167541

2.71916 2.1602 0.141441

2.74192 2.14712 0.119408

2.7629 2.13522 0.100807

2.78213 2.12422 0.0851035

2.79983 2.11415 0.0718459

2.81607 2.10487 0.060654

2.83101 2.09636 0.0512053

2.84473 2.08854 0.0432287

2.85733 2.08135 0.0364945

2.86892 2.07474 0.0308094

2.87956 2.06868 0.0260099

2.88934 2.0631 0.0219582

2.89832 2.05798 0.0185376

2.90657 2.05327 0.0156498

2.91416 2.04895 0.013212

2.92113 2.04497 0.0111538

2.92753 2.04132 0.0094163

2.93341 2.03797 0.00794944

2.93882 2.03488 0.00671111

2.94379 2.03205 0.00566565

2.94835 2.02945 0.00478304

2.95254 2.02706 0.00403797

2.9564 2.02486 0.00340894

2.95994 2.02284 0.00287788

2.96319 2.02099 0.00242957

2.96618 2.01929 0.00205113

2.96892 2.01772 0.0017316

2.97145 2.01628 0.00146184

2.97376 2.01496 0.00123412

2.97589 2.01375 0.00104187

2.97785 2.01263 0.000879567

[2.9778514, 2.012629]

上述程序运行后,会得到*.pb的文件,这个文件就是最终的模型和训练参数结果。

把这个结果保存下来,我们开始在android上,应用这个结果。


三. android上的模型应用

现在,我们已经得到了android桥接tensorflow的库(jar和so),还有一个训练好的模型(pb)。

新建一个android工程,将上述结果分别放置到:

1. jar包:   app/libs

2. so文件:app/src/jniLibs/armeabi/

3. pb文件:app/src/main/assets

然后在要使用的类中,加入对so的加载:

static {

    System.loadLibrary("tensorflow_inference");

}

最后导入TensorFlowInferenceInterface

mTensorFlowInterface = new TensorFlowInferenceInterface(getAssets(),"file:///android_asset/fuck.pb");

这样,相关的环境都ok了。

在使用pb前,我们回顾下,在上面的训练中导出的节点:

input(placeholder,需要提供数据)

output(placeholder,需要提供数据)

a, b(Variable,是训练好的参数或者需要训练的参数,这里可以作为输出数据)

y_pred(Operation,单向图节点,可以获取,可以作为输出)

开始三部曲:feed,run,fetch(提供数据,运行,获取数据)

1. 提供数据

mTensorFlowInterface.feed("input",new float[]{10.0f});

2. 运行目标节点

mTensorFlowInterface.run(new String[]{"y_pred"},false);

3. 获取结果

float[]result = new float[1];

mTensorFlowInterface.fetch("y_pred",result);

程序运行后展示:


demo展示

完整的demo如下:

链接: https://pan.baidu.com/s/1pLgEPMr 密码: 359f

更多相关文章

  1. android 通话记录去重查询方法
  2. Android(安卓)GreenDao数据库使用
  3. Android新特性-RecyclerView之基础篇
  4. Android黑群出品:SQLite数据库的使用和升级
  5. Android本地存储——SQLite数据库
  6. Android(安卓)Studio提高效率插件---adb idea
  7. Android中用onSaveInstanceState保存Fragment状态的方法
  8. Android记事本项目开发
  9. [Android(安卓)SQLite]数据存储与访问 - 内部存储

随机推荐

  1. 带你解决PHP界面显示中文乱码的问题
  2. PHP中return用法解读
  3. 两小时学会用php做网站购物车
  4. PHP之扩展Memcached命令用法实例总结
  5. 带你深入了解php与C语言的区别
  6. PHP之array_unique实现二维数组去重
  7. 通过实例解析PHP数据类型转换方法
  8. 详细解说三种PHP嵌套HTML的写法
  9. 谈谈PHP运算符"::"、"->"和"=>"的区别
  10. 十大你需要在PHP中避免的坑