-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_cli.py
More file actions
126 lines (104 loc) · 3.89 KB
/
query_cli.py
File metadata and controls
126 lines (104 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations
import argparse
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from common import (
create_chat_llm,
create_embeddings,
default_persist_dir,
load_project_env,
resolve_llm_config,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Query the markdown-header RAG index.")
parser.add_argument("query", type=str, nargs="?", default="각 국가별 대표적인 과학적인 성과를 요약해줘")
parser.add_argument("--persist-dir", type=str, default=str(default_persist_dir()))
parser.add_argument("--collection", type=str, default="w2_007_header_rag")
parser.add_argument("--embedding-model", type=str, default="BAAI/bge-m3")
parser.add_argument("--llm-provider", type=str, default="openai", choices=["openai", "ollama", "lmstudio"])
parser.add_argument("--llm-model", "--llm", dest="llm_model", type=str, default=None)
parser.add_argument("--llm-api-key", type=str, default=None)
parser.add_argument("--llm-base-url", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--k", type=int, default=3)
parser.add_argument("--show-docs", action="store_true")
return parser.parse_args()
def format_docs(docs) -> str:
lines = []
for idx, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "unknown")
h2 = doc.metadata.get("h2", "")
lines.append(f"[{idx}] source={source} h2={h2}\n{doc.page_content}")
return "\n\n".join(lines)
def main() -> None:
args = parse_args()
env_path = load_project_env()
if env_path:
print(f"Loaded env: {env_path}")
embeddings = create_embeddings(args.embedding_model)
db = Chroma(
collection_name=args.collection,
embedding_function=embeddings,
persist_directory=args.persist_dir,
)
retriever = db.as_retriever(
search_type="mmr",
search_kwargs={"k": args.k, "fetch_k": 10, "lambda_mult": 0.3},
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""당신은 유럽 과학사 질의응답 어시스턴트입니다.
반드시 [Context]에 있는 정보만 사용해 한국어로 답변하세요.
근거가 부족하면 '제공된 문서에서 확인되지 않습니다.'라고 답변하세요.
내부 추론, Thinking Process, Analyze, Constraint, 단계별 사고, 번호 목록은 출력하지 마세요.
최종 답변만 반환하세요.""",
),
(
"human",
"""[Context]
{context}
[Question]
{question}
[Answer]
`<final_answer>` 태그 내부에 들어갈 최종 답변 내용만 작성하세요.""",
),
("assistant", "<final_answer>"),
]
)
provider, model, api_key, base_url = resolve_llm_config(
provider=args.llm_provider,
model=args.llm_model,
api_key=args.llm_api_key,
base_url=args.llm_base_url,
)
print(f"[LLM] provider={provider} model={model}")
if base_url:
print(f"[LLM] base_url={base_url}")
llm = create_chat_llm(
provider=provider,
model=model,
temperature=args.temperature,
api_key=api_key,
base_url=base_url,
)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
if args.show_docs:
docs = retriever.invoke(args.query)
print("\n[Retrieved Docs]")
print(format_docs(docs)[:2500])
output = rag_chain.invoke(args.query)
print("\n[Query]")
print(args.query)
print("\n[Answer]")
print(output)
if __name__ == "__main__":
main()