Tensorflow:Android调用Tensorflow Mobile版本API(2)-基于Android的调用
16lz
2021-01-26
对上一篇博客中代码略做修改,在训练完成之后进行模型导出操作
# y = x^2 + 1import tensorflow as tfimport numpy as npimport randomdef get_batch(size=128): xs = [] ys = [] for i in range(size): x = random.random() * 2 y = x * x + 1 xs.append(x) ys.append(y) return np.array(xs), np.array(ys)X = tf.placeholder(tf.float32, [None,1], name='input')Y = tf.placeholder(tf.float32, [None,1])def my_dnn(): x = tf.reshape(X, shape=[-1, 1]) w1 = tf.Variable(tf.random_normal(shape=[1,256], mean=0.0, stddev=1)) b1 = tf.Variable(tf.random_normal([256])) out1 = tf.nn.bias_add(tf.matmul(x,w1),b1) out1 = tf.nn.relu(out1) w2= tf.Variable(tf.random_normal(shape=[256,256])) b2 = tf.Variable(tf.random_normal([256])) out2= tf.nn.bias_add(tf.matmul(out1, w2),b2) out2 = tf.nn.relu(out2) w3 = tf.Variable(tf.random_normal(shape=[256, 256])) b3 = tf.Variable(tf.random_normal([256])) out3 = tf.nn.bias_add(tf.matmul(out2, w3),b3) out3 = tf.nn.relu(out3) w4 = tf.Variable(tf.random_normal(shape=[256, 1])) b4 = tf.Variable(tf.random_normal([1])) out4 = tf.nn.bias_add(tf.matmul(out3, w4), b4, name='output') return out4def train(): out = my_dnn() loss = tf.reduce_mean(tf.square(Y - out)) optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) step = 0 while True: batch_x, batch_y = get_batch(64) batch_x = batch_x.reshape([-1, 1]) batch_y = batch_y.reshape([-1, 1]) _, loss_ = sess.run([optimizer, loss], feed_dict={X:batch_x, Y:batch_y}) print(loss_) if loss_ < 0.0001: saver.save(sess, "./1.model", global_step=step) break step += 1# train()def eval(): out = my_dnn() saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint('.')) for i in range(100): x = random.random() * 2 x = np.array([x]).reshape([-1,1]) y = sess.run(out, feed_dict={X:x}) print("x=%.5f 正确的y=%.5f 预测的 y=%.5f" % (x, x*x + 1, y))def exportModel(): out = my_dnn() saver = tf.train.Saver() with tf.Session() as sess: # 恢复模型参数 saver.restore(sess, tf.train.latest_checkpoint('.')) from tensorflow.python.framework.graph_util import convert_variables_to_constants output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output']) with tf.gfile.FastGFile('1.pb', mode='wb') as f: f.write(output_graph_def.SerializeToString())if __name__ == '__main__': # 训练 # train() # 评估 # eval() # 导出模型 exportModel()
新建一个Android项目
导入tensorflow-mobile的库
可以选择导在线的库,在这里导入离线的库
https://jcenter.bintray.com/org/tensorflow/tensorflow-android/
我添加1.6.0版本的,修改了gradle文件,完成了添加
需要注意:
compile(name: ‘tensorflow-android-1.6.0-rc0’, ext: ‘aar’)
“name: “:后要有空格
“ext: ” :后要有空格
添加模型文件
编写tensorflow mobile API的封装
最后在Activity调就可以了
给出完整项目的GitHub链接:
https://github.com/imu-hupeng/TensorflowTestApp
更多相关文章
- android Log4j学习笔记
- android9.0解决http获取异常
- Android理解Fragment生命周期,fragment和fragmentactivity解析
- android 8.1Settings添加设置项
- ECLIPSE android 布局页面文件出错故障排除Exception raised dur
- tensorflow和android零接触 (mac)
- Android(安卓)GreenDao数据库使用
- 如何在Android系统中添加系统服务(以PowerManager为例)
- Android(安卓)编辑联系人,增、删、改代码