Torch的底层核心是Storage与Tensor;应用核心就是Module的设计封装;Module中比较巧妙的是可训练参数的管理。
本主题从源代码角度捋了一下,作为Module深入理解的一部分。并使用Module及其相关封装实现抛物线的极小值求解。
理解Module的设计思想后,基本上Module,Sequential,Layer,Loss Function就可以全部打通理解了。
参数跟踪
- 在成员中构建的Layer的参数都会自动被跟踪。
from torch.nn import Module, Linear
class TestModule(Module):
def __init__(self):
super(TestModule, self).__init__()
self.layer1 = Linear(2, 1)
def forward(self, x):
return x
module = TestModule()
for param in module.parameters():
print(param)
Parameter containing:
tensor([[-0.2155, 0.2611]], requires_grad=True)
Parameter containing:
tensor([0.6998], requires_grad=True)
定制参数
Linear类的实现源代码
- 用户定义的参数怎样才能被跟踪到? 我们先看看官方的源代码的Linear的实现
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
-
Module类的可训练参数被跟踪机制:
- 使用Parameter构建变量默认被跟踪。
- 参数的初始化是通过reset_parameters函数实现,而且是在构造器调用一次,如果被改变可以使用reset_parameters()恢复到初始状态。
-
Linear的初始化
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
使用Model实现一个抛物线极小值点寻找
- 思路:
- 实现抛物线计算
- 使用迭代n次,自然得到极小值点。
- 模型实现
- 公式:
- 定义参数:因为我们需要求极小值点,就是迭代x。定义x为参数,并初始化一个值。
- 注册参数
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
class ParabolaModule(Module):
def __init__(self):
super(ParabolaModule, self).__init__()
self.x = Parameter(torch.tensor(3.0))
def forward(self, x=0):
return self.x ** 2 - 3 * self.x + 4
- 迭代计算
import torch
from torch.optim import Adam
from torch.nn import Module
net = ParabolaModule()
optimizer = Adam(net.parameters(),lr=0.01)
loss = torch.nn.Identity()
epoch = 1000
for n in range(epoch): # 迭代
y = net()
ls = loss(y)
optimizer.zero_grad()
ls.backward()
optimizer.step()
print(F"训练次数足够大,我们总能找到极值点:{net.x:6.2}", )
训练次数足够大,我们总能找到极值点: 1.5
- 实际上上面的损失函数torch.nn.Identity是可以不需要的,如下:
import torch
from torch.optim import Adam
from torch.nn import Module
net = ParabolaModule()
optimizer = Adam(net.parameters(),lr=0.01)
epoch = 1000
for n in range(epoch): # 迭代
y = net()
optimizer.zero_grad()
y.backward()
optimizer.step()
print(F"训练次数足够大,我们总能找到极值点:{net.x:6.2}", )
训练次数足够大,我们总能找到极值点: 1.5
Parameter类与自动跟踪的关系
-
原理:
- 实现函数
def __setattr__(self, name, value)
- 这个函数实现,只要使用self.xx= yy;就会导致该函数被调用;
- 在
__setattr__
函数中判定value类型:- 是Parameter类型就会被添加到参数的管理成员:
_parameters
- 而且直接使用属性名作为名字。
- 是Parameter类型就会被添加到参数的管理成员:
- 实现函数
所有逻辑都在函数:
__setattr__
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
- 本质还是调用register_parameter函数实现参数管理:
self.register_parameter(name, value)