复函数的可导性
复变函数按照是否可导,分为全纯函数holomothic和nonholomophic,判断条件为Cauchy-Riemann方程。
对于不可导的nonholomophic函数:
Wirtinger算子
采用Wirtinger算子来计算反向传播。
Wirtinger算子的思路是,将任何复变函数f,看做f(z,z*),求导数就是对z和共轭z*分别求导:
其中:
z=x + jy。
而全纯函数f(z),当且仅当df/dz*=0。
Pytorch实现
损失函数梯度
损失函数J的梯度为:
且由于J为实数,因此:
综上,算法流程如下:
1) 全纯函数y=f(w):
由于dy/dw*=0,由推导可知,梯度与实数域结果一样,无需额外实现
2)非全纯函数y=f(w,w*):
a, 求得g1 = dy/dw,g2 = dy/dw*。
b, 拿到上层backward回来的梯度,也就是grad_output
c, 求得本节点的梯度 += grad_output.g1* + grad_output*.g2
具体实现
pytorch自动求导机制可以通过继承torch.autograd.Function来扩展求导算法。由上可知,只需要扩展非全纯函数即可。
复数的矩阵表示形式为z[..., 2],最后维度的2个值分别是实部和虚部。
例如函数 y=z.z*的实现如下:
网友评论