美文网首页
Tesorflow中计算Pairwise的Euclidean D

Tesorflow中计算Pairwise的Euclidean D

作者: Nevrast | 来源:发表于2019-06-20 19:39 被阅读0次

    问题场景

    已知两组向量为:
    \begin{array}{l}{X=\left\{\mathbf{x}_{1}, \mathbf{x}_{2}, \cdots, \mathbf{x}_{n}\right\}} \\ {Y=\left\{\mathbf{y}_{1}, \mathbf{y}_{2}, \cdots, \mathbf{y}_{m}\right\}}\end{array}
    现在要计算X中每一个向量和Y中每一个向量的欧式距离。

    解决思路一

    X中向量使用tf.tile复制m次,把Y中向量复制n次。
    \left[\begin{array}{cccc}{\mathbf{x}_{1}} & {\mathbf{x}_{1}} & {\cdots} & {\mathbf{x}_{1}} \\ {\mathbf{x}_{2}} & {\mathbf{x}_{2}} & {\cdots} & {\mathbf{x}_{2}} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} \\ {\mathbf{x}_{n}} & {\mathbf{x}_{n}} & {\cdots} & {\mathbf{x}_{n}}\end{array}\right] - \left[\begin{array}{cccc}{\mathbf{y}_{1}} & {\mathbf{y}_{2}} & {\cdots} & {\mathbf{y}_{m}} \\ {\mathbf{y}_{1}} & {\mathbf{y}_{2}} & {\cdots} & {\mathbf{y}_{m}} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} \\ {\mathbf{y}_{1}} & {\mathbf{y}_{2}} & {\cdots} & {\mathbf{y}_{m}}\end{array}\right]
    然后按照向量距离公式\|\mathbf{x}-\mathbf{y}\|就可以得到n\times m的距离矩阵了。
    缺点:

    1. 计算量大,复杂度大概是n\times m \times kk为向量维度。
    2. tf.tile后要保存两个很大的矩阵,占资源

    解决思路二

    利用完全平方公式,对于距离矩阵D中的D_{ij}元素(表示X中第i个向量和Y中的第j个向量之间的欧式距离):
    \begin{aligned} D_{i j} &=\left(\mathbf{x}_{i}-\mathbf{y}_{j}\right)^{2} \\ &=\mathbf{x}_{i}^{2}-2 \mathbf{x}_{i} \mathbf{y}_{j}+\mathbf{y}_{j}^{2} \end{aligned}
    其中的\mathbf{x}_{i} \mathbf{y}_{j}矩阵可以由两个向量相乘快速计算。

    import tensorflow as tf
    
    
    def euclidean_dist(x, y):
        square_x = tf.reduce_sum(tf.square(x), axis=-1)
        square_y = tf.reduce_sum(tf.square(y), axis=-1)
        # expand dims for broadcasting
        ex = tf.expand_dims(square_x, axis=-1)
        ey = tf.expand_dims(square_y, axis=-2)
        # XY matrix
        xy = tf.einsum('bij,bkj->bik', x, y)
        # 如果没有batch_size这个维度,可以写成:
        # xy = tf.einsum('ij,kj->ik', x, y)
        # compute distance,浮点防溢出
        dist = tf.sqrt(ex - 2 * xy + ey + 1e-10)
        return dist
    
    

    计算dist的时候加上了1e-10,因为\sqrt{x}x=0处导数不存在,需要做平滑。

    优点:

    1. XY中向量的平方不用重复计算,加速运算。
    2. xy矩阵是标量矩阵,节约显存资源,这个有时很重要!

    相关文章

      网友评论

          本文标题:Tesorflow中计算Pairwise的Euclidean D

          本文链接:https://www.haomeiwen.com/subject/qrtsqctx.html