美文网首页
what does torch.nn.CosineEmbeddi

what does torch.nn.CosineEmbeddi

作者: asl_1da7 | 来源:发表于2020-07-29 11:04 被阅读0次
loss function for each sample

similarity = cos(\theta) = \frac{\vec{A} * \vec{B} }{|\vec{A}||\vec{B}|}

def CustomCosineEmbeddingLoss(x1, x2, target):
    x1_ = torch.sqrt(torch.sum(x1 * x1, dim = 1)) # |x1|
    x2_ = torch.sqrt(torch.sum(x2 * x2, dim = 1)) # |x2|
    cos_x1_x2 = torch.sum(x1 * x2, dim = 1)/(x1_ * x2_)
    ans = torch.mean(target- cos_x1_x2)
    return ans
   
cirt =  torch.nn.CosineEmbeddingLoss(reduction = "mean")
x1 = torch.randn((5,3))
x2 = torch.randn((5,3))

a1 = cirt(x1,x2,target)
print(a1)
a2 =CustomCosineEmbeddingLoss(x1,x2, target)
print(a2)
# Out[11]:
# tensor(1.0479)
# tensor(1.0479)

相关文章

网友评论

      本文标题:what does torch.nn.CosineEmbeddi

      本文链接:https://www.haomeiwen.com/subject/lichrktx.html