取出标签对应的样本特征
samples = torch.randn(6,3)
labels = torch.tensor([1,1,0,0,0,1])
# 取出label 是1 的数据
mask = labels == 1
samples[mask]
image.png
第二种操作
samples = torch.randn(6,3)
print(samples)
labels = torch.tensor([1,1,0,0,0,1])
print(labels)
# 取出标签为1
index = torch.eq(labels,1)
index = index.nonzero()[:,0]
print(index)
result = torch.index_select(samples, 0, index)
print(result)
image.png
网友评论