美文网首页
tf.slice() / tf.strided_slice()

tf.slice() / tf.strided_slice()

作者: sterio | 来源:发表于2017-09-21 11:14 被阅读0次

    Slice in tensorflow is very similar to numpy, which can be found in TensorFlow API

    tf.slice

    tf.slice(
    input_, # input tensor
    begin, # begin location
    size, # output tensor size
    name=None # name of operation
    )
    

    a equals b since tf.slice stride is [1,1,1] by default.

    import tensorflow as tf
    input = [[[1,1,1],[2,2,2],[3,3,3]],
             [[4,4,4],[5,5,5],[6,6,6]],
             [[7,7,7],[8,8,8],[9,9,9]]]
    
    sess = tf.Session()
    
    a = tf.slice(input, [0,1,0], [2,2,2])
    print(sess.run(a))
    print("\n")
    
    b = tf.strided_slice(input, [0,1,0], [2,3,2], [1,1,1])
    print(sess.run(b))
    
    [[[2 2]
      [3 3]]
    
     [[5 5]
      [6 6]]]
    
    
    [[[2 2]
      [3 3]]
    
     [[5 5]
      [6 6]]]
    

    tf.strided_slice

    tf.strided_slice(
    input_, # input tensor
    begin, # begin location
    end, # end location, ATTENTION: not included!
    strides=None, # strides of slice
    begin_mask=0,
    end_mask=0,
    ellipsis_mask=0,
    new_axis_mask=0,
    shrink_axis_mask=0,
    var=None,
    name=None
    )
    
    tf.strided_slice 1-dim operation: tf.strided_slice(input_data, [begin_element], [end_element(not included), stride_step]
    import tensorflow as tf
    input_data = [1,2,3,4,5,6,7,8,9]
    a = tf.strided_slice(input_data, [0], [4], [1])
    b = tf.strided_slice(input_data, [0], [-1], [1])
    c = tf.strided_slice(input_data, [0], [-1], [2])
    d = tf.strided_slice(input_data, [0], [-1])
    e = tf.strided_slice(input_data, [-1], [-2], [-1])
    
    a begin location at input_data[0], end location at input_data[4] with stride step 1 
    [1 2 3 4]
    
    
    b begin location at input_data[0], end location at last with stride step 1 
    [1 2 3 4 5 6 7 8]
    
    
    c begin location at input_data[0], end location at last with stride step 2 
    [1 3 5 7]
    
    
    d begin location at input_data[0], end location at last with stride step 1 by default 
    [1 2 3 4 5 6 7 8]
    
    
    e begin location at input_data[last element], end location at input_data[second to last] with stride step -1
    [9]
    
    tf.strided_slice higher dim operation:
    import tensorflow as tf
    import numpy as np
    input_data = np.arange(60).reshape(3, 4, 5)
    
    slice_np_data = input_data[1:2, 0:2, 0:2]
    print(slice_np_data)
    >>>
    [[[20 21]
      [25 26]]]
    
    slice_tensor = tf.strided_slice(input_data,[1,0,0],[2,2,2])
    with tf.Session() as sess:
        print(sess.run(slice_tensor))
    >>>
    [[[20 21]
      [25 26]]]
    
    #3 tf.strided_slice case in seq2seq model:

    In seq2seq model we slice target data into decoder network

    ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
    dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['<s>']), ending], 1)
    

    Supposing batch_size is n. tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1]) keeps every batch and drop the last integer per batch. ending = targets[0:n, 0,-1]

    相关文章

      网友评论

          本文标题:tf.slice() / tf.strided_slice()

          本文链接:https://www.haomeiwen.com/subject/ikgnsxtx.html