数据库连接与操作(LangGraph框架)

目录

6.2.1  数据库连接与操作:结构化数据交互

6.2.2  实战案例:用户信息管理(MySQL数据库操作)

【示例6.2】基于LangChain+LangGraph+通义千问实现MySQL数据库查询助手。

6.2.3  案例代码解析

1. 整体架构与核心流程

2. 关键依赖与导入

3. 配置加载与初始化

4. SQL工具定义(安全+功能)

5. 状态管理(兼容字典与Pydantic)

6. 节点函数(状态图核心逻辑)

7. LangGraph状态图构建(兼容所有版本)

8. 核心执行函数与交互逻辑

9. 总结


LangGraph开发AI Agent实践(人工智能技术丛书)【行情 报价 价格 评测】-京东

数据库操作(查询、插入、更新、删除)是智能体处理结构化数据的核心场景(如用户信息查询、订单管理、数据统计等)。LangGraph中通过SQLAlchemy ORM与LangChain SQL工具实现数据库集成,核心是将SQL操作封装为工具,由LLM生成SQL语句并执行。

6.2.1  数据库连接与操作:结构化数据交互

LangGraph是LangChain生态系统中的一个组件,用于构建具有状态管理和循环控制能力的有状态、多参与者、可循环的智能体工作流。然而,需要澄清一个关键点:LangGraph本身并不是一个数据库,它不提供数据库存储功能,也不直接连接数据库。它主要用于控制流(Control Flow)建模,通过图(Graph)结构定义智能体的状态转移和节点间交互逻辑。

尽管如此,LangGraph应用可以与结构化数据库(如PostgreSQL、MySQL、SQLite、MongoDB等)进行集成,实现对结构化数据的读写操作。

技术详解:

  • 数据库连接:使用SQLAlchemy创建数据库引擎,支持MySQL、PostgreSQL、SQLite等主流数据库。
  • SQL工具封装:LangChain提供SQLDatabaseToolkit,自动生成常见SQL操作工具(如query_sql_db、list_tables、describe_table),无须手动编写SQL函数。
  • LLM生成SQL:通过提示词引导LLM根据用户需求生成合法SQL 语句,避免SQL注入风险(LangChain内置基础防护)。
  • 图节点设计:分为SQL生成节点、SQL执行节点、结果格式化节,通过状态流转完成从用户查询到数据库结果的全流程。

6.2.实战案例:用户信息管理(MySQL数据库操作)

【示例6.2】基于LangChain+LangGraph+通义千问实现MySQL数据库查询助手。

步骤1:创建数据库与表(MySQL),先在MySQL中创建测试数据库和表。

CREATE DATABASE IF NOT EXISTS test_db;

USE langgraph_demo;

CREATE TABLE IF NOT EXISTS users (

    id INT PRIMARY KEY AUTO_INCREMENT,

    name VARCHAR(50) NOT NULL COMMENT '用户名',

    age INT COMMENT '年龄',

    email VARCHAR(100) UNIQUE NOT NULL COMMENT '邮箱',

    create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'

);

-- 插入测试数据

INSERT INTO users (name, age, email) VALUES

('张三', 25, 'zhangsan@example.com'),

('李四', 30, 'lisi@example.com'),

('王五', 28, 'wangwu@example.com');

2:安装必要的依赖(requirements.txt)。

# 核心依赖(最宽松约束,确保支持所有版本)

langchain>=0.1.0

langchain-community>=0.1.0

langchain-core>=0.1.0

langgraph>=0.0.1  # 支持包括极早期版本在内的所有LangGraph

pydantic>=1.10.0  # 同时兼容pydantic 1.x2.x

# 通义千问依赖(二选一即可)

langchain-qwen>=0.1.10      # 推荐方式

dashscope>=1.10.0           # 备用方式

# 数据库依赖

python-dotenv==1.0.1

mysql-connector-python==8.4.0

pymysql==1.1.1

sqlalchemy>=1.4.0  # 兼容旧版本SQLAlchemy

# 基础依赖

requests>=2.20.0

numpy>=1.18.0

步骤 3:设置.env文件。

# mysql set

DB_USER=root   #用户名

DB_PASSWORD=1111   #密码

DB_HOST=localhost

DB_PORT=3306

DB_NAME=test_db  #数据库名

# 通义千问配置

QWEN_API_KEY=sk-xxxxxxxxxxxx....  # 从阿里云获取

QWEN_MODEL=qwen-turbo   # 可选:qwen-turboqwen-plusqwen-max

4:实现案例代码(LangGraph_database_connection_and_operation.py)。

# -*- coding: utf-8 -*-

import os

