수민 '-'

플오그래밍

제가 작성하는 모든 글은 절대 상업적인 이용이 아니며, 그저 개인적인 공부 용도로만 사용하는 것임을 밝힙니다.

쿼리문을 작성하는 RAG - LangGraph 커스텀 SQL 에이전트

1부에서 create_react_agent로 SQL 에이전트를 통째로 돌려봤다면, 2부는 흐름을 노드 단위로 쪼개 디버깅·프롬프트·분기 규칙을 세밀하게 잡는 버전이다. ToolNode로 도구 실행을 격리하고, 질문·스키마·실행 결과를 메시지 슬롯에 쌓은 뒤 쿼리 생성 → 실행 → 오류면 재시도 → 최종 한국어 답변 순으로 연결한다. 원문 흐름은 쿼리문을 작성하는 RAG를 바탕으로 정리했다.


0. 전제

아래 코드는 1부와 동일하게 db, llm, SQLDatabaseToolkit으로 만든 tools가 이미 있다고 가정한다. 없다면 1부의 Chinook 다운로드·SQLDatabase·toolkit.get_tools()까지 먼저 실행하면 된다.


1. 그래프 개요(다이어그램)

노드 연결 관계를 한눈에 보려면 원문에 실린 흐름 그림을 참고하면 좋다.


2. ToolNode로 도구별 실행 노드 만들기

리스트에서 이름으로 도구를 고른 뒤, 도구마다 ToolNode를 하나씩 만든다.

from langgraph.prebuilt import ToolNode

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
list_tables_node = ToolNode([list_tables_tool], name="list_tables")

get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")

run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")

3. DB 조회가 필요한지 판단하는 chatbot 노드

list_tables 도구만 bind한 LLM을 호출한다. tool_calls가 나오면 DB 탐색이 필요하고, 단순 AIMessage면 질문만으로 끝낼 수 있는 경우로 볼 수 있다.

from langgraph.graph import END, START, MessagesState, StateGraph

def chatbot(state: MessagesState):
    llm_with_tools = llm.bind_tools([list_tables_tool])
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}

4. 스키마 조회를 유도하는 call_get_schema

tool_choice="any"로 스키마 도구 호출을 강제에 가깝게 유도한다.

def call_get_schema(state: MessagesState):
    llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}

5. 사용자 질문 기반으로 SQL만 생성하기

시스템 프롬프트에서 DML 금지, 백틱·코드펜스 없이 SQL 문장만 출력하도록 못을 박는다. 메시지 슬롯 관점에서는 보통 0번: 사용자 질문, 1번: 스키마 텍스트, 2번 이후: 실행 결과·에러·수정 이력으로 쌓인다고 보면 history를 구성하기 쉽다.

from langchain_core.prompts import ChatPromptTemplate

generate_query_system_prompt = f"""
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {db.dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most 5 results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
DO NOT wrap the response in markdown code fences or backticks. Respond with a SQL statement only!
"""

generate_query_user_prompt = """
User input: {question}
Schema: {schema}

If an error message is given, regenerate the query based on the error message.
History: {history}

SQL query:"""

def generate_query(state: MessagesState):
    print("##### GENERATE QUERY #####")
    print(state["messages"])
    history = ""
    for message in state["messages"][2:]:
        history += message.content + "\n"

    generate_query_msgs = [
        ("system", generate_query_system_prompt),
        ("user", generate_query_user_prompt),
    ]
    generate_prompt = ChatPromptTemplate.from_messages(generate_query_msgs)

    response = llm.invoke(
        generate_prompt.format_messages(
            question=state["messages"][0].content,
            schema=state["messages"][1].content,
            history=history,
        )
    )
    print("generate_query", response)

    return {"messages": [response]}

6. 생성한 쿼리 실행 + 코드펜스 제거

모델이 가끔 마크다운 SQL 코드 블록으로 감싸면 실행기가 실패하므로, 정규식으로 내부 SQL만 추출한다.

from langchain_core.messages import AIMessage
import re

