首页
视频
资源
登录
原
python 梯度下降和反向传播(下)(学习笔记)
1569
人阅读
2023/4/14 14:56
总访问:
2534996
评论:
0
收藏:
0
手机
分类:
AI
![](https://img.tnblog.net/arcimg/hb/55a6e3fdaa9846cb81829fd20288e216.jpg) >#python 梯度下降和反向传播(下)(学习笔记) [TOC] 前言 ------------ ![](https://img.tnblog.net/arcimg/hb/487c69c122a74c2e89d58f64c6d635c3.png) tn2>回到我们以前讲的有一个激活函数b,这是用来干嘛的呢? 举个例子:有时候豆豆的毒性会随着体积越大而毒性越小,那么此时b所代表的是y毒性自由的移动,w则是毒性。 ![](https://img.tnblog.net/arcimg/hb/e0afb0cf705f49b58e4fe8b4b33993bf.png) tn2>如果b为0的时候代价函数就是前面几篇中的样子。 ![](https://img.tnblog.net/arcimg/hb/7f28e644e1ae41ef931c9e3d1290f47a.png) tn2>把b当成第三个坐标轴,当我们为b附上值,通过一点点的绘制会形成一个曲线,绘制很多之后就是一个曲面图。 ![](https://img.tnblog.net/arcimg/hb/15cc9c59bbc64789999f301110ae909b.png) ![](https://img.tnblog.net/arcimg/hb/9d41a4c3461048a9944249e6f5344ad8.png) tn2>在这个曲面图的最低点的w和b值放回到预测函数中,会让预测的误差最小。 ![](https://img.tnblog.net/arcimg/hb/c8987126d1df4419bbd18bd02b1534a8.png) 如何求出最低点? ------------ tn2>首先我们在e和w的曲面切上一刀,然后通过梯度下降算法调整w,但是你会发现曲线的最低点并不是这个曲面的最低点。 ![](https://img.tnblog.net/arcimg/hb/548c7261a3ad4f56a4717bcf08333a7a.png) ![](https://img.tnblog.net/arcimg/hb/35732d97fb234a72af15bb88ff782982.png) tn2>如果我们在曲面的b和e上也来一刀,我们发现只要将在e和w的曲面上的点,慢慢移动到最b和e上的最低点就比较接近于完美了。 ![](https://img.tnblog.net/arcimg/hb/2df119d91fef473b82b22e6916a5b3d8.png) tn2>但事实上这又产生了两次计算,如果把w和b看成一个整体使用空间位移的方式进行一点点的挪动就会慢慢到达底部了。 ![](https://img.tnblog.net/arcimg/hb/21fe0448c04a477c847533017d42d5f9.png) ![](https://img.tnblog.net/arcimg/hb/1407176f451d4784bab33a89a5d67782.png) ![](https://img.tnblog.net/arcimg/hb/3a08ba290e6a436496944b6ae6b29bfd.png) tn2>由于我们有两个面,所以有两个斜率下降,通过两个面同时下降时的变量称之为偏导数,如果把w和b看作一个向量合起来就是一个新的合向量,沿着这个合向量进行下降的过程称为梯度下降。 所以我们说梯度下降比斜率下降更快。 编程实践 ------------ tn2>首先我们更改一下获取豆豆的变化`dataset.py`。 ```python import numpy as np ... def get_beans2(counts): xs = np.random.rand(counts) xs = np.sort(xs) ys = np.array([(0.7*x+(0.5-np.random.rand())/5+0.5) for x in xs]) return xs,ys ``` ```python # 生成-1到2每次递增加0.1的一组ws ws = np.arange(-1,2,0.1) # 生成-2到2每次递增加0.1的一组bs bs = np.arange(-2,2,0.1) for b in bs: es = [] for w in ws: # 获取预测值 y_pre = w * xs + b # 计算所有样本的均方误差 e = np.sum((ys-y_pre)**2)*(1/m) es.append(e) plt.plot(ws,es) plt.show() # 丝滑 ``` ![](https://img.tnblog.net/arcimg/hb/754ea38740a84b2c90e757fb40551519.png) ```python # 获取图形对象 fig = plt.figure() # 创建3D对象 ax = fig.add_subplot(111,projection='3d') # 生成-1到2每次递增加0.1的一组ws ws = np.arange(-1,2,0.1) # 生成-2到2每次递增加0.1的一组bs bs = np.arange(-2,2,0.1) for b in bs: es = [] for w in ws: # 获取预测值 y_pre = w * xs + b # 计算所有样本的均方误差 e = np.sum((ys-y_pre)**2)*(1/m) es.append(e) # 2D # plt.plot(ws,es) # 3D 传入三个坐标轴,zdir表示谁朝上。 ax.plot(ws, es, b, zdir='y') # 散点 # ax.scatter(ws, es, b, zdir='y') plt.show() ``` ![](https://img.tnblog.net/arcimg/hb/ac080ea23daf4e5b807fedf015fbcbe5.png) tn2>这里由于z轴的范围太大了,所以我们可以通过`ax.set_zlim(0,2)`设置一下z轴的范围。 ```python fig = plt.figure() # 创建3D对象 ax = fig.add_subplot(111,projection='3d') # 设置垂直范围 ax.set_zlim(0,2) ``` ![](https://img.tnblog.net/arcimg/hb/97fabe855edc4a77a452745d4059e38d.png) tn2>接下来我们通过代码,进行50000次训练然后同时从b和w进行梯度下降。 ```python # 50000次学习 for _ in range(500): for i in range(100): # 获取散点的值 x = xs[i] y = ys[i] # w曲面进行斜率下降 dw = 2*x**2*w + 2*x*b - 2*x*y # b曲面进行斜率下降 db = 2*b + 2*x*w - 2*y # 设置阿尔法 alpha = 0.01 # 获取新误差值 w = w - alpha * dw # 合梯度下降 b = b - alpha * db # 清除绘图窗口 plt.clf() # 重新绘制 plt.scatter(xs,ys) y_pre = w*xs + b plt.xlim(0,1) plt.ylim(0,1.2) plt.plot(xs,y_pre) # 暂停0.01秒,不暂停看不到绘制图 plt.pause(0.01) ``` ![](https://img.tnblog.net/arcimg/hb/aa1b682006aa4ff2a943d8d7929b7c5b.png) ![](https://img.tnblog.net/arcimg/hb/617d7914c117471bb1be9a2aa9f0f988.png)
欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739
👈{{preArticle.title}}
👉{{nextArticle.title}}
评价
{{titleitem}}
{{titleitem}}
{{item.content}}
{{titleitem}}
{{titleitem}}
{{item.content}}
尘叶心繁
这一世以无限游戏为使命!
博主信息
排名
6
文章
6
粉丝
16
评论
8
文章类别
.net后台框架
166篇
linux
17篇
linux中cve
1篇
windows中cve
0篇
资源分享
10篇
Win32
3篇
前端
28篇
传说中的c
4篇
Xamarin
9篇
docker
15篇
容器编排
101篇
grpc
4篇
Go
15篇
yaml模板
1篇
理论
2篇
更多
Sqlserver
4篇
云产品
39篇
git
3篇
Unity
1篇
考证
2篇
RabbitMq
23篇
Harbor
1篇
Ansible
8篇
Jenkins
17篇
Vue
1篇
Ids4
18篇
istio
1篇
架构
2篇
网络
7篇
windbg
4篇
AI
18篇
threejs
2篇
人物
1篇
嵌入式
2篇
python
13篇
HuggingFace
8篇
pytorch
9篇
opencv
6篇
最新文章
最新评价
{{item.articleTitle}}
{{item.blogName}}
:
{{item.content}}
关于我们
ICP备案 :
渝ICP备18016597号-1
网站信息:
2018-2024
TNBLOG.NET
技术交流:
群号656732739
联系我们:
contact@tnblog.net
欢迎加群
欢迎加群交流技术