from typing import Optional, List, Dict, Any, ClassVar, Union

from dotenv import load_dotenv

from langchain_community.utilities import SQLDatabase

from langchain_core.tools import tool

from langchain_core.messages import (

    SystemMessage,

    HumanMessage,

    ToolMessage,

    BaseMessage,

    AIMessage,

)

from langchain_core.runnables import RunnableConfig

from pydantic import BaseModel

from langgraph.graph import StateGraph, START, END

import warnings

warnings.filterwarnings('ignore')

# ------------------- 关键修复:完善通义千问导入(支持两种方式) -------------------

try:

    # 方式1:优先使用 langchain_qwen(推荐)

    from langchain_qwen import ChatQwen

    QWEN_IMPORT_METHOD = "langchain_qwen"

except ImportError:

    try:

        # 方式2:若 langchain_qwen 安装失败,使用 DashScope(通义千问官方SDK+ LangChain 包装

        from dashscope import Generation

        from langchain_core.language_models import BaseChatModel

        from langchain_core.outputs import ChatResult, ChatGeneration, GenerationChunk

        import json

       

        # 自定义 ChatQwen 类(基于 DashScope,修复抽象方法问题)

        class ChatQwen(BaseChatModel):

            api_key: str

            model: str = "qwen-turbo"

            temperature: float = 0.1

            max_tokens: int = 2048

            # 必须实现的抽象属性:指定LLM类型

            _llm_type: ClassVar[str] = "qwen-dashscope"

           

            def _generate(self, messages, stop=None, run_manager=None, **kwargs):

                # 转换 LangChain 消息格式为 DashScope 格式

                dashscope_messages = []

                for msg in messages:

                    if isinstance(msg, HumanMessage):

                        dashscope_messages.append({"role": "user", "content": msg.content})

                    elif isinstance(msg, SystemMessage):

                        dashscope_messages.append({"role": "system", "content": msg.content})

                    elif isinstance(msg, AIMessage):

                        dashscope_messages.append({"role": "assistant", "content": msg.content})

                    elif isinstance(msg, ToolMessage):

                        dashscope_messages.append({"role": "tool", "content": msg.content})

               

                # 调用通义千问 API

                response = Generation.call(

                    model=self.model,

                    messages=dashscope_messages,

                    api_key=self.api_key,

                    temperature=self.temperature,

                    max_tokens=self.max_tokens,

                    result_format="message"

                )

               

                # 解析响应

                if response.output.choices and len(response.output.choices) > 0:

                    choice = response.output.choices[0]

                    content = choice.message.content

                    tool_calls = choice.message.get("tool_calls", [])

                   

                    # 构建 AIMessage(适配工具调用格式)

                    additional_kwargs = {}

                    if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:

                        additional_kwargs["tool_calls"] = tool_calls

                   

                    aimsg = AIMessage(

                        content=content,

                        additional_kwargs=additional_kwargs

                    )

                    return ChatResult(generations=[ChatGeneration(message=aimsg)])

                else:

                    raise ValueError(f"通义千问 API 响应为空:{response}")

           

            def _stream(self, messages, stop=None, run_manager=None, **kwargs):

                # 实现流式输出(简化版)

                yield GenerationChunk(content="")

           

            # 可选:实现 _llm_type 方法(兼容部分旧版本)

            def _llm_type(self) -> str:

                return "qwen-dashscope"

       

        QWEN_IMPORT_METHOD = "dashscope"

        print("⚠️  注意:未安装 langchain_qwen,使用 dashscope 兼容模式(已修复抽象方法问题)")

    except ImportError:

        raise ImportError(

            " 无法导入通义千问相关依赖,请执行以下命令安装:\n"

            "pip install langchain-qwen  # 推荐方式(优先选择)\n"

            "\n"

            "pip install dashscope  # 备用方式"

        )

# -------------------------- 1. 加载环境变量和配置 --------------------------

load_dotenv()

# 数据库配置(从.env文件读取,若不存在,则使用默认值)

DB_CONFIG = {

    "user": os.getenv("DB_USER", "root"),

    "password": os.getenv("DB_PASSWORD", "123456"),

    "host": os.getenv("DB_HOST", "localhost"),

    "port": os.getenv("DB_PORT", "3306"),

    "name": os.getenv("DB_NAME", "test_db")

}

# Qwen模型配置(根据实际需求调整)

QWEN_CONFIG = {

    "model_name": os.getenv("QWEN_MODEL", "qwen-turbo"),

    "api_key": os.getenv("QWEN_API_KEY"),  # 必须在.env中配置

    "temperature": 0.1,

    "max_tokens": 2048

}

