Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pysymphony/auditor/auditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,34 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
# 添加参数到函数作用域
for arg in node.args.args:
self.add_symbol(arg.arg, arg, 'variable')
# 添加 *args
if node.args.vararg:
self.add_symbol(node.args.vararg.arg, node.args.vararg, 'variable')
# 添加 **kwargs
if node.args.kwarg:
self.add_symbol(node.args.kwarg.arg, node.args.kwarg, 'variable')
# 添加仅关键字参数
for arg in node.args.kwonlyargs:
self.add_symbol(arg.arg, arg, 'variable')
self.generic_visit(node)
self.exit_scope()

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
"""访问异步函数定义"""
self.add_symbol(node.name, node, 'function')
self.enter_scope(node.name, 'function')
# 添加参数到函数作用域
for arg in node.args.args:
self.add_symbol(arg.arg, arg, 'variable')
# 添加 *args
if node.args.vararg:
self.add_symbol(node.args.vararg.arg, node.args.vararg, 'variable')
# 添加 **kwargs
if node.args.kwarg:
self.add_symbol(node.args.kwarg.arg, node.args.kwarg, 'variable')
# 添加仅关键字参数
for arg in node.args.kwonlyargs:
self.add_symbol(arg.arg, arg, 'variable')
self.generic_visit(node)
self.exit_scope()

Expand Down
176 changes: 147 additions & 29 deletions scripts/advanced_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,8 @@ def visit_Module(self, node: ast.Module):
init_statements.append(stmt)
elif isinstance(stmt, ast.Try) and self._is_try_import_error(stmt):
# 特殊处理 try...except ImportError 块
# 保留完整的块作为初始化语句
init_statements.append(stmt)
# 访问以分析内部的导入
# 不再将运行时导入块保留为初始化语句
# 仅访问以分析内部的导入,这些导入会被转换为普通导入
self.visit(stmt)
else:
# 其他顶层语句(副作用初始化)
Expand Down Expand Up @@ -1059,6 +1058,46 @@ def __init__(self, project_root: Path):
self.import_registry: Set[Tuple[str, str]] = set() # 解决 #2: (module, alias)
self.entry_module_qname: Optional[str] = None # 解决 #3: 入口脚本的模块名

def _fix_import_alias_dependencies(self):
"""修复 import_alias 符号的依赖关系

在所有模块都被分析后,重新建立 import_alias 到实际符号的依赖关系。
这解决了循环导入时依赖关系缺失的问题。
"""
for symbol in list(self.visitor.all_symbols.values()):
if symbol.symbol_type != 'import_alias':
continue

# 检查是否是 from ... import ... 类型的导入
if isinstance(symbol.def_node, ast.ImportFrom):
node = symbol.def_node
module_name = node.module or ''
level = node.level or 0

# 解析模块路径
if level > 0:
# 相对导入
from_module = self.visitor.get_absolute_module_name(
module_name, symbol.scope.module_path, level
)
else:
from_module = module_name

# 尝试找到模块路径
module_path = self.visitor.resolve_module_path(from_module)
if not module_path:
continue

# 查找具体的导入项
for alias in node.names:
if alias.asname == symbol.name or (not alias.asname and alias.name == symbol.name):
# 找到对应的符号
target_qname = f"{self.visitor.get_module_qname(module_path)}.{alias.name}"
if target_qname in self.visitor.all_symbols:
target_symbol = self.visitor.all_symbols[target_qname]
# 建立依赖关系
symbol.dependencies.add(target_symbol)

