如何创建自定义 LLM 类
先决条件
本指南假设你熟悉以下概念
本笔记本介绍了如何创建自定义 LLM 包装器,如果你想使用自己的 LLM 或 LangChain 中直接支持的包装器之外的其他包装器,可以使用它。
在扩展 LLM
类 后,自定义 LLM 需要实现一些必要的东西。
- 一个
_call
方法,它接收一个字符串和调用选项(包括stop
序列),并返回一个字符串。 - 一个
_llmType
方法,它返回一个字符串。仅用于日志记录目的。
你还可以实现以下可选方法。
- 一个
_streamResponseChunks
方法,它返回一个AsyncIterator
并生成GenerationChunks
。这允许 LLM 支持流式传输输出。
让我们实现一个非常简单的自定义 LLM,它只回显输入的前 n
个字符。
import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms";
import type { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { GenerationChunk } from "@langchain/core/outputs";
interface CustomLLMInput extends BaseLLMParams {
n: number;
}
class CustomLLM extends LLM {
n: number;
constructor(fields: CustomLLMInput) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "custom";
}
async _call(
prompt: string,
options: this["ParsedCallOptions"],
runManager: CallbackManagerForLLMRun
): Promise<string> {
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
return prompt.slice(0, this.n);
}
async *_streamResponseChunks(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of prompt.slice(0, this.n)) {
yield new GenerationChunk({
text: letter,
});
// Trigger the appropriate callback
await runManager?.handleLLMNewToken(letter);
}
}
}
现在,我们可以像使用其他 LLM 一样使用它。
const llm = new CustomLLM({ n: 4 });
await llm.invoke("I am an LLM");
I am
并支持流式传输。
const stream = await llm.stream("I am an LLM");
for await (const chunk of stream) {
console.log(chunk);
}
I
a
m
如果你想利用 LangChain 的回调系统来实现令牌跟踪等功能,你可以扩展 BaseLLM
类并实现更低级的 _generate
方法。它不接受单个字符串作为输入和单个字符串作为输出,而是可以接受多个输入字符串并将每个输入字符串映射到多个字符串输出。此外,它返回一个包含附加元数据字段的 Generation
输出,而不仅仅是一个字符串。
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { LLMResult } from "@langchain/core/outputs";
import {
BaseLLM,
BaseLLMCallOptions,
BaseLLMParams,
} from "@langchain/core/language_models/llms";
interface AdvancedCustomLLMCallOptions extends BaseLLMCallOptions {}
interface AdvancedCustomLLMParams extends BaseLLMParams {
n: number;
}
class AdvancedCustomLLM extends BaseLLM<AdvancedCustomLLMCallOptions> {
n: number;
constructor(fields: AdvancedCustomLLMParams) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "advanced_custom_llm";
}
async _generate(
inputs: string[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<LLMResult> {
const outputs = inputs.map((input) => input.slice(0, this.n));
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
// One input could generate multiple outputs.
const generations = outputs.map((output) => [
{
text: output,
// Optional additional metadata for the generation
generationInfo: { outputCount: 1 },
},
]);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations,
llmOutput: { tokenUsage },
};
}
}
这将在回调事件和`streamEvents 方法中传递返回的附加信息。
const llm = new AdvancedCustomLLM({ n: 4 });
const eventStream = await llm.streamEvents("I am an LLM", {
version: "v2",
});
for await (const event of eventStream) {
if (event.event === "on_llm_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_llm_end",
"data": {
"output": {
"generations": [
[
{
"text": "I am",
"generationInfo": {
"outputCount": 1
}
}
]
],
"llmOutput": {
"tokenUsage": {
"usedTokens": 4
}
}
}
},
"run_id": "a9ce50e4-f85b-41eb-bcbe-793efc52f9d8",
"name": "AdvancedCustomLLM",
"tags": [],
"metadata": {}
}