Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 127 additions & 14 deletions scripts/advanced_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,15 @@ def visit_Module(self, node: ast.Module):
if isinstance(stmt, (ast.Import, ast.ImportFrom)):
# 处理导入
self.visit(stmt)
elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef,
ast.ClassDef, ast.Assign, ast.AnnAssign)):
# 定义语句
elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
# 函数和类定义
self.visit(stmt)
elif isinstance(stmt, (ast.Assign, ast.AnnAssign)):
# Issue #37 修复:赋值语句既是定义也是初始化
# 需要访问以创建变量符号
self.visit(stmt)
# 同时也需要作为初始化语句保留
init_statements.append(stmt)
elif isinstance(stmt, ast.Try) and self._is_try_import_error(stmt):
# 特殊处理 try...except ImportError 块
# 保留完整的块作为初始化语句
Expand Down Expand Up @@ -364,6 +369,17 @@ def visit_Import(self, node: ast.Import):
is_runtime_import=self.in_try_import_error
)

# Issue #37 修复:为外部导入创建模块符号依赖
# 创建一个虚拟的模块符号作为依赖
module_symbol = Symbol(
name=module_name,
qname=module_name,
symbol_type='module',
def_node=None,
scope=None
)
symbol.dependencies.add(module_symbol)

self.current_scope().symbols[alias_name] = symbol
self.all_symbols[symbol.qname] = symbol
self.defnode_to_scope[symbol.def_node] = symbol.scope
Expand Down Expand Up @@ -452,8 +468,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
is_runtime_import=self.in_try_import_error
)

# 对于外部导入,我们不设置依赖关系
# 但仍然注册符号
# Issue #37 修复:为外部导入创建模块符号依赖
# 创建一个虚拟的模块符号作为依赖
if module_name: # from xxx import yyy 的情况
module_symbol = Symbol(
name=module_name,
qname=module_name,
symbol_type='module',
def_node=None,
scope=None
)
symbol.dependencies.add(module_symbol)

self.current_scope().symbols[alias_name] = symbol
self.all_symbols[symbol.qname] = symbol
self.defnode_to_scope[symbol.def_node] = symbol.scope
Expand Down Expand Up @@ -1413,12 +1439,26 @@ def _needs_reinject(self, symbol: Symbol, output_symbols: Set[Symbol], visited:
def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str]:
"""收集并重新注入必要的导入别名

Issue #37 修复:避免重复注入已经在 external_imports 中的导入

返回需要添加的导入语句列表
"""
imports_to_reinject = []
import_set = set() # 用于去重
processed_symbols = set() # 跟踪已处理的符号

# Issue #37 修复:记录已经处理过的外部导入模块
external_modules = set()
for imp in self.visitor.external_imports:
if imp.startswith('import '):
parts = imp.split()
module = parts[1]
external_modules.add(module)
elif imp.startswith('from '):
parts = imp.split()
module = parts[1]
external_modules.add(module)

# 收集所有相关模块中的导入别名
# 不仅仅是 needed_symbols,还包括所有被访问过的模块中的导入
all_import_aliases = []
Expand Down Expand Up @@ -1480,6 +1520,10 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str
module_name = alias.name
alias_name = alias.asname if alias.asname else alias.name

# Issue #37 修复:跳过已经在 external_imports 中的模块
if module_name in external_modules:
continue

