如何创建自定义聊天模型类
先决条件
本指南假定您熟悉以下概念
本笔记本介绍了如何创建自定义聊天模型包装器,如果您想使用自己的聊天模型或与 LangChain 中直接支持的包装器不同的包装器,则可以使用它。
在扩展 SimpleChatModel
类 后,聊天模型需要实现一些必备事项
- 一个接受消息列表和调用选项(包括
stop
序列等)并返回字符串的_call
方法。 - 一个返回字符串的
_llmType
方法。仅用于记录目的。
您还可以实现以下可选方法
- 一个返回
AsyncGenerator
并生成ChatGenerationChunks
的_streamResponseChunks
方法。这使得 LLM 可以支持流式传输输出。
让我们实现一个非常简单的自定义聊天模型,它只回显输入的前n
个字符。
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}
class CustomChatModel extends SimpleChatModel {
n: number;
constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "custom";
}
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}
async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}
现在,我们可以像使用任何其他聊天模型一样使用它
const chatModel = new CustomChatModel({ n: 4 });
await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
lc_serializable: true,
lc_kwargs: {
content: 'I am',
tool_calls: [],
invalid_tool_calls: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I am',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
usage_metadata: undefined
}
并支持流式传输
const stream = await chatModel.stream([["human", "I am an LLM"]]);
for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'I',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: ' ',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: ' ',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'a',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'a',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'm',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'm',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
如果您想利用 LangChain 的回调系统来实现令牌跟踪等功能,您可以扩展 BaseChatModel
类并实现更低级的_generate
方法。它也接受BaseMessage
列表作为输入,但要求您构建并返回一个ChatGeneration
对象,该对象允许额外的元数据。以下是一个示例
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
interface AdvancedCustomChatModelOptions extends BaseChatModelCallOptions {}
interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}
class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;
static lc_name(): string {
return "AdvancedCustomChatModel";
}
constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}
这将通过回调事件和`streamEvents 方法传递返回的额外信息
const chatModel = new AdvancedCustomChatModel({ n: 4 });
const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v2",
});
for await (const event of eventStream) {
if (event.event === "on_chat_model_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_chat_model_end",
"data": {
"output": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"tool_calls": [],
"invalid_tool_calls": [],
"additional_kwargs": {},
"response_metadata": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
},
"run_id": "11dbdef6-1b91-407e-a497-1a1ce2974788",
"name": "AdvancedCustomChatModel",
"tags": [],
"metadata": {
"ls_model_type": "chat"
}
}
跟踪(高级)
如果您正在实现自定义聊天模型并想将其与 LangSmith 等跟踪服务一起使用,您可以通过在模型上实现invocationParams()
方法来自动记录用于特定调用的参数。
此方法完全可选,但它返回的任何内容都将作为元数据记录到跟踪中。
以下是一种可能使用的模式
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
type BaseChatModelCallOptions,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}
interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
n: number;
}
class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;
n: number;
static lc_name(): string {
return "CustomChatModel";
}
constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
this.n = fields.n;
}
// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}