博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
神经网络入门——13实现梯度下降
阅读量:5889 次
发布时间:2019-06-19

本文共 8872 字,大约阅读时间需要 29 分钟。

实现梯度下降

现在我们知道了如何更新我们的权重:

你看到的是如何实现一次更新,那我们如何把代码转化为能够计算多次权重更新,使得我们的网络能够真正学习呢?

作为示例,我们拿一个研究生学院录取数据,用梯度下降训练一个网络。数据可以在找到。数据有三个输入特征:GRE 分数、GPA 分数和本科院校排名(从 1 到 4)。排名 1 代表最好,排名 4 代表最差。

 
 

我们的目标是基于这些特征来预测一个学生能否被研究生院录取。这里,我们将使用有一个输出层的网络。用 sigmoid 做为激活函数。

数据清理

你也许认为有三个输入单元,但实际上我们要先做数据转换。rank 是类别特征,其中的数字并不表示任何相对的值。排名第 2 并不是排名第 1 的两倍;排名第 3 也不是排名第 2 的 1.5 倍。因此,我们需要用  来对 rank 进行编码。把数据分成 4 个新列,用 0 或 1 表示。排名为 1 的行对应 rank_1 列的值为 1 ,其余三列的值为 0;排名为 2 的行对应 rank_2 列的值为 1 ,其余三列的值为 0,以此类推。

我们还需要把 GRE 和 GPA 数据标准化,也就是说使得它们的均值为 0,标准偏差为 1。因为 sigmoid 函数会挤压很大或者很小的输入,所以这一步是必要的。很大或者很小输入的梯度为 0,这意味着梯度下降的步长也会是 0。由于 GRE 和 GPA 的值都相当大,我们在初始化权重的时候需要非常小心,否则梯度下降步长将会消失,网络也没法训练了。相对地,如果我们对数据做了标准化处理,就能更容易地对权重进行初始化。

这只是一个简单介绍,你之后还会学到如何预处理数据,如果你想了解我是怎么做的,可以查看下面编程练习中的 data_prep.py 文件。

 

经过转换后的 10 行数据

现在数据已经准备好了,我们看到有六个输入特征:gregpa,以及四个 rank 的虚拟变量 (dummy variables)。

均方差

这里我们要对如何计算误差做一点小改变。我们不计算 SSE,而是用误差平方的均值(mean of the square errors,MSE)。现在我们要处理很多数据,把所有权重更新加起来会导致很大的更新,使得梯度下降无法收敛。为了避免这种情况,你需要一个很小的学习率。这里我们还可以除以数据点的数量 mm 来取平均。这样,无论我们有多少数据,我们的学习率通常会在 0.01 to 0.001 之间。我们用 MSE(下图)来计算梯度,结果跟之前一样,只是取了平均而不是取和。