# 检查当前别名对应的符号
if symbol.name != alias_name:
continue
Expand Down Expand Up @@ -1514,6 +1558,10 @@ def _collect_and_reinject_imports(self, output_symbols: Set[Symbol]) -> List[str
# 跳过相对导入
if level > 0:
continue

# Issue #37 修复:跳过已经在 external_imports 中的模块
if module in external_modules:
continue

for alias in node.names:
name = alias.name
Expand Down Expand Up @@ -1612,6 +1660,29 @@ def generate_name_mappings(self, symbols: Set[Symbol]):
type_qname = f"{symbol.qname}#{symbol.symbol_type}"
if type_qname in self.visitor.all_symbols:
self.name_mappings[type_qname] = symbol.name

# Issue #37 修复:为外部导入的别名添加正确的映射
# 检查 import_registry 中的映射,并更新 name_mappings
module_alias_map = {} # module -> alias 的映射
for key in self.import_registry:
if key[0] == 'import' and len(key) == 3:
# ('import', 'json', 'json__mod')
module = key[1]
alias = key[2]
module_alias_map[module] = alias

# 更新所有 import_alias 符号的映射
for qname, symbol in self.visitor.all_symbols.items():
if symbol.symbol_type == 'import_alias' and symbol.dependencies:
# 检查依赖的模块
for dep in symbol.dependencies:
if dep.symbol_type == 'module' and dep.name in module_alias_map:
# 更新映射到新的别名
self.name_mappings[symbol.qname] = module_alias_map[dep.name]
# 同时更新带类型后缀的版本
type_qname = f"{symbol.qname}#{symbol.symbol_type}"
if type_qname in self.visitor.all_symbols:
self.name_mappings[type_qname] = module_alias_map[dep.name]

def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer', result_lines: List[str]):
"""解决 #1: 写入符号时进行冲突检测"""
Expand Down Expand Up @@ -1651,9 +1722,11 @@ def _write_symbol(self, symbol: Symbol, transformer: 'AdvancedNodeTransformer',
result_lines.append("")

def _process_imports(self, imports: Set[str]) -> List[str]:
"""解决 #2: 处理导入去重和 alias 冲突"""
"""解决 #2: 处理导入去重和 alias 冲突

修复:使用更精确的去重键,包含导入样式信息
"""
result = []
alias_map = {} # alias -> module 的映射

for imp in sorted(imports):
# 解析导入语句
Expand All @@ -1669,12 +1742,14 @@ def _process_imports(self, imports: Set[str]) -> List[str]:
# B2 修复:为别名添加 __mod 后缀
new_alias = f"{alias}__mod"
new_imp = f"from {module} import {name} as {new_alias}"
key = (module, name, new_alias)
# Issue #37 修复:包含导入样式在去重键中
key = ('from', module, name, new_alias)
else:
# from X import Y
module = parts[1]
name = parts[3]
key = (module, name, name)
# Issue #37 修复:包含导入样式在去重键中
key = ('from', module, name, name)
new_imp = imp
else:
# import X as Y 或 import X
Expand All @@ -1687,15 +1762,17 @@ def _process_imports(self, imports: Set[str]) -> List[str]:
# B2 修复:为别名添加 __mod 后缀
new_alias = f"{alias}__mod"
new_imp = f"import {module} as {new_alias}"
key = (module, new_alias)
# Issue #37 修复:包含导入样式在去重键中
key = ('import', module, new_alias)
else:
# import X
module = parts[1]
# B2 修复:对于没有别名的导入,也添加别名以避免冲突
alias = module.split('.')[0]
new_alias = f"{alias}__mod"
new_imp = f"import {module} as {new_alias}"
key = (module, new_alias)
# Issue #37 修复:包含导入样式在去重键中
key = ('import', module, new_alias)

# 检查是否已存在
if key not in self.import_registry:
Expand All @@ -1722,6 +1799,12 @@ def merge_script(self, script_path: Path) -> str:
if s.symbol_type in ('import_alias', 'module', 'parameter'):
continue

# Issue #37 修复:过滤掉入口模块中的变量定义
# 这些变量会作为主代码的一部分输出,不需要单独输出
if (s.symbol_type == 'variable' and
s.scope.module_path == script_path):
continue

# 检查是否是类的方法(通过判断qname中是否包含类名)
is_class_method = False
if s.symbol_type == 'function' and '.' in s.qname:
Expand Down Expand Up @@ -1750,6 +1833,12 @@ def merge_script(self, script_path: Path) -> str:
# 4. 拓扑排序
sorted_symbols = self.topological_sort(output_symbols)

# Issue #37 修复:在生成名称映射之前,先处理外部导入
# 这样import_registry会被填充,供generate_name_mappings使用
processed_imports = []
if self.visitor.external_imports:
processed_imports = self._process_imports(self.visitor.external_imports)

# 5. 生成名称映射
self.generate_name_mappings(output_symbols)

Expand All @@ -1771,9 +1860,8 @@ def merge_script(self, script_path: Path) -> str:
result_lines.extend(sorted(self.visitor.future_imports))
result_lines.append("")

# 外部导入(通过去重处理)
if self.visitor.external_imports:
processed_imports = self._process_imports(self.visitor.external_imports)
# 外部导入(使用之前已经处理好的)
if processed_imports:
result_lines.extend(processed_imports)
result_lines.append("")

Expand Down Expand Up @@ -1827,6 +1915,31 @@ def merge_script(self, script_path: Path) -> str:
# 解决 #3: 跳过非入口模块的 __main__ 块
if should_skip_main and self._is_dunder_main(stmt):
continue

# Issue #37 修复:跳过已经作为符号输出的赋值语句
# 只跳过简单的变量赋值,保留其他语句
if isinstance(stmt, (ast.Assign, ast.AnnAssign)):
# 检查是否是简单的变量定义(没有函数调用等副作用)
skip = False
if isinstance(stmt, ast.Assign):
# 检查是否是简单赋值
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
var_name = stmt.targets[0].id
# 检查这个变量是否已经被输出为符号
var_qname = f"{module_qname}.{var_name}"
if var_qname in self.name_mappings:
# 检查赋值的值是否是常量
if isinstance(stmt.value, (ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)):
skip = True
elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
var_name = stmt.target.id
var_qname = f"{module_qname}.{var_name}"
if var_qname in self.name_mappings and stmt.value:
if isinstance(stmt.value, (ast.Constant, ast.Name, ast.UnaryOp, ast.BinOp)):
skip = True

if skip:
continue

transformed_stmt = transformer.visit(copy.deepcopy(stmt))
result_lines.append(ast.unparse(transformed_stmt))
Expand Down
Loading
Loading