# 检查必要的环境变量

if not QWEN_CONFIG["api_key"]:

    raise ValueError(" 请在.env文件中配置QWEN_API_KEY(通义千问API密钥)")

# -------------------- 2. 初始化LLM模型(通义千问Qwen --------------------

try:

    llm = ChatQwen(

        api_key=QWEN_CONFIG["api_key"],

        model=QWEN_CONFIG["model_name"],

        temperature=QWEN_CONFIG["temperature"],

        max_tokens=QWEN_CONFIG["max_tokens"],

        streaming=False

    )

    print(f" Qwen模型初始化成功(导入方式:{QWEN_IMPORT_METHOD}")

except Exception as e:

    raise RuntimeError(f" Qwen模型初始化失败:{str(e)}\n"

                       "请检查:1. API密钥是否正确 2. 网络是否通畅 3. dashscope/
langchain-qwen
是否安装")

# -------------------------- 3. 数据库连接初始化 --------------------------

def create_db_connection() -> SQLDatabase:

    """创建数据库连接"""

    db_uri = (

        f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@"

        f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['name']}"

    )

    try:

        db = SQLDatabase.from_uri(

            db_uri,

            sample_rows_in_table_info=2,

            include_tables=["users"],  # 只关注users

            ignore_tables=None

        )

        # 测试连接

        db.run("SELECT 1")

        print(f" 数据库连接成功:{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['name']}")

        return db

    except Exception as e:

        raise ConnectionError(f" 数据库连接失败:{str(e)}\n"

                             "请检查:1. 数据库服务是否启动 2. 连接参数是否正确 3.Pymysql是否安装")

# 初始化数据库

try:

    db = create_db_connection()

except ConnectionError as e:

    print(f" 数据库初始化失败:{e}")

    exit(1)

# ------------------- 4. 手动构建SQL工具(修复安全校验逻辑) ------------------

@tool

def list_tables() -> str:

    """列出数据库中所有可用的表名,用于确认表是否存在"""

    try:

        tables = db.get_usable_table_names()

        return f"数据库中的表:{tables}"

    except Exception as e:

        return f" 列出表失败:{str(e)}"

@tool

def describe_table(table_name: str = "users") -> str:

    """查看指定表的结构(字段名、类型、注释)和样本数据,默认查看users"""

    if table_name != "users":

        return " 仅支持查看users表的结构"

    try:

        table_info = db.get_table_info([table_name])

        return f" users表结构和样本数据:\n{table_info}"

    except Exception as e:

        return f" 查看表结构失败:{str(e)}"

@tool

def query_sql_db(query: str) -> str:

    """执行MySQL SELECT查询,仅允许查询users表,返回查询结果"""

    try:

        # 关键修复:SQL安全校验逻辑

        query_stripped = query.strip().upper()

       

        # 1)确保是SELECT查询(严格匹配开头)

        if not query_stripped.startswith("SELECT"):

            return " 错误:仅支持SELECT查询操作"

       

        # 2)确保查询的是users表(不区分大小写)

        if "FROM" in query_stripped:

            from_index = query_stripped.index("FROM") + 4

            table_part = query_stripped[from_index:].strip().split()[0]

            if table_part.upper() != "USERS":

                return " 错误:仅允许查询users"

       

        # 3)禁止危险操作(精确匹配完整单词,避免误判)

        forbidden_keywords = ["CREATE", "DROP", "ALTER", "INSERT", "UPDATE", "DELETE", "TRUNCATE", "REPLACE"]

        # 将查询按空格分割成单词,排除注释部分

        query_words = query_stripped.split()

        # 检查是否包含禁止的关键字(完整单词匹配)

        for word in query_words:

            if word in forbidden_keywords:

                return f" 错误:禁止执行{word}操作"

       

        # 4)执行查询(移除SQL语句末尾的分号,避免语法错误)

        query_clean = query.rstrip(';').strip()

        result = db.run(query_clean)

        return f" 查询结果:\n{result}" if result else " 查询结果为空"

    except Exception as e:

        return f" SQL执行错误:{str(e)}"

# 组装SQL工具列表

sql_tools = [list_tables, describe_table, query_sql_db]

# 查看可用工具

print("\n 可用SQL工具:")

for tool in sql_tools:

    print(f"  - 工具名称:{tool.name}")

    print(f"    描述:{tool.description}")

    print()

# ------------------ 5. 定义状态(兼容字典和Pydantic模型) --------------------

