美文网首页
pytorch学习经验(六)torch.where():根据条件

pytorch学习经验(六)torch.where():根据条件

作者: nowherespyfly | 来源:发表于2021-01-18 21:32 被阅读0次

    今天写代码的时候遇到一个问题,网络前向过程中有一个张量A,我想把张量A中的大于0的值变成张量B中对应的值,最初的实现是:

    A[A>0]=B[A>0]
    

    然后运行起来就报错了,原因是这个操作属于in-place操作,而pytorch在涉及到求梯度的tensor时,是不允许对这些tensor做原地操作的,否则在反向传播的时候,这些张量计算出来的梯度发生变化。
    所以我后来采用了torch.where()方法:

    torch.where(condition, x, y) → Tensor
    # 使用where方法
    C = torch.where(A > 0, B, A)
    

    condition为y的条件表达式,where方法检查y中的所有元素,对于y中满足condition的元素,用x中对应元素替换;否则,还保留y中的元素。where返回一个新的张量。

    相关文章

      网友评论

          本文标题:pytorch学习经验(六)torch.where():根据条件

          本文链接:https://www.haomeiwen.com/subject/wtcfzktx.html