diff --git a/pysymphony/auditor/auditor.py b/pysymphony/auditor/auditor.py index 03085a8..d43f0fe 100644 --- a/pysymphony/auditor/auditor.py +++ b/pysymphony/auditor/auditor.py @@ -95,6 +95,15 @@ def visit_FunctionDef(self, node: ast.FunctionDef): # 添加参数到函数作用域 for arg in node.args.args: self.add_symbol(arg.arg, arg, 'variable') + # 添加 *args + if node.args.vararg: + self.add_symbol(node.args.vararg.arg, node.args.vararg, 'variable') + # 添加 **kwargs + if node.args.kwarg: + self.add_symbol(node.args.kwarg.arg, node.args.kwarg, 'variable') + # 添加仅关键字参数 + for arg in node.args.kwonlyargs: + self.add_symbol(arg.arg, arg, 'variable') self.generic_visit(node) self.exit_scope() @@ -102,8 +111,18 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): """访问异步函数定义""" self.add_symbol(node.name, node, 'function') self.enter_scope(node.name, 'function') + # 添加参数到函数作用域 for arg in node.args.args: self.add_symbol(arg.arg, arg, 'variable') + # 添加 *args + if node.args.vararg: + self.add_symbol(node.args.vararg.arg, node.args.vararg, 'variable') + # 添加 **kwargs + if node.args.kwarg: + self.add_symbol(node.args.kwarg.arg, node.args.kwarg, 'variable') + # 添加仅关键字参数 + for arg in node.args.kwonlyargs: + self.add_symbol(arg.arg, arg, 'variable') self.generic_visit(node) self.exit_scope() diff --git a/scripts/advanced_merge.py b/scripts/advanced_merge.py index e6fa61d..2d20359 100644 --- a/scripts/advanced_merge.py +++ b/scripts/advanced_merge.py @@ -303,9 +303,8 @@ def visit_Module(self, node: ast.Module): init_statements.append(stmt) elif isinstance(stmt, ast.Try) and self._is_try_import_error(stmt): # 特殊处理 try...except ImportError 块 - # 保留完整的块作为初始化语句 - init_statements.append(stmt) - # 访问以分析内部的导入 + # 不再将运行时导入块保留为初始化语句 + # 仅访问以分析内部的导入,这些导入会被转换为普通导入 self.visit(stmt) else: # 其他顶层语句(副作用初始化) @@ -1059,6 +1058,46 @@ def __init__(self, project_root: Path): self.import_registry: Set[Tuple[str, str]] = set() # 解决 #2: (module, alias) self.entry_module_qname: Optional[str] = None # 解决 #3: 入口脚本的模块名 + def _fix_import_alias_dependencies(self): + """修复 import_alias 符号的依赖关系 + + 在所有模块都被分析后,重新建立 import_alias 到实际符号的依赖关系。 + 这解决了循环导入时依赖关系缺失的问题。 + """ + for symbol in list(self.visitor.all_symbols.values()): + if symbol.symbol_type != 'import_alias': + continue + + # 检查是否是 from ... import ... 类型的导入 + if isinstance(symbol.def_node, ast.ImportFrom): + node = symbol.def_node + module_name = node.module or '' + level = node.level or 0 + + # 解析模块路径 + if level > 0: + # 相对导入 + from_module = self.visitor.get_absolute_module_name( + module_name, symbol.scope.module_path, level + ) + else: + from_module = module_name + + # 尝试找到模块路径 + module_path = self.visitor.resolve_module_path(from_module) + if not module_path: + continue + + # 查找具体的导入项 + for alias in node.names: + if alias.asname == symbol.name or (not alias.asname and alias.name == symbol.name): + # 找到对应的符号 + target_qname = f"{self.visitor.get_module_qname(module_path)}.{alias.name}" + if target_qname in self.visitor.all_symbols: + target_symbol = self.visitor.all_symbols[target_qname] + # 建立依赖关系 + symbol.dependencies.add(target_symbol) + def analyze_entry_script(self, script_path: Path) -> Tuple[Set[Symbol], List[ast.AST]]: """ 分析入口脚本,返回初始符号集和主代码。 @@ -1067,6 +1106,9 @@ def analyze_entry_script(self, script_path: Path) -> Tuple[Set[Symbol], List[ast # 1. 执行唯一且完整的分析过程,此过程会填充 visitor 的所有状态 self.visitor.analyze_module(script_path) self.entry_module_qname = self.visitor.get_module_qname(script_path) # 记录入口模块名 + + # 修复 import_alias 的依赖关系 + self._fix_import_alias_dependencies() initial_symbols = set() main_code = [] @@ -1286,9 +1328,29 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]: for symbol in symbols: in_degree[symbol] = 0 + + # 解析传递依赖:如果符号依赖一个import_alias,需要找到该import_alias的实际目标 + def resolve_transitive_deps(symbol: Symbol) -> Set[Symbol]: + """解析符号的传递依赖,展开import_alias""" + resolved_deps = set() + for dep in symbol.dependencies: + if dep.symbol_type == 'import_alias' and dep.dependencies: + # 如果依赖是import_alias,找到它指向的实际符号 + for target in dep.dependencies: + if target.symbol_type in ('function', 'class', 'variable'): + resolved_deps.add(target) + elif target.symbol_type == 'import_alias': + # 递归处理import_alias链 + resolved_deps.update(resolve_transitive_deps(target)) + else: + resolved_deps.add(dep) + return resolved_deps for symbol in symbols: - for dep in symbol.dependencies: + # 获取解析后的依赖 + resolved_deps = resolve_transitive_deps(symbol) + + for dep in resolved_deps: if dep in symbols and dep != symbol: # 忽略自引用 graph[dep].add(symbol) in_degree[symbol] += 1 @@ -1478,7 +1540,7 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str qname.startswith(module_qname + '.') and symbol not in all_import_aliases): # 只添加那些属于已处理模块的导入别名 - if any(s.scope.module_path == module_path for s in self.needed_symbols): + if any(s.scope and s.scope.module_path == module_path for s in self.needed_symbols): all_import_aliases.append(symbol) # 遍历所有收集到的导入别名 @@ -1634,18 +1696,22 @@ def generate_name_mappings(self, symbols: Set[Symbol]): # 生成映射 for symbol in all_symbols_to_consider: - if name_counts[symbol.name] > 1: - # 有冲突,需要重命名(非 import_alias 的情况) - module_key = self.visitor.get_module_qname(symbol.scope.module_path) - module_key = module_key.replace('.', '_').replace('__init__', 'pkg') - - # 对于运行时导入,添加特殊后缀以区分 + # 对于 import_alias 符号,总是添加后缀(__mod 或 __rt) + if symbol.symbol_type == 'import_alias': if symbol.is_runtime_import: new_name = f"{symbol.name}__rt" else: - # 对于函数和其他类型,使用模块前缀 - new_name = f"{module_key}_{symbol.name}" - + new_name = f"{symbol.name}__mod" + self.name_mappings[symbol.qname] = new_name + elif name_counts[symbol.name] > 1: + # 对于其他符号,只在有冲突时重命名 + if symbol.scope and symbol.scope.module_path: + module_key = self.visitor.get_module_qname(symbol.scope.module_path) + module_key = module_key.replace('.', '_').replace('__init__', 'pkg') + else: + # 如果没有 scope,使用符号的 qname 前缀 + module_key = symbol.qname.rsplit('.', 1)[0] if '.' in symbol.qname else 'unknown' + new_name = f"{module_key}_{symbol.name}" self.name_mappings[symbol.qname] = new_name # 同时为带类型后缀的版本添加映射 @@ -1653,7 +1719,7 @@ def generate_name_mappings(self, symbols: Set[Symbol]): if type_qname in self.visitor.all_symbols: self.name_mappings[type_qname] = new_name else: - # 无冲突,保持原名 + # 无冲突的非导入符号,保持原名 self.name_mappings[symbol.qname] = symbol.name # 同时为带类型后缀的版本添加映射 @@ -1698,8 +1764,11 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer', # 检查是否是完全重复的定义 # 简化处理:如果名称已存在,则重命名 - module_qname = self.visitor.get_module_qname(symbol.scope.module_path) - module_alias = module_qname.replace('.', '_') + if symbol.scope and symbol.scope.module_path: + module_qname = self.visitor.get_module_qname(symbol.scope.module_path) + module_alias = module_qname.replace('.', '_') + else: + module_alias = 'unknown' new_name = f"{target_name}__from_{module_alias}" self.name_mappings[symbol.qname] = new_name target_name = new_name @@ -1715,7 +1784,7 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer', # 写入结果 if transformed is not None: - if symbol.scope.module_path: + if symbol.scope and symbol.scope.module_path: rel_path = symbol.scope.module_path.relative_to(self.project_root) result_lines.append(f"# From {rel_path}") result_lines.append(ast.unparse(transformed)) @@ -1727,9 +1796,13 @@ def _process_imports(self, imports: Set[str]) -> List[str]: 修复:使用更精确的去重键,包含导入样式信息 """ result = [] + # 添加别名去重集合,避免相同的别名被多次定义 + seen_aliases = set() for imp in sorted(imports): # 解析导入语句 + new_alias = None # 初始化变量 + if imp.startswith('from '): # from X import Y as Z 或 from X import Y parts = imp.split() @@ -1748,9 +1821,11 @@ def _process_imports(self, imports: Set[str]) -> List[str]: # from X import Y module = parts[1] name = parts[3] + # B2 修复:即使没有别名,也要添加 __mod 后缀 + new_alias = f"{name}__mod" + new_imp = f"from {module} import {name} as {new_alias}" # Issue #37 修复:包含导入样式在去重键中 - key = ('from', module, name, name) - new_imp = imp + key = ('from', module, name, new_alias) else: # import X as Y 或 import X parts = imp.split() @@ -1774,8 +1849,9 @@ def _process_imports(self, imports: Set[str]) -> List[str]: # Issue #37 修复:包含导入样式在去重键中 key = ('import', module, new_alias) - # 检查是否已存在 - if key not in self.import_registry: + # 检查是否已存在(包括别名去重) + if key not in self.import_registry and new_alias not in seen_aliases: + seen_aliases.add(new_alias) self.import_registry.add(key) result.append(new_imp) @@ -1802,7 +1878,7 @@ def merge_script(self, script_path: Path) -> str: # Issue #37 修复:过滤掉入口模块中的变量定义 # 这些变量会作为主代码的一部分输出,不需要单独输出 if (s.symbol_type == 'variable' and - s.scope.module_path == script_path): + s.scope and s.scope.module_path == script_path): continue # 检查是否是类的方法(通过判断qname中是否包含类名) @@ -1842,12 +1918,15 @@ def merge_script(self, script_path: Path) -> str: # 5. 生成名称映射 self.generate_name_mappings(output_symbols) - # B2 修复:为所有 import_alias 符号添加 __mod 后缀映射 + # B2 修复:为所有 import_alias 符号添加后缀映射 # 这包括那些被过滤掉不输出的外部导入 for symbol in self.visitor.all_symbols.values(): if symbol.symbol_type == 'import_alias' and symbol.qname not in self.name_mappings: - # 为所有导入别名添加 __mod 后缀 - new_name = f"{symbol.name}__mod" + # 根据是否是运行时导入选择不同的后缀 + if symbol.is_runtime_import: + new_name = f"{symbol.name}__rt" + else: + new_name = f"{symbol.name}__mod" self.name_mappings[symbol.qname] = new_name # 6. 生成代码 @@ -1865,6 +1944,37 @@ def merge_script(self, script_path: Path) -> str: result_lines.extend(processed_imports) result_lines.append("") + # 处理运行时导入(带 __rt 后缀) + runtime_imports = [] + for symbol in self.visitor.all_symbols.values(): + if symbol.symbol_type == 'import_alias' and symbol.is_runtime_import: + # 获取重命名后的名称 + new_name = self.name_mappings.get(symbol.qname, symbol.name) + + # 根据导入类型生成导入语句 + if isinstance(symbol.def_node, ast.Import): + # import xxx as yyy 形式 + for alias in symbol.def_node.names: + if alias.asname == symbol.name or (not alias.asname and alias.name.split('.')[0] == symbol.name): + # 如果原本没有别名且新名称不同,或者有别名但需要改名 + if new_name != alias.name.split('.')[0]: + runtime_imports.append(f"import {alias.name} as {new_name}") + else: + runtime_imports.append(f"import {alias.name}") + break + elif isinstance(symbol.def_node, ast.ImportFrom): + # from xxx import yyy as zzz 形式 + module = symbol.def_node.module or '' + for alias in symbol.def_node.names: + if alias.asname == symbol.name or (not alias.asname and alias.name == symbol.name): + runtime_imports.append(f"from {module} import {alias.name} as {new_name}") + break + + if runtime_imports: + result_lines.append("# Runtime imports (originally in try...except ImportError blocks)") + result_lines.extend(sorted(set(runtime_imports))) + result_lines.append("") + # 收集并重新注入必要的导入别名 reinjected_imports = self._collect_and_reinject_imports(output_symbols) if reinjected_imports: @@ -1880,7 +1990,11 @@ def merge_script(self, script_path: Path) -> str: self._write_symbol(symbol, transformer, result_lines) # 收集模块初始化语句 - module_qname = self.visitor.get_module_qname(symbol.scope.module_path) + if symbol.scope and symbol.scope.module_path: + module_qname = self.visitor.get_module_qname(symbol.scope.module_path) + else: + continue + if module_qname in self.visitor.all_symbols: module_symbol = self.visitor.all_symbols[module_qname] # 不要收集入口模块的初始化语句,因为它们已经在 main_code 中处理了 @@ -1900,7 +2014,7 @@ def merge_script(self, script_path: Path) -> str: # 设置正确的模块作用域 module_symbol = self.visitor.all_symbols.get(module_qname) - if module_symbol and module_symbol.scope.module_path: + if module_symbol and module_symbol.scope and module_symbol.scope.module_path: module_path = module_symbol.scope.module_path if module_path in self.visitor.module_symbols: module_scope = self.visitor.module_symbols[module_path].get('__scope__') @@ -2007,7 +2121,11 @@ def transform_symbol(self, symbol: Symbol) -> ast.AST: return None # 设置当前作用域栈 - self.current_scope_stack = [symbol.scope] + if symbol.scope: + self.current_scope_stack = [symbol.scope] + else: + # 如果没有 scope,使用默认的模块作用域 + self.current_scope_stack = [] # 转换节点 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): diff --git a/tests/test_issue_41_fixes.py b/tests/test_issue_41_fixes.py new file mode 100644 index 0000000..6aaf8ae --- /dev/null +++ b/tests/test_issue_41_fixes.py @@ -0,0 +1,257 @@ +"""测试 Issue #41 中提到的四个核心修复""" +import pytest +import tempfile +import shutil +from pathlib import Path +import subprocess +import sys + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from scripts.advanced_merge import CircularDependencyError, AdvancedCodeMerger + + +class TestIssue41Fixes: + """测试 Issue #41 的四个核心修复""" + + def setup_method(self): + """设置测试环境""" + self.test_dir = Path(tempfile.mkdtemp()) + + def teardown_method(self): + """清理测试环境""" + shutil.rmtree(self.test_dir) + + def test_circular_dependency_detection(self): + """测试1: 循环依赖检测必须准确""" + # 创建循环依赖的模块 + module_a = self.test_dir / "module_a.py" + module_a.write_text(""" +from module_b import func_b + +def func_a(): + return func_b() +""") + + module_b = self.test_dir / "module_b.py" + module_b.write_text(""" +from module_a import func_a + +def func_b(): + return func_a() +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from module_a import func_a + +if __name__ == "__main__": + print(func_a()) +""") + + # 必须抛出 CircularDependencyError + merger = AdvancedCodeMerger(self.test_dir) + with pytest.raises(CircularDependencyError) as exc_info: + merger.merge_script(main_script) + + error_msg = str(exc_info.value) + assert "Circular dependency detected" in error_msg + assert "func_a" in error_msg and "func_b" in error_msg + + def test_attribute_chain_integrity(self): + """测试2: 属性调用链必须保持完整""" + # 创建具有属性链的模块 + utils_module = self.test_dir / "utils.py" + utils_module.write_text(""" +class Config: + class Database: + def get_connection(self): + return "DB Connection" + + @property + def db(self): + return self.Database() +""") + + main_script = self.test_dir / "main.py" + main_script.write_text(""" +from utils import Config + +config = Config() +# 测试属性链 +print(config.db.get_connection()) + +# 测试 super() +class MyConfig(Config): + def __init__(self): + super().__init__() + print("MyConfig initialized") + +mc = MyConfig() +""") + + merger = AdvancedCodeMerger(self.test_dir) + result = merger.merge_script(main_script) + + # 属性链必须保持完整 + assert "config.db.get_connection()" in result + assert "super().__init__()" in result + + # 运行验证 + merged_file = self.test_dir / "main_merged.py" + merged_file.write_text(result) + + proc = subprocess.run( + [sys.executable, str(merged_file)], + capture_output=True, + text=True + ) + + assert proc.returncode == 0 + assert "DB Connection" in proc.stdout + assert "MyConfig initialized" in proc.stdout + + def test_naming_conflict_strategy(self): + """测试3: 优先重命名导入别名,保持用户符号不变""" + main_script = self.test_dir / "main.py" + main_script.write_text(""" +import json + +def json_processor(): + '''用户定义的处理函数''' + return "user processor" + +# 两者都使用,验证没有冲突 +data = json.dumps({"test": "value"}) +result = json_processor() +print(data) +print(result) +""") + + merger = AdvancedCodeMerger(self.test_dir) + result = merger.merge_script(main_script) + + # 导入别名应该被重命名 + assert "json__mod" in result + # 用户定义的函数保持原名 + assert "def json_processor():" in result + # 使用处应该正确替换 + assert "json__mod.dumps" in result + + # 运行验证 + merged_file = self.test_dir / "main_merged.py" + merged_file.write_text(result) + + proc = subprocess.run( + [sys.executable, str(merged_file)], + capture_output=True, + text=True + ) + + assert proc.returncode == 0 + assert '"test": "value"' in proc.stdout + assert "user processor" in proc.stdout + + def test_alias_renaming_rules(self): + """测试4: 别名重命名规则(静态 __mod,运行时 __rt)""" + main_script = self.test_dir / "main.py" + main_script.write_text(""" +# 静态导入 +import os +import sys as system + +# 运行时导入 +try: + import ujson as json +except ImportError: + import json + +# 使用所有导入 +print(os.path.exists('.')) +print(system.version_info[0]) +print(json.dumps({"runtime": "import"})) +""") + + merger = AdvancedCodeMerger(self.test_dir) + result = merger.merge_script(main_script) + + # 静态导入使用 __mod 后缀 + assert "os__mod" in result + assert "system__mod" in result + + # 运行时导入使用 __rt 后缀 + assert "json__rt" in result or "json as json__rt" in result + + # 使用处应该正确替换 + assert "os__mod.path.exists" in result + assert "system__mod.version_info" in result + assert "json__rt.dumps" in result + + def test_complex_scenario(self): + """测试复杂场景:多个问题组合""" + # 创建一个有属性链的模块 + helper_module = self.test_dir / "helper.py" + helper_module.write_text(""" +class Helper: + class Inner: + def process(self): + return "processed" + + def get_inner(self): + return self.Inner() +""") + + # 创建主脚本,包含多种情况 + main_script = self.test_dir / "main.py" + main_script.write_text(""" +import sys +from helper import Helper + +# 运行时导入 +try: + import ujson as json +except ImportError: + import json + +def sys_info(): + '''用户定义的系统信息函数''' + return "user sys info" + +# 使用属性链 +h = Helper() +print(h.get_inner().process()) + +# 使用导入和用户函数 +print(sys.version_info[0]) +print(sys_info()) +print(json.dumps({"test": True})) +""") + + merger = AdvancedCodeMerger(self.test_dir) + result = merger.merge_script(main_script) + + # 验证所有修复都正确应用 + assert "h.get_inner().process()" in result # 属性链完整 + assert "sys__mod" in result # 静态导入重命名 + assert "json__rt" in result or "json as json__rt" in result # 运行时导入重命名 + assert "def sys_info():" in result # 用户函数保持原名 + + # 运行验证 + merged_file = self.test_dir / "main_merged.py" + merged_file.write_text(result) + + proc = subprocess.run( + [sys.executable, str(merged_file)], + capture_output=True, + text=True + ) + + assert proc.returncode == 0 + assert "processed" in proc.stdout + assert "user sys info" in proc.stdout + assert '"test": true' in proc.stdout + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file