Pytorch:利用torch.nn.Modules.parameters修改模型参数

导读:本篇文章讲解 Pytorch:利用torch.nn.Modules.parameters修改模型参数,希望对大家有帮助,欢迎收藏,转发!站点地址:www.bmabk.com

1. 关于parameters()方法

Pytorch中继承了torch.nn.Module的模型类具有named_parameters()/parameters()方法,这两个方法都会返回一个用于迭代模型参数的迭代器(named_parameters还包括参数名字):

import torch

net = torch.nn.LSTM(input_size=512, hidden_size=64)
print(net.parameters())
print(net.named_parameters())
# <generator object Module.parameters at 0x12a4e9890>
# <generator object Module.named_parameters at 0x12a4e9890>

我们可以将net.parameters()迭代器和将net.named_parameters()转化为列表类型,前者列表元素是模型参数,后者是包含参数名和模型参数的元组。

当然,我们更多的是对迭代器直接进行迭代:

for param in net.parameters():
    print(param.shape)
# torch.Size([256, 512])
# torch.Size([256, 64])
# torch.Size([256])
# torch.Size([256])
for name, param in net.named_parameters():
    print(name, param.shape)
# weight_ih_l0 torch.Size([256, 512])
# weight_hh_l0 torch.Size([256, 64])
# bias_ih_l0 torch.Size([256])
# bias_hh_l0 torch.Size([256])

我们知道,Pytorch在进行优化时需要给优化器传入这个参数迭代器,如:

from torch.optim import RMSprop
optimizer = RMSprop(net.parameters(), lr=0.01)

2. 关于参数修改

那么底层具体是怎么对参数进行修改的呢?

我们在博客《Python对象模型与序列迭代陷阱》中介绍过,Python序列中本身存放的就是对象的引用,而迭代器返回的是序列中的对象的二次引用,如果序列的引用指向基础数据类型,则是不可以通过遍历序列进行修改的,如:

my_list = [1, 2, 3, 4]
for x in my_list:
    x += 1
print(my_list) #[1, 2, 3, 4]

而序列中的引用指向复合数据类型,则可以通过遍历序列来完成修改操作,如:

my_list = [[1, 2],[3, 4]]
for sub_list in my_list:
    sub_list[0] += 1
print(my_list)
# [1, 2, 3, 4]
# [[2, 2], [4, 4]]

具体原理可参照该篇博客,此处我就不在赘述。这里想提到的是,用net.parameters()/net.named_parameters()来迭代并修改参数,本质上就是上述第二种对复合数据类型序列的修改。我们可以如下写:

for param in net.parameters():
    with torch.no_grad():
        param += 1

with torch.no_grad():表示将将所要修改的张量关闭梯度计算。所增加的1会广播到param张量的中的每一个元素上。上述操作本质上为:

for param in net.parameters():
    with torch.no_grad():
        param += torch.ones(param.shape)

但是需要注意,如果我们想让参数全部置为0,切不可像下列这样写:

for param in net.parameters():
    with torch.no_grad():
        param = torch.zeros(param.shape) 

param是二次引用,param=0操作再语义上会被解释为让param这个二次引用去指向新的全0张量对象,但是对参数张量本身并不会产生任何变动。该操作实际上类似下列这种操作:

list_1 = [1, 2]
list_2 = list_1
list_2 = [0, 0]
print(list_1) # [1, 2]

修改二次引用list_2自然不会影响到list_1引用的对象。

下面让我们纠正这种错误,采用下列方法直接来将参数张量中的所有数值置0:

for param in net.parameters():
    with torch.no_grad():
        param[:] = 0 #张量类型自带广播操作,等效于param[:] = torch.zeros(param.shape) 

这时语义上就类似

list_1 = [1, 2]
list_2 = list_1
list_2[:] = [0, 0]
print(list_1) # [0, 0]

自然就能完成修改的操作了。

参考

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/12489.html

(0)
小半的头像小半

相关推荐

极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!