-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp_workflow.py
More file actions
203 lines (161 loc) · 6.34 KB
/
app_workflow.py
File metadata and controls
203 lines (161 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from langchain_community.vectorstores import DeepLake
from langchain_cohere import CohereEmbeddings
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere.rerank import CohereRerank
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
import streamlit as st
import io
import re
import sys
from typing import Any, Callable
from data_loading import run_job
import os
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
os.environ["COHERE_API_KEY"] = st.secrets["COHERE_API_KEY"]
os.environ["ACTIVELOOP_TOKEN"] = st.secrets["ACTIVELOOP_TOKEN"]
activeloop_org_id = st.secrets["ACTIVELOOP_ORG_ID"]
activeloop_db_id = "activeloop_course_educhain_bot"
# Task that will run only once to gather the data needed
# for providing context to the chatbot
run_job()
# ------ Data Retrieval Process ------
@st.cache_resource()
def data_lake():
embeddings = CohereEmbeddings(model = "embed-english-v2.0")
dbs = DeepLake(
dataset_path=f"hub://{activeloop_org_id}/{activeloop_db_id}",
read_only=True,
embedding=embeddings
)
retriever = dbs.as_retriever(search_type="mmr")
# DeepLake instance as a retriever to fetch specific params
retriever.search_kwargs["distance_metric"] = "cos"
retriever.search_kwargs["fetch_k"] = 20
retriever.search_kwargs["k"] = 20
# -- Refines and Ranks documents in alignment with a user’s search criteria --
# This endpoint acts as the last stage reranker of a search flow.
compressor = CohereRerank(
model = 'rerank-english-v2.0',
top_n=5
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
return dbs, compression_retriever, retriever
dbs, compression_retriever, retriever = data_lake()
# ---------- Setting up a Memory System for the ChatBot ----------
@st.cache_resource()
def memory():
# only keeps list of last K interactions
memory=ConversationBufferWindowMemory(
k=3,
memory_key="chat_history",
return_messages=True,
output_key='answer'
)
return memory
memory=memory()
# ---------- Initiates the LLM Chat model ----------
llm = ChatOpenAI(
temperature=0,
model='gpt-3.5-turbo',
streaming=True,
max_tokens=1000
)
# ---------- Builds the Conversational Chain ----------
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=compression_retriever,
memory=memory,
verbose=True,
chain_type="stuff",
return_source_documents=True
)
# =========================
# Triggers the clearing of the cache and session states
if st.sidebar.button("Start a New Chat Interaction"):
clear_cache_and_session()
# Initializes chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# -------- Verbose Display Code --------
def capture_and_display_output(func: Callable[..., Any], args, **kwargs) -> Any:
# Capture the standard output
original_stdout = sys.stdout
sys.stdout = output_catcher = io.StringIO()
# Run the given function and capture its output
response = func(args, **kwargs)
# Reset the standard output to its original value
sys.stdout = original_stdout
# Clean the captured output
output_text = output_catcher.getvalue()
clean_text = re.sub(r"\x1b[.?[@-~]", "", output_text)
# Custom CSS for the response box
st.markdown("""
<style>
.response-value {
border: 2px solid #6c757d;
border-radius: 5px;
padding: 20px;
background-color: #f8f9fa;
color: #3d3d3d;
font-size: 20px; # Change this value to adjust the text size
font-family: monospace;
}
</style>
""", unsafe_allow_html=True)
with st.expander("See Langchain Thought Process"):
# Display the cleaned text as code
st.code(clean_text)
return response
# ------ Function for handling chat interactions ------
def chat_ui(qa):
# Accept user input
if prompt := st.chat_input(
"Ask me questions: How can I retrieve data from Deep Lake in Langchain?"
):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# Load the memory variables, which include the chat history
memory_variables = memory.load_memory_variables({})
# Predict the AI's response in the conversation
with st.spinner("Searching course material"):
response = capture_and_display_output(
qa, ({"question": prompt, "chat_history": memory_variables})
)
# Display chat response
full_response += response["answer"]
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
# Display top 2 retrieved sources
source = response["source_documents"][0].metadata
source2 = response["source_documents"][1].metadata
with st.expander("See Resources"):
st.write(f"Title: {source['title'].split('·')[0].strip()}")
st.write(f"Source: {source['source']}")
st.write(f"Relevance to Query: {source['relevance_score'] * 100}%")
st.write(f"Title: {source2['title'].split('·')[0].strip()}")
st.write(f"Source: {source2['source']}")
st.write(f"Relevance to Query: {source2['relevance_score'] * 100}%")
# Append message to session state
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
def main():
# Run function passing the ConversationalRetrievalChain
chat_ui(qa)
if __name__ == "__main__":
main()