美文网首页
Python 使用 Tornado Web 框架写HTTP接口实

Python 使用 Tornado Web 框架写HTTP接口实

作者: 光剑书架上的书 | 来源:发表于2022-08-31 02:28 被阅读0次
    # Copyright (c) 2022, salesforce.com, inc.
    # All rights reserved.
    # SPDX-License-Identifier: BSD-3-Clause
    # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
    # pip3 install tornado
    
    import tornado.ioloop
    import tornado.web
    import tornado.gen
    from concurrent.futures import ThreadPoolExecutor
    from tornado.concurrent import run_on_executor
    import json
    
    from aixcoder.aixcode import AIXCode
    
    AIXCode1 = AIXCode('codegen-350M-multi')
    AIXCode2 = AIXCode('codegen-2B-multi')
    
    
    def get_body_json(body):
        body_decode = body.decode()
        body_json = json.loads(body_decode)
        return body_json
    
    
    class PingHandler(tornado.web.RequestHandler):
        @tornado.gen.coroutine
        def get(self):
            print(f'request:{self.request.full_url()}')
            self.write("Pong!")
    
        @tornado.gen.coroutine
        def post(self):
            print(f'request:{self.request.full_url()}')
            body_json = get_body_json(self.request.body)
            print(f'request:{body_json}')
            self.write("Pong!")
    
    
    class AIX1Handler(tornado.web.RequestHandler):
        executor = ThreadPoolExecutor(32)
    
        @run_on_executor
        def aixcode(self, x):
            return AIXCode1.aixcode(x)
    
        @tornado.gen.coroutine
        def get(self):
            """get请求"""
            print(f'request:{self.request.full_url()}')
            x = self.get_argument('x')
            y = yield self.aixcode(x)
            self.write(y)
    
        @tornado.gen.coroutine
        def post(self):
            '''post请求'''
            print(f'request:{self.request.full_url()}')
            body_json = get_body_json(self.request.body)
            print(f'request:{body_json}')
            x = body_json.get("x")
            y = yield self.aixcode(x)
            self.write(y)
    
    
    class AIX2Handler(tornado.web.RequestHandler):
        executor = ThreadPoolExecutor(32)
    
        @run_on_executor
        def aixcode(self, x):
            return AIXCode2.aixcode(x)
    
        @tornado.gen.coroutine
        def get(self):
            """get请求"""
            print(f'request:{self.request.full_url()}')
            x = self.get_argument('x')
            y = yield self.aixcode(x)
            self.write(y)
    
        @tornado.gen.coroutine
        def post(self):
            '''post请求'''
            print(f'request:{self.request.full_url()}')
            body_json = get_body_json(self.request.body)
            print(f'request:{body_json}')
            x = body_json.get("x")
            y = yield self.aixcode(x)
            self.write(y)
    
    
    if __name__ == "__main__":
        # 注册路由
        app = tornado.web.Application([
            (r"/ping", PingHandler),
            (r"/aix1", AIX1Handler),
            (r"/aix2", AIX2Handler),
        ])
    
        # 监听端口
        port = 8888
        app.listen(port)
        print(f'AIXCoder Started, Listening on Port:{port}')
        # 启动应用程序
        tornado.ioloop.IOLoop.instance().start()
    
    

    其中,class AIXCode 代码如下:

    # Copyright (c) 2022, salesforce.com, inc.
    # All rights reserved.
    # SPDX-License-Identifier: BSD-3-Clause
    # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
    
    # models_nl = ['codegen-350M-nl', 'codegen-2B-nl', 'codegen-6B-nl', 'codegen-16B-nl']
    # models_pl = ['codegen-350M-multi', 'codegen-2B-multi', 'codegen-6B-multi', 'codegen-16B-multi',
    #              'codegen-350M-mono',
    #              'codegen-2B-mono', 'codegen-6B-mono', 'codegen-16B-mono']
    
    import os
    import re
    import time
    import random
    
    import torch
    
    from transformers import GPT2TokenizerFast
    from aixcoder.codegen.modeling_codegen import CodeGenForCausalLM
    
    
    ########################################################################
    # util
    class print_time:
        def __init__(self, desc):
            self.desc = desc
    
        def __enter__(self):
            print(self.desc)
            self.t = time.time()
    
        def __exit__(self, type, value, traceback):
            print(f'{self.desc} took {time.time() - self.t:.02f}s')
    
    
    def set_env():
        os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    
    
    def set_seed(seed, deterministic=True):
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = deterministic
            torch.backends.cudnn.benchmark = not deterministic
            # torch.use_deterministic_algorithms(deterministic)
    
    
    def cast(model, fp16=True):
        # if fp16:
        #     model.half()
        return model
    
    
    ########################################################################
    # model
    
    
    def create_model(ckpt, fp16=False):
        # if fp16:
        #     return CodeGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True)
        # else:
        return CodeGenForCausalLM.from_pretrained(ckpt)
    
    
    def create_tokenizer():
        t = GPT2TokenizerFast.from_pretrained('gpt2')
        t.max_model_input_sizes['gpt2'] = 1e20
        return t
    
    
    def include_whitespace(t, n_min=2, n_max=20, as_special_tokens=False):
        t.add_tokens([' ' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens)
        return t
    
    
    def include_tabs(t, n_min=2, n_max=20, as_special_tokens=False):
        t.add_tokens(['\t' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens)
        return t
    
    
    def create_custom_gpt2_tokenizer():
        t = create_tokenizer()
        t = include_whitespace(t=t, n_min=2, n_max=32, as_special_tokens=False)
        t = include_tabs(t=t, n_min=2, n_max=10, as_special_tokens=False)
        return t
    
    
    ########################################################################
    # sample
    
    MAX_LENGTH_SAMPLE = 512
    
    
    def sample(
            model,
            tokenizer,
            context,
            pad_token_id,
            num_return_sequences=1,
            temp=0.2,
            top_p=0.95,
            max_length_sample=MAX_LENGTH_SAMPLE,
            max_length=2048
    ):
        input_ids = tokenizer(
            context,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors='pt',
        ).input_ids
    
        input_ids_len = input_ids.shape[1]
        assert input_ids_len < max_length
    
        with torch.no_grad():
            input_ids = input_ids.to()
            tokens = model.generate(
                input_ids,
                do_sample=True,
                num_return_sequences=num_return_sequences,
                temperature=temp,
                max_length=input_ids_len + max_length_sample,
                top_p=top_p,
                pad_token_id=pad_token_id,
                use_cache=True,
            )
            text = tokenizer.batch_decode(tokens[:, input_ids_len:, ...])
    
        return text
    
    
    def truncate(completion):
        def find_re(string, pattern, start_pos):
            m = pattern.search(string, start_pos)
            return m.start() if m else -1
    
        terminals = [
            re.compile(r, re.MULTILINE)
            for r in
            [
                '^#',
                re.escape('<|endoftext|>'),
                "^'''",
                '^"""',
                '\n\n\n'
            ]
        ]
    
        prints = list(re.finditer('^print', completion, re.MULTILINE))
        if len(prints) > 1:
            completion = completion[:prints[1].start()]
    
        defs = list(re.finditer('^def', completion, re.MULTILINE))
        if len(defs) > 1:
            completion = completion[:defs[1].start()]
    
        start_pos = 0
    
        terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1]
        if len(terminals_pos) > 0:
            return completion[:min(terminals_pos)]
        else:
            return completion
    
    
    class AIXCode:
        def __init__(self, model_name):
            # preamble
            set_env()
            set_seed(42, deterministic=True)
    
            ckpt = f'/Users/bytedance/githubcode/CodeGen/checkpoints/{model_name}'
    
            # load
            with print_time(f'{model_name} loading parameters'):
                model = create_model(ckpt=ckpt, fp16=False).to()
    
            with print_time(f'{model_name} loading tokenizer'):
                tokenizer = create_custom_gpt2_tokenizer()
                tokenizer.padding_side = 'left'
                tokenizer.pad_token = 50256
    
            self.model = model
            self.tokenizer = tokenizer
    
        def aixcode(self, context_string):
            # sample
            with print_time(f'{context_string} ... AIXCoding >>>'):
                completion = sample(model=self.model,
                                    tokenizer=self.tokenizer,
                                    context=context_string,
                                    pad_token_id=50256,
                                    num_return_sequences=1,
                                    temp=0.2,
                                    top_p=0.95,
                                    max_length_sample=MAX_LENGTH_SAMPLE)[0]
    
                truncation = truncate(completion)
    
                return context_string + truncation
    
    

    参考文档:
    https://blog.csdn.net/rensihui/article/details/80474706

    相关文章

      网友评论

          本文标题:Python 使用 Tornado Web 框架写HTTP接口实

          本文链接:https://www.haomeiwen.com/subject/xgtkwrtx.html