def _sanitize_sql(text: str) -> str:
    fence = chr(96) * 3  # 백틱 3개(마크다운 코드 펜스)
    pattern = fence + r"(?:sql)?\s*(.*?)" + fence
    m = re.search(pattern, text, flags=re.DOTALL | re.IGNORECASE)
    return (m.group(1) if m else text).strip()

def check_query(state: MessagesState):
    print("##### CHECK QUERY #####")
    raw = state["messages"][-1].content
    query = _sanitize_sql(raw)
    print("check_query", query)
    response = run_query_tool.invoke({"query": query})
    return {"messages": [AIMessage(content=str(response))]}

7. 질문·쿼리·결과를 묶어 한국어 답변

answer_system_prompt = """
You are a highly intelligent assistant trained to provide concise and accurate answers.
You will be given a context that has been retrieved from a database using a specific SQL query.
Your task is to analyze the context and answer the user's question based on the information provided in the context.
ANSWER IN KOREAN.
"""

def answer(state: MessagesState):
    print("##### ANSWER #####")
    question = state["messages"][0].content
    context = state["messages"][-1].content
    generated_query = state["messages"][-2].content
    print("context", context)

    answer_msgs = [
        ("system", answer_system_prompt),
        ("user", "User Question: {question} SQL Query: {generated_query} Context: {context}"),
    ]
    answer_prompt = ChatPromptTemplate.from_messages(answer_msgs)

    response = llm.invoke(
        answer_prompt.format_messages(
            question=question,
            generated_query=generated_query,
            context=context,
        )
    )
    print("response", response)

    return {"messages": [response]}

8. 그래프 컴파일: 오류 시 generate_query로 되돌리기

check_query 직후 메시지에 Error: 또는 error가 보이면 쿼리 재생성으로, 아니면 answer로 보낸다.

from langgraph.prebuilt import tools_condition

graph_builder = StateGraph(MessagesState)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")

graph_builder.add_node("list_tables", list_tables_node)
graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
    {
        "tools": "list_tables",
        END: END,
    },
)

graph_builder.add_node("call_get_schema", call_get_schema)
graph_builder.add_node("get_schema", get_schema_node)
graph_builder.add_node("generate_query", generate_query)
graph_builder.add_node("check_query", check_query)
graph_builder.add_node("answer", answer)

graph_builder.add_edge("list_tables", "call_get_schema")
graph_builder.add_edge("call_get_schema", "get_schema")
graph_builder.add_edge("get_schema", "generate_query")
graph_builder.add_edge("generate_query", "check_query")

def should_correct(state):
    print(state["messages"][-1].content)
    txt = state["messages"][-1].content
    if "Error:" in txt or "error" in txt.lower():
        return "generate_query"
    return "answer"

graph_builder.add_conditional_edges(
    "check_query",
    should_correct,
    {
        "generate_query": "generate_query",
        "answer": "answer",
    },
)

graph_builder.add_edge("answer", END)

graph = graph_builder.compile()

9. 실행 예시

question = "2009년에 가장 많은 매출을 올린 영업 사원은 누구인가요?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

조금 더 복합적인 질문도 같은 그래프로 돌릴 수 있다.

question = (
    "DB 조회를 기반으로 2009년 가장 많은 양의 음반을 판매한 아티스트의 "
    "해당 앨범 판매 기간을 알려주세요."
)

for step in graph.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

마치며

  • 커스텀 그래프는 어느 단계에서 실패했는지 로그로 남기기 좋고, 프롬프트·도구·분기만 바꿔도 동작이 달라진다.
  • _sanitize_sql은 LLM 출력 형식이 흔들릴 때 실행 실패를 줄이는 작은 방어층이다.
  • should_correct는 문자열 검사라 DB·드라이버 메시지 형식에 따라 키워드를 조정할 여지가 있다.
  • 운영 환경에서는 DML 차단, 행 수 제한, 접근 가능한 테이블 화이트리스트 등을 추가하는 것이 안전하다.

참고