美文网首页
tensorflow API使用笔记 StringToHashB

tensorflow API使用笔记 StringToHashB

作者: peteyuan | 来源:发表于2018-12-17 11:49 被阅读87次

    tensorflow对类别特征,会先转换成字符串,然后做hash。代码实现如下:

    template <uint64 hash(StringPiece)>
    class StringToHashBucketOp : public OpKernel {
     public:
      explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
        OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
      }
    
      void Compute(OpKernelContext* context) override {
        const Tensor* input_tensor;
        OP_REQUIRES_OK(context, context->input("input", &input_tensor));
        const auto& input_flat = input_tensor->flat<string>();
    
        Tensor* output_tensor = nullptr;
        OP_REQUIRES_OK(context,
                       context->allocate_output("output", input_tensor->shape(),
                                                &output_tensor));
        auto output_flat = output_tensor->flat<int64>();
    
        typedef decltype(input_flat.size()) Index;
        for (Index i = 0; i < input_flat.size(); ++i) {
          const uint64 input_hash = hash(input_flat(i));
          const uint64 bucket_id = input_hash % num_buckets_;
          // The number of buckets is always in the positive range of int64 so is
          // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
          // safe.
          output_flat(i) = static_cast<int64>(bucket_id);
        }
      }
    
     private:
      int64 num_buckets_;
    
      TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
    };
    

    注意,这里定义了个模板类,其中hash是一个模板类型名。对模板的调用如下:

    REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
                            StringToHashBucketOp<Fingerprint64>);
    

    以上,定义在头文件string_to_hash_bucket_op.h中,实现在string_to_hash_bucket_op.cc里。

    可以发现这里的hash函数使用的是Fingerprint64,来自于google开源的farmhash。

    以下两种实现是等价的:

    const uint64 input_hash = hash(input_flat(i));
    

    const uint64_t input_hash = NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(x.data(), x.size());
    

    相关文章

      网友评论

          本文标题:tensorflow API使用笔记 StringToHashB

          本文链接:https://www.haomeiwen.com/subject/jouqkqtx.html