Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions scripts/advanced_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]:

for symbol in symbols:
for dep in symbol.dependencies:
if dep in symbols:
if dep in symbols and dep != symbol: # 忽略自引用
graph[dep].add(symbol)
in_degree[symbol] += 1

Expand Down Expand Up @@ -1293,13 +1293,95 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]:
# 检查循环依赖
if len(sorted_symbols) != len(symbols):
remaining = symbols - set(sorted_symbols)
raise CircularDependencyError(
f"Circular dependency detected among symbols: "
f"{', '.join(s.qname for s in remaining)}"
)
# 找出循环依赖的详细路径
cycles = self._find_cycles_in_graph(graph, remaining)
error_msg = "Circular dependency detected:\n"
if cycles:
for i, cycle in enumerate(cycles, 1):
cycle_str = " -> ".join(s.qname for s in cycle)
error_msg += f" Cycle {i}: {cycle_str}\n"
else:
# 如果找不到明确的循环,列出所有剩余符号
error_msg += f" Involved symbols: {', '.join(s.qname for s in remaining)}\n"

# 添加更详细的调试信息
error_msg += "\nDetailed dependency information:\n"
for sym in sorted(remaining, key=lambda s: s.qname): # 显示所有剩余符号
# 显示所有依赖,不只是在remaining中的
all_deps = [d.qname for d in sym.dependencies]
deps_in_remaining = [d.qname for d in sym.dependencies if d in remaining]
deps_in_output = [d.qname for d in sym.dependencies if d in symbols and d not in remaining]

if all_deps:
error_msg += f" {sym.qname}:\n"
error_msg += f" All dependencies: {', '.join(all_deps)}\n"
if deps_in_remaining:
error_msg += f" Dependencies in cycle: {', '.join(deps_in_remaining)}\n"
if deps_in_output:
error_msg += f" Dependencies already processed: {', '.join(deps_in_output)}\n"

raise CircularDependencyError(error_msg)

return sorted_symbols

def _find_cycles_in_graph(self, graph: Dict[Symbol, Set[Symbol]], candidates: Set[Symbol]) -> List[List[Symbol]]:
"""在图中查找循环依赖的路径

使用DFS算法查找所有循环路径
"""
cycles = []

# 构建反向图(用于查找谁依赖于某个符号)
reverse_graph = defaultdict(set)
for node, deps in graph.items():
for dep in deps:
reverse_graph[dep].add(node)

# Tarjan算法查找强连通分量
index_counter = [0]
stack = []
lowlinks = {}
index = {}
on_stack = defaultdict(bool)

def strongconnect(node: Symbol):
"""Tarjan算法的核心函数"""
index[node] = index_counter[0]
lowlinks[node] = index_counter[0]
index_counter[0] += 1
stack.append(node)
on_stack[node] = True

# 考虑后继节点
for successor in graph.get(node, []):
if successor in candidates:
if successor not in index:
strongconnect(successor)
lowlinks[node] = min(lowlinks[node], lowlinks[successor])
elif on_stack[successor]:
lowlinks[node] = min(lowlinks[node], index[successor])

# 如果node是强连通分量的根
if lowlinks[node] == index[node]:
component = []
while True:
w = stack.pop()
on_stack[w] = False
component.append(w)
if w == node:
break
# 只添加包含多个节点的强连通分量(即循环)
if len(component) > 1:
component.reverse() # 恢复正确的顺序
cycles.append(component)

# 对每个候选节点运行Tarjan算法
for symbol in candidates:
if symbol not in index:
strongconnect(symbol)

return cycles

def _ast_equal(self, node1: ast.AST, node2: ast.AST) -> bool:
"""检查两个 AST 节点是否相等"""
return ast.dump(node1) == ast.dump(node2)
Expand Down Expand Up @@ -1639,6 +1721,24 @@ def merge_script(self, script_path: Path) -> str:
for s in self.needed_symbols:
if s.symbol_type in ('import_alias', 'module', 'parameter'):
continue

# 检查是否是类的方法(通过判断qname中是否包含类名)
is_class_method = False
if s.symbol_type == 'function' and '.' in s.qname:
# 检查是否有对应的类符号
parts = s.qname.rsplit('.', 1)
if len(parts) == 2:
potential_class_qname = parts[0]
# 查找是否有对应的类
for other_s in self.needed_symbols:
if other_s.symbol_type == 'class' and other_s.qname == potential_class_qname:
is_class_method = True
break

if is_class_method:
# 类的方法不需要单独输出,它们会随类一起输出
continue