class SQLState(BaseModel):

    messages: List[BaseMessage] = []  # 对话历史

    sql_result: Optional[str] = None  # SQL执行结果

    error_message: Optional[str] = None  # 错误信息

    tool_calls: Optional[List[Dict[str, Any]]] = None  # 工具调用信息

   

    @classmethod

    def from_dict(cls, state_dict: Dict[str, Any]) -> "SQLState":

        """从字典创建状态对象(兼容旧版本LangGraph"""

        return cls(

            messages=state_dict.get("messages", []),

            sql_result=state_dict.get("sql_result"),

            error_message=state_dict.get("error_message"),

            tool_calls=state_dict.get("tool_calls")

        )

   

    def to_dict(self) -> Dict[str, Any]:

        """转换为字典(兼容旧版本LangGraph"""

        return {

            "messages": self.messages,

            "sql_result": self.sql_result,

            "error_message": self.error_message,

            "tool_calls": self.tool_calls

        }

# ---------------- 6. 节点函数(兼容字典和Pydantic模型输入) --------------------

def sql_planner_node(

    state: Union[SQLState, Dict[str, Any]],

    config: Optional[RunnableConfig] = None

) -> Union[SQLState, Dict[str, Any]]:

    """生成SQL工具调用计划(兼容字典和Pydantic模型输入)"""

    # 将输入状态统一转换为SQLState对象处理

    if isinstance(state, dict):

        state = SQLState.from_dict(state)

   

    try:

        system_prompt = """

        你是专业的MySQL数据库查询助手,仅处理users表查询,严格遵守以下规则:

        1)仅查询users表,禁止访问其他表或执行DDL/DML操作(如CREATE/UPDATE/DELETE等)

        2)操作步骤:

           a. 已确认users表存在(无须调用list_tables),直接调用describe_table获取字段信息

           b. 根据字段信息,调用query_sql_db工具执行SELECT查询

        3SQL编写规范:

           - 只查询用户需要的字段,避免使用SELECT *

           - 参数直接写具体值(如WHERE id=1,无须占位符)

           - 不添加额外排序或计算(除非用户明确要求)

           - 不要在SQL语句末尾加分号

        4)工具调用格式(必须严格遵循JSON格式,不要添加其他内容):

           - 调用describe_table{"name":"describe_table", "parameters":{"table_name":"users"}}

           - 调用query_sql_db{"name":"query_sql_db", "parameters":{"query":"SELECT 字段名 FROM users WHERE 条件"}}

        5)特殊情况处理:

           - 若用户需求不明确(如未指定查询条件、字段模糊)→ 直接追问用户补充信息,不调用工具

           - 仅返回工具调用指令或追问内容,不返回自然语言回答或解释

        """

       

        messages = [SystemMessage(content=system_prompt)] + state.messages

        response = llm.invoke(messages, config=config) if config else llm.invoke(messages)

       

        # 提取工具调用(适配两种导入方式的格式)

        tool_calls = []

        if isinstance(response, AIMessage) and hasattr(response, 'additional_kwargs'):

            # 处理 langchain_qwen 格式(tool_callsadditional_kwargs中)

            tool_calls = response.additional_kwargs.get("tool_calls", [])

            # 处理 DashScope 格式(可能直接返回工具调用JSON字符串)

            if not tool_calls and response.content.strip().startswith("{") and response.content.strip().endswith("}"):

                try:

                    tool_calls = [json.loads(response.content.strip())]

                except:

                    pass

       

        # 更新状态

        state.messages.append(response)

        state.tool_calls = tool_calls if tool_calls and len(tool_calls) > 0 else None

       

        # 根据输入类型返回对应格式(兼容旧版本)

        return state.to_dict() if isinstance(state, SQLState) else state

    except Exception as e:

        error_msg = f"计划生成失败:{str(e)}"

        print(f" {error_msg}")

        state.error_message = error_msg

        state.messages.append(SystemMessage(content=error_msg))

        return state.to_dict() if isinstance(state, SQLState) else state

