diff --git a/.github/workflows/auditor-selftest.yml b/.github/workflows/auditor-selftest.yml new file mode 100644 index 0000000..142550a --- /dev/null +++ b/.github/workflows/auditor-selftest.yml @@ -0,0 +1,122 @@ +name: AST Auditor Self-Test + +on: + push: + branches: [ main ] + paths: + - 'pysymphony/**' + - 'scripts/**' + - 'examples/**' + - '.github/workflows/auditor-selftest.yml' + pull_request: + branches: [ main ] + paths: + - 'pysymphony/**' + - 'scripts/**' + - 'examples/**' + - '.github/workflows/auditor-selftest.yml' + +jobs: + auditor-selftest: + name: Run AST Auditor on Project Code + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + + - name: Run AST Auditor on pysymphony package + run: | + echo "Auditing pysymphony package..." + python -c " +import sys +from pathlib import Path +from pysymphony.auditor.auditor import ASTAuditor + +auditor = ASTAuditor() +failed = False + +for py_file in Path('pysymphony').rglob('*.py'): + try: + source = py_file.read_text() + result = auditor.audit(source, str(py_file)) + if not result: + print(f'❌ {py_file}: Failed audit') + print(auditor.get_report()) + failed = True + else: + print(f'✓ {py_file}: Passed') + except Exception as e: + print(f'❌ {py_file}: Error during audit - {e}') + failed = True + +sys.exit(1 if failed else 0) + " + + - name: Run AST Auditor on scripts + run: | + echo "Auditing scripts..." + python -c " +import sys +from pathlib import Path +from pysymphony.auditor.auditor import ASTAuditor + +auditor = ASTAuditor() +failed = False + +for py_file in Path('scripts').rglob('*.py'): + try: + source = py_file.read_text() + result = auditor.audit(source, str(py_file)) + if not result: + print(f'❌ {py_file}: Failed audit') + print(auditor.get_report()) + failed = True + else: + print(f'✓ {py_file}: Passed') + except Exception as e: + print(f'❌ {py_file}: Error during audit - {e}') + failed = True + +sys.exit(1 if failed else 0) + " + + - name: Run AST Auditor on examples + run: | + echo "Auditing examples..." + python -c " +import sys +from pathlib import Path +from pysymphony.auditor.auditor import ASTAuditor + +auditor = ASTAuditor() +failed = False + +for py_file in Path('examples').rglob('*.py'): + try: + source = py_file.read_text() + result = auditor.audit(source, str(py_file)) + if not result: + print(f'❌ {py_file}: Failed audit') + print(auditor.get_report()) + failed = True + else: + print(f'✓ {py_file}: Passed') + except Exception as e: + print(f'❌ {py_file}: Error during audit - {e}') + failed = True + +sys.exit(1 if failed else 0) + " \ No newline at end of file diff --git a/pysymphony/auditor/auditor.py b/pysymphony/auditor/auditor.py index 5309232..03085a8 100644 --- a/pysymphony/auditor/auditor.py +++ b/pysymphony/auditor/auditor.py @@ -69,7 +69,11 @@ def add_symbol(self, name: str, node: ast.AST, symbol_type: str): # 如果两个定义都在 try...except ImportError 块中,这是预期行为,不报错 if name in self.try_except_symbols: pass # 跳过错误报告 + # 如果是变量的重新赋值,这在 Python 中是允许的 + elif existing.type == 'variable' and symbol_type == 'variable': + pass # 变量重新赋值是正常的,不报错 else: + # 其他情况(如函数或类的重复定义)才报错 self.duplicate_definitions.append(( name, [existing.lineno, node.lineno] @@ -152,10 +156,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom): def visit_Assign(self, node: ast.Assign): """访问赋值语句""" - # 只处理简单的名称赋值 + # 使用通用的目标注册函数处理所有类型的赋值 for target in node.targets: - if isinstance(target, ast.Name): - self.add_symbol(target.id, node, 'variable') + self._register_targets(target) self.generic_visit(node) def visit_AnnAssign(self, node: ast.AnnAssign): @@ -163,6 +166,123 @@ def visit_AnnAssign(self, node: ast.AnnAssign): if isinstance(node.target, ast.Name): self.add_symbol(node.target.id, node, 'variable') self.generic_visit(node) + + def _register_targets(self, target: ast.AST): + """ + 通用的目标注册助手函数 + 递归处理各种 target 节点,支持元组/列表解包 + """ + if isinstance(target, ast.Name): + # 简单名称赋值 + self.add_symbol(target.id, target, 'variable') + elif isinstance(target, (ast.Tuple, ast.List)): + # 元组或列表解包 + for elt in target.elts: + self._register_targets(elt) + elif isinstance(target, ast.Starred): + # 星号表达式 (*args) + self._register_targets(target.value) + # 其他情况(如属性赋值、下标赋值等)不创建新的局部变量 + + def visit_For(self, node: ast.For): + """访问 for 循环""" + # 注册循环变量 + self._register_targets(node.target) + self.generic_visit(node) + + def visit_AsyncFor(self, node: ast.AsyncFor): + """访问异步 for 循环""" + # 注册循环变量 + self._register_targets(node.target) + self.generic_visit(node) + + def visit_With(self, node: ast.With): + """访问 with 语句""" + for item in node.items: + if item.optional_vars: + self._register_targets(item.optional_vars) + self.generic_visit(node) + + def visit_AsyncWith(self, node: ast.AsyncWith): + """访问异步 with 语句""" + for item in node.items: + if item.optional_vars: + self._register_targets(item.optional_vars) + self.generic_visit(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler): + """访问异常处理器""" + if node.name: + # Python 3.8+ 中 name 是字符串 + if isinstance(node.name, str): + # 创建一个虚拟的 Name 节点用于记录位置信息 + name_node = ast.Name(id=node.name, ctx=ast.Store()) + name_node.lineno = node.lineno + name_node.col_offset = node.col_offset + self.add_symbol(node.name, name_node, 'variable') + # Python 3.7 及以下版本中 name 是 Name 节点 + elif isinstance(node.name, ast.Name): + self.add_symbol(node.name.id, node.name, 'variable') + self.generic_visit(node) + + def visit_ListComp(self, node: ast.ListComp): + """访问列表推导式""" + # 推导式创建新的作用域 + self.enter_scope('', 'comprehension') + # 按顺序处理每个生成器:先注册目标,再访问迭代器和条件 + for generator in node.generators: + # 访问生成器的迭代器部分(可能引用外部或前面生成器的变量) + self.visit(generator.iter) + # 注册当前生成器的目标变量 + self._register_targets(generator.target) + # 访问生成器的条件部分(可能引用当前或前面生成器的变量) + for if_clause in generator.ifs: + self.visit(if_clause) + # 最后访问推导式的主体部分 + self.visit(node.elt) + self.exit_scope() + + def visit_SetComp(self, node: ast.SetComp): + """访问集合推导式""" + self.enter_scope('', 'comprehension') + for generator in node.generators: + self.visit(generator.iter) + self._register_targets(generator.target) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.elt) + self.exit_scope() + + def visit_DictComp(self, node: ast.DictComp): + """访问字典推导式""" + self.enter_scope('', 'comprehension') + for generator in node.generators: + self.visit(generator.iter) + self._register_targets(generator.target) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.key) + self.visit(node.value) + self.exit_scope() + + def visit_GeneratorExp(self, node: ast.GeneratorExp): + """访问生成器表达式""" + self.enter_scope('', 'comprehension') + for generator in node.generators: + self.visit(generator.iter) + self._register_targets(generator.target) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.elt) + self.exit_scope() + + def visit_NamedExpr(self, node: ast.NamedExpr): + """访问海象运算符 (:=)""" + # 先访问值表达式 + self.visit(node.value) + # 然后注册目标变量 + if isinstance(node.target, ast.Name): + self.add_symbol(node.target.id, node, 'variable') class ReferenceValidator(ast.NodeVisitor): @@ -181,6 +301,10 @@ def __init__(self, module_scope: ScopeInfo): # 添加常见的内置变量 self.builtin_names.update(['__name__', '__file__', '__doc__', '__package__', '__loader__', '__spec__', '__cached__', '__annotations__']) + # 使用栈来跟踪作用域路径 + self.scope_stack = [module_scope] + # 是否在推导式中 + self.in_comprehension = False def find_symbol(self, name: str) -> Optional[SymbolInfo]: """在当前作用域链中查找符号""" @@ -191,17 +315,16 @@ def find_symbol(self, name: str) -> Optional[SymbolInfo]: scope = scope.parent return None - def enter_scope(self, name: str): - """进入指定名称的作用域""" - for child in self.current_scope.children: - if child.name == name: - self.current_scope = child - return - - def exit_scope(self): - """退出当前作用域""" - if self.current_scope.parent: - self.current_scope = self.current_scope.parent + def _push_scope(self, scope: ScopeInfo): + """压入新的作用域""" + self.scope_stack.append(scope) + self.current_scope = scope + + def _pop_scope(self): + """弹出当前作用域""" + if len(self.scope_stack) > 1: + self.scope_stack.pop() + self.current_scope = self.scope_stack[-1] def visit_Name(self, node: ast.Name): """访问名称引用""" @@ -258,21 +381,147 @@ def visit_Attribute(self, node: ast.Attribute): def visit_FunctionDef(self, node: ast.FunctionDef): """访问函数定义""" - self.enter_scope(node.name) + # 查找对应的作用域 + for child in self.current_scope.children: + if child.name == node.name and child.type == 'function': + self._push_scope(child) + self.generic_visit(node) + self._pop_scope() + return + # 如果找不到作用域,仍然访问子节点 self.generic_visit(node) - self.exit_scope() def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): """访问异步函数定义""" - self.enter_scope(node.name) + for child in self.current_scope.children: + if child.name == node.name and child.type == 'function': + self._push_scope(child) + self.generic_visit(node) + self._pop_scope() + return self.generic_visit(node) - self.exit_scope() def visit_ClassDef(self, node: ast.ClassDef): """访问类定义""" - self.enter_scope(node.name) + for child in self.current_scope.children: + if child.name == node.name and child.type == 'class': + self._push_scope(child) + self.generic_visit(node) + self._pop_scope() + return self.generic_visit(node) - self.exit_scope() + + def visit_ListComp(self, node: ast.ListComp): + """访问列表推导式""" + # 创建一个临时作用域来模拟推导式的作用域 + comp_scope = ScopeInfo(name='', type='comprehension', parent=self.current_scope) + + # 保存当前状态 + saved_scope = self.current_scope + saved_in_comp = self.in_comprehension + self.current_scope = comp_scope + self.in_comprehension = True + + # 按照 SymbolTableBuilder 的顺序处理 + for generator in node.generators: + self.visit(generator.iter) + # 注册生成器目标变量到临时作用域 + self._register_comp_targets(generator.target, comp_scope) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.elt) + + # 恢复作用域 + self.current_scope = saved_scope + self.in_comprehension = saved_in_comp + + def visit_SetComp(self, node: ast.SetComp): + """访问集合推导式""" + comp_scope = ScopeInfo(name='', type='comprehension', parent=self.current_scope) + saved_scope = self.current_scope + saved_in_comp = self.in_comprehension + self.current_scope = comp_scope + self.in_comprehension = True + + for generator in node.generators: + self.visit(generator.iter) + self._register_comp_targets(generator.target, comp_scope) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.elt) + + self.current_scope = saved_scope + self.in_comprehension = saved_in_comp + + def visit_DictComp(self, node: ast.DictComp): + """访问字典推导式""" + comp_scope = ScopeInfo(name='', type='comprehension', parent=self.current_scope) + saved_scope = self.current_scope + saved_in_comp = self.in_comprehension + self.current_scope = comp_scope + self.in_comprehension = True + + for generator in node.generators: + self.visit(generator.iter) + self._register_comp_targets(generator.target, comp_scope) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.key) + self.visit(node.value) + + self.current_scope = saved_scope + self.in_comprehension = saved_in_comp + + def visit_GeneratorExp(self, node: ast.GeneratorExp): + """访问生成器表达式""" + comp_scope = ScopeInfo(name='', type='comprehension', parent=self.current_scope) + saved_scope = self.current_scope + saved_in_comp = self.in_comprehension + self.current_scope = comp_scope + self.in_comprehension = True + + for generator in node.generators: + self.visit(generator.iter) + self._register_comp_targets(generator.target, comp_scope) + for if_clause in generator.ifs: + self.visit(if_clause) + self.visit(node.elt) + + self.current_scope = saved_scope + self.in_comprehension = saved_in_comp + + def _register_comp_targets(self, target: ast.AST, scope: ScopeInfo): + """在推导式作用域中注册目标变量""" + if isinstance(target, ast.Name): + # 在临时作用域中注册变量 + scope.symbols[target.id] = SymbolInfo( + name=target.id, + node=target, + lineno=target.lineno, + col_offset=target.col_offset, + scope=scope.type, + type='variable' + ) + elif isinstance(target, (ast.Tuple, ast.List)): + for elt in target.elts: + self._register_comp_targets(elt, scope) + elif isinstance(target, ast.Starred): + self._register_comp_targets(target.value, scope) + + def visit_NamedExpr(self, node: ast.NamedExpr): + """访问海象运算符 (:=)""" + # 先访问值表达式 + self.visit(node.value) + # 如果在推导式中,注册变量到当前推导式作用域 + if self.in_comprehension and isinstance(node.target, ast.Name): + self.current_scope.symbols[node.target.id] = SymbolInfo( + name=node.target.id, + node=node.target, + lineno=node.target.lineno, + col_offset=node.target.col_offset, + scope=self.current_scope.type, + type='variable' + ) class PatternChecker(ast.NodeVisitor): @@ -323,12 +572,12 @@ def _is_try_import_error(self, node: ast.Try) -> bool: return True return False - def audit(self, source_code: str, filename: str = '') -> bool: + def audit(self, source_code_or_tree, filename: str = '') -> bool: """ 对源代码进行完整的多阶段审计 Args: - source_code: 要审计的 Python 源代码 + source_code_or_tree: 要审计的 Python 源代码字符串或 AST 树 filename: 文件名(用于错误报告) Returns: @@ -337,11 +586,15 @@ def audit(self, source_code: str, filename: str = '') -> bool: self.errors.clear() self.warnings.clear() - try: - tree = ast.parse(source_code, filename) - except SyntaxError as e: - self.errors.append(f"语法错误: {e.msg} at line {e.lineno}") - return False + # 支持传入 AST 树或源代码字符串 + if isinstance(source_code_or_tree, str): + try: + tree = ast.parse(source_code_or_tree, filename) + except SyntaxError as e: + self.errors.append(f"语法错误: {e.msg} at line {e.lineno}") + return False + else: + tree = source_code_or_tree # 阶段一:构建符号表 symbol_builder = SymbolTableBuilder() diff --git a/tests/test_attr_reference_validation.py b/tests/test_attr_reference_validation.py index 98d3d4e..074e1aa 100644 --- a/tests/test_attr_reference_validation.py +++ b/tests/test_attr_reference_validation.py @@ -30,12 +30,14 @@ def test_undefined_attribute_reference(): """) # 使用 ASTAuditor 进行静态分析 - auditor = ASTAuditor(test_script) - errors = auditor.audit() + auditor = ASTAuditor() + source_code = test_script.read_text() + result = auditor.audit(source_code) # 应该检测到 namedtuplez 未定义 - assert any("namedtuplez" in str(error) for error in errors), \ - f"未检测到 namedtuplez 拼写错误。错误列表:{errors}" + assert not result, f"审计应该失败,但返回了 {result}" + report = auditor.get_report() + assert "namedtuplez" in report, f"未检测到 namedtuplez 拼写错误。报告:{report}" def test_valid_external_module_attributes(): @@ -57,14 +59,16 @@ def test_valid_external_module_attributes(): """) # 使用 ASTAuditor 进行静态分析 - auditor = ASTAuditor(test_script) - errors = auditor.audit() + auditor = ASTAuditor() + source_code = test_script.read_text() + result = auditor.audit(source_code) - # 不应该报告这些标准库属性的错误 - for error in errors: - assert "os.path" not in str(error) - assert "sys.version" not in str(error) - assert "Path.home" not in str(error) + # 应该通过审计,不应该报告这些标准库属性的错误 + assert result, f"审计失败:{auditor.get_report()}" + report = auditor.get_report() + assert "os.path" not in report + assert "sys.version" not in report + assert "Path.home" not in report def test_undefined_class_attribute(): @@ -95,12 +99,14 @@ def get_value(self): """) # 使用 ASTAuditor 进行静态分析 - auditor = ASTAuditor(test_script) - errors = auditor.audit() + auditor = ASTAuditor() + source_code = test_script.read_text() + result = auditor.audit(source_code) # 应该检测到 non_existent_method 未定义 # 注意:基础实现可能只检测直接的类方法,不检测实例属性 # 这是一个渐进式改进,先确保基础功能工作 + # 目前的实现可能无法检测实例属性,所以暂时跳过这个断言 def test_nested_attribute_chain(): @@ -135,12 +141,13 @@ def method(self): """) # 使用 ASTAuditor 进行静态分析 - auditor = ASTAuditor(test_script) - errors = auditor.audit() + auditor = ASTAuditor() + source_code = test_script.read_text() + result = auditor.audit(source_code) # 深层属性链的检测是高级功能,基础实现可能不支持 # 这里主要验证不会因为嵌套属性链而崩溃 - assert isinstance(errors, list) + assert isinstance(result, bool) def test_module_import_and_usage(): @@ -176,11 +183,12 @@ class ModuleClass: """) # 使用 ASTAuditor 进行静态分析 - auditor = ASTAuditor(test_script) - errors = auditor.audit() + auditor = ASTAuditor() + source_code = test_script.read_text() + result = auditor.audit(source_code) # 验证分析完成,不崩溃 - assert isinstance(errors, list) + assert isinstance(result, bool) if __name__ == "__main__": diff --git a/tests/test_issue_39_scope_binding.py b/tests/test_issue_39_scope_binding.py new file mode 100644 index 0000000..75d4eb6 --- /dev/null +++ b/tests/test_issue_39_scope_binding.py @@ -0,0 +1,286 @@ +""" +Issue #39 测试 - 验证 ASTAuditor 正确处理所有变量绑定场景 + +测试确保 ASTAuditor 不会对以下合法 Python 语法结构产生误报: +- 循环变量 (for, async for) +- 推导式目标变量 +- 上下文管理器别名 (with, async with) +- 异常捕获别名 (except ... as ...) +- 海象运算符 (:=) +""" + +import ast +import pytest +from pysymphony.auditor.auditor import ASTAuditor + + +class TestIssue39ScopeBinding: + """测试所有变量绑定场景的正确处理""" + + def test_for_loop_simple(self): + """测试简单 for 循环变量绑定""" + code = ''' +def test_func(): + for i in range(5): + print(i) +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_for_loop_unpacking(self): + """测试 for 循环中的元组解包""" + code = ''' +def test_func(items): + result = [] + for i, item in enumerate(items): + result.append((i, item)) + return result +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_for_loop_nested_unpacking(self): + """测试 for 循环中的嵌套解包""" + code = ''' +def test_func(): + data = [((1, 2), ('a', 'b')), ((3, 4), ('c', 'd'))] + for (x, y), (a, b) in data: + print(x, y, a, b) +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_list_comprehension(self): + """测试列表推导式中的变量绑定""" + code = ''' +def test_func(): + squares = [x**2 for x in range(10)] + return squares +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_nested_comprehensions(self): + """测试嵌套推导式""" + code = ''' +def test_func(): + matrix = [[i*j for j in range(3)] for i in range(3)] + # 字典推导式与条件 + even_squares = {i: i**2 for i in range(10) if i % 2 == 0} + # 嵌套的复杂推导式 + result = {(i, j): i*j for i in range(3) for j in range(i) if j % 2 == 0} + return matrix, even_squares, result +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_set_and_dict_comprehensions(self): + """测试集合和字典推导式""" + code = ''' +def test_func(): + # 集合推导式 + unique_squares = {x**2 for x in range(-5, 6)} + # 字典推导式 + word_lengths = {word: len(word) for word in ['hello', 'world', 'python']} + return unique_squares, word_lengths +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_generator_expression(self): + """测试生成器表达式""" + code = ''' +def test_func(): + # 生成器表达式 + sum_of_squares = sum(x**2 for x in range(10)) + # 带条件的生成器表达式 + even_nums = list(n for n in range(20) if n % 2 == 0) + return sum_of_squares, even_nums +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_with_statement(self): + """测试 with 语句的别名绑定""" + code = ''' +def test_func(path): + content = "" + with open(path) as f: + content = f.read() + return content +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_with_multiple_items(self): + """测试 with 语句的多个上下文管理器""" + code = ''' +def test_func(path1, path2): + with open(path1) as f1, open(path2) as f2: + content1 = f1.read() + content2 = f2.read() + return content1, content2 +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_except_handler(self): + """测试异常处理器的别名绑定""" + code = ''' +def test_func(path): + try: + with open(path) as f: + return f.read() + except IOError as e: + print(f"Error reading file: {e}") + return None + except ValueError as ve: + print(f"Value error: {ve}") + return None +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_walrus_operator(self): + """测试海象运算符(:=)""" + code = ''' +def test_func(items): + # 在 if 语句中使用 + if (n := len(items)) > 0: + print(f"Processing {n} items") + + # 在 while 循环中使用 + data = iter(items) + while (item := next(data, None)) is not None: + print(item) + + # 在列表推导式中使用 + results = [y for x in items if (y := x * 2) > 10] + return results +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_complex_scenario(self): + """测试 issue #39 中的完整示例""" + code = ''' +def demo(items, path): + result = [] + # 场景1: for 循环解包 + for i, item in enumerate(items): + result.append((i, item)) + + # 场景2: 列表推导式 + squares = [x**2 for x in range(5)] + + # 场景3: with...as... 别名 + content = "" + try: + with open(path) as f: + content = f.read() + # 场景4: except...as... 别名 + except IOError as e: + print(f"Error reading file: {e}") + + # 场景5: 海象运算符 + if (n := len(items)) > 0: + print(f"Processing {n} items") + + return result, squares, content +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_async_scenarios(self): + """测试异步场景的变量绑定""" + code = ''' +async def test_async(cursor, session_factory): + # 异步 for 循环 + async for row in cursor: + print(row) + + # 异步 with 语句 + async with session_factory() as session: + result = await session.fetch_data() + + # 异步推导式 + results = [x async for x in cursor] + return results +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_starred_unpacking(self): + """测试带星号的解包""" + code = ''' +def test_func(): + data = [1, 2, 3, 4, 5] + first, *middle, last = data + print(first, middle, last) + + # 在 for 循环中使用 + items = [(1, 2, 3), (4, 5, 6, 7), (8, 9)] + for x, *rest in items: + print(x, rest) +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_comprehension_scope_isolation(self): + """测试推导式的作用域隔离""" + code = ''' +def test_func(): + x = 10 # 外部变量 + # 推导式中的 x 不应该影响外部的 x + squares = [x**2 for x in range(5)] + # 外部的 x 仍然可用 + print(x) # 应该打印 10 + return squares +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" + + def test_nested_with_statements(self): + """测试嵌套的 with 语句""" + code = ''' +def test_func(db_path, log_path): + with open(db_path) as db: + db_content = db.read() + with open(log_path, 'w') as log: + log.write(f"Read {len(db_content)} bytes from database") + return db_content +''' + tree = ast.parse(code) + auditor = ASTAuditor() + result = auditor.audit(tree) + assert result, f"Expected audit to pass, but got errors: {auditor.get_report()}" \ No newline at end of file