美文网首页
实现三维tensor计算欧式距离

实现三维tensor计算欧式距离

作者: 全村希望gone | 来源:发表于2019-06-27 11:41 被阅读0次

    前言

    在网上找不到实现三维tensor计算欧式距离的代码,tensorflow中也没有封装,于是自己写,写得时候有点曲折,花了两个小时,主要是因为三维tensor的维度问题,稍不注意便出错,还需要多用debug和print才能看懂,还有二维转三维,不得不说tf.expand_ndims真好用。Talk is cheap,show you the code.


    补充

    刚刚在项目中用了一下,发现根本不使用,如果你的tensor非常大,batch很多,那么运行起来会非常非常慢,但是稍小一点的tensor还是可以用的。

    代码

    import tensorflow as tf
    
    a = tf.constant([[[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]],
                     [[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]]], dtype=tf.float32)
    
    
    def euclidean_metric(embeddings):
        width = embeddings.shape.as_list()[1]
        # 容器,便于拼接之后的向量
        output = tf.sqrt(tf.reduce_sum(tf.square(embeddings[:, 0] - embeddings[:, 0]), axis=1, keepdims=True))
        # 两次循环是因为每个向量与其它向量都有一个欧式距离,而且欧氏距离矩阵的大小就是width*width
        for i in range(width):
            for j in range(width):
                # 计算距离
                distance = tf.sqrt(tf.reduce_sum(tf.square(embeddings[:, i] - embeddings[:, j]), axis=1, keepdims=True))
                # 在行的方向上拼接
                output = tf.concat([output, distance], 1)
        # 将最开始的容器去掉
        output_slice = tf.slice(output, [0, 1], [-1, width * width])
        # 升维,与原embeddings维度一致,方便后面reshape
        output_expand = tf.expand_dims(output_slice, -1)
        output_reshape = tf.reshape(output_expand, [-1, width, width])
        return output_reshape
    
    
    with tf.Session() as sess:
        b = euclidean_metric(a)
        print(sess.run(b))import tensorflow as tf
    
    a = tf.constant([[[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]],
                     [[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]]], dtype=tf.float32)
    

    输出结果

    [[[0.       2.236068 4.472136]
      [2.236068 0.       2.236068]
      [4.472136 2.236068 0.      ]]
    
     [[0.       2.236068 4.472136]
      [2.236068 0.       2.236068]
      [4.472136 2.236068 0.      ]]]
    

    相关文章

      网友评论

          本文标题:实现三维tensor计算欧式距离

          本文链接:https://www.haomeiwen.com/subject/anmlcctx.html