对上一篇博客中代码略做修改,在训练完成之后进行模型导出操作

# y = x^2 + 1import tensorflow as tfimport numpy as npimport randomdef get_batch(size=128):    xs = []    ys = []    for i in range(size):        x = random.random() * 2        y = x * x + 1        xs.append(x)        ys.append(y)    return np.array(xs), np.array(ys)X = tf.placeholder(tf.float32, [None,1], name='input')Y = tf.placeholder(tf.float32, [None,1])def my_dnn():    x = tf.reshape(X, shape=[-1, 1])    w1 = tf.Variable(tf.random_normal(shape=[1,256], mean=0.0,                                      stddev=1))    b1 = tf.Variable(tf.random_normal([256]))    out1 = tf.nn.bias_add(tf.matmul(x,w1),b1)    out1 = tf.nn.relu(out1)    w2= tf.Variable(tf.random_normal(shape=[256,256]))    b2 = tf.Variable(tf.random_normal([256]))    out2= tf.nn.bias_add(tf.matmul(out1, w2),b2)    out2 = tf.nn.relu(out2)    w3 = tf.Variable(tf.random_normal(shape=[256, 256]))    b3 = tf.Variable(tf.random_normal([256]))    out3 = tf.nn.bias_add(tf.matmul(out2, w3),b3)    out3 = tf.nn.relu(out3)    w4 = tf.Variable(tf.random_normal(shape=[256, 1]))    b4 = tf.Variable(tf.random_normal([1]))    out4 = tf.nn.bias_add(tf.matmul(out3, w4), b4, name='output')    return out4def train():    out = my_dnn()    loss = tf.reduce_mean(tf.square(Y - out))    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)    saver = tf.train.Saver()    with tf.Session() as sess:        sess.run(tf.initialize_all_variables())        step = 0        while True:            batch_x, batch_y = get_batch(64)            batch_x = batch_x.reshape([-1, 1])            batch_y = batch_y.reshape([-1, 1])            _, loss_ = sess.run([optimizer, loss], feed_dict={X:batch_x, Y:batch_y})            print(loss_)            if loss_ < 0.0001:                saver.save(sess, "./1.model", global_step=step)                break            step += 1# train()def eval():    out = my_dnn()    saver = tf.train.Saver()    with tf.Session() as sess:        saver.restore(sess, tf.train.latest_checkpoint('.'))        for i in range(100):            x = random.random() * 2            x = np.array([x]).reshape([-1,1])            y = sess.run(out, feed_dict={X:x})            print("x=%.5f 正确的y=%.5f 预测的 y=%.5f" % (x, x*x + 1, y))def exportModel():    out = my_dnn()    saver = tf.train.Saver()    with tf.Session() as sess:        # 恢复模型参数        saver.restore(sess, tf.train.latest_checkpoint('.'))        from tensorflow.python.framework.graph_util import convert_variables_to_constants        output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])        with tf.gfile.FastGFile('1.pb', mode='wb') as f:            f.write(output_graph_def.SerializeToString())if __name__ == '__main__':    # 训练    # train()    # 评估    # eval()    # 导出模型    exportModel()

新建一个Android项目 

导入tensorflow-mobile的库 
可以选择导在线的库,在这里导入离线的库 
https://jcenter.bintray.com/org/tensorflow/tensorflow-android/ 
 
我添加1.6.0版本的,修改了gradle文件,完成了添加

需要注意:
compile(name: ‘tensorflow-android-1.6.0-rc0’, ext: ‘aar’) 
“name: “:后要有空格 
“ext: ” :后要有空格

添加模型文件 


编写tensorflow mobile API的封装 

最后在Activity调就可以了 

给出完整项目的GitHub链接: 
https://github.com/imu-hupeng/TensorflowTestApp

更多相关文章

  1. android Log4j学习笔记
  2. android9.0解决http获取异常
  3. Android理解Fragment生命周期,fragment和fragmentactivity解析
  4. android 8.1Settings添加设置项
  5. ECLIPSE android 布局页面文件出错故障排除Exception raised dur
  6. tensorflow和android零接触 (mac)
  7. Android(安卓)GreenDao数据库使用
  8. 如何在Android系统中添加系统服务(以PowerManager为例)
  9. Android(安卓)编辑联系人,增、删、改代码

随机推荐

  1. 音频采集(AudioRecorder)
  2. Android中xml解析--实现软件升级功能
  3. android 音量设置条
  4. Android(安卓)对Layout_weight属性完全解
  5. Android(安卓)OTA 升级之五:updater
  6. android studio 错误:“Gradle sync faile
  7. Android底部菜单的实现
  8. Android强大的数据库开源框架LitePal
  9. Android应用接入第三方登录之新浪登录
  10. ios、android 系统字体说明