精读yolov3源码时发现了一个写法,研究了下学到了新的知识点。
PS:这个代码写的真的蛮好的,有很多小细节,学到的东西也多
utils.py源码内容
# iou: torch.tensor(), size(num_of_labels(an image), )
# iou_thres: 一个小数
# t: torch.tensor(), size(num_of_labels(an image), 6)
j = iou > iou_thres
t = t[j]
第一次见到tensor1[tensor2]这种结构,做了些测试
tensor2为torch.uint8时
targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, 0, 1],dtype=torch.uint8)
t = targets[j]
print(t)
'''
输出:
tensor([[ 3, 4],
[ 9, 10],
[13, 14]])
'''
targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, -1, 1],dtype=torch.uint8)
t = targets[j]
print(t)
'''
输出:
tensor([[ 3, 4],
[ 9, 10],
[11, 12],
[13, 14]])
'''
结论1
当tensor2为uint8类型时,tensor1[tensor2]的结果为tensor2不为0元素位置对应的tensor1元素
tensor2不为torch.uint8时
进一步拓展
当tensor2不为uint8类型时结果会怎么样呢?
targets = torch.tensor([ [1,2],[3,4],[5,6],[7,8],[9,10],[11,12],[13,14] ])
j = torch.tensor([0, 1, 0, 0, 1, 0, 1])
t = targets[j]
print(t)
print(j.dtype)
'''
输出:
tensor([[1, 2],
[3, 4],
[1, 2],
[1, 2],
[3, 4],
[1, 2],
[3, 4]])
torch.int64
'''
结论2
此时tensor2的元素表示的是位置
torch.tensor不指定type时为int64类型
网友评论