# LLM如何基于SQL做QA（一）

官网: https://python.langchain.com/v0.2/docs/tutorials/sql_qa/

## 架构
![](../resource/img_19.png)
大体来说，分为三步
1. 将问题转换为DSL查询：模型将用户输入转换为SQL查询。
2. 执行sql查询
3. 回答问题，模型使用sql执行的结果和用户的问题来回答问题

### 创建数据库

这里使用sqlite
1. 创建Chinook.sql，内容为 https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql
2. 运行 sqlite3 Chinook.db
3. 在sqlite3的终端里面，运行 .read Chinook.sql（注意，Chinook.sql和Chinook.db要在统一目录，否则导入不进去）
4. 测试导入是否成功`SELECT * FROM Artist LIMIT 10;`

下面使用代码来测试下

In [1]:
from langchain_community.utilities import SQLDatabase


db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

数据库已经准备好，下面开始实现
### LangChain封装好的相关功能

In [2]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

#### 将用户输入的问题转为sql

langchan提供了 [create_sql_query_chain](https://python.langchain.com/v0.2/api_reference/langchain/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)

In [3]:
from langchain.chains import create_sql_query_chain
from langchain.globals import set_debug
set_debug(True)

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
response

[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[36;1m[1;3m[chain/end][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] [0ms] Exiting Chain run wit

'SELECT COUNT(EmployeeId) AS NumEmployees\nFROM Employee;'

看LangSmith中的观测结果，prompt是

```
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves \"today\".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:

CREATE TABLE \"Album\" (
\t\"AlbumId\" INTEGER NOT NULL, 
\t\"Title\" NVARCHAR(160) NOT NULL, 
\t\"ArtistId\" INTEGER NOT NULL, 
\tPRIMARY KEY (\"AlbumId\"), 
\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")
)

/*
3 rows from Album table:
AlbumId\tTitle\tArtistId
1\tFor Those About To Rock We Salute You\t1
2\tBalls to the Wall\t2
3\tRestless and Wild\t2
*/


CREATE TABLE \"Artist\" (
\t\"ArtistId\" INTEGER NOT NULL, 
\t\"Name\" NVARCHAR(120), 
\tPRIMARY KEY (\"ArtistId\")
)

/*
3 rows from Artist table:
ArtistId\tName
1\tAC/DC
2\tAccept
3\tAerosmith
*/


CREATE TABLE \"Customer\" (
\t\"CustomerId\" INTEGER NOT NULL, 
\t\"FirstName\" NVARCHAR(40) NOT NULL, 
\t\"LastName\" NVARCHAR(20) NOT NULL, 
\t\"Company\" NVARCHAR(80), 
\t\"Address\" NVARCHAR(70), 
\t\"City\" NVARCHAR(40), 
\t\"State\" NVARCHAR(40), 
\t\"Country\" NVARCHAR(40), 
\t\"PostalCode\" NVARCHAR(10), 
\t\"Phone\" NVARCHAR(24), 
\t\"Fax\" NVARCHAR(24), 
\t\"Email\" NVARCHAR(60) NOT NULL, 
\t\"SupportRepId\" INTEGER, 
\tPRIMARY KEY (\"CustomerId\"), 
\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")
)

/*
3 rows from Customer table:
CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId
1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3
2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5
3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3
*/


CREATE TABLE \"Employee\" (
\t\"EmployeeId\" INTEGER NOT NULL, 
\t\"LastName\" NVARCHAR(20) NOT NULL, 
\t\"FirstName\" NVARCHAR(20) NOT NULL, 
\t\"Title\" NVARCHAR(30), 
\t\"ReportsTo\" INTEGER, 
\t\"BirthDate\" DATETIME, 
\t\"HireDate\" DATETIME, 
\t\"Address\" NVARCHAR(70), 
\t\"City\" NVARCHAR(40), 
\t\"State\" NVARCHAR(40), 
\t\"Country\" NVARCHAR(40), 
\t\"PostalCode\" NVARCHAR(10), 
\t\"Phone\" NVARCHAR(24), 
\t\"Fax\" NVARCHAR(24), 
\t\"Email\" NVARCHAR(60), 
\tPRIMARY KEY (\"EmployeeId\"), 
\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")
)

/*
3 rows from Employee table:
EmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail
1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\tandrew@chinookcorp.com
2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\tnancy@chinookcorp.com
3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\tjane@chinookcorp.com
*/


CREATE TABLE \"Genre\" (
\t\"GenreId\" INTEGER NOT NULL, 
\t\"Name\" NVARCHAR(120), 
\tPRIMARY KEY (\"GenreId\")
)

/*
3 rows from Genre table:
GenreId\tName
1\tRock
2\tJazz
3\tMetal
*/


CREATE TABLE \"Invoice\" (
\t\"InvoiceId\" INTEGER NOT NULL, 
\t\"CustomerId\" INTEGER NOT NULL, 
\t\"InvoiceDate\" DATETIME NOT NULL, 
\t\"BillingAddress\" NVARCHAR(70), 
\t\"BillingCity\" NVARCHAR(40), 
\t\"BillingState\" NVARCHAR(40), 
\t\"BillingCountry\" NVARCHAR(40), 
\t\"BillingPostalCode\" NVARCHAR(10), 
\t\"Total\" NUMERIC(10, 2) NOT NULL, 
\tPRIMARY KEY (\"InvoiceId\"), 
\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")
)

/*
3 rows from Invoice table:
InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal
1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98
2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96
3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94
*/


CREATE TABLE \"InvoiceLine\" (
\t\"InvoiceLineId\" INTEGER NOT NULL, 
\t\"InvoiceId\" INTEGER NOT NULL, 
\t\"TrackId\" INTEGER NOT NULL, 
\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, 
\t\"Quantity\" INTEGER NOT NULL, 
\tPRIMARY KEY (\"InvoiceLineId\"), 
\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), 
\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity
1\t1\t2\t0.99\t1
2\t1\t4\t0.99\t1
3\t2\t6\t0.99\t1
*/


CREATE TABLE \"MediaType\" (
\t\"MediaTypeId\" INTEGER NOT NULL, 
\t\"Name\" NVARCHAR(120), 
\tPRIMARY KEY (\"MediaTypeId\")
)

/*
3 rows from MediaType table:
MediaTypeId\tName
1\tMPEG audio file
2\tProtected AAC audio file
3\tProtected MPEG-4 video file
*/


CREATE TABLE \"Playlist\" (
\t\"PlaylistId\" INTEGER NOT NULL, 
\t\"Name\" NVARCHAR(120), 
\tPRIMARY KEY (\"PlaylistId\")
)

/*
3 rows from Playlist table:
PlaylistId\tName
1\tMusic
2\tMovies
3\tTV Shows
*/


CREATE TABLE \"PlaylistTrack\" (
\t\"PlaylistId\" INTEGER NOT NULL, 
\t\"TrackId\" INTEGER NOT NULL, 
\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), 
\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), 
\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")
)

/*
3 rows from PlaylistTrack table:
PlaylistId\tTrackId
1\t3402
1\t3389
1\t3390
*/


CREATE TABLE \"Track\" (
\t\"TrackId\" INTEGER NOT NULL, 
\t\"Name\" NVARCHAR(200) NOT NULL, 
\t\"AlbumId\" INTEGER, 
\t\"MediaTypeId\" INTEGER NOT NULL, 
\t\"GenreId\" INTEGER, 
\t\"Composer\" NVARCHAR(220), 
\t\"Milliseconds\" INTEGER NOT NULL, 
\t\"Bytes\" INTEGER, 
\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, 
\tPRIMARY KEY (\"TrackId\"), 
\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), 
\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), 
\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")
)

/*
3 rows from Track table:
TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice
1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99
2\tBalls to the Wall\t2\t2\t1\tU. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann\t342562\t5510424\t0.99
3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99
*/

Question: How many employees are there
SQLQuery: 
```
将所有的表结构全部给了LLM（这就要求我们的DDL要写的超级清晰），这里之后还会说优化方式，如果表比较多，模型找起来也比较困难。

In [4]:
# 手动运行结果
db.run(response)

'[(8,)]'

#### 执行sql的查询

这里比较危险，不能保证模型生成的sql是安全的，比如drop操作，在实际使用中从下面几个方面的考虑 
-  数据库权限最小化 
- 人工审批

LangChain提供了`QuerySQLDatabaseTool`来方便的运行结果
这里面其实就是将结果传递给db来执行

In [5]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# 
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})

[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[36;1m[1;3m[chain/end][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] [0ms] Exiting Chain run wit

'[(8,)]'

#### 回答用户问题

通过上面两部两步，已经将问题和执行过程结合起来，这里需要将执行结果和用户的问题结合起来，生成最终的答案 

In [6]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

chain = (
    # 这里的逻辑是
    # 1. 先执行query，将结果复制为query
    # 2. itemgetter("query"),到上一步的query
    # 3. 执行结果
    # 4. 下面才是现在的prompt
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)

chain.invoke({"question": "How many employees are there"})

[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<query>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<query> > chain:RunnableParallel<query>] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<query> > chain:RunnableParallel<query> > chain:RunnableSequence] Entering Chain run with input:
[0m{
  "question": "How many employees are there"
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<query> > chain:RunnableParallel<query> > chain:RunnableSequence > chain:RunnableAssign<input,table_info>] Entering Chain run with input:
[0m{
  "question": "How m

'There are 8 employees.'

###  Agents

LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:

    It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).
    It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
    It can query the database as many times as needed to answer the user question.
    It will save tokens by only retrieving the schema from relevant tables.

To initialize the agent we'll use the SQLDatabaseToolkit to create a bunch of tools:

    Create and execute queries
    Check query syntax
    Retrieve table descriptions
    ... and more
LangChain提供了更加功能强大的和sql database交互的方式，agent比上面是说的chain更加的强大，主要的优点在于
1. 他可以基于数据库的schema和数据来回答问题
2. 他可以自修复，运行生成的查询如果报错，他会捕获异常，让模型来修复
3. 他可以多次查询数据库来回答用户问题
4. 他可以使用相似性搜索来提高关联性，减少模型token的消耗
5. 他包含了一系列工具，比如创建和执行，检查语法，拿到表的描述信息等

In [7]:
# 所有的工具
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for item in tools:
    print("=" * 10 ) 
    print(f"name:{item.name}\ndescribe:{item.description}")

name:sql_db_query
describe:Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
name:sql_db_schema
describe:Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
name:sql_db_list_tables
describe:Input is an empty string, output is a comma-separated list of tables in the database.
name:sql_db_query_checker
describe:Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!


In [8]:
from langchain_core.messages import SystemMessage

SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables."""

system_message = SystemMessage(content=SQL_PREFIX)

### 初始化agent

这里用的是LangGraph

In [9]:

from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)

  agent_executor = create_react_agent(llm, tools, messages_modifier=system_message)


In [11]:
from langchain_core.messages import HumanMessage
from langchain.globals import set_debug
set_debug(False)

for s in agent_executor.stream(
    {"messages": [HumanMessage(content="Which country's customers spent the most?")]}
):
    print(s)
    print("----")

{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_4nkS4xqUvSoWXW7D4L1g2U7K', 'function': {'arguments': '{"query":"SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1"}', 'name': 'sql_db_query'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 557, 'total_tokens': 610}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-b70289a7-ac28-49b8-adc7-ec089f84bb9a-0', tool_calls=[{'name': 'sql_db_query', 'args': {'query': 'SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1'}, 'id': 'call_4nkS4xqUvSoWXW7D4L1g2U7K', 'type': 'tool_call'}], usage_metadata={'input_tokens': 557, 'output_tokens': 53, 'total_tokens': 610

上面的过程中显示了模型的响应 
在第一个对话中，调用了 sql_db_query，生成了查询语句，`SELECT c.Country, SUM(i.Total) AS Total_Spent FROM customers c JOIN invoices i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY Total_Spent DESC LIMIT 1`
然后工具去执行这个查询语句，报错了。
之后，模型调用了`sql_db_list_tables`工具，返回了所有的表名字。
模型调用了`sql_db_query`，返回的正确结果 模型，使用了结果来回答用户的问题 
LangSmith中trace记录： https://smith.langchain.com/public/5423647a-e378-4279-8e78-b06925b1fa55/r
这里的执行过程比较复杂，具体的看上面的链接
![](../resource/img_20.png)

### 处理高基数列
高基数列是指在数据库中，某一列包含的数据具有非常高的多样性或不同值的数量很大。比如包含地址、歌曲名称或艺术家等专有名词的列就可能是高基数列，因为可能存在大量不同的地址、歌曲名和艺术家名等。
在对包含诸如地址、歌曲名称或艺术家等专有名词来操作的时候，需要检查下这些词对不对，或者可以直接做查询，返回一些和这个专有名词相似的词。


对这些词坐向量化操作，分装成一个工具 ，交给模型，这样当用户的问题中包含这些专有名词的时候，会执行这个工具，找到这个单词正确的拼写 。这样在构建查询的语句之前 就可以确保用户所指的是哪一个实体。

In [19]:
import ast
import re

# 构建用户实体，做向量化操作，保存在向量数据库中
def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))


artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]

['New Adventures In Hi-Fi',
 'Axé Bahia',
 'Greatest Hits II',
 'International Superhits',
 'Surfing with the Alien (Remastered)']

In [25]:
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

vector_db = Chroma.from_texts(artists + albums, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})

# 创建工具，交给模型，增加描述
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

In [22]:
print(retriever_tool.invoke("Alice Chains"))

Alice In Chains

Alanis Morissette

Pearl Jam

Pearl Jam

Audioslave


In [23]:
system = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

You have access to the following tables: {table_names}

If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool!
Do not try to guess at the proper name - use this function to find similar ones.""".format(
    table_names=db.get_usable_table_names()
)

system_message = SystemMessage(content=system)

tools.append(retriever_tool)

agent = create_react_agent(llm, tools, messages_modifier=system_message)

  agent = create_react_agent(llm, tools, messages_modifier=system_message)


In [24]:
for s in agent.stream(
    {"messages": [HumanMessage(content="How many albums does alis in chain have?")]}
):
    print(s)
    print("----")

{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_bmECQtxWvsrUytiYN22tKirx', 'function': {'arguments': '{"query":"alis in chain"}', 'name': 'search_proper_nouns'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 19, 'prompt_tokens': 674, 'total_tokens': 693}, 'model_name': 'gpt-35-turbo', 'system_fingerprint': 'fp_e49e4201a9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6a9d7fcb-263f-45f0-89f2-b2ea54316969-0', tool_calls=[{'name': 'search_proper_nouns', 'args': {'query': 'alis in chain'}, 'id': 'call_bmECQtxWvsrUytiYN22tKirx', 'type': 'tool_call'}], usage_metadata={'input_tokens': 674, 'output_tokens': 19, 'total_tokens': 693})]}}
----
{'tools': {'messages': [ToolMessage(content='Aisha Duo\n\nXis\n\nDa Lama Ao Caos\n\nA-Sides\n\nAzymuth', name='search_proper_nouns', tool_call_id='call_bmECQtxWvsrUytiYN22tKirx')]}}
----
{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{

从这里看到第一个调用的工具，就是我们专有名词的工具。这样就可以保证后面的名词实体是正确的，数据库查询的文章是正确的。  