diff --git a/scripts/advanced_merge.py b/scripts/advanced_merge.py index 99ec005..c9e5b05 100644 --- a/scripts/advanced_merge.py +++ b/scripts/advanced_merge.py @@ -1263,7 +1263,7 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]: for symbol in symbols: for dep in symbol.dependencies: - if dep in symbols: + if dep in symbols and dep != symbol: # 忽略自引用 graph[dep].add(symbol) in_degree[symbol] += 1 @@ -1293,13 +1293,95 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]: # 检查循环依赖 if len(sorted_symbols) != len(symbols): remaining = symbols - set(sorted_symbols) - raise CircularDependencyError( - f"Circular dependency detected among symbols: " - f"{', '.join(s.qname for s in remaining)}" - ) + # 找出循环依赖的详细路径 + cycles = self._find_cycles_in_graph(graph, remaining) + error_msg = "Circular dependency detected:\n" + if cycles: + for i, cycle in enumerate(cycles, 1): + cycle_str = " -> ".join(s.qname for s in cycle) + error_msg += f" Cycle {i}: {cycle_str}\n" + else: + # 如果找不到明确的循环,列出所有剩余符号 + error_msg += f" Involved symbols: {', '.join(s.qname for s in remaining)}\n" + + # 添加更详细的调试信息 + error_msg += "\nDetailed dependency information:\n" + for sym in sorted(remaining, key=lambda s: s.qname): # 显示所有剩余符号 + # 显示所有依赖,不只是在remaining中的 + all_deps = [d.qname for d in sym.dependencies] + deps_in_remaining = [d.qname for d in sym.dependencies if d in remaining] + deps_in_output = [d.qname for d in sym.dependencies if d in symbols and d not in remaining] + + if all_deps: + error_msg += f" {sym.qname}:\n" + error_msg += f" All dependencies: {', '.join(all_deps)}\n" + if deps_in_remaining: + error_msg += f" Dependencies in cycle: {', '.join(deps_in_remaining)}\n" + if deps_in_output: + error_msg += f" Dependencies already processed: {', '.join(deps_in_output)}\n" + + raise CircularDependencyError(error_msg) return sorted_symbols + def _find_cycles_in_graph(self, graph: Dict[Symbol, Set[Symbol]], candidates: Set[Symbol]) -> List[List[Symbol]]: + """在图中查找循环依赖的路径 + + 使用DFS算法查找所有循环路径 + """ + cycles = [] + + # 构建反向图(用于查找谁依赖于某个符号) + reverse_graph = defaultdict(set) + for node, deps in graph.items(): + for dep in deps: + reverse_graph[dep].add(node) + + # Tarjan算法查找强连通分量 + index_counter = [0] + stack = [] + lowlinks = {} + index = {} + on_stack = defaultdict(bool) + + def strongconnect(node: Symbol): + """Tarjan算法的核心函数""" + index[node] = index_counter[0] + lowlinks[node] = index_counter[0] + index_counter[0] += 1 + stack.append(node) + on_stack[node] = True + + # 考虑后继节点 + for successor in graph.get(node, []): + if successor in candidates: + if successor not in index: + strongconnect(successor) + lowlinks[node] = min(lowlinks[node], lowlinks[successor]) + elif on_stack[successor]: + lowlinks[node] = min(lowlinks[node], index[successor]) + + # 如果node是强连通分量的根 + if lowlinks[node] == index[node]: + component = [] + while True: + w = stack.pop() + on_stack[w] = False + component.append(w) + if w == node: + break + # 只添加包含多个节点的强连通分量(即循环) + if len(component) > 1: + component.reverse() # 恢复正确的顺序 + cycles.append(component) + + # 对每个候选节点运行Tarjan算法 + for symbol in candidates: + if symbol not in index: + strongconnect(symbol) + + return cycles + def _ast_equal(self, node1: ast.AST, node2: ast.AST) -> bool: """检查两个 AST 节点是否相等""" return ast.dump(node1) == ast.dump(node2) @@ -1639,6 +1721,24 @@ def merge_script(self, script_path: Path) -> str: for s in self.needed_symbols: if s.symbol_type in ('import_alias', 'module', 'parameter'): continue + + # 检查是否是类的方法(通过判断qname中是否包含类名) + is_class_method = False + if s.symbol_type == 'function' and '.' in s.qname: + # 检查是否有对应的类符号 + parts = s.qname.rsplit('.', 1) + if len(parts) == 2: + potential_class_qname = parts[0] + # 查找是否有对应的类 + for other_s in self.needed_symbols: + if other_s.symbol_type == 'class' and other_s.qname == potential_class_qname: + is_class_method = True + break + + if is_class_method: + # 类的方法不需要单独输出,它们会随类一起输出 + continue + if s.is_nested: # 检查是否有其他符号依赖这个嵌套符号 # 暂时保留所有嵌套类,因为它们可能被属性访问使用 diff --git a/tests/test_circular_dependency_detection.py b/tests/test_circular_dependency_detection.py new file mode 100644 index 0000000..ba8433d --- /dev/null +++ b/tests/test_circular_dependency_detection.py @@ -0,0 +1,181 @@ +"""测试循环依赖检测和诊断功能""" +import pytest +import tempfile +import shutil +from pathlib import Path +import sys + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from scripts.advanced_merge import CircularDependencyError, AdvancedCodeMerger + + +class TestCircularDependencyDetection: + """测试循环依赖检测功能""" + + def setup_method(self): + """设置测试环境""" + self.test_dir = Path(tempfile.mkdtemp()) + + def teardown_method(self): + """清理测试环境""" + shutil.rmtree(self.test_dir) + + def test_simple_circular_dependency(self): + """测试简单的循环依赖检测""" + # 创建测试文件 + module_a = self.test_dir / "module_a.py" + module_a.write_text(""" +from module_b import func_b + +def func_a(): + return func_b() +""") + + module_b = self.test_dir / "module_b.py" + module_b.write_text(""" +from module_a import func_a + +def func_b(): + return func_a() +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from module_a import func_a + +if __name__ == "__main__": + print(func_a()) +""") + + # 测试合并器应该检测到循环依赖 + merger = AdvancedCodeMerger(self.test_dir) + with pytest.raises(CircularDependencyError) as exc_info: + merger.merge_script(main_script) + + error_msg = str(exc_info.value) + assert "Circular dependency detected" in error_msg + # 应该包含循环路径信息 + assert "func_a" in error_msg and "func_b" in error_msg + + def test_class_method_circular_dependency(self): + """测试类和方法之间的循环依赖""" + # 创建测试文件 + class_a = self.test_dir / "class_a.py" + class_a.write_text(""" +from class_b import ClassB + +class ClassA: + def method_a(self): + b = ClassB() + return b.method_b() +""") + + class_b = self.test_dir / "class_b.py" + class_b.write_text(""" +from class_a import ClassA + +class ClassB: + def method_b(self): + a = ClassA() + return a.method_a() +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from class_a import ClassA + +if __name__ == "__main__": + a = ClassA() + print(a.method_a()) +""") + + # 测试合并器应该检测到循环依赖 + merger = AdvancedCodeMerger(self.test_dir) + with pytest.raises(CircularDependencyError) as exc_info: + merger.merge_script(main_script) + + error_msg = str(exc_info.value) + assert "Circular dependency detected" in error_msg + + def test_complex_circular_dependency(self): + """测试复杂的多节点循环依赖""" + # 创建测试文件 + module_a = self.test_dir / "module_a.py" + module_a.write_text(""" +from module_b import func_b + +def func_a(): + return func_b() +""") + + module_b = self.test_dir / "module_b.py" + module_b.write_text(""" +from module_c import func_c + +def func_b(): + return func_c() +""") + + module_c = self.test_dir / "module_c.py" + module_c.write_text(""" +from module_a import func_a + +def func_c(): + return func_a() +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from module_a import func_a + +if __name__ == "__main__": + print(func_a()) +""") + + # 测试合并器应该检测到循环依赖 + merger = AdvancedCodeMerger(self.test_dir) + with pytest.raises(CircularDependencyError) as exc_info: + merger.merge_script(main_script) + + error_msg = str(exc_info.value) + assert "Circular dependency detected" in error_msg + # 应该显示完整的循环路径 + assert "func_a" in error_msg and "func_b" in error_msg and "func_c" in error_msg + + def test_no_circular_dependency(self): + """测试没有循环依赖的情况""" + # 创建测试文件 + module_a = self.test_dir / "module_a.py" + module_a.write_text(""" +def func_a(): + return "A" +""") + + module_b = self.test_dir / "module_b.py" + module_b.write_text(""" +from module_a import func_a + +def func_b(): + return func_a() + "B" +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from module_b import func_b + +if __name__ == "__main__": + print(func_b()) +""") + + # 测试合并器应该成功 + merger = AdvancedCodeMerger(self.test_dir) + merged_code = merger.merge_script(main_script) + assert merged_code is not None + assert "func_a" in merged_code + assert "func_b" in merged_code + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file