美文网首页
基于PAI—EAS平台ChatGLM API进行模型推理

基于PAI—EAS平台ChatGLM API进行模型推理

作者: 梅西爱骑车 | 来源:发表于2024-03-08 22:11 被阅读0次

    ChatGLM的部署见上篇文章:使用PAI——EAS部署ChatGLM,部署之后是Web页面方式访问langchain-ChatGLM,下面是通过API方式访问的实例。

    一、获取服务访问地址和Token

    1. 进入PAI-EAS 模型在线服务页面,详情请参见使用PAI——EAS部署ChatGLM
    2. 在该页面中单击目标服务名称进入“服务详情”页面。
    3. 在“基本信息”区域单击“查看调用信息”,在“公网地址调用”页签获取服务Token和访问地址。


      调用信息
      由于我通过本地命令调用,所以只关注公网地址调用,VPC暂且不涉及。 公网地址调用信息

    二、 启动API进行模型推理。

    2.1 使用HTTP方式调用服务

    2.1.1非流式调用

    客户端使用标准的HTTP格式,使用curl命令调用时,支持发送以下两种类型的请求:

    1. 发送String类型的请求
    curl $host -H 'Authorization: $authorization' --data-binary @chatllm_data.txt -v
    

    其中:$authorization需替换为服务Token,$host:需替换为服务访问地址,chatllm_data.txt:该文件为包含问题的纯文本文件。

    1. 发送结构化类型的请求
    
    curl $host -H 'Authorization: $authorization' -H "Content-type: application/json" --data-binary @chatllm_data.json -v -H "Connection: close"
    
    

    使用chatllm_data.json文件来设置推理参数,chatllm_data.json文件的内容格式如下:

    {
        "max_new_tokens": 4096,
        "use_stream_chat": false,
        "prompt": "How to install it?",
        "system_prompt": "Act like you are programmer with 5+ years of experience."
        "history": [
            [
                "Can you tell me what's the bladellm?",
                "BladeLLM is an framework for LLM serving, integrated with acceleration techniques like quantization, ai compilation, etc. , and supporting popular LLMs like OPT, Bloom, LLaMA, etc."
            ]
        ],
        "temperature": 0.8,
        "top_k": 10,
        "top_p": 0.8,
        "do_sample": True,
        "use_cache": True,
    }
    

    参数说明如下,请酌情添加与删减。


    参数说明

    也可以基于Python的requests包实现自己的客户端,示例代码如下:

    import argparse
    import json
    from typing import Iterable, List
    
    import requests
    
    def post_http_request(prompt: str,
                          system_prompt: str,
                          history: list,
                          host: str,
                          authorization: str,
                          max_new_tokens: int = 2048,
                          temperature: float = 0.95,
                          top_k: int = 1,
                          top_p: float = 0.8,
                          langchain: bool = False,
                          use_stream_chat: bool = False) -> requests.Response:
        headers = {
            "User-Agent": "Test Client",
            "Authorization": f"{authorization}"
        }
        if not history:
            history = [
                (
                    "San Francisco is a",
                    "city located in the state of California in the United States. \
                    It is known for its iconic landmarks, such as the Golden Gate Bridge \
                    and Alcatraz Island, as well as its vibrant culture, diverse population, \
                    and tech industry. The city is also home to many famous companies and \
                    startups, including Google, Apple, and Twitter."
                )
            ]
        pload = {
            "prompt": prompt,
            "system_prompt": system_prompt,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "use_stream_chat": use_stream_chat,
            "history": history
        }
        if langchain:
            print(langchain)
            pload["langchain"] = langchain
        response = requests.post(host, headers=headers,
                                 json=pload, stream=use_stream_chat)
        return response
    
    def get_response(response: requests.Response) -> List[str]:
        data = json.loads(response.content)
        output = data["response"]
        history = data["history"]
        return output, history
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--top-k", type=int, default=4)
        parser.add_argument("--top-p", type=float, default=0.8)
        parser.add_argument("--max-new-tokens", type=int, default=2048)
        parser.add_argument("--temperature", type=float, default=0.95)
        parser.add_argument("--prompt", type=str, default="How can I get there?")
        parser.add_argument("--langchain", action="store_true")
    
        args = parser.parse_args()
    
        prompt = args.prompt
        top_k = args.top_k
        top_p = args.top_p
        use_stream_chat = False
        temperature = args.temperature
        langchain = args.langchain
        max_new_tokens = args.max_new_tokens
    
        host = "EAS服务公网地址"
        authorization = "EAS服务公网Token"
    
        print(f"Prompt: {prompt!r}\n", flush=True)
        # 在客户端请求中可设置语言模型输入中的system prompt
        system_prompt = "Act like you are programmer with \
                    5+ years of experience."
    
        # 客户端请求中可设置对话的历史信息,客户端维护当前用户的对话记录,用于实现多轮对话。通常情况下可以使用上一轮对话返回的histroy信息,history格式为List[Tuple(str, str)]
        history = []
        response = post_http_request(
            prompt, system_prompt, history,
            host, authorization,
            max_new_tokens, temperature, top_k, top_p,
            langchain=langchain, use_stream_chat=use_stream_chat)
        output, history = get_response(response)
        print(f" --- output: {output} \n --- history: {history}", flush=True)
    
    # 服务端返回结果为json,包含推理结果与对话历史
    def get_response(response: requests.Response) -> List[str]:
        data = json.loads(response.content)
        output = data["response"]
        history = data["history"]
        return output, history
    

    其中:
    host:配置为服务访问地址。
    authorization:配置为服务Token。

    2.1.2流式调用

    流式调用使用HTTP SSE方式,其他设置方式与非流式相同,代码参考如下:

    import argparse
    import json
    from typing import Iterable, List
    
    import requests
    
    
    def clear_line(n: int = 1) -> None:
        LINE_UP = '\033[1A'
        LINE_CLEAR = '\x1b[2K'
        for _ in range(n):
            print(LINE_UP, end=LINE_CLEAR, flush=True)
    
    
    def post_http_request(prompt: str,
                          system_prompt: str,
                          history: list,
                          host: str,
                          authorization: str,
                          max_new_tokens: int = 2048,
                          temperature: float = 0.95,
                          top_k: int = 1,
                          top_p: float = 0.8,
                          langchain: bool = False,
                          use_stream_chat: bool = False) -> requests.Response:
        headers = {
            "User-Agent": "Test Client",
            "Authorization": f"{authorization}"
        }
        if not history:
            history = [
                (
                    "San Francisco is a",
                    "city located in the state of California in the United States. \
                    It is known for its iconic landmarks, such as the Golden Gate Bridge \
                    and Alcatraz Island, as well as its vibrant culture, diverse population, \
                    and tech industry. The city is also home to many famous companies and \
                    startups, including Google, Apple, and Twitter."
                )
            ]
        pload = {
            "prompt": prompt,
            "system_prompt": system_prompt,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "use_stream_chat": use_stream_chat,
            "history": history
        }
        if langchain:
            print(langchain)
            pload["langchain"] = langchain
        response = requests.post(host, headers=headers,
                                 json=pload, stream=use_stream_chat)
        return response
    
    
    def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
        for chunk in response.iter_lines(chunk_size=8192,
                                         decode_unicode=False,
                                         delimiter=b"\0"):
            if chunk:
                data = json.loads(chunk.decode("utf-8"))
                output = data["response"]
                history = data["history"]
                yield output, history
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--top-k", type=int, default=4)
        parser.add_argument("--top-p", type=float, default=0.8)
        parser.add_argument("--max-new-tokens", type=int, default=2048)
        parser.add_argument("--temperature", type=float, default=0.95)
        parser.add_argument("--prompt", type=str, default="How can I get there?")
        parser.add_argument("--langchain", action="store_true")
        args = parser.parse_args()
    
        prompt = args.prompt
        top_k = args.top_k
        top_p = args.top_p
        use_stream_chat = True
        temperature = args.temperature
        langchain = args.langchain
        max_new_tokens = args.max_new_tokens
    
        host = ""
        authorization = ""
    
        print(f"Prompt: {prompt!r}\n", flush=True)
        system_prompt = "Act like you are programmer with \
                    5+ years of experience."
        history = []
        response = post_http_request(
            prompt, system_prompt, history,
            host, authorization,
            max_new_tokens, temperature, top_k, top_p,
            langchain=langchain, use_stream_chat=use_stream_chat)
    
        for h, history in get_streaming_response(response):
            print(
                f" --- stream line: {h} \n --- history: {history}", flush=True)
    

    2.1.2.3 如何配置更多参数

    运行命令中支持配置的参数如下:




    相关文章

      网友评论

          本文标题:基于PAI—EAS平台ChatGLM API进行模型推理

          本文链接:https://www.haomeiwen.com/subject/hhtewdtx.html