【懒懒的Tensorflow学习笔记三之搭建简单的神经网络模型】
16lz
2021-01-22
用Tensorflow 实现了一个简单的三层神经网络结构,来拟合Y=X²,激励函数采用Relu,采用误差反向传播更新权值,并用可视化的形式展示处理,具体代码如下:
# coding=utf-8 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 构建添加层的函数和一个简单的神经网络 def add_layer(inputs, in_size, out_size, activation_function=None): ''' :param inputs:输入值 :param in_size:输入值的维度 :param out_size:输出装的维度 :param activation_function:激活函数 :return: ''' # 初始化变量的时候采用随机值比全0值好 Weight = tf.Variable(tf.random_uniform([in_size, out_size])) Bias = tf.Variable(tf.zeros([1, out_size]) + 0.1) # 网络输出值 Wx_plus_b = tf.matmul(inputs, Weight) + Bias # 选择激活函数 if activation_function is None: output = Wx_plus_b else: output = activation_function(Wx_plus_b) return output # 构建一个简单的神经网络模型 # 创建训练数据集 x_data = np.linspace(-1, 1, 200, dtype=np.float32)[:, np.newaxis] # 添加噪声 noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32) y_data = np.square(x_data) + noise # 定义输入和输出 # placehoolder()表示占位符,None表示多少个都可以,1表示数据特征维度为1 inputs = tf.placeholder(tf.float32, [None, 1]) outputs = tf.placeholder(tf.float32, [None, 1]) # 搭建网络 # 定义隐含层 # inputs表示输入数据,1表示输入数据特征维度为1,10表示隐含层神经元个数为10,激活函数使用Tensorflow自带的激活函数relu hidden_outputs = add_layer(inputs, 1, 10, activation_function=tf.nn.relu) # 定义输出层 # hidden_outputs是隐含层的输出,10是隐含层输出特征的个数,1表示输出值的维度 prediction = add_layer(hidden_outputs, 10, 1, activation_function=None) # 定义损失函数 reduction_indices=[1]表示按行求和 loss = tf.reduce_mean(tf.reduce_sum(tf.square(prediction - outputs), reduction_indices=[1])) # 优化器 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 变量初始化 init = tf.global_variables_initializer() # 开始训练 sess = tf.Session() sess.run(init) # 输出散点图 fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.scatter(x_data, y_data, color='b') plt.ion() # 用于连续显示 plt.show() for i in range(1000): sess.run(train_step, feed_dict={inputs: x_data, outputs: y_data}) if i % 50 ==0: try: ax.lines.remove(lines[0]) except: pass prediction_value = sess.run(prediction, feed_dict={inputs: x_data}) print(sess.run(loss, feed_dict={inputs: x_data, outputs: y_data})) lines = ax.plot(x_data, prediction_value, 'r-', lw=2) plt.pause(0.1)
结果如图:
更多相关文章
- Python3语法——Python3函数参数的各种形式
- python 函数、参数及参数解构
- python函数小练习
- Python学习总结-(15)---返回函数和闭包初步理解
- 初识python:高阶函数(附-高阶函数)
- Python学习札记(二十六) 函数式编程7 修饰器
- Python:运算类内建函数列举
- 学习笔记(11月02日)--高阶函数
- python:inspect函数自动生成函数名