方案一:只定义loss函数的前向计算公式
在pytorch中定义了前向计算的公式,在训练时它会自动帮你计算反向传播。
import torch.nn as nn
Class YourLoss(nn.Module):
def __init__():
pass
def forward():
pass
方案二:自定义loss函数的forward和backward
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return input.new(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return grad_output.new(result)
方案三:自己写一个pytorch的C扩展
这个了解不多,所以也不太会
方案四:简单定义
看网上有说直接定义一个简单函数就可以了,可以尝试一下,与只定义forward类似。
import torch
...... #模型操作
loss = torch.sum(x - y)
网友评论