美文网首页
Langchain 使用入门

Langchain 使用入门

作者: AlienPaul | 来源:发表于2024-12-08 09:17 被阅读0次

简介

Langchain是一个简单易用的LLM应用开发框架。借助Langchain,我们可以很容易的将LLM的能力整合到自己的程序中。Langchain还能够将LLM和外部计算系统,以及数据联系起来,让LLM能基于我们定制的数据服务。

本篇为大家带来Langchain的简单入门。包含RAG,问答对生成,外部工具调用和Text2sql。

依赖安装

pip install langchain langchain-community -i https://pypi.tuna.tsinghua.edu.cn/simple/

LLM接入

Langchain对接LMStudio

首先,需要本地开发环境启用LMStudio的Local Server。方法为在LMStudio load一个本地模型之后,点击右侧的Developer标签,然后点击Server Status右侧的Start按钮。即可启动Local Server。通过启动日志可以看到Local Server的端口以及服务URL等信息。

[INFO] [LM STUDIO SERVER] Success! HTTP server listening on port 1234
[INFO]
[INFO] [LM STUDIO SERVER] Supported endpoints:
[INFO] [LM STUDIO SERVER] ->    GET  http://localhost:1234/v1/models
[INFO] [LM STUDIO SERVER] ->    POST http://localhost:1234/v1/chat/completions
[INFO] [LM STUDIO SERVER] ->    POST http://localhost:1234/v1/completions
[INFO] [LM STUDIO SERVER] ->    POST http://localhost:1234/v1/embeddings

接下来在Langchain中使用,首先需要安装langchain-openai依赖。

然后使用如下代码,连接到Local Server并创建出llm对象:

# 填写Local Server的URL
BASE_URL = "http://localhost:1234/v1/"
# API_KEY任意填写
os.environ["OPENAI_API_KEY"] = "sk-1234567890"
os.environ["OPENAI_API_BASE"] = BASE_URL
# 创建llm对象
llm = ChatOpenAI(temperature=0.0, verbose=True)

Langchain对接Ollama

首先需要安装langchain-ollama依赖。然后使用如下代码,连接到Ollama服务:

# 指定Ollama服务的URL
BASE_URL = "http://ollama_ip:11434/"
# 创建llm对象
llm = ChatOllama(base_url=BASE_URL, model="llama3.2:3b", temperature=0.0)

RAG

加载文档

加载text纯文本

from langchain.document_loaders import TextLoader
documents = TextLoader("/path/to/document.md", encoding='utf-8').load()

加载Csv文件

from langchain.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path='/path/to/data.csv')
data = loader.load()

加载PDF文档

需要安装pypdf依赖,然后使用如下方式加载:

from langchain.document_loaders import PyPDFLoader

loader = PyPDFLoader("/path/to/paper.pdf")
pages = loader.load_and_split()

加载网页

需要按照BeautifulSoup4依赖,然后使用如下方式加载:

from langchain_community.document_loaders import WebBaseLoader
documents = WebBaseLoader(web_path="https://example.com/aaa.html").load()

完整示例

使用RetrievalQA。该方式为Langchain针对RAG作业编写好的chain,可以直接使用。

import os

from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.retrievers import ContextualCompressionRetriever
from langchain_ollama import ChatOllama
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

BASE_URL = "http://ollama_ip:11434/"

if __name__ == '__main__':
    llm = ChatOllama(base_url=BASE_URL, model="llama3.2:3b", temperature=0.5)

    # 载入和切分文档
    text_spliter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
    documents = TextLoader("/path/to/document.md", encoding='utf-8').load()
    chunks = text_spliter.split_documents(documents)
    print(chunks)
    os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

    # 加载embedding模型,第一次使用会从huggingface下载
    embeddings = HuggingFaceEmbeddings(model_name='TencentBAC/Conan-embedding-v1',
                                       cache_folder="/path/to/model_cache")
    # 初始化向量库                                
    db = FAISS.from_documents(chunks, embeddings)
    # 如果需要保存到磁盘,取消注释这一行
    # db.save_local("/path/to/vector_data")
    # 查询相似数据,使用MMR算法
    # 从向量库查询fetch_k个文档片段,从其中找到k个最不同的文档片段返回
    retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 3, "fetch_k": 5})

    # 使用CrossEncoderReranker,对文档按照相关度重排序
    # 可以不使用此步骤,不是必须项
    rerank_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base",
                                           model_kwargs={})
    reranker = CrossEncoderReranker(model=rerank_model, top_n=3)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=reranker, base_retriever=retriever
    )

    # stuff类型,将所有文档片段组合在一起传给LLM
    qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=compression_retriever)

    # 对文档内容提问,并打印出答案
    print(qa.run("针对文档内容的提问"))

如果需要自定义chain,或者是使用langchain提供的prompt或者是自定义prompt。可以按照如下方法编写代码。

示例代码如下:

from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_ollama import ChatOllama

BASE_URL = "http://ollama_ip:11434/"

