一、基本概念
1.1、Prompt
大模型的所有输入,即,我们每一次访问大模型的输入为一个 Prompt, 而大模型给我们的返回结果则被称为 Completion。
1.2、Temperature
LLM 生成是具有随机性的,在模型的顶层通过选取不同预测概率的预测结果来生成最后的结果,而Temperature 参数就是用来控制 LLM 生成结果的随机性与创造性。
Temperature 一般取值在 0~1 之间,当取值较低接近0时,预测的随机性小,更为保守,严谨,稳定。当取值较高接近1时,预测的随机性会较高,预测结果更创意,多样化。
不同问题的应用场景,设置不同的Temperature。例如:
- 个人知识库项目,一般将 Temperature 设置为0,从而保证知识库内容的稳定使用,规避错误内容;
- 产品智能客服、科研论文写作等场景中,同样更需要稳定性而不是创造性;
- 个性化 AI、创意营销文案生成等场景中,更需要创意性,从而更倾向于将 Temperature 设置为较高的值。
1.3、System Prompt
使用 ChatGPT API 时,可以设置两种 Prompt:一种是 System Prompt,该种 Prompt 内容会在整个会话过程中持久地影响模型的回复,且相比于普通 Prompt 具有更高的重要性;另一种是 User Prompt,这更偏向于普通的 Prompt,即需要模型做出回复的输入。
System Prompt 一般在一个会话中仅有一个。
二、调用ChatGPT
调用 ChatGPT API 的2种方法:直接调用 OpenAI 的原生接口,或是基于 LangChain 调用 ChatGPT API。
2.1、准备OpenAI API Key
1、登录openai官网,登录账号
2、选择API,然后点击右上角的头像,选择View API keys
3、点击Create new secret key按钮创建OpenAI API key,将创建好的OpenAI API key复制以此形式OPENAI_API_KEY="sk-..."保存到.env文件中,并将.env文件保存在项目根目录下。
读取.env文件的代码
import os
import openai
from dotenv import load_dotenv, find_dotenv
# 读取本地/项目的环境变量。
# find_dotenv()寻找并定位.env文件的路径
# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中
# 如果你设置的是全局的环境变量,这行代码则没有任何作用。
_ = load_dotenv(find_dotenv())
# 如果你需要通过代理端口访问,你需要如下配置
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
os.environ["HTTP_PROXY"] = 'http://127.0.0.1:7890'
# 获取环境变量 OPENAI_API_KEY
openai.api_key = os.environ['OPENAI_API_KEY']
2.2、调用OpenAI原生接口
接口文档地址:https://platform.openai.com/docs/api-reference/chat
调用 ChatGPT 需要使用 ChatCompletion API,ChatCompletion API 调用方法如下:
import openai
# 导入所需库
# 注意,此处我们假设你已根据上文配置了 OpenAI API Key,如没有将访问失败
completion = openai.ChatCompletion.create(
# 创建一个 ChatCompletion
# 调用模型:ChatGPT-3.5
model="gpt-3.5-turbo",
# message 是你的 prompt
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
调用该 API 会返回一个 ChatCompletion 对象,其中包括了回答文本、创建时间、ID等属性。我们一般需要的是回答文本,也就是回答对象中的 content 信息。
<OpenAIObject chat.completion id=chatcmpl-80QUFny7lXqOcfu5CZMRYhgXqUCv0 at 0x7f1fbc0bd770> JSON: {
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "Hello! How can I assist you today?",
"role": "assistant"
}
}
],
"created": 1695112507,
"id": "chatcmpl-80QUFny7lXqOcfu5CZMRYhgXqUCv0",
"model": "gpt-3.5-turbo-0613",
"object": "chat.completion",
"usage": {
"completion_tokens": 9,
"prompt_tokens": 19,
"total_tokens": 28
}
}
print(completion["choices"][0]["message"]["content"])
Hello! How can I assist you today?
API 常会用到的几个参数:
· model,即调用的模型,一般取值包括“gpt-3.5-turbo”(ChatGPT-3.5)、“gpt-3.5-16k-0613”(ChatGPT-3.5 16K 版本)、“gpt-4”(ChatGPT-4)。注意,不同模型的成本是不一样的。
· message,即我们的 prompt。ChatCompletion 的 message 需要传入一个列表,列表中包括多个不同角色的 prompt。我们可以选择的角色一般包括 system:即前文中提到的 system prompt;user:用户输入的 prompt;assitance:助手,一般是模型历史回复,作为给模型参考的示例。
· temperature,温度。即前文中提到的 Temperature 系数。
· max_tokens,最大 token 数,即模型输出的最大 token 数。OpenAI 计算 token 数是合并计算 Prompt 和 Completion 的总 token 数,要求总 token 数不能超过模型上限(如默认模型 token 上限为 4096)。因此,如果输入的 prompt 较长,需要设置较小的 max_token 值,否则会报错超出限制长度。
把对OpenAI API的调用封装成一个函数,直接传入 Prompt 并获得模型的输出:
# 一个封装 OpenAI 接口的函数,参数为 Prompt,返回对应结果
def get_completion(prompt, model="gpt-3.5-turbo", temperature = 0):
'''
prompt: 对应的提示词
model: 调用的模型,默认为 gpt-3.5-turbo(ChatGPT),有内测资格的用户可以选择 gpt-4
'''
messages = [{"role": "user", "content": prompt}]
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature, # 模型输出的温度系数,控制输出的随机程度
)
# 调用 OpenAI 的 ChatCompletion 接口
return response.choices[0].message["content"]
上述函数中,封装了 messages 的细节,仅使用 user prompt 来实现调用。在简单场景中,该函数完全足够使用。
2.3、基于 LangChain 调用 ChatGPT
LangChain 提供了对于多种大模型的封装,基于 LangChain 的接口可以便捷地调用 ChatGPT 并将其集合在以 LangChain 为基础框架搭建的个人应用中。
官网文档:https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.chat_models
1、从langchain.chat_models导入OpenAI的对话模型ChatOpenAI。
from langchain.chat_models import ChatOpenAI
2、实例化一个 ChatOpenAI 类,可以在实例化时传入超参数来控制回答,例如 temperature 参数。
# 这里我们将参数temperature设置为0.0,从而减少生成答案的随机性。
# 如果你想要每次得到不一样的有新意的答案,可以尝试调整该参数。
chat = ChatOpenAI(temperature=0.0)
chat
ChatOpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, model_name='gpt-3.5-turbo', temperature=0.0, model_kwargs={}, openai_api_key='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', openai_api_base='', openai_organization='', openai_proxy='', request_timeout=None, max_retries=6, streaming=False, n=1, max_tokens=None, tiktoken_model_name=None)
常用的超参数设置包括:
· model_name:所要使用的模型,默认为 ‘gpt-3.5-turbo’,参数设置与 OpenAI 原生接口参数设置一致。
· temperature:温度系数,取值同原生接口。
· openai_api_key:OpenAI API key,如果不使用环境变量设置 API Key,也可以在实例化时设置。
· openai_proxy:设置代理,如果不使用环境变量设置代理,也可以在实例化时设置。
· streaming:是否使用流式传输,即逐字输出模型回答,默认为 False,此处不赘述。
· max_tokens:模型输出的最大 token 数,意义及取值同上。
3、 构造个性化 Template。
Template,即模板,是 LangChain 设置好的一种 Prompt 格式,开发者可以直接调用 Template 向里面填充个性化任务。
from langchain.prompts import ChatPromptTemplate
# 这里我们要求模型对给定文本进行中文翻译
template_string = """Translate the text \
that is delimited by triple backticks \
into a Chinses. \
text: ```{text}```
"""
# 接着将 Template 实例化
chat_template = ChatPromptTemplate.from_template(template_string)
4、将 template 转化为message 格式
# 我们首先设置变量值
text = "Today is a nice day."
# 接着调用 format_messages 将 template 转化为 message 格式
message = chat_template.format_messages(text=text)
print(message)
[HumanMessage(content='Translate the text that is delimited by triple backticks into a Chinses. text: ```Today is a nice day.```\n', additional_kwargs={}, example=False)]
5、使用实例化的类型直接传入设定好的 prompt
response = chat(message)
response
AIMessage(content='今天是个好天气。', additional_kwargs={}, example=False)
返回值的 content 属性即为模型的返回文本。
三、langchain 核心组件详解
3.1、模型输入/输出
LangChain 中模型输入/输出模块是与各种大语言模型进行交互的基本组件,是大语言模型应用的核心元素。模型 I/O 允许管理 prompt(提示),通过通用接口调用语言模型以及从模型输出中提取信息。该模块的基本流程:
3.2、数据连接
大语言模型的知识来源于其训练数据集,并没有用户的信息(比如用户的个人数据,公司的自有数据),也没有最新发生时事的信息(在大模型数据训练后发表的文章或者新闻)。因此大模型能给出的答案比较受限。如果能够让大模型在训练数据集的基础上,利用我们自有数据中的信息来回答我们的问题,那便能够得到更有用的答案。
LangChain 数据连接(Data connection)模块支持自定义数据,支持加载、转换、存储和查询数据,模块具体内容包括:Document loaders、Document transformers、Text embedding models、Vector stores 以及 Retrievers。数据连接模块部分的基本框架如下。
3.3、链(Chain)
独立使用大型语言模型能够应对一些简单任务,但对于更加复杂的需求,可能需要将多个大型语言模型进行链式组合,或与其他组件进行链式调用。链允许将多个组件组合在一起,创建一个连贯的应用程序。
大语言模型链(LLMChain)的使用:
import warnings
warnings.filterwarnings('ignore')
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
# 这里我们将参数temperature设置为0.0,从而减少生成答案的随机性。
# 如果你想要每次得到不一样的有新意的答案,可以尝试调整该参数。
llm = ChatOpenAI(temperature=0.0)
#初始化提示模版
prompt = ChatPromptTemplate.from_template("描述制造{product}的一个公司的最佳名称是什么?")
#将大语言模型(LLM)和提示(Prompt)组合成链
chain = LLMChain(llm=llm, prompt=prompt)
#运行大语言模型链
product = "大号床单套装"
chain.run(product)
输出:
'"豪华床纺"'
除了LLMChain,LangChain 中链还包含 RouterChain、SimpleSequentialChain、SequentialChain、TransformChain 等。
RouterChain 可以根据输入数据的某些属性/特征值,选择调用不同的子链(Subchain)。
SimpleSequentialChain 是最简单的序列链形式,其中每个步骤具有单一的输入/输出,上一个步骤的输出是下一个步骤的输入。
SequentialChain 是简单顺序链的更复杂形式,允许多个输入/输出。
TransformChain 可以引入自定义转换函数,对输入进行处理后进行输出。
使用 SimpleSequentialChain 的代码示例:
from langchain.chains import SimpleSequentialChain
llm = ChatOpenAI(temperature=0.9)
#创建两个子链
# 提示模板 1 :这个提示将接受产品并返回最佳名称来描述该公司
first_prompt = ChatPromptTemplate.from_template(
"描述制造{product}的一个公司的最好的名称是什么"
)
chain_one = LLMChain(llm=llm, prompt=first_prompt)
# 提示模板 2 :接受公司名称,然后输出该公司的长为20个单词的描述
second_prompt = ChatPromptTemplate.from_template(
"写一个20字的描述对于下面这个\
公司:{company_name}的"
)
chain_two = LLMChain(llm=llm, prompt=second_prompt)
#构建简单顺序链
#现在我们可以组合两个LLMChain,以便我们可以在一个步骤中创建公司名称和描述
overall_simple_chain = SimpleSequentialChain(chains=[chain_one, chain_two], verbose=True)
#运行简单顺序链
product = "大号床单套装"
overall_simple_chain.run(product)
输出:
> Entering new SimpleSequentialChain chain...
优床制造公司
优床制造公司是一家专注于生产高品质床具的公司。
> Finished chain.
'优床制造公司是一家专注于生产高品质床具的公司。'
3.4、记忆(Meomory)
在 LangChain 中,记忆(Memory)指的是大语言模型(LLM)的短期记忆。
为什么是短期记忆?那是因为LLM训练好之后 (获得了一些长期记忆),它的参数便不会因为用户的输入而发生改变。当用户与训练好的LLM进行对话时,LLM 会暂时记住用户的输入和它已经生成的输出,以便预测之后的输出,而模型输出完毕后,它便会“遗忘”之前用户的输入和它的输出。因此,之前的这些信息只能称作为 LLM 的短期记忆。
3.5、 代理(Agents)
大型语言模型(LLMs)非常强大,但它们缺乏“最笨”的计算机程序可以轻松处理的特定能力。例如,无法准确回答简单的计算问题,还有当询问最近发生的事件时,其回答也可能过时或错误,因为无法主动获取最新信息。这是由于当前语言模型仅依赖预训练数据,与外界“断开”。要克服这一缺陷, LangChain 框架提出了 “代理”( Agent ) 的解决方案。
代理作为语言模型的外部模块,可提供计算、逻辑、检索等功能的支持,使语言模型获得异常强大的推理和获取信息的超能力。
3.6、回调(Callback)
LangChain的回调系统,允许连接到LLM应用程序的各个阶段。这对于日志记录、监视、流式处理和其他任务非常有用。
Callback 模块扮演着记录整个流程运行情况的角色,充当类似于日志的功能。在每个关键节点,它记录了相应的信息,以便跟踪整个应用的运行情况。
Callback 模块的具体实现包括两个主要功能,对应CallbackHandler 和 CallbackManager 的基类功能:
- CallbackHandler 用于记录每个应用场景(如 Agent、LLchain 或 Tool )的日志,它是单个日志处理器,主要记录单个场景的完整日志信息。
- CallbackManager则封装和管理所有的 CallbackHandler ,包括单个场景的处理器,也包括整个运行时链路的处理器。"
四、基于 LangChain 自定义 LLM
LangChain 为基于 LLM 开发自定义应用提供了高效的开发框架,便于开发者迅速地激发 LLM 的强大能力,搭建 LLM 应用。LangChain 也同样支持多种大模型,内置了 OpenAI、LLAMA 等大模型的调用接口。但是,LangChain 并没有内置所有大模型,它通过允许用户自定义 LLM 类型,来提供强大的可扩展性。
要实现自定义 LLM,需要定义一个自定义类继承自 LangChain 的 LLM 基类,然后定义两个函数:① _call 方法,其接受一个字符串,并返回一个字符串,即模型的核心调用;② _identifying_params 方法,用于打印 LLM 信息。
4.1、导入所需的第三方库
import json
import time
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
from pydantic import Field, model_validator
4.2、定义一个 get_access_token 方法来获取 access_token
def get_access_token(api_key : str, secret_key : str):
"""
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
"""
# 指定网址
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
# 设置 POST 访问
payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
# 通过 POST 访问获取账户对应的 access_token
response = requests.request("POST", url, headers=headers, data=payload)
return response.json().get("access_token")
4.3、定义一个继承自 LLM 类的自定义 LLM 类:
# 继承自 langchain.llms.base.LLM
class Wenxin_LLM(LLM):
# 原生接口地址
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
# 默认选用 ERNIE-Bot-turbo 模型,即目前一般所说的百度文心大模型
model_name: str = Field(default="ERNIE-Bot-turbo", alias="model")
# 访问时延上限
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
# 温度系数
temperature: float = 0.1
# API_Key
api_key: str = None
# Secret_Key
secret_key : str = None
# access_token
access_token: str = None
# 必备的可选参数
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
4.4、实现一个初始化方法 init_access_token,当模型的 access_token 为空时调用
def init_access_token(self):
if self.api_key != None and self.secret_key != None:
# 两个 Key 均非空才可以获取 access_token
try:
self.access_token = get_access_token(self.api_key, self.secret_key)
except Exception as e:
print(e)
print("获取 access_token 失败,请检查 Key")
else:
print("API_Key 或 Secret_Key 为空,请检查 Key")
4.5、实现核心的方法——调用模型 API
def _call(self, prompt : str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
# 除 prompt 参数外,其他参数并没有被用到,但当我们通过 LangChain 调用时会传入这些参数,因此必须设置
# 如果 access_token 为空,初始化 access_token
if self.access_token == None:
self.init_access_token()
# API 调用 url
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}".format(self.access_token)
# 配置 POST 参数
payload = json.dumps({
"messages": [
{
"role": "user",# user prompt
"content": "{}".format(prompt)# 输入的 prompt
}
],
'temperature' : self.temperature
})
headers = {
'Content-Type': 'application/json'
}
# 发起请求
response = requests.request("POST", url, headers=headers, data=payload, timeout=self.request_timeout)
if response.status_code == 200:
# 返回的是一个 Json 字符串
js = json.loads(response.text)
return js["result"]
else:
return "请求失败"
4.6、定义模型的描述方法
# 首先定义一个返回默认参数的方法
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用Ennie API的默认参数。"""
normal_params = {
"temperature": self.temperature,
"request_timeout": self.request_timeout,
}
return {**normal_params}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
五、将大模型 API 封装成本地 API
我们可以使用 FastAPI,对不同的大模型 API 再进行一层封装,将其映射到本地接口上,从而通过统一的方式来调用本地接口实现不同大模型的调用。通过这样的手段,可以极大程度减少对于模型调用的工作量和复杂度。
以讯飞星火大模型 API 为例。
5.1、创建 api 对象
安装 fastapi 第三方库
! pip install fastapi
导入需要依赖包
from fastapi import FastAPI
from pydantic import BaseModel
import os
app = FastAPI() # 创建 api 对象
5.2、定义数据模型接收数据
本地 API 一般通过 POST 方式进行访问,即参数会附加在 POST 请求中,我们需要定义一个数据模型来接收 POST 请求中的数据:
# 定义一个数据模型,用于接收POST请求中的数据
class Item(BaseModel):
prompt : str # 用户 prompt
temperature : float # 温度系数
max_tokens : int # token 上限
if_list : bool = False # 是否多轮对话
数据模型中常用参数:
· prompt:即用户输入的 Prompt。我们默认为单轮对话调用,因此 prompt 默认为一句输入;如果将 if_list 设置为 True,那么就是多轮对话调用,prompt 应为一个已构造好(即有标准 role、content 格式)的列表字符串
· temperature:温度系数
· max_tokens:回答的最大 token 上限
· if_list:是否多轮对话,默认为 False
5.3、创建 POST 请求的 API 端点
@app.post("/spark/")
async def get_spark_response(item: Item):
# 实现星火大模型调用的 API 端点
response = get_spark(item)
return response
定义一个函数来实现对星火 API 的调用
import SparkApiSelf
# 首先定义一个构造参数函数
def getText(role, content, text = []):
# role 是指定角色,content 是 prompt 内容
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def get_spark(item):
# 配置 spark 秘钥
#以下密钥信息从控制台获取
appid = "9f922c84" #填写控制台中获取的 APPID 信息
api_secret = "YjU0ODk4MWQ4NTgyNDU5MzNiNWQzZmZm" #填写控制台中获取的 APISecret 信息
api_key ="5d4e6e41f6453936ccc34dd524904324" #填写控制台中获取的 APIKey 信息
domain = "generalv2" # v2.0版本
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
# 构造请求参数
if item.if_list:
prompt = item.prompt
else:
prompt = getText("user", item.prompt)
response = SparkApiSelf.main(appid,api_key,api_secret,Spark_url,domain,prompt, item.temperature, item.max_tokens)
return response
注意,由于星火给出的示例 SparkApi 中将 temperature、max_tokens 都进行了封装,我们需要对示例代码进行改写,暴露出这两个参数接口,我们实现了一个新的文件 SparkApiSelf,对其中的改动如下:
首先,我们对参数类中新增了 temperature、max_tokens 两个属性:
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 自定义
self.temperature = 0
self.max_tokens = 2048
然后在生成请求参数的函数中,增加这两个参数并在构造请求数据时加入参数:
def gen_params(appid, domain,question, temperature, max_tokens):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": max_tokens,
"temperature" : temperature,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
在 run 函数中调用生成参数时加入这两个参数:
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
ws.send(data)
由于 WebSocket 是直接打印到终端,但我们需要将最后的结果返回给用户,我们需要修改 main 函数,使用一个队列来装填星火流式输出产生的结果,并最终集成返回给用户:
def main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
# print("星火:")
output_queue = queue.Queue()
def on_message(ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
# print(content, end='')
# 将输出值放入队列
output_queue.put(content)
if status == 2:
ws.close()
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.temperature = temperature
ws.max_tokens = max_tokens
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return ''.join([output_queue.get() for _ in range(output_queue.qsize())])
API 封装完成。
六、基于 LangChain 自定义 Embeddings
LangChain 为基于 LLM 开发自定义应用提供了高效的开发框架,便于开发者迅速地激发 LLM 的强大能力,搭建 LLM 应用。LangChain 也同样支持多种大模型的 Embeddings,内置了 OpenAI、LLAMA 等大模型 Embeddings 的调用接口。但是,LangChain 并没有内置所有大模型,它通过允许用户自定义 Embeddings 类型,来提供强大的可扩展性。
要实现自定义 Embeddings,需要定义一个自定义类继承自 LangChain 的 Embeddings 基类,然后定义三个函数:① _embed 方法,其接受一个字符串,并返回一个存放 Embeddings 的 List[float],即模型的核心调用;② embed_query 方法,用于对单个字符串(query)进行 embedding。③ embed_documents 方法,用于对字符串列表(documents)进行 embedding。
6.1、导入所需的第三方库
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
6.2、定义一个继承自 Embeddings 类的自定义 Embeddings 类
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""`Zhipuai Embeddings` embedding models."""
zhipuai_api_key: Optional[str] = None
"""Zhipuai application apikey"""
在 Python 中,root_validator 是 Pydantic 模块中一个用于自定义数据校验的装饰器函数。root_validator 用于在校验整个数据模型之前对整个数据模型进行自定义校验,以确保所有的数据都符合所期望的数据结构。
root_validator 接收一个函数作为参数,该函数包含需要校验的逻辑。函数应该返回一个字典,其中包含经过校验的数据。如果校验失败,则抛出一个 ValueError 异常。
装饰器 root_validator 确保导入了相关的包和并配置了相关的 API_Key 这里取巧,在确保导入 zhipuai model 后直接将zhipuai.model_api绑定到 cliet 上,减少和其他 Embeddings 类的差异。
values["client"] = zhipuai.model_api
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
验证环境变量或配置文件中的zhipuai_api_key是否可用。
Args:
values (Dict): 包含配置信息的字典,必须包含 zhipuai_api_key 的字段
Returns:
values (Dict): 包含配置信息的字典。如果环境变量或配置文件中未提供 zhipuai_api_key,则将返回原始值;否则将返回包含 zhipuai_api_key 的值。
Raises:
ValueError: zhipuai package not found, please install it with `pip install
zhipuai`
"""
values["zhipuai_api_key"] = get_from_dict_or_env(
values,
"zhipuai_api_key",
"ZHIPUAI_API_KEY",
)
try:
import zhipuai
zhipuai.api_key = values["zhipuai_api_key"]
values["client"] = zhipuai.model_api
except ImportError:
raise ValueError(
"Zhipuai package not found, please install it with "
"`pip install zhipuai`"
)
return values
6.3、重写 _embed 方法,调用远程 API 并解析 embedding 结果
def _embed(self, texts: str) -> List[float]:
"""
生成输入文本的 embedding。
Args:
texts (str): 要生成 embedding 的文本。
Return:
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表。
"""
try:
resp = self.client.invoke(
model="text_embedding",
prompt=texts
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
if resp["code"] != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (resp["code"], resp["msg"])
)
embeddings = resp["data"]["embedding"]
return embeddings
6.4、重写 embed_documents 方法
因为这里 _embed 已经定义好了,可以直接传入文本并返回结果即可。
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
生成输入文本列表的 embedding。
Args:
texts (List[str]): 要生成 embedding 的文本列表.
Returns:
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
"""
return [self._embed(text) for text in texts]
embed_query 是对单个文本计算 embedding 的方法,因为我们已经定义好对文档列表计算 embedding 的方法embed_documents 了,这里可以直接将单个文本组装成 list 的形式传给 embed_documents。
def embed_query(self, text: str) -> List[float]:
"""
生成输入文本的 embedding。
Args:
text (str): 要生成 embedding 的文本。
Return:
List [float]: 输入文本的 embedding,一个浮点数值列表。
"""
resp = self.embed_documents([text])
return resp[0]
什么要先定义embed_documents再用 embed_query 调用呢,不返过来呢,其实也是可以的,embed_query 单独请求也是可以的。
对于 embed_documents 可以加入一些内容处理后再请求 embedding,如果文档特别长,可以考虑对文档分段,防止超过最大 token 限制。