程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

暂无数据

如何创建带有条件的 PyTorch 钩子?

发布于2022-10-06 22:15     阅读(801)     评论(0)     点赞(2)     收藏(5)


我正在学习钩子并使用二值化神经网络。问题是有时我的梯度在反向传递中为 0。我正在尝试用某个值替换那些渐变。

说我有以下网络

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)        
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Model()

opt = optim.Adam(net.parameters())

还有一些功能

features = torch.rand((3,1))

我可以使用以下方法正常训练它:

for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss.backward()
    opt.step()

如何附加一个钩子函数,该函数将具有以下反向传递条件(对于每一层):

  • 如果单层中的所有梯度都为 0,则将它们更改为 1.0。

  • 如果其中一个梯度为 0,但至少有一个梯度不是 0,则将其更改为 0.5。


解决方案


您可以在nn.Modulewith上附加一个回调函数nn.Module.register_full_backward_hook

您将不得不处理这两种情况:如果所有元素都等于零,则使用torch.all,否则(至少一个非零)如果至少一个元素等于零,则使用torch.any

def grad_mod(module, grad_inputs, grad_outputs):
    if module.weight.grad is None: # safety measure for last layer 
        return None                # and layers w/ require_grad=False

    flat = module.weight.grad.view(-1)
    if torch.all(flat == 0):
        flat.data.fill_(1.)
    elif torch.any(flat == 0):
        flat.data.scatter_(0, (flat == 0).nonzero()[:,0], value=.5)

第一个子句中的指令将填充所有值,1.而第二个子句中的指令仅将零值替换为.5.

将挂钩连接到nn.Module

>>> net.fc3.register_full_backward_hook(grad_mod)

这里我使用print变异前后的语句flat来展示钩子的效果:

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0.0947, 0.0000, 0.0000]) # before
>>> tensor([0.0947, 0.5000, 0.5000]) # after

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0., 0., 0.])             # before
>>> tensor([1., 1., 1.])             # after

为了将此钩子应用于多个层,您可以包装grad_mod并利用nn.Module.apply递归行为:

>>> def apply_grad_mod(module):
...     if hasattr(module, 'weight'):
...         module.register_full_backward_hook(grad_mod)

然后下面将在所有层权重上应用钩子。

>>> net.apply(apply_grad_mod)

注意:如果您还希望影响偏差,则必须扩展此行为!



所属网站分类: 技术文章 > 问答

作者:黑洞官方问答小能手

链接:https://www.pythonheidong.com/blog/article/1793559/926e2ea7a5b2e9b7d6af/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

2 0
收藏该文
已收藏

评论内容:(最多支持255个字符)