import numpy as npfrom data_prep import features, targets, features_test, targets_testdef sigmoid(x):    """    Calculate sigmoid    """    return 1 / (1 + np.exp(-x))# TODO: We haven't provided the sigmoid_prime function like we did in#       the previous lesson to encourage you to come up with a more#       efficient solution. If you need a hint, check out the comments#       in solution.py from the previous lecture.# Use to same seed to make debugging easiernp.random.seed(42)n_records, n_features = features.shapelast_loss = None# Initialize weightsweights = np.random.normal(scale=1 / n_features**.5, size=n_features)# Neural Network hyperparametersepochs = 1000learnrate = 0.5for e in range(epochs):    del_w = np.zeros(weights.shape)    for x, y in zip(features.values, targets):        # Loop through all records, x is the input, y is the target        # Activation of the output unit        #   Notice we multiply the inputs and the weights here         #   rather than storing h as a separate variable         output = sigmoid(np.dot(x, weights))        # The error, the target minus the network output        error = y - output        # The error term        #   Notice we calulate f'(h) here instead of defining a separate        #   sigmoid_prime function. This just makes it faster because we        #   can re-use the result of the sigmoid function stored in        #   the output variable        error_term = error * output * (1 - output)        # The gradient descent step, the error times the gradient times the inputs        del_w += error_term * x    # Update the weights here. The learning rate times the     # change in weights, divided by the number of records to average    weights += learnrate * del_w / n_records    # Printing out the mean square error on the training set    if e % (epochs / 10) == 0:        out = sigmoid(np.dot(features, weights))        loss = np.mean((out - targets) ** 2)        if last_loss and last_loss < loss:            print("Train loss: ", loss, "  WARNING - Loss Increasing")        else:            print("Train loss: ", loss)        last_loss = loss# Calculate accuracy on test datates_out = sigmoid(np.dot(features_test, weights))predictions = tes_out > 0.5accuracy = np.mean(predictions == targets_test)print("Prediction accuracy: {:.3f}".format(accuracy))
import numpy as npimport pandas as pdadmissions = pd.read_csv('binary.csv')# Make dummy variables for rankdata = pd.concat([admissions, pd.get_dummies(admissions['rank'], prefix='rank')], axis=1)data = data.drop('rank', axis=1)# Standarize featuresfor field in ['gre', 'gpa']:    mean, std = data[field].mean(), data[field].std()    data.loc[:,field] = (data[field]-mean)/std    # Split off random 10% of the data for testingnp.random.seed(42)sample = np.random.choice(data.index, size=int(len(data)*0.9), replace=False)data, test_data = data.ix[sample], data.drop(sample)# Split into features and targetsfeatures, targets = data.drop('admit', axis=1), data['admit']features_test, targets_test = test_data.drop('admit', axis=1), test_data['admit']
admit,gre,gpa,rank0,380,3.61,31,660,3.67,31,800,4,11,640,3.19,40,520,2.93,41,760,3,21,560,2.98,10,400,3.08,21,540,3.39,30,700,3.92,20,800,4,40,440,3.22,11,760,4,10,700,3.08,21,700,4,10,480,3.44,30,780,3.87,40,360,2.56,30,800,3.75,21,540,3.81,10,500,3.17,31,660,3.63,20,600,2.82,40,680,3.19,41,760,3.35,21,800,3.66,11,620,3.61,11,520,3.74,41,780,3.22,20,520,3.29,10,540,3.78,40,760,3.35,30,600,3.4,31,800,4,30,360,3.14,10,400,3.05,20,580,3.25,10,520,2.9,31,500,3.13,21,520,2.68,30,560,2.42,21,580,3.32,21,600,3.15,20,500,3.31,30,700,2.94,21,460,3.45,31,580,3.46,20,500,2.97,40,440,2.48,40,400,3.35,30,640,3.86,30,440,3.13,40,740,3.37,41,680,3.27,20,660,3.34,31,740,4,30,560,3.19,30,380,2.94,30,400,3.65,20,600,2.82,41,620,3.18,20,560,3.32,40,640,3.67,31,680,3.85,30,580,4,30,600,3.59,20,740,3.62,40,620,3.3,10,580,3.69,10,800,3.73,10,640,4,30,300,2.92,40,480,3.39,40,580,4,20,720,3.45,40,720,4,30,560,3.36,31,800,4,30,540,3.12,11,620,4,10,700,2.9,40,620,3.07,20,500,2.71,20,380,2.91,41,500,3.6,30,520,2.98,20,600,3.32,20,600,3.48,20,700,3.28,11,660,4,20,700,3.83,21,720,3.64,10,800,3.9,20,580,2.93,21,660,3.44,20,660,3.33,20,640,3.52,40,480,3.57,20,700,2.88,20,400,3.31,30,340,3.15,30,580,3.57,30,380,3.33,40,540,3.94,31,660,3.95,21,740,2.97,21,700,3.56,10,480,3.13,20,400,2.93,30,480,3.45,20,680,3.08,40,420,3.41,40,360,3,30,600,3.22,10,720,3.84,30,620,3.99,31,440,3.45,20,700,3.72,21,800,3.7,10,340,2.92,31,520,3.74,21,480,2.67,20,520,2.85,30,500,2.98,30,720,3.88,30,540,3.38,41,600,3.54,10,740,3.74,40,540,3.19,20,460,3.15,41,620,3.17,20,640,2.79,20,580,3.4,20,500,3.08,30,560,2.95,20,500,3.57,30,560,3.33,40,700,4,30,620,3.4,21,600,3.58,10,640,3.93,21,700,3.52,40,620,3.94,40,580,3.4,30,580,3.4,40,380,3.43,30,480,3.4,20,560,2.71,31,480,2.91,10,740,3.31,11,800,3.74,10,400,3.38,21,640,3.94,20,580,3.46,30,620,3.69,31,580,2.86,40,560,2.52,21,480,3.58,10,660,3.49,20,700,3.82,30,600,3.13,20,640,3.5,21,700,3.56,20,520,2.73,20,580,3.3,20,700,4,10,440,3.24,40,720,3.77,30,500,4,30,600,3.62,30,400,3.51,30,540,2.81,30,680,3.48,31,800,3.43,20,500,3.53,41,620,3.37,20,520,2.62,21,620,3.23,30,620,3.33,30,300,3.01,30,620,3.78,30,500,3.88,40,700,4,21,540,3.84,20,500,2.79,40,800,3.6,20,560,3.61,30,580,2.88,20,560,3.07,20,500,3.35,21,640,2.94,20,800,3.54,30,640,3.76,30,380,3.59,41,600,3.47,20,560,3.59,20,660,3.07,31,400,3.23,40,600,3.63,30,580,3.77,40,800,3.31,31,580,3.2,21,700,4,10,420,3.92,41,600,3.89,11,780,3.8,30,740,3.54,11,640,3.63,10,540,3.16,30,580,3.5,20,740,3.34,40,580,3.02,20,460,2.87,20,640,3.38,31,600,3.56,21,660,2.91,30,340,2.9,11,460,3.64,10,460,2.98,11,560,3.59,20,540,3.28,30,680,3.99,31,480,3.02,10,800,3.47,30,800,2.9,21,720,3.5,30,620,3.58,20,540,3.02,40,480,3.43,21,720,3.42,20,580,3.29,40,600,3.28,30,380,3.38,20,420,2.67,31,800,3.53,10,620,3.05,21,660,3.49,20,480,4,20,500,2.86,40,700,3.45,30,440,2.76,21,520,3.81,11,680,2.96,30,620,3.22,20,540,3.04,10,800,3.91,30,680,3.34,20,440,3.17,20,680,3.64,30,640,3.73,30,660,3.31,40,620,3.21,41,520,4,21,540,3.55,41,740,3.52,40,640,3.35,31,520,3.3,21,620,3.95,30,520,3.51,20,640,3.81,20,680,3.11,20,440,3.15,21,520,3.19,31,620,3.95,31,520,3.9,30,380,3.34,30,560,3.24,41,600,3.64,31,680,3.46,20,500,2.81,31,640,3.95,20,540,3.33,31,680,3.67,20,660,3.32,10,520,3.12,21,600,2.98,20,460,3.77,31,580,3.58,11,680,3,41,660,3.14,20,660,3.94,20,360,3.27,30,660,3.45,40,520,3.1,41,440,3.39,20,600,3.31,41,800,3.22,11,660,3.7,40,800,3.15,40,420,2.26,41,620,3.45,20,800,2.78,20,680,3.7,20,800,3.97,10,480,2.55,10,520,3.25,30,560,3.16,10,460,3.07,20,540,3.5,20,720,3.4,30,640,3.3,21,660,3.6,31,400,3.15,21,680,3.98,20,220,2.83,30,580,3.46,41,540,3.17,10,580,3.51,20,540,3.13,20,440,2.98,30,560,4,30,660,3.67,20,660,3.77,31,520,3.65,40,540,3.46,41,300,2.84,21,340,3,21,780,3.63,41,480,3.71,40,540,3.28,10,460,3.14,30,460,3.58,20,500,3.01,40,420,2.69,20,520,2.7,30,680,3.9,10,680,3.31,21,560,3.48,20,580,3.34,20,500,2.93,40,740,4,30,660,3.59,30,420,2.96,10,560,3.43,31,460,3.64,31,620,3.71,10,520,3.15,30,620,3.09,40,540,3.2,11,660,3.47,30,500,3.23,41,560,2.65,30,500,3.95,40,580,3.06,20,520,3.35,30,500,3.03,30,600,3.35,20,580,3.8,20,400,3.36,20,620,2.85,21,780,4,20,620,3.43,31,580,3.12,30,700,3.52,21,540,3.78,21,760,2.81,10,700,3.27,20,720,3.31,11,560,3.69,30,720,3.94,31,520,4,11,540,3.49,10,680,3.14,20,460,3.44,21,560,3.36,10,480,2.78,30,460,2.93,30,620,3.63,30,580,4,10,800,3.89,21,540,3.77,21,680,3.76,31,680,2.42,11,620,3.37,10,560,3.78,20,560,3.49,40,620,3.63,21,800,4,20,640,3.12,30,540,2.7,20,700,3.65,21,540,3.49,20,540,3.51,20,660,4,11,480,2.62,20,420,3.02,11,740,3.86,20,580,3.36,20,640,3.17,20,640,3.51,21,800,3.05,21,660,3.88,21,600,3.38,31,620,3.75,21,460,3.99,30,620,4,20,560,3.04,30,460,2.63,20,700,3.65,20,600,3.89,3

 

 

 

 

 

 

转载于:https://www.cnblogs.com/fuhang/p/8962065.html

你可能感兴趣的文章
Java面向对象编程概述
查看>>
Android利用文本分割拼接开发一个花藤文字生成
查看>>
哈夫曼树的实现
查看>>
12-18Windows窗体应用小程序之记事本(1)
查看>>
毕业论文一次性修改所有字母和数字的字体
查看>>
结构体:HASH表模板
查看>>
[转]理解Linux文件系统之inode
查看>>
在i3 Cpu上允许64位系统
查看>>
视频编解码学习之五:差错控制及传输
查看>>
Postman教程
查看>>
python模块--os模块
查看>>
HSSFRow获取单元格方法与区别
查看>>
词汇小助手V1.2——可以显示英语单词的国际音标
查看>>
洛谷 1365 WJMZBMR打osu! / Easy
查看>>
删除UINavigationItem上的BarButtonItem
查看>>
数据分析相关模块
查看>>
Python数据结构1-----基本数据结构和collections系列
查看>>
SQL Denali-FileTable
查看>>
C# 图像处理:复制屏幕到内存中,拷屏操作
查看>>
PHP微信支付流程
查看>>