include include "models...">
美文网首页
pybind11尝试编写

pybind11尝试编写

作者: KyoDante | 来源:发表于2024-03-13 01:13 被阅读0次

include "chat.h"

include <pybind11/pybind11.h>

include <pybind11/stl.h>

include "models.cpp"

namespace chatllm {

namespace py = pybind11;
using namespace pybind11::literals;

// class PyBaseTokenizer : public BaseTokenizer {
// public:
// using BaseTokenizer::BaseTokenizer;

// std::vector<int> encode(const std::string &text, int max_length) const override {
// PYBIND11_OVERRIDE_PURE(std::vector<int>, BaseTokenizer, encode, text, max_length);
// }
// std::string decode(const std::vector<int> &ids) const override {
// PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids);
// }
// std::vector<int> encode_messages(const std::vector<ChatMessage> &history, int max_length) const override {
// PYBIND11_OVERLOAD_PURE(std::vector<int>, BaseTokenizer, encode_messages, history, max_length);
// }
// };

// class PyBaseModelForCausalLM : public BaseModelForCausalLM {
// public:
// using BaseModelForCausalLM::BaseModelForCausalLM;

// void load(ModelLoader &loader) override { PYBIND11_OVERLOAD_PURE(void, PyBaseModelForCausalLM, load, loader); }

// ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx,
// bool is_decoding) const override {
// PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx,
// is_decoding)
// }
// };

template <typename T>
static inline std::string to_string(const T &obj) {
std::ostringstream oss;
oss << obj;
return oss.str();
}

PYBIND11_MODULE(_C, m) {
m.doc() = "ChatLLM.cpp python binding";

py::enum_<ModelType>(m, "ModelType")
    .value("MINICPM", ModelType::MODEL_TYPE_MINICPM);

py::class_<minicpm::Config>(m, "MiniCPMConfig")
    // .def_readonly("dtype", &BaseConfig::dtype)
    .def_readonly("vocab_size", &minicpm::Config::vocab_size)
    .def_readonly("hidden_size", &minicpm::Config::hidden_size)
    .def_readonly("num_attention_heads", &minicpm::Config::num_attention_heads)
    .def_readonly("num_hidden_layers", &minicpm::Config::num_hidden_layers)
    .def_readonly("intermediate_size", &minicpm::Config::intermediate_size)
    .def_readonly("max_length", &minicpm::Config::max_length)
    .def_readonly("bos_token_id", &minicpm::Config::bos_token_id)
    .def_readonly("eos_token_id", &minicpm::Config::eos_token_id)
    .def_readonly("pad_token_id", &minicpm::Config::pad_token_id)
    .def_readonly("sep_token_id", &minicpm::Config::sep_token_id)
    .def_readonly("num_key_value_heads", &minicpm::Config::num_key_value_heads)
    .def_readonly("rope_scaling", &minicpm::Config::rope_scaling)
    .def_readonly("rope_theta", &minicpm::Config::rope_theta)
    .def_readonly("scale_depth", &minicpm::Config::scale_depth);

py::class_<GenerationConfig>(m, "GenerationConfig")
    .def(py::init<int, int, bool, int, float, float, int>(), "max_length"_a = 2048,
        "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0,
        "top_p"_a = 0.7, "temperature"_a = 0.95, "num_threads"_a = 0)
    .def_readwrite("max_length", &GenerationConfig::max_length)
    .def_readwrite("max_context_length", &GenerationConfig::max_context_length)
    .def_readwrite("do_sample", &GenerationConfig::do_sample)
    .def_readwrite("top_k", &GenerationConfig::top_k)
    .def_readwrite("top_p", &GenerationConfig::top_p)
    .def_readwrite("temperature", &GenerationConfig::temperature)
    .def_readwrite("num_threads", &GenerationConfig::num_threads);

// py::class_<ChatMessage>(m, "ChatMessage")
//     .def(py::init<std::string, std::string, std::vector<ToolCallMessage>>(), "role"_a, "content"_a,
//          "tool_calls"_a = std::vector<ToolCallMessage>{})
//     .def("__repr__", &to_string<ChatMessage>)
//     .def("__str__", &to_string<ChatMessage>)
//     .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
//     .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
//     .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
//     .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION)
//     .def_readwrite("role", &ChatMessage::role)
//     .def_readwrite("content", &ChatMessage::content)
//     .def_readwrite("tool_calls", &ChatMessage::tool_calls);

// py::class_<minicpm::Tokenizer>(m, "Tokenizer")
//     .def("encode", &minicpm::Tokenizer::encode, py::arg("text"))
//     .def("decode", &minicpm::Tokenizer::decode, "ids"_a);

// py::class_<chatllm::BaseHistoryEncoder>(m, "BaseHistoryEncoder");
// py::class_<chatllm::BaseTokenizer>(m, "BaseTokenizer")
//     .def("load", [](chatllm::BaseTokenizer& tokenizer, const char *buffer, int n_vocab){

//     });
// py::class_<chatllm::BaseStreamer>(m, "BaseStreamer");
// py::class_<chatllm::TextStreamer>(m, "TextStreamer");
    // .def(py::init<chatllm::BaseTokenizer>(), "tokenizer"_a); // 有bug

py::class_<chatllm::BaseTokenizer, minicpm::Tokenizer>(m, "MiniCPMTokenizer")
    .def("encode", [](minicpm::Tokenizer& tokenizer, const std::string& text){
        return tokenizer.encode(text);
    })
    .def("decode", [](minicpm::Tokenizer& tokenizer, const std::vector<int> &ids){
        return tokenizer.decode(ids);
    });
    // .def("load", [](minicpm::Tokenizer& tokenizer, const char *buffer, int n_vocab){
    //     return tokenizer.load(buffer, n_vocab);
    // });

// py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
//     .def("generate_next_token", &minicpm::ConditionalGeneration::generate_next_token, 
//     "input_ids"_a, "gen_config"_a);

py::class_<minicpm::ConditionalGeneration>(m, "MiniCPMModel")
    .def("generate_next_token", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config) {
        int gen_token = -1;
        if (generation.get_n_past() == 0) {
            gen_token = generation.generate_next_token(input_ids, gen_config);
            generation.set_n_past(generation.get_n_past() + input_ids.size());
        } else {
            int lastElement = input_ids.back();
            const std::vector<int> &lastElementVec = {lastElement};
            gen_token = generation.generate_next_token(lastElementVec, gen_config);
            generation.set_n_past(generation.get_n_past() + 1);
        }
        return gen_token;
    })
    .def("reset_n_past", [](minicpm::ConditionalGeneration& generation){
        generation.set_n_past(0);
    })
    .def_readonly("config", &minicpm::ConditionalGeneration::config);
    // .def("generate", [](minicpm::ConditionalGeneration& generation, const std::vector<int> &input_ids, const GenerationConfig &gen_config,
    //                           const bool continuous,
    //                           bool &completed){
        
