美文网首页
[pytorch]tensors used as indices

[pytorch]tensors used as indices

作者: 祁晏晏 | 来源:发表于2019-07-11 00:32 被阅读0次

精读yolov3源码时发现了一个写法,研究了下学到了新的知识点。

PS:这个代码写的真的蛮好的,有很多小细节,学到的东西也多

utils.py源码内容

# iou: torch.tensor(), size(num_of_labels(an image), )
# iou_thres: 一个小数
# t: torch.tensor(), size(num_of_labels(an image), 6)

j = iou > iou_thres
t = t[j]

第一次见到tensor1[tensor2]这种结构,做了些测试

tensor2为torch.uint8时

targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, 0, 1],dtype=torch.uint8)
t = targets[j]
print(t)

'''
输出:
tensor([[ 3,  4],
        [ 9, 10],
        [13, 14]])
'''

targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, -1, 1],dtype=torch.uint8)
t = targets[j]
print(t)

'''
输出:
tensor([[ 3,  4],
        [ 9, 10],
        [11, 12],
        [13, 14]])
'''

结论1
当tensor2为uint8类型时,tensor1[tensor2]的结果为tensor2不为0元素位置对应的tensor1元素

tensor2不为torch.uint8时

进一步拓展
当tensor2不为uint8类型时结果会怎么样呢?

targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, 0, 1])
t = targets[j]
print(t)
print(j.dtype)

'''
输出:
tensor([[1, 2],
        [3, 4],
        [1, 2],
        [1, 2],
        [3, 4],
        [1, 2],
        [3, 4]])
torch.int64
'''

结论2
此时tensor2的元素表示的是位置
torch.tensor不指定type时为int64类型

相关文章

网友评论

      本文标题:[pytorch]tensors used as indices

      本文链接:https://www.haomeiwen.com/subject/ouslkctx.html