if s.is_nested:
# 检查是否有其他符号依赖这个嵌套符号
# 暂时保留所有嵌套类,因为它们可能被属性访问使用
Expand Down
181 changes: 181 additions & 0 deletions tests/test_circular_dependency_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""测试循环依赖检测和诊断功能"""
import pytest
import tempfile
import shutil
from pathlib import Path
import sys

# 添加项目根目录到 Python 路径
sys.path.insert(0, str(Path(__file__).parent.parent))

from scripts.advanced_merge import CircularDependencyError, AdvancedCodeMerger


class TestCircularDependencyDetection:
"""测试循环依赖检测功能"""

def setup_method(self):
"""设置测试环境"""
self.test_dir = Path(tempfile.mkdtemp())

def teardown_method(self):
"""清理测试环境"""
shutil.rmtree(self.test_dir)

def test_simple_circular_dependency(self):
"""测试简单的循环依赖检测"""
# 创建测试文件
module_a = self.test_dir / "module_a.py"
module_a.write_text("""
from module_b import func_b

def func_a():
return func_b()
""")

module_b = self.test_dir / "module_b.py"
module_b.write_text("""
from module_a import func_a

def func_b():
return func_a()
""")

main_script = self.test_dir / "main.py"
main_script.write_text("""
from module_a import func_a

if __name__ == "__main__":
print(func_a())
""")

# 测试合并器应该检测到循环依赖
merger = AdvancedCodeMerger(self.test_dir)
with pytest.raises(CircularDependencyError) as exc_info:
merger.merge_script(main_script)

error_msg = str(exc_info.value)
assert "Circular dependency detected" in error_msg
# 应该包含循环路径信息
assert "func_a" in error_msg and "func_b" in error_msg

def test_class_method_circular_dependency(self):
"""测试类和方法之间的循环依赖"""
# 创建测试文件
class_a = self.test_dir / "class_a.py"
class_a.write_text("""
from class_b import ClassB

class ClassA:
def method_a(self):
b = ClassB()
return b.method_b()
""")

class_b = self.test_dir / "class_b.py"
class_b.write_text("""
from class_a import ClassA

class ClassB:
def method_b(self):
a = ClassA()
return a.method_a()
""")

main_script = self.test_dir / "main.py"
main_script.write_text("""
from class_a import ClassA

if __name__ == "__main__":
a = ClassA()
print(a.method_a())
""")

# 测试合并器应该检测到循环依赖
merger = AdvancedCodeMerger(self.test_dir)
with pytest.raises(CircularDependencyError) as exc_info:
merger.merge_script(main_script)

error_msg = str(exc_info.value)
assert "Circular dependency detected" in error_msg

def test_complex_circular_dependency(self):
"""测试复杂的多节点循环依赖"""
# 创建测试文件
module_a = self.test_dir / "module_a.py"
module_a.write_text("""
from module_b import func_b

def func_a():
return func_b()
""")

module_b = self.test_dir / "module_b.py"
module_b.write_text("""
from module_c import func_c

def func_b():
return func_c()
""")

module_c = self.test_dir / "module_c.py"
module_c.write_text("""
from module_a import func_a

def func_c():
return func_a()
""")

main_script = self.test_dir / "main.py"
main_script.write_text("""
from module_a import func_a

if __name__ == "__main__":
print(func_a())
""")

# 测试合并器应该检测到循环依赖
merger = AdvancedCodeMerger(self.test_dir)
with pytest.raises(CircularDependencyError) as exc_info:
merger.merge_script(main_script)

error_msg = str(exc_info.value)
assert "Circular dependency detected" in error_msg
# 应该显示完整的循环路径
assert "func_a" in error_msg and "func_b" in error_msg and "func_c" in error_msg

def test_no_circular_dependency(self):
"""测试没有循环依赖的情况"""
# 创建测试文件
module_a = self.test_dir / "module_a.py"
module_a.write_text("""
def func_a():
return "A"
""")

module_b = self.test_dir / "module_b.py"
module_b.write_text("""
from module_a import func_a

def func_b():
return func_a() + "B"
""")

main_script = self.test_dir / "main.py"
main_script.write_text("""
from module_b import func_b

if __name__ == "__main__":
print(func_b())
""")

# 测试合并器应该成功
merger = AdvancedCodeMerger(self.test_dir)
merged_code = merger.merge_script(main_script)
assert merged_code is not None
assert "func_a" in merged_code
assert "func_b" in merged_code


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading