0%

2020-11-09-pytorch反向传播以及参数更新理解

反向传播以及更新

方法一:手动计算变量

这种方法不常用,因为一般的模型参数太多了

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.autograd import Variable
# 定义参数
w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True)
# 定义输出
d = torch.mean(w1)
# 反向求导
d.backward()
# 定义学习率等参数
lr = 0.001
# 手动更新参数
w1.data.zero_() # BP求导更新参数之前,需先对导数置0
w1.data.sub_(lr*w1.grad.data)12345678910111213

一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim

方法二:使用torch.optim

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
# 这里假设我们定义了一个网络,为net
steps = 10000
# 定义一个optim对象
optimizer = optim.SGD(net.parameters(), lr = 0.01)
# 在for循环中更新参数
for i in range(steps):
optimizer.zero_grad() # 对网络中参数当前的导数置0,清零梯度缓存
output = net(input) # 网络前向计算
loss = criterion(output, target) # 通过定义损失函数:criterion,计算误差,得到网络的损失值:loss;
loss.backward() # 通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。
optimizer.step() # 更新参数123456789101112131415

torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码

loss是反向传播整个计算图/模型(有一条传播路径)的节点参数,其中一个模型可以认为是一个连通图,是由数据传播的,比如encoder和decoder之间会有隐藏向量Z进行连接,那么就是一个计算图,那么loss反向传播就会更新所有的参数。参数在定义时默认就是可动态更新的。

Variable & Parameter的区别

之所以有Variable这个数据结构,是为了引入计算图(自动求导),方便构建神经网络。也就是一般模型网络(计算图)的输入是Variable类型的,是要外部给值的,返回的是tensor类型。

不同于Parameter,Parameter一般是随机初始化,然后根据loss反向传播被动更新值

1
x = Variable(torch.Tensor(array), requires_grad = True) #可以自求导更新,若一个节点requires_grad被设置为True,那么计图中所有依赖它求得的节点的requires_grad都为True

Pytorch主要通过引入nn.Parameter类型的变量和optimizer机制来解决自动更新多个参数的问题。

Parameter是Variable的子类,本质上和后者一样,只不过parameter默认是求梯度的,同时一个网络net中的parameter变量是可以通过 net.parameters() 来很方便地访问到的,只需将网络中所有需要训练更新的参数定义为Parameter类型,再佐以optimizer,就能够完成所有参数的更新了

Parameter是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

Parameter的require_grad默认设置为true。Varaible默认设置为False.

Parameters类是Tensor 的子类, 不过相对于它的父类,Parameters类有一个很重要的特性就是当其在 Module类中被使用并被当做这个Module类的模块属性的时候,那么这个Parameters对象会被自动地添加到这个Module类的参数列表(list of parameters)之中,同时也就会被添加入此Module类的 parameters()方法所返回的参数迭代器中。而Parameters类的父类Tensor类也可以被用为构建模块的属性,但不会被加入参数列表。这样主要是因为,有时可能需要在模型中存储一些非模型参数的临时状态,比如RNN中的最后一个隐状态。而通过使用非Parameter的Tensor类,可以将这些临时变量注册(register)为模型的属性的同时使其不被加入参数列表。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)是 variable 类型,该类型不会被保存到模型中。

网络的权重是 parameter 类型,在计算过程中会被更新,将会被保存到模型中。

https://www.jianshu.com/p/cb739922ce88

https://zhoef.com/2019/08/12/16_Pytorch_Basic/

-------------本文结束感谢阅读-------------
卑微博主,在线求赏