美文网首页
pytorch 计算2维矩阵每行最小N个值的索引

pytorch 计算2维矩阵每行最小N个值的索引

作者: Ailien | 来源:发表于2022-01-26 19:46 被阅读0次

    以二维numpy矩阵为例

    import torch
    import numpy as np
    K=3 #取每行最小3个值的索引
    data=np.random.rand(4,7)
    print(data)
    data=torch.from_numpy(data)
    a, idx = torch.sort(data, descending=False)
    lists=idx[:,:K]
    print(lists)
    

    运行结果如下:

    results.jpg

    相关文章

      网友评论

          本文标题:pytorch 计算2维矩阵每行最小N个值的索引

          本文链接:https://www.haomeiwen.com/subject/bqunhrtx.html