diff --git a/pyproject.toml b/pyproject.toml index 346736f..fb970bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ check-domain = "workflow.scripts.check_domain:app" gcmt-auto-simulate = "workflow.scripts.gcmt_auto_simulate:app" import-realisation = "workflow.scripts.import_realisation:app" lf-to-xarray = "workflow.scripts.lf_to_xarray:app" +migrate = "workflow.scripts.migrate:app" [tool.setuptools.package-dir] workflow = "workflow" diff --git a/workflow/defaults.py b/workflow/defaults.py index c651372..37a70fd 100644 --- a/workflow/defaults.py +++ b/workflow/defaults.py @@ -3,11 +3,11 @@ import importlib from enum import StrEnum from importlib import resources -from typing import Any import yaml import workflow.default_parameters.root as root +from workflow import utils class DefaultsVersion(StrEnum): @@ -19,30 +19,6 @@ class DefaultsVersion(StrEnum): develop = "develop" -def _merge_defaults(defaults_a: dict[str, Any], defaults_b: dict[str, Any]) -> None: - """Deep-merge dictionaries in place, updating the first with values from the second. - - Parameters - ---------- - defaults_a : dict[str, Any] - Base dictionary to be updated. This dictionary is modified in place with - merged values. - defaults_b : dict[str, Any] - Dictionary providing overriding values. Keys in this dictionary are - preferred when keys conflict. This dictionary is not modified. - """ - - for key, value in defaults_b.items(): - if ( - key in defaults_a - and isinstance(defaults_a[key], dict) - and isinstance(value, dict) - ): - _merge_defaults(defaults_a[key], defaults_b[key]) - else: - defaults_a[key] = value - - def load_defaults(version: DefaultsVersion) -> dict[str, int | float | str]: """Load default parameters for EMOD3D simulation from a YAML file. @@ -72,5 +48,5 @@ def load_defaults(version: DefaultsVersion) -> dict[str, int | float | str]: defaults_path = resources.files(defaults_package) / "defaults.yaml" with defaults_path.open(encoding="utf-8") as emod3d_defaults_file_handle: defaults = yaml.safe_load(emod3d_defaults_file_handle) - _merge_defaults(root_defaults, defaults) + utils.merge_dictionaries(root_defaults, defaults) return root_defaults diff --git a/workflow/scripts/migrate.py b/workflow/scripts/migrate.py new file mode 100644 index 0000000..c8416e6 --- /dev/null +++ b/workflow/scripts/migrate.py @@ -0,0 +1,587 @@ +"""Check that realisation can be loaded, if it can't automatically trim extraneous tags and offer to fill in default values.""" + +import difflib +import inspect +import json +import re +import shutil +from collections import defaultdict +from collections.abc import MutableMapping +from enum import Enum, auto +from pathlib import Path +from typing import Annotated, TypeVar + +import parse +import schema +import typer +from rich.console import Console + +from qcore import cli +from workflow import realisations, utils +from workflow.defaults import DefaultsVersion +from workflow.realisations import RealisationMetadata, Seeds + +app = typer.Typer() +console = Console() + + +def is_realisation_configuration(cls: type) -> bool: + """Returns True if the class is a subclass of realisation configuration. + + Parameters + ---------- + cls : type + Type to check. + + Returns + ------- + bool + True if class is a realisation configuration. + """ + return ( + cls != realisations.RealisationConfiguration + and inspect.isclass(cls) + and issubclass(cls, realisations.RealisationConfiguration) + ) + + +ConfigType = TypeVar("ConfigType", bound=realisations.RealisationConfiguration) + + +def realisation_configurations() -> list[ConfigType]: + """Return a list of all realisation configurations. + + Returns + ------- + list[ConfigType] + A list of all realisation configuration types. + """ + return [ + cls + for name, cls in inspect.getmembers(realisations) + if is_realisation_configuration(cls) + ] + + +def loadable_defaults( + configurations: list[type], defaults: DefaultsVersion +) -> dict[ConfigType, realisations.RealisationConfiguration]: + """Filter a list of realisation configurations for those with loadable defaults. + + + + Parameters + ---------- + configurations : list[type] + Configurations to filter. + defaults : defaults.DefaultsVersion + Defaults to try and load. + + + Returns + ------- + dict[ConfigType, realisations.RealisationConfiguration] + A mapping from realisation configuration types to their + defaults specified by ``defaults``. + + Raises + ------ + TypeError + If ``configurations`` contains a type that is not a + realisation configuration. + + """ + config_defaults = {} + for config in configurations: + if not is_realisation_configuration(config): + raise TypeError( + f"{config=} should be a subclass of realisations.RealisationConfiguration" + ) + else: + try: + default_config = config.read_from_defaults(defaults) # type: ignore[unresolved-attribute] + config_defaults[config] = default_config + except realisations.RealisationParseError: + continue + return config_defaults + + +class Response(Enum): + """Enum for response to prompts asked of user.""" + + YES = auto() + NO = auto() + AUTO = auto() # Always (!) + NEVER = auto() # Never (N) + + +class Action(Enum): + """Migration actions that can be taken on realisation configuration.""" + + MIGRATE = auto() + TRIM = auto() + FILL = auto() + UPDATE = auto() + + +def yes_no_always_prompt(raw_prompt: str) -> Response: + """Prompt user for a decision, handling y, n, !, and N. + + + Parameters + ---------- + raw_prompt : str + Prompt to prepend to options. + + + Returns + ------- + Response + Response from user. + """ + + prompt = f"{raw_prompt} (y/n/!/N): " + response_map = { + "N": Response.NEVER, + "!": Response.AUTO, + "A": Response.AUTO, + "y": Response.YES, + "n": Response.NO, + } + while True: + raw_response = input(prompt).strip() + if raw_response in response_map: + return response_map[raw_response] + + +def autofill( + realisation: Path, + config: realisations.RealisationConfiguration, + dry_run: bool, +) -> None: + """Autofill realisation with defaults from config. + + Parameters + ---------- + realisation : Path + Realisation to write to. + config : realisations.RealisationConfiguration + Config to write. + dry_run : bool + If True, print to console instead of writing. + """ + if dry_run: + console.print( + f"DRY RUN: Would merge with {config.__class__.__name__} defaults in {realisation}" + ) + else: + config.write_to_realisation(realisation) + + +def extract_error( + name: str, schema: schema.Schema, e: schema.SchemaError +) -> tuple[str, list[str]]: + """Returns the formatted error string and a list of extraneous keys found. + + + Parameters + ---------- + name : str + Name of configuration to parse. + schema : schema.Schema + Schema to read. + e : schema.SchemaError + Schema error encountered. + + + Returns + ------- + str + Human readable error message. + list[str] + Unknown keys identified in error. + """ + + path_segments = [str(a) for a in e.autos if isinstance(a, str)] + keys = [] + for segment in path_segments: + if match := re.match(r"^Key '(.*?)'", segment): + keys.append(match.group(1)) + + last_error = e.autos[-1] if e.autos else str(e) + extraneous_keys = [] + assert isinstance(last_error, str) + if "Wrong keys" in last_error: + extraneous_keys = re.findall(r"'(.*?)'", last_error.split(" in {")[0]) + error_msg = f"Extraneous keys found: [red]{', '.join(extraneous_keys)}[/red]" + return f"Error in {name}: {error_msg}", extraneous_keys + + if match := re.match(r"^Wrong key '(.*?)'", last_error): + unknown_key = match.group(1) + return f"Error in {name}: Unknown key '{unknown_key}'", [unknown_key] + + return f"Error in {name}: {last_error}", [] + + +def should_trim_keys(config: ConfigType, extra_keys: list[str]) -> Response: + """Prompts user if they want to trim extra keys. + + Parameters + ---------- + config : ConfigType + Config to trim keys from. + extra_keys : list[str] + Extra keys to trim. + + Returns + ------- + Response + Response from user to prompt. + """ + return yes_no_always_prompt( + f"Remove extraneous keys {extra_keys} from {config._config_key}?" + ) + + +def should_update(config: ConfigType) -> Response: + """Prompt user to merge config with default values. + + Parameters + ---------- + config : ConfigType + Config to merge with. + + Returns + ------- + Response + Response from user to prompt. + """ + return yes_no_always_prompt(f"Merge with defaults for {config._config_key}?") + + +def trim_keys( + realisation: Path, + config: ConfigType, + extra_keys: list[str], + dry_run: bool, +) -> None: + """Trim extra keys from realisation. + + Parameters + ---------- + realisation : Path + Path to realisation. + config : ConfigType + Config to trim from. + extra_keys : list[str] + Keys to trim. + dry_run : bool + If True, print instead of trimming. + """ + if dry_run: + console.print(f"DRY RUN: Would remove {extra_keys} from {realisation}") + else: + with open(realisation, "r") as f: + data = json.load(f) + + config_data = data[config._config_key] + for k in extra_keys: + config_data.pop(k, None) + + with open(realisation, "w") as f: + json.dump(data, f, indent=4) + + +def print_diff(config_a: dict, config_b: dict) -> None: + """Pretty print diff between two dictionaries. + + Parameters + ---------- + config_a : dict + Dictionary a. + config_b : dict + Dictionary b. + """ + config_a_str = json.dumps(config_a, indent=4) + config_b_str = json.dumps(config_b, indent=4) + + diff = difflib.context_diff( + config_a_str.splitlines(keepends=True), + config_b_str.splitlines(keepends=True), + fromfile="Current", + tofile="Defaults", + ) + + for line in diff: + if line.startswith("+ "): + console.print(f"[green]{line}[/green]", end="") + elif line.startswith("- "): + console.print(f"[red]{line}[/red]", end="") + elif line.startswith("!"): + console.print(f"[yellow]{line}[/yellow]", end="") + else: + console.print(line, end="") + + +def migrate( + realisation: Path, + defaults_version: DefaultsVersion, + check_configs: list[ConfigType], + defaults: dict[ConfigType, realisations.RealisationConfiguration], + auto_response: MutableMapping[tuple[ConfigType, Action], Response], + dry_run: bool, +) -> None: + """Attempt to migrate realisation to new defaults set. + + Parameters + ---------- + realisation : Path + Path to realisation. + defaults_version : DefaultsVersion + Defaults to update to. + check_configs : list[ConfigType] + Configurations to check. + defaults : dict[ConfigType, realisations.RealisationConfiguration] + Defaults to use. + auto_response : MutableMapping[tuple[ConfigType, Action], Response] + Auto response map recording user's always and never requests. + dry_run : bool + If True, print instead of writing to realisations. + """ + metadata = realisations.RealisationMetadata.read_from_realisation(realisation) + if metadata.defaults_version != defaults_version: + console.print( + f"Updating defaults in {realisation} from {metadata.defaults_version} to {defaults_version}" + ) + if not dry_run: + metadata.defaults_version = defaults_version + metadata.write_to_realisation(realisation) + try: + with open(realisation, "r") as f: + json_data = json.load(f) + except json.JSONDecodeError: + console.print( + f"[bold red]Invalid JSON in {realisation}, skipping...[/bold red]" + ) + return + + for config in check_configs: + default_config = defaults.get(config) + if not default_config: + continue + default_config_dict = default_config.to_dict() + current_config = json_data.get(config._config_key, dict()) + if current_config != default_config_dict: + print_diff(current_config, default_config_dict) + print("") + response = auto_response.get((config, Action.UPDATE)) or should_update( + config + ) + + if response in (Response.AUTO, Response.NEVER): + auto_response[(config, Action.UPDATE)] = response + + if response in (response.AUTO, response.YES): + autofill( + realisation, + default_config, + dry_run=dry_run, + ) + + try: + _ = config.read_from_realisation(realisation) + except realisations.RealisationParseError: + if config not in defaults and config != realisations.Seeds: + console.print( + f"[bold red]Missing required configuration {config.__class__.__name__}[/bold red]" + ) + except schema.SchemaError as error: + console.print(f"[red]Schema error for {realisation}[/red]") + + default_config = defaults.get(config) + error, extra_keys = extract_error(config._config_key, config._schema, error) + console.print(error) + if extra_keys: + response = auto_response.get((config, Action.TRIM)) or should_trim_keys( + config, extra_keys + ) + + if response in (Response.AUTO, Response.NEVER): + auto_response[(config, Action.TRIM)] = response + + if response in (response.AUTO, response.YES): + trim_keys(realisation, config, extra_keys, dry_run) + # Try to read one more time + try: + _ = config.read_from_realisation(realisation) + except schema.SchemaError as error: + error, _ = extract_error( + config._config_key, config._schema, error + ) + console.print( + f"[bold red]Unrecoverable schema error for {realisation}[/bold red]" + ) + console.print(error) + + except Exception as e: # noqa: BLE001 + console.print( + f"[bold red]Could not load realisation {realisation} for unrecoverable reason:[/bold red]" + ) + console.print(str(e)) + + +@cli.from_docstring(app, name="migrate") # type: ignore[invalid-argument-type] +def migrate_all( + realisation_directory: Annotated[ + Path, typer.Argument(exists=True, file_okay=False) + ], + defaults_version: DefaultsVersion, + glob: str = "*.json", + backup: str | None = None, + dry_run: bool = False, +) -> None: + """Migrate all realisations in a directory to the current workflow version. + + Parameters + ---------- + realisation_directory : Path + Path containing realisations. + defaults_version : DefaultsVersion + Defaults version to migrate to. + glob : str + Glob pattern to look for realisations. + backup : str | None + If given, backup the realisation file with named suffix before + running migration. Equivalent to the ``-iext`` flag used in + sed. Has no effect when combined with dry run. + dry_run : bool + If given, print instead of writing. Useful to check what would + be migrated. + """ + auto_response = dict() + configs = realisation_configurations() + defaults = loadable_defaults(configs, defaults_version) + + for realisation in realisation_directory.rglob(glob): + if backup and not dry_run: # only make a copy if we actually modify the file. + shutil.copy( + realisation, realisation.with_suffix(realisation.suffix + backup) + ) + + migrate( + realisation, + defaults_version, + configs, + defaults, + auto_response, + dry_run, + ) + + +@cli.from_docstring(app) +def copy( + realisation_template: Annotated[Path, typer.Argument(exists=True, dir_okay=False)], + realisation_directory: Annotated[ + Path, typer.Argument(exists=True, file_okay=False) + ], + configs: list[str] | None = None, + backup: str | None = None, + glob: str = "*.json", +) -> None: + """Utility to copy blocks of configurations between a template and a directory of realisations. + + Realisation configurations can be partially specified, so that + some values can be replaced without replacing all of the others. + + Parameters + ---------- + realisation_template : Path + Template realisation to copy from. + realisation_directory : Path + Directory containing realisation files. + configs : list[str] + Configurations to copy. If None, will copy all configurations + in realisation file. + backup : str | None + If given, backup the realisation file with named suffix before + running migration. Equivalent to the ``-iext`` flag used in + sed. Has no effect when combined with dry run. + glob : str + Glob pattern to look for realisations. + """ + with open(realisation_template) as f: + template = json.load(f) + + configs = configs or list(template) + + for realisation_path in realisation_directory.rglob(glob): + if backup: + shutil.copy( + realisation_path, + realisation_path.with_suffix(realisation_path.suffix + backup), + ) + + with open(realisation_path) as f: + realisation = json.load(f) + + utils.merge_dictionaries(realisation, template) + + with open(realisation_path, "w") as f: + json.dump(realisation, f, indent=4) + + +@cli.from_docstring(app) +def clone( + realisation_directory: Annotated[ + Path, typer.Argument(exists=True, file_okay=False) + ], + num_realisations: int, + realisation_template: str = "{event}_R{realisation:d}", + regenerate_seeds: bool = True, +) -> None: + """Utility to clone realisations with updated seeds. + + Parameters + ---------- + realisation_directory : Path + Directory containing realisation files. + num_realisations : int + Number of realisations to copy. + realisation_template : str, optional + Template structure for realisation names + regenerate_seeds : bool, optional + If set, re-roll seeds configuration. + """ + + realisations = defaultdict(set) + for realisation in realisation_directory.iterdir(): + realisation_path = realisation / "realisation.json" + parsed_content = parse.parse(realisation_template, realisation.name) + if not (realisation.is_dir and realisation_path.exists() and parsed_content): + continue + assert isinstance(parsed_content, parse.Result) + event = parsed_content["event"] + realisation_number = int(parsed_content["realisation"]) + realisations[event].add(realisation_number) + + for event, existing_realisations in realisations.items(): + base_realisation = min(existing_realisations) + base_realisation_path = realisation_directory / realisation_template.format( + event=event, realisation=base_realisation + ) + for i in range(base_realisation + 1, num_realisations + 1): + # Handles cases like clarence_R1, clarence_R3 existing already. + if i in existing_realisations: + continue + realisation_path = realisation_directory / realisation_template.format( + event=event, realisation=i + ) + shutil.copytree(base_realisation_path, realisation_path) + if regenerate_seeds: + realisation_json = realisation_path / "realisation.json" + seeds = Seeds.random_seeds() + seeds.write_to_realisation(realisation_json) + + +if __name__ == "__main__": + app() diff --git a/workflow/utils.py b/workflow/utils.py index b9acb22..32ed41f 100644 --- a/workflow/utils.py +++ b/workflow/utils.py @@ -1,5 +1,6 @@ """Miscellaneous workflow utilities that couldn't go anywhere else.""" +import inspect import hashlib import os import tempfile @@ -14,6 +15,7 @@ from shapely import Geometry, Polygon, geometry from qcore import coordinates +from workflow import defaults NZ_COASTLINE_URL = "https://www.dropbox.com/scl/fi/zkohh794y0s2189t7b1hi/NZ.gmt?rlkey=02011f4morc4toutt9nzojrw1&st=vpz2ri8x&dl=1" @@ -177,6 +179,25 @@ def dict_zip(*dicts: Mapping[K, Any], strict: bool = True) -> dict[K, tuple[Any, return result +def merge_dictionaries(dict_a: dict[str, Any], dict_b: dict[str, Any]) -> None: + """Deep-merge dictionaries in place, updating the first with values from the second. + + Parameters + ---------- + dict_a : dict[str, Any] + Base dictionary to be updated. This dictionary is modified in place with + merged values. + dict_b : dict[str, Any] + Dictionary providing overriding values. Keys in this dictionary are + preferred when keys conflict. This dictionary is not modified. + """ + + for key, value in dict_b.items(): + if key in dict_a and isinstance(dict_a[key], dict) and isinstance(value, dict): + merge_dictionaries(dict_a[key], dict_b[key]) + else: + dict_a[key] = value + def stable_hash(value: str, size: int = 4) -> int: """Compute stable hashes for strings.