美文网首页我爱编程
tensorflow学习-embedding_lookup()用

tensorflow学习-embedding_lookup()用

作者: 听风1996 | 来源:发表于2018-04-08 15:12 被阅读441次

    embedding_lookup( )的用法
    关于tensorflow中embedding_lookup( )的用法,在Udacity的word2vec会涉及到,本文将通俗的进行解释

    #!/usr/bin/env/python
    # coding=utf-8
    import tensorflow as tf
    import numpy as np
    
    input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
    
    embedding = tf.Variable(np.identity(5, dtype=np.int32))
    input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
    
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    print(embedding.eval())
    print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))
    

    代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。
    运行结果为:

    embedding = [[1 0 0 0 0]
                 [0 1 0 0 0]
                 [0 0 1 0 0]
                 [0 0 0 1 0]
                 [0 0 0 0 1]]
    input_embedding = [[0 1 0 0 0]
                       [0 0 1 0 0]
                       [0 0 0 1 0]
                       [1 0 0 0 0]
                       [0 0 0 1 0]
                       [0 0 1 0 0]
                       [0 1 0 0 0]]
    

    简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。
    如果将input_ids改写成下面的格式:

    input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
    print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))
    

    输出结果就会变成如下的格式:

    [[[0 1 0 0 0]
      [0 0 1 0 0]]
     [[0 0 1 0 0]
      [0 1 0 0 0]]
     [[0 0 0 1 0]
      [0 0 0 1 0]]]
    

    对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。

    相关文章

      网友评论

        本文标题:tensorflow学习-embedding_lookup()用

        本文链接:https://www.haomeiwen.com/subject/ywzphftx.html