跳到主要内容

如何创建自定义聊天模型类

先决条件

本指南假设您熟悉以下概念

此笔记本介绍如何创建自定义聊天模型包装器,如果您想使用自己的聊天模型或 LangChain 中直接支持的包装器以外的包装器,则可以使用它。

扩展 SimpleChatModel 后,聊天模型需要实现一些必要的内容

  • 一个 _call 方法,该方法接收消息列表和调用选项(包括 stop 序列),并返回一个字符串。
  • 一个 _llmType 方法,该方法返回一个字符串。仅用于日志记录目的。

您还可以实现以下可选方法

  • 一个 _streamResponseChunks 方法,该方法返回一个 AsyncGenerator 并生成 ChatGenerationChunks。这允许 LLM 支持流式输出。

让我们实现一个非常简单的自定义聊天模型,它只是回显输入的第一个 n 个字符。

import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";

interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}

class CustomChatModel extends SimpleChatModel {
n: number;

constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}

_llmType() {
return "custom";
}

async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}

现在我们可以像使用其他聊天模型一样使用它

const chatModel = new CustomChatModel({ n: 4 });

await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
lc_serializable: true,
lc_kwargs: {
content: 'I am',
tool_calls: [],
invalid_tool_calls: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I am',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
usage_metadata: undefined
}

并支持流式传输

const stream = await chatModel.stream([["human", "I am an LLM"]]);

for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'I',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: ' ',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: ' ',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'a',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'a',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'm',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'm',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}

如果您想利用 LangChain 的回调系统来实现令牌跟踪等功能,您可以扩展 BaseChatModel 类并实现更低级的 _generate 方法。它也接收 BaseMessage 列表作为输入,但要求您构建并返回一个 ChatGeneration 对象,该对象允许使用其他元数据。以下是一个示例

import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

interface AdvancedCustomChatModelOptions extends BaseChatModelCallOptions {}

interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}

class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;

static lc_name(): string {
return "AdvancedCustomChatModel";
}

constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

这将在回调事件和`streamEvents 方法中传递额外返回的信息

const chatModel = new AdvancedCustomChatModel({ n: 4 });

const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v2",
});

for await (const event of eventStream) {
if (event.event === "on_chat_model_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_chat_model_end",
"data": {
"output": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"tool_calls": [],
"invalid_tool_calls": [],
"additional_kwargs": {},
"response_metadata": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
},
"run_id": "11dbdef6-1b91-407e-a497-1a1ce2974788",
"name": "AdvancedCustomChatModel",
"tags": [],
"metadata": {
"ls_model_type": "chat"
}
}

跟踪(高级)

如果您正在实现自定义聊天模型并希望将其与 LangSmith 等跟踪服务一起使用,您可以通过在模型上实现 invocationParams() 方法来自动记录给定调用使用的参数。

此方法完全可选,但它返回的任何内容都将作为跟踪的元数据记录。

以下是一种您可能使用的模式

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
type BaseChatModelCallOptions,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";

interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}

interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
n: number;
}

class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;

n: number;

static lc_name(): string {
return "CustomChatModel";
}

constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
this.n = fields.n;
}

// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

此页面有帮助吗?


您也可以留下详细的反馈 在 GitHub 上.