美文网首页
TensorFlow核心概念之Tensor(2):索引切片

TensorFlow核心概念之Tensor(2):索引切片

作者: 老羊_肖恩 | 来源:发表于2023-01-09 23:25 被阅读0次

      张量(Tensor)是TensorFlow的基本数据结构。张量即多维数组(0~n维),Tensorflow的张量和Numpy中的ndarray对象很类似,用来定义一个 n 维数组对象,它是一个一系列相同类型元素组成的数组集合。Tensor的索引及切片操作和Numpy中的索引及切片操作基本上相差无几,这也是为啥Numpy是真个数据分析领域最重要的工具包之一,因为很多的后来者针对多维数组的创建和操作都是基于或者借鉴了Numpy的思想,比如Pandas中的DataFrame以及我们这里要介绍的TensorFlow中的Tensor。接下来我们分别从一维张量的索引及切片和多维张量的索引及切片来展示张量的索引及切片操作。在展示索引及切片功能之前,我们需要先引入NumpyTensorFlow的包,代码如下:

    import numpy as np
    import tensorflow as tf
    
    print(np.__version__)
    print(tf.__version__)
    

    输出如下:

    1.23.5
    2.11.0
    
    一维张量索引及切片

      一维张量的索引及切片操作与Numpy中的一维数组的索引及切片基本一致,也跟Pandas中的Series索引及切片差不多,只是TensorFlow中的张量分成常量和变量,对索引切片的赋值在Variable中需要使用assign等方法来实现。下面展示一维常量张量的索引及切片操作,代码如下:

    #一维张量的索引及切片
    tf.random.set_seed(1)
    c = tf.random.uniform([6], minval=0, maxval=10, dtype=tf.int32)
    print("原始张量:")
    tf.print(c)
    print("----------------------")
    
    #第i个元素,下标从0开始
    print("第[0, 3, 5]个元素:")
    tf.print(c[0], c[3], c[5])
    print("----------------------")
    
    #倒数第i个元素,下标从-1开始
    print("倒数第[1, 3, 6]个元素:")
    tf.print(c[-1], c[-3], c[-6])
    print("----------------------")
    
    #第i至第j个元素,不包括元素j
    print("第3至第5个元素:")
    tf.print(c[3:5])
    print("----------------------")
    

    结果如下:

    原始张量:
    [2 9 7 1 8 6]
    ----------------------
    第[0, 3, 5]个元素:
    2 1 6
    ----------------------
    倒数第[1, 3, 6]个元素:
    6 1 2
    ----------------------
    第3至第5个元素:
    [1 8]
    ----------------------
    

      由于常量是不可改变的,如果在建模的过程中需要用到可变的一维张量,那么可以使用一维的Variable张量,关于一维的Variable张量的索引及切片操作代码如下:

    #对Variable来说,可以使用assign来通过索引切片修改分元素的值
    v = tf.Variable([1, 2, 3, 4, 5, 6], dtype = tf.int32)
    print("原始Variable张量v:")
    tf.print(v)
    print("----------------------")
    
    #第i个元素,下表从0开始
    print("第[0, 3, 5]个元素:")
    tf.print(v[0], v[3], v[5])
    print("----------------------")
    
    #倒数第i个元素,下表从-1开始
    print("倒数第[1, 3, 6]个元素:")
    tf.print(v[-1], v[-3], v[-6])
    print("----------------------")
    
    #第i至第j个元素,不包括元素j
    print("第3至第5个元素:")
    tf.print(v[3:5])
    print("----------------------")
    
    v[0].assign(-1)
    print("将v[0]修改为-1后的张量v:")
    tf.print(v)
    print("----------------------")
    
    v[2:5].assign([-1, -1, -1])
    print("将v[2:5]修改为-1后的张量v:")
    tf.print(v)
    print("----------------------")
    
    #布尔索引:找出不为-1的所有值
    print("索引出所有不为-1的值")
    tf.print(v[v != -1])
    print("----------------------")
    

    结果如下:

    原始Variable张量v:
    [1 2 3 4 5 6]
    ----------------------
    第[0, 3, 5]个元素:
    1 4 6
    ----------------------
    倒数第[1, 3, 6]个元素:
    6 4 1
    ----------------------
    第3至第5个元素:
    [4 5]
    ----------------------
    将v[0]修改为-1后的张量v:
    [-1 2 3 4 5 6]
    ----------------------
    将v[2:5]修改为-1后的张量v:
    [-1 2 -1 -1 -1 6]
    ----------------------
    索引出所有不为-1的值
    [2 6]
    
    多维张量索引及切片

      多维张量的索引及切片根据切片和索引的内容是否规范,又可以分成规范索引及切片和不规范索引及切片,其中规范的意思值得是切片的范围为规则的行,列或者规则的矩阵范围内。而不规范的索引及切片一般为某些特定条件下的索引和切片,如布尔索引等。下面介绍简单的规范索引的实例,这里仅以constant常量作为实例,不在以Variable变量作为实例,代码如下:

    #多维张量
    tf.random.set_seed(5)
    m = tf.random.uniform([5, 5], minval=0, maxval=5, dtype=tf.int32)
    print("原始张量m:")
    tf.print(m)
    print("----------------------")
    
    #获取第i行切片,下标从0开始
    print("获取第[1, 3]行元素:")
    tf.print(m[1], m[3])
    print("----------------------")
    
    
    #倒数第i个元素,下表从-1开始
    print("获取倒数第[1, 3]行元素:")
    tf.print(m[-1], m[-3])
    print("----------------------")
    
    #获取第1行第3列元素
    print("获取倒数第1行第3列元素:")
    tf.print(m[1,3])
    tf.print(m[1][3])
    print("----------------------")
    
    #获取第1行至第3行元素
    print("获取第1行至第3行元素:")
    tf.print(m[1:4, :])
    print("----------------------")
    
    #使用tf.slice(input,begin_vector,size_vector)获取切片
    print("使用tf.slice获取第[1,0]至第[3,4]范围内的二维张量的元素:")
    tf.print(tf.slice(m, [1, 0], [3, 5])) 
    print("----------------------")
    
    #获取第1行至第5行元素,每两行取一行
    print("获取第1行至第5行,每隔两取一行的元素:")
    tf.print(m[1:5:2, :])
    tf.print(m[1:5:2, ...]) #也可以用省略号代替:,省略号可以代替多个:
    print("----------------------")
    

    结果如下:

    原始张量m:
    [[0 1 3 1 2]
     [1 4 0 1 1]
     [3 0 3 2 4]
     [4 0 0 2 1]
     [3 1 2 0 2]]
    ----------------------
    获取第[1, 3]行元素:
    [1 4 0 1 1] [4 0 0 2 1]
    ----------------------
    获取倒数第[1, 3]行元素:
    [3 1 2 0 2] [3 0 3 2 4]
    ----------------------
    获取倒数第1行第3列元素:
    1
    1
    ----------------------
    获取第1行至第3行元素:
    [[1 4 0 1 1]
     [3 0 3 2 4]
     [4 0 0 2 1]]
    ----------------------
    使用tf.slice获取第[1,0]至第[3,4]范围内的二维张量的元素:
    [[1 4 0 1 1]
     [3 0 3 2 4]
     [4 0 0 2 1]]
    ----------------------
    获取第1行至第5行,每隔两取一行的元素:
    [[1 4 0 1 1]
     [4 0 0 2 1]]
    [[1 4 0 1 1]
     [4 0 0 2 1]]
    ----------------------
    

      以上切片方式相对来说规则,而对于不规则的切片,TensorFlow为我们提供了gathergather_nd以及boolean_mask等方法。代码如下:

    #使用tf.gather,tf.gather_nd,tf.boolean_mask实现不规则切片
    tf.random.set_seed(5)
    m = tf.random.uniform([5, 5], minval=0, maxval=5, dtype=tf.int32)
    print("原始张量m:")
    tf.print(m)
    print("----------------------")
    
    #从axis=0的角度,选择第[0,1,3]行数据
    print("axis=0,索引[0,1,3]行数据:")
    g1 = tf.gather(m, [0,1,3], axis=0)
    tf.print(g1)
    print("----------------------")
    
    #从axis=1的角度,选择第[0,1,4]列数据
    print("axis=1,索引[0,1,4]列数据:")
    g2 = tf.gather(m, [0,1,4], axis=1)
    tf.print(g2)
    print("----------------------")
    
    
    #先从axis=0的角度,选取第[0,1,3]行数据,然后在此基础上从axis=1的角度,选择第[0,1,4]列数据
    print("axis=0索引[0,1,3]行,且axis=1索引[0,1,4]列数据:")
    g3 = tf.gather(tf.gather(m,[0, 1, 3], axis=0), [0,1,4], axis=1)
    tf.print(g3)
    print("----------------------")
    
    #使用tf.gather_nd按位置索引元素,indices为要索引元素的坐标
    gn1 = tf.gather_nd(m, indices = [(0,0),(1,1),(0,2)])
    tf.print(gn1)
    print("----------------------")
    
    #使用tf.boolean_mask实现按行索引
    print("a使用tf.boolean_mask按行索[0,1,3]行数据:")
    bm1 = tf.boolean_mask(m, [True,True,False,True,False], axis=0)
    tf.print(bm1)
    print("----------------------")
    
    #使用tf.boolean_mask实现按行索引
    print("a使用tf.boolean_mask按列索[0,1,4]列数据:")
    bm2 = tf.boolean_mask(m, [True,True,False,False,True], axis=1)
    tf.print(bm2)
    print("----------------------")
    
    #使用tf.boolean_mask实现位置索引
    print("a使用tf.boolean_mask按位置索引数据:")
    bm3 = tf.boolean_mask(m, 
                          [[True,True,False,False,True],
                          [False,False,False,False,True],
                          [False,False,True,False,False,],
                          [False,False,False,False,False],
                          [True,False,False,False,False]])
    tf.print(bm3)
    print("----------------------")
    
    #使用tf.boolean_mask实现布尔索引
    print("a使用tf.boolean_mask实现布尔索引:")
    bm4 = tf.boolean_mask(m, m < 1)
    tf.print(bm4)
    #使用tf.boolean_mask语法糖形式,实现布尔索引
    print("a使用tf.boolean_mask布尔索引语法糖:")
    bm5 = m[m < 1]
    tf.print(bm5)
    print("----------------------")
    

    结果如下:

    原始张量m:
    [[0 1 3 1 2]
     [1 4 0 1 1]
     [3 0 3 2 4]
     [4 0 0 2 1]
     [3 1 2 0 2]]
    ----------------------
    axis=0,索引[0,1,3]行数据:
    [[0 1 3 1 2]
     [1 4 0 1 1]
     [4 0 0 2 1]]
    ----------------------
    axis=1,索引[0,1,4]列数据:
    [[0 1 2]
     [1 4 1]
     [3 0 4]
     [4 0 1]
     [3 1 2]]
    ----------------------
    axis=0索引[0,1,3]行,且axis=1索引[0,1,4]列数据:
    [[0 1 2]
     [1 4 1]
     [4 0 1]]
    ----------------------
    [0 4 3]
    ----------------------
    a使用tf.boolean_mask按行索[0,1,3]行数据:
    [[0 1 3 1 2]
     [1 4 0 1 1]
     [4 0 0 2 1]]
    ----------------------
    a使用tf.boolean_mask按列索[0,1,4]列数据:
    [[0 1 2]
     [1 4 1]
     [3 0 4]
     [4 0 1]
     [3 1 2]]
    ----------------------
    a使用tf.boolean_mask按位置索引数据:
    [0 1 2 1 3 3]
    ----------------------
    a使用tf.boolean_mask实现布尔索引:
    [0 0 0 0 0 0]
    a使用tf.boolean_mask布尔索引语法糖:
    [0 0 0 0 0 0]
    ----------------------
    
    张量改值

      上面提到的方法主要用于对tf.constant张量或tf.Variable张量进行规则或者不规则索引及切片下的数据获取,会生成一个新的tf.constant张量,并不支持在获取原张量内容的前提下修改数据,即有什么查什么,满足条件的查出来即可。为了基于原张量返回一个修改若干值的新张量,TensorFlow提供了tf.wheretf.scatter_nd等方法来实现张量内容的修改。示例代码如下:

    #张量改值 
    tf.random.set_seed(5)
    m = tf.random.uniform([5, 5], minval=-1, maxval=5, dtype=tf.int32)
    print("原始张量m:")
    tf.print(m)
    print("----------------------")
    #将m小于0的值,替换成6
    print("利用tf.where替换-1为 666:")
    m1 = tf.where(m == -1, tf.fill(m.shape, 666), m)
    tf.print(m1)
    print("----------------------")
    
    print("利用tf.scatter_nd将m对角线数据倒叙插入新的m.shape的全0张量的对角线中")
    m2 = tf.scatter_nd([[0,0],[1,1],[2,2],[3,3],[4,4]], [m[4,4],m[3,3],m[2,2],m[1,1],m[0,0]], m.shape)
    tf.print(m2)
    print("----------------------")
    
    #如果where只有一个参数,将返回所有满足条件的位置坐标
    indices = tf.where(m < 0)
    print("挑选出所有值小于0的位置:")
    tf.print(indices)
    print("----------------------")
    print("按位置返回m中所有小于0的值,其他的值用0替换:")
    m3 = tf.scatter_nd(indices, tf.gather_nd(m, indices), m.shape)
    tf.print(m3)
    print("----------------------")
    

    结果如下:

    原始张量m:
    [[1 0 1 3 0]
     [-1 2 4 4 0]
     [-1 4 3 4 -1]
     [2 -1 1 2 2]
     [1 2 2 0 -1]]
    ----------------------
    利用tf.where替换-1为 666:
    [[1 0 1 3 0]
     [666 2 4 4 0]
     [666 4 3 4 666]
     [2 666 1 2 2]
     [1 2 2 0 666]]
    ----------------------
    利用tf.scatter_nd将m对角线数据倒叙插入新的m.shape的全0张量的对角线中
    [[-1 0 0 0 0]
     [0 2 0 0 0]
     [0 0 3 0 0]
     [0 0 0 2 0]
     [0 0 0 0 1]]
    ----------------------
    挑选出所有值小于0的位置:
    [[1 0]
     [2 0]
     [2 4]
     [3 1]
     [4 4]]
    ----------------------
    按位置返回m中所有小于0的值,其他的值用0替换:
    [[0 0 0 0 0]
     [-1 0 0 0 0]
     [-1 0 0 0 -1]
     [0 -1 0 0 0]
     [0 0 0 0 -1]]
    ----------------------
    

      以上关于TensorFlow中张量的索引及切片操作进行了简单的介绍,除了TensorFlow内置的一些方法外,其实基本上大部分的索引及切片操作都类似于Numpy中针对ndarray的索引及切片操作,如果熟悉Numpy的话,这一块看起来将毫无压力,因为百分之90以上是一样的。好了关于TensorFlow中张量的索引及切片操作就介绍这么多。

    TensorFlow系列文章:
    1. TensorFlow核心概念之Tensor(1):张量创建
    2. TensorFlow核心概念之Tensor(2):索引切片
    3. TensorFlow核心概念之Tensor(3):变换拆合
    4. TensorFlow核心概念之Tensor(4):张量运算
    5. TensorFlow核心概念之计算图

    相关文章

      网友评论

          本文标题:TensorFlow核心概念之Tensor(2):索引切片

          本文链接:https://www.haomeiwen.com/subject/rbqmcdtx.html