if __name__ == '__main__':
    llm = ChatOllama(base_url=BASE_URL, model="llama3.2:3b", temperature=0.6)

    text_spliter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
    documents = TextLoader("/path/to/document.md", encoding='utf-8').load()
    chunks = text_spliter.split_documents(documents)
    print(chunks)
    os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

    from langchain_huggingface import HuggingFaceEmbeddings
    from langchain_community.vectorstores import FAISS
    embeddings = HuggingFaceEmbeddings(model_name='TencentBAC/Conan-embedding-v1', cache_folder="/path/to/model_cache")
    db = FAISS.from_documents(chunks, embeddings)
    # db.save_local("/path/to/vector_data")
    retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 3, "fetch_k": 5})
    # 这里使用自定义prompt
    # template = """你是问答任务助手。你需要使用下面检索到的上下文片段来回答问题。如果你不知道答案回答你不知道。最多使用五个句子回答,务必保持答案简洁。
    #         Question: {question}
    #         Context: {context}
    #         Answer:
    #         """
    # prompt = ChatPromptTemplate.from_template(template)
    
    # 这里使用的是hub提供的prompt
    from langchain import hub
    prompt = hub.pull("rlm/rag-prompt")
    
    rag_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    print(rag_chain.invoke("针对文档内容的提问"))

QA生成

Langchain可以根据文档拆分成的片段内容,生成与之匹配的问答对。

默认情况LLM生成的JSON结果会被包裹在三个撇号中,导致Python解析JSON失败。需要修改Langchain默认的prompt。默认的prompt位于

site-packages/langchain/chains/qa_generation/prompt.py

其内容为:

templ = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
When coming up with this question/answer pair, you must respond in the following format:
# Markdown语法增加转义,实际上没有斜杠
\```
{{
    "question": "$YOUR_QUESTION_HERE",
    "answer": "$THE_ANSWER_HERE"
}}
# Markdown语法增加转义,实际上没有斜杠
\```

Everything between the ``` must be valid json.

Please come up with a question/answer pair, in the specified JSON format, for the following text:
----------------
{text}"""

接下来的程序片段我们拆分一个文档,将其片段的内容逐个生成问答对。

修改了默认的prompt,确保大模型生成的JSON结果不会被包裹在三个撇号中,这样后续解析不会失败。

示例代码如下:

from langchain.chains.qa_generation.base import QAGenerationChain
from langchain_core.prompts import PromptTemplate
from langchain_ollama import ChatOllama

BASE_URL = "http://ollama_ip:11434/"

templ = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
When coming up with this question/answer pair, you must respond in the following format:
{{
    "question": "$YOUR_QUESTION_HERE",
    "answer": "$THE_ANSWER_HERE"
}}

Response must be valid json.

Please come up with a question/answer pair, in the specified JSON format, for the following text:
----------------
{text}"""
PROMPT = PromptTemplate.from_template(templ)

if __name__ == '__main__':
    llm = ChatOllama(base_url=BASE_URL, model="qwen2.5:7b", temperature=0.0)

    from langchain.text_splitter import RecursiveCharacterTextSplitter
    from langchain.document_loaders import TextLoader
    text_spliter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
    documents = TextLoader("/path/to/document.md", encoding='utf-8').load()
    chunks = text_spliter.split_documents(documents)
    # 打印文档片段
    print(chunks)

    gen_chain = QAGenerationChain.from_llm(llm, prompt=PROMPT)
    examples = gen_chain.apply([{'text': t.page_content} for t in chunks[:5]])
    # 打印出生成好的问答对
    print(examples)

Tools调用

Langchain可以让LLM调用外部代码,从而赋予大语言模型更为强大的能力。

下面举两个例子。第一个例子赋予大模型知道当前日期的能力。第二个例子演示带参数方法调用,让大模型返回大数据HDP平台组件在主机上的安装路径。

from datetime import date

from langchain.agents import initialize_agent, AgentType
from langchain_ollama import ChatOllama
from langchain.tools import tool

BASE_URL = "http://ollama_ip:11434/"

@tool
def today(text: str) -> str:
    """
    返回今天的日期。用于需要返回今天日期的场景
    这个函数总是返回今天的日期,任何日期计算应该在这个函数之外进行
    """
    return str(date.today())

@tool
def get_hdp_installation_dir(component: str) -> str:
    """
    返回HDP组件的安装目录。用于获取HDP组件的安装目录。
    接收参数为HDP组件名称。
    """
    import os
    return os.path.join('/usr/hdp/3.0.1.0-187/', component)

