美文网首页
跨模态检索的mAP(@R)和 PR曲线(Precision-Re

跨模态检索的mAP(@R)和 PR曲线(Precision-Re

作者: s苏薳 | 来源:发表于2020-08-19 11:52 被阅读0次

    对于基础知识网上资料很多,对此不在重述。本文主要是记录pytorch下怎么实现。

    1.mAP(@R)(参考DCMH:https://github.com/WendellGul/DCMH/blob/master/utils.py)

    import torch
    
    def calc_hammingDist(B1, B2):
      q = B2.shape[1]
      if len(B1.shape) < 2:
          B1 = B1.unsqueeze(0)
      distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
      return distH
    
    
    def calc_map_k(qB, rB, query_L, retrieval_L, k=None):
      # qB: {-1,+1}^{mxq}
      # rB: {-1,+1}^{nxq}
      # query_L: {0,1}^{mxl}
      # retrieval_L: {0,1}^{nxl}
      num_query = query_L.shape[0]
      map = 0
      if k is None:
          k = retrieval_L.shape[0]
      for iter in range(num_query):
          q_L = query_L[iter]
          if len(q_L.shape) < 2:
              q_L = q_L.unsqueeze(0)
          gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
          tsum = torch.sum(gnd)
          if tsum == 0:
              continue
          hamm = calc_hammingDist(qB[iter, :], rB)
          _, ind = torch.sort(hamm)
          ind.squeeze_()
          gnd = gnd[ind]
          total = min(k, int(tsum))
          count = torch.arange(1, total + 1).type(torch.float32)
          tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
          if tindex.is_cuda:
              count = count.cuda()
          map = map + torch.mean(count / tindex)
      map = map / num_query
      return map
    

    2.Precision-Recall Curve(参考:https://blog.csdn.net/HackerTom/article/details/89425729

    import matplotlib.pyplot as plt
    def pr_curve(qB, rB, qL, rL, ep, task,  topK=-1):
      n_query = qB.shape[0]
      if topK == -1 or topK > rB.shape[0]:  # top-K 之 K 的上限
        topK = rB.shape[0]
    
    
      # Gnd = (np.dot(qL, rL.transpose()) > 0).astype(np.float32)
      Gnd = (qL.mm(rL.transpose(0, 1)) > 0).type(torch.float32)
      _,Rank =  torch.sort(calc_hammingDist(qB, rB))
      P, R = [], []
      # KK = []
      # K_ = [x * 2000 + 1 for x in range(1, int(topK/2000))]
      # for i in K_:
      #     if i < topK:
      #         KK.append(i)
      for k in range(1, topK+1):  # 枚举 top-K 之 K
          # ground-truth: 1 vs all
          p = torch.zeros(n_query)  # 各 query sample 的 Precision@R
          r = torch.zeros(n_query)  # 各 query sample 的 Recall@R
          for it in range(n_query):  # 枚举 query sample
    
              gnd = Gnd[it]
              gnd_all = torch.sum(gnd)  # 整个被检索数据库中的相关样本数
              if gnd_all == 0:
                  continue
    
              asc_id = Rank[it][:k]
              gnd = gnd[asc_id]
              gnd_r = torch.sum(gnd)  # top-K 中的相关样本数
              p[it] = gnd_r / k
              r[it] = gnd_r / gnd_all
    
          P.append(torch.mean(p))
          R.append(torch.mean(r))
      print(P)
      print(R)
    
    # 画 P-R 曲线
    fig = plt.figure(figsize=(5, 5))
    plt.plot(R, P)  # 第一个是 x,第二个是 y
    plt.grid(True)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel('recall')
    plt.ylabel('precision')
    plt.legend()
    plt.show()
    

    3.实列说明

    if __name__ == '__main__':
      qB = torch.Tensor([[1, -1, 1, 1],
                         [-1, -1, -1, 1],
                         [1, 1, -1, 1],
                         [1, 1, 1, -1]])
      rB = torch.Tensor([[1, -1, 1, -1],
                         [-1, -1, 1, -1],
                         [-1, -1, 1, -1],
                         [1, 1, -1, -1],
                         [-1, 1, -1, -1],
                         [1, 1, -1, 1]])
      query_L = torch.Tensor([[0, 1, 0, 0],
                              [1, 1, 0, 0],
                              [1, 0, 0, 1],
                              [0, 1, 0, 1]])
      retrieval_L = torch.Tensor([[1, 0, 0, 1],
                                  [1, 1, 0, 0],
                                  [0, 1, 1, 0],
                                  [0, 0, 1, 0],
                                  [1, 0, 0, 0],
                                  [0, 0, 1, 0]])
    
      map = calc_map_k(qB, rB, query_L, retrieval_L)
      print("map", map)
      pr = pr_curve(qB, rB, query_L, retrieval_L, 2, 'i2t',  topK=-1)
      print('pr', pr)
    

    4.结果

    map tensor(0.7042)
    [tensor(0.5000), tensor(0.5000), tensor(0.6667), tensor(0.6250), tensor(0.6000), tensor(0.5000)]
    [tensor(0.1458), tensor(0.3333), tensor(0.6875), tensor(0.8542), tensor(1.), tensor(1.)]  
    
    Figure_1.png

    相关文章

      网友评论

          本文标题:跨模态检索的mAP(@R)和 PR曲线(Precision-Re

          本文链接:https://www.haomeiwen.com/subject/tfyrjktx.html