def analyze_entry_script(self, script_path: Path) -> Tuple[Set[Symbol], List[ast.AST]]:
"""
分析入口脚本,返回初始符号集和主代码。
Expand All @@ -1067,6 +1106,9 @@ def analyze_entry_script(self, script_path: Path) -> Tuple[Set[Symbol], List[ast
# 1. 执行唯一且完整的分析过程,此过程会填充 visitor 的所有状态
self.visitor.analyze_module(script_path)
self.entry_module_qname = self.visitor.get_module_qname(script_path) # 记录入口模块名

# 修复 import_alias 的依赖关系
self._fix_import_alias_dependencies()

initial_symbols = set()
main_code = []
Expand Down Expand Up @@ -1286,9 +1328,29 @@ def topological_sort(self, symbols: Set[Symbol]) -> List[Symbol]:

for symbol in symbols:
in_degree[symbol] = 0

# 解析传递依赖:如果符号依赖一个import_alias,需要找到该import_alias的实际目标
def resolve_transitive_deps(symbol: Symbol) -> Set[Symbol]:
"""解析符号的传递依赖,展开import_alias"""
resolved_deps = set()
for dep in symbol.dependencies:
if dep.symbol_type == 'import_alias' and dep.dependencies:
# 如果依赖是import_alias,找到它指向的实际符号
for target in dep.dependencies:
if target.symbol_type in ('function', 'class', 'variable'):
resolved_deps.add(target)
elif target.symbol_type == 'import_alias':
# 递归处理import_alias链
resolved_deps.update(resolve_transitive_deps(target))
else:
resolved_deps.add(dep)
return resolved_deps

for symbol in symbols:
for dep in symbol.dependencies:
# 获取解析后的依赖
resolved_deps = resolve_transitive_deps(symbol)

for dep in resolved_deps:
if dep in symbols and dep != symbol: # 忽略自引用
graph[dep].add(symbol)
in_degree[symbol] += 1
Expand Down Expand Up @@ -1478,7 +1540,7 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str
qname.startswith(module_qname + '.') and
symbol not in all_import_aliases):
# 只添加那些属于已处理模块的导入别名
if any(s.scope.module_path == module_path for s in self.needed_symbols):
if any(s.scope and s.scope.module_path == module_path for s in self.needed_symbols):
all_import_aliases.append(symbol)

# 遍历所有收集到的导入别名
Expand Down Expand Up @@ -1634,26 +1696,30 @@ def generate_name_mappings(self, symbols: Set[Symbol]):

# 生成映射
for symbol in all_symbols_to_consider:
if name_counts[symbol.name] > 1:
# 有冲突,需要重命名(非 import_alias 的情况)
module_key = self.visitor.get_module_qname(symbol.scope.module_path)
module_key = module_key.replace('.', '_').replace('__init__', 'pkg')

# 对于运行时导入,添加特殊后缀以区分
# 对于 import_alias 符号,总是添加后缀(__mod 或 __rt)
if symbol.symbol_type == 'import_alias':
if symbol.is_runtime_import:
new_name = f"{symbol.name}__rt"
else:
# 对于函数和其他类型,使用模块前缀
new_name = f"{module_key}_{symbol.name}"

new_name = f"{symbol.name}__mod"
self.name_mappings[symbol.qname] = new_name
elif name_counts[symbol.name] > 1:
# 对于其他符号,只在有冲突时重命名
if symbol.scope and symbol.scope.module_path:
module_key = self.visitor.get_module_qname(symbol.scope.module_path)
module_key = module_key.replace('.', '_').replace('__init__', 'pkg')
else:
# 如果没有 scope,使用符号的 qname 前缀
module_key = symbol.qname.rsplit('.', 1)[0] if '.' in symbol.qname else 'unknown'
new_name = f"{module_key}_{symbol.name}"
self.name_mappings[symbol.qname] = new_name

# 同时为带类型后缀的版本添加映射
type_qname = f"{symbol.qname}#{symbol.symbol_type}"
if type_qname in self.visitor.all_symbols:
self.name_mappings[type_qname] = new_name
else:
# 无冲突,保持原名
# 无冲突的非导入符号,保持原名
self.name_mappings[symbol.qname] = symbol.name

# 同时为带类型后缀的版本添加映射
Expand Down Expand Up @@ -1698,8 +1764,11 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer',

# 检查是否是完全重复的定义
# 简化处理:如果名称已存在,则重命名
module_qname = self.visitor.get_module_qname(symbol.scope.module_path)
module_alias = module_qname.replace('.', '_')
if symbol.scope and symbol.scope.module_path:
module_qname = self.visitor.get_module_qname(symbol.scope.module_path)
module_alias = module_qname.replace('.', '_')
else:
module_alias = 'unknown'
new_name = f"{target_name}__from_{module_alias}"
self.name_mappings[symbol.qname] = new_name
target_name = new_name
Expand All @@ -1715,7 +1784,7 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer',

# 写入结果
if transformed is not None:
if symbol.scope.module_path:
if symbol.scope and symbol.scope.module_path:
rel_path = symbol.scope.module_path.relative_to(self.project_root)
result_lines.append(f"# From {rel_path}")
result_lines.append(ast.unparse(transformed))
Expand All @@ -1727,9 +1796,13 @@ def _process_imports(self, imports: Set[str]) -> List[str]:
修复:使用更精确的去重键,包含导入样式信息
"""
result = []
# 添加别名去重集合,避免相同的别名被多次定义
seen_aliases = set()

for imp in sorted(imports):
# 解析导入语句
new_alias = None # 初始化变量

if imp.startswith('from '):
# from X import Y as Z 或 from X import Y
parts = imp.split()
Expand All @@ -1748,9 +1821,11 @@ def _process_imports(self, imports: Set[str]) -> List[str]:
# from X import Y
module = parts[1]
name = parts[3]
# B2 修复:即使没有别名,也要添加 __mod 后缀
new_alias = f"{name}__mod"
new_imp = f"from {module} import {name} as {new_alias}"
# Issue #37 修复:包含导入样式在去重键中
key = ('from', module, name, name)
new_imp = imp
key = ('from', module, name, new_alias)
else:
# import X as Y 或 import X
parts = imp.split()
Expand All @@ -1774,8 +1849,9 @@ def _process_imports(self, imports: Set[str]) -> List[str]:
# Issue #37 修复:包含导入样式在去重键中
key = ('import', module, new_alias)

