diff --git a/cortex_on/agents/code_agent.py b/cortex_on/agents/code_agent.py index 43fa047..d1eb785 100644 --- a/cortex_on/agents/code_agent.py +++ b/cortex_on/agents/code_agent.py @@ -15,7 +15,7 @@ from pydantic_ai.models.anthropic import AnthropicModel # Local application imports -from utils.ant_client import get_client +from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client from utils.stream_response_format import StreamResponse load_dotenv() @@ -25,6 +25,7 @@ class CoderAgentDeps: websocket: Optional[WebSocket] = None stream_output: Optional[StreamResponse] = None + model_preference: str = "Anthropic" # Constants ALLOWED_COMMANDS = { @@ -236,180 +237,188 @@ async def send_stream_update(ctx: RunContext[CoderAgentDeps], message: str) -> N stream_output_json = json.dumps(asdict(ctx.deps.stream_output)) logfire.debug("WebSocket message sent: {stream_output_json}", stream_output_json=stream_output_json) -# Initialize the model -model = AnthropicModel( - model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() -) + # Initialize the agent -coder_agent = Agent( + +async def coder_agent(model_preference: str = "Anthropic") -> Agent: + if model_preference == "Anthropic": + model = get_anthropic_model_instance() + elif model_preference == "OpenAI": + model = get_openai_model_instance() + else: + raise ValueError(f"Unknown model_preference: {model_preference}") + print(f"[CODER_INIT] Creating coder agent with model: {model}") + coder_agent = Agent( model=model, name="Coder Agent", result_type=CoderResult, deps_type=CoderAgentDeps, system_prompt=coder_system_message -) - -@coder_agent.tool -async def execute_shell(ctx: RunContext[CoderAgentDeps], command: str) -> str: - """ - Executes a shell command within a restricted directory and returns the output. - This consolidated tool handles terminal commands and file operations. - """ - try: - # Extract base command for security checks and messaging - base_command = command.split()[0] if command.split() else "" - - # Send operation description message - operation_message = get_high_level_operation_message(command, base_command) - await send_stream_update(ctx, operation_message) - - logfire.info("Executing shell command: {command}", command=command) - - # Setup restricted directory - base_dir = os.path.abspath(os.path.dirname(__file__)) - restricted_dir = os.path.join(base_dir, "code_files") - os.makedirs(restricted_dir, exist_ok=True) - - # Security check - if base_command not in ALLOWED_COMMANDS: - await send_stream_update(ctx, "Operation not permitted") - return f"Error: Command '{base_command}' is not allowed for security reasons." - - # Change to restricted directory for execution - original_dir = os.getcwd() - os.chdir(restricted_dir) - + ) + + @coder_agent.tool + async def execute_shell(ctx: RunContext[CoderAgentDeps], command: str) -> str: + """ + Executes a shell command within a restricted directory and returns the output. + This consolidated tool handles terminal commands and file operations. + """ try: - # Handle echo with redirection (file writing) - if ">" in command and base_command == "echo": - file_path = command.split(">", 1)[1].strip() - await send_stream_update(ctx, f"Writing content to {file_path}") - - # Parse command parts - parts = command.split(">", 1) - echo_cmd = parts[0].strip() - - # Extract content, removing quotes if present - content = echo_cmd[5:].strip() - if (content.startswith('"') and content.endswith('"')) or \ - (content.startswith("'") and content.endswith("'")): - content = content[1:-1] - - try: - with open(file_path, "w") as file: - file.write(content) - - await send_stream_update(ctx, f"File {file_path} created successfully") - return f"Successfully wrote to {file_path}" - except Exception as e: - error_msg = f"Error writing to file: {str(e)}" - await send_stream_update(ctx, f"Failed to create file {file_path}") - logfire.error(error_msg, exc_info=True) - return error_msg + # Extract base command for security checks and messaging + base_command = command.split()[0] if command.split() else "" - # Handle cat with here-document for multiline file writing - elif "<<" in command and base_command == "cat": - cmd_parts = command.split("<<", 1) - cat_part = cmd_parts[0].strip() - - # Extract filename for status message if possible - file_path = None - if ">" in cat_part: - file_path = cat_part.split(">", 1)[1].strip() - await send_stream_update(ctx, f"Creating file {file_path}") + # Send operation description message + operation_message = get_high_level_operation_message(command, base_command) + await send_stream_update(ctx, operation_message) + + logfire.info("Executing shell command: {command}", command=command) + + # Setup restricted directory + base_dir = os.path.abspath(os.path.dirname(__file__)) + restricted_dir = os.path.join(base_dir, "code_files") + os.makedirs(restricted_dir, exist_ok=True) + + # Security check + if base_command not in ALLOWED_COMMANDS: + await send_stream_update(ctx, "Operation not permitted") + return f"Error: Command '{base_command}' is not allowed for security reasons." + + # Change to restricted directory for execution + original_dir = os.getcwd() + os.chdir(restricted_dir) + + try: + # Handle echo with redirection (file writing) + if ">" in command and base_command == "echo": + file_path = command.split(">", 1)[1].strip() + await send_stream_update(ctx, f"Writing content to {file_path}") + + # Parse command parts + parts = command.split(">", 1) + echo_cmd = parts[0].strip() + + # Extract content, removing quotes if present + content = echo_cmd[5:].strip() + if (content.startswith('"') and content.endswith('"')) or \ + (content.startswith("'") and content.endswith("'")): + content = content[1:-1] + + try: + with open(file_path, "w") as file: + file.write(content) + + await send_stream_update(ctx, f"File {file_path} created successfully") + return f"Successfully wrote to {file_path}" + except Exception as e: + error_msg = f"Error writing to file: {str(e)}" + await send_stream_update(ctx, f"Failed to create file {file_path}") + logfire.error(error_msg, exc_info=True) + return error_msg - try: - # Parse heredoc parts - doc_part = cmd_parts[1].strip() + # Handle cat with here-document for multiline file writing + elif "<<" in command and base_command == "cat": + cmd_parts = command.split("<<", 1) + cat_part = cmd_parts[0].strip() - # Extract filename + # Extract filename for status message if possible + file_path = None if ">" in cat_part: file_path = cat_part.split(">", 1)[1].strip() - else: - await send_stream_update(ctx, "Invalid file operation") - return "Error: Invalid cat command format. Must include redirection." + await send_stream_update(ctx, f"Creating file {file_path}") - # Parse the heredoc content and delimiter - if "\n" in doc_part: - delimiter_and_content = doc_part.split("\n", 1) - delimiter = delimiter_and_content[0].strip("'").strip('"') - content = delimiter_and_content[1] + try: + # Parse heredoc parts + doc_part = cmd_parts[1].strip() - # Find the end delimiter and extract content - if f"\n{delimiter}" in content: - content = content.split(f"\n{delimiter}")[0] - - # Write to file - with open(file_path, "w") as file: - file.write(content) + # Extract filename + if ">" in cat_part: + file_path = cat_part.split(">", 1)[1].strip() + else: + await send_stream_update(ctx, "Invalid file operation") + return "Error: Invalid cat command format. Must include redirection." + + # Parse the heredoc content and delimiter + if "\n" in doc_part: + delimiter_and_content = doc_part.split("\n", 1) + delimiter = delimiter_and_content[0].strip("'").strip('"') + content = delimiter_and_content[1] - await send_stream_update(ctx, f"File {file_path} created successfully") - return f"Successfully wrote multiline content to {file_path}" + # Find the end delimiter and extract content + if f"\n{delimiter}" in content: + content = content.split(f"\n{delimiter}")[0] + + # Write to file + with open(file_path, "w") as file: + file.write(content) + + await send_stream_update(ctx, f"File {file_path} created successfully") + return f"Successfully wrote multiline content to {file_path}" + else: + await send_stream_update(ctx, "File content format error") + return "Error: End delimiter not found in heredoc" else: await send_stream_update(ctx, "File content format error") - return "Error: End delimiter not found in heredoc" - else: - await send_stream_update(ctx, "File content format error") - return "Error: Invalid heredoc format" - except Exception as e: - error_msg = f"Error processing cat with heredoc: {str(e)}" - file_path_str = file_path if file_path else 'file' - await send_stream_update(ctx, f"Failed to create file {file_path_str}") - logfire.error(error_msg, exc_info=True) - return error_msg - - # Execute standard commands - else: - # Send execution message - execution_msg = get_high_level_execution_message(command, base_command) - await send_stream_update(ctx, execution_msg) + return "Error: Invalid heredoc format" + except Exception as e: + error_msg = f"Error processing cat with heredoc: {str(e)}" + file_path_str = file_path if file_path else 'file' + await send_stream_update(ctx, f"Failed to create file {file_path_str}") + logfire.error(error_msg, exc_info=True) + return error_msg - # Execute the command using subprocess - try: - args = shlex.split(command) - result = subprocess.run( - args, - shell=True, - capture_output=True, - text=True, - timeout=60, - ) + # Execute standard commands + else: + # Send execution message + execution_msg = get_high_level_execution_message(command, base_command) + await send_stream_update(ctx, execution_msg) - logfire.info(f"Command executed: {result.args}") + # Execute the command using subprocess + try: + args = shlex.split(command) + result = subprocess.run( + args, + shell=True, + capture_output=True, + text=True, + timeout=60, + ) + + logfire.info(f"Command executed: {result.args}") + + # Handle success + if result.returncode == 0: + success_msg = get_success_message(command, base_command) + await send_stream_update(ctx, success_msg) + logfire.info(f"Command executed successfully: {result.stdout}") + return result.stdout + + # Handle failure + else: + files = os.listdir('.') + error_msg = f"Command failed with error code {result.returncode}:\n{result.stderr}\n\nFiles in directory: {files}" + failure_msg = get_failure_message(command, base_command) + await send_stream_update(ctx, failure_msg) + return error_msg - # Handle success - if result.returncode == 0: - success_msg = get_success_message(command, base_command) - await send_stream_update(ctx, success_msg) - logfire.info(f"Command executed successfully: {result.stdout}") - return result.stdout + except subprocess.TimeoutExpired: + await send_stream_update(ctx, "Operation timed out") + return "Command execution timed out after 60 seconds" - # Handle failure - else: - files = os.listdir('.') - error_msg = f"Command failed with error code {result.returncode}:\n{result.stderr}\n\nFiles in directory: {files}" - failure_msg = get_failure_message(command, base_command) - await send_stream_update(ctx, failure_msg) + except Exception as e: + error_msg = f"Error executing command: {str(e)}" + await send_stream_update(ctx, "Operation failed") + logfire.error(error_msg, exc_info=True) return error_msg + + finally: + # Always return to the original directory + os.chdir(original_dir) - except subprocess.TimeoutExpired: - await send_stream_update(ctx, "Operation timed out") - return "Command execution timed out after 60 seconds" - - except Exception as e: - error_msg = f"Error executing command: {str(e)}" - await send_stream_update(ctx, "Operation failed") - logfire.error(error_msg, exc_info=True) - return error_msg + except Exception as e: + error_msg = f"Error executing command: {str(e)}" + await send_stream_update(ctx, "Operation failed") + logfire.error(error_msg, exc_info=True) + return error_msg - finally: - # Always return to the original directory - os.chdir(original_dir) - - except Exception as e: - error_msg = f"Error executing command: {str(e)}" - await send_stream_update(ctx, "Operation failed") - logfire.error(error_msg, exc_info=True) - return error_msg \ No newline at end of file + logfire.info("All tools initialized for coder agent") + return coder_agent \ No newline at end of file diff --git a/cortex_on/agents/orchestrator_agent.py b/cortex_on/agents/orchestrator_agent.py index 6b001ba..e08c881 100644 --- a/cortex_on/agents/orchestrator_agent.py +++ b/cortex_on/agents/orchestrator_agent.py @@ -12,9 +12,9 @@ from pydantic_ai import Agent, RunContext from agents.web_surfer import WebSurfer from utils.stream_response_format import StreamResponse -from agents.planner_agent import planner_agent, update_todo_status +from agents.planner_agent import planner_agent from agents.code_agent import coder_agent, CoderAgentDeps -from utils.ant_client import get_client +from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client @dataclass class orchestrator_deps: @@ -22,6 +22,7 @@ class orchestrator_deps: stream_output: Optional[StreamResponse] = None # Add a collection to track agent-specific streams agent_responses: Optional[List[StreamResponse]] = None + model_preference: str = "Anthropic" orchestrator_system_prompt = """You are an AI orchestrator that manages a team of agents to solve tasks. You have access to tools for coordinating the agents and managing the task flow. @@ -158,347 +159,377 @@ class orchestrator_deps: - Format: "Task description (agent_name)" """ -model = AnthropicModel( - model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() -) - -orchestrator_agent = Agent( - model=model, - name="Orchestrator Agent", - system_prompt=orchestrator_system_prompt, - deps_type=orchestrator_deps -) - -@orchestrator_agent.tool -async def plan_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Plans the task and assigns it to the appropriate agents""" - try: - logfire.info(f"Planning task: {task}") - - # Create a new StreamResponse for Planner Agent - planner_stream_output = StreamResponse( - agent_name="Planner Agent", - instructions=task, - steps=[], - output="", - status_code=0 - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(planner_stream_output) +async def orchestrator_agent(model_preference: str) -> Agent: + if model_preference == "Anthropic": + + model = get_anthropic_model_instance() + elif model_preference == "OpenAI": + model = get_openai_model_instance() + else: + raise ValueError(f"Unknown model_preference: {model_preference}") + orchestrator_agent = Agent( + model=model, + name="Orchestrator Agent", + system_prompt=orchestrator_system_prompt, + deps_type=orchestrator_deps + ) + + @orchestrator_agent.tool + async def plan_task(ctx: RunContext[orchestrator_deps], task: str) -> str: + """Plans the task and assigns it to the appropriate agents""" + try: + logfire.info(f"Planning task: {task}") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Update planner stream - planner_stream_output.steps.append("Planning task...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Run planner agent - planner_response = await planner_agent.run(user_prompt=task) - - # Update planner stream with results - plan_text = planner_response.data.plan - planner_stream_output.steps.append("Task planned successfully") - planner_stream_output.output = plan_text - planner_stream_output.status_code = 200 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Also update orchestrator stream - ctx.deps.stream_output.steps.append("Task planned successfully") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - - return f"Task planned successfully\nTask: {plan_text}" - except Exception as e: - error_msg = f"Error planning task: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update planner stream with error - if planner_stream_output: - planner_stream_output.steps.append(f"Planning failed: {str(e)}") - planner_stream_output.status_code = 500 + # Create a new StreamResponse for Planner Agent + planner_stream_output = StreamResponse( + agent_name="Planner Agent", + instructions=task, + steps=[], + output="", + status_code=0 + ) + + # Add to orchestrator's response collection if available + if ctx.deps.agent_responses is not None: + ctx.deps.agent_responses.append(planner_stream_output) + + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Update planner stream + planner_stream_output.steps.append("Planning task...") await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - # Also update orchestrator stream - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Planning failed: {str(e)}") + # Run planner agent + agent = await planner_agent(model_preference=ctx.deps.model_preference) + planner_response = await agent.run(user_prompt=task) + + logfire.info(f"Planner Agent using model type: {ctx.deps.model_preference}") + + # Update planner stream with results + plan_text = planner_response.data.plan + planner_stream_output.steps.append("Task planned successfully") + planner_stream_output.output = plan_text + planner_stream_output.status_code = 200 + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Also update orchestrator stream + ctx.deps.stream_output.steps.append("Task planned successfully") await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) - return f"Failed to plan task: {error_msg}" + return f"Task planned successfully\nTask: {plan_text}" + except Exception as e: + error_msg = f"Error planning task: {str(e)}" + logfire.error(error_msg, exc_info=True) + + # Update planner stream with error + if planner_stream_output: + planner_stream_output.steps.append(f"Planning failed: {str(e)}") + planner_stream_output.status_code = 500 + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Also update orchestrator stream + if ctx.deps.stream_output: + ctx.deps.stream_output.steps.append(f"Planning failed: {str(e)}") + await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + + return f"Failed to plan task: {error_msg}" -@orchestrator_agent.tool -async def coder_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Assigns coding tasks to the coder agent""" - try: - logfire.info(f"Assigning coding task: {task}") - - # Create a new StreamResponse for Coder Agent - coder_stream_output = StreamResponse( - agent_name="Coder Agent", - instructions=task, - steps=[], - output="", - status_code=0 - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(coder_stream_output) - - # Send initial update for Coder Agent - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - - # Create deps with the new stream_output - deps_for_coder_agent = CoderAgentDeps( - websocket=ctx.deps.websocket, - stream_output=coder_stream_output - ) - - # Run coder agent - coder_response = await coder_agent.run( - user_prompt=task, - deps=deps_for_coder_agent - ) - - # Extract response data - response_data = coder_response.data.content - - # Update coder_stream_output with coding results - coder_stream_output.output = response_data - coder_stream_output.status_code = 200 - coder_stream_output.steps.append("Coding task completed successfully") - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - - # Add a reminder in the result message to update the plan using planner_agent_update - response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\"" - - return response_with_reminder - except Exception as e: - error_msg = f"Error assigning coding task: {str(e)}" - logfire.error(error_msg, exc_info=True) + @orchestrator_agent.tool + async def coder_task(ctx: RunContext[orchestrator_deps], task: str) -> str: + """Assigns coding tasks to the coder agent""" + try: + logfire.info(f"Assigning coding task: {task}") + + # Create a new StreamResponse for Coder Agent + coder_stream_output = StreamResponse( + agent_name="Coder Agent", + instructions=task, + steps=[], + output="", + status_code=0 + ) + + # Add to orchestrator's response collection if available + if ctx.deps.agent_responses is not None: + ctx.deps.agent_responses.append(coder_stream_output) + + # Send initial update for Coder Agent + await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) + + # Create deps with the new stream_output + deps_for_coder_agent = CoderAgentDeps( + websocket=ctx.deps.websocket, + stream_output=coder_stream_output, + model_preference=ctx.deps.model_preference + ) + + # Run coder agent + agent = await coder_agent(model_preference=ctx.deps.model_preference) + coder_response = await agent.run( + user_prompt=task, + deps=deps_for_coder_agent + ) + logfire.info(f"coder_response: {coder_response}") + logfire.info(f"Coder Agent using model type: {ctx.deps.model_preference}") + + + # Extract response data + response_data = coder_response.data.content + + # Update coder_stream_output with coding results + coder_stream_output.output = response_data + coder_stream_output.status_code = 200 + coder_stream_output.steps.append("Coding task completed successfully") + await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) + + # Add a reminder in the result message to update the plan using planner_agent_update + response_with_reminder = f"{response_data}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (coder_agent)\"" + + return response_with_reminder + except Exception as e: + error_msg = f"Error assigning coding task: {str(e)}" + logfire.error(error_msg, exc_info=True) - # Update coder_stream_output with error - coder_stream_output.steps.append(f"Coding task failed: {str(e)}") - coder_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) + # Update coder_stream_output with error + coder_stream_output.steps.append(f"Coding task failed: {str(e)}") + coder_stream_output.status_code = 500 + await _safe_websocket_send(ctx.deps.websocket, coder_stream_output) - return f"Failed to assign coding task: {error_msg}" + return f"Failed to assign coding task: {error_msg}" -@orchestrator_agent.tool -async def web_surfer_task(ctx: RunContext[orchestrator_deps], task: str) -> str: - """Assigns web surfing tasks to the web surfer agent""" - try: - logfire.info(f"Assigning web surfing task: {task}") - - # Create a new StreamResponse for WebSurfer - web_surfer_stream_output = StreamResponse( - agent_name="Web Surfer", - instructions=task, - steps=[], - output="", - status_code=0, - live_url=None - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(web_surfer_stream_output) - - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - # Initialize WebSurfer agent - web_surfer_agent = WebSurfer(api_url="http://localhost:8000/api/v1/web/stream") - - # Run WebSurfer with its own stream_output - success, message, messages = await web_surfer_agent.generate_reply( - instruction=task, - websocket=ctx.deps.websocket, - stream_output=web_surfer_stream_output - ) - - # Update WebSurfer's stream_output with final result - if success: - web_surfer_stream_output.steps.append("Web search completed successfully") - web_surfer_stream_output.output = message - web_surfer_stream_output.status_code = 200 - - # Add a reminder to update the plan - message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\"" - else: - web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}") - web_surfer_stream_output.status_code = 500 - message_with_reminder = message - - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}") - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - - return message_with_reminder - except Exception as e: - error_msg = f"Error assigning web surfing task: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update WebSurfer's stream_output with error - web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}") - web_surfer_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - return f"Failed to assign web surfing task: {error_msg}" - -@orchestrator_agent.tool -async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str: - """Sends a question to the frontend and waits for human input""" - try: - logfire.info(f"Asking human: {question}") - - # Create a new StreamResponse for Human Input - human_stream_output = StreamResponse( - agent_name="Human Input", - instructions=question, - steps=[], - output="", - status_code=0 - ) - - # Add to orchestrator's response collection if available - if ctx.deps.agent_responses is not None: - ctx.deps.agent_responses.append(human_stream_output) - - # Send the question to frontend - await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - - # Update stream with waiting message - human_stream_output.steps.append("Waiting for human input...") - await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - - # Wait for response from frontend - response = await ctx.deps.websocket.receive_text() - - # Update stream with response - human_stream_output.steps.append("Received human input") - human_stream_output.output = response - human_stream_output.status_code = 200 - await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - - return response - except Exception as e: - error_msg = f"Error getting human input: {str(e)}" - logfire.error(error_msg, exc_info=True) - - # Update stream with error - human_stream_output.steps.append(f"Failed to get human input: {str(e)}") - human_stream_output.status_code = 500 - await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - - return f"Failed to get human input: {error_msg}" - -@orchestrator_agent.tool -async def planner_agent_update(ctx: RunContext[orchestrator_deps], completed_task: str) -> str: - """ - Updates the todo.md file to mark a task as completed and returns the full updated plan. - - Args: - completed_task: Description of the completed task including which agent performed it - - Returns: - The complete updated todo.md content with tasks marked as completed - """ - try: - logfire.info(f"Updating plan with completed task: {completed_task}") - - # Create a new StreamResponse for Planner Agent update - planner_stream_output = StreamResponse( - agent_name="Planner Agent", - instructions=f"Update todo.md to mark as completed: {completed_task}", - steps=[], - output="", - status_code=0 - ) - - # Send initial update - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Directly read and update the todo.md file - base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - planner_dir = os.path.join(base_dir, "agents", "planner") - todo_path = os.path.join(planner_dir, "todo.md") - - planner_stream_output.steps.append("Reading current todo.md...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) - - # Make sure the directory exists - os.makedirs(planner_dir, exist_ok=True) - + @orchestrator_agent.tool + async def web_surfer_task(ctx: RunContext[orchestrator_deps], task: str) -> str: + """Assigns web surfing tasks to the web surfer agent""" try: - # Check if todo.md exists - if not os.path.exists(todo_path): - planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + logfire.info(f"Assigning web surfing task: {task}") + + # Create a new StreamResponse for WebSurfer + web_surfer_stream_output = StreamResponse( + agent_name="Web Surfer", + instructions=task, + steps=[], + output="", + status_code=0, + live_url=None + ) + + # Add to orchestrator's response collection if available + if ctx.deps.agent_responses is not None: + ctx.deps.agent_responses.append(web_surfer_stream_output) + + await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) + + # Initialize WebSurfer agent with model preference + web_surfer_agent = WebSurfer( + api_url="http://localhost:8000/api/v1/web/stream", + model_preference=ctx.deps.model_preference + ) + logfire.info(f"web_surfing on model type: {ctx.deps.model_preference}") - # We'll directly call planner_agent.run() to create a new plan first - plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}" - plan_response = await planner_agent.run(user_prompt=plan_prompt) - current_content = plan_response.data.plan - else: - # Read existing todo.md - with open(todo_path, "r") as file: - current_content = file.read() - planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + # Run WebSurfer with its own stream_output + success, message, messages = await web_surfer_agent.generate_reply( + instruction=task, + websocket=ctx.deps.websocket, + stream_output=web_surfer_stream_output + ) - # Now call planner_agent.run() with specific instructions to update the plan - update_prompt = f""" - Here is the current todo.md content: + # Update WebSurfer's stream_output with final result + if success: + web_surfer_stream_output.steps.append("Web search completed successfully") + web_surfer_stream_output.output = message + web_surfer_stream_output.status_code = 200 + + # Add a reminder to update the plan + message_with_reminder = f"{message}\n\nReminder: You must now call planner_agent_update with the completed task description: \"{task} (web_surfer_agent)\"" + else: + web_surfer_stream_output.steps.append(f"Web search completed with issues: {message[:100]}") + web_surfer_stream_output.status_code = 500 + message_with_reminder = message - {current_content} + await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - Please update this plan to mark the following task as completed: {completed_task} - Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ]. - """ + web_surfer_stream_output.steps.append(f"WebSurfer completed: {'Success' if success else 'Failed'}") + await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) - planner_stream_output.steps.append("Asking planner to update the plan...") - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + return message_with_reminder + except Exception as e: + error_msg = f"Error assigning web surfing task: {str(e)}" + logfire.error(error_msg, exc_info=True) - updated_plan_response = await planner_agent.run(user_prompt=update_prompt) - updated_plan = updated_plan_response.data.plan + # Update WebSurfer's stream_output with error + web_surfer_stream_output.steps.append(f"Web search failed: {str(e)}") + web_surfer_stream_output.status_code = 500 + await _safe_websocket_send(ctx.deps.websocket, web_surfer_stream_output) + return f"Failed to assign web surfing task: {error_msg}" + + @orchestrator_agent.tool + async def ask_human(ctx: RunContext[orchestrator_deps], question: str) -> str: + """Sends a question to the frontend and waits for human input""" + try: + logfire.info(f"Asking human: {question}") - # Write the updated plan back to todo.md - with open(todo_path, "w") as file: - file.write(updated_plan) + # Create a new StreamResponse for Human Input + human_stream_output = StreamResponse( + agent_name="Human Input", + instructions=question, + steps=[], + output="", + status_code=0 + ) + + # Add to orchestrator's response collection if available + if ctx.deps.agent_responses is not None: + ctx.deps.agent_responses.append(human_stream_output) + + # Send the question to frontend + await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - planner_stream_output.steps.append("Plan updated successfully") - planner_stream_output.output = updated_plan - planner_stream_output.status_code = 200 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + # Update stream with waiting message + human_stream_output.steps.append("Waiting for human input...") + await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - # Update orchestrator stream - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Plan updated to mark task as completed: {completed_task}") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + # Wait for response from frontend + response = await ctx.deps.websocket.receive_text() - return updated_plan + # Update stream with response + human_stream_output.steps.append("Received human input") + human_stream_output.output = response + human_stream_output.status_code = 200 + await _safe_websocket_send(ctx.deps.websocket, human_stream_output) + return response except Exception as e: - error_msg = f"Error during plan update operations: {str(e)}" + error_msg = f"Error getting human input: {str(e)}" logfire.error(error_msg, exc_info=True) - planner_stream_output.steps.append(f"Plan update failed: {str(e)}") - planner_stream_output.status_code = a500 - await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + # Update stream with error + human_stream_output.steps.append(f"Failed to get human input: {str(e)}") + human_stream_output.status_code = 500 + await _safe_websocket_send(ctx.deps.websocket, human_stream_output) - return f"Failed to update the plan: {error_msg}" - - except Exception as e: - error_msg = f"Error updating plan: {str(e)}" - logfire.error(error_msg, exc_info=True) + return f"Failed to get human input: {error_msg}" + + @orchestrator_agent.tool + async def planner_agent_update(ctx: RunContext[orchestrator_deps], completed_task: str) -> str: + """ + Updates the todo.md file to mark a task as completed and returns the full updated plan. - # Update stream output with error - if ctx.deps.stream_output: - ctx.deps.stream_output.steps.append(f"Failed to update plan: {str(e)}") - await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + Args: + completed_task: Description of the completed task including which agent performed it - return f"Failed to update plan: {error_msg}" + Returns: + The complete updated todo.md content with tasks marked as completed + """ + try: + logfire.info(f"Updating plan with completed task: {completed_task}") + + # Create a new StreamResponse for Planner Agent update + planner_stream_output = StreamResponse( + agent_name="Planner Agent", + instructions=f"Update todo.md to mark as completed: {completed_task}", + steps=[], + output="", + status_code=0 + ) + + # Send initial update + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Directly read and update the todo.md file + base_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + planner_dir = os.path.join(base_dir, "agents", "planner") + todo_path = os.path.join(planner_dir, "todo.md") + + planner_stream_output.steps.append("Reading current todo.md...") + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Make sure the directory exists + os.makedirs(planner_dir, exist_ok=True) + + try: + # Check if todo.md exists + if not os.path.exists(todo_path): + planner_stream_output.steps.append("No todo.md file found. Will create new one after task completion.") + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # We'll directly call planner_agent.run() to create a new plan first + plan_prompt = f"Create a simple task plan based on this completed task: {completed_task}" + + agent = await planner_agent(model_preference=ctx.deps.model_preference) + plan_response = await agent.run(user_prompt=plan_prompt) + current_content = plan_response.data.plan + logfire.info(f"Updating plan task using model type: {ctx.deps.model_preference}") + + else: + # Read existing todo.md + with open(todo_path, "r") as file: + current_content = file.read() + planner_stream_output.steps.append(f"Found existing todo.md ({len(current_content)} bytes)") + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Now call planner_agent.run() with specific instructions to update the plan + update_prompt = f""" + Here is the current todo.md content: + + {current_content} + + Please update this plan to mark the following task as completed: {completed_task} + Return ONLY the fully updated plan with appropriate tasks marked as [x] instead of [ ]. + """ + + planner_stream_output.steps.append("Asking planner to update the plan...") + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + agent = await planner_agent(model_preference=ctx.deps.model_preference) + updated_plan_response = await agent.run(user_prompt=update_prompt) + updated_plan = updated_plan_response.data.plan + + + logfire.info(f"Updating prompt for plan task using model type: {ctx.deps.model_preference}") + + # Write the updated plan back to todo.md + with open(todo_path, "w") as file: + file.write(updated_plan) + + planner_stream_output.steps.append("Plan updated successfully") + planner_stream_output.output = updated_plan + planner_stream_output.status_code = 200 + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + # Update orchestrator stream + if ctx.deps.stream_output: + ctx.deps.stream_output.steps.append(f"Plan updated to mark task as completed: {completed_task}") + await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + + return updated_plan + + except Exception as e: + error_msg = f"Error during plan update operations: {str(e)}" + logfire.error(error_msg, exc_info=True) + + planner_stream_output.steps.append(f"Plan update failed: {str(e)}") + planner_stream_output.status_code = 500 + await _safe_websocket_send(ctx.deps.websocket, planner_stream_output) + + return f"Failed to update the plan: {error_msg}" + + except Exception as e: + error_msg = f"Error updating plan: {str(e)}" + logfire.error(error_msg, exc_info=True) + + # Update stream output with error + if ctx.deps.stream_output: + ctx.deps.stream_output.steps.append(f"Failed to update plan: {str(e)}") + await _safe_websocket_send(ctx.deps.websocket, ctx.deps.stream_output) + + return f"Failed to update plan: {error_msg}" + + + + logfire.info("All tools initialized for orchestrator agent") + return orchestrator_agent + + # Helper function for sending WebSocket messages async def _safe_websocket_send(websocket: Optional[WebSocket], message: Any) -> bool: diff --git a/cortex_on/agents/planner_agent.py b/cortex_on/agents/planner_agent.py index 897c22f..d607d04 100644 --- a/cortex_on/agents/planner_agent.py +++ b/cortex_on/agents/planner_agent.py @@ -10,7 +10,7 @@ from pydantic_ai.models.anthropic import AnthropicModel # Local application imports -from utils.ant_client import get_client +from utils.ant_client import get_client, get_anthropic_model_instance, get_openai_model_instance, get_openai_client @@ -176,172 +176,178 @@ class PlannerResult(BaseModel): plan: str = Field(description="The generated or updated plan in string format - this should be the complete plan text") -model = AnthropicModel( - model_name=os.environ.get("ANTHROPIC_MODEL_NAME"), - anthropic_client=get_client() -) +async def planner_agent(model_preference: str = "Anthropic") -> Agent: + if model_preference == "Anthropic": + model = get_anthropic_model_instance() + elif model_preference == "OpenAI": + model = get_openai_model_instance() + else: + raise ValueError(f"Unknown model_preference: {model_preference}") + planner_agent = Agent( + model=model, + name="Planner Agent", + result_type=PlannerResult, + system_prompt=planner_prompt + ) -planner_agent = Agent( - model=model, - name="Planner Agent", - result_type=PlannerResult, - system_prompt=planner_prompt -) - -@planner_agent.tool_plain -async def update_todo_status(task_description: str) -> str: - """ - A helper function that logs the update request but lets the planner agent handle the actual update logic. - - Args: - task_description: Description of the completed task + @planner_agent.tool_plain + async def update_todo_status(task_description: str) -> str: + """ + A helper function that logs the update request but lets the planner agent handle the actual update logic. - Returns: - A simple log message - """ - logfire.info(f"Received request to update todo.md for task: {task_description}") - return f"Received update request for: {task_description}" + Args: + task_description: Description of the completed task + + Returns: + A simple log message + """ + logfire.info(f"Received request to update todo.md for task: {task_description}") + return f"Received update request for: {task_description}" -@planner_agent.tool_plain -async def execute_terminal(command: str) -> str: - """ - Executes a terminal command within the planner directory for file operations. - This consolidated tool handles reading and writing plan files. - Restricted to only read and write operations for security. - """ - try: - logfire.info(f"Executing terminal command: {command}") - - # Define the restricted directory - base_dir = os.path.abspath(os.path.dirname(__file__)) - planner_dir = os.path.join(base_dir, "planner") - os.makedirs(planner_dir, exist_ok=True) - - # Extract the base command - base_command = command.split()[0] - - # Allow only read and write operations - ALLOWED_COMMANDS = {"cat", "echo", "ls"} - - # Security checks - if base_command not in ALLOWED_COMMANDS: - return f"Error: Command '{base_command}' is not allowed. Only read and write operations are permitted." - - if ".." in command or "~" in command or "/" in command: - return "Error: Path traversal attempts are not allowed." - - # Change to the restricted directory - original_dir = os.getcwd() - os.chdir(planner_dir) - + @planner_agent.tool_plain + async def execute_terminal(command: str) -> str: + """ + Executes a terminal command within the planner directory for file operations. + This consolidated tool handles reading and writing plan files. + Restricted to only read and write operations for security. + """ try: - # Handle echo with >> (append) - if base_command == "echo" and ">>" in command: - try: - # Split only on the first occurrence of >> - parts = command.split(">>", 1) - echo_part = parts[0].strip() + logfire.info(f"Executing terminal command: {command}") + + # Define the restricted directory + base_dir = os.path.abspath(os.path.dirname(__file__)) + planner_dir = os.path.join(base_dir, "planner") + os.makedirs(planner_dir, exist_ok=True) + + # Extract the base command + base_command = command.split()[0] + + # Allow only read and write operations + ALLOWED_COMMANDS = {"cat", "echo", "ls"} + + # Security checks + if base_command not in ALLOWED_COMMANDS: + return f"Error: Command '{base_command}' is not allowed. Only read and write operations are permitted." + + if ".." in command or "~" in command or "/" in command: + return "Error: Path traversal attempts are not allowed." + + # Change to the restricted directory + original_dir = os.getcwd() + os.chdir(planner_dir) + + try: + # Handle echo with >> (append) + if base_command == "echo" and ">>" in command: + try: + # Split only on the first occurrence of >> + parts = command.split(">>", 1) + echo_part = parts[0].strip() + file_path = parts[1].strip() + + # Extract content after echo command + content = echo_part[4:].strip() + + # Handle quotes if present + if (content.startswith('"') and content.endswith('"')) or \ + (content.startswith("'") and content.endswith("'")): + content = content[1:-1] + + # Append to file + with open(file_path, "a") as file: + file.write(content + "\n") + return f"Successfully appended to {file_path}" + except Exception as e: + logfire.error(f"Error appending to file: {str(e)}", exc_info=True) + return f"Error appending to file: {str(e)}" + + # Special handling for echo with redirection (file writing) + elif ">" in command and base_command == "echo" and ">>" not in command: + # Simple parsing for echo "content" > file.txt + parts = command.split(">", 1) + echo_cmd = parts[0].strip() file_path = parts[1].strip() - # Extract content after echo command - content = echo_part[4:].strip() - - # Handle quotes if present + # Extract content between echo and > (removing quotes if present) + content = echo_cmd[5:].strip() if (content.startswith('"') and content.endswith('"')) or \ - (content.startswith("'") and content.endswith("'")): + (content.startswith("'") and content.endswith("'")): content = content[1:-1] - # Append to file - with open(file_path, "a") as file: - file.write(content + "\n") - return f"Successfully appended to {file_path}" - except Exception as e: - logfire.error(f"Error appending to file: {str(e)}", exc_info=True) - return f"Error appending to file: {str(e)}" - - # Special handling for echo with redirection (file writing) - elif ">" in command and base_command == "echo" and ">>" not in command: - # Simple parsing for echo "content" > file.txt - parts = command.split(">", 1) - echo_cmd = parts[0].strip() - file_path = parts[1].strip() + # Write to file + try: + with open(file_path, "w") as file: + file.write(content) + return f"Successfully wrote to {file_path}" + except Exception as e: + logfire.error(f"Error writing to file: {str(e)}", exc_info=True) + return f"Error writing to file: {str(e)}" - # Extract content between echo and > (removing quotes if present) - content = echo_cmd[5:].strip() - if (content.startswith('"') and content.endswith('"')) or \ - (content.startswith("'") and content.endswith("'")): - content = content[1:-1] - - # Write to file - try: - with open(file_path, "w") as file: - file.write(content) - return f"Successfully wrote to {file_path}" - except Exception as e: - logfire.error(f"Error writing to file: {str(e)}", exc_info=True) - return f"Error writing to file: {str(e)}" - - # Handle cat with here-document for multiline file writing - elif "<<" in command and base_command == "cat": - try: - # Parse the command: cat > file.md << 'EOF'\nplan content\nEOF - cmd_parts = command.split("<<", 1) - cat_part = cmd_parts[0].strip() - doc_part = cmd_parts[1].strip() - - # Extract filename - if ">" in cat_part: - file_path = cat_part.split(">", 1)[1].strip() - else: - return "Error: Invalid cat command format. Must include redirection." - - # Parse the heredoc content - if "\n" in doc_part: - delimiter_and_content = doc_part.split("\n", 1) - delimiter = delimiter_and_content[0].strip("'").strip('"') - content = delimiter_and_content[1] + # Handle cat with here-document for multiline file writing + elif "<<" in command and base_command == "cat": + try: + # Parse the command: cat > file.md << 'EOF'\nplan content\nEOF + cmd_parts = command.split("<<", 1) + cat_part = cmd_parts[0].strip() + doc_part = cmd_parts[1].strip() - # Find the end delimiter and extract content - if f"\n{delimiter}" in content: - content = content.split(f"\n{delimiter}")[0] + # Extract filename + if ">" in cat_part: + file_path = cat_part.split(">", 1)[1].strip() + else: + return "Error: Invalid cat command format. Must include redirection." + + # Parse the heredoc content + if "\n" in doc_part: + delimiter_and_content = doc_part.split("\n", 1) + delimiter = delimiter_and_content[0].strip("'").strip('"') + content = delimiter_and_content[1] - # Write to file - with open(file_path, "w") as file: - file.write(content) - return f"Successfully wrote multiline content to {file_path}" + # Find the end delimiter and extract content + if f"\n{delimiter}" in content: + content = content.split(f"\n{delimiter}")[0] + + # Write to file + with open(file_path, "w") as file: + file.write(content) + return f"Successfully wrote multiline content to {file_path}" + else: + return "Error: End delimiter not found in heredoc" else: - return "Error: End delimiter not found in heredoc" - else: - return "Error: Invalid heredoc format" - except Exception as e: - logfire.error(f"Error processing cat with heredoc: {str(e)}", exc_info=True) - return f"Error processing cat with heredoc: {str(e)}" - - # Handle cat for reading files - elif base_command == "cat" and ">" not in command and "<<" not in command: - try: - file_path = command.split()[1] - with open(file_path, "r") as file: - content = file.read() - return content - except Exception as e: - logfire.error(f"Error reading file: {str(e)}", exc_info=True) - return f"Error reading file: {str(e)}" - - # Handle ls for listing files - elif base_command == "ls": - try: - files = os.listdir('.') - return "Files in planner directory:\n" + "\n".join(files) - except Exception as e: - logfire.error(f"Error listing files: {str(e)}", exc_info=True) - return f"Error listing files: {str(e)}" - else: - return f"Error: Command '{command}' is not supported. Only read and write operations are permitted." - - finally: - os.chdir(original_dir) - - except Exception as e: - logfire.error(f"Error executing command: {str(e)}", exc_info=True) - return f"Error executing command: {str(e)}" \ No newline at end of file + return "Error: Invalid heredoc format" + except Exception as e: + logfire.error(f"Error processing cat with heredoc: {str(e)}", exc_info=True) + return f"Error processing cat with heredoc: {str(e)}" + + # Handle cat for reading files + elif base_command == "cat" and ">" not in command and "<<" not in command: + try: + file_path = command.split()[1] + with open(file_path, "r") as file: + content = file.read() + return content + except Exception as e: + logfire.error(f"Error reading file: {str(e)}", exc_info=True) + return f"Error reading file: {str(e)}" + + # Handle ls for listing files + elif base_command == "ls": + try: + files = os.listdir('.') + return "Files in planner directory:\n" + "\n".join(files) + except Exception as e: + logfire.error(f"Error listing files: {str(e)}", exc_info=True) + return f"Error listing files: {str(e)}" + else: + return f"Error: Command '{command}' is not supported. Only read and write operations are permitted." + + finally: + os.chdir(original_dir) + + except Exception as e: + logfire.error(f"Error executing command: {str(e)}", exc_info=True) + return f"Error executing command: {str(e)}" + + + logfire.info("All tools initialized for planner agent") + return planner_agent \ No newline at end of file diff --git a/cortex_on/agents/web_surfer.py b/cortex_on/agents/web_surfer.py index 34e2cfd..2711348 100644 --- a/cortex_on/agents/web_surfer.py +++ b/cortex_on/agents/web_surfer.py @@ -29,19 +29,25 @@ TIMEOUT = 9999999999999999999999999999999999999999999 class WebSurfer: - def __init__(self, api_url: str = "http://localhost:8000/api/v1/web/stream"): + def __init__(self, api_url: str = "http://localhost:8000/api/v1/web/stream", model_preference: str = "Anthropic"): self.api_url = api_url self.name = "Web Surfer Agent" self.description = "An agent that is a websurfer and a webscraper that can access any web-page to extract information or perform actions." self.websocket: Optional[WebSocket] = None self.stream_output: Optional[StreamResponse] = None + self.model_preference = model_preference async def _make_api_call(self, instruction: str) -> Tuple[int, List[Dict[str, Any]]]: session_timeout = aiohttp.ClientTimeout(total=None, sock_connect=TIMEOUT, sock_read=TIMEOUT) async with aiohttp.ClientSession(timeout=session_timeout) as session: final_json_response = [] try: - payload = {"cmd": instruction, "critique_disabled": False} + payload = { + "cmd": instruction, + "critique_disabled": False, + "model_preference": self.model_preference + } + logfire.info(f"Making API call with model_preference: {self.model_preference}") async with session.post(self.api_url, json=payload) as response: if response.status != 200: error_text = await response.text() diff --git a/cortex_on/instructor.py b/cortex_on/instructor.py index b4f0efb..bf9670f 100644 --- a/cortex_on/instructor.py +++ b/cortex_on/instructor.py @@ -25,9 +25,6 @@ load_dotenv() - - - class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder that can handle datetime objects""" def default(self, obj): @@ -38,10 +35,12 @@ def default(self, obj): # Main Orchestrator Class class SystemInstructor: - def __init__(self): + def __init__(self, model_preference: str = "Anthropic"): + logfire.info(f"Initializing SystemInstructor with model_preference: {model_preference}") self.websocket: Optional[WebSocket] = None self.stream_output: Optional[StreamResponse] = None self.orchestrator_response: List[StreamResponse] = [] + self.model_preference = model_preference self._setup_logging() def _setup_logging(self) -> None: @@ -80,7 +79,8 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]: deps_for_orchestrator = orchestrator_deps( websocket=self.websocket, stream_output=stream_output, - agent_responses=self.orchestrator_response # Pass reference to collection + agent_responses=self.orchestrator_response, # Pass reference to collection + model_preference=self.model_preference ) try: @@ -88,11 +88,15 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]: await self._safe_websocket_send(stream_output) stream_output.steps.append("Agents initialized successfully") await self._safe_websocket_send(stream_output) - - orchestrator_response = await orchestrator_agent.run( + + agent = await orchestrator_agent(self.model_preference) + orchestrator_response = await agent.run( user_prompt=task, deps=deps_for_orchestrator ) + + logfire.info(f"Orchestrator Agent using model type: {self.model_preference}") + stream_output.output = orchestrator_response.data stream_output.status_code = 200 logfire.debug(f"Orchestrator response: {orchestrator_response.data}") @@ -116,15 +120,11 @@ async def run(self, task: str, websocket: WebSocket) -> List[Dict[str, Any]]: finally: logfire.info("Orchestration process complete") - # Clear any sensitive data + + async def shutdown(self): """Clean shutdown of orchestrator""" try: - # Close websocket if open - if self.websocket: - await self.websocket.close() - - # Clear all responses self.orchestrator_response = [] logfire.info("Orchestrator shutdown complete") diff --git a/cortex_on/main.py b/cortex_on/main.py index a8dd4de..44b840b 100644 --- a/cortex_on/main.py +++ b/cortex_on/main.py @@ -1,27 +1,112 @@ # Standard library imports from typing import List, Optional +from contextlib import asynccontextmanager # Third-party imports -from fastapi import FastAPI, WebSocket +from fastapi import FastAPI, WebSocket, WebSocketDisconnect ,Depends +from fastapi.middleware.cors import CORSMiddleware +import logfire + +# Configure Logfire +logfire.configure() # Local application imports from instructor import SystemInstructor +# Default model preference is Anthropic +MODEL_PREFERENCE = "Anthropic" + +# Global instructor instance +instructor = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Set default model preference at startup + app.state.model_preference = MODEL_PREFERENCE + logfire.info(f"Setting default model preference to: {MODEL_PREFERENCE}") + + # Initialize the instructor + global instructor + instructor = SystemInstructor(model_preference=MODEL_PREFERENCE) + + yield + + + +app: FastAPI = FastAPI(lifespan=lifespan) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allow all origins for development + allow_credentials=True, + allow_methods=["*"], # Allow all methods + allow_headers=["*"], # Allow all headers +) + +async def get_model_preference() -> str: + """ + Get the current model preference from app state + """ + logfire.info(f"Current model preference: {app.state.model_preference}") + return app.state.model_preference -app: FastAPI = FastAPI() +@app.get("/set_model_preference") +async def set_model_preference(model: str): + """ + Set the model preference (Anthropic or OpenAI) and reinitialize the instructor + """ + if model not in ["Anthropic", "OpenAI"]: + logfire.error(f"Invalid model preference attempted: {model}") + return {"error": "Model must be 'Anthropic' or 'OpenAI'"} + + logfire.info(f"Changing model preference from {app.state.model_preference} to {model}") + app.state.model_preference = model + + # Reinitialize the instructor with new model preference + global instructor + if instructor: + await instructor.shutdown() + instructor = SystemInstructor(model_preference=model) + logfire.info(f"Instructor reinitialized with model preference: {model}") + + return {"message": f"Model preference set to {model}"} -async def generate_response(task: str, websocket: Optional[WebSocket] = None): - orchestrator: SystemInstructor = SystemInstructor() - return await orchestrator.run(task, websocket) +async def generate_response(task: str, websocket: Optional[WebSocket] = None, model_preference: str = None): + if model_preference is None: + model_preference = app.state.model_preference + logfire.info(f"Using model preference: {model_preference} for task: {task[:30]}...") + + global instructor + if not instructor: + instructor = SystemInstructor(model_preference=model_preference) + + return await instructor.run(task, websocket) @app.get("/agent/chat") -async def agent_chat(task: str) -> List: - final_agent_response = await generate_response(task) +async def agent_chat(task: str, model_preference: str = Depends(get_model_preference)) -> List: + logfire.info(f"Received chat request with model preference: {model_preference}") + final_agent_response = await generate_response(task, model_preference=model_preference) return final_agent_response @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - while True: - data = await websocket.receive_text() - await generate_response(data, websocket) + model_preference = app.state.model_preference + logfire.info(f"New connection using model preference: {model_preference}") + try: + while True: + try: + data = await websocket.receive_text() + logfire.info(f"Received message, using model: {model_preference}") + await generate_response(data, websocket, model_preference) + except WebSocketDisconnect: + logfire.info("[WEBSOCKET] Client disconnected") + break + except Exception as e: + logfire.error(f"[WEBSOCKET] Error handling message: {str(e)}") + if "disconnect message has been received" in str(e): + print(f"[WEBSOCKET] DIsconnect detected, closing connection: {str(e)}") + break + except Exception as e: + logfire.error(f"[WEBSOCKET] Connection error: {str(e)}") diff --git a/cortex_on/utils/ant_client.py b/cortex_on/utils/ant_client.py index 924f0c6..c04732c 100644 --- a/cortex_on/utils/ant_client.py +++ b/cortex_on/utils/ant_client.py @@ -1,5 +1,6 @@ from pydantic_ai.models.anthropic import AnthropicModel from anthropic import AsyncAnthropic +from openai import AsyncOpenAI import os from dotenv import load_dotenv @@ -12,3 +13,26 @@ def get_client(): max_retries=3, timeout=10000) return client + + + +def get_openai_client(): + api_key = os.getenv("OPENAI_API_KEY") + client = AsyncOpenAI(api_key=api_key, + max_retries=3, + timeout=10000) + return client + +def get_openai_model_instance(): + model_name = os.getenv("OPENAI_MODEL_NAME") + # Create model instance + from pydantic_ai.models.openai import OpenAIModel + model_instance = OpenAIModel(model_name=model_name, openai_client=get_openai_client()) + return model_instance + +def get_anthropic_model_instance(): + model_name = os.getenv("ANTHROPIC_MODEL_NAME") + # Create model instance + from pydantic_ai.models.anthropic import AnthropicModel + model_instance = AnthropicModel(model_name=model_name, anthropic_client=get_client()) + return model_instance diff --git a/frontend/src/components/home/ChatList.tsx b/frontend/src/components/home/ChatList.tsx index 9d8cb21..50fba0a 100644 --- a/frontend/src/components/home/ChatList.tsx +++ b/frontend/src/components/home/ChatList.tsx @@ -94,9 +94,30 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { retryOnError: true, shouldReconnect: () => true, reconnectInterval: 3000, + share: true, } ); + useEffect(() => { + if (readyState === ReadyState.CONNECTING) { + if (messages.length > 0) { + setIsLoading(true); + } + } else if (readyState === ReadyState.OPEN) { + setIsLoading(false); + } + }, [readyState, messages.length]); + + useEffect(() => { + if (messages.length === 0) { + setLiveUrl(""); + setCurrentOutput(null); + setOutputsList([]); + setIsLoading(false); + } + }, [messages.length]); + + const scrollToBottom = (smooth = true) => { if (scrollAreaRef.current) { const scrollableDiv = scrollAreaRef.current.querySelector( @@ -557,7 +578,7 @@ const ChatList = ({isLoading, setIsLoading}: ChatListPageProps) => { className="space-y-2 animate-fade-in animate-once animate-duration-500" key={idx} > - {readyState === ReadyState.CONNECTING ? ( + {readyState === ReadyState.CONNECTING && messages.length > 0 ? ( <> diff --git a/frontend/src/components/home/Header.tsx b/frontend/src/components/home/Header.tsx index d6217fd..4d142fb 100644 --- a/frontend/src/components/home/Header.tsx +++ b/frontend/src/components/home/Header.tsx @@ -4,6 +4,7 @@ import {useDispatch} from "react-redux"; import {useLocation, useNavigate} from "react-router-dom"; import Logo from "../../assets/CortexON_logo_dark.svg"; import {Button} from "../ui/button"; +import ModelToggle from "../ui/ModelToggle"; const Header = () => { const nav = useNavigate(); @@ -12,19 +13,20 @@ const Header = () => { return (
-
{ - dispatch(setMessages([])); - nav("/"); - }} - > - Logo -
-
+
+
{ + dispatch(setMessages([])); + nav("/"); + }} + > + Logo +
+
nav("/vault")} - className={`w-[10%] h-full flex justify-center items-center cursor-pointer border-b-2 hover:border-[#BD24CA] ${ + className={`h-full flex justify-center items-center cursor-pointer border-b-2 -mb-[27px] px-4 pb-7 hover:border-[#BD24CA] ${ location.includes("/vault") ? "border-[#BD24CA]" : "border-background" @@ -33,17 +35,25 @@ const Header = () => {

Vault

- + +
+ {/* Model Toggle Button - Smaller size */} +
+ +
+ + +
); }; diff --git a/frontend/src/components/ui/ModelToggle.tsx b/frontend/src/components/ui/ModelToggle.tsx new file mode 100644 index 0000000..ef97d77 --- /dev/null +++ b/frontend/src/components/ui/ModelToggle.tsx @@ -0,0 +1,61 @@ +import { useDispatch, useSelector } from "react-redux"; +import { RootState } from "@/dataStore/store"; +import { setModelPreference } from "@/dataStore/modelPreferenceSlice"; +import { useSetModelPreferenceMutation } from "@/services/modelPreferenceApi"; +import { useEffect, useState } from "react"; + +export default function ModelToggle() { + const dispatch = useDispatch(); + const modelPreference = useSelector((state: RootState) => state.modelPreference.preference); + const [setModelPreferenceApi] = useSetModelPreferenceMutation(); + const [isAnthropicSelected, setIsAnthropicSelected] = useState(modelPreference === "Anthropic"); + + const toggleModel = async () => { + const newPreference = isAnthropicSelected ? "OpenAI" : "Anthropic"; + setIsAnthropicSelected(!isAnthropicSelected); + + try { + // Update UI immediately + dispatch(setModelPreference(newPreference)); + + // Call API to update backend + await setModelPreferenceApi(newPreference).unwrap(); + } catch (error) { + console.error("Failed to update model preference:", error); + // Revert UI on error + setIsAnthropicSelected(isAnthropicSelected); + dispatch(setModelPreference(isAnthropicSelected ? "Anthropic" : "OpenAI")); + } + }; + + // Ensure UI reflects the redux state + useEffect(() => { + setIsAnthropicSelected(modelPreference === "Anthropic"); + }, [modelPreference]); + + return ( +
+ + Claude + + + {/* Simple toggle switch built with CSS */} + + + + GPT-4o + +
+ ); +} \ No newline at end of file diff --git a/frontend/src/dataStore/modelPreferenceSlice.ts b/frontend/src/dataStore/modelPreferenceSlice.ts new file mode 100644 index 0000000..7c1d602 --- /dev/null +++ b/frontend/src/dataStore/modelPreferenceSlice.ts @@ -0,0 +1,22 @@ +import { createSlice, PayloadAction } from "@reduxjs/toolkit"; + +interface ModelPreferenceState { + preference: "Anthropic" | "OpenAI"; +} + +const initialState: ModelPreferenceState = { + preference: "Anthropic", +}; + +const modelPreferenceSlice = createSlice({ + name: "modelPreference", + initialState, + reducers: { + setModelPreference: (state, action: PayloadAction<"Anthropic" | "OpenAI">) => { + state.preference = action.payload; + }, + }, +}); + +export const { setModelPreference } = modelPreferenceSlice.actions; +export default modelPreferenceSlice; \ No newline at end of file diff --git a/frontend/src/dataStore/store.ts b/frontend/src/dataStore/store.ts index c3be4e6..516535f 100644 --- a/frontend/src/dataStore/store.ts +++ b/frontend/src/dataStore/store.ts @@ -1,13 +1,18 @@ import vaultApi from "@/services/vaultApi"; +import modelPreferenceApi from "@/services/modelPreferenceApi"; import {configureStore} from "@reduxjs/toolkit"; import messagesSlice from "./messagesSlice"; +import modelPreferenceSlice from "./modelPreferenceSlice"; + export const store = configureStore({ reducer: { [messagesSlice.name]: messagesSlice.reducer, [vaultApi.reducerPath]: vaultApi.reducer, + [modelPreferenceSlice.name]: modelPreferenceSlice.reducer, + [modelPreferenceApi.reducerPath]: modelPreferenceApi.reducer, }, middleware: (getDefaultMiddleware) => - getDefaultMiddleware().concat(vaultApi.middleware), + getDefaultMiddleware().concat(vaultApi.middleware, modelPreferenceApi.middleware), }); export type RootState = ReturnType; diff --git a/frontend/src/lib/utils.ts b/frontend/src/lib/utils.ts index dfe0774..07ab4f7 100644 --- a/frontend/src/lib/utils.ts +++ b/frontend/src/lib/utils.ts @@ -1,5 +1,5 @@ -import {clsx, type ClassValue} from "clsx"; -import {twMerge} from "tailwind-merge"; +import { type ClassValue, clsx } from "clsx"; +import { twMerge } from "tailwind-merge"; export function cn(...inputs: ClassValue[]) { return twMerge(clsx(inputs)); diff --git a/frontend/src/services/modelPreferenceApi.ts b/frontend/src/services/modelPreferenceApi.ts new file mode 100644 index 0000000..e7320a6 --- /dev/null +++ b/frontend/src/services/modelPreferenceApi.ts @@ -0,0 +1,23 @@ +import { createApi } from "@reduxjs/toolkit/query/react"; +import { fetchBaseQuery } from "@reduxjs/toolkit/query/react"; + +// Get the base URL from environment variable or use a default +const CORTEX_ON_API_URL = import.meta.env.VITE_CORTEX_ON_API_URL || "http://localhost:8081"; + +const modelPreferenceApi = createApi({ + reducerPath: "modelPreferenceApi", + baseQuery: fetchBaseQuery({ + baseUrl: CORTEX_ON_API_URL, + }), + endpoints: (builder) => ({ + setModelPreference: builder.mutation<{ message: string }, string>({ + query: (model) => ({ + url: `/set_model_preference?model=${model}`, + method: "GET", + }), + }), + }), +}); + +export const { useSetModelPreferenceMutation } = modelPreferenceApi; +export default modelPreferenceApi; \ No newline at end of file diff --git a/ta-browser/core/orchestrator.py b/ta-browser/core/orchestrator.py index 9dbf16e..4d56511 100644 --- a/ta-browser/core/orchestrator.py +++ b/ta-browser/core/orchestrator.py @@ -244,27 +244,35 @@ def log_token_usage(self, agent_type: str, usage: Usage, step: Optional[int] = N """ ) - async def async_init(self, job_id: str, start_url: str = "https://google.com"): + async def async_init(self, job_id: str, start_url: str = "https://google.com", model_preference: str = "Anthropic"): """Initialize a new session context with improved error handling""" try: logger.info("Initializing browser session", extra={ "job_id": job_id, - "start_url": start_url + "start_url": start_url, + "model_preference": model_preference }) - - # 1. Initialize job_id and validate + logger.info(f"Starting async_init with model_preference: {model_preference}") + + # Store and validate parameters + self.job_id = job_id + if not job_id: - raise ValueError("job_id is required for initialization") - self.job_id = str(job_id) - logger.debug(f"job_id: {self.job_id}") + raise ValueError("job_id is required") - # 2. Initialize conversation storage + # 1. Create conversation_storage with job_id try: self.conversation_storage = ConversationStorage(job_id=self.job_id) except Exception as storage_error: raise RuntimeError(f"Failed to initialize conversation storage: {str(storage_error)}") from storage_error - - # 3. Set and validate URL + + # 2. Initialize browser manager + try: + await self.initialize_browser_manager(start_url) + except Exception as browser_error: + raise RuntimeError(f"Failed to initialize browser: {str(browser_error)}") from browser_error + + # 3. Set current URL and domain try: self.current_url = start_url self.current_domain = extract_domain(self.current_url) @@ -273,23 +281,16 @@ async def async_init(self, job_id: str, start_url: str = "https://google.com"): except InvalidURLError as url_error: raise ValueError(f"Invalid URL provided: {str(url_error)}") from url_error - # 4. Initialize browser manager with start_url - if not self.browser_manager: - try: - self.browser_manager = await self.initialize_browser_manager(start_url=start_url) - if not self.browser_manager: - raise RuntimeError("Browser manager initialization failed") - except Exception as browser_error: - raise RuntimeError(f"Failed to initialize browser manager: {str(browser_error)}") from browser_error - - # 5. Initialize client and agents + # 4. Initialize client and agents try: - from core.utils.init_client import initialize_client - self.client, model_instance = await initialize_client() + self.client, model_instance = await initialize_client(model_preference) self.initialize_agents(model_instance) + logger.info(f"Agents initialized successfully with model_preference: {model_preference}") except Exception as agent_error: - raise RuntimeError(f"Failed to initialize client and agents: {str(agent_error)}") from agent_error + error_msg = f"Failed to initialize client and agents: {str(agent_error)}" + logger.error(f"{error_msg}") + raise RuntimeError(error_msg) from agent_error self.async_init_done = True logger.debug("Async initialization completed successfully") diff --git a/ta-browser/core/server/models/web.py b/ta-browser/core/server/models/web.py index bf9e9fc..ca397f0 100644 --- a/ta-browser/core/server/models/web.py +++ b/ta-browser/core/server/models/web.py @@ -7,6 +7,7 @@ class StreamRequestModel(BaseModel): cmd: str = Field(..., description="Command to execute") url: str = Field("https://google.com", description="URL to navigate to") critique_disabled: bool = Field(False, description="Whether to disable critique") + model_preference: str = Field("Anthropic", description="Model provider preference (Anthropic or OpenAI)") @validator('url') def validate_and_format_url(cls, v): diff --git a/ta-browser/core/server/routes/web.py b/ta-browser/core/server/routes/web.py index a5cee6d..1be01af 100644 --- a/ta-browser/core/server/routes/web.py +++ b/ta-browser/core/server/routes/web.py @@ -1,17 +1,19 @@ import asyncio -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, WebSocket from fastapi.responses import StreamingResponse from typing import AsyncGenerator import json import time from datetime import datetime from queue import Empty +import os from core.server.models.web import StreamRequestModel, StreamResponseModel from core.server.constants import GLOBAL_TIMEOUT from core.server.utils.timeout import timeout from core.server.utils.session_tracker import SessionTracker from core.utils.logger import Logger +from core.utils.init_client import initialize_client logger = Logger() @@ -115,6 +117,7 @@ async def stream_session( and streams back real-time updates as the command is executed. """ print(f"[{time.time()}] Stream route: Starting processing") + print(f"[STREAM] Received request with model_preference: {request.model_preference}") session_tracker = SessionTracker() @@ -124,8 +127,12 @@ async def stream_session( # Initialize the session logger.debug(f"Initializing stream session with ID {session_id}") + print(f"[STREAM] Initializing session with model_preference: {request.model_preference}") session_info = await session_tracker.initialize_session( - request.url, request.critique_disabled, session_id + request.url, + request.critique_disabled, + session_id, + request.model_preference ) session_context = session_tracker.active_sessions.get(session_id) @@ -145,4 +152,5 @@ async def stream_session( except Exception as e: logger.error(f"Error in /stream: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) + diff --git a/ta-browser/core/server/utils/session_tracker.py b/ta-browser/core/server/utils/session_tracker.py index acf7168..fb3fa13 100644 --- a/ta-browser/core/server/utils/session_tracker.py +++ b/ta-browser/core/server/utils/session_tracker.py @@ -28,17 +28,19 @@ def get_active_sessions_status(self) -> dict: "sessions": list(self.active_sessions.keys()) } - async def initialize_session(self, start_url: str, no_crit: bool, session_id: str) -> Dict[str, Any]: + async def initialize_session(self, start_url: str, no_crit: bool, session_id: str, model_preference: str = "Anthropic") -> Dict[str, Any]: """Initialize a new session context""" logger.debug(f"Starting session initialization with URL: {start_url}") + print(f"[SESSION_TRACKER] Initializing session with model_preference: {model_preference}") logger.set_job_id(session_id) orchestrator = None try: orchestrator = Orchestrator(input_mode="API", no_crit=no_crit) - await orchestrator.async_init(job_id=session_id, start_url=start_url) - logger.debug(f"Orchestrator async_init completed with params: job_id={session_id}, start_url={start_url}") + print(f"[SESSION_TRACKER] Passing model_preference: {model_preference} to orchestrator.async_init") + await orchestrator.async_init(job_id=session_id, start_url=start_url, model_preference=model_preference) + logger.debug(f"Orchestrator async_init completed with params: job_id={session_id}, start_url={start_url}, model_preference={model_preference}") notification_queue = Queue() orchestrator.notification_queue = notification_queue @@ -49,7 +51,8 @@ async def initialize_session(self, start_url: str, no_crit: bool, session_id: st "notification_queue": notification_queue, "start_time": datetime.now(), "current_url": start_url, - "include_screenshot": False + "include_screenshot": False, + "model_preference": model_preference } self.add_active_session(session_id, session_context) diff --git a/ta-browser/core/utils/init_client.py b/ta-browser/core/utils/init_client.py index 7d170c6..5be2ea3 100644 --- a/ta-browser/core/utils/init_client.py +++ b/ta-browser/core/utils/init_client.py @@ -1,44 +1,81 @@ import os from core.utils.anthropic_client import create_client_with_retry as create_anthropic_client, AsyncAnthropic from core.utils.logger import Logger +from core.utils.openai_client import get_client logger = Logger() -async def initialize_client(): +async def initialize_client(model_preference: str = "Anthropic"): """ Initialize and return the Anthropic client and model instance + Args: + model_preference (str): The model provider to use ("Anthropic" or "OpenAI") + Returns: tuple: (client_instance, model_instance) """ try: - # Get API key from environment variable - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - logger.error("ANTHROPIC_API_KEY not found in environment variables") - raise ValueError("ANTHROPIC_API_KEY not found in environment variables") - - # Set model name - Claude 3.5 Sonnet - model_name = os.getenv("ANTHROPIC_MODEL_NAME") - - # Create client config - config = { - "api_key": api_key, - "model": model_name, - "max_retries": 3, - "timeout": 300.0 - } - - # Initialize client - client_instance = create_anthropic_client(AsyncAnthropic, config) - - # Create model instance - from pydantic_ai.models.anthropic import AnthropicModel - model_instance = AnthropicModel(model_name=model_name, anthropic_client=client_instance) - - logger.info(f"Anthropic client initialized successfully with model: {model_name}") - return client_instance, model_instance - + logger.info(f"Initializing client with model preference: {model_preference}") + if model_preference == "Anthropic": + # Get API key from environment variable + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + logger.error("ANTHROPIC_API_KEY not found in environment variables") + raise ValueError("ANTHROPIC_API_KEY not found in environment variables") + + # Set model name - Claude 3.5 Sonnet + model_name = os.getenv("ANTHROPIC_MODEL_NAME") + + # Create client config + config = { + "api_key": api_key, + "model": model_name, + "max_retries": 3, + "timeout": 300.0 + } + + # Initialize client + client_instance = create_anthropic_client(AsyncAnthropic, config) + + # Create model instance + from pydantic_ai.models.anthropic import AnthropicModel + model_instance = AnthropicModel(model_name=model_name, anthropic_client=client_instance) + + logger.info(f"Anthropic client initialized successfully with model: {model_name}") + return client_instance, model_instance + elif model_preference == "OpenAI": + # Get API key from environment variable + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + logger.error("OPENAI_API_KEY not found in environment variables") + raise ValueError("OPENAI_API_KEY not found in environment variables") + + # Set model name - GPT-4o + model_name = os.getenv("OPENAI_MODEL_NAME") + + # Create client config + config = { + "api_key": api_key, + "model": model_name, + "max_retries": 3, + "timeout": 300.0 + } + + # Initialize client + client_instance = get_client() + + # Create model instance + from pydantic_ai.models.openai import OpenAIModel + model_instance = OpenAIModel(model_name=model_name, openai_client=client_instance) + + logger.info(f"OpenAI client initialized successfully with model: {model_name}") + return client_instance, model_instance + else: + error_msg = f"Invalid model preference: {model_preference}. Must be 'Anthropic' or 'OpenAI'" + raise ValueError(error_msg) + except Exception as e: - logger.error(f"Error initializing Anthropic client: {str(e)}", exc_info=True) + error_msg = f"Error initializing client: {str(e)}" + logger.error(error_msg, exc_info=True) raise \ No newline at end of file diff --git a/ta-browser/core/utils/openai_client.py b/ta-browser/core/utils/openai_client.py index 80670e3..5c74992 100644 --- a/ta-browser/core/utils/openai_client.py +++ b/ta-browser/core/utils/openai_client.py @@ -28,57 +28,15 @@ def validate_model(model: str) -> bool: return True @staticmethod def get_text_config() -> Dict: - model = get_env_var("AGENTIC_BROWSER_TEXT_MODEL") - if not OpenAIConfig.validate_model(model): - raise ModelValidationError( - f"Invalid model: {model}. Must match one of the patterns: " - f"{', '.join(OpenAIConfig.VALID_MODEL_PATTERNS)}" - ) - return { - "api_key": get_env_var("AGENTIC_BROWSER_TEXT_API_KEY"), - "base_url": get_env_var("AGENTIC_BROWSER_TEXT_BASE_URL"), - "model": model, + "api_key": get_env_var("OPENAI_API_KEY"), + "base_url": "https://api.openai.com/v1", + "model": get_env_var("OPENAI_MODEL_NAME"), "max_retries": 3, "timeout": 300.0 } - @staticmethod - def get_ss_config() -> Dict: - model = get_env_var("AGENTIC_BROWSER_SS_MODEL") - if not OpenAIConfig.validate_model(model): - raise ModelValidationError( - f"Invalid model: {model}. Must match one of the patterns: " - f"{', '.join(OpenAIConfig.VALID_MODEL_PATTERNS)}" - ) - - return { - "api_key": get_env_var("AGENTIC_BROWSER_SS_API_KEY"), - "base_url": get_env_var("AGENTIC_BROWSER_SS_BASE_URL"), - "model": model, - "max_retries": 3, - "timeout": 300.0 - } -async def validate_models(client: AsyncOpenAI) -> bool: - """Validate that configured models are available""" - try: - available_models = await client.models.list() - available_model_ids = [model.id for model in available_models.data] - - text_model = get_text_model() - ss_model = get_ss_model() - - if text_model not in available_model_ids: - raise ModelValidationError(f"Text model '{text_model}' not available. Available models: {', '.join(available_model_ids)}") - - if ss_model not in available_model_ids: - raise ModelValidationError(f"Screenshot model '{ss_model}' not available. Available models: {', '.join(available_model_ids)}") - - return True - except Exception as e: - logger.error(f"Model validation failed: {str(e)}") - return False def create_client_with_retry(client_class, config: dict): """Create an OpenAI client with proper error handling""" @@ -102,26 +60,9 @@ def get_client(): config = OpenAIConfig.get_text_config() return create_client_with_retry(AsyncOpenAI, config) -def get_ss_client(): - """Get OpenAI client for screenshot analysis""" - config = OpenAIConfig.get_ss_config() - return create_client_with_retry(OpenAI, config) - -def get_text_model() -> str: - """Get model name for text analysis""" - return OpenAIConfig.get_text_config()["model"] - -def get_ss_model() -> str: - """Get model name for screenshot analysis""" - return OpenAIConfig.get_ss_config()["model"] - # Example usage async def initialize_and_validate(): """Initialize client and validate configuration""" client = get_client() - # Validate models - if not await validate_models(client): - raise ModelValidationError("Failed to validate models. Please check your configuration.") - return client \ No newline at end of file