今天写代码的时候遇到一个问题,网络前向过程中有一个张量A,我想把张量A中的大于0的值变成张量B中对应的值,最初的实现是:
A[A>0]=B[A>0]
然后运行起来就报错了,原因是这个操作属于in-place操作,而pytorch在涉及到求梯度的tensor时,是不允许对这些tensor做原地操作的,否则在反向传播的时候,这些张量计算出来的梯度发生变化。
所以我后来采用了torch.where()方法:
torch.where(condition, x, y) → Tensor
# 使用where方法
C = torch.where(A > 0, B, A)
condition为y的条件表达式,where方法检查y中的所有元素,对于y中满足condition的元素,用x中对应元素替换;否则,还保留y中的元素。where返回一个新的张量。
网友评论