tensorflow对类别特征,会先转换成字符串,然后做hash。代码实现如下:
template <uint64 hash(StringPiece)>
class StringToHashBucketOp : public OpKernel {
public:
explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<string>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int64>();
typedef decltype(input_flat.size()) Index;
for (Index i = 0; i < input_flat.size(); ++i) {
const uint64 input_hash = hash(input_flat(i));
const uint64 bucket_id = input_hash % num_buckets_;
// The number of buckets is always in the positive range of int64 so is
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
// safe.
output_flat(i) = static_cast<int64>(bucket_id);
}
}
private:
int64 num_buckets_;
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
};
注意,这里定义了个模板类,其中hash是一个模板类型名。对模板的调用如下:
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
StringToHashBucketOp<Fingerprint64>);
以上,定义在头文件string_to_hash_bucket_op.h中,实现在string_to_hash_bucket_op.cc里。
可以发现这里的hash函数使用的是Fingerprint64,来自于google开源的farmhash。
以下两种实现是等价的:
const uint64 input_hash = hash(input_flat(i));
和
const uint64_t input_hash = NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(x.data(), x.size());
网友评论