|
6 | 6 | from codesage.governance.task_models import GovernancePlan, GovernanceTask |
7 | 7 | from codesage.llm.client import BaseLLMClient, LLMRequest |
8 | 8 | from codesage.governance.patch_manager import PatchManager |
| 9 | +from codesage.governance.validator import CodeValidator |
| 10 | +from codesage.config.governance import GovernanceConfig |
9 | 11 |
|
10 | 12 | logger = structlog.get_logger() |
11 | 13 |
|
12 | 14 | RISK_LEVEL_MAP = {"low": 1, "medium": 2, "high": 3, "unknown": 0} |
13 | 15 |
|
14 | 16 | class TaskOrchestrator: |
15 | | - def __init__(self, plan: GovernancePlan, llm_client: Optional[BaseLLMClient] = None) -> None: |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + plan: GovernancePlan, |
| 20 | + llm_client: Optional[BaseLLMClient] = None, |
| 21 | + config: Optional[GovernanceConfig] = None |
| 22 | + ) -> None: |
16 | 23 | self._plan = plan |
17 | 24 | self._all_tasks: List[GovernanceTask] = self._flatten_tasks() |
18 | 25 | self.llm_client = llm_client |
19 | 26 | self.patch_manager = PatchManager() |
| 27 | + self.config = config or GovernanceConfig.default() |
| 28 | + self.validator = CodeValidator(self.config) |
20 | 29 |
|
21 | 30 | def _flatten_tasks(self) -> List[GovernanceTask]: |
22 | 31 | """Extracts and flattens all tasks from the plan's groups.""" |
@@ -63,66 +72,94 @@ def select_tasks( |
63 | 72 |
|
64 | 73 | return filtered_tasks |
65 | 74 |
|
66 | | - def execute_task(self, task: GovernanceTask, apply_fix: bool = False) -> bool: |
| 75 | + def execute_task(self, task: GovernanceTask, apply_fix: bool = False, max_retries: int = 3) -> bool: |
67 | 76 | """ |
68 | 77 | Executes a governance task using the LLM client and optionally applies the fix. |
| 78 | + Includes a validation loop with rollback and retry. |
69 | 79 | """ |
70 | 80 | if not self.llm_client: |
71 | 81 | logger.warning("LLM client not configured, skipping execution", task_id=task.id) |
72 | 82 | return False |
73 | 83 |
|
74 | 84 | logger.info("Executing task", task_id=task.id, file=task.file_path) |
75 | 85 |
|
76 | | - # 1. Prepare context and prompt |
77 | | - # Assuming task.context contains necessary info or we read file |
78 | 86 | file_path = Path(task.file_path) |
79 | 87 | if not file_path.exists(): |
80 | 88 | logger.error("File not found", file_path=str(file_path)) |
81 | 89 | return False |
82 | 90 |
|
83 | | - file_content = file_path.read_text(encoding="utf-8") |
| 91 | + original_content = file_path.read_text(encoding="utf-8") |
84 | 92 |
|
85 | | - # Construct a prompt (This logic might be moved to a PromptBuilder later) |
86 | | - prompt = ( |
| 93 | + # Initial Prompt |
| 94 | + base_prompt = ( |
87 | 95 | f"Fix the following issue in {task.file_path}:\n" |
88 | | - f"Issue: {task.issue_type} - {task.message}\n" |
89 | | - f"Severity: {task.severity}\n\n" |
| 96 | + f"Issue: {task.rule_id} - {task.description}\n" |
| 97 | + f"Severity: {task.risk_level}\n\n" |
90 | 98 | f"Here is the file content:\n" |
91 | | - f"```\n{file_content}\n```\n\n" |
| 99 | + f"```\n{original_content}\n```\n\n" |
92 | 100 | f"Please provide the FULL corrected file content in a markdown code block." |
93 | 101 | ) |
94 | 102 |
|
95 | | - # 2. Call LLM |
96 | | - request = LLMRequest( |
97 | | - prompt=prompt, |
98 | | - metadata={"task_id": task.id, "file_path": task.file_path} |
99 | | - ) |
| 103 | + current_prompt = base_prompt |
| 104 | + attempts = 0 |
100 | 105 |
|
101 | | - try: |
102 | | - response = self.llm_client.generate(request) |
103 | | - except Exception as e: |
104 | | - logger.error("LLM generation failed", error=str(e)) |
105 | | - return False |
| 106 | + while attempts <= max_retries: |
| 107 | + # 1. Call LLM |
| 108 | + request = LLMRequest( |
| 109 | + prompt=current_prompt, |
| 110 | + metadata={"task_id": task.id, "file_path": task.file_path, "attempt": attempts} |
| 111 | + ) |
106 | 112 |
|
107 | | - # 3. Extract Code |
108 | | - new_content = self.patch_manager.extract_code_block(response.content) |
109 | | - if not new_content: |
110 | | - logger.error("Failed to extract code from LLM response") |
111 | | - return False |
| 113 | + try: |
| 114 | + response = self.llm_client.generate(request) |
| 115 | + except Exception as e: |
| 116 | + logger.error("LLM generation failed", error=str(e)) |
| 117 | + return False |
112 | 118 |
|
113 | | - # 4. Apply Fix if requested |
114 | | - if apply_fix: |
115 | | - success = self.patch_manager.apply_patch(file_path, new_content) |
116 | | - if success: |
117 | | - task.status = "done" |
118 | | - logger.info("Task completed and patch applied", task_id=task.id) |
| 119 | + # 2. Extract Code |
| 120 | + new_content = self.patch_manager.extract_code_block(response.content, language=task.language) |
| 121 | + if not new_content: |
| 122 | + logger.error("Failed to extract code from LLM response", attempt=attempts) |
| 123 | + attempts += 1 |
| 124 | + continue |
| 125 | + |
| 126 | + # 3. Apply Fix (or Dry Run) |
| 127 | + if not apply_fix: |
| 128 | + diff = self.patch_manager.create_diff(original_content, new_content, filename=task.file_path) |
| 129 | + print(f"--- Patch for {task.file_path} (Dry Run) ---\n{diff}\n-----------------------------") |
| 130 | + logger.info("Dry run completed", task_id=task.id) |
119 | 131 | return True |
| 132 | + |
| 133 | + # Apply with backup |
| 134 | + if self.patch_manager.apply_patch(file_path, new_content, create_backup=True): |
| 135 | + # 4. Validate |
| 136 | + # We use file_path as scope for now. Ideally, we should detect the test scope. |
| 137 | + validation_result = self.validator.validate( |
| 138 | + file_path, |
| 139 | + language=task.language, |
| 140 | + related_test_scope=str(file_path) |
| 141 | + ) |
| 142 | + |
| 143 | + if validation_result.success: |
| 144 | + logger.info("Validation passed", task_id=task.id) |
| 145 | + self.patch_manager.cleanup_backup(file_path) |
| 146 | + task.status = "done" |
| 147 | + return True |
| 148 | + else: |
| 149 | + logger.warning("Validation failed, rolling back", task_id=task.id, error=validation_result.error) |
| 150 | + self.patch_manager.revert(file_path) |
| 151 | + |
| 152 | + # Prepare retry prompt |
| 153 | + current_prompt = ( |
| 154 | + f"{base_prompt}\n\n" |
| 155 | + f"Previous attempt failed validation ({validation_result.stage}):\n" |
| 156 | + f"Error:\n{validation_result.error}\n\n" |
| 157 | + f"Please try again and fix the error." |
| 158 | + ) |
120 | 159 | else: |
121 | | - logger.error("Failed to apply patch", task_id=task.id) |
122 | | - return False |
123 | | - else: |
124 | | - # Just generate diff for dry-run |
125 | | - diff = self.patch_manager.create_diff(file_content, new_content, filename=task.file_path) |
126 | | - print(f"--- Patch for {task.file_path} ---\n{diff}\n-----------------------------") |
127 | | - logger.info("Dry run completed", task_id=task.id) |
128 | | - return True |
| 160 | + logger.error("Failed to apply patch", task_id=task.id) |
| 161 | + |
| 162 | + attempts += 1 |
| 163 | + |
| 164 | + logger.error("Task failed after retries", task_id=task.id) |
| 165 | + return False |
0 commit comments