如何处理大型数据库
本指南假设您熟悉以下内容
为了针对数据库编写有效的查询,我们需要向模型提供表名、表模式和特征值,以便它可以查询这些数据。当存在很多表、列和/或高基数列时,我们将无法在每个提示中都包含关于数据库的完整信息。因此,我们必须找到方法,仅动态地将最相关的信息插入提示中。让我们看一下实现此目的的一些技术。
设置
首先,安装所需的软件包并设置环境变量。此示例将使用 OpenAI 作为 LLM。
npm install langchain @langchain/community @langchain/openai typeorm sqlite3
export OPENAI_API_KEY="your api key"
# Uncomment the below to use LangSmith. Not required.
# export LANGCHAIN_API_KEY="your api key"
# export LANGCHAIN_TRACING_V2=true
# Reduce tracing latency if you are not in a serverless environment
# export LANGCHAIN_CALLBACKS_BACKGROUND=true
以下示例将使用与 Chinook 数据库的 SQLite 连接。按照这些安装步骤在与本笔记本相同的目录中创建 Chinook.db
- 将此文件保存为
Chinook_Sqlite.sql
- 运行 sqlite3
Chinook.db
- 运行
.read Chinook_Sqlite.sql
- 测试
SELECT * FROM Artist LIMIT 10;
现在,Chinhook.db
位于我们的目录中,我们可以使用 Typeorm 驱动的 SqlDatabase
类与它进行交互
import { SqlDatabase } from "langchain/sql_db";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
console.log(db.allTables.map((t) => t.tableName));
/**
[
'Album', 'Artist',
'Customer', 'Employee',
'Genre', 'Invoice',
'InvoiceLine', 'MediaType',
'Playlist', 'PlaylistTrack',
'Track'
]
*/
API 参考
- SqlDatabase 来自
langchain/sql_db
很多表
我们需要在提示中包含的主要信息之一是相关表的模式。当我们有很多表时,我们无法将所有模式都放入单个提示中。在这种情况下,我们可以做的是首先提取与用户输入相关的表的名称,然后只包含它们的模式。
一种简单且可靠的方法是使用 OpenAI 函数调用和 Zod 模型。LangChain 带有一个内置的 createExtractionChainZod
链,它让我们可以做到这一点
import { ChatPromptTemplate } from "@langchain/core/prompts";
import {
RunnablePassthrough,
RunnableSequence,
} from "@langchain/core/runnables";
import { ChatOpenAI } from "@langchain/openai";
import { createSqlQueryChain } from "langchain/chains/sql_db";
import { SqlDatabase } from "langchain/sql_db";
import { DataSource } from "typeorm";
import { z } from "zod";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const llm = new ChatOpenAI({ model: "gpt-4", temperature: 0 });
const Table = z.object({
names: z.array(z.string()).describe("Names of tables in SQL database"),
});
const tableNames = db.allTables.map((t) => t.tableName).join("\n");
const system = `Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
The tables are:
${tableNames}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.`;
const prompt = ChatPromptTemplate.fromMessages([
["system", system],
["human", "{input}"],
]);
const tableChain = prompt.pipe(llm.withStructuredOutput(Table));
console.log(
await tableChain.invoke({
input: "What are all the genres of Alanis Morisette songs?",
})
);
/**
{ names: [ 'Artist', 'Track', 'Genre' ] }
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/5ca0c91e-4a40-44ef-8c45-9a4247dc474c/r
// -------------
/**
This works pretty well! Except, as we’ll see below, we actually need a few other tables as well.
This would be pretty difficult for the model to know based just on the user question.
In this case, we might think to simplify our model’s job by grouping the tables together.
We’ll just ask the model to choose between categories “Music” and “Business”, and then take care of selecting all the relevant tables from there:
*/
const prompt2 = ChatPromptTemplate.fromMessages([
[
"system",
`Return the names of the SQL tables that are relevant to the user question.
The tables are:
Music
Business`,
],
["human", "{input}"],
]);
const categoryChain = prompt2.pipe(llm.withStructuredOutput(Table));
console.log(
await categoryChain.invoke({
input: "What are all the genres of Alanis Morisette songs?",
})
);
/**
{ names: [ 'Music' ] }
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/12b62e78-bfbe-42ff-86f2-ad738a476554/r
// -------------
const getTables = (categories: z.infer<typeof Table>): Array<string> => {
let tables: Array<string> = [];
for (const category of categories.names) {
if (category === "Music") {
tables = tables.concat([
"Album",
"Artist",
"Genre",
"MediaType",
"Playlist",
"PlaylistTrack",
"Track",
]);
} else if (category === "Business") {
tables = tables.concat([
"Customer",
"Employee",
"Invoice",
"InvoiceLine",
]);
}
}
return tables;
};
const tableChain2 = categoryChain.pipe(getTables);
console.log(
await tableChain2.invoke({
input: "What are all the genres of Alanis Morisette songs?",
})
);
/**
[
'Album',
'Artist',
'Genre',
'MediaType',
'Playlist',
'PlaylistTrack',
'Track'
]
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/e78c10aa-e923-4a24-b0c8-f7a6f5d316ce/r
// -------------
// Now that we’ve got a chain that can output the relevant tables for any query we can combine this with our createSqlQueryChain, which can accept a list of tableNamesToUse to determine which table schemas are included in the prompt:
const queryChain = await createSqlQueryChain({
llm,
db,
dialect: "sqlite",
});
const tableChain3 = RunnableSequence.from([
{
input: (i: { question: string }) => i.question,
},
tableChain2,
]);
const fullChain = RunnablePassthrough.assign({
tableNamesToUse: tableChain3,
}).pipe(queryChain);
const query = await fullChain.invoke({
question: "What are all the genres of Alanis Morisette songs?",
});
console.log(query);
/**
SELECT DISTINCT "Genre"."Name"
FROM "Genre"
JOIN "Track" ON "Genre"."GenreId" = "Track"."GenreId"
JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId"
JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId"
WHERE "Artist"."Name" = 'Alanis Morissette'
LIMIT 5;
*/
console.log(await db.run(query));
/**
[{"Name":"Rock"}]
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/c7d576d0-3462-40db-9edc-5492f10555bf/r
// -------------
// We might rephrase our question slightly to remove redundancy in the answer
const query2 = await fullChain.invoke({
question: "What is the set of all unique genres of Alanis Morisette songs?",
});
console.log(query2);
/**
SELECT DISTINCT Genre.Name FROM Genre
JOIN Track ON Genre.GenreId = Track.GenreId
JOIN Album ON Track.AlbumId = Album.AlbumId
JOIN Artist ON Album.ArtistId = Artist.ArtistId
WHERE Artist.Name = 'Alanis Morissette'
*/
console.log(await db.run(query2));
/**
[{"Name":"Rock"}]
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/6e80087d-e930-4f22-9b40-f7edb95a2145/r
// -------------
API 参考
- ChatPromptTemplate 来自
@langchain/core/prompts
- RunnablePassthrough 来自
@langchain/core/runnables
- RunnableSequence 来自
@langchain/core/runnables
- ChatOpenAI 来自
@langchain/openai
- createSqlQueryChain 来自
langchain/chains/sql_db
- SqlDatabase 来自
langchain/sql_db
我们已经了解了如何在链中动态地将表模式的子集包含在提示中。解决此问题的另一种可能方法是让代理自行决定何时查找表,方法是为它提供一个用于执行此操作的工具。
高基数列
高基数是指数据库中具有大量唯一值的列。这些列的特点是数据条目中具有高度的唯一性,例如个人姓名、地址或产品序列号。高基数数据可能会给索引和查询带来挑战,因为它需要更复杂的策略才能有效地筛选和检索特定条目。
为了筛选包含专有名词(如地址、歌曲名称或艺术家)的列,我们首先需要仔细检查拼写,以便正确地筛选数据。
一种简单的方法是创建一个包含数据库中所有存在的不同专有名词的向量存储。然后,我们可以查询该向量存储以获取每个用户输入,并将最相关的专有名词注入提示中。
首先,我们需要获取我们要使用的每个实体的唯一值,为此,我们定义一个函数,它将结果解析为元素列表
import { DocumentInterface } from "@langchain/core/documents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import {
RunnablePassthrough,
RunnableSequence,
} from "@langchain/core/runnables";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { createSqlQueryChain } from "langchain/chains/sql_db";
import { SqlDatabase } from "langchain/sql_db";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
async function queryAsList(database: any, query: string): Promise<string[]> {
const res: Array<{ [key: string]: string }> = JSON.parse(
await database.run(query)
)
.flat()
.filter((el: any) => el != null);
const justValues: Array<string> = res.map((item) =>
Object.values(item)[0]
.replace(/\b\d+\b/g, "")
.trim()
);
return justValues;
}
let properNouns: string[] = await queryAsList(db, "SELECT Name FROM Artist");
properNouns = properNouns.concat(
await queryAsList(db, "SELECT Title FROM Album")
);
properNouns = properNouns.concat(
await queryAsList(db, "SELECT Name FROM Genre")
);
console.log(properNouns.length);
/**
647
*/
console.log(properNouns.slice(0, 5));
/**
[
'AC/DC',
'Accept',
'Aerosmith',
'Alanis Morissette',
'Alice In Chains'
]
*/
// Now we can embed and store all of our values in a vector database:
const vectorDb = await MemoryVectorStore.fromTexts(
properNouns,
{},
new OpenAIEmbeddings()
);
const retriever = vectorDb.asRetriever(15);
// And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:
const system = `You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run.
Unless otherwise specified, do not return more than {top_k} rows.
Here is the relevant table info: {table_info}
Here is a non-exhaustive list of possible feature values.
If filtering on a feature value make sure to check its spelling against this list first:
{proper_nouns}`;
const prompt = ChatPromptTemplate.fromMessages([
["system", system],
["human", "{input}"],
]);
const llm = new ChatOpenAI({ model: "gpt-4", temperature: 0 });
const queryChain = await createSqlQueryChain({
llm,
db,
prompt,
dialect: "sqlite",
});
const retrieverChain = RunnableSequence.from([
(i: { question: string }) => i.question,
retriever,
(docs: Array<DocumentInterface>) =>
docs.map((doc) => doc.pageContent).join("\n"),
]);
const chain = RunnablePassthrough.assign({
proper_nouns: retrieverChain,
}).pipe(queryChain);
// To try out our chain, let’s see what happens when we try filtering on “elenis moriset”, a misspelling of Alanis Morissette, without and with retrieval:
// Without retrieval
const query = await queryChain.invoke({
question: "What are all the genres of Elenis Moriset songs?",
proper_nouns: "",
});
console.log("query", query);
/**
query SELECT DISTINCT Genre.Name
FROM Genre
JOIN Track ON Genre.GenreId = Track.GenreId
JOIN Album ON Track.AlbumId = Album.AlbumId
JOIN Artist ON Album.ArtistId = Artist.ArtistId
WHERE Artist.Name = 'Elenis Moriset'
LIMIT 5;
*/
console.log("db query results", await db.run(query));
/**
db query results []
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/b153cb9b-6fbb-43a8-b2ba-4c86715183b9/r
// -------------
// With retrieval:
const query2 = await chain.invoke({
question: "What are all the genres of Elenis Moriset songs?",
});
console.log("query2", query2);
/**
query2 SELECT DISTINCT Genre.Name
FROM Genre
JOIN Track ON Genre.GenreId = Track.GenreId
JOIN Album ON Track.AlbumId = Album.AlbumId
JOIN Artist ON Album.ArtistId = Artist.ArtistId
WHERE Artist.Name = 'Alanis Morissette';
*/
console.log("db query results", await db.run(query2));
/**
db query results [{"Name":"Rock"}]
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/2f4f0e37-3b7f-47b5-837c-e2952489cac0/r
// -------------
API 参考
- DocumentInterface 来自
@langchain/core/documents
- ChatPromptTemplate 来自
@langchain/core/prompts
- RunnablePassthrough 来自
@langchain/core/runnables
- RunnableSequence 来自
@langchain/core/runnables
- ChatOpenAI 来自
@langchain/openai
- OpenAIEmbeddings 来自
@langchain/openai
- createSqlQueryChain 来自
langchain/chains/sql_db
- SqlDatabase 来自
langchain/sql_db
- MemoryVectorStore 来自
langchain/vectorstores/memory
我们可以看到,通过检索,我们能够纠正拼写并获得有效的结果。
解决这个问题的另一种方法是让代理自行决定何时查找专有名词。
下一步
你现在已经了解了一些用于改进 SQL 生成的提示策略。