美文网首页
TORCH08-01深入理解Module的训练参数管理

TORCH08-01深入理解Module的训练参数管理

作者: 杨强AT南京 | 来源:发表于2020-04-09 08:20 被阅读0次

      Torch的底层核心是Storage与Tensor;应用核心就是Module的设计封装;Module中比较巧妙的是可训练参数的管理。
      本主题从源代码角度捋了一下,作为Module深入理解的一部分。并使用Module及其相关封装实现抛物线的极小值求解。
      理解Module的设计思想后,基本上Module,Sequential,Layer,Loss Function就可以全部打通理解了。


    参数跟踪

    • 在成员中构建的Layer的参数都会自动被跟踪。
    from torch.nn import Module, Linear
    
    class TestModule(Module):
        def __init__(self):
            super(TestModule, self).__init__()
            self.layer1 = Linear(2, 1)
            
        def forward(self, x):
            return x
    
    
    module = TestModule()
    for param in module.parameters():
        print(param)
    
    Parameter containing:
    tensor([[-0.2155,  0.2611]], requires_grad=True)
    Parameter containing:
    tensor([0.6998], requires_grad=True)
    

    定制参数

    Linear类的实现源代码

    • 用户定义的参数怎样才能被跟踪到? 我们先看看官方的源代码的Linear的实现
       def __init__(self, in_features, out_features, bias=True):
            super(Linear, self).__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = Parameter(torch.Tensor(out_features, in_features))
            if bias:
                self.bias = Parameter(torch.Tensor(out_features))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
    
    • Module类的可训练参数被跟踪机制:

      1. 使用Parameter构建变量默认被跟踪。
      2. 参数的初始化是通过reset_parameters函数实现,而且是在构造器调用一次,如果被改变可以使用reset_parameters()恢复到初始状态。
    • Linear的初始化

      • \text{uniform}(- \sqrt{k}, \sqrt{k}) \qquad k=\dfrac{1}{\text{in_features}}
        def reset_parameters(self):
            init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(self.bias, -bound, bound)
    

    使用Model实现一个抛物线极小值点寻找

    • 思路:
      • 实现抛物线计算
      • 使用迭代n次,自然得到极小值点。
    1. 模型实现
      • 公式:y = x ^ 2 - 3 x + 4
      • 定义参数:因为我们需要求极小值点,就是迭代x。定义x为参数,并初始化一个值。
      • 注册参数
    import torch
    from torch.nn import Module
    from torch.nn.parameter import Parameter
    
    class ParabolaModule(Module):
        def __init__(self):
            super(ParabolaModule, self).__init__()
            self.x = Parameter(torch.tensor(3.0))
            
        def forward(self, x=0):
            return self.x ** 2 - 3 * self.x + 4
    
    1. 迭代计算
    import torch
    from torch.optim import Adam
    from torch.nn import Module
    
    net = ParabolaModule()
    optimizer = Adam(net.parameters(),lr=0.01)
    
    loss = torch.nn.Identity()
    epoch = 1000
    
    for n in range(epoch): # 迭代
        y = net()
        ls = loss(y)
        optimizer.zero_grad()
        ls.backward()
        optimizer.step()
    print(F"训练次数足够大,我们总能找到极值点:{net.x:6.2}", )
    
    训练次数足够大,我们总能找到极值点:   1.5
    
    • 实际上上面的y=x损失函数torch.nn.Identity是可以不需要的,如下:
    import torch
    from torch.optim import Adam
    from torch.nn import Module
    
    net = ParabolaModule()
    optimizer = Adam(net.parameters(),lr=0.01)
    
    epoch = 1000
    
    for n in range(epoch): # 迭代
        y = net()
        optimizer.zero_grad()
        y.backward()
        optimizer.step()
    print(F"训练次数足够大,我们总能找到极值点:{net.x:6.2}", )
    
    训练次数足够大,我们总能找到极值点:   1.5
    

    Parameter类与自动跟踪的关系

    • 原理:

      1. 实现函数 def __setattr__(self, name, value)
        • 这个函数实现,只要使用self.xx= yy;就会导致该函数被调用;
      2. __setattr__函数中判定value类型:
        • 是Parameter类型就会被添加到参数的管理成员:_parameters
        • 而且直接使用属性名作为名字。
    • 所有逻辑都在函数:__setattr__

        def __setattr__(self, name, value):
            def remove_from(*dicts):
                for d in dicts:
                    if name in d:
                        del d[name]
    
            params = self.__dict__.get('_parameters')
            if isinstance(value, Parameter):
                if params is None:
                    raise AttributeError(
                        "cannot assign parameters before Module.__init__() call")
                remove_from(self.__dict__, self._buffers, self._modules)
                self.register_parameter(name, value)
            elif params is not None and name in params:
                if value is not None:
                    raise TypeError("cannot assign '{}' as parameter '{}' "
                                    "(torch.nn.Parameter or None expected)"
                                    .format(torch.typename(value), name))
                self.register_parameter(name, value)
            else:
                modules = self.__dict__.get('_modules')
                if isinstance(value, Module):
                    if modules is None:
                        raise AttributeError(
                            "cannot assign module before Module.__init__() call")
                    remove_from(self.__dict__, self._parameters, self._buffers)
                    modules[name] = value
                elif modules is not None and name in modules:
                    if value is not None:
                        raise TypeError("cannot assign '{}' as child module '{}' "
                                        "(torch.nn.Module or None expected)"
                                        .format(torch.typename(value), name))
                    modules[name] = value
                else:
                    buffers = self.__dict__.get('_buffers')
                    if buffers is not None and name in buffers:
                        if value is not None and not isinstance(value, torch.Tensor):
                            raise TypeError("cannot assign '{}' as buffer '{}' "
                                            "(torch.Tensor or None expected)"
                                            .format(torch.typename(value), name))
                        buffers[name] = value
                    else:
                        object.__setattr__(self, name, value)
    
    • 本质还是调用register_parameter函数实现参数管理:
      • self.register_parameter(name, value)

    相关文章

      网友评论

          本文标题:TORCH08-01深入理解Module的训练参数管理

          本文链接:https://www.haomeiwen.com/subject/ayqqmhtx.html