def sql_executor_node(

    state: Union[SQLState, Dict[str, Any]],

    config: Optional[RunnableConfig] = None

) -> Union[SQLState, Dict[str, Any]]:

    """执行工具调用(兼容字典和Pydantic模型输入)"""

    # 将输入状态统一转换为SQLState对象处理

    if isinstance(state, dict):

        state = SQLState.from_dict(state)

   

    try:

        if state.error_message:

            return state.to_dict() if isinstance(state, SQLState) else state

       

        tool_calls = state.tool_calls or []

        if not tool_calls:

            state.messages.append(SystemMessage(content="未获取到有效工具调用指令,直接返回结果"))

            return state.to_dict() if isinstance(state, SQLState) else state

       

        # 执行第一个工具调用(支持多工具调用扩展)

        tool_call = tool_calls[0]

        tool_name = tool_call.get("name", "")

        tool_args = tool_call.get("parameters", {})

       

        # 匹配工具

        selected_tool = next((t for t in sql_tools if t.name == tool_name), None)

        if not selected_tool:

            error_msg = f"未找到工具:{tool_name}(可用工具:{[t.name for t in sql_tools]}"

            state.error_message = error_msg

            state.messages.append(SystemMessage(content=error_msg))

            return state.to_dict() if isinstance(state, SQLState) else state

       

        # 执行工具

        print(f" 执行工具:{tool_name},参数:{tool_args}")

        tool_result = selected_tool.func(**tool_args)

       

        # 构建工具响应消息

        tool_message = ToolMessage(

            content=str(tool_result),

            tool_call_id=tool_call.get("id", str(hash(f"{tool_name}_{tool_args}"))),

            name=tool_name

        )

       

        # 更新状态

        state.messages.append(tool_message)

        state.sql_result = str(tool_result)

       

        # 如果是describe_table,自动生成后续的query_sql_db调用

        if tool_name == "describe_table":

            # 提取用户原始查询

            user_query = next((m.content for m in state.messages if isinstance(m, HumanMessage)), "")

            if user_query:

                # 生成查询SQL的提示

                sql_prompt = f"""

                根据以下users表结构信息,为用户查询生成MySQL SELECT语句:

                表结构信息:{tool_result}

                用户查询:{user_query}

                要求:

                1)只查询用户需要的字段,不要查询多余字段

                2)语法正确,不要在语句末尾加分号

                3)只返回SQL语句,不返回任何其他内容

                4)条件判断要准确(如id=1要精确匹配)

                """

                sql_response = llm.invoke([HumanMessage(content=sql_prompt)], config=config) if config else llm.invoke([HumanMessage(content=sql_prompt)])

                sql_query = sql_response.content.strip()

               

                # 添加query_sql_db工具调用

                state.tool_calls = [{"name": "query_sql_db", "parameters": {"query": sql_query}}]

                # 递归执行query_sql_db工具

                return sql_executor_node(state, config)

       

        return state.to_dict() if isinstance(state, SQLState) else state

    except Exception as e:

        error_msg = f"工具执行失败:{str(e)}"

        print(f" {error_msg}")

        state.error_message = error_msg

        state.messages.append(SystemMessage(content=error_msg))

        return state.to_dict() if isinstance(state, SQLState) else state

