diff --git a/cortex_on/agents/code_agent.py b/cortex_on/agents/code_agent.py
index 43fa047..d1eb785 100644
--- a/cortex_on/agents/code_agent.py
+++ b/cortex_on/agents/code_agent.py
@@ -15,7 +15,7 @@
from pydantic_ai.models.anthropic import AnthropicModel
# Local application imports
-from utils.ant_client import get_client
+from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client
from utils.stream_response_format import StreamResponse
load_dotenv()
@@ -25,6 +25,7 @@
class CoderAgentDeps:
websocket: Optional[WebSocket] = None
stream_output: Optional[StreamResponse] = None
+ model_preference: str = "Anthropic"
# Constants
ALLOWED_COMMANDS = {
@@ -236,180 +237,188 @@ async def send_stream_update(ctx: RunContext[CoderAgentDeps], message: str) -> N
stream_output_json = json.dumps(asdict(ctx.deps.stream_output))
logfire.debug("WebSocket message sent: {stream_output_json}", stream_output_json=stream_output_json)
-# Initialize the model
-model = AnthropicModel(
- model_name=os.environ.get("ANTHROPIC_MODEL_NAME"),
- anthropic_client=get_client()
-)
+
# Initialize the agent
-coder_agent = Agent(
+
+async def coder_agent(model_preference: str = "Anthropic") -> Agent:
+ if model_preference == "Anthropic":
+ model = get_anthropic_model_instance()
+ elif model_preference == "OpenAI":
+ model = get_openai_model_instance()
+ else:
+ raise ValueError(f"Unknown model_preference: {model_preference}")
+ print(f"[CODER_INIT] Creating coder agent with model: {model}")
+ coder_agent = Agent(
model=model,
name="Coder Agent",
result_type=CoderResult,
deps_type=CoderAgentDeps,
system_prompt=coder_system_message
-)
-
-@coder_agent.tool
-async def execute_shell(ctx: RunContext[CoderAgentDeps], command: str) -> str:
- """
- Executes a shell command within a restricted directory and returns the output.
- This consolidated tool handles terminal commands and file operations.
- """
- try:
- # Extract base command for security checks and messaging
- base_command = command.split()[0] if command.split() else ""
-
- # Send operation description message
- operation_message = get_high_level_operation_message(command, base_command)
- await send_stream_update(ctx, operation_message)
-
- logfire.info("Executing shell command: {command}", command=command)
-
- # Setup restricted directory
- base_dir = os.path.abspath(os.path.dirname(__file__))
- restricted_dir = os.path.join(base_dir, "code_files")
- os.makedirs(restricted_dir, exist_ok=True)
-
- # Security check
- if base_command not in ALLOWED_COMMANDS:
- await send_stream_update(ctx, "Operation not permitted")
- return f"Error: Command '{base_command}' is not allowed for security reasons."
-
- # Change to restricted directory for execution
- original_dir = os.getcwd()
- os.chdir(restricted_dir)
-
+ )
+
+ @coder_agent.tool
+ async def execute_shell(ctx: RunContext[CoderAgentDeps], command: str) -> str:
+ """
+ Executes a shell command within a restricted directory and returns the output.
+ This consolidated tool handles terminal commands and file operations.
+ """
try:
- # Handle echo with redirection (file writing)
- if ">" in command and base_command == "echo":
- file_path = command.split(">", 1)[1].strip()
- await send_stream_update(ctx, f"Writing content to {file_path}")
-
- # Parse command parts
- parts = command.split(">", 1)
- echo_cmd = parts[0].strip()
-
- # Extract content, removing quotes if present
- content = echo_cmd[5:].strip()
- if (content.startswith('"') and content.endswith('"')) or \
- (content.startswith("'") and content.endswith("'")):
- content = content[1:-1]
-
- try:
- with open(file_path, "w") as file:
- file.write(content)
-
- await send_stream_update(ctx, f"File {file_path} created successfully")
- return f"Successfully wrote to {file_path}"
- except Exception as e:
- error_msg = f"Error writing to file: {str(e)}"
- await send_stream_update(ctx, f"Failed to create file {file_path}")
- logfire.error(error_msg, exc_info=True)
- return error_msg
+ # Extract base command for security checks and messaging
+ base_command = command.split()[0] if command.split() else ""
- # Handle cat with here-document for multiline file writing
- elif "<<" in command and base_command == "cat":
- cmd_parts = command.split("<<", 1)
- cat_part = cmd_parts[0].strip()
-
- # Extract filename for status message if possible
- file_path = None
- if ">" in cat_part:
- file_path = cat_part.split(">", 1)[1].strip()
- await send_stream_update(ctx, f"Creating file {file_path}")
+ # Send operation description message
+ operation_message = get_high_level_operation_message(command, base_command)
+ await send_stream_update(ctx, operation_message)
+
+ logfire.info("Executing shell command: {command}", command=command)
+
+ # Setup restricted directory
+ base_dir = os.path.abspath(os.path.dirname(__file__))
+ restricted_dir = os.path.join(base_dir, "code_files")
+ os.makedirs(restricted_dir, exist_ok=True)
+
+ # Security check
+ if base_command not in ALLOWED_COMMANDS:
+ await send_stream_update(ctx, "Operation not permitted")
+ return f"Error: Command '{base_command}' is not allowed for security reasons."
+
+ # Change to restricted directory for execution
+ original_dir = os.getcwd()
+ os.chdir(restricted_dir)
+
+ try:
+ # Handle echo with redirection (file writing)
+ if ">" in command and base_command == "echo":
+ file_path = command.split(">", 1)[1].strip()
+ await send_stream_update(ctx, f"Writing content to {file_path}")
+
+ # Parse command parts
+ parts = command.split(">", 1)
+ echo_cmd = parts[0].strip()
+
+ # Extract content, removing quotes if present
+ content = echo_cmd[5:].strip()
+ if (content.startswith('"') and content.endswith('"')) or \
+ (content.startswith("'") and content.endswith("'")):
+ content = content[1:-1]
+
+ try:
+ with open(file_path, "w") as file:
+ file.write(content)
+
+ await send_stream_update(ctx, f"File {file_path} created successfully")
+ return f"Successfully wrote to {file_path}"
+ except Exception as e:
+ error_msg = f"Error writing to file: {str(e)}"
+ await send_stream_update(ctx, f"Failed to create file {file_path}")
+ logfire.error(error_msg, exc_info=True)
+ return error_msg
- try:
- # Parse heredoc parts
- doc_part = cmd_parts[1].strip()
+ # Handle cat with here-document for multiline file writing
+ elif "<<" in command and base_command == "cat":
+ cmd_parts = command.split("<<", 1)
+ cat_part = cmd_parts[0].strip()
- # Extract filename
+ # Extract filename for status message if possible
+ file_path = None
if ">" in cat_part:
file_path = cat_part.split(">", 1)[1].strip()
- else:
- await send_stream_update(ctx, "Invalid file operation")
- return "Error: Invalid cat command format. Must include redirection."
+ await send_stream_update(ctx, f"Creating file {file_path}")
- # Parse the heredoc content and delimiter
- if "\n" in doc_part:
- delimiter_and_content = doc_part.split("\n", 1)
- delimiter = delimiter_and_content[0].strip("'").strip('"')
- content = delimiter_and_content[1]
+ try:
+ # Parse heredoc parts
+ doc_part = cmd_parts[1].strip()
- # Find the end delimiter and extract content
- if f"\n{delimiter}" in content:
- content = content.split(f"\n{delimiter}")[0]
-
- # Write to file
- with open(file_path, "w") as file:
- file.write(content)
+ # Extract filename
+ if ">" in cat_part:
+ file_path = cat_part.split(">", 1)[1].strip()
+ else:
+ await send_stream_update(ctx, "Invalid file operation")
+ return "Error: Invalid cat command format. Must include redirection."
+
+ # Parse the heredoc content and delimiter
+ if "\n" in doc_part:
+ delimiter_and_content = doc_part.split("\n", 1)
+ delimiter = delimiter_and_content[0].strip("'").strip('"')
+ content = delimiter_and_content[1]
- await send_stream_update(ctx, f"File {file_path} created successfully")
- return f"Successfully wrote multiline content to {file_path}"
+ # Find the end delimiter and extract content
+ if f"\n{delimiter}" in content:
+ content = content.split(f"\n{delimiter}")[0]
+
+ # Write to file
+ with open(file_path, "w") as file:
+ file.write(content)
+
+ await send_stream_update(ctx, f"File {file_path} created successfully")
+ return f"Successfully wrote multiline content to {file_path}"
+ else:
+ await send_stream_update(ctx, "File content format error")
+ return "Error: End delimiter not found in heredoc"
else:
await send_stream_update(ctx, "File content format error")
- return "Error: End delimiter not found in heredoc"
- else:
- await send_stream_update(ctx, "File content format error")
- return "Error: Invalid heredoc format"
- except Exception as e:
- error_msg = f"Error processing cat with heredoc: {str(e)}"
- file_path_str = file_path if file_path else 'file'
- await send_stream_update(ctx, f"Failed to create file {file_path_str}")
- logfire.error(error_msg, exc_info=True)
- return error_msg
-
- # Execute standard commands
- else:
- # Send execution message
- execution_msg = get_high_level_execution_message(command, base_command)
- await send_stream_update(ctx, execution_msg)
+ return "Error: Invalid heredoc format"
+ except Exception as e:
+ error_msg = f"Error processing cat with heredoc: {str(e)}"
+ file_path_str = file_path if file_path else 'file'
+ await send_stream_update(ctx, f"Failed to create file {file_path_str}")
+ logfire.error(error_msg, exc_info=True)
+ return error_msg
- # Execute the command using subprocess
- try:
- args = shlex.split(command)
- result = subprocess.run(
- args,
- shell=True,
- capture_output=True,
- text=True,
- timeout=60,
- )
+ # Execute standard commands
+ else:
+ # Send execution message
+ execution_msg = get_high_level_execution_message(command, base_command)
+ await send_stream_update(ctx, execution_msg)
- logfire.info(f"Command executed: {result.args}")
+ # Execute the command using subprocess
+ try:
+ args = shlex.split(command)
+ result = subprocess.run(
+ args,
+ shell=True,
+ capture_output=True,
+ text=True,
+ timeout=60,
+ )
+
+ logfire.info(f"Command executed: {result.args}")
+
+ # Handle success
+ if result.returncode == 0:
+ success_msg = get_success_message(command, base_command)
+ await send_stream_update(ctx, success_msg)
+ logfire.info(f"Command executed successfully: {result.stdout}")
+ return result.stdout
+
+ # Handle failure
+ else:
+ files = os.listdir('.')
+ error_msg = f"Command failed with error code {result.returncode}:\n{result.stderr}\n\nFiles in directory: {files}"
+ failure_msg = get_failure_message(command, base_command)
+ await send_stream_update(ctx, failure_msg)
+ return error_msg
- # Handle success
- if result.returncode == 0:
- success_msg = get_success_message(command, base_command)
- await send_stream_update(ctx, success_msg)
- logfire.info(f"Command executed successfully: {result.stdout}")
- return result.stdout
+ except subprocess.TimeoutExpired:
+ await send_stream_update(ctx, "Operation timed out")
+ return "Command execution timed out after 60 seconds"
- # Handle failure
- else:
- files = os.listdir('.')
- error_msg = f"Command failed with error code {result.returncode}:\n{result.stderr}\n\nFiles in directory: {files}"
- failure_msg = get_failure_message(command, base_command)
- await send_stream_update(ctx, failure_msg)
+ except Exception as e:
+ error_msg = f"Error executing command: {str(e)}"
+ await send_stream_update(ctx, "Operation failed")
+ logfire.error(error_msg, exc_info=True)
return error_msg
+
+ finally:
+ # Always return to the original directory
+ os.chdir(original_dir)
- except subprocess.TimeoutExpired:
- await send_stream_update(ctx, "Operation timed out")
- return "Command execution timed out after 60 seconds"
-
- except Exception as e:
- error_msg = f"Error executing command: {str(e)}"
- await send_stream_update(ctx, "Operation failed")
- logfire.error(error_msg, exc_info=True)
- return error_msg
+ except Exception as e:
+ error_msg = f"Error executing command: {str(e)}"
+ await send_stream_update(ctx, "Operation failed")
+ logfire.error(error_msg, exc_info=True)
+ return error_msg
- finally:
- # Always return to the original directory
- os.chdir(original_dir)
-
- except Exception as e:
- error_msg = f"Error executing command: {str(e)}"
- await send_stream_update(ctx, "Operation failed")
- logfire.error(error_msg, exc_info=True)
- return error_msg
\ No newline at end of file
+ logfire.info("All tools initialized for coder agent")
+ return coder_agent
\ No newline at end of file
diff --git a/cortex_on/agents/orchestrator_agent.py b/cortex_on/agents/orchestrator_agent.py
index 6b001ba..e08c881 100644
--- a/cortex_on/agents/orchestrator_agent.py
+++ b/cortex_on/agents/orchestrator_agent.py
@@ -12,9 +12,9 @@
from pydantic_ai import Agent, RunContext
from agents.web_surfer import WebSurfer
from utils.stream_response_format import StreamResponse
-from agents.planner_agent import planner_agent, update_todo_status
+from agents.planner_agent import planner_agent
from agents.code_agent import coder_agent, CoderAgentDeps
-from utils.ant_client import get_client
+from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client
@dataclass
class orchestrator_deps:
@@ -22,6 +22,7 @@ class orchestrator_deps:
stream_output: Optional[StreamResponse] = None
# Add a collection to track agent-specific streams
agent_responses: Optional[List[StreamResponse]] = None
+ model_preference: str = "Anthropic"
orchestrator_system_prompt = """You are an AI orchestrator that manages a team of agents to solve tasks. You have access to tools for coordinating the agents and managing the task flow.
@@ -158,347 +159,377 @@ class orchestrator_deps:
- Format: "Task description (agent_name)"
"""
-model = AnthropicModel(
- model_name=os.environ.get("ANTHROPIC_MODEL_NAME"),
- anthropic_client=get_client()
-)
-
-orchestrator_agent = Agent(
- model=model,
- name="Orchestrator Agent",
- system_prompt=orchestrator_system_prompt,
- deps_type=orchestrator_deps
-)
-
-@orchestrator_agent.tool
-async def plan_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
- """Plans the task and assigns it to the appropriate agents"""
- try:
- logfire.info(f"Planning task: {task}")
-
- # Create a new StreamResponse for Planner Agent
- planner_stream_output = StreamResponse(
- agent_name="Planner Agent",
- instructions=task,
- steps=[],
- output="",
- status_code=0
- )
-
- # Add to orchestrator's response collection if available
- if ctx.deps.agent_responses is not None:
- ctx.deps.agent_responses.append(planner_stream_output)
+async def orchestrator_agent(model_preference: str) -> Agent:
+ if model_preference == "Anthropic":
+
+ model = get_anthropic_model_instance()
+ elif model_preference == "OpenAI":
+ model = get_openai_model_instance()
+ else:
+ raise ValueError(f"Unknown model_preference: {model_preference}")
+ orchestrator_agent = Agent(
+ model=model,
+ name="Orchestrator Agent",
+ system_prompt=orchestrator_system_prompt,
+ deps_type=orchestrator_deps
+ )
+
+ @orchestrator_agent.tool
+ async def plan_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
+ """Plans the task and assigns it to the appropriate agents"""
+ try:
+ logfire.info(f"Planning task: {task}")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
-
- # Update planner stream
- planner_stream_output.steps.append("Planning task...")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
-
- # Run planner agent
- planner_response = await planner_agent.run(user_prompt=task)
-
- # Update planner stream with results
- plan_text = planner_response.data.plan
- planner_stream_output.steps.append("Task planned successfully")
- planner_stream_output.output = plan_text
- planner_stream_output.status_code = 200
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
-
- # Also update orchestrator stream
- ctx.deps.stream_output.steps.append("Task planned successfully")
- await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
-
- return f"Task planned successfully\nTask: {plan_text}"
- except Exception as e:
- error_msg = f"Error planning task: {str(e)}"
- logfire.error(error_msg, exc_info=True)
-
- # Update planner stream with error
- if planner_stream_output:
- planner_stream_output.steps.append(f"Planning failed: {str(e)}")
- planner_stream_output.status_code = 500
+ # Create a new StreamResponse for Planner Agent
+ planner_stream_output = StreamResponse(
+ agent_name="Planner Agent",
+ instructions=task,
+ steps=[],
+ output="",
+ status_code=0
+ )
+
+ # Add to orchestrator's response collection if available
+ if ctx.deps.agent_responses is not None:
+ ctx.deps.agent_responses.append(planner_stream_output)
+
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Update planner stream
+ planner_stream_output.steps.append("Planning task...")
await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
- # Also update orchestrator stream
- if ctx.deps.stream_output:
- ctx.deps.stream_output.steps.append(f"Planning failed: {str(e)}")
+ # Run planner agent
+ agent = await planner_agent(model_preference=ctx.deps.model_preference)
+ planner_response = await agent.run(user_prompt=task)
+
+ logfire.info(f"Planner Agent using model type: {ctx.deps.model_preference}")
+
+ # Update planner stream with results
+ plan_text = planner_response.data.plan
+ planner_stream_output.steps.append("Task planned successfully")
+ planner_stream_output.output = plan_text
+ planner_stream_output.status_code = 200
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Also update orchestrator stream
+ ctx.deps.stream_output.steps.append("Task planned successfully")
await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
- return f"Failed to plan task: {error_msg}"
+ return f"Task planned successfully\nTask: {plan_text}"
+ except Exception as e:
+ error_msg = f"Error planning task: {str(e)}"
+ logfire.error(error_msg, exc_info=True)
+
+ # Update planner stream with error
+ if planner_stream_output:
+ planner_stream_output.steps.append(f"Planning failed: {str(e)}")
+ planner_stream_output.status_code = 500
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Also update orchestrator stream
+ if ctx.deps.stream_output:
+ ctx.deps.stream_output.steps.append(f"Planning failed: {str(e)}")
+ await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
+
+ return f"Failed to plan task: {error_msg}"
-@orchestrator_agent.tool
-async def coder_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
- """Assigns coding tasks to the coder agent"""
- try:
- logfire.info(f"Assigning coding task: {task}")
-
- # Create a new StreamResponse for Coder Agent
- coder_stream_output = StreamResponse(
- agent_name="Coder Agent",
- instructions=task,
- steps=[],
- output="",
- status_code=0
- )
-
- # Add to orchestrator's response collection if available
- if ctx.deps.agent_responses is not None:
- ctx.deps.agent_responses.append(coder_stream_output)
-
- # Send initial update for Coder Agent
- await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
-
- # Create deps with the new stream_output
- deps_for_coder_agent = CoderAgentDeps(
- websocket=ctx.deps.websocket,
- stream_output=coder_stream_output
- )
-
- # Run coder agent
- coder_response = await coder_agent.run(
- user_prompt=task,
- deps=deps_for_coder_agent
- )
-
- # Extract response data
- response_data = coder_response.data.content
-
- # Update coder_stream_output with coding results
- coder_stream_output.output = response_data
- coder_stream_output.status_code = 200
- coder_stream_output.steps.append("Coding task completed successfully")
- await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
-
- # Add a reminder in the result message to update the plan using planner_agent_update
- response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\""
-
- return response_with_reminder
- except Exception as e:
- error_msg = f"Error assigning coding task: {str(e)}"
- logfire.error(error_msg, exc_info=True)
+ @orchestrator_agent.tool
+ async def coder_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
+ """Assigns coding tasks to the coder agent"""
+ try:
+ logfire.info(f"Assigning coding task: {task}")
+
+ # Create a new StreamResponse for Coder Agent
+ coder_stream_output = StreamResponse(
+ agent_name="Coder Agent",
+ instructions=task,
+ steps=[],
+ output="",
+ status_code=0
+ )
+
+ # Add to orchestrator's response collection if available
+ if ctx.deps.agent_responses is not None:
+ ctx.deps.agent_responses.append(coder_stream_output)
+
+ # Send initial update for Coder Agent
+ await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
+
+ # Create deps with the new stream_output
+ deps_for_coder_agent = CoderAgentDeps(
+ websocket=ctx.deps.websocket,
+ stream_output=coder_stream_output,
+ model_preference=ctx.deps.model_preference
+ )
+
+ # Run coder agent
+ agent = await coder_agent(model_preference=ctx.deps.model_preference)
+ coder_response = await agent.run(
+ user_prompt=task,
+ deps=deps_for_coder_agent
+ )
+ logfire.info(f"coder_response: {coder_response}")
+ logfire.info(f"Coder Agent using model type: {ctx.deps.model_preference}")
+
+
+ # Extract response data
+ response_data = coder_response.data.content
+
+ # Update coder_stream_output with coding results
+ coder_stream_output.output = response_data
+ coder_stream_output.status_code = 200
+ coder_stream_output.steps.append("Coding task completed successfully")
+ await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
+
+ # Add a reminder in the result message to update the plan using planner_agent_update
+ response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\""
+
+ return response_with_reminder
+ except Exception as e:
+ error_msg = f"Error assigning coding task: {str(e)}"
+ logfire.error(error_msg, exc_info=True)
- # Update coder_stream_output with error
- coder_stream_output.steps.append(f"Coding task failed: {str(e)}")
- coder_stream_output.status_code = 500
- await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
+ # Update coder_stream_output with error
+ coder_stream_output.steps.append(f"Coding task failed: {str(e)}")
+ coder_stream_output.status_code = 500
+ await _safe_websocket_send(ctx.deps.websocket, coder_stream_output)
- return f"Failed to assign coding task: {error_msg}"
+ return f"Failed to assign coding task: {error_msg}"
-@orchestrator_agent.tool
-async def web_surfer_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
- """Assigns web surfing tasks to the web surfer agent"""
- try:
- logfire.info(f"Assigning web surfing task: {task}")
-
- # Create a new StreamResponse for WebSurfer
- web_surfer_stream_output = StreamResponse(
- agent_name="Web Surfer",
- instructions=task,
- steps=[],
- output="",
- status_code=0,
- live_url=None
- )
-
- # Add to orchestrator's response collection if available
- if ctx.deps.agent_responses is not None:
- ctx.deps.agent_responses.append(web_surfer_stream_output)
-
- await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
-
- # Initialize WebSurfer agent
- web_surfer_agent = WebSurfer(api_url="http://localhost:8000/api/v1/web/stream")
-
- # Run WebSurfer with its own stream_output
- success, message, messages = await web_surfer_agent.generate_reply(
- instruction=task,
- websocket=ctx.deps.websocket,
- stream_output=web_surfer_stream_output
- )
-
- # Update WebSurfer's stream_output with final result
- if success:
- web_surfer_stream_output.steps.append("Web search completed successfully")
- web_surfer_stream_output.output = message
- web_surfer_stream_output.status_code = 200
-
- # Add a reminder to update the plan
- message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\""
- else:
- web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}")
- web_surfer_stream_output.status_code = 500
- message_with_reminder = message
-
- await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
-
- web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}")
- await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
-
- return message_with_reminder
- except Exception as e:
- error_msg = f"Error assigning web surfing task: {str(e)}"
- logfire.error(error_msg, exc_info=True)
-
- # Update WebSurfer's stream_output with error
- web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}")
- web_surfer_stream_output.status_code = 500
- await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
- return f"Failed to assign web surfing task: {error_msg}"
-
-@orchestrator_agent.tool
-async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str:
- """Sends a question to the frontend and waits for human input"""
- try:
- logfire.info(f"Asking human: {question}")
-
- # Create a new StreamResponse for Human Input
- human_stream_output = StreamResponse(
- agent_name="Human Input",
- instructions=question,
- steps=[],
- output="",
- status_code=0
- )
-
- # Add to orchestrator's response collection if available
- if ctx.deps.agent_responses is not None:
- ctx.deps.agent_responses.append(human_stream_output)
-
- # Send the question to frontend
- await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
-
- # Update stream with waiting message
- human_stream_output.steps.append("Waiting for human input...")
- await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
-
- # Wait for response from frontend
- response = await ctx.deps.websocket.receive_text()
-
- # Update stream with response
- human_stream_output.steps.append("Received human input")
- human_stream_output.output = response
- human_stream_output.status_code = 200
- await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
-
- return response
- except Exception as e:
- error_msg = f"Error getting human input: {str(e)}"
- logfire.error(error_msg, exc_info=True)
-
- # Update stream with error
- human_stream_output.steps.append(f"Failed to get human input: {str(e)}")
- human_stream_output.status_code = 500
- await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
-
- return f"Failed to get human input: {error_msg}"
-
-@orchestrator_agent.tool
-async def planner_agent_update(ctx: RunContext[orchestrator_deps], completed_task: str) -> str:
- """
- Updates the todo.md file to mark a task as completed and returns the full updated plan.
-
- Args:
- completed_task: Description of the completed task including which agent performed it
-
- Returns:
- The complete updated todo.md content with tasks marked as completed
- """
- try:
- logfire.info(f"Updating plan with completed task: {completed_task}")
-
- # Create a new StreamResponse for Planner Agent update
- planner_stream_output = StreamResponse(
- agent_name="Planner Agent",
- instructions=f"Update todo.md to mark as completed: {completed_task}",
- steps=[],
- output="",
- status_code=0
- )
-
- # Send initial update
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
-
- # Directly read and update the todo.md file
- base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- planner_dir = os.path.join(base_dir, "agents", "planner")
- todo_path = os.path.join(planner_dir, "todo.md")
-
- planner_stream_output.steps.append("Reading current todo.md...")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
-
- # Make sure the directory exists
- os.makedirs(planner_dir, exist_ok=True)
-
+ @orchestrator_agent.tool
+ async def web_surfer_task(ctx: RunContext[orchestrator_deps], task: str) -> str:
+ """Assigns web surfing tasks to the web surfer agent"""
try:
- # Check if todo.md exists
- if not os.path.exists(todo_path):
- planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+ logfire.info(f"Assigning web surfing task: {task}")
+
+ # Create a new StreamResponse for WebSurfer
+ web_surfer_stream_output = StreamResponse(
+ agent_name="Web Surfer",
+ instructions=task,
+ steps=[],
+ output="",
+ status_code=0,
+ live_url=None
+ )
+
+ # Add to orchestrator's response collection if available
+ if ctx.deps.agent_responses is not None:
+ ctx.deps.agent_responses.append(web_surfer_stream_output)
+
+ await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
+
+ # Initialize WebSurfer agent with model preference
+ web_surfer_agent = WebSurfer(
+ api_url="http://localhost:8000/api/v1/web/stream",
+ model_preference=ctx.deps.model_preference
+ )
+ logfire.info(f"web_surfing on model type: {ctx.deps.model_preference}")
- # We'll directly call planner_agent.run() to create a new plan first
- plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}"
- plan_response = await planner_agent.run(user_prompt=plan_prompt)
- current_content = plan_response.data.plan
- else:
- # Read existing todo.md
- with open(todo_path, "r") as file:
- current_content = file.read()
- planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+ # Run WebSurfer with its own stream_output
+ success, message, messages = await web_surfer_agent.generate_reply(
+ instruction=task,
+ websocket=ctx.deps.websocket,
+ stream_output=web_surfer_stream_output
+ )
- # Now call planner_agent.run() with specific instructions to update the plan
- update_prompt = f"""
- Here is the current todo.md content:
+ # Update WebSurfer's stream_output with final result
+ if success:
+ web_surfer_stream_output.steps.append("Web search completed successfully")
+ web_surfer_stream_output.output = message
+ web_surfer_stream_output.status_code = 200
+
+ # Add a reminder to update the plan
+ message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\""
+ else:
+ web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}")
+ web_surfer_stream_output.status_code = 500
+ message_with_reminder = message
- {current_content}
+ await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
- Please update this plan to mark the following task as completed: {completed_task}
- Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ].
- """
+ web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}")
+ await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
- planner_stream_output.steps.append("Asking planner to update the plan...")
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+ return message_with_reminder
+ except Exception as e:
+ error_msg = f"Error assigning web surfing task: {str(e)}"
+ logfire.error(error_msg, exc_info=True)
- updated_plan_response = await planner_agent.run(user_prompt=update_prompt)
- updated_plan = updated_plan_response.data.plan
+ # Update WebSurfer's stream_output with error
+ web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}")
+ web_surfer_stream_output.status_code = 500
+ await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output)
+ return f"Failed to assign web surfing task: {error_msg}"
+
+ @orchestrator_agent.tool
+ async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str:
+ """Sends a question to the frontend and waits for human input"""
+ try:
+ logfire.info(f"Asking human: {question}")
- # Write the updated plan back to todo.md
- with open(todo_path, "w") as file:
- file.write(updated_plan)
+ # Create a new StreamResponse for Human Input
+ human_stream_output = StreamResponse(
+ agent_name="Human Input",
+ instructions=question,
+ steps=[],
+ output="",
+ status_code=0
+ )
+
+ # Add to orchestrator's response collection if available
+ if ctx.deps.agent_responses is not None:
+ ctx.deps.agent_responses.append(human_stream_output)
+
+ # Send the question to frontend
+ await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
- planner_stream_output.steps.append("Plan updated successfully")
- planner_stream_output.output = updated_plan
- planner_stream_output.status_code = 200
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+ # Update stream with waiting message
+ human_stream_output.steps.append("Waiting for human input...")
+ await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
- # Update orchestrator stream
- if ctx.deps.stream_output:
- ctx.deps.stream_output.steps.append(f"Plan updated to mark task as completed: {completed_task}")
- await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
+ # Wait for response from frontend
+ response = await ctx.deps.websocket.receive_text()
- return updated_plan
+ # Update stream with response
+ human_stream_output.steps.append("Received human input")
+ human_stream_output.output = response
+ human_stream_output.status_code = 200
+ await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
+ return response
except Exception as e:
- error_msg = f"Error during plan update operations: {str(e)}"
+ error_msg = f"Error getting human input: {str(e)}"
logfire.error(error_msg, exc_info=True)
- planner_stream_output.steps.append(f"Plan update failed: {str(e)}")
- planner_stream_output.status_code = a500
- await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+ # Update stream with error
+ human_stream_output.steps.append(f"Failed to get human input: {str(e)}")
+ human_stream_output.status_code = 500
+ await _safe_websocket_send(ctx.deps.websocket, human_stream_output)
- return f"Failed to update the plan: {error_msg}"
-
- except Exception as e:
- error_msg = f"Error updating plan: {str(e)}"
- logfire.error(error_msg, exc_info=True)
+ return f"Failed to get human input: {error_msg}"
+
+ @orchestrator_agent.tool
+ async def planner_agent_update(ctx: RunContext[orchestrator_deps], completed_task: str) -> str:
+ """
+ Updates the todo.md file to mark a task as completed and returns the full updated plan.
- # Update stream output with error
- if ctx.deps.stream_output:
- ctx.deps.stream_output.steps.append(f"Failed to update plan: {str(e)}")
- await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
+ Args:
+ completed_task: Description of the completed task including which agent performed it
- return f"Failed to update plan: {error_msg}"
+ Returns:
+ The complete updated todo.md content with tasks marked as completed
+ """
+ try:
+ logfire.info(f"Updating plan with completed task: {completed_task}")
+
+ # Create a new StreamResponse for Planner Agent update
+ planner_stream_output = StreamResponse(
+ agent_name="Planner Agent",
+ instructions=f"Update todo.md to mark as completed: {completed_task}",
+ steps=[],
+ output="",
+ status_code=0
+ )
+
+ # Send initial update
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Directly read and update the todo.md file
+ base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+ planner_dir = os.path.join(base_dir, "agents", "planner")
+ todo_path = os.path.join(planner_dir, "todo.md")
+
+ planner_stream_output.steps.append("Reading current todo.md...")
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Make sure the directory exists
+ os.makedirs(planner_dir, exist_ok=True)
+
+ try:
+ # Check if todo.md exists
+ if not os.path.exists(todo_path):
+ planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.")
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # We'll directly call planner_agent.run() to create a new plan first
+ plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}"
+
+ agent = await planner_agent(model_preference=ctx.deps.model_preference)
+ plan_response = await agent.run(user_prompt=plan_prompt)
+ current_content = plan_response.data.plan
+ logfire.info(f"Updating plan task using model type: {ctx.deps.model_preference}")
+
+ else:
+ # Read existing todo.md
+ with open(todo_path, "r") as file:
+ current_content = file.read()
+ planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)")
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Now call planner_agent.run() with specific instructions to update the plan
+ update_prompt = f"""
+ Here is the current todo.md content:
+
+ {current_content}
+
+ Please update this plan to mark the following task as completed: {completed_task}
+ Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ].
+ """
+
+ planner_stream_output.steps.append("Asking planner to update the plan...")
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ agent = await planner_agent(model_preference=ctx.deps.model_preference)
+ updated_plan_response = await agent.run(user_prompt=update_prompt)
+ updated_plan = updated_plan_response.data.plan
+
+
+ logfire.info(f"Updating prompt for plan task using model type: {ctx.deps.model_preference}")
+
+ # Write the updated plan back to todo.md
+ with open(todo_path, "w") as file:
+ file.write(updated_plan)
+
+ planner_stream_output.steps.append("Plan updated successfully")
+ planner_stream_output.output = updated_plan
+ planner_stream_output.status_code = 200
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ # Update orchestrator stream
+ if ctx.deps.stream_output:
+ ctx.deps.stream_output.steps.append(f"Plan updated to mark task as completed: {completed_task}")
+ await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
+
+ return updated_plan
+
+ except Exception as e:
+ error_msg = f"Error during plan update operations: {str(e)}"
+ logfire.error(error_msg, exc_info=True)
+
+ planner_stream_output.steps.append(f"Plan update failed: {str(e)}")
+ planner_stream_output.status_code = 500
+ await _safe_websocket_send(ctx.deps.websocket, planner_stream_output)
+
+ return f"Failed to update the plan: {error_msg}"
+
+ except Exception as e:
+ error_msg = f"Error updating plan: {str(e)}"
+ logfire.error(error_msg, exc_info=True)
+
+ # Update stream output with error
+ if ctx.deps.stream_output:
+ ctx.deps.stream_output.steps.append(f"Failed to update plan: {str(e)}")
+ await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output)
+
+ return f"Failed to update plan: {error_msg}"
+
+
+
+ logfire.info("All tools initialized for orchestrator agent")
+ return orchestrator_agent
+
+
# Helper function for sending WebSocket messages
async def _safe_websocket_send(websocket: Optional[WebSocket], message: Any) -> bool:
diff --git a/cortex_on/agents/planner_agent.py b/cortex_on/agents/planner_agent.py
index 897c22f..d607d04 100644
--- a/cortex_on/agents/planner_agent.py
+++ b/cortex_on/agents/planner_agent.py
@@ -10,7 +10,7 @@
from pydantic_ai.models.anthropic import AnthropicModel
# Local application imports
-from utils.ant_client import get_client
+from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client
@@ -176,172 +176,178 @@
class PlannerResult(BaseModel):
plan: str = Field(description="The generated or updated plan in string format - this should be the complete plan text")
-model = AnthropicModel(
- model_name=os.environ.get("ANTHROPIC_MODEL_NAME"),
- anthropic_client=get_client()
-)
+async def planner_agent(model_preference: str = "Anthropic") -> Agent:
+ if model_preference == "Anthropic":
+ model = get_anthropic_model_instance()
+ elif model_preference == "OpenAI":
+ model = get_openai_model_instance()
+ else:
+ raise ValueError(f"Unknown model_preference: {model_preference}")
+ planner_agent = Agent(
+ model=model,
+ name="Planner Agent",
+ result_type=PlannerResult,
+ system_prompt=planner_prompt
+ )
-planner_agent = Agent(
- model=model,
- name="Planner Agent",
- result_type=PlannerResult,
- system_prompt=planner_prompt
-)
-
-@planner_agent.tool_plain
-async def update_todo_status(task_description: str) -> str:
- """
- A helper function that logs the update request but lets the planner agent handle the actual update logic.
-
- Args:
- task_description: Description of the completed task
+ @planner_agent.tool_plain
+ async def update_todo_status(task_description: str) -> str:
+ """
+ A helper function that logs the update request but lets the planner agent handle the actual update logic.
- Returns:
- A simple log message
- """
- logfire.info(f"Received request to update todo.md for task: {task_description}")
- return f"Received update request for: {task_description}"
+ Args:
+ task_description: Description of the completed task
+
+ Returns:
+ A simple log message
+ """
+ logfire.info(f"Received request to update todo.md for task: {task_description}")
+ return f"Received update request for: {task_description}"
-@planner_agent.tool_plain
-async def execute_terminal(command: str) -> str:
- """
- Executes a terminal command within the planner directory for file operations.
- This consolidated tool handles reading and writing plan files.
- Restricted to only read and write operations for security.
- """
- try:
- logfire.info(f"Executing terminal command: {command}")
-
- # Define the restricted directory
- base_dir = os.path.abspath(os.path.dirname(__file__))
- planner_dir = os.path.join(base_dir, "planner")
- os.makedirs(planner_dir, exist_ok=True)
-
- # Extract the base command
- base_command = command.split()[0]
-
- # Allow only read and write operations
- ALLOWED_COMMANDS = {"cat", "echo", "ls"}
-
- # Security checks
- if base_command not in ALLOWED_COMMANDS:
- return f"Error: Command '{base_command}' is not allowed. Only read and write operations are permitted."
-
- if ".." in command or "~" in command or "/" in command:
- return "Error: Path traversal attempts are not allowed."
-
- # Change to the restricted directory
- original_dir = os.getcwd()
- os.chdir(planner_dir)
-
+ @planner_agent.tool_plain
+ async def execute_terminal(command: str) -> str:
+ """
+ Executes a terminal command within the planner directory for file operations.
+ This consolidated tool handles reading and writing plan files.
+ Restricted to only read and write operations for security.
+ """
try:
- # Handle echo with >> (append)
- if base_command == "echo" and ">>" in command:
- try:
- # Split only on the first occurrence of >>
- parts = command.split(">>", 1)
- echo_part = parts[0].strip()
+ logfire.info(f"Executing terminal command: {command}")
+
+ # Define the restricted directory
+ base_dir = os.path.abspath(os.path.dirname(__file__))
+ planner_dir = os.path.join(base_dir, "planner")
+ os.makedirs(planner_dir, exist_ok=True)
+
+ # Extract the base command
+ base_command = command.split()[0]
+
+ # Allow only read and write operations
+ ALLOWED_COMMANDS = {"cat", "echo", "ls"}
+
+ # Security checks
+ if base_command not in ALLOWED_COMMANDS:
+ return f"Error: Command '{base_command}' is not allowed. Only read and write operations are permitted."
+
+ if ".." in command or "~" in command or "/" in command:
+ return "Error: Path traversal attempts are not allowed."
+
+ # Change to the restricted directory
+ original_dir = os.getcwd()
+ os.chdir(planner_dir)
+
+ try:
+ # Handle echo with >> (append)
+ if base_command == "echo" and ">>" in command:
+ try:
+ # Split only on the first occurrence of >>
+ parts = command.split(">>", 1)
+ echo_part = parts[0].strip()
+ file_path = parts[1].strip()
+
+ # Extract content after echo command
+ content = echo_part[4:].strip()
+
+ # Handle quotes if present
+ if (content.startswith('"') and content.endswith('"')) or \
+ (content.startswith("'") and content.endswith("'")):
+ content = content[1:-1]
+
+ # Append to file
+ with open(file_path, "a") as file:
+ file.write(content + "\n")
+ return f"Successfully appended to {file_path}"
+ except Exception as e:
+ logfire.error(f"Error appending to file: {str(e)}", exc_info=True)
+ return f"Error appending to file: {str(e)}"
+
+ # Special handling for echo with redirection (file writing)
+ elif ">" in command and base_command == "echo" and ">>" not in command:
+ # Simple parsing for echo "content" > file.txt
+ parts = command.split(">", 1)
+ echo_cmd = parts[0].strip()
file_path = parts[1].strip()
- # Extract content after echo command
- content = echo_part[4:].strip()
-
- # Handle quotes if present
+ # Extract content between echo and > (removing quotes if present)
+ content = echo_cmd[5:].strip()
if (content.startswith('"') and content.endswith('"')) or \
- (content.startswith("'") and content.endswith("'")):
+ (content.startswith("'") and content.endswith("'")):
content = content[1:-1]
- # Append to file
- with open(file_path, "a") as file:
- file.write(content + "\n")
- return f"Successfully appended to {file_path}"
- except Exception as e:
- logfire.error(f"Error appending to file: {str(e)}", exc_info=True)
- return f"Error appending to file: {str(e)}"
-
- # Special handling for echo with redirection (file writing)
- elif ">" in command and base_command == "echo" and ">>" not in command:
- # Simple parsing for echo "content" > file.txt
- parts = command.split(">", 1)
- echo_cmd = parts[0].strip()
- file_path = parts[1].strip()
+ # Write to file
+ try:
+ with open(file_path, "w") as file:
+ file.write(content)
+ return f"Successfully wrote to {file_path}"
+ except Exception as e:
+ logfire.error(f"Error writing to file: {str(e)}", exc_info=True)
+ return f"Error writing to file: {str(e)}"
- # Extract content between echo and > (removing quotes if present)
- content = echo_cmd[5:].strip()
- if (content.startswith('"') and content.endswith('"')) or \
- (content.startswith("'") and content.endswith("'")):
- content = content[1:-1]
-
- # Write to file
- try:
- with open(file_path, "w") as file:
- file.write(content)
- return f"Successfully wrote to {file_path}"
- except Exception as e:
- logfire.error(f"Error writing to file: {str(e)}", exc_info=True)
- return f"Error writing to file: {str(e)}"
-
- # Handle cat with here-document for multiline file writing
- elif "<<" in command and base_command == "cat":
- try:
- # Parse the command: cat > file.md << 'EOF'\nplan content\nEOF
- cmd_parts = command.split("<<", 1)
- cat_part = cmd_parts[0].strip()
- doc_part = cmd_parts[1].strip()
-
- # Extract filename
- if ">" in cat_part:
- file_path = cat_part.split(">", 1)[1].strip()
- else:
- return "Error: Invalid cat command format. Must include redirection."
-
- # Parse the heredoc content
- if "\n" in doc_part:
- delimiter_and_content = doc_part.split("\n", 1)
- delimiter = delimiter_and_content[0].strip("'").strip('"')
- content = delimiter_and_content[1]
+ # Handle cat with here-document for multiline file writing
+ elif "<<" in command and base_command == "cat":
+ try:
+ # Parse the command: cat > file.md << 'EOF'\nplan content\nEOF
+ cmd_parts = command.split("<<", 1)
+ cat_part = cmd_parts[0].strip()
+ doc_part = cmd_parts[1].strip()
- # Find the end delimiter and extract content
- if f"\n{delimiter}" in content:
- content = content.split(f"\n{delimiter}")[0]
+ # Extract filename
+ if ">" in cat_part:
+ file_path = cat_part.split(">", 1)[1].strip()
+ else:
+ return "Error: Invalid cat command format. Must include redirection."
+
+ # Parse the heredoc content
+ if "\n" in doc_part:
+ delimiter_and_content = doc_part.split("\n", 1)
+ delimiter = delimiter_and_content[0].strip("'").strip('"')
+ content = delimiter_and_content[1]
- # Write to file
- with open(file_path, "w") as file:
- file.write(content)
- return f"Successfully wrote multiline content to {file_path}"
+ # Find the end delimiter and extract content
+ if f"\n{delimiter}" in content:
+ content = content.split(f"\n{delimiter}")[0]
+
+ # Write to file
+ with open(file_path, "w") as file:
+ file.write(content)
+ return f"Successfully wrote multiline content to {file_path}"
+ else:
+ return "Error: End delimiter not found in heredoc"
else:
- return "Error: End delimiter not found in heredoc"
- else:
- return "Error: Invalid heredoc format"
- except Exception as e:
- logfire.error(f"Error processing cat with heredoc: {str(e)}", exc_info=True)
- return f"Error processing cat with heredoc: {str(e)}"
-
- # Handle cat for reading files
- elif base_command == "cat" and ">" not in command and "<<" not in command:
- try:
- file_path = command.split()[1]
- with open(file_path, "r") as file:
- content = file.read()
- return content
- except Exception as e:
- logfire.error(f"Error reading file: {str(e)}", exc_info=True)
- return f"Error reading file: {str(e)}"
-
- # Handle ls for listing files
- elif base_command == "ls":
- try:
- files = os.listdir('.')
- return "Files in planner directory:\n" + "\n".join(files)
- except Exception as e:
- logfire.error(f"Error listing files: {str(e)}", exc_info=True)
- return f"Error listing files: {str(e)}"
- else:
- return f"Error: Command '{command}' is not supported. Only read and write operations are permitted."
-
- finally:
- os.chdir(original_dir)
-
- except Exception as e:
- logfire.error(f"Error executing command: {str(e)}", exc_info=True)
- return f"Error executing command: {str(e)}"
\ No newline at end of file
+ return "Error: Invalid heredoc format"
+ except Exception as e:
+ logfire.error(f"Error processing cat with heredoc: {str(e)}", exc_info=True)
+ return f"Error processing cat with heredoc: {str(e)}"
+
+ # Handle cat for reading files
+ elif base_command == "cat" and ">" not in command and "<<" not in command:
+ try:
+ file_path = command.split()[1]
+ with open(file_path, "r") as file:
+ content = file.read()
+ return content
+ except Exception as e:
+ logfire.error(f"Error reading file: {str(e)}", exc_info=True)
+ return f"Error reading file: {str(e)}"
+
+ # Handle ls for listing files
+ elif base_command == "ls":
+ try:
+ files = os.listdir('.')
+ return "Files in planner directory:\n" + "\n".join(files)
+ except Exception as e:
+ logfire.error(f"Error listing files: {str(e)}", exc_info=True)
+ return f"Error listing files: {str(e)}"
+ else:
+ return f"Error: Command '{command}' is not supported. Only read and write operations are permitted."
+
+ finally:
+ os.chdir(original_dir)
+
+ except Exception as e:
+ logfire.error(f"Error executing command: {str(e)}", exc_info=True)
+ return f"Error executing command: {str(e)}"
+
+
+ logfire.info("All tools initialized for planner agent")
+ return planner_agent
\ No newline at end of file
diff --git a/cortex_on/agents/web_surfer.py b/cortex_on/agents/web_surfer.py
index 34e2cfd..2711348 100644
--- a/cortex_on/agents/web_surfer.py
+++ b/cortex_on/agents/web_surfer.py
@@ -29,19 +29,25 @@
TIMEOUT = 9999999999999999999999999999999999999999999
class WebSurfer:
- def __init__(self, api_url: str = "http://localhost:8000/api/v1/web/stream"):
+ def __init__(self, api_url: str = "http://localhost:8000/api/v1/web/stream", model_preference: str = "Anthropic"):
self.api_url = api_url
self.name = "Web Surfer Agent"
self.description = "An agent that is a websurfer and a webscraper that can access any web-page to extract information or perform actions."
self.websocket: Optional[WebSocket] = None
self.stream_output: Optional[StreamResponse] = None
+ self.model_preference = model_preference
async def _make_api_call(self, instruction: str) -> Tuple[int, List[Dict[str, Any]]]:
session_timeout = aiohttp.ClientTimeout(total=None, sock_connect=TIMEOUT, sock_read=TIMEOUT)
async with aiohttp.ClientSession(timeout=session_timeout) as session:
final_json_response = []
try:
- payload = {"cmd": instruction, "critique_disabled": False}
+ payload = {
+ "cmd": instruction,
+ "critique_disabled": False,
+ "model_preference": self.model_preference
+ }
+ logfire.info(f"Making API call with model_preference: {self.model_preference}")
async with session.post(self.api_url, json=payload) as response:
if response.status != 200:
error_text = await response.text()
diff --git a/cortex_on/instructor.py b/cortex_on/instructor.py
index b4f0efb..bf9670f 100644
--- a/cortex_on/instructor.py
+++ b/cortex_on/instructor.py
@@ -25,9 +25,6 @@
load_dotenv()
-
-
-
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that can handle datetime objects"""
def default(self, obj):
@@ -38,10 +35,12 @@ def default(self, obj):
# Main Orchestrator Class
class SystemInstructor:
- def __init__(self):
+ def __init__(self, model_preference: str = "Anthropic"):
+ logfire.info(f"Initializing SystemInstructor with model_preference: {model_preference}")
self.websocket: Optional[WebSocket] = None
self.stream_output: Optional[StreamResponse] = None
self.orchestrator_response: List[StreamResponse] = []
+ self.model_preference = model_preference
self._setup_logging()
def _setup_logging(self) -> None:
@@ -80,7 +79,8 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]:
deps_for_orchestrator = orchestrator_deps(
websocket=self.websocket,
stream_output=stream_output,
- agent_responses=self.orchestrator_response # Pass reference to collection
+ agent_responses=self.orchestrator_response, # Pass reference to collection
+ model_preference=self.model_preference
)
try:
@@ -88,11 +88,15 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]:
await self._safe_websocket_send(stream_output)
stream_output.steps.append("Agents initialized successfully")
await self._safe_websocket_send(stream_output)
-
- orchestrator_response = await orchestrator_agent.run(
+
+ agent = await orchestrator_agent(self.model_preference)
+ orchestrator_response = await agent.run(
user_prompt=task,
deps=deps_for_orchestrator
)
+
+ logfire.info(f"Orchestrator Agent using model type: {self.model_preference}")
+
stream_output.output = orchestrator_response.data
stream_output.status_code = 200
logfire.debug(f"Orchestrator response: {orchestrator_response.data}")
@@ -116,15 +120,11 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]:
finally:
logfire.info("Orchestration process complete")
- # Clear any sensitive data
+
+
async def shutdown(self):
"""Clean shutdown of orchestrator"""
try:
- # Close websocket if open
- if self.websocket:
- await self.websocket.close()
-
- # Clear all responses
self.orchestrator_response = []
logfire.info("Orchestrator shutdown complete")
diff --git a/cortex_on/main.py b/cortex_on/main.py
index a8dd4de..44b840b 100644
--- a/cortex_on/main.py
+++ b/cortex_on/main.py
@@ -1,27 +1,112 @@
# Standard library imports
from typing import List, Optional
+from contextlib import asynccontextmanager
# Third-party imports
-from fastapi import FastAPI, WebSocket
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect ,Depends
+from fastapi.middleware.cors import CORSMiddleware
+import logfire
+
+# Configure Logfire
+logfire.configure()
# Local application imports
from instructor import SystemInstructor
+# Default model preference is Anthropic
+MODEL_PREFERENCE = "Anthropic"
+
+# Global instructor instance
+instructor = None
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Set default model preference at startup
+ app.state.model_preference = MODEL_PREFERENCE
+ logfire.info(f"Setting default model preference to: {MODEL_PREFERENCE}")
+
+ # Initialize the instructor
+ global instructor
+ instructor = SystemInstructor(model_preference=MODEL_PREFERENCE)
+
+ yield
+
+
+
+app: FastAPI = FastAPI(lifespan=lifespan)
+
+# Add CORS middleware
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # Allow all origins for development
+ allow_credentials=True,
+ allow_methods=["*"], # Allow all methods
+ allow_headers=["*"], # Allow all headers
+)
+
+async def get_model_preference() -> str:
+ """
+ Get the current model preference from app state
+ """
+ logfire.info(f"Current model preference: {app.state.model_preference}")
+ return app.state.model_preference
-app: FastAPI = FastAPI()
+@app.get("/set_model_preference")
+async def set_model_preference(model: str):
+ """
+ Set the model preference (Anthropic or OpenAI) and reinitialize the instructor
+ """
+ if model not in ["Anthropic", "OpenAI"]:
+ logfire.error(f"Invalid model preference attempted: {model}")
+ return {"error": "Model must be 'Anthropic' or 'OpenAI'"}
+
+ logfire.info(f"Changing model preference from {app.state.model_preference} to {model}")
+ app.state.model_preference = model
+
+ # Reinitialize the instructor with new model preference
+ global instructor
+ if instructor:
+ await instructor.shutdown()
+ instructor = SystemInstructor(model_preference=model)
+ logfire.info(f"Instructor reinitialized with model preference: {model}")
+
+ return {"message": f"Model preference set to {model}"}
-async def generate_response(task: str, websocket: Optional[WebSocket] = None):
- orchestrator: SystemInstructor = SystemInstructor()
- return await orchestrator.run(task, websocket)
+async def generate_response(task: str, websocket: Optional[WebSocket] = None, model_preference: str = None):
+ if model_preference is None:
+ model_preference = app.state.model_preference
+ logfire.info(f"Using model preference: {model_preference} for task: {task[:30]}...")
+
+ global instructor
+ if not instructor:
+ instructor = SystemInstructor(model_preference=model_preference)
+
+ return await instructor.run(task, websocket)
@app.get("/agent/chat")
-async def agent_chat(task: str) -> List:
- final_agent_response = await generate_response(task)
+async def agent_chat(task: str, model_preference: str = Depends(get_model_preference)) -> List:
+ logfire.info(f"Received chat request with model preference: {model_preference}")
+ final_agent_response = await generate_response(task, model_preference=model_preference)
return final_agent_response
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
- while True:
- data = await websocket.receive_text()
- await generate_response(data, websocket)
+ model_preference = app.state.model_preference
+ logfire.info(f"New connection using model preference: {model_preference}")
+ try:
+ while True:
+ try:
+ data = await websocket.receive_text()
+ logfire.info(f"Received message, using model: {model_preference}")
+ await generate_response(data, websocket, model_preference)
+ except WebSocketDisconnect:
+ logfire.info("[WEBSOCKET] Client disconnected")
+ break
+ except Exception as e:
+ logfire.error(f"[WEBSOCKET] Error handling message: {str(e)}")
+ if "disconnect message has been received" in str(e):
+ print(f"[WEBSOCKET] DIsconnect detected, closing connection: {str(e)}")
+ break
+ except Exception as e:
+ logfire.error(f"[WEBSOCKET] Connection error: {str(e)}")
diff --git a/cortex_on/utils/ant_client.py b/cortex_on/utils/ant_client.py
index 924f0c6..c04732c 100644
--- a/cortex_on/utils/ant_client.py
+++ b/cortex_on/utils/ant_client.py
@@ -1,5 +1,6 @@
from pydantic_ai.models.anthropic import AnthropicModel
from anthropic import AsyncAnthropic
+from openai import AsyncOpenAI
import os
from dotenv import load_dotenv
@@ -12,3 +13,26 @@ def get_client():
max_retries=3,
timeout=10000)
return client
+
+
+
+def get_openai_client():
+ api_key = os.getenv("OPENAI_API_KEY")
+ client = AsyncOpenAI(api_key=api_key,
+ max_retries=3,
+ timeout=10000)
+ return client
+
+def get_openai_model_instance():
+ model_name = os.getenv("OPENAI_MODEL_NAME")
+ # Create model instance
+ from pydantic_ai.models.openai import OpenAIModel
+ model_instance = OpenAIModel(model_name=model_name, openai_client=get_openai_client())
+ return model_instance
+
+def get_anthropic_model_instance():
+ model_name = os.getenv("ANTHROPIC_MODEL_NAME")
+ # Create model instance
+ from pydantic_ai.models.anthropic import AnthropicModel
+ model_instance = AnthropicModel(model_name=model_name, anthropic_client=get_client())
+ return model_instance
diff --git a/frontend/src/components/home/ChatList.tsx b/frontend/src/components/home/ChatList.tsx
index 9d8cb21..50fba0a 100644
--- a/frontend/src/components/home/ChatList.tsx
+++ b/frontend/src/components/home/ChatList.tsx
@@ -94,9 +94,30 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => {
retryOnError: true,
shouldReconnect: () => true,
reconnectInterval: 3000,
+ share: true,
}
);
+ useEffect(() => {
+ if (readyState === ReadyState.CONNECTING) {
+ if (messages.length > 0) {
+ setIsLoading(true);
+ }
+ } else if (readyState === ReadyState.OPEN) {
+ setIsLoading(false);
+ }
+ }, [readyState, messages.length]);
+
+ useEffect(() => {
+ if (messages.length === 0) {
+ setLiveUrl("");
+ setCurrentOutput(null);
+ setOutputsList([]);
+ setIsLoading(false);
+ }
+ }, [messages.length]);
+
+
const scrollToBottom = (smooth = true) => {
if (scrollAreaRef.current) {
const scrollableDiv = scrollAreaRef.current.querySelector(
@@ -557,7 +578,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => {
className="space-y-2 animate-fade-in animate-once animate-duration-500"
key={idx}
>
- {readyState === ReadyState.CONNECTING ? (
+ {readyState === ReadyState.CONNECTING && messages.length > 0 ? (
<>
Vault