[参考link]
tf.nn.embedding_lookup()
tf.gather()
通过查看embedding_lookup函数的源码,不难发现,它是gather函数的一种特殊形式。embedding_lookup
的底层实现,最终调用了gather
@tf_export(v1=["nn.embedding_lookup"])
def embedding_lookup(
params,
ids,
partition_strategy="mod",
name=None,
validate_indices=True, # pylint: disable=unused-argument
max_norm=None):
"""Looks up `ids` in a list of embedding tensors.
This function is used to perform parallel lookups on the list of
tensors in `params`. It is a generalization of
`tf.gather`, where `params` is
interpreted as a partitioning of a large embedding tensor. `params` may be
a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()`
with a
partitioner.
If `len(params) > 1`, each element `id` of `ids` is partitioned between
the elements of `params` according to the `partition_strategy`.
In all strategies, if the id space does not evenly divide the number of
partitions, each of the first `(max_id + 1) % len(params)` partitions will
be assigned one more id.
If `partition_strategy` is `"mod"`, we assign each id to partition
`p = id % len(params)`. For instance,
13 ids are split across 5 partitions as:
`[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
If `partition_strategy` is `"div"`, we assign ids to partitions in a
contiguous manner. In this case, 13 ids are split across 5 partitions as:
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
The results of the lookup are concatenated into a dense
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
Args:
params: A single tensor representing the complete embedding tensor, or a
list of P tensors all of same shape except for the first dimension,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
up in `params`.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`.
name: A name for the operation (optional).
validate_indices: DEPRECATED. If this operation is assigned to CPU, values
in `indices` are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may
include raising an error.
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
than this value.
Returns:
A `Tensor` with the same type as the tensors in `params`.
Raises:
ValueError: If `params` is empty.
"""
return _embedding_lookup_and_transform(
params=params,
ids=ids,
partition_strategy=partition_strategy,
name=name,
max_norm=max_norm,
transform_fn=None)
def _embedding_lookup_and_transform(params,
ids,
partition_strategy="mod",
name=None,
max_norm=None,
transform_fn=None):
"""Helper function for embedding_lookup and _compute_sampled_logits.
This function is a generalization of embedding_lookup that optionally
applies a caller-specified transformation to each embedding. This is
done through the `transform_fn` argument. If provided, the function is
applied to each partitioned tensor of retrieved embeddings, colocated
with the embeddings. This function will be called with a single `Tensor`
argument of the same type as the `params` tensor and should return a
`Tensor`. The shape of the argument will be the same as `params` except
for the size of the first dimension. The first dimension of the result's
shape must be the same size as the argument's.
Args:
params: See embedding_lookup.
ids: See embedding_lookup.
partition_strategy: See embedding_lookup.
name: See embedding_lookup.
max_norm: See embedding_lookup.
transform_fn: An optional function to apply to each retrieved embedding. If
max_norm is provided, transform_fn is applied to the norm-limited
embeddings.
Returns:
See embedding_lookup for details.
Raises:
ValueError: If `params` is empty.
"""
if params is None or params in ((), []):
raise ValueError("Need at least one param")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
np = len(params) # Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.
if not any(
isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
ids = ops.convert_to_tensor(ids, name="ids")
if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
with ops.colocate_with(params[0]):
result = _clip(
array_ops.gather(params[0], ids, name=name), ids, max_norm)
if transform_fn:
result = transform_fn(result)
# Make sure the final result does not have colocation contraints on the
# params. Similar to the case np > 1 where parallel_dynamic_stitch is
# outside the scioe of all with ops.colocate_with(params[p]).
return array_ops.identity(result)
else:
# Flatten the ids. There are two cases where we need to do this.
# - There is more than one params tensor.
# - There is a transform_fn and ids is not statically known to be 1-D.
# We must flatten in this case because transform_fn expects a flat
# tensor of embeddings.
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
# Create p_assignments and set new_ids depending on the strategy.
if partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
elif partition_strategy == "div":
# Compute num_total_ids as the sum of dim-0 of params, then assign to
# partitions based on a constant number of ids per partition. Optimize
# if we already know the full shape statically.
dim_0_size = tensor_shape.Dimension(
tensor_shape.dimension_value(params[0].get_shape()[0]))
for p in xrange(1, np):
dim_0_size += tensor_shape.Dimension(
tensor_shape.dimension_value(params[p].get_shape()[0]))
if dim_0_size.value:
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
else:
dim_0_sizes = []
for p in xrange(np):
param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
if param_p_dim is not None:
dim_0_sizes.append(param_p_dim)
else:
with ops.colocate_with(params[p]):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
(flat_ids - extras) //
ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
new_ids = array_ops.where(p_assignments < extras,
flat_ids % (ids_per_partition + 1),
(flat_ids - extras) % ids_per_partition)
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices,
p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
pids = gather_ids[p]
with ops.colocate_with(params[p]):
result = array_ops.gather(params[p], pids)
if transform_fn:
# If transform_fn is provided, the clip_by_norm precedes
# the transform and hence must be co-located. See below
# for the counterpart if transform_fn is not proveded.
result = transform_fn(_clip(result, pids, max_norm))
partitioned_result.append(result)
# Stitch these back together
ret = data_flow_ops.parallel_dynamic_stitch(
pindices, partitioned_result, name=name)
# Determine the static element shape.
if transform_fn is None:
element_shape_s = params[0].get_shape()[1:]
for p in params[1:]:
element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
else:
element_shape_s = ret.get_shape()[1:]
# Compute the dynamic element shape.
if element_shape_s.is_fully_defined():
element_shape_d = element_shape_s
elif transform_fn is None:
# It's important that we compute params[0].shape on the right device
# to avoid data motion.
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
element_shape_d = params_shape[1:]
else:
element_shape_d = array_ops.shape(ret)[1:]
# Reshape to reverse the flattening of ids.
ret = array_ops.reshape(
ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
# Normally the reshape is sufficient, but setting shape explicitly
# teaches shape inference that params[1:].get_shape() matters
# (in the case that transform_fn is None).
ret.set_shape(ids.get_shape().concatenate(element_shape_s))
if not transform_fn:
# If transform_fn was provided, the clip_by_norm was done above.
ret = _clip(ret, ids, max_norm)
return ret
网友评论