目录
【示例6.2】基于LangChain+LangGraph+通义千问实现MySQL数据库查询助手。
LangGraph开发AI Agent实践(人工智能技术丛书)【行情 报价 价格 评测】-京东
数据库操作(查询、插入、更新、删除)是智能体处理结构化数据的核心场景(如用户信息查询、订单管理、数据统计等)。LangGraph中通过SQLAlchemy ORM与LangChain SQL工具实现数据库集成,核心是将SQL操作封装为工具,由LLM生成SQL语句并执行。
6.2.1 数据库连接与操作:结构化数据交互
LangGraph是LangChain生态系统中的一个组件,用于构建具有状态管理和循环控制能力的有状态、多参与者、可循环的智能体工作流。然而,需要澄清一个关键点:LangGraph本身并不是一个数据库,它不提供数据库存储功能,也不直接连接数据库。它主要用于控制流(Control Flow)建模,通过图(Graph)结构定义智能体的状态转移和节点间交互逻辑。
尽管如此,LangGraph应用可以与结构化数据库(如PostgreSQL、MySQL、SQLite、MongoDB等)进行集成,实现对结构化数据的读写操作。
技术详解:
- 数据库连接:使用SQLAlchemy创建数据库引擎,支持MySQL、PostgreSQL、SQLite等主流数据库。
- SQL工具封装:LangChain提供SQLDatabaseToolkit,自动生成常见SQL操作工具(如query_sql_db、list_tables、describe_table),无须手动编写SQL函数。
- LLM生成SQL:通过提示词引导LLM根据用户需求生成合法SQL 语句,避免SQL注入风险(LangChain内置基础防护)。
- 图节点设计:分为SQL生成节点、SQL执行节点、结果格式化节,通过状态流转完成从用户查询到数据库结果的全流程。
6.2.2 实战案例:用户信息管理(MySQL数据库操作)
【示例6.2】基于LangChain+LangGraph+通义千问实现MySQL数据库查询助手。
步骤1:创建数据库与表(MySQL),先在MySQL中创建测试数据库和表。
CREATE DATABASE IF NOT EXISTS test_db;
USE langgraph_demo;
CREATE TABLE IF NOT EXISTS users (
id INT PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(50) NOT NULL COMMENT '用户名',
age INT COMMENT '年龄',
email VARCHAR(100) UNIQUE NOT NULL COMMENT '邮箱',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'
);
-- 插入测试数据
INSERT INTO users (name, age, email) VALUES
('张三', 25, 'zhangsan@example.com'),
('李四', 30, 'lisi@example.com'),
('王五', 28, 'wangwu@example.com');
步骤 2:安装必要的依赖(requirements.txt)。
# 核心依赖(最宽松约束,确保支持所有版本)
langchain>=0.1.0
langchain-community>=0.1.0
langchain-core>=0.1.0
langgraph>=0.0.1 # 支持包括极早期版本在内的所有LangGraph
pydantic>=1.10.0 # 同时兼容pydantic 1.x和2.x
# 通义千问依赖(二选一即可)
langchain-qwen>=0.1.10 # 推荐方式
dashscope>=1.10.0 # 备用方式
# 数据库依赖
python-dotenv==1.0.1
mysql-connector-python==8.4.0
pymysql==1.1.1
sqlalchemy>=1.4.0 # 兼容旧版本SQLAlchemy
# 基础依赖
requests>=2.20.0
numpy>=1.18.0
步骤 3:设置.env文件。
# mysql set
DB_USER=root #用户名
DB_PASSWORD=1111 #密码
DB_HOST=localhost
DB_PORT=3306
DB_NAME=test_db #数据库名
# 通义千问配置
QWEN_API_KEY=sk-xxxxxxxxxxxx.... # 从阿里云获取
QWEN_MODEL=qwen-turbo # 可选:qwen-turbo、qwen-plus、qwen-max
步骤 4:实现案例代码(LangGraph_database_connection_and_operation.py)。
# -*- coding: utf-8 -*-
import os
from typing import Optional, List, Dict, Any, ClassVar, Union
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_core.messages import (
SystemMessage,
HumanMessage,
ToolMessage,
BaseMessage,
AIMessage,
)
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel
from langgraph.graph import StateGraph, START, END
import warnings
warnings.filterwarnings('ignore')
# ------------------- 关键修复:完善通义千问导入(支持两种方式) -------------------
try:
# 方式1:优先使用 langchain_qwen(推荐)
from langchain_qwen import ChatQwen
QWEN_IMPORT_METHOD = "langchain_qwen"
except ImportError:
try:
# 方式2:若 langchain_qwen 安装失败,使用 DashScope(通义千问官方SDK)+ LangChain 包装
from dashscope import Generation
from langchain_core.language_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration, GenerationChunk
import json
# 自定义 ChatQwen 类(基于 DashScope,修复抽象方法问题)
class ChatQwen(BaseChatModel):
api_key: str
model: str = "qwen-turbo"
temperature: float = 0.1
max_tokens: int = 2048
# 必须实现的抽象属性:指定LLM类型
_llm_type: ClassVar[str] = "qwen-dashscope"
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
# 转换 LangChain 消息格式为 DashScope 格式
dashscope_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
dashscope_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, SystemMessage):
dashscope_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, AIMessage):
dashscope_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, ToolMessage):
dashscope_messages.append({"role": "tool", "content": msg.content})
# 调用通义千问 API
response = Generation.call(
model=self.model,
messages=dashscope_messages,
api_key=self.api_key,
temperature=self.temperature,
max_tokens=self.max_tokens,
result_format="message"
)
# 解析响应
if response.output.choices and len(response.output.choices) > 0:
choice = response.output.choices[0]
content = choice.message.content
tool_calls = choice.message.get("tool_calls", [])
# 构建 AIMessage(适配工具调用格式)
additional_kwargs = {}
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
additional_kwargs["tool_calls"] = tool_calls
aimsg = AIMessage(
content=content,
additional_kwargs=additional_kwargs
)
return ChatResult(generations=[ChatGeneration(message=aimsg)])
else:
raise ValueError(f"通义千问 API 响应为空:{response}")
def _stream(self, messages, stop=None, run_manager=None, **kwargs):
# 实现流式输出(简化版)
yield GenerationChunk(content="")
# 可选:实现 _llm_type 方法(兼容部分旧版本)
def _llm_type(self) -> str:
return "qwen-dashscope"
QWEN_IMPORT_METHOD = "dashscope"
print("⚠️ 注意:未安装 langchain_qwen,使用 dashscope 兼容模式(已修复抽象方法问题)")
except ImportError:
raise ImportError(
"❌ 无法导入通义千问相关依赖,请执行以下命令安装:\n"
"pip install langchain-qwen # 推荐方式(优先选择)\n"
"或\n"
"pip install dashscope # 备用方式"
)
# -------------------------- 1. 加载环境变量和配置 --------------------------
load_dotenv()
# 数据库配置(从.env文件读取,若不存在,则使用默认值)
DB_CONFIG = {
"user": os.getenv("DB_USER", "root"),
"password": os.getenv("DB_PASSWORD", "123456"),
"host": os.getenv("DB_HOST", "localhost"),
"port": os.getenv("DB_PORT", "3306"),
"name": os.getenv("DB_NAME", "test_db")
}
# Qwen模型配置(根据实际需求调整)
QWEN_CONFIG = {
"model_name": os.getenv("QWEN_MODEL", "qwen-turbo"),
"api_key": os.getenv("QWEN_API_KEY"), # 必须在.env中配置
"temperature": 0.1,
"max_tokens": 2048
}
# 检查必要的环境变量
if not QWEN_CONFIG["api_key"]:
raise ValueError("❌ 请在.env文件中配置QWEN_API_KEY(通义千问API密钥)")
# -------------------- 2. 初始化LLM模型(通义千问Qwen) --------------------
try:
llm = ChatQwen(
api_key=QWEN_CONFIG["api_key"],
model=QWEN_CONFIG["model_name"],
temperature=QWEN_CONFIG["temperature"],
max_tokens=QWEN_CONFIG["max_tokens"],
streaming=False
)
print(f"✅ Qwen模型初始化成功(导入方式:{QWEN_IMPORT_METHOD})")
except Exception as e:
raise RuntimeError(f"❌ Qwen模型初始化失败:{str(e)}\n"
"请检查:1. API密钥是否正确 2. 网络是否通畅 3. dashscope/
langchain-qwen是否安装")
# -------------------------- 3. 数据库连接初始化 --------------------------
def create_db_connection() -> SQLDatabase:
"""创建数据库连接"""
db_uri = (
f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@"
f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['name']}"
)
try:
db = SQLDatabase.from_uri(
db_uri,
sample_rows_in_table_info=2,
include_tables=["users"], # 只关注users表
ignore_tables=None
)
# 测试连接
db.run("SELECT 1")
print(f"✅ 数据库连接成功:{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['name']}")
return db
except Exception as e:
raise ConnectionError(f"❌ 数据库连接失败:{str(e)}\n"
"请检查:1. 数据库服务是否启动 2. 连接参数是否正确 3.Pymysql是否安装")
# 初始化数据库
try:
db = create_db_connection()
except ConnectionError as e:
print(f"❌ 数据库初始化失败:{e}")
exit(1)
# ------------------- 4. 手动构建SQL工具(修复安全校验逻辑) ------------------
@tool
def list_tables() -> str:
"""列出数据库中所有可用的表名,用于确认表是否存在"""
try:
tables = db.get_usable_table_names()
return f"数据库中的表:{tables}"
except Exception as e:
return f"❌ 列出表失败:{str(e)}"
@tool
def describe_table(table_name: str = "users") -> str:
"""查看指定表的结构(字段名、类型、注释)和样本数据,默认查看users表"""
if table_name != "users":
return "❌ 仅支持查看users表的结构"
try:
table_info = db.get_table_info([table_name])
return f"✅ users表结构和样本数据:\n{table_info}"
except Exception as e:
return f"❌ 查看表结构失败:{str(e)}"
@tool
def query_sql_db(query: str) -> str:
"""执行MySQL SELECT查询,仅允许查询users表,返回查询结果"""
try:
# 关键修复:SQL安全校验逻辑
query_stripped = query.strip().upper()
# (1)确保是SELECT查询(严格匹配开头)
if not query_stripped.startswith("SELECT"):
return "❌ 错误:仅支持SELECT查询操作"
# (2)确保查询的是users表(不区分大小写)
if "FROM" in query_stripped:
from_index = query_stripped.index("FROM") + 4
table_part = query_stripped[from_index:].strip().split()[0]
if table_part.upper() != "USERS":
return "❌ 错误:仅允许查询users表"
# (3)禁止危险操作(精确匹配完整单词,避免误判)
forbidden_keywords = ["CREATE", "DROP", "ALTER", "INSERT", "UPDATE", "DELETE", "TRUNCATE", "REPLACE"]
# 将查询按空格分割成单词,排除注释部分
query_words = query_stripped.split()
# 检查是否包含禁止的关键字(完整单词匹配)
for word in query_words:
if word in forbidden_keywords:
return f"❌ 错误:禁止执行{word}操作"
# (4)执行查询(移除SQL语句末尾的分号,避免语法错误)
query_clean = query.rstrip(';').strip()
result = db.run(query_clean)
return f"✅ 查询结果:\n{result}" if result else "✅ 查询结果为空"
except Exception as e:
return f"❌ SQL执行错误:{str(e)}"
# 组装SQL工具列表
sql_tools = [list_tables, describe_table, query_sql_db]
# 查看可用工具
print("\n
可用SQL工具:")
for tool in sql_tools:
print(f" - 工具名称:{tool.name}")
print(f" 描述:{tool.description}")
print()
# ------------------ 5. 定义状态(兼容字典和Pydantic模型) --------------------
class SQLState(BaseModel):
messages: List[BaseMessage] = [] # 对话历史
sql_result: Optional[str] = None # SQL执行结果
error_message: Optional[str] = None # 错误信息
tool_calls: Optional[List[Dict[str, Any]]] = None # 工具调用信息
@classmethod
def from_dict(cls, state_dict: Dict[str, Any]) -> "SQLState":
"""从字典创建状态对象(兼容旧版本LangGraph)"""
return cls(
messages=state_dict.get("messages", []),
sql_result=state_dict.get("sql_result"),
error_message=state_dict.get("error_message"),
tool_calls=state_dict.get("tool_calls")
)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(兼容旧版本LangGraph)"""
return {
"messages": self.messages,
"sql_result": self.sql_result,
"error_message": self.error_message,
"tool_calls": self.tool_calls
}
# ---------------- 6. 节点函数(兼容字典和Pydantic模型输入) --------------------
def sql_planner_node(
state: Union[SQLState, Dict[str, Any]],
config: Optional[RunnableConfig] = None
) -> Union[SQLState, Dict[str, Any]]:
"""生成SQL工具调用计划(兼容字典和Pydantic模型输入)"""
# 将输入状态统一转换为SQLState对象处理
if isinstance(state, dict):
state = SQLState.from_dict(state)
try:
system_prompt = """
你是专业的MySQL数据库查询助手,仅处理users表查询,严格遵守以下规则:
(1)仅查询users表,禁止访问其他表或执行DDL/DML操作(如CREATE/UPDATE/DELETE等)
(2)操作步骤:
a. 已确认users表存在(无须调用list_tables),直接调用describe_table获取字段信息
b. 根据字段信息,调用query_sql_db工具执行SELECT查询
(3)SQL编写规范:
- 只查询用户需要的字段,避免使用SELECT *
- 参数直接写具体值(如WHERE id=1,无须占位符)
- 不添加额外排序或计算(除非用户明确要求)
- 不要在SQL语句末尾加分号
(4)工具调用格式(必须严格遵循JSON格式,不要添加其他内容):
- 调用describe_table:{"name":"describe_table", "parameters":{"table_name":"users"}}
- 调用query_sql_db:{"name":"query_sql_db", "parameters":{"query":"SELECT 字段名 FROM users WHERE 条件"}}
(5)特殊情况处理:
- 若用户需求不明确(如未指定查询条件、字段模糊)→ 直接追问用户补充信息,不调用工具
- 仅返回工具调用指令或追问内容,不返回自然语言回答或解释
"""
messages = [SystemMessage(content=system_prompt)] + state.messages
response = llm.invoke(messages, config=config) if config else llm.invoke(messages)
# 提取工具调用(适配两种导入方式的格式)
tool_calls = []
if isinstance(response, AIMessage) and hasattr(response, 'additional_kwargs'):
# 处理 langchain_qwen 格式(tool_calls在additional_kwargs中)
tool_calls = response.additional_kwargs.get("tool_calls", [])
# 处理 DashScope 格式(可能直接返回工具调用JSON字符串)
if not tool_calls and response.content.strip().startswith("{") and response.content.strip().endswith("}"):
try:
tool_calls = [json.loads(response.content.strip())]
except:
pass
# 更新状态
state.messages.append(response)
state.tool_calls = tool_calls if tool_calls and len(tool_calls) > 0 else None
# 根据输入类型返回对应格式(兼容旧版本)
return state.to_dict() if isinstance(state, SQLState) else state
except Exception as e:
error_msg = f"计划生成失败:{str(e)}"
print(f"❌ {error_msg}")
state.error_message = error_msg
state.messages.append(SystemMessage(content=error_msg))
return state.to_dict() if isinstance(state, SQLState) else state
def sql_executor_node(
state: Union[SQLState, Dict[str, Any]],
config: Optional[RunnableConfig] = None
) -> Union[SQLState, Dict[str, Any]]:
"""执行工具调用(兼容字典和Pydantic模型输入)"""
# 将输入状态统一转换为SQLState对象处理
if isinstance(state, dict):
state = SQLState.from_dict(state)
try:
if state.error_message:
return state.to_dict() if isinstance(state, SQLState) else state
tool_calls = state.tool_calls or []
if not tool_calls:
state.messages.append(SystemMessage(content="未获取到有效工具调用指令,直接返回结果"))
return state.to_dict() if isinstance(state, SQLState) else state
# 执行第一个工具调用(支持多工具调用扩展)
tool_call = tool_calls[0]
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("parameters", {})
# 匹配工具
selected_tool = next((t for t in sql_tools if t.name == tool_name), None)
if not selected_tool:
error_msg = f"未找到工具:{tool_name}(可用工具:{[t.name for t in sql_tools]})"
state.error_message = error_msg
state.messages.append(SystemMessage(content=error_msg))
return state.to_dict() if isinstance(state, SQLState) else state
# 执行工具
print(f"
执行工具:{tool_name},参数:{tool_args}")
tool_result = selected_tool.func(**tool_args)
# 构建工具响应消息
tool_message = ToolMessage(
content=str(tool_result),
tool_call_id=tool_call.get("id", str(hash(f"{tool_name}_{tool_args}"))),
name=tool_name
)
# 更新状态
state.messages.append(tool_message)
state.sql_result = str(tool_result)
# 如果是describe_table,自动生成后续的query_sql_db调用
if tool_name == "describe_table":
# 提取用户原始查询
user_query = next((m.content for m in state.messages if isinstance(m, HumanMessage)), "")
if user_query:
# 生成查询SQL的提示
sql_prompt = f"""
根据以下users表结构信息,为用户查询生成MySQL SELECT语句:
表结构信息:{tool_result}
用户查询:{user_query}
要求:
(1)只查询用户需要的字段,不要查询多余字段
(2)语法正确,不要在语句末尾加分号
(3)只返回SQL语句,不返回任何其他内容
(4)条件判断要准确(如id=1要精确匹配)
"""
sql_response = llm.invoke([HumanMessage(content=sql_prompt)], config=config) if config else llm.invoke([HumanMessage(content=sql_prompt)])
sql_query = sql_response.content.strip()
# 添加query_sql_db工具调用
state.tool_calls = [{"name": "query_sql_db", "parameters": {"query": sql_query}}]
# 递归执行query_sql_db工具
return sql_executor_node(state, config)
return state.to_dict() if isinstance(state, SQLState) else state
except Exception as e:
error_msg = f"工具执行失败:{str(e)}"
print(f"❌ {error_msg}")
state.error_message = error_msg
state.messages.append(SystemMessage(content=error_msg))
return state.to_dict() if isinstance(state, SQLState) else state
def result_formatter_node(
state: Union[SQLState, Dict[str, Any]],
config: Optional[RunnableConfig] = None
) -> Union[SQLState, Dict[str, Any]]:
"""格式化查询结果为自然语言(兼容字典和Pydantic模型输入)"""
# 将输入状态统一转换为SQLState对象处理
if isinstance(state, dict):
state = SQLState.from_dict(state)
try:
# 处理错误情况
if state.error_message:
formatted_msg = f"查询过程中出现错误:{state.error_message}\n建议检查:1. 数据库连接 2. 查询条件 3. 表结构"
state.messages.append(HumanMessage(content=formatted_msg))
return state.to_dict() if isinstance(state, SQLState) else state
# 处理无工具调用且无查询结果的情况(直接返回LLM的追问/解释)
if not state.sql_result and state.messages:
# 取最后一条消息作为结果(可能是LLM的追问或解释)
last_msg = state.messages[-1]
if isinstance(last_msg, AIMessage) and not any(isinstance(m, ToolMessage) for m in state.messages):
state.messages.append(HumanMessage(content=last_msg.content))
return state.to_dict() if isinstance(state, SQLState) else state
# 处理无结果情况
if not state.sql_result or "查询结果为空" in state.sql_result:
state.messages.append(HumanMessage(content="未查询到符合条件的用户数据。"))
return state.to_dict() if isinstance(state, SQLState) else state
# 正常格式化结果
system_prompt = """
你是结果格式化助手,将数据库查询结果转换为清晰易读的自然语言,要求:
(1)去掉"查询结果:"、"✅"、"❌"等冗余前缀和符号
(2)若结果包含多条用户数据,分点列出(用数字或项目符号),每行一个用户
(3)仅展示用户查询的字段信息,不添加额外内容或解释
(4)格式简洁美观,避免使用代码块、引号等格式符号
(5)日期时间字段(如create_time)保留年月日即可,去掉时分秒
(6)示例:
用户查询:查询所有用户的姓名和邮箱
格式化结果:
1. 姓名:张三,邮箱:zhangsan@example.com
2. 姓名:李四,邮箱:lisi@example.com
"""
# 获取用户原始查询
user_query = next((m.content for m in state.messages if isinstance(m, HumanMessage)), "用户查询")
# 构建格式化请求
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"用户查询:{user_query}"),
HumanMessage(content=f"需要格式化的SQL结果:{state.sql_result}")
]
formatted_response = llm.invoke(messages, config=config) if config else llm.invoke(messages)
state.messages.append(formatted_response)
return state.to_dict() if isinstance(state, SQLState) else state
except Exception as e:
error_msg = f"结果格式化失败:{str(e)}"
print(f"❌ {error_msg}")
state.error_message = error_msg
state.messages.append(HumanMessage(content=error_msg))
return state.to_dict() if isinstance(state, SQLState) else state
# ------------------- 7. 构建状态图(无任何条件边,最稳定写法) --------------------
def build_sql_graph() -> StateGraph:
"""构建LangGraph状态图(无条件边,固定流程,兼容所有版本)"""
try:
# 旧版本LangGraph可能需要指定状态类型
try:
sql_graph_builder = StateGraph(SQLState)
except:
# 极早期版本不支持Pydantic状态,直接使用字典
sql_graph_builder = StateGraph(dict)
# 添加节点
sql_graph_builder.add_node("sql_planner", sql_planner_node) # 计划节点
sql_graph_builder.add_node("sql_executor", sql_executor_node)# 执行节点
sql_graph_builder.add_node("result_formatter", result_formatter_node) # 格式化节点
# 固定流程:START → 计划节点 → 执行节点 → 格式化节点 → END
sql_graph_builder.add_edge(START, "sql_planner")
sql_graph_builder.add_edge("sql_planner", "sql_executor")
sql_graph_builder.add_edge("sql_executor", "result_formatter")
sql_graph_builder.add_edge("result_formatter", END)
# 编译状态图
sql_graph = sql_graph_builder.compile()
print("✅ LangGraph状态图构建成功(兼容字典/模型输入,支持所有版本)")
return sql_graph
except Exception as e:
raise RuntimeError(f"❌ 状态图构建失败:{str(e)}")
# -------------------------- 8. 执行查询函数 --------------------------
def run_sql_query(user_query: str) -> str:
"""执行用户查询并返回自然语言结果"""
print(f"\n
收到用户查询:{user_query}")
try:
# 构建状态图
sql_graph = build_sql_graph()
# 初始化状态(兼容字典格式,避免Pydantic模型问题)
initial_state = {
"messages": [HumanMessage(content=user_query)],
"sql_result": None,
"error_message": None,
"tool_calls": None
}
# 执行状态图(兼容新旧版本的调用方式)
try:
# 新版本LangGraph:支持invoke
result_state = sql_graph.invoke(initial_state)
except AttributeError:
# 旧版本LangGraph:使用run
result_state = sql_graph.run(initial_state)
except Exception as e:
# 极端情况:直接调用节点函数执行流程
result_state = sql_planner_node(initial_state)
result_state = sql_executor_node(result_state)
result_state = result_formatter_node(result_state)
# 提取最终结果(兼容字典和模型格式)
if isinstance(result_state, dict):
messages = result_state.get("messages", [])
elif isinstance(result_state, SQLState):
messages = result_state.messages
else:
messages = []
# 返回最后一条人类可读消息
for msg in reversed(messages):
if isinstance(msg, (HumanMessage, AIMessage)):
return msg.content
return "查询完成,但未获取到有效结果"
except Exception as e:
error_msg = f"查询执行失败:{str(e)}"
print(f"❌ {error_msg}")
return error_msg
# -------------------------- 9. 主函数 --------------------------
if __name__ == "__main__":
print("=" * 60)
print("
通义千问 SQL 查询助手(最终完美版)")
print("
功能:MySQL users表查询 | 自动工具调用 | 自然语言结果")
print("
支持查询示例:'查询所有用户的姓名和邮箱'、'查询id为1的用户信息'")
print("
兼容:所有LangGraph版本 + 两种通义千问导入方式 + 字典/模型输入")
print("✅ 修复所有已知错误:语法/参数/版本/类型/安全校验错误")
print("=" * 60)
# 可选测试查询(帮助用户快速验证功能)
test_queries = [
"查询所有用户的姓名和邮箱",
"查询id为1的用户信息",
]
# 执行测试查询
for i, query in enumerate(test_queries, 1):
print(f"\n--- 测试查询 {i} ---")
answer = run_sql_query(query)
print(f"
回答:{answer}")
print("-" * 40)
# 交互式查询(核心功能)
print("\n
测试完成!现在可以输入您的查询(输入 'exit' 退出)")
while True:
user_input = input("> ")
if user_input.lower() == "exit":
print("
再见!")
break
if not user_input.strip():
print("❌ 请输入有效的查询内容(例如:查询所有女性用户的姓名和电话)")
continue
answer = run_sql_query(user_input)
print(f"
回答:{answer}")
输出结果如图6.2所示。

图6.2 输出结果
6.2.3 案例代码解析
该代码实现了一个基于LangChain+LangGraph+通义千问的MySQL数据库查询助手,核心功能是将用户自然语言查询自动转换为SQL语句,执行查询后再以自然语言返回结果,全程无须人工编写SQL。以下是分层解析。
1. 整体架构与核心流程
1)核心目标
- 仅支持users表查询,保证数据安全。
- 自动完成自然语言→SQL生成→SQL执行→结果格式化全流程。
- 兼容不同版本依赖(LangGraph新旧版本、通义千问两种导入方式)。
- 提供安全校验、错误处理、交互式查询等工程化特性。
2)整体流程(固定状态图流程)
用户输入自然语言查询→START→sql_planner(生成工具调用计划)→sql_executor(执行SQL工具)→result_formatter(格式化结果)→END→返回自然语言答案。
2. 关键依赖与导入
1)核心依赖
- langchain_core:LangChain核心组件(工具、消息、可运行对象)。
- langchain_community:第三方工具集成(如SQLDatabaseToolkit)。
- langgraph:状态图工作流(控制流程流转)。
- pydantic:数据模型定义(状态管理)。
- dotenv:环境变量加载。
- 通义千问依赖:二选一(langchain_qwen,推荐DashScope备用)。
- 数据库依赖:pymysql(MySQL连接)。
2)通义千问导入兼容(核心修复)
解决不同环境下的依赖导入问题,支持以下两种方式。
- 优先使用langchain_qwen:LangChain官方适配包,调用更简洁。
- 备用DashScope:通义千问官方SDK,自定义ChatQwen类实现LangChain接口。
- 核心是实现_generate方法(同步调用)和_stream方法(流式调用,简化版)。
- 自动转换LangChain消息格式与DashScope格式(HumanMessage→user角色等)。
3. 配置加载与初始化
1)环境变量与配置
- 从.env文件加载配置,缺失时使用默认值:
- 数据库配置(DB_USER、DB_PASSWORD、DB_HOST 等)。
- 通义千问配置(QWEN_API_KEY必配,QWEN_MODEL 默认为 qwen-turbo)。
- 校验关键配置:若未配置QWEN_API_KEY,则直接报错。
2)核心实例初始化
(1)通义千问LLM初始化
llm = ChatQwen(
api_key=QWEN_CONFIG["api_key"],
model=QWEN_CONFIG["model_name"],
temperature=QWEN_CONFIG["temperature"],
max_tokens=QWEN_CONFIG["max_tokens"],
streaming=False
)
异常处理:捕获API密钥错误、网络问题、依赖缺失等。
(2)数据库连接初始化
- 构建MySQL连接URI(mysql+pymysql://user:password@host:port/dbname)。
- 初始化SQLDatabase实例,仅关注users表(include_tables=["users"])。
- 测试连接:执行SELECT验证数据库可达性。
- 错误处理:数据库未启动、连接参数错误、pymysql未安装等。
4. SQL工具定义(安全+功能)
通过@tool装饰器定义3个核心工具,仅支持users表相关操作。
1)list_tables:列出可用表
- 功能:返回数据库中可查询的表名(仅users)。
- 用途:确认表存在性(实际流程中已默认确认,可用于调试)。
2)describe_table:查看表结构
- 功能:返回users表的字段名、类型、注释和样本数据。
- 安全限制:仅允许查看users表,禁止访问其他表。
3)query_sql_db:执行SQL查询(核心工具)
- 功能:执行MySQL SELECT查询,返回结果。
- 关键安全校验(防止恶意SQL):
- 仅允许SELECT开头的查询(禁止DDL/DML操作)。
- 仅允许查询users表(从FROM后提取表名校验)。
- 禁止危险关键字(CREATE、DROP、UPDATE、DELETE等)。
- 移除SQL末尾分号(避免语法错误)。
- 错误处理:SQL语法错误、查询逻辑错误等。
5. 状态管理(兼容字典与Pydantic)
- SQLState模型定义,存储流程中的关键数据,支持字典与Pydantic模型互转(兼容 LangGraph 新旧版本):
class SQLState(BaseModel):
messages: List[BaseMessage] = [] # 对话历史(用户消息、LLM 消息、工具消息)
sql_result: Optional[str] = None # SQL 执行结果
error_message: Optional[str] = None # 错误信息
tool_calls: Optional[List[Dict[str, Any]]] = None # LLM 生成的工具调用指令
- 提供from_dict和to_dict方法,解决旧版本LangGraph不支持Pydantic模型的问题。
6. 节点函数(状态图核心逻辑)
每个节点函数处理特定任务,输入输出均兼容字典/Pydantic模型,保证通用性。
1)sql_planner_node:生成工具调用计划
- 核心逻辑:让LLM根据用户查询,决定调用哪个工具(describe_table或query_sql_db)
- 系统提示词约束:
- 限制表的访问:仅查询users表,禁止其他操作。
- 明确步骤:先调用describe_table获取字段,再调用query_sql_db执行查询。
- SQL编写规范:不使用SELECT *、无分号、参数直接写值。
- 补充信息:模糊查询时追问用户补充信息,不盲目调用工具。
- 工具调用格式:严格JSON格式(如{"name":"describe_table","parameters" :{"table_name"
:"users"}})。
- 适配两种导入方式的工具调用格式提取(langchain_qwen的additional_kwargs以及dashscope 的JSON字符串)。
2)sql_executor_node:执行工具调用
- 核心逻辑:解析tool_calls指令,调用对应的SQL工具,返回执行结果。
- 关键流程:
- 匹配工具:根据tool_name 从sql_tools中找到对应工具。
- 执行工具:传入parameters调用工具函数,获取结果。
- 自动衔接:若执行describe_table,则自动让LLM生成query_sql_db所需的SQL语句,递归执行查询。
- 错误处理:工具不存在、参数错误、执行失败等。
3)result_formatter_node:格式化结果
- 核心逻辑:将SQL执行结果(JSON/字符串)转换为清晰易读的自然语言。
- 格式化规则:
- 去除冗余符号(✅/❌/查询结果等)。
- 多条数据分点列出,每行一个用户。
- 日期字段仅保留年月日。
- 仅展示查询字段,不添加额外解释。
- 异常处理:无结果、格式化失败、错误信息展示等。
7. LangGraph状态图构建(兼容所有版本)
1)核心设计:无条件固定流程
为避免LangGraph版本差异导致的条件边报错,采用固定流程设计:START→sql_planner→sql_executor→result_formatter→END。
- 不使用条件判断(如should_continue),降低复杂度,提升稳定性。
- 兼容处理:新版本LangGraph支持Pydantic状态模型,旧版本LangGraph自动降级为字典状态。
2)编译状态图
sql_graph_builder = StateGraph(SQLState) # 优先 Pydantic
# 若失败,则使用字典:sql_graph_builder = StateGraph(dict)
sql_graph_builder.add_node("sql_planner", sql_planner_node)
sql_graph_builder.add_edge(START, "sql_planner")
# 其他节点和边
sql_graph = sql_graph_builder.compile()
8. 核心执行函数与交互逻辑
1)run_sql_query:执行查询入口
- 输入:用户自然语言查询(如“查询所有用户的姓名和邮箱”)。
- 流程:
- 初始化状态:字典格式,兼容所有版本。
- 执行状态图:优先使用invoke方法,旧版本降级为run,极端情况直接调用节点函数。
- 提取结果:从状态的messages中找到最后一条人类可读消息(HumanMessage/AIM
-essage)。
- 错误处理:状态图执行失败、结果提取失败等。
2)交互式查询(主函数)
- 测试查询:启动时自动执行两个测试用例(验证功能可用性)。
- 交互式循环:用户输入查询,输入exit退出。
- 输入校验:过滤空输入,提示有效查询示例。
9. 总结
该代码是一个工程化程度较高、兼容性较强、安全性较好的自然语言转SQL工具,核心优势为:
- 全程自动化,无须人工干预。
- 兼容多种依赖版本,降低部署难度。
- 严格的安全限制,避免数据风险。
- 完善的错误处理和用户引导,提升使用体验。
该代码适用于需要快速实现自然语言查询MySQL users表的场景,如后台管理系统、简单数据查询工具等。

1339

被折叠的 条评论
为什么被折叠?



