美文网首页
[pytorch]tensors used as indices

[pytorch]tensors used as indices

作者: 祁晏晏 | 来源:发表于2019-07-11 00:32 被阅读0次

    精读yolov3源码时发现了一个写法,研究了下学到了新的知识点。

    PS:这个代码写的真的蛮好的,有很多小细节,学到的东西也多

    utils.py源码内容

    # iou: torch.tensor(), size(num_of_labels(an image), )
    # iou_thres: 一个小数
    # t: torch.tensor(), size(num_of_labels(an image), 6)
    
    j = iou > iou_thres
    t = t[j]
    

    第一次见到tensor1[tensor2]这种结构,做了些测试

    tensor2为torch.uint8时

    targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
    j = torch.tensor([0, 1, 0, 0, 1, 0, 1],dtype=torch.uint8)
    t = targets[j]
    print(t)
    
    '''
    输出:
    tensor([[ 3,  4],
            [ 9, 10],
            [13, 14]])
    '''
    
    targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
    j = torch.tensor([0, 1, 0, 0, 1, -1, 1],dtype=torch.uint8)
    t = targets[j]
    print(t)
    
    '''
    输出:
    tensor([[ 3,  4],
            [ 9, 10],
            [11, 12],
            [13, 14]])
    '''
    

    结论1
    当tensor2为uint8类型时,tensor1[tensor2]的结果为tensor2不为0元素位置对应的tensor1元素

    tensor2不为torch.uint8时

    进一步拓展
    当tensor2不为uint8类型时结果会怎么样呢?

    targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
    j = torch.tensor([0, 1, 0, 0, 1, 0, 1])
    t = targets[j]
    print(t)
    print(j.dtype)
    
    '''
    输出:
    tensor([[1, 2],
            [3, 4],
            [1, 2],
            [1, 2],
            [3, 4],
            [1, 2],
            [3, 4]])
    torch.int64
    '''
    

    结论2
    此时tensor2的元素表示的是位置
    torch.tensor不指定type时为int64类型

    相关文章

      网友评论

          本文标题:[pytorch]tensors used as indices

          本文链接:https://www.haomeiwen.com/subject/ouslkctx.html