# 检查是否已存在
if key not in self.import_registry:
# 检查是否已存在(包括别名去重)
if key not in self.import_registry and new_alias not in seen_aliases:
seen_aliases.add(new_alias)
self.import_registry.add(key)
result.append(new_imp)

Expand All @@ -1802,7 +1878,7 @@ def merge_script(self, script_path: Path) -> str:
# Issue #37 修复:过滤掉入口模块中的变量定义
# 这些变量会作为主代码的一部分输出,不需要单独输出
if (s.symbol_type == 'variable' and
s.scope.module_path == script_path):
s.scope and s.scope.module_path == script_path):
continue

# 检查是否是类的方法(通过判断qname中是否包含类名)
Expand Down Expand Up @@ -1842,12 +1918,15 @@ def merge_script(self, script_path: Path) -> str:
# 5. 生成名称映射
self.generate_name_mappings(output_symbols)

# B2 修复:为所有 import_alias 符号添加 __mod 后缀映射
# B2 修复:为所有 import_alias 符号添加后缀映射
# 这包括那些被过滤掉不输出的外部导入
for symbol in self.visitor.all_symbols.values():
if symbol.symbol_type == 'import_alias' and symbol.qname not in self.name_mappings:
# 为所有导入别名添加 __mod 后缀
new_name = f"{symbol.name}__mod"
# 根据是否是运行时导入选择不同的后缀
if symbol.is_runtime_import:
new_name = f"{symbol.name}__rt"
else:
new_name = f"{symbol.name}__mod"
self.name_mappings[symbol.qname] = new_name

# 6. 生成代码
Expand All @@ -1865,6 +1944,37 @@ def merge_script(self, script_path: Path) -> str:
result_lines.extend(processed_imports)
result_lines.append("")

# 处理运行时导入(带 __rt 后缀)
runtime_imports = []
for symbol in self.visitor.all_symbols.values():
if symbol.symbol_type == 'import_alias' and symbol.is_runtime_import:
# 获取重命名后的名称
new_name = self.name_mappings.get(symbol.qname, symbol.name)

# 根据导入类型生成导入语句
if isinstance(symbol.def_node, ast.Import):
# import xxx as yyy 形式
for alias in symbol.def_node.names:
if alias.asname == symbol.name or (not alias.asname and alias.name.split('.')[0] == symbol.name):
# 如果原本没有别名且新名称不同,或者有别名但需要改名
if new_name != alias.name.split('.')[0]:
runtime_imports.append(f"import {alias.name} as {new_name}")
else:
runtime_imports.append(f"import {alias.name}")
break
elif isinstance(symbol.def_node, ast.ImportFrom):
# from xxx import yyy as zzz 形式
module = symbol.def_node.module or ''
for alias in symbol.def_node.names:
if alias.asname == symbol.name or (not alias.asname and alias.name == symbol.name):
runtime_imports.append(f"from {module} import {alias.name} as {new_name}")
break

if runtime_imports:
result_lines.append("# Runtime imports (originally in try...except ImportError blocks)")
result_lines.extend(sorted(set(runtime_imports)))
result_lines.append("")

# 收集并重新注入必要的导入别名
reinjected_imports = self._collect_and_reinject_imports(output_symbols)
if reinjected_imports:
Expand All @@ -1880,7 +1990,11 @@ def merge_script(self, script_path: Path) -> str:
self._write_symbol(symbol, transformer, result_lines)

# 收集模块初始化语句
module_qname = self.visitor.get_module_qname(symbol.scope.module_path)
if symbol.scope and symbol.scope.module_path:
module_qname = self.visitor.get_module_qname(symbol.scope.module_path)
else:
continue

if module_qname in self.visitor.all_symbols:
module_symbol = self.visitor.all_symbols[module_qname]
# 不要收集入口模块的初始化语句,因为它们已经在 main_code 中处理了
Expand All @@ -1900,7 +2014,7 @@ def merge_script(self, script_path: Path) -> str:

# 设置正确的模块作用域
module_symbol = self.visitor.all_symbols.get(module_qname)
if module_symbol and module_symbol.scope.module_path:
if module_symbol and module_symbol.scope and module_symbol.scope.module_path:
module_path = module_symbol.scope.module_path
if module_path in self.visitor.module_symbols:
module_scope = self.visitor.module_symbols[module_path].get('__scope__')
Expand Down Expand Up @@ -2007,7 +2121,11 @@ def transform_symbol(self, symbol: Symbol) -> ast.AST:
return None

# 设置当前作用域栈
self.current_scope_stack = [symbol.scope]
if symbol.scope:
self.current_scope_stack = [symbol.scope]
else:
# 如果没有 scope,使用默认的模块作用域
self.current_scope_stack = []

# 转换节点
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
Expand Down
Loading
Loading