if __name__ == '__main__':
    llm = ChatOllama(base_url=BASE_URL, model="qwen2.5:7b", temperature=0.5)
    agent = initialize_agent(tools=[today, get_hdp_installation_dir], llm=llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
    print(agent.run("今天是几号?"))
    print(agent.run("hive在HDP中的安装路径是什么"))

Text2sql

Langchain可以通过LLM将自然语言转化为SQL,然后执行并获取到查询结果。经实测生成SQL的能力和采用的大语言模型以及参数量关系密切。

需要额外安装langchain-experimental依赖。

pip install langchain-experimental -i https://pypi.tuna.tsinghua.edu.cn/simple/

连接MySQL还需要安装pymysql:

pip install pymysql -i https://pypi.tuna.tsinghua.edu.cn/simple/

接下来的示例均使用如下表结构和数据:

create database demo character set utf8mb4;

create table demo.student (
    id int,
    name varchar(50),
    age int,
    enrollment_date date,
    tutor_id int
);

create table demo.tutor (
    id int,
    name varchar(50),
    age int
);

insert into demo.student values
    (1, 'Paul', 20, '2024-01-01', 1),
    (2, 'Kate', 25, '2023-10-01', 1),
    (3, 'Peter', 22, '2024-04-30', 2),
    (4, 'Mary', 18, '2022-05-10', 1),
    (5, 'Liza', 34, '2023-03-30', 2);
    
insert into demo.tutor values
    (1, 'Adam', 42),
    (2, 'Sam', 53);

使用SQLDataBaseChain

该方法将数据库的schema,示例数据和问题一起作为prompt传递给LLM。

由于默认Langchain提示词会导致LLM生成的SQL包裹在三撇号之中,执行SQL会报错。下面例子修改了默认的prompt,强调不让大模型将生成SQL包裹在三撇号之中。

from langchain_core.prompts import PromptTemplate
from langchain_ollama import ChatOllama
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.utilities import SQLDatabase

BASE_URL = "http://ollama_ip:11434/"

_mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today". 
You must not wrap sql query in triple apostrophe. You must make sure the constant in where clause is wrapped in single quote (').You must make sure to wrap all subqueries with brackets.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

"""

PROMPT_SUFFIX = """Only use the following tables:
{table_info}

Question: {input}"""

MYSQL_PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template=_mysql_prompt + PROMPT_SUFFIX,
)

if __name__ == '__main__':
    # 精准回答,不需要过高的temperature
    llm = ChatOllama(base_url=BASE_URL, model="qwen2.5:7b", temperature=0.0)
    
    # 数据库连接信息
    db_user = 'root'
    db_password = 'password'
    db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@mysql_ip:3306/demo")
    # 配置返回给LLM用作prompt的样本数据的行数
    # db._sample_rows_in_table_info = 5
    
    # 查询数据,可验证数据库连接是否正常
    # print(db.get_usable_table_names())
    # print(db.run('select * from hosts'))
    
    # 创建SQLDatabaseChain,use_query_checker表示需要检查SQL语法,return_direct表示直接执行SQL,返回查询结果
    db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=MYSQL_PROMPT, use_query_checker=True, return_direct=True)
    
    # 提问,LLM会将问题转化为SQL,返回查询之后的结果
    db_chain.run("How many students are there?")
    db_chain.run("get the id of the student named Mary")
    db_chain.run("get all names of students whose name begins with letter P")
    db_chain.run("get all names of students whose name's length is 4 letters")

使用react agent

react agent实际上使用了上面Tools调用一章的内容。Database相关的toolkitlangchain已经为我们定义好了

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

这个toolkit中包含如下三个tool:

  • sql_db_list_tables:列出所有的数据表名。
  • sql_db_schema:获取表的schema,类似于show create table
  • sql_db_query_checker:校验SQL语法。
  • sql_db_query:查询SQL。这一步真正去查询数据库。

LLM在执行的时候会根据当前的状态,调用这些工具。

工具调用关系图

参考链接:https://python.langchain.ac.cn/docs/integrations/tools/sql_database/

import os

from langchain_ollama import ChatOllama
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain import hub
from langgraph.prebuilt import create_react_agent

BASE_URL = "http://ollama_ip:11434/"

if __name__ == '__main__':
    llm = ChatOllama(base_url=BASE_URL, model="qwen2.5:7b", temperature=0.0, num_predict=-1)
    
    # 连接数据库的必要信息
    db_user = 'root'
    db_password = 'password'
    db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@mysql_ip:3306/demo")
    # 配置返回给LLM用作prompt的样本数据的行数
    # db._sample_rows_in_table_info = 5

    # 创建数据库工具集
    toolkit = SQLDatabaseToolkit(db=db, llm=llm)

    # 拉取prompt
    prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
    system_message = prompt_template.format(dialect="MySQL", top_k=5)

    # 创建agent
    agent_executor = create_react_agent(
        llm, tools=toolkit.get_tools(), state_modifier=system_message
    )

    # Query agent
    example_query = "Get all tutor Sam's student names and ages"
 
    events = agent_executor.stream(
        {"messages": [("user", example_query)]},
        stream_mode="values",
    )
    for event in events:
        event["messages"][-1].pretty_print()

相关文章

网友评论

      本文标题:Langchain 使用入门

      本文链接:https://www.haomeiwen.com/subject/cvxgsjtx.html