    // });

// ===== ChatGLM3 =====

// py::class_<ChatGLM3Tokenizer, BaseTokenizer>(m, "ChatGLM3Tokenizer");

// ===== Pipeline ====

py::class_<Pipeline>(m, "Pipeline")
    .def(py::init<const std::string &>(), "path"_a)
    .def_property_readonly("model", [](const Pipeline &self) { return self.model; })
    .def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer; })
    .def("chat", [](Pipeline& pipeline, std::vector<std::string> &history, const GenerationConfig &gen_config){
        return pipeline.chat(history, gen_config);
    });

}

} // namespace chatglm

from pathlib import Path
import chatllm_cpp._C as _C

class Pipeline(_C.Pipeline):
def init(self, model_path: str) -> None:
if Path(model_path).is_file():
# load ggml model
super().init(str(model_path))
else:
raise RuntimeError("参数错误")

def chat(
    self,
    message: str,
    *,
    max_length: int = 2048,
    max_context_length: int = 512,
    do_sample: bool = True,
    top_k: int = 0,
    top_p: float = 0.7,
    temperature: float = 0.95,
    num_threads: int = 0,
    # stream: bool = False,
):
    input_ids = self.tokenizer.encode(message)
    
    gen_config = _C.GenerationConfig(
        max_length=max_length,
        max_new_tokens=max_new_tokens,
        max_context_length=max_context_length,
        do_sample=do_sample,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        num_threads=num_threads,
    )
    _C.
    if stream:
        return self._stream_chat(input_ids=input_ids, gen_config=gen_config)
    return self._sync_chat(input_ids=input_ids, gen_config=gen_config)

import _C
pipeline = _C.Pipeline(r"C:\Users\KyoDa\Downloads\chatllm.cpp\quantized_16.bin")
question = "Hello."
ids = pipeline.tokenizer.encode(f" <用户>{question}<AI>")
config = _C.GenerationConfig()
new_token = 0
pipeline.model.reset_n_past()
print(pipeline.model.config.eos_token_id, "<< id")
while new_token != pipeline.model.config.eos_token_id:
new_token = pipeline.model.generate_next_token(ids, config)
ids.append(new_token);
print(new_token, end=',', flush=True)

print(pipeline.tokenizer.decode(ids))

pipeline.chat(["Hello."], config)

相关文章

网友评论

      本文标题:pybind11尝试编写

      本文链接:https://www.haomeiwen.com/subject/zopbzdtx.html