美文网首页
tf.nn.embedding_lookup用法浅析

tf.nn.embedding_lookup用法浅析

作者: 陈晓峥 | 来源:发表于2018-10-31 17:38 被阅读0次

tf.nn.embedding_lookup的用法主要是选取一个张量里面索引对应的元素

原型:tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)

params 代表输入的张量,ids代表要选取params里对应的那个维度的数据

简单来个例子(粘贴可直接运行)

import tensorflow as tf

import numpy as np

a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]

a = np.asarray(a)

idx1 = tf.Variable([0, 2, 3, 1], tf.int32)

idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)

b = [[0.1, 0.2, 1], [2.1, 1.2, 1]]

b = np.asarray(b)

idx3 = tf.placeholder(tf.int32, [None, 3], name="input_x")

out1 = tf.nn.embedding_lookup(a, idx1)

out2 = tf.nn.embedding_lookup(a, idx2)

out3 = tf.nn.embedding_lookup(a, idx3)

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

print (sess.run(out1))

print (out1)

print ('==================')

print (sess.run(out2))

print (out2)

print (sess.run(out3, feed_dict ={idx3: b}))

print (out3)

其输入内容为

咱们一个一个分析

1.第一个out1代表从a中依次取第 0, 2, 3, 1维数据进行拼装,拼出来的shape还是(4,3)

2.第二个out2代表从a中依次取 第0, 2, 3, 1维数据拼装一个(4,3)的数据 接着再从a中依次取4, 0, 2, 2 维来进行拼装,之后再把两个(4, 3) 拼装在一起形成(2,4,3)的张量(tensor)

3.第三个使用了placeholder来输入ids,placeholder的shape为(?,3),代表从数据里先取3个数据出来,每个数据有3个元素,最后再 ?个(3, 3)拼接在一起组成(?,3,3)的tensor

自己多动手多跑跑例子就可以了。

如有问题欢迎大家指正,谢谢

相关文章

网友评论

      本文标题:tf.nn.embedding_lookup用法浅析

      本文链接:https://www.haomeiwen.com/subject/zqnatqtx.html