def result_formatter_node(

    state: Union[SQLState, Dict[str, Any]],

    config: Optional[RunnableConfig] = None

) -> Union[SQLState, Dict[str, Any]]:

    """格式化查询结果为自然语言(兼容字典和Pydantic模型输入)"""

    # 将输入状态统一转换为SQLState对象处理

    if isinstance(state, dict):

        state = SQLState.from_dict(state)

   

    try:

        # 处理错误情况

        if state.error_message:

            formatted_msg = f"查询过程中出现错误:{state.error_message}\n建议检查:1. 数据库连接 2. 查询条件 3. 表结构"

            state.messages.append(HumanMessage(content=formatted_msg))

            return state.to_dict() if isinstance(state, SQLState) else state

       

        # 处理无工具调用且无查询结果的情况(直接返回LLM的追问/解释)

        if not state.sql_result and state.messages:

            # 取最后一条消息作为结果(可能是LLM的追问或解释)

            last_msg = state.messages[-1]

            if isinstance(last_msg, AIMessage) and not any(isinstance(m, ToolMessage) for m in state.messages):

                state.messages.append(HumanMessage(content=last_msg.content))

                return state.to_dict() if isinstance(state, SQLState) else state

       

        # 处理无结果情况

        if not state.sql_result or "查询结果为空" in state.sql_result:

            state.messages.append(HumanMessage(content="未查询到符合条件的用户数据。"))

            return state.to_dict() if isinstance(state, SQLState) else state

       

        # 正常格式化结果

        system_prompt = """

        你是结果格式化助手,将数据库查询结果转换为清晰易读的自然语言,要求:

        1)去掉"查询结果:"""""等冗余前缀和符号

        2)若结果包含多条用户数据,分点列出(用数字或项目符号),每行一个用户

        3)仅展示用户查询的字段信息,不添加额外内容或解释

        4)格式简洁美观,避免使用代码块、引号等格式符号

        5)日期时间字段(如create_time)保留年月日即可,去掉时分秒

        6)示例:

           用户查询:查询所有用户的姓名和邮箱

           格式化结果:

           1. 姓名:张三,邮箱:zhangsan@example.com

           2. 姓名:李四,邮箱:lisi@example.com

        """

        # 获取用户原始查询

        user_query = next((m.content for m in state.messages if isinstance(m, HumanMessage)), "用户查询")

       

        # 构建格式化请求

        messages = [

            SystemMessage(content=system_prompt),

            HumanMessage(content=f"用户查询:{user_query}"),

            HumanMessage(content=f"需要格式化的SQL结果:{state.sql_result}")

        ]

        formatted_response = llm.invoke(messages, config=config) if config else llm.invoke(messages)

        state.messages.append(formatted_response)

        return state.to_dict() if isinstance(state, SQLState) else state

    except Exception as e:

        error_msg = f"结果格式化失败:{str(e)}"

        print(f" {error_msg}")

        state.error_message = error_msg

        state.messages.append(HumanMessage(content=error_msg))

        return state.to_dict() if isinstance(state, SQLState) else state

# ------------------- 7. 构建状态图(无任何条件边,最稳定写法) --------------------

def build_sql_graph() -> StateGraph:

    """构建LangGraph状态图(无条件边,固定流程,兼容所有版本)"""

    try:

        # 旧版本LangGraph可能需要指定状态类型

        try:

            sql_graph_builder = StateGraph(SQLState)

        except:

            # 极早期版本不支持Pydantic状态,直接使用字典

            sql_graph_builder = StateGraph(dict)

        # 添加节点

        sql_graph_builder.add_node("sql_planner", sql_planner_node)  # 计划节点

        sql_graph_builder.add_node("sql_executor", sql_executor_node)# 执行节点

        sql_graph_builder.add_node("result_formatter", result_formatter_node)  # 格式化节点

        # 固定流程:START 计划节点 执行节点 格式化节点 END

        sql_graph_builder.add_edge(START, "sql_planner")

        sql_graph_builder.add_edge("sql_planner", "sql_executor")

        sql_graph_builder.add_edge("sql_executor", "result_formatter")

        sql_graph_builder.add_edge("result_formatter", END)

       

        # 编译状态图

        sql_graph = sql_graph_builder.compile()

        print(" LangGraph状态图构建成功(兼容字典/模型输入,支持所有版本)")

        return sql_graph

    except Exception as e:

        raise RuntimeError(f" 状态图构建失败:{str(e)}")

# -------------------------- 8. 执行查询函数 --------------------------

def run_sql_query(user_query: str) -> str:

    """执行用户查询并返回自然语言结果"""

    print(f"\n 收到用户查询:{user_query}")

    try:

        # 构建状态图

        sql_graph = build_sql_graph()

       

        # 初始化状态(兼容字典格式,避免Pydantic模型问题)

        initial_state = {

            "messages": [HumanMessage(content=user_query)],

            "sql_result": None,

            "error_message": None,

            "tool_calls": None

        }

        # 执行状态图(兼容新旧版本的调用方式)

        try:

            # 新版本LangGraph:支持invoke

            result_state = sql_graph.invoke(initial_state)

        except AttributeError:

            # 旧版本LangGraph:使用run

            result_state = sql_graph.run(initial_state)

        except Exception as e:

            # 极端情况:直接调用节点函数执行流程

            result_state = sql_planner_node(initial_state)

            result_state = sql_executor_node(result_state)

            result_state = result_formatter_node(result_state)

        # 提取最终结果(兼容字典和模型格式)

        if isinstance(result_state, dict):

            messages = result_state.get("messages", [])

        elif isinstance(result_state, SQLState):

            messages = result_state.messages

        else:

            messages = []

        # 返回最后一条人类可读消息

        for msg in reversed(messages):

            if isinstance(msg, (HumanMessage, AIMessage)):

                return msg.content

        return "查询完成,但未获取到有效结果"

    except Exception as e:

        error_msg = f"查询执行失败:{str(e)}"

        print(f" {error_msg}")

        return error_msg

# -------------------------- 9. 主函数 --------------------------

if __name__ == "__main__":

    print("=" * 60)

    print(" 通义千问 SQL 查询助手(最终完美版)")

    print(" 功能:MySQL users表查询 | 自动工具调用 | 自然语言结果")

    print(" 支持查询示例:'查询所有用户的姓名和邮箱''查询id1的用户信息'")

    print(" 兼容:所有LangGraph版本 + 两种通义千问导入方式 + 字典/模型输入")

    print(" 修复所有已知错误:语法/参数/版本/类型/安全校验错误")

    print("=" * 60)

    # 可选测试查询(帮助用户快速验证功能)

    test_queries = [

        "查询所有用户的姓名和邮箱",

        "查询id1的用户信息",

    ]

    # 执行测试查询

    for i, query in enumerate(test_queries, 1):

        print(f"\n--- 测试查询 {i} ---")

        answer = run_sql_query(query)

        print(f" 回答:{answer}")

        print("-" * 40)

    # 交互式查询(核心功能)

    print("\n 测试完成!现在可以输入您的查询(输入 'exit' 退出)")

    while True:

        user_input = input("> ")

        if user_input.lower() == "exit":

            print(" 再见!")

            break

        if not user_input.strip():

            print(" 请输入有效的查询内容(例如:查询所有女性用户的姓名和电话)")

            continue

        answer = run_sql_query(user_input)

        print(f" 回答:{answer}")

输出结果如图6.2所示。

图6.2  输出结果

6.2.3  案例代码解析

该代码实现了一个基于LangChain+LangGraph+通义千问的MySQL数据库查询助手,核心功能是将用户自然语言查询自动转换为SQL语句,执行查询后再以自然语言返回结果,全程无须人工编写SQL。以下是分层解析。

1. 整体架构与核心流程

1)核心目标

  • 仅支持users表查询,保证数据安全。
  • 自动完成自然语言→SQL生成→SQL执行→结果格式化全流程。
  • 兼容不同版本依赖(LangGraph新旧版本、通义千问两种导入方式)。
  • 提供安全校验、错误处理、交互式查询等工程化特性。

2)整体流程(固定状态图流程)

用户输入自然语言查询→START→sql_planner(生成工具调用计划)→sql_executor(执行SQL工具)→result_formatter(格式化结果)→END→返回自然语言答案。

2. 关键依赖与导入

1)核心依赖

  • langchain_core:LangChain核心组件(工具、消息、可运行对象)。
  • langchain_community:第三方工具集成(如SQLDatabaseToolkit)。
  • langgraph:状态图工作流(控制流程流转)。
  • pydantic:数据模型定义(状态管理)。
  • dotenv:环境变量加载。
  • 通义千问依赖:二选一(langchain_qwen,推荐DashScope备用)。
  • 数据库依赖:pymysql(MySQL连接)。

2)通义千问导入兼容(核心修复)

解决不同环境下的依赖导入问题,支持以下两种方式。

  • 优先使用langchain_qwen:LangChain官方适配包,调用更简洁。
  • 备用DashScope:通义千问官方SDK,自定义ChatQwen类实现LangChain接口。
  • 核心是实现_generate方法(同步调用)和_stream方法(流式调用,简化版)。
  • 自动转换LangChain消息格式与DashScope格式(HumanMessage→user角色等)。
3. 配置加载与初始化

1)环境变量与配置

  • 从.env文件加载配置,缺失时使用默认值:
  • 数据库配置(DB_USER、DB_PASSWORD、DB_HOST 等)。
  • 通义千问配置(QWEN_API_KEY必配,QWEN_MODEL 默认为 qwen-turbo)。
  • 校验关键配置:若未配置QWEN_API_KEY,则直接报错。

2)核心实例初始化

(1)通义千问LLM初始化

llm = ChatQwen(

    api_key=QWEN_CONFIG["api_key"],

    model=QWEN_CONFIG["model_name"],

    temperature=QWEN_CONFIG["temperature"],

    max_tokens=QWEN_CONFIG["max_tokens"],

    streaming=False

)

异常处理:捕获API密钥错误、网络问题、依赖缺失等。

(2)数据库连接初始化

  • 构建MySQL连接URI(mysql+pymysql://user:password@host:port/dbname)。
  • 初始化SQLDatabase实例,仅关注users表(include_tables=["users"])。
  • 测试连接:执行SELECT验证数据库可达性。
  • 错误处理:数据库未启动、连接参数错误、pymysql未安装等。
4. SQL工具定义(安全+功能)

通过@tool装饰器定义3个核心工具,仅支持users表相关操作。

1)list_tables:列出可用表

  • 功能:返回数据库中可查询的表名(仅users)。
  • 用途:确认表存在性(实际流程中已默认确认,可用于调试)。

2)describe_table:查看表结构

  • 功能:返回users表的字段名、类型、注释和样本数据。
  • 安全限制:仅允许查看users表,禁止访问其他表。

3)query_sql_db:执行SQL查询(核心工具)

  • 功能:执行MySQL SELECT查询,返回结果。
  • 关键安全校验(防止恶意SQL):
  • 仅允许SELECT开头的查询(禁止DDL/DML操作)。
  • 仅允许查询users表(从FROM后提取表名校验)。
  • 禁止危险关键字(CREATE、DROP、UPDATE、DELETE等)。
  • 移除SQL末尾分号(避免语法错误)。
  • 错误处理:SQL语法错误、查询逻辑错误等。
5. 态管理(兼容字典与Pydantic)
  • SQLState模型定义,存储流程中的关键数据,支持字典与Pydantic模型互转(兼容 LangGraph 新旧版本):

class SQLState(BaseModel):

    messages: List[BaseMessage] = []  # 对话历史(用户消息、LLM 消息、工具消息)

    sql_result: Optional[str] = None  # SQL 执行结果

    error_message: Optional[str] = None  # 错误信息

    tool_calls: Optional[List[Dict[str, Any]]] = None  # LLM 生成的工具调用指令

  • 提供from_dict和to_dict方法,解决旧版本LangGraph不支持Pydantic模型的问题。
6. 点函数(状态图核心逻辑)

每个节点函数处理特定任务,输入输出均兼容字典/Pydantic模型,保证通用性。

1)sql_planner_node:生成工具调用计划

  • 核心逻辑:让LLM根据用户查询,决定调用哪个工具(describe_table或query_sql_db)
  • 系统提示词约束:
  • 限制表的访问:仅查询users表,禁止其他操作。
  • 明确步骤:先调用describe_table获取字段,再调用query_sql_db执行查询。
  • SQL编写规范:不使用SELECT *、无分号、参数直接写值。
  • 补充信息:模糊查询时追问用户补充信息,不盲目调用工具。
  • 工具调用格式:严格JSON格式(如{"name":"describe_table","parameters" :{"table_name"

:"users"}}

  • 适配两种导入方式的工具调用格式提取(langchain_qwen的additional_kwargs以及dashscope 的JSON字符串)。

2)sql_executor_node:执行工具调用

  • 核心逻辑:解析tool_calls指令,调用对应的SQL工具,返回执行结果。
  • 关键流程:
  • 匹配工具:根据tool_name 从sql_tools中找到对应工具。
  • 执行工具:传入parameters调用工具函数,获取结果。
  • 自动衔接:若执行describe_table,则自动让LLM生成query_sql_db所需的SQL语句,递归执行查询。
  • 错误处理:工具不存在、参数错误、执行失败等。

3)result_formatter_node:格式化结果

  • 核心逻辑:将SQL执行结果(JSON/字符串)转换为清晰易读的自然语言。
  • 格式化规则:
  • 去除冗余符号(✅/❌/查询结果等)。
  • 多条数据分点列出,每行一个用户。
  • 日期字段仅保留年月日。
  • 仅展示查询字段,不添加额外解释。
  • 异常处理:无结果、格式化失败、错误信息展示等。

7. LangGraph状态图构建(兼容所有版本)

1)核心设计:无条件固定流程

为避免LangGraph版本差异导致的条件边报错,采用固定流程设计:START→sql_planner→sql_executor→result_formatter→END。

  • 不使用条件判断(如should_continue),降低复杂度,提升稳定性。
  • 兼容处理:新版本LangGraph支持Pydantic状态模型,旧版本LangGraph自动降级为字典状态。

2)编译状态图

sql_graph_builder = StateGraph(SQLState)  # 优先 Pydantic

# 若失败,则使用字典:sql_graph_builder = StateGraph(dict)

sql_graph_builder.add_node("sql_planner", sql_planner_node)

sql_graph_builder.add_edge(START, "sql_planner")

# 其他节点和边

sql_graph = sql_graph_builder.compile()

8. 核心执行函数与交互逻辑

1)run_sql_query:执行查询入口

  • 输入:用户自然语言查询(如“查询所有用户的姓名和邮箱”)。
  • 流程:
  • 初始化状态:字典格式,兼容所有版本。
  • 执行状态图:优先使用invoke方法,旧版本降级为run,极端情况直接调用节点函数。
  • 提取结果:从状态的messages中找到最后一条人类可读消息(HumanMessage/AIM
        -essage)。
  • 错误处理:状态图执行失败、结果提取失败等。

2)交互式查询(主函数)

  • 测试查询:启动时自动执行两个测试用例(验证功能可用性)。
  • 交互式循环:用户输入查询,输入exit退出。
  • 输入校验:过滤空输入,提示有效查询示例。
9. 总结

该代码是一个工程化程度较高、兼容性较强、安全性较好的自然语言转SQL工具,核心优势为:

  • 全程自动化,无须人工干预。
  • 兼容多种依赖版本,降低部署难度。
  • 严格的安全限制,避免数据风险。
  • 完善的错误处理和用户引导,提升使用体验。

该代码适用于需要快速实现自然语言查询MySQL users表的场景,如后台管理系统、简单数据查询工具等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值