diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84b7a106e..92a07d62c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: run: | # Ruff has no max-module-lines rule. This check prevents new files from # exceeding the current worst case. Tighten the threshold over time. - MAX_LINES=2450 # current max: 2404 (github_downloader.py) + MAX_LINES=2000 # Stage 1 (was 2450); target 1400 deferred to Stage 2 VIOLATIONS=$(find src/ -name '*.py' -print0 | xargs -0 -I{} awk -v max="$MAX_LINES" \ 'END { if (NR > max) printf "%s: %d lines (max %d)\n", FILENAME, NR, max }' {}) if [ -n "$VIOLATIONS" ]; then diff --git a/pyproject.toml b/pyproject.toml index a6d61ad9b..f7493a8a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,13 +96,16 @@ ignore = [ # High initial thresholds set just above current codebase maximums. # Prevents new code from exceeding the worst existing violations. # Tighten these over time via dedicated refactoring PRs. -max-statements = 275 # current max: 269 (mcp_integrator.py::install) -max-args = 18 # current max: 16 (commands/install.py) -max-branches = 115 # current max: 108 (mcp_integrator.py::install) -max-returns = 18 # current max: 16 (marketplace/publisher.py) +# Stage 1 thresholds (PR #1464, issue #1077). +# Roadmap: Stage 2 targets max-complexity<=20, max-branches<=25 (McCabe standard). +max-statements = 200 # Stage 1 (was 275) +max-args = 15 # Stage 1 (was 18) +max-branches = 60 # Stage 1 (was 115) +max-returns = 12 # Stage 1 (was 18) [tool.ruff.lint.mccabe] -max-complexity = 100 # current max: 97 (mcp_integrator.py::install) +# Stage 1 (was 100). Stage 2 target: <=20 (McCabe industry standard). +max-complexity = 50 [tool.ruff.lint.per-file-ignores] # Subprocess calls are intentional in a CLI tool @@ -120,7 +123,6 @@ max-complexity = 100 # current max: 97 (mcp_integrator.py::install) "src/apm_cli/commands/init.py" = ["F821", "S603", "S607"] "src/apm_cli/install/pipeline.py" = ["F821", "S603", "S607"] "src/apm_cli/integration/skill_integrator.py" = ["F821", "S603", "S607"] - [tool.ruff.format] quote-style = "double" indent-style = "space" diff --git a/src/apm_cli/commands/compile/cli.py b/src/apm_cli/commands/compile/cli.py index 5dd10cf8f..c35e1e38e 100644 --- a/src/apm_cli/commands/compile/cli.py +++ b/src/apm_cli/commands/compile/cli.py @@ -318,6 +318,429 @@ def _resolve_effective_target( return detected_target, detection_reason, config_target +def _validate_project(logger: CommandLogger, dry_run: bool) -> None: + """Check APM project exists and has content. + + Calls ``sys.exit(1)`` on fatal errors. In dry-run mode the function + emits diagnostic messages but does *not* exit so callers can test the + full compile path even without real content. + """ + from ...compilation.constitution import find_constitution + + if not Path(APM_YML_FILENAME).exists(): + logger.error("Not an APM project - no apm.yml found") + logger.progress(" To initialize an APM project, run:") + logger.progress(" apm init") + sys.exit(1) + + # Check if there are any instruction files to compile + apm_modules_exists = Path(APM_MODULES_DIR).exists() + constitution_exists = find_constitution(Path(".")).exists() + + # Check if .apm directory has actual content + apm_dir = Path(APM_DIR) + local_apm_has_content = apm_dir.exists() and ( + any(apm_dir.rglob("*.instructions.md")) or any(apm_dir.rglob("*.chatmode.md")) + ) + + # If no primitive sources exist, check deeper to provide better feedback + if not apm_modules_exists and not local_apm_has_content and not constitution_exists: + # Check if .apm directories exist but are empty + has_empty_apm = ( + apm_dir.exists() + and not any(apm_dir.rglob("*.instructions.md")) + and not any(apm_dir.rglob("*.chatmode.md")) + ) + + if has_empty_apm: + logger.error("No instruction files found in .apm/ directory") + logger.progress(" To add instructions, create files like:") + logger.progress(" .apm/instructions/coding-standards.instructions.md") + logger.progress(" .apm/chatmodes/backend-engineer.chatmode.md") + else: + logger.error("No APM content found to compile") + logger.progress(" To get started:") + logger.progress(" 1. Install APM dependencies: apm install /") + logger.progress(" 2. Or create local instructions: mkdir -p .apm/instructions") + logger.progress(" 3. Then create .instructions.md or .chatmode.md files") + + if not dry_run: # Don't exit on dry-run to allow testing + sys.exit(1) + + +def _run_validation_mode(logger: CommandLogger, verbose: bool) -> None: + """Run validation-only mode (``--validate`` flag). + + Discovers all primitives, validates them, and prints a structured + summary. Calls ``sys.exit(1)`` when validation errors are found. + """ + logger.start("Validating APM context...", symbol="gear") + compiler = AgentsCompiler(".") + try: + primitives = discover_primitives(".") + except Exception as e: + logger.error(f"Failed to discover primitives: {e}") + logger.progress(f" Error details: {type(e).__name__}") + sys.exit(1) + + validation_errors = compiler.validate_primitives(primitives) + if validation_errors: + _display_validation_errors(validation_errors) + logger.error(f"Validation failed with {len(validation_errors)} errors") + sys.exit(1) + + logger.success("All primitives validated successfully!") + logger.progress(f"Validated {primitives.count()} primitives:") + logger.progress(f" * {len(primitives.chatmodes)} chatmodes") + logger.progress(f" * {len(primitives.instructions)} instructions") + logger.progress(f" * {len(primitives.contexts)} contexts") + + # Show MCP dependency validation count + try: + from ...models.apm_package import APMPackage + + apm_pkg = APMPackage.from_apm_yml(Path(APM_YML_FILENAME)) + mcp_count = len(apm_pkg.get_mcp_dependencies()) + if mcp_count > 0: + logger.progress(f" * {mcp_count} MCP dependencies") + except Exception: + pass + + +def _run_watch_mode( + logger: CommandLogger, + target: str | list[str] | None, + output: str, + chatmode: str | None, + no_links: bool, + dry_run: bool, + verbose: bool, +) -> None: + """Set up and run watch mode (``--watch`` flag). + + Resolves the effective compile target using the same logic as the + one-shot path so that ``targets: [claude, cursor]`` in apm.yml does + not silently regress on every recompile (#1345), then delegates to + :func:`_watch_mode`. + """ + effective_target, _detection_reason, config_target = _resolve_effective_target(target) + _watch_mode( + output, + chatmode, + no_links, + dry_run, + verbose=verbose, + effective_target=effective_target, + target_label_user=target, + target_label_config=config_target, + ) + + +def _run_compilation( + logger: CommandLogger, + target: str | list[str] | None, + output: str, + dry_run: bool, + no_links: bool, + chatmode: str | None, + with_constitution: bool, + single_agents: bool, + verbose: bool, + local_only: bool, + clean: bool, +) -> None: + """Main compilation flow: target resolution, config, compile, and output. + + Handles both distributed (default) and single-file (``--single-agents``) + strategies, emits the canonical target-provenance line, runs the + compiler, reports results, and hard-fails on critical security findings. + """ + from ...core.target_detection import ( + REASON_NO_TARGET_FOLDER, + ResolvedTargets, + format_provenance, + get_target_description, + ) + + logger.start("Starting context compilation...", symbol="cogs") + + # Resolve effective target using the shared helper (mirrors watch-mode path). + effective_target, detection_reason, config_target = _resolve_effective_target(target) + + # Emit canonical provenance line BEFORE compilation -- mirrors + # `apm install` so users see the same `[i] Targets: ... + # (source: ...)` line on both surfaces. Use the user-facing + # source values (target / config_target) NOT the compiler-family + # expansion in effective_target -- install shows the schema names + # the user wrote (e.g. "copilot"), so compile must too, otherwise + # parity drifts (compile would print "agents, vscode" for the + # same input). + def _coerce_provenance_targets(value): + if value is None: + return [] + if isinstance(value, str): + return [t.strip() for t in value.split(",") if t.strip()] + if isinstance(value, list): + return [str(t) for t in value] + if isinstance(value, frozenset): + return sorted(value) + return [] + + if detection_reason == "explicit --target flag": + _provenance_targets = _coerce_provenance_targets(target) + _provenance_source = "--target flag" + elif detection_reason == "apm.yml target": + _provenance_targets = _coerce_provenance_targets(config_target) + _provenance_source = "apm.yml" + else: + if isinstance(effective_target, frozenset): + _provenance_targets = sorted(effective_target) + elif isinstance(effective_target, str): + _provenance_targets = [effective_target] + else: + _provenance_targets = [] + _provenance_source = f"auto-detect ({detection_reason})" + + if _provenance_targets: + _rich_info( + format_provenance( + ResolvedTargets( + targets=sorted(set(_provenance_targets)), + source=_provenance_source, + auto_create=True, + ) + ), + symbol="info", + ) + + # Build config with distributed compilation flags (Task 7) + config = CompilationConfig.from_apm_yml( + output_path=output if output != AGENTS_MD_FILENAME else None, + chatmode=chatmode, + resolve_links=not no_links if no_links else None, + dry_run=dry_run, + single_agents=single_agents, + trace=verbose, + local_only=local_only, + debug=verbose, + clean_orphaned=clean, + target=effective_target, + ) + config.with_constitution = with_constitution + + # Show target-aware progress message for the chosen strategy. + if config.strategy == "distributed" and not single_agents: + if isinstance(effective_target, frozenset): + # Multi-target compile (from CLI `--target a,b` OR apm.yml + # `target: [a, b]`): show what the compiler will produce. + if isinstance(target, list): + _target_label = f"--target {','.join(target)}" + elif isinstance(config_target, list): + _target_label = f"apm.yml target: [{', '.join(config_target)}]" + else: + _target_label = "multi-target" + from ...core.target_detection import ( + should_compile_agents_md, + should_compile_claude_md, + should_compile_gemini_md, + ) + + _parts = [] + if should_compile_agents_md(effective_target): + _parts.append("AGENTS.md") + if should_compile_claude_md(effective_target): + _parts.append("CLAUDE.md") + if should_compile_gemini_md(effective_target): + _parts.append("GEMINI.md") + logger.progress(f"Compiling for {' + '.join(_parts)} ({_target_label})") + elif ( + isinstance(effective_target, str) + and effective_target == "vscode" + and detection_reason == REASON_NO_TARGET_FOLDER + ): + logger.progress(f"Compiling for AGENTS.md only ({detection_reason})") + logger.progress( + " Create .github/, .claude/, .codex/, .opencode/ or .cursor/ folder for full integration", + symbol="light_bulb", + ) + else: + description = get_target_description(effective_target) + logger.progress(f"Compiling for {description} - {detection_reason}") + + if dry_run: + logger.dry_run_notice("showing placement without writing files") + if verbose: + logger.verbose_detail("Verbose mode: showing source attribution and optimizer analysis") + else: + logger.progress("Using single-file compilation (legacy mode)", symbol="page") + + # Perform compilation + compiler = AgentsCompiler(".") + result = compiler.compile(config, logger=logger) + compile_has_critical = result.has_critical_security + + if result.success: + # Handle different compilation modes + if config.strategy == "distributed" and not single_agents: + # Distributed compilation results - output already shown by professional formatter + # Just show final success message + if dry_run: + # Success message for dry run already included in formatter output + pass + else: + # Defense-in-depth (#820): don't claim "completed + # successfully" when zero files were emitted. With + # parse_target_field as the upstream gatekeeper this is + # unreachable in normal flow, but silent zero-effect + # success is the worst-case package-manager DX. + # + # Pattern-based stat scan (instead of a hardcoded key + # list) so new compile-time targets pick up the guard + # automatically: any stat ending in ``_files_written`` + # or ``_files_generated`` contributes to the total. + _files_written = sum( + int(v or 0) + for k, v in result.stats.items() + if k.endswith(("_files_written", "_files_generated")) + ) + if _files_written > 0: + logger.success( + "Compilation completed successfully!", + symbol="check", + ) + else: + # Zero-output compile is the silent-success failure + # mode #820 guards against. Don't claim success; + # surface what the user can act on. The cause is + # usually one of: target dirs not present (auto- + # detect found nothing), explicit target rejected + # by policy, or no primitives in the project. + logger.warning( + "Compilation completed but produced no output " + "files. Check that target directories exist " + "(e.g. .github/, .claude/) or set 'target:' " + "in apm.yml / pass --target explicitly." + ) + + else: + # Traditional single-file compilation - keep existing logic + # Perform initial compilation in dry-run to get generated body (without constitution) + # TODO: Refactor to use dataclasses.replace() once CompilationConfig fields stabilise + intermediate_config = CompilationConfig( + output_path=config.output_path, + chatmode=config.chatmode, + resolve_links=config.resolve_links, + dry_run=True, # force + with_constitution=config.with_constitution, + strategy="single-file", + target=config.target, + ) + intermediate_result = compiler.compile(intermediate_config) + + if intermediate_result.success: + # Perform constitution injection / preservation + from ...compilation.injector import ConstitutionInjector + + injector = ConstitutionInjector(base_dir=".") + output_path = Path(config.output_path) + final_content, c_status, c_hash = injector.inject( + intermediate_result.content, + with_constitution=config.with_constitution, + output_path=output_path, + ) + + if not dry_run: + # Only rewrite when content materially changes (creation, update, missing constitution case) + if c_status in ("CREATED", "UPDATED", "MISSING"): + # Defense-in-depth: scan compiled output before writing + from ...security.gate import WARN_POLICY, SecurityGate + + verdict = SecurityGate.scan_text( + final_content, str(output_path), policy=WARN_POLICY + ) + if verdict.has_findings: + actionable = verdict.critical_count + verdict.warning_count + if verdict.has_critical: + compile_has_critical = True + if actionable: + logger.warning( + f"Compiled output contains {actionable} hidden character(s) " + f"-- run 'apm audit --file {output_path}' to inspect" + ) + try: + from ...compilation.output_writer import CompiledOutputWriter + + CompiledOutputWriter().write(output_path, final_content) + except OSError as e: + logger.error(f"Failed to write final AGENTS.md: {e}") + sys.exit(1) + else: + logger.progress( + "No changes detected; preserving existing AGENTS.md for idempotency" + ) + + # Report success at the top + if dry_run: + logger.success( + "Context compilation completed successfully (dry run)", + symbol="check", + ) + else: + logger.success( + f"Context compiled successfully to {output_path}", + ) + + stats = ( + intermediate_result.stats + ) # timestamp removed; stats remain version + counts + + # Add spacing before summary table + _rich_blank_line() + + _display_single_file_summary(stats, c_status, c_hash, output_path, dry_run) + + if dry_run: + preview = final_content[:500] + ("..." if len(final_content) > 500 else "") + _rich_panel(preview, title=" Generated Content Preview", style="cyan") + else: + _display_next_steps(output) + + # Display warnings for all compilation modes + if result.warnings: + logger.warning(f"Compilation completed with {len(result.warnings)} warning(s):") + for warning in result.warnings: + logger.warning(f" {warning}") + + if result.errors: + logger.error(f"Compilation failed with {len(result.errors)} errors:") + for error in result.errors: + logger.error(f" {error}") + sys.exit(1) + + # Check for orphaned packages after successful compilation + try: + orphaned_packages = _check_orphaned_packages() + if orphaned_packages: + _rich_blank_line() + logger.warning( + f"Found {len(orphaned_packages)} orphaned package(s) that were included in compilation:" + ) + for pkg in orphaned_packages: + logger.progress(f" * {pkg}") + logger.progress(" Run 'apm prune' to remove orphaned packages") + except Exception: + pass # Continue if orphan check fails + + # Hard-fail when critical security findings were detected in compiled + # output. Consistent with apm install and apm unpack behavior. + if compile_has_critical: + logger.error( + "Compiled output contains critical hidden characters" + " -- run 'apm audit' to inspect, 'apm audit --strip' to clean" + ) + sys.exit(1) + + @click.command(help="Compile APM context into distributed AGENTS.md files") @click.option( "--output", @@ -442,454 +865,30 @@ def compile( logger.warning("'--target all' is deprecated; use '--all' instead.") try: - # Check if this is an APM project first - from pathlib import Path - - if not Path(APM_YML_FILENAME).exists(): - logger.error("Not an APM project - no apm.yml found") - logger.progress(" To initialize an APM project, run:") - logger.progress(" apm init") - sys.exit(1) - - # Check if there are any instruction files to compile - from ...compilation.constitution import find_constitution - - apm_modules_exists = Path(APM_MODULES_DIR).exists() - constitution_exists = find_constitution(Path(".")).exists() - - # Check if .apm directory has actual content - apm_dir = Path(APM_DIR) - local_apm_has_content = apm_dir.exists() and ( - any(apm_dir.rglob("*.instructions.md")) or any(apm_dir.rglob("*.chatmode.md")) - ) - - # If no primitive sources exist, check deeper to provide better feedback - if not apm_modules_exists and not local_apm_has_content and not constitution_exists: - # Check if .apm directories exist but are empty - has_empty_apm = ( - apm_dir.exists() - and not any(apm_dir.rglob("*.instructions.md")) - and not any(apm_dir.rglob("*.chatmode.md")) - ) + _validate_project(logger, dry_run) - if has_empty_apm: - logger.error("No instruction files found in .apm/ directory") - logger.progress(" To add instructions, create files like:") - logger.progress(" .apm/instructions/coding-standards.instructions.md") - logger.progress(" .apm/chatmodes/backend-engineer.chatmode.md") - else: - logger.error("No APM content found to compile") - logger.progress(" To get started:") - logger.progress(" 1. Install APM dependencies: apm install /") - logger.progress(" 2. Or create local instructions: mkdir -p .apm/instructions") - logger.progress(" 3. Then create .instructions.md or .chatmode.md files") - - if not dry_run: # Don't exit on dry-run to allow testing - sys.exit(1) - - # Validation-only mode if validate: - logger.start("Validating APM context...", symbol="gear") - compiler = AgentsCompiler(".") - try: - primitives = discover_primitives(".") - except Exception as e: - logger.error(f"Failed to discover primitives: {e}") - logger.progress(f" Error details: {type(e).__name__}") - sys.exit(1) - validation_errors = compiler.validate_primitives(primitives) - if validation_errors: - _display_validation_errors(validation_errors) - logger.error(f"Validation failed with {len(validation_errors)} errors") - sys.exit(1) - logger.success("All primitives validated successfully!") - logger.progress(f"Validated {primitives.count()} primitives:") - logger.progress(f" * {len(primitives.chatmodes)} chatmodes") - logger.progress(f" * {len(primitives.instructions)} instructions") - logger.progress(f" * {len(primitives.contexts)} contexts") - # Show MCP dependency validation count - try: - from ...models.apm_package import APMPackage - - apm_pkg = APMPackage.from_apm_yml(Path(APM_YML_FILENAME)) - mcp_count = len(apm_pkg.get_mcp_dependencies()) - if mcp_count > 0: - logger.progress(f" * {mcp_count} MCP dependencies") - except Exception: - pass + _run_validation_mode(logger, verbose) return - # Watch mode if watch: - # Resolve the same effective target the one-shot path uses so - # `targets: [claude, cursor]` does not silently regress to the - # all-families fanout on every recompile (#1345). - effective_target, _detection_reason, config_target = _resolve_effective_target(target) - _watch_mode( - output, - chatmode, - no_links, - dry_run, - verbose=verbose, - effective_target=effective_target, - target_label_user=target, - target_label_config=config_target, - ) + _run_watch_mode(logger, target, output, chatmode, no_links, dry_run, verbose) return - logger.start("Starting context compilation...", symbol="cogs") - - # Auto-detect target if not explicitly provided - from ...core.target_detection import ( - REASON_NO_TARGET_FOLDER, - detect_target, - get_target_description, + _run_compilation( + logger, + target, + output, + dry_run, + no_links, + chatmode, + with_constitution, + single_agents, + verbose, + local_only, + clean, ) - # Get config target from apm.yml if available. When the file is - # absent we proceed with auto-detection; when it is present but - # malformed we let the parse error surface so users see exactly - # what is wrong (e.g. ``target: opencode,bogus`` -> a ValueError - # naming the bad token), rather than silently falling through to - # auto-detect. See #820. - from ...models.apm_package import APMPackage - - config_target = None - apm_yml_path = Path(APM_YML_FILENAME) - if apm_yml_path.exists(): - apm_pkg = APMPackage.from_apm_yml(apm_yml_path) - config_target = apm_pkg.target - # Parity with `apm install`: also honor canonical plural - # `targets:` key (#1154). APMPackage only reads singular - # `target:`; parse_targets_field handles both keys, raises - # ConflictingTargetsError when both appear, and validates - # tokens against CANONICAL_TARGETS. When only `targets:` is - # present, apm_pkg.target is None and we promote the plural - # list here so compile sees the same schema install sees. - if config_target is None: - try: - from ...core.apm_yml import parse_targets_field - from ...utils.yaml_io import load_yaml - - _raw = load_yaml(apm_yml_path) - if isinstance(_raw, dict): - _yaml_targets = parse_targets_field(_raw) - if _yaml_targets: - config_target = ( - _yaml_targets[0] if len(_yaml_targets) == 1 else _yaml_targets - ) - except Exception: - pass - - # Resolve list targets to compiler-understood value - compile_target = _resolve_compile_target(target) - # Also handle config_target being a list (from apm.yml target: [claude, copilot]) - compile_config_target = _resolve_compile_target(config_target) - - # A frozenset means multiple compiler families were explicitly - # requested -- bypass detect_target() since it only handles strings. - if isinstance(compile_target, frozenset): - effective_target = compile_target - detection_reason = "explicit --target flag" - elif isinstance(compile_config_target, frozenset) and compile_target is None: - effective_target = compile_config_target - detection_reason = "apm.yml target" - else: - # Pass config_target only when it's a string -- detect_target() is - # typed for Optional[str], and a frozenset config_target is already - # handled by the branch above. - detected_target, detection_reason = detect_target( - project_root=Path("."), - explicit_target=compile_target, - config_target=compile_config_target - if isinstance(compile_config_target, str) - else None, - ) - # Keep the detected target intact so the compiler can preserve - # minimal-mode semantics (AGENTS.md only, no .github side outputs). - effective_target = detected_target - - # Emit canonical provenance line BEFORE compilation -- mirrors - # `apm install` so users see the same `[i] Targets: ... - # (source: ...)` line on both surfaces. Use the user-facing - # source values (target / config_target) NOT the compiler-family - # expansion in effective_target -- install shows the schema names - # the user wrote (e.g. "copilot"), so compile must too, otherwise - # parity drifts (compile would print "agents, vscode" for the - # same input). - from ...core.target_detection import ResolvedTargets, format_provenance - from ...utils.console import _rich_info - - def _coerce_provenance_targets(value): - if value is None: - return [] - if isinstance(value, str): - return [t.strip() for t in value.split(",") if t.strip()] - if isinstance(value, list): - return [str(t) for t in value] - if isinstance(value, frozenset): - return sorted(value) - return [] - - if detection_reason == "explicit --target flag": - _provenance_targets = _coerce_provenance_targets(target) - _provenance_source = "--target flag" - elif detection_reason == "apm.yml target": - _provenance_targets = _coerce_provenance_targets(config_target) - _provenance_source = "apm.yml" - else: - if isinstance(effective_target, frozenset): - _provenance_targets = sorted(effective_target) - elif isinstance(effective_target, str): - _provenance_targets = [effective_target] - else: - _provenance_targets = [] - _provenance_source = f"auto-detect ({detection_reason})" - - if _provenance_targets: - _rich_info( - format_provenance( - ResolvedTargets( - targets=sorted(set(_provenance_targets)), - source=_provenance_source, - auto_create=True, - ) - ), - symbol="info", - ) - - # Build config with distributed compilation flags (Task 7) - config = CompilationConfig.from_apm_yml( - output_path=output if output != AGENTS_MD_FILENAME else None, - chatmode=chatmode, - resolve_links=not no_links if no_links else None, - dry_run=dry_run, - single_agents=single_agents, - trace=verbose, - local_only=local_only, - debug=verbose, - clean_orphaned=clean, - target=effective_target, - ) - config.with_constitution = with_constitution - - # Handle distributed vs single-file compilation - if config.strategy == "distributed" and not single_agents: - # Show target-aware message with detection reason. Use - # get_target_description() so any future target added to - # target_detection shows up here automatically. - if isinstance(effective_target, frozenset): - # Multi-target compile (from CLI `--target a,b` OR apm.yml - # `target: [a, b]`): show what the compiler will produce. - if isinstance(target, list): - _target_label = f"--target {','.join(target)}" - elif isinstance(config_target, list): - _target_label = f"apm.yml target: [{', '.join(config_target)}]" - else: - _target_label = "multi-target" - from ...core.target_detection import ( - should_compile_agents_md, - should_compile_claude_md, - should_compile_gemini_md, - ) - - _parts = [] - if should_compile_agents_md(effective_target): - _parts.append("AGENTS.md") - if should_compile_claude_md(effective_target): - _parts.append("CLAUDE.md") - if should_compile_gemini_md(effective_target): - _parts.append("GEMINI.md") - logger.progress(f"Compiling for {' + '.join(_parts)} ({_target_label})") - elif ( - isinstance(effective_target, str) - and effective_target == "vscode" - and detection_reason == REASON_NO_TARGET_FOLDER - ): - logger.progress(f"Compiling for AGENTS.md only ({detection_reason})") - logger.progress( - " Create .github/, .claude/, .codex/, .opencode/ or .cursor/ folder for full integration", - symbol="light_bulb", - ) - else: - description = get_target_description(effective_target) - logger.progress(f"Compiling for {description} - {detection_reason}") - - if dry_run: - logger.dry_run_notice("showing placement without writing files") - if verbose: - logger.verbose_detail( - "Verbose mode: showing source attribution and optimizer analysis" - ) - else: - logger.progress("Using single-file compilation (legacy mode)", symbol="page") - - # Perform compilation - compiler = AgentsCompiler(".") - result = compiler.compile(config, logger=logger) - compile_has_critical = result.has_critical_security - - if result.success: - # Handle different compilation modes - if config.strategy == "distributed" and not single_agents: - # Distributed compilation results - output already shown by professional formatter - # Just show final success message - if dry_run: - # Success message for dry run already included in formatter output - pass - else: - # Defense-in-depth (#820): don't claim "completed - # successfully" when zero files were emitted. With - # parse_target_field as the upstream gatekeeper this is - # unreachable in normal flow, but silent zero-effect - # success is the worst-case package-manager DX. - # - # Pattern-based stat scan (instead of a hardcoded key - # list) so new compile-time targets pick up the guard - # automatically: any stat ending in ``_files_written`` - # or ``_files_generated`` contributes to the total. - _files_written = sum( - int(v or 0) - for k, v in result.stats.items() - if k.endswith(("_files_written", "_files_generated")) - ) - if _files_written > 0: - logger.success( - "Compilation completed successfully!", - symbol="check", - ) - else: - # Zero-output compile is the silent-success failure - # mode #820 guards against. Don't claim success; - # surface what the user can act on. The cause is - # usually one of: target dirs not present (auto- - # detect found nothing), explicit target rejected - # by policy, or no primitives in the project. - logger.warning( - "Compilation completed but produced no output " - "files. Check that target directories exist " - "(e.g. .github/, .claude/) or set 'target:' " - "in apm.yml / pass --target explicitly." - ) - - else: - # Traditional single-file compilation - keep existing logic - # Perform initial compilation in dry-run to get generated body (without constitution) - # TODO: Refactor to use dataclasses.replace() once CompilationConfig fields stabilise - intermediate_config = CompilationConfig( - output_path=config.output_path, - chatmode=config.chatmode, - resolve_links=config.resolve_links, - dry_run=True, # force - with_constitution=config.with_constitution, - strategy="single-file", - target=config.target, - ) - intermediate_result = compiler.compile(intermediate_config) - - if intermediate_result.success: - # Perform constitution injection / preservation - from ...compilation.injector import ConstitutionInjector - - injector = ConstitutionInjector(base_dir=".") - output_path = Path(config.output_path) - final_content, c_status, c_hash = injector.inject( - intermediate_result.content, - with_constitution=config.with_constitution, - output_path=output_path, - ) - - if not dry_run: - # Only rewrite when content materially changes (creation, update, missing constitution case) - if c_status in ("CREATED", "UPDATED", "MISSING"): - # Defense-in-depth: scan compiled output before writing - from ...security.gate import WARN_POLICY, SecurityGate - - verdict = SecurityGate.scan_text( - final_content, str(output_path), policy=WARN_POLICY - ) - if verdict.has_findings: - actionable = verdict.critical_count + verdict.warning_count - if verdict.has_critical: - compile_has_critical = True - if actionable: - logger.warning( - f"Compiled output contains {actionable} hidden character(s) " - f"-- run 'apm audit --file {output_path}' to inspect" - ) - try: - from ...compilation.output_writer import CompiledOutputWriter - - CompiledOutputWriter().write(output_path, final_content) - except OSError as e: - logger.error(f"Failed to write final AGENTS.md: {e}") - sys.exit(1) - else: - logger.progress( - "No changes detected; preserving existing AGENTS.md for idempotency" - ) - - # Report success at the top - if dry_run: - logger.success( - "Context compilation completed successfully (dry run)", - symbol="check", - ) - else: - logger.success( - f"Context compiled successfully to {output_path}", - ) - - stats = ( - intermediate_result.stats - ) # timestamp removed; stats remain version + counts - - # Add spacing before summary table - _rich_blank_line() - - _display_single_file_summary(stats, c_status, c_hash, output_path, dry_run) - - if dry_run: - preview = final_content[:500] + ("..." if len(final_content) > 500 else "") - _rich_panel(preview, title=" Generated Content Preview", style="cyan") - else: - _display_next_steps(output) - - # Display warnings for all compilation modes - if result.warnings: - logger.warning(f"Compilation completed with {len(result.warnings)} warning(s):") - for warning in result.warnings: - logger.warning(f" {warning}") - - if result.errors: - logger.error(f"Compilation failed with {len(result.errors)} errors:") - for error in result.errors: - logger.error(f" {error}") - sys.exit(1) - - # Check for orphaned packages after successful compilation - try: - orphaned_packages = _check_orphaned_packages() - if orphaned_packages: - _rich_blank_line() - logger.warning( - f"Found {len(orphaned_packages)} orphaned package(s) that were included in compilation:" - ) - for pkg in orphaned_packages: - logger.progress(f" * {pkg}") - logger.progress(" Run 'apm prune' to remove orphaned packages") - except Exception: - pass # Continue if orphan check fails - - # Hard-fail when critical security findings were detected in compiled - # output. Consistent with apm install and apm unpack behavior. - if compile_has_critical: - logger.error( - "Compiled output contains critical hidden characters" - " -- run 'apm audit' to inspect, 'apm audit --strip' to clean" - ) - sys.exit(1) - except ImportError as e: logger.error(f"Compilation module not available: {e}") logger.progress("This might be a development environment issue.") diff --git a/src/apm_cli/commands/install.py b/src/apm_cli/commands/install.py index f81f828b5..b713e490d 100644 --- a/src/apm_cli/commands/install.py +++ b/src/apm_cli/commands/install.py @@ -796,7 +796,6 @@ def _handle_mcp_install( runtime, exclude, verbose, - dry_run, logger, no_policy, validated_registry_url, @@ -852,14 +851,14 @@ def _handle_mcp_install( mcp_deps=[_preflight_dep], no_policy=no_policy, logger=logger, - dry_run=dry_run, + dry_run=logger.dry_run, ) except PolicyBlockError: # Diagnostics already emitted by the helper + logger. logger.render_summary() sys.exit(1) - if dry_run: + if logger.dry_run: # C1: validate eagerly so dry-run rejects what real install would. _validate_mcp_dry_run_entry( mcp_name, @@ -885,9 +884,7 @@ def _handle_mcp_install( force=force, runtime=runtime, exclude=exclude, - verbose=verbose, logger=logger, - manifest_path=mcp_manifest_path, apm_dir=mcp_apm_dir, scope=mcp_scope, registry_url=validated_registry_url, @@ -1326,9 +1323,7 @@ def install( # noqa: PLR0913 global_=global_, only=only, update=update, - use_ssh=use_ssh, - use_https=use_https, - allow_protocol_fallback=allow_protocol_fallback, + any_transport_flag=use_ssh or use_https or allow_protocol_fallback, registry_url=validated_registry_url, ) @@ -1354,7 +1349,6 @@ def install( # noqa: PLR0913 runtime=runtime, exclude=exclude, verbose=verbose, - dry_run=dry_run, logger=logger, no_policy=no_policy, validated_registry_url=validated_registry_url, diff --git a/src/apm_cli/commands/pack.py b/src/apm_cli/commands/pack.py index 67869858d..349b32f50 100644 --- a/src/apm_cli/commands/pack.py +++ b/src/apm_cli/commands/pack.py @@ -75,6 +75,79 @@ def _emit_json_error_or_raise(ctx, json_output: bool, code: str, message: str): raise click.ClickException(message) +def _parse_path_overrides( + marketplace_path_overrides: "tuple[str, ...]", + ctx, + json_output: bool, +) -> "dict[str, str] | None": + """Parse --marketplace-path KEY=VALUE pairs. + + Returns a dict mapping format name -> path, or ``None`` on the first + validation error (after emitting the error via *ctx*). + """ + from ..marketplace.output_profiles import known_output_names + from ..utils.path_security import validate_path_segments + + path_overrides: dict[str, str] = {} + for override in marketplace_path_overrides: + if "=" not in override: + msg = f"--marketplace-path must be FORMAT=PATH, got: {override!r}" + _emit_json_error_or_raise(ctx, json_output, "cli_error", msg) + return None + fmt_name, path_val = override.split("=", 1) + fmt_name = fmt_name.strip() + path_val = path_val.strip() + if fmt_name not in known_output_names(): + msg = ( + f"Unknown marketplace format '{fmt_name}' in --marketplace-path. " + f"Known formats: {', '.join(sorted(known_output_names()))}" + ) + _emit_json_error_or_raise(ctx, json_output, "unknown_format", msg) + return None + # Security: validate path to prevent traversal attacks + try: + validate_path_segments(path_val, context="--marketplace-path", allow_current_dir=True) + except Exception as exc: + _emit_json_error_or_raise(ctx, json_output, "path_error", str(exc)) + return None + path_overrides[fmt_name] = path_val + return path_overrides + + +def _parse_marketplace_filter( + marketplace_filter: "str | None", + ctx, + json_output: bool, +) -> "tuple[str, ...] | None": + """Parse the --marketplace filter value. + + Returns: + - ``None`` -- build all configured outputs + - empty ``tuple`` -- skip marketplace entirely (``--marketplace none``) + - non-empty tuple -- build only the named formats + - ``None`` on validation error (after emitting the error via *ctx*) + """ + from ..marketplace.output_profiles import known_output_names + + if marketplace_filter is None: + return None + if marketplace_filter.strip().lower() == "none": + return () + if marketplace_filter.strip().lower() == "all": + return None # all configured + requested = [f.strip() for f in marketplace_filter.split(",") if f.strip()] + known = known_output_names() + for r in requested: + if r not in known: + msg = ( + f"Unknown marketplace format '{r}' in --marketplace. " + f"Known formats: {', '.join(sorted(known))}" + ) + _emit_json_error_or_raise(ctx, json_output, "unknown_format", msg) + return None + return tuple(requested) + + @click.command(name="pack", help=_PACK_HELP) @click.option( "--format", @@ -192,7 +265,7 @@ def _emit_json_error_or_raise(ctx, json_output: bool, code: str, message: str): ), ) @click.pass_context -def pack_cmd( +def pack_cmd( # noqa: PLR0913 -- Click handler, one param per CLI option ctx, fmt, target, @@ -212,16 +285,13 @@ def pack_cmd( check_clean, ): """Pack APM artifacts: bundle and/or marketplace.json.""" - from ..marketplace.output_profiles import known_output_names - from ..utils.path_security import validate_path_segments - # -- Stream discipline: under --json, route ALL output to stderr -- if json_output: set_console_stderr(True) logger = CommandLogger("pack", verbose=verbose, dry_run=dry_run) - # -- Deprecation: --marketplace-output → --marketplace-path claude=PATH -- + # -- Deprecation: --marketplace-output -> --marketplace-path claude=PATH -- if marketplace_output is not None: translated = f"--marketplace-path claude={marketplace_output}" click.echo( @@ -236,49 +306,14 @@ def pack_cmd( marketplace_output = None # -- Parse --marketplace-path overrides -- - path_overrides: dict[str, str] = {} - for override in marketplace_path_overrides: - if "=" not in override: - msg = f"--marketplace-path must be FORMAT=PATH, got: {override!r}" - _emit_json_error_or_raise(ctx, json_output, "cli_error", msg) - return - fmt_name, path_val = override.split("=", 1) - fmt_name = fmt_name.strip() - path_val = path_val.strip() - if fmt_name not in known_output_names(): - msg = ( - f"Unknown marketplace format '{fmt_name}' in --marketplace-path. " - f"Known formats: {', '.join(sorted(known_output_names()))}" - ) - _emit_json_error_or_raise(ctx, json_output, "unknown_format", msg) - return - # Security: validate path to prevent traversal attacks - try: - validate_path_segments(path_val, context="--marketplace-path", allow_current_dir=True) - except Exception as exc: - _emit_json_error_or_raise(ctx, json_output, "path_error", str(exc)) - return - path_overrides[fmt_name] = path_val + path_overrides_result = _parse_path_overrides(marketplace_path_overrides, ctx, json_output) + if path_overrides_result is None: + return + path_overrides = path_overrides_result # -- Parse --marketplace filter -- - marketplace_formats: tuple[str, ...] | None = None - if marketplace_filter is not None: - if marketplace_filter.strip().lower() == "none": - marketplace_formats = () - elif marketplace_filter.strip().lower() == "all": - marketplace_formats = None # all configured - else: - requested = [f.strip() for f in marketplace_filter.split(",") if f.strip()] - known = known_output_names() - for r in requested: - if r not in known: - msg = ( - f"Unknown marketplace format '{r}' in --marketplace. " - f"Known formats: {', '.join(sorted(known))}" - ) - _emit_json_error_or_raise(ctx, json_output, "unknown_format", msg) - return - marketplace_formats = tuple(requested) + marketplace_formats = _parse_marketplace_filter(marketplace_filter, ctx, json_output) + # _parse_marketplace_filter raises/exits on error via _emit_json_error_or_raise project_root = Path(".").resolve() # Issue #1207 D1: when --target is not given, detect the project's # actual target so the embedded ``pack.target`` reflects what was diff --git a/src/apm_cli/install/drift.py b/src/apm_cli/install/drift.py index 4c6673f07..81686de73 100644 --- a/src/apm_cli/install/drift.py +++ b/src/apm_cli/install/drift.py @@ -394,7 +394,7 @@ def run_replay(config: ReplayConfig, logger: CheckLogger) -> Path: Surfaced verbatim when a locked dep is not in the cache. """ from apm_cli.deps.lockfile import _SELF_KEY, LockFile - from apm_cli.install.services import integrate_package_primitives + from apm_cli.install.services import IntegratorBundle, integrate_package_primitives from apm_cli.integration.targets import resolve_targets from apm_cli.utils.diagnostics import DiagnosticCollector @@ -473,12 +473,14 @@ def run_replay(config: ReplayConfig, logger: CheckLogger) -> Path: package_info, scratch_root, targets=targets, - prompt_integrator=integrators["prompt"], - agent_integrator=integrators["agent"], - skill_integrator=integrators["skill"], - instruction_integrator=integrators["instruction"], - command_integrator=integrators["command"], - hook_integrator=integrators["hook"], + integrators=IntegratorBundle( + prompt=integrators["prompt"], + agent=integrators["agent"], + skill=integrators["skill"], + instruction=integrators["instruction"], + command=integrators["command"], + hook=integrators["hook"], + ), force=True, managed_files=set(), diagnostics=diagnostics, diff --git a/src/apm_cli/install/mcp/command.py b/src/apm_cli/install/mcp/command.py index 55faf0bb3..d492e9882 100644 --- a/src/apm_cli/install/mcp/command.py +++ b/src/apm_cli/install/mcp/command.py @@ -49,15 +49,18 @@ def run_mcp_install( force: bool, runtime: str | None, exclude: str | None, - verbose: bool, logger, - manifest_path: Path, apm_dir: Path, scope: str | None, registry_url: str | None = None, ) -> None: """Execute the --mcp install path. ``registry_url`` is the validated - --registry value; the caller resolved precedence vs MCP_REGISTRY_URL.""" + --registry value; the caller resolved precedence vs MCP_REGISTRY_URL. + ``manifest_path`` is derived from ``apm_dir`` (``apm_dir / 'apm.yml'``).""" + from ...constants import APM_YML_FILENAME + + manifest_path = apm_dir / APM_YML_FILENAME + verbose = logger.verbose from ...models.dependency.mcp import MCPDependency env = parse_env_pairs(env_pairs) diff --git a/src/apm_cli/install/mcp/conflicts.py b/src/apm_cli/install/mcp/conflicts.py index 6c8f726a2..b0e79d524 100644 --- a/src/apm_cli/install/mcp/conflicts.py +++ b/src/apm_cli/install/mcp/conflicts.py @@ -38,12 +38,14 @@ def validate_mcp_conflicts( global_: bool, only: str | None, update: bool, - use_ssh: bool, - use_https: bool, - allow_protocol_fallback: bool, + any_transport_flag: bool, registry_url: str | None = None, ) -> None: - """Apply conflict matrix E1-E15. Raises ``click.UsageError`` on hit.""" + """Apply conflict matrix E1-E15. Raises ``click.UsageError`` on hit. + + ``any_transport_flag`` should be ``use_ssh or use_https or + allow_protocol_fallback`` (pre-evaluated by the caller). + """ # E10: flags require --mcp -- run first so users get the right hint. if mcp_name is None: flag_values = { @@ -84,7 +86,7 @@ def validate_mcp_conflicts( raise click.UsageError("cannot use --only apm with --mcp") # E4: transport selection flags do not apply. - if use_ssh or use_https or allow_protocol_fallback: + if any_transport_flag: raise click.UsageError( "transport selection flags (--ssh/--https/--allow-protocol-fallback) " "don't apply to MCP entries" diff --git a/src/apm_cli/install/phases/resolve.py b/src/apm_cli/install/phases/resolve.py index 969f2ce40..7a25893b4 100644 --- a/src/apm_cli/install/phases/resolve.py +++ b/src/apm_cli/install/phases/resolve.py @@ -31,23 +31,18 @@ _logger = logging.getLogger(__name__) -def run(ctx: InstallContext) -> None: - """Execute the resolve phase. +# ------------------------------------------------------------------ +# Private helpers (each mutates ctx in-place, following existing pattern) +# ------------------------------------------------------------------ - On return every field listed in the *Resolve phase outputs* section of - :class:`~apm_cli.install.context.InstallContext` is populated. - """ - from apm_cli.core.auth import AuthResolver - from apm_cli.core.scope import InstallScope, get_modules_dir - from apm_cli.deps import github_downloader as _ghd_mod - from apm_cli.deps.apm_resolver import APMDependencyResolver - from apm_cli.deps.lockfile import LockFile, get_lockfile_path - from apm_cli.install.phases.local_content import _copy_local_package - from apm_cli.models.apm_package import DependencyReference +def _load_lockfile(ctx: InstallContext) -> None: + """Load ``apm.lock.yaml`` and populate ``ctx.existing_lockfile`` / ``ctx.lockfile_path``.""" # ------------------------------------------------------------------ # 1. Lockfile loading # ------------------------------------------------------------------ + from apm_cli.deps.lockfile import LockFile, get_lockfile_path + lockfile_path = get_lockfile_path(ctx.apm_dir) ctx.lockfile_path = lockfile_path existing_lockfile = None @@ -78,16 +73,29 @@ def run(ctx: InstallContext) -> None: ctx.logger.lockfile_entry(locked_dep.get_unique_key(), ref=_ref, sha=_sha) ctx.existing_lockfile = existing_lockfile + +def _ensure_modules_dir(ctx: InstallContext) -> None: + """Create the ``apm_modules/`` directory and populate ``ctx.apm_modules_dir``.""" # ------------------------------------------------------------------ # 2. apm_modules directory # ------------------------------------------------------------------ + from apm_cli.core.scope import get_modules_dir + apm_modules_dir = get_modules_dir(ctx.scope) apm_modules_dir.mkdir(parents=True, exist_ok=True) ctx.apm_modules_dir = apm_modules_dir + +def _setup_downloader(ctx: InstallContext) -> None: + """Create auth resolver and downloader; populate ``ctx.auth_resolver`` / ``ctx.downloader``.""" # ------------------------------------------------------------------ # 3. Auth resolver + downloader # ------------------------------------------------------------------ + import os as _os + + from apm_cli.core.auth import AuthResolver + from apm_cli.deps import github_downloader as _ghd_mod + if ctx.auth_resolver is None: ctx.auth_resolver = AuthResolver() @@ -100,8 +108,7 @@ def run(ctx: InstallContext) -> None: # WS2a (#1116): attach a per-run shared clone cache so subdirectory # deps from the same upstream repo+ref share a single git clone. - # The cache is cleaned up in the resolve phase's finally-equivalent - # (after resolution completes, whether success or failure). + # The cache is cleaned up after resolution completes (see _resolve_dependencies). from apm_cli.deps.shared_clone_cache import SharedCloneCache shared_cache = SharedCloneCache() @@ -109,8 +116,6 @@ def run(ctx: InstallContext) -> None: # WS3 (#1116): attach persistent cross-run git cache unless disabled # via APM_NO_CACHE environment variable. - import os as _os - if not _os.environ.get("APM_NO_CACHE"): from apm_cli.cache.paths import get_cache_root @@ -158,6 +163,26 @@ def run(ctx: InstallContext) -> None: exc, ) + +def _resolve_dependencies(ctx: InstallContext) -> None: + """Run ``APMDependencyResolver``, handle errors; populate ``ctx.deps_to_install`` and ``ctx.dependency_graph``. + + Also wires the download callback (which handles transitive package fetching), + builds ``ctx.dep_base_dirs``, writes ancillary state to ``ctx``, and cleans up + the shared clone cache. + """ + import threading as _threading + + from apm_cli.core.scope import InstallScope + from apm_cli.deps.apm_resolver import APMDependencyResolver + from apm_cli.install.insecure_policy import ( + _check_insecure_dependencies, + _collect_insecure_dependency_infos, + _guard_transitive_insecure_dependencies, + _warn_insecure_dependencies, + ) + from apm_cli.install.phases.local_content import _copy_local_package + # ------------------------------------------------------------------ # 4. Tracking variables (phase-local except where noted) # ------------------------------------------------------------------ @@ -173,8 +198,6 @@ def run(ctx: InstallContext) -> None: # ``callback_downloaded`` (e.g. duplicate-key races) are not. A single # narrow lock around the result-recording sites is sufficient and # cheap; the heavy I/O work runs OUTSIDE the lock. - import threading as _threading - callback_lock = _threading.Lock() # ------------------------------------------------------------------ @@ -186,7 +209,8 @@ def run(ctx: InstallContext) -> None: project_root = ctx.project_root update_refs = ctx.update_refs logger = ctx.logger - verbose = ctx.verbose # noqa: F841 + existing_lockfile = ctx.existing_lockfile + downloader = ctx.downloader def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): """Download a package during dependency resolution. @@ -333,7 +357,7 @@ def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): # 6. Resolver creation + dependency resolution # ------------------------------------------------------------------ resolver = APMDependencyResolver( - apm_modules_dir=apm_modules_dir, + apm_modules_dir=ctx.apm_modules_dir, download_callback=download_callback, ) @@ -380,49 +404,6 @@ def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): flat_deps = dependency_graph.flattened_dependencies deps_to_install = flat_deps.get_installation_list() - # ------------------------------------------------------------------ - # 7. --only filtering - # ------------------------------------------------------------------ - if ctx.only_packages: - # Build identity set from user-supplied package specs. - # Accepts any input form: git URLs, FQDN, shorthand. - only_identities = builtins.set() - for p in ctx.only_packages: - try: - ref = DependencyReference.parse(p) - only_identities.add(ref.get_identity()) - except Exception: - only_identities.add(p) - - # Expand the set to include transitive descendants of the - # requested packages so their MCP servers, primitives, etc. - # are correctly installed and written to the lockfile. - tree = dependency_graph.dependency_tree - - def _collect_descendants(node, visited=None): - """Walk the tree and add every child identity (cycle-safe).""" - if visited is None: - visited = builtins.set() - for child in node.children: - identity = child.dependency_ref.get_identity() - if identity not in visited: - visited.add(identity) - only_identities.add(identity) - _collect_descendants(child, visited) - - for node in tree.nodes.values(): - if node.dependency_ref.get_identity() in only_identities: - _collect_descendants(node) - - deps_to_install = [dep for dep in deps_to_install if dep.get_identity() in only_identities] - - from apm_cli.install.insecure_policy import ( - _check_insecure_dependencies, - _collect_insecure_dependency_infos, - _guard_transitive_insecure_dependencies, - _warn_insecure_dependencies, - ) - _check_insecure_dependencies( ctx.all_apm_deps, ctx.allow_insecure, @@ -503,11 +484,6 @@ def _collect_descendants(node, visited=None): dep_base_dirs = {} ctx.dep_base_dirs = dep_base_dirs - # ------------------------------------------------------------------ - # 8. Orphan detection: intended_dep_keys - # ------------------------------------------------------------------ - ctx.intended_dep_keys = builtins.set(d.get_unique_key() for d in deps_to_install) - # ------------------------------------------------------------------ # Write ancillary state to ctx for later phases # ------------------------------------------------------------------ @@ -518,7 +494,9 @@ def _collect_descendants(node, visited=None): # WS2a (#1116): release shared clone temp dirs now that all subdir # deps have extracted their subpaths. Safe to call even if no # subdir deps were processed (no-op in that case). - shared_cache.cleanup() + shared_cache = getattr(ctx.downloader, "shared_clone_cache", None) + if shared_cache is not None: + shared_cache.cleanup() # Perf #1433: emit ref-resolver tier hit counts at the end of the # resolve phase. Verbose only; one line; lets reviewers see which @@ -529,3 +507,68 @@ def _collect_descendants(node, visited=None): # tier_summary is install-only; other loggers degrade silently. if hasattr(ctx.logger, "tier_summary"): ctx.logger.tier_summary(_tier_stats) + + +def _apply_only_filter(ctx: InstallContext) -> None: + """Filter ``ctx.deps_to_install`` to the ``--only`` package(s) and their subtrees.""" + # ------------------------------------------------------------------ + # 7. --only filtering + # ------------------------------------------------------------------ + from apm_cli.models.apm_package import DependencyReference + + # Build identity set from user-supplied package specs. + # Accepts any input form: git URLs, FQDN, shorthand. + only_identities: builtins.set = builtins.set() + for p in ctx.only_packages: + try: + ref = DependencyReference.parse(p) + only_identities.add(ref.get_identity()) + except Exception: + only_identities.add(p) + + # Expand the set to include transitive descendants of the + # requested packages so their MCP servers, primitives, etc. + # are correctly installed and written to the lockfile. + tree = ctx.dependency_graph.dependency_tree + + def _collect_descendants(node: object, visited: builtins.set | None = None) -> None: + """Walk the tree and add every child identity (cycle-safe).""" + if visited is None: + visited = builtins.set() + for child in node.children: # type: ignore[attr-defined] + identity = child.dependency_ref.get_identity() + if identity not in visited: + visited.add(identity) + only_identities.add(identity) + _collect_descendants(child, visited) + + for node in tree.nodes.values(): + if node.dependency_ref.get_identity() in only_identities: + _collect_descendants(node) + + ctx.deps_to_install = [ + dep for dep in ctx.deps_to_install if dep.get_identity() in only_identities + ] + + +def _compute_intended_dep_keys(ctx: InstallContext) -> None: + """Populate ``ctx.intended_dep_keys`` (manifest-intent set for orphan cleanup).""" + # ------------------------------------------------------------------ + # 8. Orphan detection: intended_dep_keys + # ------------------------------------------------------------------ + ctx.intended_dep_keys = builtins.set(d.get_unique_key() for d in ctx.deps_to_install) + + +def run(ctx: InstallContext) -> None: + """Execute the resolve phase. + + On return every field listed in the *Resolve phase outputs* section of + :class:`~apm_cli.install.context.InstallContext` is populated. + """ + _load_lockfile(ctx) + _ensure_modules_dir(ctx) + _setup_downloader(ctx) + _resolve_dependencies(ctx) + if ctx.only_packages: + _apply_only_filter(ctx) + _compute_intended_dep_keys(ctx) diff --git a/src/apm_cli/install/services.py b/src/apm_cli/install/services.py index 1b62e22ba..3ec07891e 100644 --- a/src/apm_cli/install/services.py +++ b/src/apm_cli/install/services.py @@ -19,6 +19,7 @@ from __future__ import annotations import builtins +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any @@ -26,6 +27,7 @@ from ..core.command_logger import InstallLogger from ..core.scope import InstallScope from ..install.context import InstallContext + from ..integration.base_integrator import BaseIntegrator from ..utils.diagnostics import DiagnosticCollector @@ -38,6 +40,23 @@ dict = builtins.dict +@dataclass(frozen=True) +class IntegratorBundle: + """Groups the six primitive integrators passed to ``integrate_package_primitives``. + + Using a bundle reduces the public argument count of + ``integrate_package_primitives`` below the PLR0913 threshold (≤15) while + keeping the integrator objects strongly typed and discoverable. + """ + + prompt: BaseIntegrator + agent: BaseIntegrator + skill: BaseIntegrator + instruction: BaseIntegrator + command: BaseIntegrator + hook: BaseIntegrator + + def _deployed_path_entry( target_path: Path, project_root: Path, @@ -101,12 +120,7 @@ def integrate_package_primitives( project_root: Path, *, targets: Any, - prompt_integrator: Any, - agent_integrator: Any, - skill_integrator: Any, - instruction_integrator: Any, - command_integrator: Any, - hook_integrator: Any, + integrators: IntegratorBundle, force: bool, managed_files: Any, diagnostics: DiagnosticCollector, @@ -242,12 +256,12 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ _verbose = bool(getattr(ctx, "verbose", False)) if ctx is not None else False _INTEGRATOR_KWARGS = { - "prompts": prompt_integrator, - "agents": agent_integrator, - "commands": command_integrator, - "instructions": instruction_integrator, - "hooks": hook_integrator, - "skills": skill_integrator, + "prompts": integrators.prompt, + "agents": integrators.agent, + "commands": integrators.command, + "instructions": integrators.instruction, + "hooks": integrators.hook, + "skills": integrators.skill, } # Aggregate per-primitive across targets so we emit ONE line per kind @@ -375,7 +389,7 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ " |-- workflows arrive disabled; enable from the Copilot App's Workflows tab" ) - skill_result = skill_integrator.integrate_package_skill( + skill_result = integrators.skill.integrate_package_skill( package_info, project_root, diagnostics=diagnostics, @@ -478,12 +492,14 @@ def integrate_local_content( local_info, project_root, targets=targets, - prompt_integrator=prompt_integrator, - agent_integrator=agent_integrator, - skill_integrator=skill_integrator, - instruction_integrator=instruction_integrator, - command_integrator=command_integrator, - hook_integrator=hook_integrator, + integrators=IntegratorBundle( + prompt=prompt_integrator, + agent=agent_integrator, + skill=skill_integrator, + instruction=instruction_integrator, + command=command_integrator, + hook=hook_integrator, + ), force=force, managed_files=managed_files, diagnostics=diagnostics, diff --git a/src/apm_cli/install/template.py b/src/apm_cli/install/template.py index 45aaaf9eb..a9e99583a 100644 --- a/src/apm_cli/install/template.py +++ b/src/apm_cli/install/template.py @@ -16,7 +16,7 @@ from typing import Dict, Optional # noqa: F401, UP035 from apm_cli.install.helpers.security_scan import _pre_deploy_security_scan -from apm_cli.install.services import integrate_package_primitives +from apm_cli.install.services import IntegratorBundle, integrate_package_primitives from apm_cli.install.sources import DependencySource, Materialization @@ -77,12 +77,14 @@ def _integrate_materialization( m.package_info, ctx.project_root, targets=ctx.targets, - prompt_integrator=ctx.integrators["prompt"], - agent_integrator=ctx.integrators["agent"], - skill_integrator=ctx.integrators["skill"], - instruction_integrator=ctx.integrators["instruction"], - command_integrator=ctx.integrators["command"], - hook_integrator=ctx.integrators["hook"], + integrators=IntegratorBundle( + prompt=ctx.integrators["prompt"], + agent=ctx.integrators["agent"], + skill=ctx.integrators["skill"], + instruction=ctx.integrators["instruction"], + command=ctx.integrators["command"], + hook=ctx.integrators["hook"], + ), force=ctx.force, managed_files=ctx.managed_files, diagnostics=diagnostics, diff --git a/src/apm_cli/install/validation.py b/src/apm_cli/install/validation.py index 06fe69f80..43aea11a4 100644 --- a/src/apm_cli/install/validation.py +++ b/src/apm_cli/install/validation.py @@ -134,16 +134,584 @@ def _local_path_no_markers_hint(local_dir, logger=None): _rich_echo(f" ... and {len(found) - 5} more", color="dim") +def _validate_local_package(dep_ref, logger) -> bool: + """Validate a local-path package: directory must exist and contain package markers. + + Returns True if the directory exists and has ``apm.yml``, ``SKILL.md``, or + a ``plugin.json`` file. Returns False and optionally surfaces a sub-package + hint when markers are absent. + """ + local = Path(dep_ref.local_path).expanduser() + if not local.is_absolute(): + local = Path.cwd() / local + local = local.resolve() + if not local.is_dir(): + return False + # Must contain apm.yml, SKILL.md, or plugin.json + if (local / "apm.yml").exists() or (local / "SKILL.md").exists(): + return True + from apm_cli.utils.helpers import find_plugin_json + + if find_plugin_json(local) is not None: + return True + # Directory exists but lacks package markers -- surface a hint + _local_path_no_markers_hint(local, logger=logger) + return False + + +def _validate_virtual_package( + dep_ref, + auth_resolver, + verbose: bool, + verbose_log, + package: str, + logger, +) -> bool: + """Validate a virtual package using ``GitHubPackageDownloader``. + + Returns True when ``PROXY_REGISTRY_ONLY=1`` (proxy handles the 404 case), + or delegates to the downloader's ``validate_virtual_package_exists`` and + surfaces a verbose auth context on failure. + """ + from apm_cli.deps.github_downloader import GitHubPackageDownloader + + from ..deps.registry_proxy import is_enforce_only + + if is_enforce_only(): + # PROXY_REGISTRY_ONLY=1: skip virtual package validation probe. + # The download step will surface a proxy 404 if the package is absent. + if logger: + logger.info( + "Skipping virtual package validation for" + f" {dep_ref.host or 'remote'}: proxy-only mode is active" + ) + return True + + ctx = auth_resolver.resolve_for_dep(dep_ref) + host = dep_ref.host or default_host() + org = dep_ref.repo_url.split("/")[0] if dep_ref.repo_url and "/" in dep_ref.repo_url else None + if verbose_log: + verbose_log( + f"Auth resolved: host={host}, org={org}, source={ctx.source}, type={ctx.token_type}" + ) + virtual_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver) + + def _warn(msg: str) -> None: + # Round-4 panel fix (cli-logging + devx-ux converge): + # * Yellow warnings MUST reach the user in BOTH + # verbose and non-verbose modes -- the git-fallback + # signal is security-relevant (a scoped PAT may + # have correctly rejected the package on the API + # surface and the broader git-credential chain + # accepted it). Operators must see this in default + # CI logs. + # * Strip the "Run with --verbose for details." + # suffix only when --verbose is already set; the + # suffix is meaningful only when it tells the user + # a follow-up is available. + # * Fall back to ``_rich_warning`` when ``logger`` is + # None so production callers without a + # CommandLogger still emit the yellow signal -- + # comments are not enforcement. + display = msg + verbose_suffix = " Run with --verbose for details." + if verbose and msg.endswith(verbose_suffix): + display = msg[: -len(verbose_suffix)] + if logger: + logger.warning(display) + else: + _rich_warning(display) + + result = virtual_downloader.validate_virtual_package_exists( + dep_ref, + verbose_callback=verbose_log, + warn_callback=_warn, + ) + if not result and verbose_log: + try: + err_ctx = auth_resolver.build_error_context( + host, + f"accessing {package}", + org=org, + port=dep_ref.port, + dep_url=dep_ref.repo_url, + ) + for line in err_ctx.splitlines(): + verbose_log(line) + except Exception: + pass + return result + + +def _validate_ado_git_package( + dep_ref, + auth_resolver, + verbose_log, + package: str, + logger, +) -> bool: + """Validate an ADO, GHES, or generic-git-host package via ``git ls-remote``. + + Handles: + - Proxy-only short-circuit (``PROXY_REGISTRY_ONLY=1``) + - Host classification (GitLab, generic, ADO/GHES) + - Authenticated URL construction with the correct auth scheme + - Strict vs. fallback protocol ordering (``APM_ALLOW_PROTOCOL_FALLBACK``) + - ADO bearer-token fallback when a PAT is rejected + - Typed ``AuthenticationError`` for auth failures on managed hosts + + Returns True when the repo is reachable, False otherwise. + Raises ``AuthenticationError`` for auth failures on non-generic managed hosts. + """ + import os + import subprocess + + from apm_cli.deps.github_downloader import GitHubPackageDownloader + from apm_cli.deps.transport_selection import is_fallback_allowed + from apm_cli.utils.github_host import is_azure_devops_hostname, is_github_hostname + + from ..deps.registry_proxy import is_enforce_only + + if is_enforce_only(): + # PROXY_REGISTRY_ONLY=1: skip direct git ls-remote probe for ADO/GHES. + # The download step will surface a proxy 404 if the package is absent. + if logger: + logger.info( + "Skipping direct git ls-remote for" + f" {dep_ref.host or 'remote'}: proxy-only mode is active" + ) + return True + + # Determine host type before building the URL so we know whether to + # embed a token. Generic (non-GitHub, non-ADO) hosts are excluded + # from APM-managed auth; they rely on git credential helpers via the + # relaxed validate_env below. GitLab hosts are managed when classified + # as GitLab because they need oauth2 HTTPS token formatting. + is_gitlab = auth_resolver.classify_host(dep_ref.host).kind == "gitlab" + is_generic = ( + not is_github_hostname(dep_ref.host) + and not is_azure_devops_hostname(dep_ref.host) + and not is_gitlab + ) + + # For GHES / ADO: resolve per-dependency auth up front so the URL + # carries an embedded token and avoids triggering OS credential + # helper popups during git ls-remote validation. + _url_token = None + _dep_ctx = None + _auth_scheme = "basic" + if not is_generic: + _dep_ctx = auth_resolver.resolve_for_dep(dep_ref) + _url_token = _dep_ctx.token + _auth_scheme = getattr(_dep_ctx, "auth_scheme", "basic") or "basic" + + ado_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver) + # Set the host + if dep_ref.host: + ado_downloader.github_host = dep_ref.host + + # Build authenticated URL using the resolved per-dep token. + # #1015: pass auth_scheme so bearer tokens use extraheader + # injection instead of embedding a ~1.5KB JWT in the userinfo. + package_url = ado_downloader._build_repo_url( + dep_ref.repo_url, + use_ssh=False, + dep_ref=dep_ref, + token=_url_token, + auth_scheme=_auth_scheme, + ) + + explicit_scheme = (getattr(dep_ref, "explicit_scheme", None) or "").lower() or None + is_insecure = bool(getattr(dep_ref, "is_insecure", False)) + + # Strict-by-default cross-protocol policy (issue microsoft/apm#992): + # an explicit ``http://`` / ``https://`` / ``ssh://`` URL is honored + # exactly and does NOT silently fall back to a different protocol. + # This mirrors the strict default of ``_clone_with_fallback`` / + # :class:`TransportSelector` and prevents the foot-gun where a user + # types ``https://corp-bitbucket.example/...`` and the validation + # pre-check silently retries SSH on port 22, masking the real HTTPS + # failure (auth/redirect/etc.) behind a 30s SSH timeout. The + # ``APM_ALLOW_PROTOCOL_FALLBACK=1`` env var (the same escape-hatch + # the clone path honors) restores the legacy permissive chain. + allow_fallback_env = is_fallback_allowed() + + # For generic hosts (not GitHub, not ADO), relax the env so native + # credential helpers (macOS Keychain, credential-store, + # manager-core, SSH agent, etc.) can work. Config isolation + # (GIT_CONFIG_GLOBAL=/dev/null, GIT_CONFIG_NOSYSTEM=1) is only + # enforced for insecure plaintext HTTP connections where + # credential leakage is a real risk; HTTPS connections need + # access to user-configured helpers in ~/.gitconfig. This + # matches _clone_with_fallback() and git_reference_resolver. + if is_generic: + validate_env = ado_downloader._build_noninteractive_git_env( + preserve_config_isolation=is_insecure, + suppress_credential_helpers=is_insecure, + ) + else: + # #1015: merge _dep_ctx.git_env (bearer-aware GIT_CONFIG_* + # overrides) into the subprocess env so `git ls-remote` + # actually sends the Authorization header for AAD tokens. + _ctx_git_env = getattr(_dep_ctx, "git_env", {}) if _dep_ctx else {} + validate_env = {**os.environ, **ado_downloader.git_env, **_ctx_git_env} + + # Build the probe order. Non-generic hosts (GHES/ADO) always probe + # a single authenticated URL. Generic hosts: + # - explicit https/http -> web URL only (strict) + # - explicit ssh -> SSH URL only (strict) + # - shorthand (no scheme) -> legacy [SSH, HTTPS] chain + # ``APM_ALLOW_PROTOCOL_FALLBACK=1`` re-appends the opposite scheme + # for the explicit cases to match clone semantics exactly. + if is_generic: + ssh_url = ado_downloader._build_repo_url(dep_ref.repo_url, use_ssh=True, dep_ref=dep_ref) + if explicit_scheme in ("http", "https"): + urls_to_try: list[str] = ( + [package_url] if not allow_fallback_env else [package_url, ssh_url] + ) + elif explicit_scheme == "ssh": + urls_to_try = [ssh_url] if not allow_fallback_env else [ssh_url, package_url] + else: + # Shorthand has no user-stated transport; keep the legacy + # SSH-first chain so existing flows (e.g. SSH-key users on + # corporate hosts) keep validating successfully. + urls_to_try = [ssh_url, package_url] + else: + urls_to_try = [package_url] + + if verbose_log: + attempt_word = "attempt" if len(urls_to_try) == 1 else "attempts" + verbose_log(f"Trying git ls-remote for {dep_ref.host} ({len(urls_to_try)} {attempt_word})") + + def _scheme_of(url: str) -> str: + return url.split("://", 1)[0] if "://" in url else "ssh" + + def _log_attempt_result(probe_url: str, run_result) -> None: + """Per-attempt sanitized verbose logging. + + The previous implementation only logged the final attempt's + result, which masked the actual failure (typically the HTTPS + leg) behind the SSH-fallback timeout. Logging each attempt + gives users the diagnostic data they need to act. + """ + if not verbose_log: + return + scheme = _scheme_of(probe_url) + if run_result.returncode == 0: + verbose_log(f"git ls-remote ({scheme}) rc=0 for {package}") + return + raw_stderr = (run_result.stderr or "").strip()[:200] + stderr_snippet = ado_downloader._sanitize_git_error(raw_stderr) + for env_var in ("GIT_ASKPASS", "GIT_CONFIG_GLOBAL"): + env_val = validate_env.get(env_var, "") + if env_val: + stderr_snippet = stderr_snippet.replace(env_val, "***") + verbose_log(f"git ls-remote ({scheme}) rc={run_result.returncode}: {stderr_snippet}") + + result = None + for probe_url in urls_to_try: + cmd = ["git", "ls-remote", "--heads", "--exit-code", probe_url] + result = subprocess.run( + cmd, + capture_output=True, + text=True, + encoding="utf-8", + timeout=30, + env=validate_env, + ) + _log_attempt_result(probe_url, result) + if result.returncode == 0: + break + + # ADO bearer fallback: if PAT was rejected (rc != 0 with auth-failure + # signal) AND the dep is on Azure DevOps AND we resolved a PAT, + # silently retry with az-cli bearer token. + if ( + result is not None + and result.returncode != 0 + and dep_ref.is_azure_devops() + and _url_token is not None # we had a PAT + and is_ado_auth_failure_signal(result.stderr or "") + ): + try: + from apm_cli.core.azure_cli import AzureCliBearerError, get_bearer_provider + + provider = get_bearer_provider() + if provider.is_available(): + try: + bearer = provider.get_bearer_token() + bearer_url = ado_downloader._build_repo_url( + dep_ref.repo_url, + use_ssh=False, + dep_ref=dep_ref, + token=None, + auth_scheme="bearer", + ) + # SECURITY: build a CLEAN env via _build_git_env(scheme="bearer") + # rather than {**validate_env, **build_ado_bearer_git_env(bearer)}. + # validate_env still carries the PAT-context GIT_CONFIG_* + # entries from _ctx_git_env; merging the bearer env on top + # would keep the rejected PAT visible in the child-process + # env (visible in /proc//environ on Linux). _build_git_env + # explicitly skips GIT_TOKEN for scheme="bearer" and emits + # only the bearer-specific GIT_CONFIG_* injection. + bearer_env = auth_resolver._build_git_env( + bearer, scheme="bearer", host_kind="ado" + ) + cmd = ["git", "ls-remote", "--heads", "--exit-code", bearer_url] + bearer_result = subprocess.run( + cmd, + capture_output=True, + text=True, + encoding="utf-8", + timeout=30, + env=bearer_env, + ) + if bearer_result.returncode == 0: + # Emit deferred stale-PAT warning via resolver + auth_resolver.emit_stale_pat_diagnostic(dep_ref.host or "dev.azure.com") + if verbose_log: + verbose_log( + f"git ls-remote rc=0 for {package} (via AAD bearer fallback)" + ) + return True + except AzureCliBearerError: + pass + except ImportError: + pass + + # Per-attempt verbose logging is emitted inside the probe loop + # (and by the bearer-fallback branch above), so the result is + # already on screen by the time we get here. Stderr is sanitized + # via ``GitHubPackageDownloader._sanitize_git_error`` to scrub + # any token-bearing URLs / env values before logging. + + # #1015: distinguish auth failures from non-auth failures (DNS, + # timeout, repo-truly-not-found 404). Auth failures get a typed + # exception with actionable diagnostics; non-auth failures keep + # the legacy False return so the caller can word its own message. + if result.returncode != 0 and not is_generic: + if is_ado_auth_failure_signal(result.stderr or ""): + _host = dep_ref.host or "dev.azure.com" + _org = ( + dep_ref.repo_url.split("/")[0] + if dep_ref.repo_url and "/" in dep_ref.repo_url + else None + ) + _diag = auth_resolver.build_error_context( + _host, + "validate", + org=_org, + dep_url=dep_ref.repo_url, + ) + raise AuthenticationError( + f"Authentication failed for {_host}", + diagnostic_context=_diag, + ) + + return result.returncode == 0 + + +def _validate_github_package( + dep_ref, + auth_resolver, + verbose: bool, + verbose_log, + package: str, + logger, +) -> bool: + """Validate a GitHub.com (or GHES) package via the GitHub REST API. + + Uses ``AuthResolver.try_with_fallback`` with ``unauth_first=True`` so + public repos are probed anonymously before burning a rate-limited token. + Returns True/False; surfaces verbose auth context on failure. + """ + from ..deps.registry_proxy import is_enforce_only + + host = dep_ref.host or default_host() + port = dep_ref.port + org = dep_ref.repo_url.split("/")[0] if dep_ref.repo_url and "/" in dep_ref.repo_url else None + host_info = auth_resolver.classify_host(host, port=port) + + if is_enforce_only(): + # PROXY_REGISTRY_ONLY=1: skip the GitHub API probe. + # Marketplace/lockfile resolution already ran through the proxy; + # the download step will surface a proxy 404 if absent. + if logger: + logger.info(f"Skipping direct GitHub API probe for {host}: proxy-only mode is active") + return True + + if verbose_log: + ctx = auth_resolver.resolve(host, org=org, port=port) + verbose_log( + f"Auth resolved: host={host_info.display_name}, org={org}, " + f"source={ctx.source}, type={ctx.token_type}" + ) + + def _check_repo(token, git_env) -> bool: + """Check repo accessibility via GitHub API.""" + api_base = host_info.api_base + api_url = f"{api_base}/repos/{dep_ref.repo_url}" + headers = { + "Accept": "application/vnd.github+json", + "User-Agent": "apm-cli", + } + if token: + headers["Authorization"] = f"Bearer {token}" + + try: + resp = requests.get(api_url, headers=headers, timeout=15) + except requests.exceptions.SSLError as e: + raise RuntimeError(f"TLS verification failed for {host_info.display_name}") from e + except requests.exceptions.RequestException as e: + if verbose_log: + verbose_log(f"API request failed: {e}") + raise + + if verbose_log: + verbose_log(f"API {api_url} -> {resp.status_code}") + if resp.ok: + return True + if resp.status_code == 404 and token: + # 404 with token could mean no access -- raise to trigger fallback + raise RuntimeError(f"API returned {resp.status_code}") + raise RuntimeError(f"API returned {resp.status_code}: {resp.reason}") + + try: + return auth_resolver.try_with_fallback( + host, + _check_repo, + org=org, + port=port, + # dep_ref.repo_url is owner/repo (never a full URL per the + # DependencyReference invariant); forwarded as path= so GCM + # multi-account users get per-URL credential matching. + path=dep_ref.repo_url, + unauth_first=True, + verbose_callback=verbose_log, + ) + except Exception as exc: + if _is_tls_failure(exc): + _log_tls_failure(host_info.display_name, exc, verbose_log, logger) + return False + if verbose_log: + try: + ctx = auth_resolver.build_error_context( + host, + f"accessing {package}", + org=org, + port=port, + dep_url=getattr(dep_ref, "repo_url", None), + ) + for line in ctx.splitlines(): + verbose_log(line) + except Exception: + pass + return False + + +def _validate_parse_failure_fallback( + package: str, + auth_resolver, + verbose_log, + logger, +) -> bool: + """Fallback validation used when ``DependencyReference.parse`` raises. + + Treats *package* as a raw ``owner/repo`` slug and probes the GitHub.com + API. Rejects anything that doesn't match the strict slug pattern so + path-confusion sequences cannot reach the API URL or git credential fill. + """ + from ..deps.registry_proxy import is_enforce_only + + host = default_host() + org = package.split("/")[0] if "/" in package else None + repo_path = package # owner/repo format + # Defensive owner/repo guard: when DependencyReference.parse raises, + # we fall back to embedding `repo_path` directly into an API URL and + # forwarding it as `path=` to git credential fill. Reject anything + # that isn't a strict / slug so path-confusion sequences + # (`../`, embedded slashes, control bytes) cannot reach either sink. + # Allows GitHub's documented owner/repo characters: alphanumeric, + # dot, underscore, hyphen. + if not re.fullmatch(r"[A-Za-z0-9._-]+/[A-Za-z0-9._-]+", repo_path): + return False + + if is_enforce_only(): + # PROXY_REGISTRY_ONLY=1: skip the GitHub API fallback probe. + # The download step will surface a proxy 404 if the package is absent. + if logger: + logger.info( + f"Skipping direct GitHub API fallback probe for {host}: proxy-only mode is active" + ) + return True + + def _check_repo_fallback(token, git_env) -> bool: + host_info = auth_resolver.classify_host(host) + api_url = f"{host_info.api_base}/repos/{repo_path}" + headers = { + "Accept": "application/vnd.github+json", + "User-Agent": "apm-cli", + } + if token: + headers["Authorization"] = f"Bearer {token}" + + try: + resp = requests.get(api_url, headers=headers, timeout=15) + except requests.exceptions.SSLError as e: + raise RuntimeError(f"TLS verification failed for {host_info.display_name}") from e + except requests.exceptions.RequestException as e: + if verbose_log: + verbose_log(f"API fallback failed: {e}") + raise + + if resp.ok: + return True + if verbose_log: + verbose_log(f"API fallback -> {resp.status_code} {resp.reason}") + raise RuntimeError(f"API returned {resp.status_code}") + + try: + return auth_resolver.try_with_fallback( + host, + _check_repo_fallback, + org=org, + path=repo_path, + unauth_first=True, + verbose_callback=verbose_log, + ) + except Exception as exc: + if _is_tls_failure(exc): + # See note above: logged once here, skip auth context render. + _log_tls_failure(host, exc, verbose_log, logger) + return False + if verbose_log: + try: + ctx = auth_resolver.build_error_context( + host, f"accessing {package}", org=org, dep_url=package + ) + for line in ctx.splitlines(): + verbose_log(line) + except Exception: + pass + return False + + def _validate_package_exists(package, verbose=False, auth_resolver=None, logger=None, dep_ref=None): """Validate that a package exists and is accessible on GitHub, Azure DevOps, or locally. When *dep_ref* is provided (for example, marketplace GitLab monorepo resolution), use it instead of reparsing *package* so explicit ``git`` + ``path`` semantics are preserved. - """ - import os - import subprocess + Dispatches to per-backend helpers: + + - ``_validate_local_package`` -- local filesystem paths + - ``_validate_virtual_package`` -- virtual monorepo packages + - ``_validate_ado_git_package`` -- ADO / GHES / generic git hosts + - ``_validate_github_package`` -- GitHub.com REST API + - ``_validate_parse_failure_fallback`` -- raw slug fallback + """ from apm_cli.core.auth import AuthResolver if logger: @@ -155,36 +723,19 @@ def _validate_package_exists(package, verbose=False, auth_resolver=None, logger= auth_resolver = AuthResolver() try: - # Parse the package to check if it's a virtual package or ADO - from apm_cli.deps.github_downloader import GitHubPackageDownloader from apm_cli.models.apm_package import DependencyReference + from apm_cli.utils.github_host import is_github_hostname if dep_ref is None: dep_ref = DependencyReference.parse(package) # For local packages, validate directory exists and has valid package content if dep_ref.is_local and dep_ref.local_path: - local = Path(dep_ref.local_path).expanduser() - if not local.is_absolute(): - local = Path.cwd() / local - local = local.resolve() - if not local.is_dir(): - return False - # Must contain apm.yml, SKILL.md, or plugin.json - if (local / "apm.yml").exists() or (local / "SKILL.md").exists(): - return True - from apm_cli.utils.helpers import find_plugin_json - - if find_plugin_json(local) is not None: - return True - # Directory exists but lacks package markers -- surface a hint - _local_path_no_markers_hint(local, logger=logger) - return False - - from apm_cli.utils.github_host import is_azure_devops_hostname, is_github_hostname - - from ..deps.registry_proxy import is_enforce_only + return _validate_local_package(dep_ref, logger) + # ``virtual_subdir_repo_probe``: a virtual subdirectory on a non-GitHub, + # non-ADO host must validate the clone root via git ls-remote rather than + # the virtual downloader so SSH/credential-helper flows are preserved. virtual_subdir_repo_probe = ( dep_ref.is_virtual and dep_ref.is_virtual_subdirectory() @@ -196,417 +747,23 @@ def _validate_package_exists(package, verbose=False, auth_resolver=None, logger= # the virtual path is a subdirectory on a non-GitHub host. Those should # validate the clone root with git, preserving SSH/credential-helper flows. if dep_ref.is_virtual and not virtual_subdir_repo_probe: - if is_enforce_only(): - # PROXY_REGISTRY_ONLY=1: skip virtual package validation probe. - # The download step will surface a proxy 404 if the package is absent. - if logger: - logger.info( - "Skipping virtual package validation for" - f" {dep_ref.host or 'remote'}: proxy-only mode is active" - ) - return True - ctx = auth_resolver.resolve_for_dep(dep_ref) - host = dep_ref.host or default_host() - org = ( - dep_ref.repo_url.split("/")[0] - if dep_ref.repo_url and "/" in dep_ref.repo_url - else None + return _validate_virtual_package( + dep_ref, auth_resolver, verbose, verbose_log, package, logger ) - if verbose_log: - verbose_log( - f"Auth resolved: host={host}, org={org}, source={ctx.source}, type={ctx.token_type}" - ) - virtual_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver) - - def _warn(msg: str) -> None: - # Round-4 panel fix (cli-logging + devx-ux converge): - # * Yellow warnings MUST reach the user in BOTH - # verbose and non-verbose modes -- the git-fallback - # signal is security-relevant (a scoped PAT may - # have correctly rejected the package on the API - # surface and the broader git-credential chain - # accepted it). Operators must see this in default - # CI logs. - # * Strip the "Run with --verbose for details." - # suffix only when --verbose is already set; the - # suffix is meaningful only when it tells the user - # a follow-up is available. - # * Fall back to ``_rich_warning`` when ``logger`` is - # None so production callers without a - # CommandLogger still emit the yellow signal -- - # comments are not enforcement. - display = msg - verbose_suffix = " Run with --verbose for details." - if verbose and msg.endswith(verbose_suffix): - display = msg[: -len(verbose_suffix)] - if logger: - logger.warning(display) - else: - _rich_warning(display) - - result = virtual_downloader.validate_virtual_package_exists( - dep_ref, - verbose_callback=verbose_log, - warn_callback=_warn, - ) - if not result and verbose_log: - try: - err_ctx = auth_resolver.build_error_context( - host, - f"accessing {package}", - org=org, - port=dep_ref.port, - dep_url=dep_ref.repo_url, - ) - for line in err_ctx.splitlines(): - verbose_log(line) - except Exception: - pass - return result # For Azure DevOps or GitHub Enterprise (non-github.com hosts), - # use the downloader which handles authentication properly + # use git ls-remote which handles authentication properly. if ( virtual_subdir_repo_probe or dep_ref.is_azure_devops() or (dep_ref.host and dep_ref.host != "github.com") ): - if is_enforce_only(): - # PROXY_REGISTRY_ONLY=1: skip direct git ls-remote probe for ADO/GHES. - # The download step will surface a proxy 404 if the package is absent. - if logger: - logger.info( - "Skipping direct git ls-remote for" - f" {dep_ref.host or 'remote'}: proxy-only mode is active" - ) - return True - - # Determine host type before building the URL so we know whether to - # embed a token. Generic (non-GitHub, non-ADO) hosts are excluded - # from APM-managed auth; they rely on git credential helpers via the - # relaxed validate_env below. GitLab hosts are managed when classified - # as GitLab because they need oauth2 HTTPS token formatting. - is_gitlab = auth_resolver.classify_host(dep_ref.host).kind == "gitlab" - is_generic = ( - not is_github_hostname(dep_ref.host) - and not is_azure_devops_hostname(dep_ref.host) - and not is_gitlab - ) - - # For GHES / ADO: resolve per-dependency auth up front so the URL - # carries an embedded token and avoids triggering OS credential - # helper popups during git ls-remote validation. - _url_token = None - _dep_ctx = None - _auth_scheme = "basic" - if not is_generic: - _dep_ctx = auth_resolver.resolve_for_dep(dep_ref) - _url_token = _dep_ctx.token - _auth_scheme = getattr(_dep_ctx, "auth_scheme", "basic") or "basic" - - ado_downloader = GitHubPackageDownloader(auth_resolver=auth_resolver) - # Set the host - if dep_ref.host: - ado_downloader.github_host = dep_ref.host - - # Build authenticated URL using the resolved per-dep token. - # #1015: pass auth_scheme so bearer tokens use extraheader - # injection instead of embedding a ~1.5KB JWT in the userinfo. - package_url = ado_downloader._build_repo_url( - dep_ref.repo_url, - use_ssh=False, - dep_ref=dep_ref, - token=_url_token, - auth_scheme=_auth_scheme, - ) - - explicit_scheme = (getattr(dep_ref, "explicit_scheme", None) or "").lower() or None - is_insecure = bool(getattr(dep_ref, "is_insecure", False)) - - # Strict-by-default cross-protocol policy (issue microsoft/apm#992): - # an explicit ``http://`` / ``https://`` / ``ssh://`` URL is honored - # exactly and does NOT silently fall back to a different protocol. - # This mirrors the strict default of ``_clone_with_fallback`` / - # :class:`TransportSelector` and prevents the foot-gun where a user - # types ``https://corp-bitbucket.example/...`` and the validation - # pre-check silently retries SSH on port 22, masking the real HTTPS - # failure (auth/redirect/etc.) behind a 30s SSH timeout. The - # ``APM_ALLOW_PROTOCOL_FALLBACK=1`` env var (the same escape-hatch - # the clone path honors) restores the legacy permissive chain. - from apm_cli.deps.transport_selection import is_fallback_allowed - - allow_fallback_env = is_fallback_allowed() - - # For generic hosts (not GitHub, not ADO), relax the env so native - # credential helpers (macOS Keychain, credential-store, - # manager-core, SSH agent, etc.) can work. Config isolation - # (GIT_CONFIG_GLOBAL=/dev/null, GIT_CONFIG_NOSYSTEM=1) is only - # enforced for insecure plaintext HTTP connections where - # credential leakage is a real risk; HTTPS connections need - # access to user-configured helpers in ~/.gitconfig. This - # matches _clone_with_fallback() and git_reference_resolver. - if is_generic: - validate_env = ado_downloader._build_noninteractive_git_env( - preserve_config_isolation=is_insecure, - suppress_credential_helpers=is_insecure, - ) - else: - # #1015: merge _dep_ctx.git_env (bearer-aware GIT_CONFIG_* - # overrides) into the subprocess env so `git ls-remote` - # actually sends the Authorization header for AAD tokens. - _ctx_git_env = getattr(_dep_ctx, "git_env", {}) if _dep_ctx else {} - validate_env = {**os.environ, **ado_downloader.git_env, **_ctx_git_env} - - # Build the probe order. Non-generic hosts (GHES/ADO) always probe - # a single authenticated URL. Generic hosts: - # - explicit https/http -> web URL only (strict) - # - explicit ssh -> SSH URL only (strict) - # - shorthand (no scheme) -> legacy [SSH, HTTPS] chain - # ``APM_ALLOW_PROTOCOL_FALLBACK=1`` re-appends the opposite scheme - # for the explicit cases to match clone semantics exactly. - urls_to_try = [] - if is_generic: - ssh_url = ado_downloader._build_repo_url( - dep_ref.repo_url, use_ssh=True, dep_ref=dep_ref - ) - if explicit_scheme in ("http", "https"): - urls_to_try = ( - [package_url] if not allow_fallback_env else [package_url, ssh_url] - ) - elif explicit_scheme == "ssh": - urls_to_try = [ssh_url] if not allow_fallback_env else [ssh_url, package_url] - else: - # Shorthand has no user-stated transport; keep the legacy - # SSH-first chain so existing flows (e.g. SSH-key users on - # corporate hosts) keep validating successfully. - urls_to_try = [ssh_url, package_url] - else: - urls_to_try = [package_url] - - if verbose_log: - attempt_word = "attempt" if len(urls_to_try) == 1 else "attempts" - verbose_log( - f"Trying git ls-remote for {dep_ref.host} ({len(urls_to_try)} {attempt_word})" - ) - - def _scheme_of(url: str) -> str: - return url.split("://", 1)[0] if "://" in url else "ssh" - - def _log_attempt_result(probe_url: str, run_result): - """Per-attempt sanitized verbose logging. - - The previous implementation only logged the final attempt's - result, which masked the actual failure (typically the HTTPS - leg) behind the SSH-fallback timeout. Logging each attempt - gives users the diagnostic data they need to act. - """ - if not verbose_log: - return - scheme = _scheme_of(probe_url) - if run_result.returncode == 0: - verbose_log(f"git ls-remote ({scheme}) rc=0 for {package}") - return - raw_stderr = (run_result.stderr or "").strip()[:200] - stderr_snippet = ado_downloader._sanitize_git_error(raw_stderr) - for env_var in ("GIT_ASKPASS", "GIT_CONFIG_GLOBAL"): - env_val = validate_env.get(env_var, "") - if env_val: - stderr_snippet = stderr_snippet.replace(env_val, "***") - verbose_log( - f"git ls-remote ({scheme}) rc={run_result.returncode}: {stderr_snippet}" - ) - - result = None - for probe_url in urls_to_try: - cmd = ["git", "ls-remote", "--heads", "--exit-code", probe_url] - result = subprocess.run( - cmd, - capture_output=True, - text=True, - encoding="utf-8", - timeout=30, - env=validate_env, - ) - _log_attempt_result(probe_url, result) - if result.returncode == 0: - break - - # ADO bearer fallback: if PAT was rejected (rc != 0 with auth-failure - # signal) AND the dep is on Azure DevOps AND we resolved a PAT, - # silently retry with az-cli bearer token. - if ( - result is not None - and result.returncode != 0 - and dep_ref.is_azure_devops() - and _url_token is not None # we had a PAT - and is_ado_auth_failure_signal(result.stderr or "") - ): - try: - from apm_cli.core.azure_cli import AzureCliBearerError, get_bearer_provider - - provider = get_bearer_provider() - if provider.is_available(): - try: - bearer = provider.get_bearer_token() - bearer_url = ado_downloader._build_repo_url( - dep_ref.repo_url, - use_ssh=False, - dep_ref=dep_ref, - token=None, - auth_scheme="bearer", - ) - # SECURITY: build a CLEAN env via _build_git_env(scheme="bearer") - # rather than {**validate_env, **build_ado_bearer_git_env(bearer)}. - # validate_env still carries the PAT-context GIT_CONFIG_* - # entries from _ctx_git_env; merging the bearer env on top - # would keep the rejected PAT visible in the child-process - # env (visible in /proc//environ on Linux). _build_git_env - # explicitly skips GIT_TOKEN for scheme="bearer" and emits - # only the bearer-specific GIT_CONFIG_* injection. - bearer_env = auth_resolver._build_git_env( - bearer, scheme="bearer", host_kind="ado" - ) - cmd = ["git", "ls-remote", "--heads", "--exit-code", bearer_url] - bearer_result = subprocess.run( - cmd, - capture_output=True, - text=True, - encoding="utf-8", - timeout=30, - env=bearer_env, - ) - if bearer_result.returncode == 0: - # Emit deferred stale-PAT warning via resolver - auth_resolver.emit_stale_pat_diagnostic( - dep_ref.host or "dev.azure.com" - ) - if verbose_log: - verbose_log( - f"git ls-remote rc=0 for {package} " - f"(via AAD bearer fallback)" - ) - return True - except AzureCliBearerError: - pass - except ImportError: - pass - - # Per-attempt verbose logging is emitted inside the probe loop - # (and by the bearer-fallback branch above), so the result is - # already on screen by the time we get here. Stderr is sanitized - # via ``GitHubPackageDownloader._sanitize_git_error`` to scrub - # any token-bearing URLs / env values before logging. - - # #1015: distinguish auth failures from non-auth failures (DNS, - # timeout, repo-truly-not-found 404). Auth failures get a typed - # exception with actionable diagnostics; non-auth failures keep - # the legacy False return so the caller can word its own message. - if result.returncode != 0 and not is_generic: - if is_ado_auth_failure_signal(result.stderr or ""): - _host = dep_ref.host or "dev.azure.com" - _org = ( - dep_ref.repo_url.split("/")[0] - if dep_ref.repo_url and "/" in dep_ref.repo_url - else None - ) - _diag = auth_resolver.build_error_context( - _host, - "validate", - org=_org, - dep_url=dep_ref.repo_url, - ) - raise AuthenticationError( - f"Authentication failed for {_host}", - diagnostic_context=_diag, - ) - - return result.returncode == 0 + return _validate_ado_git_package(dep_ref, auth_resolver, verbose_log, package, logger) # For GitHub.com, use AuthResolver with unauth-first fallback - host = dep_ref.host or default_host() - port = dep_ref.port - org = ( - dep_ref.repo_url.split("/")[0] if dep_ref.repo_url and "/" in dep_ref.repo_url else None + return _validate_github_package( + dep_ref, auth_resolver, verbose, verbose_log, package, logger ) - host_info = auth_resolver.classify_host(host, port=port) - - if is_enforce_only(): - # PROXY_REGISTRY_ONLY=1: skip the GitHub API probe. - # Marketplace/lockfile resolution already ran through the proxy; - # the download step will surface a proxy 404 if absent. - if logger: - logger.info( - f"Skipping direct GitHub API probe for {host}: proxy-only mode is active" - ) - return True - - if verbose_log: - ctx = auth_resolver.resolve(host, org=org, port=port) - verbose_log( - f"Auth resolved: host={host_info.display_name}, org={org}, " - f"source={ctx.source}, type={ctx.token_type}" - ) - - def _check_repo(token, git_env): - """Check repo accessibility via GitHub API (or git ls-remote for non-GitHub).""" - api_base = host_info.api_base - api_url = f"{api_base}/repos/{dep_ref.repo_url}" - headers = { - "Accept": "application/vnd.github+json", - "User-Agent": "apm-cli", - } - if token: - headers["Authorization"] = f"Bearer {token}" - - try: - resp = requests.get(api_url, headers=headers, timeout=15) - except requests.exceptions.SSLError as e: - raise RuntimeError(f"TLS verification failed for {host_info.display_name}") from e - except requests.exceptions.RequestException as e: - if verbose_log: - verbose_log(f"API request failed: {e}") - raise - - if verbose_log: - verbose_log(f"API {api_url} -> {resp.status_code}") - if resp.ok: - return True - if resp.status_code == 404 and token: - # 404 with token could mean no access -- raise to trigger fallback - raise RuntimeError(f"API returned {resp.status_code}") - raise RuntimeError(f"API returned {resp.status_code}: {resp.reason}") - - try: - return auth_resolver.try_with_fallback( - host, - _check_repo, - org=org, - port=port, - # dep_ref.repo_url is owner/repo (never a full URL per the - # DependencyReference invariant); forwarded as path= so GCM - # multi-account users get per-URL credential matching. - path=dep_ref.repo_url, - unauth_first=True, - verbose_callback=verbose_log, - ) - except Exception as exc: - if _is_tls_failure(exc): - _log_tls_failure(host_info.display_name, exc, verbose_log, logger) - return False - if verbose_log: - try: - ctx = auth_resolver.build_error_context( - host, - f"accessing {package}", - org=org, - port=port, - dep_url=getattr(dep_ref, "repo_url", None), - ) - for line in ctx.splitlines(): - verbose_log(line) - except Exception: - pass - return False except AuthenticationError: # #1015: let auth failures propagate to the caller for proper @@ -614,77 +771,4 @@ def _check_repo(token, git_env): raise except Exception: # If parsing fails, assume it's a regular GitHub package - host = default_host() - org = package.split("/")[0] if "/" in package else None - repo_path = package # owner/repo format - # Defensive owner/repo guard: when DependencyReference.parse raises, - # we fall back to embedding `repo_path` directly into an API URL and - # forwarding it as `path=` to git credential fill. Reject anything - # that isn't a strict / slug so path-confusion sequences - # (`../`, embedded slashes, control bytes) cannot reach either sink. - # Allows GitHub's documented owner/repo characters: alphanumeric, - # dot, underscore, hyphen. - if not re.fullmatch(r"[A-Za-z0-9._-]+/[A-Za-z0-9._-]+", repo_path): - return False - - from ..deps.registry_proxy import is_enforce_only - - if is_enforce_only(): - # PROXY_REGISTRY_ONLY=1: skip the GitHub API fallback probe. - # The download step will surface a proxy 404 if the package is absent. - if logger: - logger.info( - f"Skipping direct GitHub API fallback probe for {host}:" - " proxy-only mode is active" - ) - return True - - def _check_repo_fallback(token, git_env): - host_info = auth_resolver.classify_host(host) - api_url = f"{host_info.api_base}/repos/{repo_path}" - headers = { - "Accept": "application/vnd.github+json", - "User-Agent": "apm-cli", - } - if token: - headers["Authorization"] = f"Bearer {token}" - - try: - resp = requests.get(api_url, headers=headers, timeout=15) - except requests.exceptions.SSLError as e: - raise RuntimeError(f"TLS verification failed for {host_info.display_name}") from e - except requests.exceptions.RequestException as e: - if verbose_log: - verbose_log(f"API fallback failed: {e}") - raise - - if resp.ok: - return True - if verbose_log: - verbose_log(f"API fallback -> {resp.status_code} {resp.reason}") - raise RuntimeError(f"API returned {resp.status_code}") - - try: - return auth_resolver.try_with_fallback( - host, - _check_repo_fallback, - org=org, - path=repo_path, - unauth_first=True, - verbose_callback=verbose_log, - ) - except Exception as exc: - if _is_tls_failure(exc): - # See note above: logged once here, skip auth context render. - _log_tls_failure(host, exc, verbose_log, logger) - return False - if verbose_log: - try: - ctx = auth_resolver.build_error_context( - host, f"accessing {package}", org=org, dep_url=package - ) - for line in ctx.splitlines(): - verbose_log(line) - except Exception: - pass - return False + return _validate_parse_failure_fallback(package, auth_resolver, verbose_log, logger) diff --git a/src/apm_cli/integration/mcp_integrator.py b/src/apm_cli/integration/mcp_integrator.py index 411b57dbb..02369321e 100644 --- a/src/apm_cli/integration/mcp_integrator.py +++ b/src/apm_cli/integration/mcp_integrator.py @@ -10,6 +10,7 @@ """ import builtins +import json import logging import re import shutil @@ -49,6 +50,145 @@ def _is_vscode_available(project_root: Path | str | None = None) -> bool: return shutil.which("code") is not None or (root / ".vscode").is_dir() +def _clean_json_mcp_config( + config_path: Path, + stale_names: builtins.set, + logger, + label: str, + servers_key: str = "mcpServers", + trailing_newline: bool = False, + use_rich: bool = False, +) -> int: + """Remove stale entries from a JSON-based MCP config file. + + Args: + config_path: Path to the JSON config file. + stale_names: Set of server names to remove (expanded form). + logger: Command logger for progress messages. + label: Human-readable config label used in log messages. + servers_key: Key under which MCP servers are stored (default: ``"mcpServers"``). + trailing_newline: When True, append a trailing newline after JSON serialisation. + use_rich: When True, emit removal notices via ``_rich_success``; otherwise use + ``logger.progress``. + + Returns: + Number of entries removed. + """ + if not config_path.exists(): + return 0 + try: + config = json.loads(config_path.read_text(encoding="utf-8")) + servers = config.get(servers_key, {}) + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + text = json.dumps(config, indent=2) + if trailing_newline: + text += "\n" + config_path.write_text(text, encoding="utf-8") + for name in removed: + msg = f"Removed stale MCP server '{name}' from {label}" + if use_rich: + _rich_success(msg, symbol="check") + else: + logger.progress(msg) + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 + + +def _clean_toml_mcp_config( + config_path: Path, + stale_names: builtins.set, + label: str, + logger=None, + use_rich: bool = True, +) -> int: + """Remove stale entries from a TOML-based MCP config file. + + Args: + config_path: Path to the TOML config file. + stale_names: Set of server names to remove (expanded form). + label: Human-readable config label used in log messages. + logger: Optional command logger for progress messages. When provided + and *use_rich* is False, removal notices use ``logger.progress``. + use_rich: When True (default), emit removal notices via ``_rich_success``; + otherwise use ``logger.progress``. + + Returns: + Number of entries removed. + """ + if not config_path.exists(): + return 0 + try: + import toml as _toml + + config = _toml.loads(config_path.read_text(encoding="utf-8")) + servers = config.get("mcp_servers", {}) + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + config_path.write_text(_toml.dumps(config), encoding="utf-8") + for name in removed: + msg = f"Removed stale MCP server '{name}' from {label}" + if use_rich: + _rich_success(msg, symbol="check") + elif logger is not None: + logger.progress(msg) + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 + + +def _clean_claude_config( + config_path: Path, + stale_names: builtins.set, + logger, + is_user_scope: bool = False, +) -> int: + """Remove stale entries from a Claude Code JSON config file. + + Handles both the project-level ``.mcp.json`` and the user-level + ``~/.claude.json``, which share the same JSON structure but differ in + scope-validation requirements and log labels. + + Args: + config_path: Path to the Claude JSON config file. + stale_names: Set of server names to remove (expanded form). + logger: Command logger for progress messages. + is_user_scope: When True, validates that the top-level config is a dict + (``~/.claude.json`` guard) and uses the user-scope log label. + + Returns: + Number of entries removed. + """ + label = "~/.claude.json" if is_user_scope else ".mcp.json" + if not config_path.exists(): + return 0 + try: + config = json.loads(config_path.read_text(encoding="utf-8")) + if is_user_scope and not isinstance(config, dict): + return 0 + servers = config.get("mcpServers", {}) + if not isinstance(servers, dict): + servers = {} + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + config_path.write_text(json.dumps(config, indent=2) + "\n", encoding="utf-8") + for name in removed: + logger.progress(f"Removed stale MCP server '{name}' from {label}") + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 + + class MCPIntegrator: """MCP lifecycle orchestrator -- dependency resolution, installation, and cleanup. @@ -501,54 +641,24 @@ def remove_stale( project_root_path = Path(project_root) if project_root is not None else Path.cwd() - # Clean .vscode/mcp.json + # Per-runtime cleanup -- each helper reads, diffs, writes, and logs. if "vscode" in target_runtimes: - vscode_mcp = project_root_path / ".vscode" / "mcp.json" - if vscode_mcp.exists(): - try: - import json as _json - - config = _json.loads(vscode_mcp.read_text(encoding="utf-8")) - servers = config.get("servers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - vscode_mcp.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - logger.progress( - f"Removed stale MCP server '{name}' from .vscode/mcp.json" - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from .vscode/mcp.json", - exc_info=True, - ) + _clean_json_mcp_config( + project_root_path / ".vscode" / "mcp.json", + expanded_stale, + logger, + ".vscode/mcp.json", + servers_key="servers", + ) - # Clean ~/.copilot/mcp-config.json if "copilot" in target_runtimes: - copilot_mcp = Path.home() / ".copilot" / "mcp-config.json" - if copilot_mcp.exists(): - try: - import json as _json - - config = _json.loads(copilot_mcp.read_text(encoding="utf-8")) - servers = config.get("mcpServers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - copilot_mcp.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - _rich_success( - f"Removed stale MCP server '{name}' from Copilot CLI config", - symbol="check", - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from Copilot CLI config", - exc_info=True, - ) + _clean_json_mcp_config( + Path.home() / ".copilot" / "mcp-config.json", + expanded_stale, + logger, + "Copilot CLI config", + use_rich=True, + ) # Clean the scope-resolved Codex config.toml (mcp_servers section) if "codex" in target_runtimes: @@ -561,184 +671,62 @@ def remove_stale( user_scope=user_scope, ).get_config_path() ) - if codex_cfg.exists(): - try: - import toml as _toml - - config = _toml.loads(codex_cfg.read_text(encoding="utf-8")) - servers = config.get("mcp_servers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - codex_cfg.write_text(_toml.dumps(config), encoding="utf-8") - for name in removed: - _rich_success( - f"Removed stale MCP server '{name}' from Codex CLI config", - symbol="check", - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from Codex CLI config", - exc_info=True, - ) + _clean_toml_mcp_config(codex_cfg, expanded_stale, "Codex CLI config") - # Clean .cursor/mcp.json (only if .cursor/ directory exists) if "cursor" in target_runtimes: - cursor_mcp = project_root_path / ".cursor" / "mcp.json" - if cursor_mcp.exists(): - try: - import json as _json - - config = _json.loads(cursor_mcp.read_text(encoding="utf-8")) - servers = config.get("mcpServers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - cursor_mcp.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - _rich_success( - f"Removed stale MCP server '{name}' from .cursor/mcp.json", - symbol="check", - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from .cursor/mcp.json", - exc_info=True, - ) + _clean_json_mcp_config( + project_root_path / ".cursor" / "mcp.json", + expanded_stale, + logger, + ".cursor/mcp.json", + use_rich=True, + ) # Clean opencode.json (only if .opencode/ directory exists) if "opencode" in target_runtimes: - opencode_cfg = project_root_path / "opencode.json" - if opencode_cfg.exists() and (project_root_path / ".opencode").is_dir(): - try: - import json as _json - - config = _json.loads(opencode_cfg.read_text(encoding="utf-8")) - servers = config.get("mcp", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - opencode_cfg.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - logger.progress(f"Removed stale MCP server '{name}' from opencode.json") - except Exception: - _log.debug( - "Failed to clean stale MCP servers from opencode.json", - exc_info=True, - ) + if (project_root_path / ".opencode").is_dir(): + _clean_json_mcp_config( + project_root_path / "opencode.json", + expanded_stale, + logger, + "opencode.json", + servers_key="mcp", + ) - # Clean ~/.codeium/windsurf/mcp_config.json if "windsurf" in target_runtimes: - windsurf_mcp = Path.home() / ".codeium" / "windsurf" / "mcp_config.json" - if windsurf_mcp.exists(): - try: - import json as _json - - config = _json.loads(windsurf_mcp.read_text(encoding="utf-8")) - servers = config.get("mcpServers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - windsurf_mcp.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - _rich_success( - f"Removed stale MCP server '{name}' from Windsurf config", - symbol="check", - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from Windsurf config", - exc_info=True, - ) + _clean_json_mcp_config( + Path.home() / ".codeium" / "windsurf" / "mcp_config.json", + expanded_stale, + logger, + "Windsurf config", + use_rich=True, + ) - # Clean .gemini/settings.json (only if .gemini/ directory exists) if "gemini" in target_runtimes: - gemini_cfg = project_root_path / ".gemini" / "settings.json" - if gemini_cfg.exists(): - try: - import json as _json - - config = _json.loads(gemini_cfg.read_text(encoding="utf-8")) - servers = config.get("mcpServers", {}) - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - gemini_cfg.write_text(_json.dumps(config, indent=2), encoding="utf-8") - for name in removed: - if logger: - logger.progress( - f"Removed stale MCP server '{name}' from .gemini/settings.json" - ) - else: - _rich_success( - f"Removed stale MCP server '{name}' from .gemini/settings.json", - symbol="check", - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from .gemini/settings.json", - exc_info=True, - ) + _clean_json_mcp_config( + project_root_path / ".gemini" / "settings.json", + expanded_stale, + logger, + ".gemini/settings.json", + ) # Clean Claude Code project .mcp.json (only if .claude/ directory exists) if clean_claude_project: - claude_mcp = project_root_path / ".mcp.json" - if claude_mcp.exists() and (project_root_path / ".claude").is_dir(): - try: - import json as _json - - config = _json.loads(claude_mcp.read_text(encoding="utf-8")) - servers = config.get("mcpServers", {}) - if not isinstance(servers, dict): - servers = {} - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - claude_mcp.write_text( - _json.dumps(config, indent=2) + "\n", encoding="utf-8" - ) - for name in removed: - logger.progress(f"Removed stale MCP server '{name}' from .mcp.json") - except Exception: - _log.debug( - "Failed to clean stale MCP servers from .mcp.json", - exc_info=True, - ) + if (project_root_path / ".claude").is_dir(): + _clean_claude_config( + project_root_path / ".mcp.json", + expanded_stale, + logger, + ) # Clean Claude Code user ~/.claude.json (USER scope only) if clean_claude_user: - claude_user = Path.home() / ".claude.json" - if claude_user.exists(): - try: - import json as _json - - config = _json.loads(claude_user.read_text(encoding="utf-8")) - if isinstance(config, dict): - servers = config.get("mcpServers", {}) - if not isinstance(servers, dict): - servers = {} - removed = [n for n in expanded_stale if n in servers] - for name in removed: - del servers[name] - if removed: - claude_user.write_text( - _json.dumps(config, indent=2) + "\n", encoding="utf-8" - ) - for name in removed: - logger.progress( - f"Removed stale MCP server '{name}' from ~/.claude.json" - ) - except Exception: - _log.debug( - "Failed to clean stale MCP servers from ~/.claude.json", - exc_info=True, - ) + _clean_claude_config( + Path.home() / ".claude.json", + expanded_stale, + logger, + is_user_scope=True, + ) # ------------------------------------------------------------------ # Lockfile persistence diff --git a/src/apm_cli/integration/mcp_integrator_install.py b/src/apm_cli/integration/mcp_integrator_install.py index 2bcb8110b..be4cabf89 100644 --- a/src/apm_cli/integration/mcp_integrator_install.py +++ b/src/apm_cli/integration/mcp_integrator_install.py @@ -184,115 +184,37 @@ def _install_registry_group( return configured_count -def run_mcp_install( - mcp_deps: list, - runtime: str | None = None, - exclude: str | None = None, - verbose: bool = False, - apm_config: dict | None = None, - stored_mcp_configs: dict | None = None, - project_root=None, - user_scope: bool = False, - explicit_target: str | None = None, - logger=None, - diagnostics=None, - scope: InstallScope | None = None, -) -> int: - """Install MCP dependencies. - - Args: - mcp_deps: List of MCP dependency entries (registry strings or - MCPDependency objects). - runtime: Target specific runtime only. - exclude: Exclude specific runtime from installation. - verbose: Show detailed installation information. - apm_config: The parsed apm.yml configuration dict (optional). - When not provided, this function loads ``apm.yml`` from the project - root if it exists. - stored_mcp_configs: Previously stored MCP configs from lockfile - for diff-aware installation. When provided, servers whose - manifest config has changed are re-applied automatically. - project_root: Project root for repo-local runtime configs. - user_scope: Whether runtime configuration is being resolved at user scope. - explicit_target: Explicit target selected by CLI or manifest. - scope: InstallScope (PROJECT or USER). When USER, only - runtimes whose adapter declares ``supports_user_scope`` - are targeted; workspace-only runtimes are skipped. - - Returns: - Number of MCP servers newly configured or updated. +def _resolve_target_runtimes( + runtime: str | None, + exclude: str | None, + verbose: bool, + apm_config: dict | None, + project_root, + user_scope: bool, + explicit_target: str | None, + scope: InstallScope | None, + logger, + console, +) -> list[str] | None: + """Detect, filter, and gate the target runtimes for MCP installation. + + Returns a (possibly empty) list of runtime names to target, or ``None`` + when the caller should immediately return 0 (e.g. all runtimes excluded, + no user-scope-capable runtimes available). """ - # Local import: ``mcp_integrator`` must finish loading before this module - # is first imported (``MCPIntegrator.install`` delegates here lazily). from apm_cli.integration.mcp_integrator import ( MCPIntegrator, - _get_console, _is_vscode_available, ) - if logger is None: - logger = NullCommandLogger() - if not mcp_deps: - logger.warning("No MCP dependencies found in apm.yml") - return 0 - - from apm_cli.core.scope import InstallScope - - # The explicit scope enum takes precedence over the raw user_scope bool - # so callers cannot accidentally mix user-scope runtime filtering with - # project-scope config writes (or the inverse). - if scope is InstallScope.USER: - user_scope = True - elif scope is InstallScope.PROJECT: - user_scope = False - - # Split into registry-resolved and self-defined deps - # Backward compat: plain strings are treated as registry deps - registry_deps = [ - dep - for dep in mcp_deps - if isinstance(dep, str) - or (hasattr(dep, "is_registry_resolved") and dep.is_registry_resolved) - ] - self_defined_deps = [ - dep for dep in mcp_deps if hasattr(dep, "is_self_defined") and dep.is_self_defined - ] - registry_dep_names = [dep.name if hasattr(dep, "name") else dep for dep in registry_deps] - - console = _get_console() - # Track servers that were re-applied due to config drift - servers_to_update: builtins.set = builtins.set() - # Track successful updates separately so the summary counts are accurate - # even when some drift-detected servers fail to install. - successful_updates: builtins.set = builtins.set() - if stored_mcp_configs is None: - stored_mcp_configs = {} - - # Start MCP section with clean header - if console: - try: - from rich.text import Text - - header = Text() - header.append("+- MCP Servers (", style="cyan") - header.append(str(len(mcp_deps)), style="cyan bold") - header.append(")", style="cyan") - console.print(header) - except Exception: - logger.progress(f"Installing MCP dependencies ({len(mcp_deps)})...") - else: - logger.progress(f"Installing MCP dependencies ({len(mcp_deps)})...") - - # Runtime detection and multi-runtime installation if runtime: - # Single runtime mode - target_runtimes = [runtime] + # Single runtime mode — skip auto-discovery entirely. logger.progress(f"Targeting specific runtime: {runtime}") + target_runtimes: list[str] = [runtime] else: project_root_path = Path(project_root) if project_root is not None else Path.cwd() if apm_config is None: - # Lazy load -- only when the caller doesn't provide it try: apm_yml = project_root_path / "apm.yml" if apm_yml.exists(): @@ -308,7 +230,7 @@ def run_mcp_install( from apm_cli.runtime.manager import RuntimeManager manager = RuntimeManager() - installed_runtimes = [] + installed_runtimes: list[str] = [] for runtime_name in [ "copilot", @@ -438,7 +360,7 @@ def run_mcp_install( logger.warning( f"All installed runtimes excluded (--exclude {exclude}), skipping MCP configuration" ) - return 0 + return None # Fall back to VS Code only if no runtimes are installed at all if not target_runtimes and not installed_runtimes: @@ -459,11 +381,13 @@ def run_mcp_install( # Explicit runtime/exclusion/gating can leave nothing to configure. if not target_runtimes: - return 0 + return None # Scope filtering: at USER scope, keep only global-capable runtimes. # Applied after both explicit --runtime and auto-discovery paths. - if scope is InstallScope.USER: + from apm_cli.core.scope import InstallScope as _IS + + if scope is _IS.USER: from apm_cli.factory import ClientFactory as _CF pre_filter = list(target_runtimes) @@ -488,7 +412,265 @@ def run_mcp_install( logger.warning( "No runtimes support user-scope MCP installation (supported: copilot, codex, gemini)" ) - return 0 + return None + + return target_runtimes + + +def _install_self_defined_deps( + self_defined_deps: list, + target_runtimes: list[str], + stored_mcp_configs: dict, + servers_to_update: builtins.set, + successful_updates: builtins.set, + project_root, + user_scope: bool, + verbose: bool, + console, + logger, +) -> int: + """Install self-defined (``registry: false``) MCP deps for all target runtimes. + + Mutates ``servers_to_update`` and ``successful_updates`` in-place. + Returns the number of servers newly configured or updated. + """ + from apm_cli.integration.mcp_integrator import MCPIntegrator + + configured_count = 0 + self_defined_names = [dep.name for dep in self_defined_deps] + self_defined_to_install = MCPIntegrator._check_self_defined_servers_needing_installation( + self_defined_names, + target_runtimes, + project_root=project_root, + user_scope=user_scope, + ) + already_configured_candidates_sd = [ + name for name in self_defined_names if name not in self_defined_to_install + ] + + # Detect config drift for "already configured" self-defined servers + if stored_mcp_configs and already_configured_candidates_sd: + drifted_sd_deps = [ + dep for dep in self_defined_deps if dep.name in already_configured_candidates_sd + ] + drifted_sd = MCPIntegrator._detect_mcp_config_drift( + drifted_sd_deps, + stored_mcp_configs, + ) + if drifted_sd: + servers_to_update.update(drifted_sd) + MCPIntegrator._append_drifted_to_install_list(self_defined_to_install, drifted_sd) + already_configured_self_defined = [ + name for name in already_configured_candidates_sd if name not in servers_to_update + ] + + if already_configured_self_defined: + if console: + for name in already_configured_self_defined: + console.print( + f"| [green]{STATUS_SYMBOLS['check']}[/green] {name} " + f"[dim](already configured)[/dim]" + ) + else: + count = len(already_configured_self_defined) + logger.success(f"{count} self-defined server(s) already configured") + for name in already_configured_self_defined: + logger.verbose_detail(f"{name} already configured, skipping") + + for dep in self_defined_deps: + if dep.name not in self_defined_to_install: + continue + + is_update = dep.name in servers_to_update + synthetic_info = MCPIntegrator._build_self_defined_info(dep) + self_defined_cache = {dep.name: synthetic_info} + self_defined_env = dep.env or {} + + transport_label = dep.transport or "stdio" + action_text = "Updating" if is_update else "Configuring" + if console: + console.print( + f"| [cyan]{STATUS_SYMBOLS['running']}[/cyan] {dep.name} " + f"[dim](self-defined, {transport_label})[/dim]" + ) + console.print( + f"| +- {action_text} for {', '.join([rt.title() for rt in target_runtimes])}..." + ) + else: + logger.progress( + f"{dep.name}: {action_text.lower()} for {', '.join(target_runtimes)}..." + ) + + any_ok = False + for rt in target_runtimes: + if verbose: + logger.verbose_detail(f"Configuring {dep.name} for {rt}...") + if MCPIntegrator._install_for_runtime( + rt, + [dep.name], + self_defined_env, + self_defined_cache, + project_root=project_root, + user_scope=user_scope, + logger=logger, + ): + any_ok = True + + if any_ok: + if console: + label = "updated" if is_update else "configured" + console.print( + f"| [green]{STATUS_SYMBOLS['check']}[/green] {dep.name} -> " + f"{', '.join([rt.title() for rt in target_runtimes])}" + f" [dim]({label})[/dim]" + ) + configured_count += 1 + if is_update: + successful_updates.add(dep.name) + elif console: + console.print( + f"| [red]{STATUS_SYMBOLS['cross']}[/red] {dep.name} -- failed for all runtimes" + ) + else: + logger.error(f"{dep.name} -- failed for all runtimes") + + return configured_count + + +def _print_mcp_summary( + console, + configured_count: int, + successful_updates: builtins.set, +) -> None: + """Print the MCP install summary footer panel.""" + if not console: + return + if configured_count > 0: + # Use successful_updates (not servers_to_update) for accurate counts. + # servers_to_update = all drift-detected servers (some may have failed). + # successful_updates = servers that were re-applied AND succeeded. + update_count = builtins.len(successful_updates) + new_count = configured_count - update_count + parts = [] + if new_count > 0: + parts.append(f"configured {new_count} server{'s' if new_count != 1 else ''}") + if update_count > 0: + parts.append(f"updated {update_count} server{'s' if update_count != 1 else ''}") + console.print(f"[green]{STATUS_SYMBOLS['success']} {', '.join(parts).capitalize()}[/green]") + else: + console.print(f"[green]{STATUS_SYMBOLS['success']} All servers up to date[/green]") + + +def run_mcp_install( + mcp_deps: list, + runtime: str | None = None, + exclude: str | None = None, + verbose: bool = False, + apm_config: dict | None = None, + stored_mcp_configs: dict | None = None, + project_root=None, + user_scope: bool = False, + explicit_target: str | None = None, + logger=None, + diagnostics=None, + scope: InstallScope | None = None, +) -> int: + """Install MCP dependencies. + + Args: + mcp_deps: List of MCP dependency entries (registry strings or + MCPDependency objects). + runtime: Target specific runtime only. + exclude: Exclude specific runtime from installation. + verbose: Show detailed installation information. + apm_config: The parsed apm.yml configuration dict (optional). + When not provided, this function loads ``apm.yml`` from the project + root if it exists. + stored_mcp_configs: Previously stored MCP configs from lockfile + for diff-aware installation. When provided, servers whose + manifest config has changed are re-applied automatically. + project_root: Project root for repo-local runtime configs. + user_scope: Whether runtime configuration is being resolved at user scope. + explicit_target: Explicit target selected by CLI or manifest. + scope: InstallScope (PROJECT or USER). When USER, only + runtimes whose adapter declares ``supports_user_scope`` + are targeted; workspace-only runtimes are skipped. + + Returns: + Number of MCP servers newly configured or updated. + """ + # Local import: ``mcp_integrator`` must finish loading before this module + # is first imported (``MCPIntegrator.install`` delegates here lazily). + from apm_cli.integration.mcp_integrator import _get_console + + if logger is None: + logger = NullCommandLogger() + if not mcp_deps: + logger.warning("No MCP dependencies found in apm.yml") + return 0 + + from apm_cli.core.scope import InstallScope + + # The explicit scope enum takes precedence over the raw user_scope bool + # so callers cannot accidentally mix user-scope runtime filtering with + # project-scope config writes (or the inverse). + if scope is InstallScope.USER: + user_scope = True + elif scope is InstallScope.PROJECT: + user_scope = False + + # Split into registry-resolved and self-defined deps + # Backward compat: plain strings are treated as registry deps + registry_deps = [ + dep + for dep in mcp_deps + if isinstance(dep, str) + or (hasattr(dep, "is_registry_resolved") and dep.is_registry_resolved) + ] + self_defined_deps = [ + dep for dep in mcp_deps if hasattr(dep, "is_self_defined") and dep.is_self_defined + ] + registry_dep_names = [dep.name if hasattr(dep, "name") else dep for dep in registry_deps] + + console = _get_console() + # Track servers that were re-applied due to config drift + servers_to_update: builtins.set = builtins.set() + # Track successful updates separately so the summary counts are accurate + # even when some drift-detected servers fail to install. + successful_updates: builtins.set = builtins.set() + if stored_mcp_configs is None: + stored_mcp_configs = {} + + # Start MCP section with clean header + if console: + try: + from rich.text import Text + + header = Text() + header.append("+- MCP Servers (", style="cyan") + header.append(str(len(mcp_deps)), style="cyan bold") + header.append(")", style="cyan") + console.print(header) + except Exception: + logger.progress(f"Installing MCP dependencies ({len(mcp_deps)})...") + else: + logger.progress(f"Installing MCP dependencies ({len(mcp_deps)})...") + + # Runtime detection, gating, and scope filtering + target_runtimes = _resolve_target_runtimes( + runtime=runtime, + exclude=exclude, + verbose=verbose, + apm_config=apm_config, + project_root=project_root, + user_scope=user_scope, + explicit_target=explicit_target, + scope=scope, + logger=logger, + console=console, + ) + if target_runtimes is None: + return 0 # Use the new registry operations module for better server detection configured_count = 0 @@ -539,122 +721,20 @@ def run_mcp_install( # --- Self-defined deps (registry: false) --- if self_defined_deps: - self_defined_names = [dep.name for dep in self_defined_deps] - self_defined_to_install = MCPIntegrator._check_self_defined_servers_needing_installation( - self_defined_names, - target_runtimes, + configured_count += _install_self_defined_deps( + self_defined_deps=self_defined_deps, + target_runtimes=target_runtimes, + stored_mcp_configs=stored_mcp_configs, + servers_to_update=servers_to_update, + successful_updates=successful_updates, project_root=project_root, user_scope=user_scope, + verbose=verbose, + console=console, + logger=logger, ) - already_configured_candidates_sd = [ - name for name in self_defined_names if name not in self_defined_to_install - ] - - # Detect config drift for "already configured" self-defined servers - if stored_mcp_configs and already_configured_candidates_sd: - drifted_sd_deps = [ - dep for dep in self_defined_deps if dep.name in already_configured_candidates_sd - ] - drifted_sd = MCPIntegrator._detect_mcp_config_drift( - drifted_sd_deps, - stored_mcp_configs, - ) - if drifted_sd: - servers_to_update.update(drifted_sd) - MCPIntegrator._append_drifted_to_install_list(self_defined_to_install, drifted_sd) - already_configured_self_defined = [ - name for name in already_configured_candidates_sd if name not in servers_to_update - ] - - if already_configured_self_defined: - if console: - for name in already_configured_self_defined: - console.print( - f"| [green]{STATUS_SYMBOLS['check']}[/green] {name} " - f"[dim](already configured)[/dim]" - ) - else: - count = len(already_configured_self_defined) - logger.success(f"{count} self-defined server(s) already configured") - for name in already_configured_self_defined: - logger.verbose_detail(f"{name} already configured, skipping") - - for dep in self_defined_deps: - if dep.name not in self_defined_to_install: - continue - - is_update = dep.name in servers_to_update - synthetic_info = MCPIntegrator._build_self_defined_info(dep) - self_defined_cache = {dep.name: synthetic_info} - self_defined_env = dep.env or {} - - transport_label = dep.transport or "stdio" - action_text = "Updating" if is_update else "Configuring" - if console: - console.print( - f"| [cyan]{STATUS_SYMBOLS['running']}[/cyan] {dep.name} " - f"[dim](self-defined, {transport_label})[/dim]" - ) - console.print( - f"| +- {action_text} for " - f"{', '.join([rt.title() for rt in target_runtimes])}..." - ) - else: - logger.progress( - f"{dep.name}: {action_text.lower()} for {', '.join(target_runtimes)}..." - ) - - any_ok = False - for rt in target_runtimes: - if verbose: - logger.verbose_detail(f"Configuring {dep.name} for {rt}...") - if MCPIntegrator._install_for_runtime( - rt, - [dep.name], - self_defined_env, - self_defined_cache, - project_root=project_root, - user_scope=user_scope, - logger=logger, - ): - any_ok = True - - if any_ok: - if console: - label = "updated" if is_update else "configured" - console.print( - f"| [green]{STATUS_SYMBOLS['check']}[/green] {dep.name} -> " - f"{', '.join([rt.title() for rt in target_runtimes])}" - f" [dim]({label})[/dim]" - ) - configured_count += 1 - if is_update: - successful_updates.add(dep.name) - elif console: - console.print( - f"| [red]{STATUS_SYMBOLS['cross']}[/red] {dep.name} " - "-- failed for all runtimes" - ) - else: - logger.error(f"{dep.name} -- failed for all runtimes") # Close the panel - if console: - if configured_count > 0: - # Use successful_updates (not servers_to_update) for accurate counts. - # servers_to_update = all drift-detected servers (some may have failed). - # successful_updates = servers that were re-applied AND succeeded. - update_count = builtins.len(successful_updates) - new_count = configured_count - update_count - parts = [] - if new_count > 0: - parts.append(f"configured {new_count} server{'s' if new_count != 1 else ''}") - if update_count > 0: - parts.append(f"updated {update_count} server{'s' if update_count != 1 else ''}") - console.print( - f"[green]{STATUS_SYMBOLS['success']} {', '.join(parts).capitalize()}[/green]" - ) - else: - console.print(f"[green]{STATUS_SYMBOLS['success']} All servers up to date[/green]") + _print_mcp_summary(console, configured_count, successful_updates) return configured_count diff --git a/src/apm_cli/marketplace/publisher.py b/src/apm_cli/marketplace/publisher.py index d4de9cb41..22fb11a20 100644 --- a/src/apm_cli/marketplace/publisher.py +++ b/src/apm_cli/marketplace/publisher.py @@ -36,7 +36,10 @@ from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Any, Optional # noqa: F401 +from typing import TYPE_CHECKING, Any, Optional # noqa: F401 + +if TYPE_CHECKING: + from .semver import SemVer import yaml @@ -500,6 +503,153 @@ def _process(idx: int, target: ConsumerTarget) -> TargetResult: # Return in plan.targets order return [results[i] for i in range(len(plan.targets))] + # -- per-target helpers ------------------------------------------------- + + def _load_consumer_manifest( + self, + clone_dir: Path, + target: ConsumerTarget, + plan: PublishPlan, + ) -> tuple[dict | None, Path, TargetResult | None]: + """Load and validate consumer apm.yml. + + Returns ``(data, apm_yml_path, None)`` on success or + ``(None, apm_yml_path, TargetResult)`` on first error. + """ + apm_yml_path = clone_dir / target.path_in_repo + try: + ensure_path_within(apm_yml_path, clone_dir) + except PathTraversalError: + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message="Path traversal rejected: " + target.path_in_repo, + ), + ) + + if not apm_yml_path.exists(): + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=f"File not found: {target.path_in_repo}", + ), + ) + + try: + raw_text = apm_yml_path.read_text(encoding="utf-8") + data = yaml.safe_load(raw_text) + except (yaml.YAMLError, OSError) as exc: + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=f"Failed to parse {target.path_in_repo}: {exc}", + ), + ) + + if not isinstance(data, dict): + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message="Invalid apm.yml: expected a mapping", + ), + ) + + deps = data.get("dependencies") + if not isinstance(deps, dict): + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=f"Marketplace '{plan.marketplace_name}' not referenced in apm.yml", + ), + ) + + apm_deps = deps.get("apm") + if not isinstance(apm_deps, list): + return ( + None, + apm_yml_path, + TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=f"Marketplace '{plan.marketplace_name}' not referenced in apm.yml", + ), + ) + + return data, apm_yml_path, None + + def _check_ref_guards( + self, + matches: list[tuple[int, str, str | None, str]], + target: ConsumerTarget, + plan: PublishPlan, + new_sv: SemVer | None, + ) -> TargetResult | None: + """Check ref-change and downgrade guards. Returns error result or None.""" + new_ref = plan.new_ref + for _idx, _pname, old_ref, entry_str in matches: + if old_ref == new_ref: + continue + + # Ref-change guard + if old_ref is None: + if not plan.allow_ref_change: + return TargetResult( + target=target, + outcome=PublishOutcome.SKIPPED_REF_CHANGE, + message=( + f"Entry '{entry_str}' uses implicit " + "latest; pass allow_ref_change to pin" + ), + old_version=None, + new_version=new_ref, + ) + else: + old_sv = parse_semver(old_ref.lstrip("vV")) + if old_sv is None and new_sv is not None: + if not plan.allow_ref_change: + return TargetResult( + target=target, + outcome=PublishOutcome.SKIPPED_REF_CHANGE, + message=( + f"Entry '{entry_str}' uses " + f"non-semver ref '{old_ref}'; " + "pass allow_ref_change to switch" + ), + old_version=old_ref, + new_version=new_ref, + ) + + # Downgrade guard + if old_sv and new_sv and new_sv < old_sv: + if not plan.allow_downgrade: + return TargetResult( + target=target, + outcome=PublishOutcome.SKIPPED_DOWNGRADE, + message=( + f"Downgrade from {old_ref} to " + f"{new_ref}; pass allow_downgrade " + "to override" + ), + old_version=old_ref, + new_version=new_ref, + ) + return None + # -- per-target processing ---------------------------------------------- def _process_single_target( @@ -556,56 +706,12 @@ def _process_single_target( ) # 3. Load consumer apm.yml - apm_yml_path = clone_dir / target.path_in_repo - try: - ensure_path_within(apm_yml_path, clone_dir) - except PathTraversalError: - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=("Path traversal rejected: " + target.path_in_repo), - ) - - if not apm_yml_path.exists(): - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=f"File not found: {target.path_in_repo}", - ) - - try: - raw_text = apm_yml_path.read_text(encoding="utf-8") - data = yaml.safe_load(raw_text) - except (yaml.YAMLError, OSError) as exc: - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=(f"Failed to parse {target.path_in_repo}: {exc}"), - ) - - if not isinstance(data, dict): - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message="Invalid apm.yml: expected a mapping", - ) + data, apm_yml_path, manifest_err = self._load_consumer_manifest(clone_dir, target, plan) + if manifest_err is not None: + return manifest_err # 4. Find matching marketplace entries in dependencies.apm - deps = data.get("dependencies") - if not isinstance(deps, dict): - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=(f"Marketplace '{plan.marketplace_name}' not referenced in apm.yml"), - ) - - apm_deps = deps.get("apm") - if not isinstance(apm_deps, list): - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=(f"Marketplace '{plan.marketplace_name}' not referenced in apm.yml"), - ) + apm_deps = data["dependencies"]["apm"] # Parse each entry with parse_marketplace_ref new_ref = plan.new_ref @@ -643,56 +749,9 @@ def _process_single_target( # 6. Guards -- check every entry that would change new_sv = parse_semver(new_ref.lstrip("vV")) - - for _idx, _pname, old_ref, entry_str in matches: - if old_ref == new_ref: - continue # Already at target -- no guard needed - - # Ref-change guard - if old_ref is None: - # Implicit latest -> explicit pin - if not plan.allow_ref_change: - return TargetResult( - target=target, - outcome=PublishOutcome.SKIPPED_REF_CHANGE, - message=( - f"Entry '{entry_str}' uses implicit " - "latest; pass allow_ref_change to pin" - ), - old_version=None, - new_version=new_ref, - ) - else: - old_sv = parse_semver(old_ref.lstrip("vV")) - if old_sv is None and new_sv is not None: - # Non-semver ref -> semver tag - if not plan.allow_ref_change: - return TargetResult( - target=target, - outcome=(PublishOutcome.SKIPPED_REF_CHANGE), - message=( - f"Entry '{entry_str}' uses " - f"non-semver ref '{old_ref}'; " - "pass allow_ref_change to switch" - ), - old_version=old_ref, - new_version=new_ref, - ) - - # Downgrade guard - if old_sv and new_sv and new_sv < old_sv: - if not plan.allow_downgrade: - return TargetResult( - target=target, - outcome=(PublishOutcome.SKIPPED_DOWNGRADE), - message=( - f"Downgrade from {old_ref} to " - f"{new_ref}; pass allow_downgrade " - "to override" - ), - old_version=old_ref, - new_version=new_ref, - ) + guard_err = self._check_ref_guards(matches, target, plan, new_sv) + if guard_err is not None: + return guard_err # 7. No-change check needs_update = any(old_ref != new_ref for _, _, old_ref, _ in matches) diff --git a/src/apm_cli/marketplace/semver.py b/src/apm_cli/marketplace/semver.py index 612c9873f..88134f2d6 100644 --- a/src/apm_cli/marketplace/semver.py +++ b/src/apm_cli/marketplace/semver.py @@ -168,6 +168,15 @@ def satisfies_range(version: SemVer, range_spec: str) -> bool: return _satisfies_single(version, spec) +# Comparison operators -- longest prefix first so ">=" is tested before ">". +_CMP_OPS: list[tuple[str, object]] = [ + (">=", lambda v, b: v >= b), + (">", lambda v, b: v > b), + ("<=", lambda v, b: v <= b), + ("<", lambda v, b: v < b), +] + + def _satisfies_single(version: SemVer, spec: str) -> bool: """Check a single constraint.""" spec = spec.strip() @@ -201,19 +210,11 @@ def _satisfies_single(version: SemVer, spec: str) -> bool: # ~1.2.3 := >=1.2.3 <1.3.0 return version >= base and version.major == base.major and version.minor == base.minor - # Comparison operators - if spec.startswith(">="): - base = parse_semver(spec[2:]) - return base is not None and version >= base - if spec.startswith(">") and not spec.startswith(">="): - base = parse_semver(spec[1:]) - return base is not None and version > base - if spec.startswith("<="): - base = parse_semver(spec[2:]) - return base is not None and version <= base - if spec.startswith("<") and not spec.startswith("<="): - base = parse_semver(spec[1:]) - return base is not None and version < base + # Comparison operators (table-driven dispatch) + for prefix, cmp in _CMP_OPS: + if spec.startswith(prefix): + base = parse_semver(spec[len(prefix) :]) + return base is not None and cmp(version, base) # Wildcard: 1.2.x or 1.2.* wildcard_match = re.match(r"^(\d+)\.(\d+)\.[xX*]$", spec) diff --git a/src/apm_cli/policy/policy_checks.py b/src/apm_cli/policy/policy_checks.py index 1c2e5b5b7..2ad27a234 100644 --- a/src/apm_cli/policy/policy_checks.py +++ b/src/apm_cli/policy/policy_checks.py @@ -875,24 +875,17 @@ def _run(check: CheckResult) -> bool: if _run(_check_mcp_self_defined(mcp_list, policy.mcp)): return result - # -- Target / compilation checks (11-13) ----------------------- - # Skipped when effective_target is None -- those run in a separate - # post-targets call (W2-target-aware). + # -- Target / compilation + manifest tail checks ---------------- + # Collect applicable tail checks, then run in a single loop so + # the function stays within the max-returns threshold. + tail_checks: list[CheckResult] = [] if effective_target is not None: - # Build a minimal raw_yml dict so _check_compilation_target - # sees the effective (possibly CLI-overridden) target value - # rather than what is literally on disk. synthetic_yml = {"target": effective_target} - if _run(_check_compilation_target(synthetic_yml, policy.compilation)): - return result - - # -- Manifest-level explicit-includes check -------------------- - # Only run when the caller supplied the manifest includes value. - # Dep-only seams that lack manifest context (legacy callers) skip - # this check; the install pipeline and ``apm audit`` wrappers both - # supply it. + tail_checks.append(_check_compilation_target(synthetic_yml, policy.compilation)) if manifest_includes is not _INCLUDES_NOT_PROVIDED: - if _run(_check_includes_explicit(manifest_includes, policy.manifest)): + tail_checks.append(_check_includes_explicit(manifest_includes, policy.manifest)) + for check in tail_checks: + if _run(check): return result # NOTE: compilation strategy, source attribution, manifest fields, diff --git a/tests/integration/test_coverage_phase4.py b/tests/integration/test_coverage_phase4.py index 3c40bb505..cf4721b8f 100644 --- a/tests/integration/test_coverage_phase4.py +++ b/tests/integration/test_coverage_phase4.py @@ -643,9 +643,7 @@ def test_skipped_status_returns_early(self, tmp_path: Path) -> None: force=False, runtime=None, exclude=None, - verbose=False, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) @@ -685,9 +683,7 @@ def test_added_string_entry_no_deps_available(self, tmp_path: Path) -> None: force=False, runtime=None, exclude=None, - verbose=False, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) @@ -732,9 +728,7 @@ def test_replaced_dict_entry_no_deps_available(self, tmp_path: Path) -> None: force=True, runtime=None, exclude=None, - verbose=False, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) @@ -772,9 +766,7 @@ def test_build_entry_value_error_becomes_usage_error(self, tmp_path: Path) -> No force=False, runtime=None, exclude=None, - verbose=False, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) @@ -829,9 +821,7 @@ def test_mcp_integrator_failure_raises_click_exception(self, tmp_path: Path) -> force=False, runtime=None, exclude=None, - verbose=True, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) @@ -887,9 +877,7 @@ def test_mcp_integrator_success_updates_lockfile(self, tmp_path: Path) -> None: force=False, runtime=None, exclude=None, - verbose=False, logger=logger, - manifest_path=apm_yml, apm_dir=tmp_path, scope=None, ) diff --git a/tests/integration/test_deps_models_phase3c.py b/tests/integration/test_deps_models_phase3c.py index c6d2360cb..a7cda61d2 100644 --- a/tests/integration/test_deps_models_phase3c.py +++ b/tests/integration/test_deps_models_phase3c.py @@ -32,49 +32,15 @@ # --------------------------------------------------------------------------- -def _make_dep_ref( - repo_url: str = "owner/repo", - host: str | None = "github.com", - port: int | None = None, - reference: str | None = None, - is_virtual: bool = False, - virtual_path: str | None = None, - is_local: bool = False, - local_path: str | None = None, - is_insecure: bool = False, - ado_organization: str | None = None, - ado_project: str | None = None, - ado_repo: str | None = None, - alias: str | None = None, - is_parent_repo_inheritance: bool = False, - explicit_scheme: str | None = None, - skill_subset: list[str] | None = None, - artifactory_prefix: str | None = None, - allow_insecure: bool = False, -) -> Any: +def _make_dep_ref(**kwargs: Any) -> Any: """Build a DependencyReference without network calls.""" from apm_cli.models.dependency.reference import DependencyReference - return DependencyReference( - repo_url=repo_url, - host=host, - port=port, - reference=reference, - is_virtual=is_virtual, - virtual_path=virtual_path, - is_local=is_local, - local_path=local_path, - is_insecure=is_insecure, - ado_organization=ado_organization, - ado_project=ado_project, - ado_repo=ado_repo, - alias=alias, - is_parent_repo_inheritance=is_parent_repo_inheritance, - explicit_scheme=explicit_scheme, - skill_subset=skill_subset, - artifactory_prefix=artifactory_prefix, - allow_insecure=allow_insecure, - ) + defaults = { + "repo_url": "owner/repo", + "host": "github.com", + } + return DependencyReference(**{**defaults, **kwargs}) def _make_downloader() -> Any: diff --git a/tests/integration/test_deps_models_validation.py b/tests/integration/test_deps_models_validation.py index a02dedb1b..ac58be8d3 100644 --- a/tests/integration/test_deps_models_validation.py +++ b/tests/integration/test_deps_models_validation.py @@ -32,49 +32,15 @@ # --------------------------------------------------------------------------- -def _make_dep_ref( - repo_url: str = "owner/repo", - host: str | None = "github.com", - port: int | None = None, - reference: str | None = None, - is_virtual: bool = False, - virtual_path: str | None = None, - is_local: bool = False, - local_path: str | None = None, - is_insecure: bool = False, - ado_organization: str | None = None, - ado_project: str | None = None, - ado_repo: str | None = None, - alias: str | None = None, - is_parent_repo_inheritance: bool = False, - explicit_scheme: str | None = None, - skill_subset: list[str] | None = None, - artifactory_prefix: str | None = None, - allow_insecure: bool = False, -) -> Any: +def _make_dep_ref(**kwargs: Any) -> Any: """Build a DependencyReference without network calls.""" from apm_cli.models.dependency.reference import DependencyReference - return DependencyReference( - repo_url=repo_url, - host=host, - port=port, - reference=reference, - is_virtual=is_virtual, - virtual_path=virtual_path, - is_local=is_local, - local_path=local_path, - is_insecure=is_insecure, - ado_organization=ado_organization, - ado_project=ado_project, - ado_repo=ado_repo, - alias=alias, - is_parent_repo_inheritance=is_parent_repo_inheritance, - explicit_scheme=explicit_scheme, - skill_subset=skill_subset, - artifactory_prefix=artifactory_prefix, - allow_insecure=allow_insecure, - ) + defaults = { + "repo_url": "owner/repo", + "host": "github.com", + } + return DependencyReference(**{**defaults, **kwargs}) def _make_downloader() -> Any: diff --git a/tests/integration/test_install_services_orchestration.py b/tests/integration/test_install_services_orchestration.py index b5f03407d..3f88cdbe8 100644 --- a/tests/integration/test_install_services_orchestration.py +++ b/tests/integration/test_install_services_orchestration.py @@ -7,6 +7,7 @@ import pytest from apm_cli.install import services +from apm_cli.install.services import IntegratorBundle from apm_cli.integration.base_integrator import IntegrationResult @@ -163,7 +164,14 @@ def invoke_integrate( skill_subset=skill_subset, ctx=ctx, scratch_root=scratch_root, - **integrators, + integrators=IntegratorBundle( + prompt=integrators["prompt_integrator"], + agent=integrators["agent_integrator"], + skill=integrators["skill_integrator"], + instruction=integrators["instruction_integrator"], + command=integrators["command_integrator"], + hook=integrators["hook_integrator"], + ), ) return result, integrators, diagnostics, logger @@ -321,14 +329,16 @@ def test_cowork_warning_emits_once_and_sets_context(self, tmp_path: Path) -> Non make_package_info(package_dir), tmp_path, targets=[target], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, @@ -808,14 +818,16 @@ def test_package_name_override_is_used_in_cowork_warning(self, tmp_path: Path) - make_package_info(package_dir), tmp_path, targets=[make_target(name="copilot-cowork", primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, @@ -841,14 +853,16 @@ def test_package_info_name_is_used_when_package_name_empty(self, tmp_path: Path) pkg_info, tmp_path, targets=[make_target(name="copilot-cowork", primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, diff --git a/tests/integration/test_install_services_phase3w5.py b/tests/integration/test_install_services_phase3w5.py index b5f03407d..3f88cdbe8 100644 --- a/tests/integration/test_install_services_phase3w5.py +++ b/tests/integration/test_install_services_phase3w5.py @@ -7,6 +7,7 @@ import pytest from apm_cli.install import services +from apm_cli.install.services import IntegratorBundle from apm_cli.integration.base_integrator import IntegrationResult @@ -163,7 +164,14 @@ def invoke_integrate( skill_subset=skill_subset, ctx=ctx, scratch_root=scratch_root, - **integrators, + integrators=IntegratorBundle( + prompt=integrators["prompt_integrator"], + agent=integrators["agent_integrator"], + skill=integrators["skill_integrator"], + instruction=integrators["instruction_integrator"], + command=integrators["command_integrator"], + hook=integrators["hook_integrator"], + ), ) return result, integrators, diagnostics, logger @@ -321,14 +329,16 @@ def test_cowork_warning_emits_once_and_sets_context(self, tmp_path: Path) -> Non make_package_info(package_dir), tmp_path, targets=[target], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, @@ -808,14 +818,16 @@ def test_package_name_override_is_used_in_cowork_warning(self, tmp_path: Path) - make_package_info(package_dir), tmp_path, targets=[make_target(name="copilot-cowork", primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, @@ -841,14 +853,16 @@ def test_package_info_name_is_used_when_package_name_empty(self, tmp_path: Path) pkg_info, tmp_path, targets=[make_target(name="copilot-cowork", primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=diagnostics, diff --git a/tests/integration/test_wave7_policy_registry_coverage.py b/tests/integration/test_wave7_policy_registry_coverage.py index 6e9737659..95fa6d6f5 100644 --- a/tests/integration/test_wave7_policy_registry_coverage.py +++ b/tests/integration/test_wave7_policy_registry_coverage.py @@ -793,6 +793,35 @@ def test_fail_fast_false_collects_all_checks(self): assert "dependency-allowlist" in names assert "dependency-denylist" in names + def test_consolidated_tail_checks_cover_all_categories(self): + """Regression trap for the consolidated tail_checks loop (PR #1464). + + Exercises compilation-target AND manifest-includes in a single + call together with an MCP denylist violation, confirming no check + category is silently dropped by the consolidated loop. + """ + policy = ApmPolicy( + compilation=CompilationPolicy( + target=CompilationTargetPolicy(enforce="vscode"), + ), + manifest=ManifestPolicy(require_explicit_includes=True), + mcp=McpPolicy(deny=("evil-srv",)), + ) + mcp = _mcp("evil-srv") + result = run_dependency_policy_checks( + [], + policy=policy, + fail_fast=False, + effective_target="claude", + manifest_includes=None, + mcp_deps=[mcp], + ) + names = [c.name for c in result.checks] + assert "mcp-denylist" in names + assert "compilation-target" in names + assert "explicit-includes" in names + assert not result.passed + _MINIMAL_APM_YML = dedent("""\ name: test-project diff --git a/tests/unit/commands/test_pack.py b/tests/unit/commands/test_pack.py index 80ecc83a9..0694fd45a 100644 --- a/tests/unit/commands/test_pack.py +++ b/tests/unit/commands/test_pack.py @@ -5,11 +5,17 @@ import textwrap as _tw from pathlib import Path from types import SimpleNamespace +from unittest.mock import MagicMock import pytest from click.testing import CliRunner -from apm_cli.commands.pack import _render_marketplace_result, pack_cmd +from apm_cli.commands.pack import ( + _parse_marketplace_filter, + _parse_path_overrides, + _render_marketplace_result, + pack_cmd, +) from apm_cli.marketplace.builder import BuildReport, MarketplaceOutputReport @@ -44,6 +50,139 @@ def test_pack_help_recommends_manifest_marketplace_output_config() -> None: assert "--claude-output" not in result.output +# --------------------------------------------------------------------------- +# _parse_path_overrides unit tests +# --------------------------------------------------------------------------- + + +def _make_ctx(json_output: bool = False): + """Return a minimal Click context mock for helper tests.""" + ctx = MagicMock() + exited = [] + + def _exit(code=0): + exited.append(code) + + ctx.exit.side_effect = _exit + ctx._exited = exited + return ctx + + +class TestParsePathOverrides: + """Unit tests for _parse_path_overrides().""" + + def test_empty_tuple_returns_empty_dict(self) -> None: + ctx = _make_ctx() + result = _parse_path_overrides((), ctx, json_output=False) + assert result == {} + + def test_valid_single_override(self) -> None: + ctx = _make_ctx() + result = _parse_path_overrides(("claude=dist/marketplace.json",), ctx, json_output=False) + assert result == {"claude": "dist/marketplace.json"} + + def test_valid_multiple_overrides(self) -> None: + ctx = _make_ctx() + result = _parse_path_overrides( + ("claude=dist/claude.json", "codex=dist/codex.json"), + ctx, + json_output=False, + ) + assert result == {"claude": "dist/claude.json", "codex": "dist/codex.json"} + + def test_missing_equals_returns_none(self) -> None: + ctx = _make_ctx(json_output=True) + result = _parse_path_overrides(("claude-no-equals",), ctx, json_output=True) + assert result is None + + def test_unknown_format_returns_none(self) -> None: + ctx = _make_ctx(json_output=True) + result = _parse_path_overrides(("unknown_format=dist/foo.json",), ctx, json_output=True) + assert result is None + + def test_path_traversal_returns_none(self) -> None: + ctx = _make_ctx(json_output=True) + result = _parse_path_overrides(("claude=../../etc/passwd",), ctx, json_output=True) + assert result is None + + def test_missing_equals_raises_click_exception_non_json(self) -> None: + import click as _click + + ctx = _make_ctx() + with pytest.raises(_click.ClickException): + _parse_path_overrides(("no-equals",), ctx, json_output=False) + + def test_path_traversal_raises_click_exception_non_json(self) -> None: + import click as _click + + ctx = _make_ctx() + with pytest.raises(_click.ClickException): + _parse_path_overrides(("claude=../../etc/passwd",), ctx, json_output=False) + + def test_strips_whitespace_around_name_and_path(self) -> None: + ctx = _make_ctx() + result = _parse_path_overrides( + (" claude = dist/marketplace.json ",), ctx, json_output=False + ) + assert result == {"claude": "dist/marketplace.json"} + + +# --------------------------------------------------------------------------- +# _parse_marketplace_filter unit tests +# --------------------------------------------------------------------------- + + +class TestParseMarketplaceFilter: + """Unit tests for _parse_marketplace_filter().""" + + def test_none_returns_none(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter(None, ctx, json_output=False) + assert result is None + + def test_none_string_returns_empty_tuple(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("none", ctx, json_output=False) + assert result == () + + def test_none_string_case_insensitive(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("NONE", ctx, json_output=False) + assert result == () + + def test_all_string_returns_none(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("all", ctx, json_output=False) + assert result is None + + def test_all_string_case_insensitive(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("ALL", ctx, json_output=False) + assert result is None + + def test_single_known_format(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("claude", ctx, json_output=False) + assert result == ("claude",) + + def test_multiple_known_formats(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter("claude,codex", ctx, json_output=False) + assert result == ("claude", "codex") + + def test_formats_with_whitespace(self) -> None: + ctx = _make_ctx() + result = _parse_marketplace_filter(" claude , codex ", ctx, json_output=False) + assert result == ("claude", "codex") + + def test_unknown_format_exits_with_error(self) -> None: + ctx = _make_ctx(json_output=True) + result = _parse_marketplace_filter("unknown_format", ctx, json_output=True) + # With json_output=True, ctx.exit(1) is called and function returns None + assert result is None + assert ctx._exited == [1] + + def test_marketplace_fallback_renders_warnings_and_package_count() -> None: logger = _RecordingLogger() fallback_report = SimpleNamespace( diff --git a/tests/unit/install/test_mcp_conflicts.py b/tests/unit/install/test_mcp_conflicts.py index 8c229f9ee..3cd2c5c86 100644 --- a/tests/unit/install/test_mcp_conflicts.py +++ b/tests/unit/install/test_mcp_conflicts.py @@ -34,9 +34,7 @@ def _call(**overrides) -> None: global_=False, only=None, update=False, - use_ssh=False, - use_https=False, - allow_protocol_fallback=False, + any_transport_flag=False, registry_url=None, ) defaults.update(overrides) @@ -201,18 +199,18 @@ def test_only_none_ok(self) -> None: class TestE4TransportSelectionFlags: def test_use_ssh_raises(self) -> None: with pytest.raises(click.UsageError, match=r"transport selection flags"): - _call(use_ssh=True) + _call(any_transport_flag=True) def test_use_https_raises(self) -> None: with pytest.raises(click.UsageError, match=r"transport selection flags"): - _call(use_https=True) + _call(any_transport_flag=True) def test_allow_protocol_fallback_raises(self) -> None: with pytest.raises(click.UsageError, match=r"transport selection flags"): - _call(allow_protocol_fallback=True) + _call(any_transport_flag=True) def test_none_set_ok(self) -> None: - _call(use_ssh=False, use_https=False, allow_protocol_fallback=False) + _call(any_transport_flag=False) # --------------------------------------------------------------------------- diff --git a/tests/unit/install/test_services.py b/tests/unit/install/test_services.py index 316c7fa96..8bb37eb92 100644 --- a/tests/unit/install/test_services.py +++ b/tests/unit/install/test_services.py @@ -9,9 +9,25 @@ import pytest -from apm_cli.install.services import _deployed_path_entry +from apm_cli.install.services import IntegratorBundle, _deployed_path_entry from apm_cli.integration.targets import KNOWN_TARGETS +# --------------------------------------------------------------------------- +# Helper: convert legacy integrators dict to IntegratorBundle +# --------------------------------------------------------------------------- + + +def _to_bundle(d: dict) -> IntegratorBundle: + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @@ -55,6 +71,30 @@ def _make_copilot_app_target(app_root: Path) -> Any: return replace(KNOWN_TARGETS["copilot-app"], resolved_deploy_root=app_root) +# --------------------------------------------------------------------------- +# IntegratorBundle frozen contract +# --------------------------------------------------------------------------- + + +class TestIntegratorBundleFrozen: + """IntegratorBundle must be immutable (frozen=True contract).""" + + def test_cannot_mutate_field(self) -> None: + from dataclasses import FrozenInstanceError + from unittest.mock import MagicMock + + bundle = IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock(), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), + ) + with pytest.raises(FrozenInstanceError): + bundle.prompt = MagicMock() # type: ignore[misc] + + # --------------------------------------------------------------------------- # TestDeployedPathEntry # --------------------------------------------------------------------------- @@ -237,7 +277,7 @@ def test_warning_fires_once_per_run_with_non_skill_primitives( package_name="test-pkg", logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -261,7 +301,7 @@ def test_warning_fires_once_per_run_with_non_skill_primitives( package_name="test-pkg2", logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -312,7 +352,7 @@ def test_warning_does_not_fire_when_only_skills( diagnostics=MagicMock(), logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -360,7 +400,7 @@ def test_warning_does_not_fire_when_cowork_not_active( diagnostics=MagicMock(), logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -408,7 +448,7 @@ def test_warning_does_not_fire_when_ctx_is_none( diagnostics=MagicMock(), logger=logger, ctx=None, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -456,7 +496,7 @@ def test_warning_msg_text_includes_package_name_and_primitive_types( package_name="my-awesome-pkg", logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -509,7 +549,7 @@ def test_warning_also_emitted_to_diagnostics_warn( package_name="diag-pkg", logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -568,7 +608,7 @@ def test_warning_with_prompts_only_does_not_mention_commands( package_name="prompts-only-pkg", logger=logger, ctx=ctx, - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) diff --git a/tests/unit/install/test_services_branches.py b/tests/unit/install/test_services_branches.py index 37e82731a..80f164dc9 100644 --- a/tests/unit/install/test_services_branches.py +++ b/tests/unit/install/test_services_branches.py @@ -23,6 +23,7 @@ import pytest from apm_cli.install.services import ( + IntegratorBundle, _deployed_path_entry, _integrate_local_content, _integrate_package_primitives, @@ -32,6 +33,19 @@ ) from apm_cli.integration.targets import KNOWN_TARGETS + +def _to_bundle(d: dict) -> IntegratorBundle: + """Convert a dict of old-style integrator kwargs to an IntegratorBundle.""" + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -159,7 +173,7 @@ def test_empty_targets_returns_zero_counts(self, tmp_path: Path) -> None: tmp_path, targets=[], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -196,7 +210,7 @@ def test_scratch_root_inside_itself_is_valid(self, tmp_path: Path) -> None: project_in_scratch, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, scratch_root=scratch, @@ -227,7 +241,7 @@ def test_scratch_root_outside_raises(self, tmp_path: Path) -> None: project_root, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, scratch_root=scratch, @@ -272,7 +286,7 @@ def _call_with_skill_paths( tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -343,7 +357,7 @@ def test_sub_skills_promoted_logged_single_path(self, tmp_path: Path) -> None: tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -371,7 +385,7 @@ def test_files_unchanged_line_logged(self, tmp_path: Path) -> None: tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -418,7 +432,7 @@ def test_cowork_skill_path_outside_project_labeled_correctly(self, tmp_path: Pat project_root, targets=[cowork_target], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -441,7 +455,12 @@ def test_delegates_to_integrate_package_primitives(self, tmp_path: Path) -> None tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - **integrators, + prompt_integrator=integrators["prompt_integrator"], + agent_integrator=integrators["agent_integrator"], + skill_integrator=integrators["skill_integrator"], + instruction_integrator=integrators["instruction_integrator"], + command_integrator=integrators["command_integrator"], + hook_integrator=integrators["hook_integrator"], force=False, managed_files=None, ) @@ -468,7 +487,12 @@ def _capture(pkg_info, *args, **kwargs): tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - **integrators, + prompt_integrator=integrators["prompt_integrator"], + agent_integrator=integrators["agent_integrator"], + skill_integrator=integrators["skill_integrator"], + instruction_integrator=integrators["instruction_integrator"], + command_integrator=integrators["command_integrator"], + hook_integrator=integrators["hook_integrator"], force=False, managed_files=None, ) @@ -752,7 +776,7 @@ def test_copilot_app_path_triggers_workflow_hint(self, tmp_path: Path) -> None: tmp_path, targets=[mock_target], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, diff --git a/tests/unit/install/test_services_hook_scope.py b/tests/unit/install/test_services_hook_scope.py index cd3aef4d2..42213845e 100644 --- a/tests/unit/install/test_services_hook_scope.py +++ b/tests/unit/install/test_services_hook_scope.py @@ -30,7 +30,7 @@ import pytest from apm_cli.core.scope import InstallScope -from apm_cli.install.services import integrate_package_primitives +from apm_cli.install.services import IntegratorBundle, integrate_package_primitives from apm_cli.integration.base_integrator import IntegrationResult from apm_cli.integration.targets import KNOWN_TARGETS from apm_cli.utils.diagnostics import DiagnosticCollector @@ -103,12 +103,14 @@ def _call(scope: InstallScope, project_root: Path) -> MagicMock: package_info, project_root, targets=[_claude_hooks_only_target()], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=_make_skill_integrator(), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=hook_integrator, + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=_make_skill_integrator(), + instruction=MagicMock(), + command=MagicMock(), + hook=hook_integrator, + ), force=False, managed_files=None, diagnostics=DiagnosticCollector(), @@ -202,12 +204,14 @@ def test_non_hook_integrators_never_receive_user_scope(tmp_path: Path) -> None: package_info, tmp_path, targets=[target], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=_make_skill_integrator(), - instruction_integrator=MagicMock(), - command_integrator=command_integrator, - hook_integrator=hook_integrator, + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=_make_skill_integrator(), + instruction=MagicMock(), + command=command_integrator, + hook=hook_integrator, + ), force=False, managed_files=None, diagnostics=DiagnosticCollector(), diff --git a/tests/unit/install/test_services_phase3.py b/tests/unit/install/test_services_phase3.py index 4bcf4d139..f3f22ae97 100644 --- a/tests/unit/install/test_services_phase3.py +++ b/tests/unit/install/test_services_phase3.py @@ -23,6 +23,7 @@ import pytest from apm_cli.install.services import ( + IntegratorBundle, _deployed_path_entry, _integrate_local_content, _integrate_package_primitives, @@ -32,6 +33,19 @@ ) from apm_cli.integration.targets import KNOWN_TARGETS + +def _to_bundle(d: dict) -> IntegratorBundle: + """Convert a dict of old-style integrator kwargs to an IntegratorBundle.""" + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -159,7 +173,7 @@ def test_empty_targets_returns_zero_counts(self, tmp_path: Path) -> None: tmp_path, targets=[], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -196,7 +210,7 @@ def test_scratch_root_inside_itself_is_valid(self, tmp_path: Path) -> None: project_in_scratch, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, scratch_root=scratch, @@ -227,7 +241,7 @@ def test_scratch_root_outside_raises(self, tmp_path: Path) -> None: project_root, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, scratch_root=scratch, @@ -272,7 +286,7 @@ def _call_with_skill_paths( tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -343,7 +357,7 @@ def test_sub_skills_promoted_logged_single_path(self, tmp_path: Path) -> None: tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -371,7 +385,7 @@ def test_files_unchanged_line_logged(self, tmp_path: Path) -> None: tmp_path, targets=[copilot], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, @@ -418,7 +432,7 @@ def test_cowork_skill_path_outside_project_labeled_correctly(self, tmp_path: Pat project_root, targets=[cowork_target], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -441,7 +455,12 @@ def test_delegates_to_integrate_package_primitives(self, tmp_path: Path) -> None tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - **integrators, + prompt_integrator=integrators["prompt_integrator"], + agent_integrator=integrators["agent_integrator"], + skill_integrator=integrators["skill_integrator"], + instruction_integrator=integrators["instruction_integrator"], + command_integrator=integrators["command_integrator"], + hook_integrator=integrators["hook_integrator"], force=False, managed_files=None, ) @@ -468,7 +487,12 @@ def _capture(pkg_info, *args, **kwargs): tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - **integrators, + prompt_integrator=integrators["prompt_integrator"], + agent_integrator=integrators["agent_integrator"], + skill_integrator=integrators["skill_integrator"], + instruction_integrator=integrators["instruction_integrator"], + command_integrator=integrators["command_integrator"], + hook_integrator=integrators["hook_integrator"], force=False, managed_files=None, ) @@ -752,7 +776,7 @@ def test_copilot_app_path_triggers_workflow_hint(self, tmp_path: Path) -> None: tmp_path, targets=[mock_target], diagnostics=MagicMock(), - **integrators, + integrators=_to_bundle(integrators), force=False, managed_files=None, logger=logger, diff --git a/tests/unit/install/test_services_rendering.py b/tests/unit/install/test_services_rendering.py index 0d3ccf7a2..c61c77962 100644 --- a/tests/unit/install/test_services_rendering.py +++ b/tests/unit/install/test_services_rendering.py @@ -20,8 +20,22 @@ import pytest +from apm_cli.install.services import IntegratorBundle from apm_cli.integration.targets import KNOWN_TARGETS + +def _to_bundle(d: dict) -> IntegratorBundle: + """Convert a dict of old-style integrator kwargs to an IntegratorBundle.""" + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -158,7 +172,7 @@ def _run( ctx=_ctx(verbose=verbose), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) return _logger_lines(logger) @@ -254,7 +268,7 @@ def _run( ctx=_ctx(), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) return _logger_lines(logger) @@ -322,7 +336,7 @@ def test_emits_annotation_when_no_files_integrated(self, tmp_path: Path) -> None ctx=_ctx(), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) lines = _logger_lines(logger) @@ -351,7 +365,7 @@ def test_no_annotation_when_files_integrated(self, tmp_path: Path) -> None: ctx=_ctx(), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) lines = _logger_lines(logger) @@ -386,7 +400,7 @@ def test_no_annotation_when_skill_created(self, tmp_path: Path) -> None: ctx=_ctx(), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) lines = _logger_lines(logger) @@ -422,7 +436,7 @@ def test_counter_equals_sum_across_targets(self, tmp_path: Path) -> None: ctx=_ctx(), force=False, managed_files=None, - **kwargs, + integrators=_to_bundle(kwargs), ) assert result["agents"] == 7 diff --git a/tests/unit/integration/test_command_integrator.py b/tests/unit/integration/test_command_integrator.py index b009f1d9e..7af995d4f 100644 --- a/tests/unit/integration/test_command_integrator.py +++ b/tests/unit/integration/test_command_integrator.py @@ -16,12 +16,24 @@ import frontmatter import pytest +from apm_cli.install.services import IntegratorBundle from apm_cli.integration.command_integrator import ( CommandIntegrator, _extract_input_names, ) +def _to_bundle(d: dict) -> IntegratorBundle: + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + def _make_package(project_root, prompts): """Create a test package with .prompt.md files and return a mock PackageInfo. @@ -517,7 +529,7 @@ def test_copilot_only_does_not_dispatch_commands(self): managed_files=set(), force=False, diagnostics=diagnostics, - **integrators, + integrators=_to_bundle(integrators), ) integrators["command_integrator"].integrate_commands_for_target.assert_not_called() @@ -547,7 +559,7 @@ def test_claude_target_dispatches_commands(self): managed_files=set(), force=False, diagnostics=diagnostics, - **integrators, + integrators=_to_bundle(integrators), ) integrators["command_integrator"].integrate_commands_for_target.assert_called_once() @@ -576,7 +588,7 @@ def test_cursor_target_dispatches_commands(self): managed_files=set(), force=False, diagnostics=diagnostics, - **integrators, + integrators=_to_bundle(integrators), ) integrators["command_integrator"].integrate_commands_for_target.assert_called_once() @@ -625,12 +637,14 @@ def test_full_dispatch_deploys_to_cursor(self, temp_project): pkg_info, temp_project, targets=[KNOWN_TARGETS["cursor"]], - prompt_integrator=PromptIntegrator(), - agent_integrator=AgentIntegrator(), - skill_integrator=SkillIntegrator(), - instruction_integrator=InstructionIntegrator(), - command_integrator=CommandIntegrator(), - hook_integrator=HookIntegrator(), + integrators=IntegratorBundle( + prompt=PromptIntegrator(), + agent=AgentIntegrator(), + skill=SkillIntegrator(), + instruction=InstructionIntegrator(), + command=CommandIntegrator(), + hook=HookIntegrator(), + ), force=False, managed_files=set(), diagnostics=DiagnosticCollector(), @@ -779,12 +793,14 @@ def test_full_dispatch_maps_input_to_arguments(self, temp_project): pkg_info, temp_project, targets=[KNOWN_TARGETS["claude"]], - prompt_integrator=PromptIntegrator(), - agent_integrator=AgentIntegrator(), - skill_integrator=SkillIntegrator(), - instruction_integrator=InstructionIntegrator(), - command_integrator=CommandIntegrator(), - hook_integrator=HookIntegrator(), + integrators=IntegratorBundle( + prompt=PromptIntegrator(), + agent=AgentIntegrator(), + skill=SkillIntegrator(), + instruction=InstructionIntegrator(), + command=CommandIntegrator(), + hook=HookIntegrator(), + ), force=False, managed_files=set(), diagnostics=DiagnosticCollector(), diff --git a/tests/unit/integration/test_data_driven_dispatch.py b/tests/unit/integration/test_data_driven_dispatch.py index afddb1cf8..3a5e8b442 100644 --- a/tests/unit/integration/test_data_driven_dispatch.py +++ b/tests/unit/integration/test_data_driven_dispatch.py @@ -13,6 +13,7 @@ from unittest.mock import MagicMock from apm_cli.commands.install import _integrate_package_primitives +from apm_cli.install.services import IntegratorBundle from apm_cli.integration.base_integrator import BaseIntegrator, IntegrationResult from apm_cli.integration.targets import KNOWN_TARGETS, PrimitiveMapping, TargetProfile @@ -21,6 +22,17 @@ # ------------------------------------------------------------------ +def _to_bundle(d: dict) -> IntegratorBundle: + return IntegratorBundle( + prompt=d["prompt_integrator"], + agent=d["agent_integrator"], + skill=d["skill_integrator"], + instruction=d["instruction_integrator"], + command=d["command_integrator"], + hook=d["hook_integrator"], + ) + + def _make_integration_result(n=0): """Return an IntegrationResult with *n* files integrated.""" return IntegrationResult( @@ -96,7 +108,7 @@ def _dispatch(targets, integrators=None, package_info=None, project_root=None): force=False, managed_files=set(), diagnostics=None, - **integrators, + integrators=_to_bundle(integrators), ), integrators