引
我很喜欢这一篇文章,因为证明用到的知识并不难,但是却用的很巧,数学真是太牛了,这些人的嗅觉怎么这么好呢?
这篇文章,归根结底就是想说明一个问题,就是和一般的认知不同,随着神经网络的加深,参数更新的收敛速度并不会下降,感觉也有很多论文论述了深度depth的重要性.
不过,这篇文章,是在线性神经网络上做的一个分析,另外,标题中的Acceleration并没有很好的理论支撑,作者给出了几个特例和一些实验论据。我想作者肯定尝试过,但是想要证明想想就不易,至少得弄出个之类的.
虽然理论支撑不够,但是我感觉还是很厉害了.
主要内容
首先,为了排除一些干扰因素,就是Acceleration来自于俩个网络的表达能力不同,神经网络, 如果二者的收敛速度不同,原因可能是和能让损失下降的程度不同. 而在线性网络中,层数增加并不会改变网络的表达能力.
是关于的损失函数,这个网络的表达能力和的表达能力是相同的,如果.
对上面的结论,有一点点存疑,假设后者的最优为,那么只要让即可,所以.
反过来似乎不一定,假设, , 但是利用here的结果,只要, 满足且关于为凸函数,就能说明等价. 居然还用上了之前看过的结果.
符号可能有点多,尽可能简化点吧. 为样本,为输出,
显然. 假设是关于的函数, 可得
不要觉得这么做多此一举,不然后面证明的时候会弄乱的.
梯度下降采用了类似momentum的感觉,但是又有点不一样:
的学习率,是权重的递减系数.
定义
故.
作者假设,也就是学习率是一个小量,所以上面的式子可以从微分方程的角度去看
怎么说呢,这个所以,我的理解是很小的时候,很平缓,所以可以认为导数和的时候是一样的?
看了之前有一篇类似的Oja'rule也用了这种方法,感觉作者的意思应该是如果:
此时,
我觉得应该是这个样子的,不过对于最后的结果没有影响.
定理1
定理1 假设权重矩阵满足微分方程:
且
则权重矩阵的变化满足下列微分方程:
其中是关于半正定矩阵的一个定义,假如:
对角矩阵是让对角线元素的.
所以,权重的更新变换近似于:
Claim 1
上面的更新实际上让人看不出一个所以然来,所以作者给出了一个向量形式的更新方式,可以更加直观地展现其中地奥秘.
Claim 1 对于任意矩阵, 定义为由矩阵按列重排后的向量形式. 于是,
其中是一个半正定矩阵,依赖于, 假设
其中, 的对角线元素,即的奇异值从大到小为, 则的特征向量和对应的特征值为:
在这里插入图片描述
这说明了什么呢?也就是overparameterization后的更新,的更新,也就是的更新倾向于, 感觉这一点就和一些梯度下降方法的思想有点类似了,借用之前的成果. 而且,这个借用,会有一种坐标之间的互相沟通,一般的下降方法是不具备这一点的.
Claim 2
在这里插入图片描述定理2
定理2 假设在处有定义,的某个邻域内连续,那么对于给定的, 定义:
那么,不存在关于的一个函数,其梯度场为.
定理2的意义在于,它告诉我们,overparameterization的方法是不能通过添加正则项来实现的,因为不存在原函数,所以诸如
的操作是不可能实现overparametrization的更新变化的.
证明思路是,构造一个封闭曲线,证明在其上的线积分不为0. (太帅了...)
证明
定理1的证明
首先是一些符号:
用
表示块对角矩阵.
容易证明(其实费了一番功夫,但是不想写下来,因为每次都会忘,如果下次忘了,就再推一次当惩罚):
在这里插入图片描述于是
第个等式俩边右乘, 第个等式俩边左乘可得:
在这里插入图片描述
俩边乘以2
在这里插入图片描述
令, 则
注意,我们将上面的等式改写以下,等价于
用, 则
另外有初值条件(这是题设的条件).
容易知道,上面的微分方程的解为.
所以
假设的奇异值分解为
且假设的对角线元素,即奇异值是从大到小排列的.
则可得
显然, 这是因为一个矩阵的特征值是固定的(如果顺序固定的话),特征向量是不一定的,因为可能有多个相同的特征值,那么对于一个特征值的子空间的任意正交基都可以作为特征向量,也就是说
在这里插入图片描述
其中是单位矩阵, 是正交矩阵.
所以对于, 成立
有
故
在这里插入图片描述
注意,上面的推导需要用到:
既然
在这里插入图片描述
那么
在这里插入图片描述
在这里插入图片描述
上式左端为, 于是
在这里插入图片描述
再利用(23)(24)的结论
在这里插入图片描述Claim 1 的证明
Kronecker product (克罗内克积)
网上似乎都用, 不过这里还是遵循论文的使用规范吧, 用来表示Kronecker product:
其中.
容易证明 的第列为:
其中表示的第列, 沿用为的列展开. 相应的,的第行为:
其中表示的第行.
用表示的第列行的元素, 则
另外.
下面再证明几个重要的性质:
假设, 则
考察俩边矩阵的的元素,
得证. 注意,倒数第四个等式到倒数第三个用到了迹的可交换性.
所以.
回到Claim 1 的证明上来,容易证明
于是
在这里插入图片描述
第二个等式用到了.
只需要证明:
在这里插入图片描述
等价于. 令
其中.
所以
第三个等式用了俩次.
定义:
在这里插入图片描述则
剩下的,关于的列
的对角元素:
在这里插入图片描述
只是一些简单的推导罢了.
Theorem 2 的证明
这个证明我不想贴在这里,因为这个证明我只能看懂,所以想知道就直接看原文吧.
代码
在这里插入图片描述在这里插入图片描述
虽然只是用了一个很简单的例子做实验,但是感觉,这个迭代算法很吃初始值. 就像Claim 1 所解释的那样,这个下降方法,会更倾向于之前的方向,也就是之前的错了,后面也会错?
y1设置为100, y2设置为1, lr=0.005, 会出现(也有可能是收敛不到0):
在这里插入图片描述
这种下降的方式是蛮恐怖的啊,但是感觉实在是不稳定. 当然,也有可能是程序写的太烂了.
"""
On the Optimization of Deep
Net works: Implicit Acceleration by
Overparameterization
"""
import numpy as np
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
class Net(nn.Module):
def __init__(self, d, k):
"""
:param k: 输出维度
:param d: 输入维度
"""
super(Net, self).__init__()
self.d = d
self.dense = nn.Sequential(
nn.Linear(d, k)
)
def forward(self, input):
x = input.view(-1, self.d)
output = self.dense(x)
return output
class Overparameter(Optimizer):
def __init__(self, params, N, lr=required, weight_decay=1.):
defaults = dict(lr=lr)
super(Overparameter, self).__init__(params, defaults)
self.N = N
self.weight_decay = weight_decay
def __setstate__(self, state):
super(Overparameter, self).__setstate__(state)
print("????")
print(state)
print("????")
def step(self, colsure=None):
def calc_part2(W, dw, N):
dw = dw.detach().numpy()
w = W.detach().numpy()
norm = np.linalg.norm(w, 2)
part2 = norm ** (2-2/N) * (
dw +
(N - 1) * (w @ dw.T) * w / (norm ** 2 + 1e-5)
)
return torch.from_numpy(part2)
p = self.param_groups[0]['params'][0]
if p.grad is None:
return 0
d_p = p.grad.data
part1 = (self.weight_decay * p.data).float()
part2 = (calc_part2(p, d_p, self.N)).float()
p.data -= self.param_groups[0]['lr'] * (part1+part2)
return 1
class L4Loss(nn.Module):
def __init__(self):
super(L4Loss, self).__init__()
def forward(self, x, y):
return torch.norm(x-y, 4)
x1 = torch.tensor([1., 0])
y1 = torch.tensor(10.)
x2 = torch.tensor([0, 1.])
y2 = torch.tensor(2.)
net = Net(2, 1)
criterion = L4Loss()
opti = Overparameter(net.parameters(), 4, lr=0.01)
loss_store = []
for epoch in range(500):
running_loss = 0.0
out1 = net(x1)
loss1 = criterion(out1, y1)
opti.zero_grad()
loss1.backward()
opti.step()
running_loss += loss1.item()
out2 = net(x2)
loss2 = criterion(out2, y2)
opti.zero_grad()
loss2.backward()
opti.step()
running_loss += loss2.item()
#print(running_loss)
loss_store.append(running_loss)
net = Net(2, 1)
criterion = nn.MSELoss()
opti = torch.optim.SGD(net.parameters(), lr=0.01)
loss_store2 = []
for epoch in range(500):
running_loss = 0.0
out1 = net(x1)
loss1 = criterion(out1, y1)
opti.zero_grad()
loss1.backward()
opti.step()
running_loss += loss1.item()
out2 = net(x2)
loss2 = criterion(out2, y2)
opti.zero_grad()
loss2.backward()
opti.step()
running_loss += loss2.item()
#print(running_loss)
loss_store2.append(running_loss)
import matplotlib.pyplot as plt
plt.plot(range(len(loss_store)), loss_store, color="red", label="Over")
plt.plot(range(len(loss_store2)), loss_store2, color="blue", label="normal")
plt.legend()
plt.show()
网友评论