diff --git a/scripts/advanced_merge.py b/scripts/advanced_merge.py index c9e5b05..e6fa61d 100644 --- a/scripts/advanced_merge.py +++ b/scripts/advanced_merge.py @@ -292,10 +292,15 @@ def visit_Module(self, node: ast.Module): if isinstance(stmt, (ast.Import, ast.ImportFrom)): # 处理导入 self.visit(stmt) - elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, - ast.ClassDef, ast.Assign, ast.AnnAssign)): - # 定义语句 + elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + # 函数和类定义 self.visit(stmt) + elif isinstance(stmt, (ast.Assign, ast.AnnAssign)): + # Issue #37 修复:赋值语句既是定义也是初始化 + # 需要访问以创建变量符号 + self.visit(stmt) + # 同时也需要作为初始化语句保留 + init_statements.append(stmt) elif isinstance(stmt, ast.Try) and self._is_try_import_error(stmt): # 特殊处理 try...except ImportError 块 # 保留完整的块作为初始化语句 @@ -364,6 +369,17 @@ def visit_Import(self, node: ast.Import): is_runtime_import=self.in_try_import_error ) + # Issue #37 修复:为外部导入创建模块符号依赖 + # 创建一个虚拟的模块符号作为依赖 + module_symbol = Symbol( + name=module_name, + qname=module_name, + symbol_type='module', + def_node=None, + scope=None + ) + symbol.dependencies.add(module_symbol) + self.current_scope().symbols[alias_name] = symbol self.all_symbols[symbol.qname] = symbol self.defnode_to_scope[symbol.def_node] = symbol.scope @@ -452,8 +468,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom): is_runtime_import=self.in_try_import_error ) - # 对于外部导入,我们不设置依赖关系 - # 但仍然注册符号 + # Issue #37 修复:为外部导入创建模块符号依赖 + # 创建一个虚拟的模块符号作为依赖 + if module_name: # from xxx import yyy 的情况 + module_symbol = Symbol( + name=module_name, + qname=module_name, + symbol_type='module', + def_node=None, + scope=None + ) + symbol.dependencies.add(module_symbol) + self.current_scope().symbols[alias_name] = symbol self.all_symbols[symbol.qname] = symbol self.defnode_to_scope[symbol.def_node] = symbol.scope @@ -1413,12 +1439,26 @@ def _needs_reinject(self, symbol: Symbol, output_symbols: Set[Symbol], visited: def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str]: """收集并重新注入必要的导入别名 + Issue #37 修复:避免重复注入已经在 external_imports 中的导入 + 返回需要添加的导入语句列表 """ imports_to_reinject = [] import_set = set() # 用于去重 processed_symbols = set() # 跟踪已处理的符号 + # Issue #37 修复:记录已经处理过的外部导入模块 + external_modules = set() + for imp in self.visitor.external_imports: + if imp.startswith('import '): + parts = imp.split() + module = parts[1] + external_modules.add(module) + elif imp.startswith('from '): + parts = imp.split() + module = parts[1] + external_modules.add(module) + # 收集所有相关模块中的导入别名 # 不仅仅是 needed_symbols,还包括所有被访问过的模块中的导入 all_import_aliases = [] @@ -1480,6 +1520,10 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str module_name = alias.name alias_name = alias.asname if alias.asname else alias.name + # Issue #37 修复:跳过已经在 external_imports 中的模块 + if module_name in external_modules: + continue + # 检查当前别名对应的符号 if symbol.name != alias_name: continue @@ -1514,6 +1558,10 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str # 跳过相对导入 if level > 0: continue + + # Issue #37 修复:跳过已经在 external_imports 中的模块 + if module in external_modules: + continue for alias in node.names: name = alias.name @@ -1612,6 +1660,29 @@ def generate_name_mappings(self, symbols: Set[Symbol]): type_qname = f"{symbol.qname}#{symbol.symbol_type}" if type_qname in self.visitor.all_symbols: self.name_mappings[type_qname] = symbol.name + + # Issue #37 修复:为外部导入的别名添加正确的映射 + # 检查 import_registry 中的映射,并更新 name_mappings + module_alias_map = {} # module -> alias 的映射 + for key in self.import_registry: + if key[0] == 'import' and len(key) == 3: + # ('import', 'json', 'json__mod') + module = key[1] + alias = key[2] + module_alias_map[module] = alias + + # 更新所有 import_alias 符号的映射 + for qname, symbol in self.visitor.all_symbols.items(): + if symbol.symbol_type == 'import_alias' and symbol.dependencies: + # 检查依赖的模块 + for dep in symbol.dependencies: + if dep.symbol_type == 'module' and dep.name in module_alias_map: + # 更新映射到新的别名 + self.name_mappings[symbol.qname] = module_alias_map[dep.name] + # 同时更新带类型后缀的版本 + type_qname = f"{symbol.qname}#{symbol.symbol_type}" + if type_qname in self.visitor.all_symbols: + self.name_mappings[type_qname] = module_alias_map[dep.name] def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer', result_lines: List[str]): """解决 #1: 写入符号时进行冲突检测""" @@ -1651,9 +1722,11 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer', result_lines.append("") def _process_imports(self, imports: Set[str]) -> List[str]: - """解决 #2: 处理导入去重和 alias 冲突""" + """解决 #2: 处理导入去重和 alias 冲突 + + 修复:使用更精确的去重键,包含导入样式信息 + """ result = [] - alias_map = {} # alias -> module 的映射 for imp in sorted(imports): # 解析导入语句 @@ -1669,12 +1742,14 @@ def _process_imports(self, imports: Set[str]) -> List[str]: # B2 修复:为别名添加 __mod 后缀 new_alias = f"{alias}__mod" new_imp = f"from {module} import {name} as {new_alias}" - key = (module, name, new_alias) + # Issue #37 修复:包含导入样式在去重键中 + key = ('from', module, name, new_alias) else: # from X import Y module = parts[1] name = parts[3] - key = (module, name, name) + # Issue #37 修复:包含导入样式在去重键中 + key = ('from', module, name, name) new_imp = imp else: # import X as Y 或 import X @@ -1687,7 +1762,8 @@ def _process_imports(self, imports: Set[str]) -> List[str]: # B2 修复:为别名添加 __mod 后缀 new_alias = f"{alias}__mod" new_imp = f"import {module} as {new_alias}" - key = (module, new_alias) + # Issue #37 修复:包含导入样式在去重键中 + key = ('import', module, new_alias) else: # import X module = parts[1] @@ -1695,7 +1771,8 @@ def _process_imports(self, imports: Set[str]) -> List[str]: alias = module.split('.')[0] new_alias = f"{alias}__mod" new_imp = f"import {module} as {new_alias}" - key = (module, new_alias) + # Issue #37 修复:包含导入样式在去重键中 + key = ('import', module, new_alias) # 检查是否已存在 if key not in self.import_registry: @@ -1722,6 +1799,12 @@ def merge_script(self, script_path: Path) -> str: if s.symbol_type in ('import_alias', 'module', 'parameter'): continue + # Issue #37 修复:过滤掉入口模块中的变量定义 + # 这些变量会作为主代码的一部分输出,不需要单独输出 + if (s.symbol_type == 'variable' and + s.scope.module_path == script_path): + continue + # 检查是否是类的方法(通过判断qname中是否包含类名) is_class_method = False if s.symbol_type == 'function' and '.' in s.qname: @@ -1750,6 +1833,12 @@ def merge_script(self, script_path: Path) -> str: # 4. 拓扑排序 sorted_symbols = self.topological_sort(output_symbols) + # Issue #37 修复:在生成名称映射之前,先处理外部导入 + # 这样import_registry会被填充,供generate_name_mappings使用 + processed_imports = [] + if self.visitor.external_imports: + processed_imports = self._process_imports(self.visitor.external_imports) + # 5. 生成名称映射 self.generate_name_mappings(output_symbols) @@ -1771,9 +1860,8 @@ def merge_script(self, script_path: Path) -> str: result_lines.extend(sorted(self.visitor.future_imports)) result_lines.append("") - # 外部导入(通过去重处理) - if self.visitor.external_imports: - processed_imports = self._process_imports(self.visitor.external_imports) + # 外部导入(使用之前已经处理好的) + if processed_imports: result_lines.extend(processed_imports) result_lines.append("") @@ -1827,6 +1915,31 @@ def merge_script(self, script_path: Path) -> str: # 解决 #3: 跳过非入口模块的 __main__ 块 if should_skip_main and self._is_dunder_main(stmt): continue + + # Issue #37 修复:跳过已经作为符号输出的赋值语句 + # 只跳过简单的变量赋值,保留其他语句 + if isinstance(stmt, (ast.Assign, ast.AnnAssign)): + # 检查是否是简单的变量定义(没有函数调用等副作用) + skip = False + if isinstance(stmt, ast.Assign): + # 检查是否是简单赋值 + if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + var_name = stmt.targets[0].id + # 检查这个变量是否已经被输出为符号 + var_qname = f"{module_qname}.{var_name}" + if var_qname in self.name_mappings: + # 检查赋值的值是否是常量 + if isinstance(stmt.value, (ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)): + skip = True + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): + var_name = stmt.target.id + var_qname = f"{module_qname}.{var_name}" + if var_qname in self.name_mappings and stmt.value: + if isinstance(stmt.value, (ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)): + skip = True + + if skip: + continue transformed_stmt = transformer.visit(copy.deepcopy(stmt)) result_lines.append(ast.unparse(transformed_stmt)) diff --git a/tests/test_issue_37_audit_failures.py b/tests/test_issue_37_audit_failures.py new file mode 100644 index 0000000..8899548 --- /dev/null +++ b/tests/test_issue_37_audit_failures.py @@ -0,0 +1,322 @@ +""" +测试Issue #37: 合并脚本的静态审计失败问题 + +本测试专门验证advanced_merge.py生成的代码能否通过ASTAuditor的检查 +""" + +import ast +import tempfile +import pytest +from pathlib import Path + +from scripts.advanced_merge import AdvancedCodeMerger +from pysymphony.auditor.auditor import ASTAuditor + + +class TestIssue37AuditFailures: + """测试合并后代码的静态审计问题""" + + def test_duplicate_mod_imports(self, tmp_path): + """测试重复的_mod导入问题""" + # 创建测试文件结构 + pkg_dir = tmp_path / "test_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + + # module_a.py - 导入json + module_a = pkg_dir / "module_a.py" + module_a.write_text(""" +import json + +def func_a(): + return json.dumps({"a": 1}) +""") + + # module_b.py - 也导入json + module_b = pkg_dir / "module_b.py" + module_b.write_text(""" +import json + +def func_b(): + return json.dumps({"b": 2}) +""") + + # main.py - 使用两个模块 + main_script = tmp_path / "main.py" + main_script.write_text(""" +from test_pkg.module_a import func_a +from test_pkg.module_b import func_b + +def main(): + print(func_a()) + print(func_b()) + +if __name__ == "__main__": + main() +""") + + # 执行合并 + merger = AdvancedCodeMerger(tmp_path) + merged_content = merger.merge_script(main_script) + + # 解析合并后的AST + merged_ast = ast.parse(merged_content) + + # 使用ASTAuditor检查 + auditor = ASTAuditor() + audit_result = auditor.audit(merged_ast) + + # 应该通过审计(修复后) + assert audit_result is True, f"ASTAuditor发现错误:\n{auditor.get_report()}" + + # 检查没有重复的_mod导入 + imports = [node for node in merged_ast.body if isinstance(node, ast.Import)] + json_mod_count = sum(1 for imp in imports + for alias in imp.names + if alias.asname and 'json' in alias.asname and '_mod' in alias.asname) + assert json_mod_count <= 1, f"发现{json_mod_count}个json__mod导入" + + def test_undefined_loop_variables(self, tmp_path): + """测试循环变量未定义问题""" + # 创建测试模块 + module = tmp_path / "loop_module.py" + module.write_text(""" +def process_items(items): + result = [] + for i, item in enumerate(items): + # 使用循环变量 + result.append((i, item)) + + # 列表推导式中的变量 + squares = [x**2 for x in range(10)] + + return result, squares +""") + + # 主脚本 + main_script = tmp_path / "main_loop.py" + main_script.write_text(""" +from loop_module import process_items + +items = ["a", "b", "c"] +result, squares = process_items(items) +print(result) +print(squares) +""") + + # 执行合并 + merger = AdvancedCodeMerger(tmp_path) + merged_content = merger.merge_script(main_script) + merged_ast = ast.parse(merged_content) + + # 审计检查 + auditor = ASTAuditor() + audit_result = auditor.audit(merged_ast) + + assert audit_result is True, f"循环变量应该被正确识别:\n{auditor.get_report()}" + + def test_store_target_mapping(self, tmp_path): + """测试赋值目标的名称映射""" + # 创建有名称冲突的模块 + module_a = tmp_path / "mod_a.py" + module_a.write_text(""" +x = 100 +y = 200 + +def get_x(): + return x + +def set_x(value): + global x + x = value +""") + + module_b = tmp_path / "mod_b.py" + module_b.write_text(""" +x = 300 # 同名变量 +y = 400 + +def get_x_b(): + return x +""") + + main_script = tmp_path / "main_store.py" + main_script.write_text(""" +from mod_a import get_x, set_x +from mod_b import get_x_b + +print(get_x()) +print(get_x_b()) +set_x(500) +print(get_x()) +""") + + # 执行合并 + merger = AdvancedCodeMerger(tmp_path) + merged_content = merger.merge_script(main_script) + merged_ast = ast.parse(merged_content) + + # 审计检查 + auditor = ASTAuditor() + audit_result = auditor.audit(merged_ast) + + assert audit_result is True, f"Store目标应该被正确映射:\n{auditor.get_report()}" + + def test_external_import_preservation(self, tmp_path): + """测试外部导入的保留""" + # 创建使用多个外部库的模块 + module = tmp_path / "external_module.py" + module.write_text(""" +import os +import sys +import json +from typing import List, Dict +from pathlib import Path + +def get_cwd(): + return os.getcwd() + +def get_python_version(): + return sys.version + +def save_json(data: Dict, path: Path): + with open(path, 'w') as f: + json.dump(data, f) +""") + + main_script = tmp_path / "main_external.py" + main_script.write_text(""" +from external_module import get_cwd, save_json +from pathlib import Path + +cwd = get_cwd() +save_json({"cwd": cwd}, Path("output.json")) +""") + + # 执行合并 + merger = AdvancedCodeMerger(tmp_path) + merged_content = merger.merge_script(main_script) + merged_ast = ast.parse(merged_content) + + # 审计检查 + auditor = ASTAuditor() + audit_result = auditor.audit(merged_ast) + + assert audit_result is True, f"外部导入应该被正确处理:\n{auditor.get_report()}" + + # 验证外部导入存在 + import_names = set() + for node in merged_ast.body: + if isinstance(node, ast.Import): + for alias in node.names: + import_names.add(alias.name) + elif isinstance(node, ast.ImportFrom): + import_names.add(node.module) + + assert 'os' in import_names or any('os' in name for name in import_names) + assert 'sys' in import_names or any('sys' in name for name in import_names) + assert 'json' in import_names or any('json' in name for name in import_names) + + def test_complex_real_world_scenario(self, tmp_path): + """测试复杂的真实场景""" + # 创建一个模拟真实项目的结构 + utils_dir = tmp_path / "utils" + utils_dir.mkdir() + (utils_dir / "__init__.py").write_text("") + + # config.py + (utils_dir / "config.py").write_text(""" +import json +import os + +CONFIG_FILE = "config.json" + +def load_config(): + if os.path.exists(CONFIG_FILE): + with open(CONFIG_FILE) as f: + return json.load(f) + return {} + +config = load_config() +""") + + # logger.py + (utils_dir / "logger.py").write_text(""" +import logging +import sys + +def setup_logger(name): + logger = logging.getLogger(name) + handler = logging.StreamHandler(sys.stdout) + logger.addHandler(handler) + return logger + +logger = setup_logger("app") +""") + + # data.py + (utils_dir / "data.py").write_text(""" +import json +from typing import List, Dict + +def process_data(items: List[Dict]) -> List[Dict]: + result = [] + for i, item in enumerate(items): + # 处理每个项目 + processed = { + "index": i, + "original": item, + "processed": True + } + result.append(processed) + + # 使用列表推导式 + values = [x["value"] for x in items if "value" in x] + + return result + +# 模块级变量 +data_cache = {} +""") + + # main.py + main_script = tmp_path / "main_complex.py" + main_script.write_text(""" +from utils.config import config +from utils.logger import logger +from utils.data import process_data, data_cache + +def main(): + # 使用配置 + logger.info(f"Config: {config}") + + # 处理数据 + items = [{"value": i} for i in range(5)] + result = process_data(items) + + # 更新缓存 + data_cache["result"] = result + + logger.info(f"Processed {len(result)} items") + +if __name__ == "__main__": + main() +""") + + # 执行合并 + merger = AdvancedCodeMerger(tmp_path) + merged_content = merger.merge_script(main_script) + merged_ast = ast.parse(merged_content) + + # 最关键的测试:ASTAuditor审计 + auditor = ASTAuditor() + audit_result = auditor.audit(merged_ast) + + # 断言没有错误 + assert audit_result is True, f"复杂场景审计失败:\n{auditor.get_report()}" + + # 额外验证:编译检查 + try: + compile(merged_ast, "", "exec") + except Exception as e: + pytest.fail(f"合并后的代码无法编译: {e}") \ No newline at end of file