美文网首页TensorFlow技术帖我爱编程
修改TensorFlow张量特定位置的值

修改TensorFlow张量特定位置的值

作者: 巾梵 | 来源:发表于2017-12-31 11:05 被阅读961次

这篇文章简单的讲一讲如何在TensorFlow里指定修改Variable类型张量指定坐标位置的值。

不得不吐槽TensorFlow的张量设计得蛋疼,明明支持下标和切片操作,却只支持到一半,只能读不能改。比如matrix是个二维的Variable,用matrix[x][y]下标,或者matrix[x1:x2][y1:y2]这样的切片能读取出指定位置或者范围的值,但是要是想局部更新一个张量可就没那么容易了。想写matrix[x][y] = 0?试试您就知道了0_0。(说它蛋疼是那是因为有对比,隔壁老李家MXNet的ndarray这么写就没得问题,溜溜的)

那莫,就只好曲线救国啦。首先搜StackOverflow,找到一篇回答How to update a subset of 2D tensor in Tensorflow?,大致的思路是,TensorFlow不让你直接单独改指定位置的值,但是留了个歪门儿,就是tf.scatter_update这个方法,它可以批量替换张量某一维上的所有数据。

照着这个思路改改,写出了第一版的解决方法。提取个函数的话,长成下面这个样子:

def set_value(matrix, x, y, val):
    # 提取出要更新的行
    row = tf.gather(matrix, x)
    # 构造这行的新数据
    new_row = tf.concat([row[:y], [val], row[y+1:]], axis=0)
    # 使用 tf.scatter_update 方法进正行替换
    matrix.assign(tf.scatter_update(matrix, x, new_row))                 

其中matrix是要更新的张量,x和y是目标坐标,val是要写入的值。其余的代码注释得很清楚了,不赘述。

问题解决,但是这么做有没什么缺点呢?有,那就是慢,特别是矩阵很大的时候,那是真心的慢。

继续想办法,TensorFlow是对张量运算(其实二维的就是矩阵运算)有速度优化的,能不能将张量修改的操作变成一个普通的张量运算呢?能,再构建一个差值张量然后做个加法,哎,又是一条旁门邪道。把刚刚的函数改改,参数不变,计算过程变成这样:

def set_value(matrix, x, y, val):
    # 得到张量的宽和高,即第一维和第二维的Size
    w = int(matrix.get_shape()[0])
    h = int(matrix.get_shape()[1])
    # 构造一个只有目标位置有值的稀疏矩阵,其值为目标值于原始值的差
    val_diff = val - matrix[x][y]
    diff_matrix = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=[x, y], values=[val_diff], dense_shape=[w, h]))
    # 用 Variable.assign_add 将两个矩阵相加
    matrix.assign_add(diff_matrix)

注意在这个方法里面我用了一个tf.SparseTensor类型,这是一个TensorFlow里的稀疏张量(或者叫稀疏矩阵),构造它的时候只需要指定有值位置的内容,其余位置默认为0。这样一方面方便了差值张量的构造,另一方面大大的减少了内存的消耗(别忘了我们是要修改一个很大的矩阵)。

实测在我的场景下,后一种方法的效率大概提升了4倍。我的场景是什么呢?其实是cs20si课程作业1的第3题,具体的代码和上下文可以看Github仓库的这个文件

最后,祝各位TF Boy/Girl们,Happy Hacking。

相关文章

网友评论

    本文标题:修改TensorFlow张量特定位置的值

    本文链接:https://www.haomeiwen.com/subject/uzvjxxtx.html