-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
107 lines (100 loc) · 5.97 KB
/
main.py
File metadata and controls
107 lines (100 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from read_from_db import read_from_db
from augment_generate import generate_answer
import json
from tqdm import tqdm
from score import get_score
import csv
from utils import write_csv, calculate_avg, count_lines_in_jsonl, write_jsonl, read_jsonl
options = dict()
# 可能影响性能
options['k'] = 3
options['system_prompt'] = '你是一个蚂蚁集团的TuGraph数据库专家,\
擅长使用与TuGraph数据库相关的知识来回答用户的问题,\
针对用户的提问,你会得到一些文本材料辅助回答,如果某些辅助文本与提问关联性不强,则可以忽略,\
结合有用的部分以及你的知识,回答用户的提问。如果可以直接给出答案,则只回答最关键的具体信息部分,做到尽可能简洁。\
注意:问题中的数据库一律指代TuGraph,\
请仿照下面的样例答案格式进行后续的回答,给出答案.\
样例问题1:"RPC 及 HA 服务中,verbose 参数的设置有几个级别?", 样例答案: "三个级别(0,1,2)。"\
样例问题2:"如果成功修改一个用户的描述,应返回什么状态码?"样例答案:“200” '
# options['system_prompt'] = '你是一个蚂蚁集团的TuGraph数据库专家,\
# 擅长使用TuGraph数据库的相关知识来回答用户的问题,\
# 针对用户的提问,你会得到一些知识辅助,请忽略没有帮助的知识,\
# 结合有用的部分以及你的知识,尽可能简洁地直接给出答案,不需要任何解释。\
# 注意:问题中的数据库一律指代TuGraph,问及系统是否支持某些功能时,若不清楚一律回答暂不支持\
# 请仿照下面的样例答案格式进行后续的回答:\
# 样例问题1:"RPC 及 HA 服务中,verbose 参数的设置有几个级别?", 样例答案: "三个级别(0,1,2)。"\
# 样例问题2:"如果成功修改一个用户的描述,应返回什么状态码?"样例答案:“200” '
options['chat-model'] = "gpt-4o-mini"
options['embedding-model'] = "text-embedding-3-large"
options['tokens_per_knowledge'] = 2000 # 为防止单个知识过长,进行截断
# gpt调用
options['gpt-baseurl'] = 'https://api.gptapi.us/v'
options['gpt-apikey'] = "sk-xfovpV3O7IwdmDDJBb05Ff03E5014c14Ab5e935715Fe90D3"
# 文件路径
options['persist_directory'] = './db/xldatabase/rag'
options['test_path'] = './test/test1.jsonl'
options['val_path'] = './test/val.jsonl'
# 输出路径
options['test_out_path'] = './result/answer_test.jsonl'
options['val_out_path'] = './result/answer_val.jsonl'
options['score_path'] = './result/score.csv'
options['retrieval_path'] = './result/' # 对检索得到的知识输出
# 功能开启,1表示开启
options['use_val'] = 0
options['use_val_score'] = 0
options['use_test'] = 1
options['save_knowledge'] = 1 # 把问题对应检索知识保存下来
if options['use_val']:
print('正在对 val.jsonl 进行生成检索.....')
answers_val = []
if options['save_knowledge']: #是否要保存知识
knowledge_val = []
with tqdm(total=count_lines_in_jsonl(options['val_path'])) as pbar:
for obj in read_jsonl(options['val_path']):
query = obj.get('input_field') # 获取提问
if options['save_knowledge']:
knowledges = read_from_db(query, options['k'], options) # 读取知识a list of Documents
# 保存知识
knowledge_val.append(dict(Q = query, K1 = knowledges[0], K2 = knowledges[1], K3 = knowledges[2]))
# 生成答案
answers_val.append(dict(id=obj.get('id'), output_field = generate_answer(query,knowledges, options)))
else:
# 生成答案
answers_val.append(dict(id=obj.get('id'), output_field = generate_answer(query, read_from_db(query, options['k'], options), options)))
pbar.update(1)
# 答案写入文件
write_jsonl(answers_val, options['val_out_path'] )
if options['save_knowledge']:
# 知识保存写入文件
write_csv(knowledge_val, options['retrieval_path']+ 'retrieval_val.csv')
print('val.jsonl 已生成答案!\n \n')
if options['use_val_score']:
print('正在计算分数.....')
# 计算得分
score_output = get_score(options)
# 写入文件
write_csv(score_output, options['score_path'])
print('分数平均为{}! \n \n'.format(calculate_avg(score_output)))
# write_jsonl(score_output, options['score_path'])
if options['use_test']:
print('正在对 test1.jsonl 进行生成检索.....')
answers_test = []
if options['save_knowledge']: # 知识保存
knowledge_test = []
with tqdm(total=count_lines_in_jsonl(options['test_path'])) as pbar:
for obj in read_jsonl(options['test_path']):
query = obj.get('input_field')
if options['save_knowledge']:
# 知识检索和答案生成
knowledges = read_from_db(query, options['k'], options) # a list of Documents
knowledge_test.append(dict(Q = query, K1 = knowledges[0], K2 = knowledges[1], K3 = knowledges[2]))
answers_test.append(dict(id=obj.get('id'), output_field = generate_answer(query,knowledges, options)))
else:
# 生成问题答案
answers_test.append(dict(id=obj.get('id'), output_field = generate_answer(query, read_from_db(query, options['k'], options), options)))
pbar.update(1)
# 结果写入文件
write_jsonl(answers_test, options['test_out_path'])
if options['save_knowledge']:
write_csv(knowledge_test, options['retrieval_path']+ 'retrieval_test.csv')
print('test1.jsonl 已生成答案!\n \n')