Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/admin-api-lib/src/admin_api_lib/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from rag_core_lib.impl.settings.ollama_llm_settings import OllamaSettings
from rag_core_lib.impl.settings.rag_class_types_settings import RAGClassTypeSettings
from rag_core_lib.impl.settings.stackit_vllm_settings import StackitVllmSettings
from rag_core_lib.impl.tracers.langfuse_traced_chain import LangfuseTracedGraph
from rag_core_lib.impl.tracers.langfuse_traced_runnable import LangfuseTracedRunnable
from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore


Expand Down Expand Up @@ -147,7 +147,7 @@ class DependencyContainer(DeclarativeContainer):
summary_enhancer,
)
information_enhancer = Singleton(
LangfuseTracedGraph,
LangfuseTracedRunnable,
inner_chain=untraced_information_enhancer,
settings=langfuse_settings,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig

from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable

RetrieverInput = list[Document]
RetrieverOutput = list[Document]


class InformationEnhancer(AsyncChain[RetrieverInput, RetrieverOutput], ABC):
class InformationEnhancer(AsyncRunnable[RetrieverInput, RetrieverOutput], ABC):
"""The base class for an information enhancer."""

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions libs/admin-api-lib/src/admin_api_lib/summarizer/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from langchain_core.runnables import RunnableConfig

from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable

SummarizerInput = str
SummarizerOutput = str


class Summarizer(AsyncChain[SummarizerInput, SummarizerOutput], ABC):
class Summarizer(AsyncRunnable[SummarizerInput, SummarizerOutput], ABC):
"""Baseclass for summarizers."""

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions libs/rag-core-api/src/rag_core_api/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from rag_core_lib.impl.settings.ollama_llm_settings import OllamaSettings
from rag_core_lib.impl.settings.rag_class_types_settings import RAGClassTypeSettings
from rag_core_lib.impl.settings.stackit_vllm_settings import StackitVllmSettings
from rag_core_lib.impl.tracers.langfuse_traced_chain import LangfuseTracedGraph
from rag_core_lib.impl.tracers.langfuse_traced_runnable import LangfuseTracedRunnable
from rag_core_lib.impl.utils.async_threadsafe_semaphore import AsyncThreadsafeSemaphore


Expand Down Expand Up @@ -218,7 +218,7 @@ class DependencyContainer(DeclarativeContainer):

# wrap graph in tracer
traced_chat_graph = Singleton(
LangfuseTracedGraph,
LangfuseTracedRunnable,
inner_chain=chat_graph,
settings=langfuse_settings,
)
Expand Down
4 changes: 2 additions & 2 deletions libs/rag-core-api/src/rag_core_api/graph/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from langchain_core.runnables.utils import Input, Output
from langgraph.graph import StateGraph

from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable


class GraphBase(AsyncChain[Input, Output], ABC):
class GraphBase(AsyncRunnable[Input, Output], ABC):
"""
Base class for a langgraph graph.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from langchain_core.output_parsers import StrOutputParser

from rag_core_api.impl.graph.graph_state.graph_state import AnswerGraphState
from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable
from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager

RunnableInput = AnswerGraphState
RunnableOutput = str


class AnswerGenerationChain(AsyncChain[RunnableInput, RunnableOutput]):
class AnswerGenerationChain(AsyncRunnable[RunnableInput, RunnableOutput]):
"""Base class for LLM answer generation chain."""

def __init__(self, langfuse_manager: LangfuseManager):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from langchain_core.runnables import Runnable, RunnableConfig

from rag_core_api.impl.graph.graph_state.graph_state import AnswerGraphState
from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable
from rag_core_lib.impl.langfuse_manager.langfuse_manager import LangfuseManager

RunnableInput = AnswerGraphState
RunnableOutput = str


class RephrasingChain(AsyncChain[RunnableInput, RunnableOutput]):
class RephrasingChain(AsyncRunnable[RunnableInput, RunnableOutput]):
"""Base class for rephrasing of the input question."""

def __init__(self, langfuse_manager: LangfuseManager):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from rag_core_api.api_endpoints.chat import Chat
from rag_core_api.models.chat_request import ChatRequest
from rag_core_api.models.chat_response import ChatResponse
from rag_core_lib.tracers.traced_chain import TracedGraph
from rag_core_lib.tracers.traced_runnable import TracedRunnable


class DefaultChat(Chat):
"""DefaultChat is a class that handles chat interactions using a traced graph."""

def __init__(self, chat_graph: TracedGraph):
def __init__(self, chat_graph: TracedRunnable):
"""
Initialize the DefaultChat instance.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from langfuse.langchain import CallbackHandler

from rag_core_lib.impl.settings.langfuse_settings import LangfuseSettings
from rag_core_lib.tracers.traced_chain import TracedGraph
from rag_core_lib.tracers.traced_runnable import TracedRunnable


class LangfuseTracedGraph(TracedGraph):
class LangfuseTracedRunnable(TracedRunnable):
"""A class to trace the execution of a Runnable using Langfuse.

This class wraps an inner Runnable and adds tracing capabilities using the Langfuse tracer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain_core.runnables.utils import Input, Output


class AsyncChain(Runnable[Input, Output], ABC):
class AsyncRunnable(Runnable[Input, Output], ABC):
"""Base class for asynchronous chains."""

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langfuse import get_client

from rag_core_lib.chains.async_chain import AsyncChain
from rag_core_lib.runnables.async_runnable import AsyncRunnable

RunnableInput = Any
RunnableOutput = Any


class TracedGraph(AsyncChain[RunnableInput, RunnableOutput], ABC):
class TracedRunnable(AsyncRunnable[RunnableInput, RunnableOutput], ABC):
"""A class to represent a traced graph in an asynchronous chain.

This class is designed to wrap around an inner Runnable chain and add tracing capabilities to it.
Expand Down
4 changes: 2 additions & 2 deletions services/rag-backend/chat_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from rag_core_api.api_endpoints.chat import Chat
from rag_core_api.models.chat_request import ChatRequest
from rag_core_api.models.chat_response import ChatResponse
from rag_core_lib.tracers.traced_chain import TracedGraph
from rag_core_lib.tracers.traced_runnable import TracedRunnable

logger = logging.getLogger(__name__)


class UseCaseChat(Chat):
def __init__(self, chat_graph: TracedGraph):
def __init__(self, chat_graph: TracedRunnable):
self._chat_graph = chat_graph

async def achat(
Expand Down