diff --git a/pyproject.toml b/pyproject.toml index 9667022..1dd9d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ yapf = "^0.43.0" [tool.poetry.scripts] pytrajplot = "pytrajplot.main:cli" +pytrajplot-aws = "pytrajplot.aws_wrapper:cli" [tool.yapf] based_on_style = "pep8" diff --git a/pytrajplot/aws_wrapper.py b/pytrajplot/aws_wrapper.py new file mode 100644 index 0000000..8e86eca --- /dev/null +++ b/pytrajplot/aws_wrapper.py @@ -0,0 +1,97 @@ +"""AWS wrapper for pytrajplot. + +Downloads input files from S3, invokes the standard pytrajplot CLI, then uploads output to S3. +Intended for use in AWS Step Functions / ECS Fargate tasks. + +S3-specific options are defined here; all other pytrajplot options are passed through unchanged. +""" +import logging +import os +import tempfile + +import boto3 +import click + +from pytrajplot.main import cli as pytrajplot_cli +from pytrajplot.main import print_version +from pytrajplot.s3_utils import download_s3_prefix, upload_dir_to_s3 + +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +logging.basicConfig(level=log_level) +logger = logging.getLogger(__name__) + + +@click.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, + epilog="Any additional options are forwarded to pytrajplot unchanged (see: pytrajplot --help).", +) +@click.option( + "--s3-input-bucket", + required=True, + envvar="S3_INPUT_BUCKET", + help="S3 bucket containing input files.", +) +@click.option( + "--model-name", + required=True, + envvar="LM_NL_C_TTAG", + help="Model name, first path segment of the S3 input prefix (e.g. 'ICON-CH1-EPS').", +) +@click.option( + "--model-base-time", + required=True, + envvar="MODEL_BASE_TIME", + help="Model base time in YYYYMMDDHHMM format (e.g. '202504030900').", +) +@click.option( + "--s3-output-bucket", + required=True, + envvar="S3_OUTPUT_BUCKET", + help="S3 bucket for output files.", +) +@click.option( + "--s3-output-prefix", + default="", + envvar="S3_OUTPUT_PREFIX", + help="S3 key prefix (folder) for output files. Defaults to empty (bucket root).", +) +@click.option( + "--version", + "-V", + help="Print version and exit.", + is_flag=True, + expose_value=False, + callback=print_version, +) +@click.argument("pytrajplot_args", nargs=-1, type=click.UNPROCESSED) +def cli( + *, + s3_input_bucket: str, + model_name: str, + model_base_time: str, + s3_output_bucket: str, + s3_output_prefix: str, + pytrajplot_args: tuple[str, ...], +) -> None: + """Run pytrajplot with input/output backed by S3. + + S3 options can be supplied as environment variables for ECS / Step Functions deployments. + All standard pytrajplot options (--language, --domain, --datatype, etc.) are passed through. + """ + s3_client = boto3.client("s3") + s3_input_prefix = f"{model_name}/{model_base_time[:8]}_{model_base_time[8:]}" + s3_output_prefix = s3_output_prefix or s3_input_prefix + + with tempfile.TemporaryDirectory() as input_dir, tempfile.TemporaryDirectory() as output_dir: + logger.info("Downloading input files from s3://%s/%s", s3_input_bucket, s3_input_prefix) + download_s3_prefix(s3_client, s3_input_bucket, s3_input_prefix, input_dir) + + pytrajplot_cli.main( + args=[input_dir, output_dir, *pytrajplot_args], + standalone_mode=False, + ) + + logger.info("Uploading output files to s3://%s/%s", s3_output_bucket, s3_output_prefix) + upload_dir_to_s3(s3_client, output_dir, s3_output_bucket, s3_output_prefix) + + print("--- Done.") diff --git a/pytrajplot/main.py b/pytrajplot/main.py index 63efe99..75dc44a 100644 --- a/pytrajplot/main.py +++ b/pytrajplot/main.py @@ -167,7 +167,7 @@ def cli( if not plot_info_created: if ssm_parameter_path : - logger.error("File %s/%s does not exist and plot_info could not be created from SSM parameter.", input_dir, info_name) + logger.error("File %s/%s doesn't exist and couldn't be created from SSM parameter.", input_dir, info_name) raise click.ClickException("Missing plot_info file and failed to create from SSM parameter.") logger.error("File %s/%s does not exist.", input_dir, info_name) diff --git a/pytrajplot/parsing/plot_info.py b/pytrajplot/parsing/plot_info.py index 1d1f263..505c5b4 100644 --- a/pytrajplot/parsing/plot_info.py +++ b/pytrajplot/parsing/plot_info.py @@ -3,6 +3,7 @@ # Standard library from typing import Any from typing import Dict +import datetime import os import logging from pathlib import Path @@ -24,7 +25,7 @@ class PLOT_INFO: """ - def __init__(self, file) -> None: + def __init__(self, file: str | Path) -> None: """Create an instance of ``PLOT_INFO``. Args: @@ -40,38 +41,57 @@ def __init__(self, file) -> None: def _parse(self) -> None: """Parse the plot info file.""" - # read the plot_info file with open(self.file, "r") as file: - for line in file: - elements = line.strip().split(":", maxsplit=1) - # Skip extraction of header information if line contains no ":" - if len(elements) == 1: - continue - key, data = elements[0], elements[1].lstrip() - if key == "Model base time": - self.data["mbt"] = "".join(data) - if key == "Model name": - self.data["model_name"] = "".join(data) + raw_content = file.read() + + logger.info("plot_info file content (%s):\n%s", self.file, raw_content) + + for line in raw_content.splitlines(): + elements = line.strip().split(":", maxsplit=1) + # Skip extraction of header information if line contains no ":" + if len(elements) == 1: + continue + key, data = elements[0], elements[1].lstrip() + if key == "Model base time": + self.data["mbt"] = "".join(data) + if key == "Model name": + self.data["model_name"] = "".join(data) + + logger.info("plot_info parsed dict: %s", self.data) + + +def _format_model_base_time(raw: str) -> str: + """Convert MODEL_BASE_TIME from YYYYMMDDHHMM to YYYY-MM-DD HH:MM UTC. + + Args: + raw: Model base time string in YYYYMMDDHHMM format (e.g. '202504030900') + + Returns: + Formatted string suitable for plot_info (e.g. '2025-04-03 09:00 UTC') + """ + return datetime.datetime.strptime(raw, "%Y%m%d%H%M").strftime("%Y-%m-%d %H:%M UTC") def replace_variables(template_content: str) -> str: """ Replace $VAR with actual environment variable values. + MODEL_BASE_TIME is converted from YYYYMMDDHHMM to YYYY-MM-DD HH:MM UTC before substitution. Args: template_content: Template string with $VARIABLE placeholders Returns: String with variables replaced by environment values """ result = template_content - # Get all environment variables as dict env_vars = dict(os.environ) - # Replace variables found in the template + if "MODEL_BASE_TIME" in env_vars: + env_vars["MODEL_BASE_TIME"] = _format_model_base_time(env_vars["MODEL_BASE_TIME"]) + for env_key, env_value in env_vars.items(): placeholder = f'${env_key}' if placeholder in result: result = result.replace(placeholder, env_value) - logger.info(f"Replaced {placeholder} with {env_value}") + logger.info("Replaced %s with %s", placeholder, env_value) return result @@ -91,20 +111,17 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str # If file exists, use it regardless of SSM config if plot_info_file.exists(): - logger.info(f"Plot info file already exists: {plot_info_file}") + logger.info("Plot info file already exists: %s", plot_info_file) return True - # File doesn't exist, try to create it from SSM parameter ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH') if not ssm_param_path: - logger.error(f"Plot info file not found and no ssm parameter set: {plot_info_file}") + logger.error("Plot info file not found and no ssm parameter set: %s", plot_info_file) return False try: - # Get SSM parameter path from argument or environment - #ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH') - logger.info(f"Fetching SSM parameter: {ssm_param_path}") - + logger.info("Fetching SSM parameter: %s", ssm_param_path) + # Fetch template from SSM Parameter ssm_client = boto3.client('ssm') response = ssm_client.get_parameter( @@ -114,7 +131,7 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str # Get the template content template_content = response['Parameter']['Value'] - logger.info(f"Template content length: {len(template_content)} chars") + logger.info("Template content length: %s chars", len(template_content)) # Replace variables with environment variable values substituted_content = replace_variables(template_content) @@ -123,10 +140,10 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str with open(plot_info_file, 'w') as f: f.write(substituted_content) - logger.info(f"Successfully created plot info file: {plot_info_file}") + logger.info("Successfully created plot info file: %s", plot_info_file) return True except Exception as e: - logger.error(f"Failed to create plot info file from SSM parameter: {str(e)}") - logger.error(f"SSM parameter path: {ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set')}") + logger.error("Failed to create plot info file from SSM parameter: %s", e) + logger.error("SSM parameter path: %s", ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set')) return False diff --git a/pytrajplot/s3_utils.py b/pytrajplot/s3_utils.py new file mode 100644 index 0000000..54394d9 --- /dev/null +++ b/pytrajplot/s3_utils.py @@ -0,0 +1,72 @@ +"""S3 utility functions for pytrajplot AWS integration.""" +import logging +import os +from typing import Any + +from botocore.exceptions import ClientError + +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +logging.basicConfig(level=log_level) +logger = logging.getLogger(__name__) + + +def download_s3_prefix(s3_client: Any, bucket: str, prefix: str, local_dir: str) -> None: + """Download all objects under an S3 prefix into a local directory, preserving relative paths. + + Raises: + RuntimeError: If no files are found under the given S3 prefix. + RuntimeError: If the bucket does not exist or access is denied. + """ + try: + paginator = s3_client.get_paginator("list_objects_v2") + downloaded = 0 + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + for obj in page.get("Contents", []): + key: str = obj["Key"] + relative_path = key[len(prefix):].lstrip("/") + if not relative_path: + continue + local_path = os.path.join(local_dir, relative_path) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + logger.info("Downloading s3://%s/%s -> %s", bucket, key, local_path) + s3_client.download_file(bucket, key, local_path) + downloaded += 1 + except ClientError as e: + code = e.response["Error"]["Code"] + if code == "NoSuchBucket": + raise RuntimeError(f"S3 bucket does not exist: {bucket}") from e + if code == "AccessDenied": + raise RuntimeError(f"Access denied reading from s3://{bucket}/{prefix}") from e + raise + + if downloaded == 0: + raise RuntimeError(f"No files found at s3://{bucket}/{prefix} — bucket prefix is empty or does not exist.") + + +def upload_dir_to_s3(s3_client: Any, local_dir: str, bucket: str, prefix: str) -> None: + """Upload all files in a local directory to an S3 prefix. + + Raises: + RuntimeError: If the output directory is empty (pytrajplot produced no output). + RuntimeError: If the bucket does not exist or access is denied. + """ + uploaded = 0 + try: + for root, _, files in os.walk(local_dir): + for filename in files: + local_path = os.path.join(root, filename) + relative_path = os.path.relpath(local_path, local_dir) + s3_key = f"{prefix.rstrip('/')}/{relative_path}" if prefix else relative_path + logger.info("Uploading %s -> s3://%s/%s", local_path, bucket, s3_key) + s3_client.upload_file(local_path, bucket, s3_key) + uploaded += 1 + except ClientError as e: + code = e.response["Error"]["Code"] + if code == "NoSuchBucket": + raise RuntimeError(f"S3 bucket does not exist: {bucket}") from e + if code == "AccessDenied": + raise RuntimeError(f"Access denied writing to s3://{bucket}/{prefix}") from e + raise + + if uploaded == 0: + raise RuntimeError(f"No output files were produced — nothing uploaded to s3://{bucket}/{prefix}.") diff --git a/test/unit/test_aws_wrapper.py b/test/unit/test_aws_wrapper.py new file mode 100644 index 0000000..aef9727 --- /dev/null +++ b/test/unit/test_aws_wrapper.py @@ -0,0 +1,112 @@ +"""Unit tests for pytrajplot.aws_wrapper.""" +from unittest.mock import MagicMock, patch +from click.testing import CliRunner + +from pytrajplot.aws_wrapper import cli +from pytrajplot.main import __version__ + + +REQUIRED_ARGS = [ + "--s3-input-bucket", "input-bucket", + "--model-name", "ICON-CH1-EPS", + "--model-base-time", "202504030900", + "--s3-output-bucket", "output-bucket", +] + + +class TestAwsWrapperCli: + """Tests for the aws_wrapper CLI command.""" + + def call(self, args=None): + runner = CliRunner() + return runner.invoke(cli, args) + + def test_help(self): + result = self.call(["--help"]) + assert result.exit_code == 0 + assert "S3" in result.output + + def test_version(self): + result = self.call(["-V"]) + assert result.exit_code == 0 + assert __version__ in result.output + + def test_missing_required_options_fails(self): + result = self.call([]) + assert result.exit_code != 0 + + @patch("pytrajplot.aws_wrapper.upload_dir_to_s3") + @patch("pytrajplot.aws_wrapper.download_s3_prefix") + @patch("pytrajplot.aws_wrapper.pytrajplot_cli") + @patch("pytrajplot.aws_wrapper.boto3.client") + def test_success_flow(self, mock_boto3, mock_pytrajplot, mock_download, mock_upload): + mock_boto3.return_value = MagicMock() + + result = self.call(REQUIRED_ARGS) + + assert result.exit_code == 0 + mock_download.assert_called_once() + mock_pytrajplot.main.assert_called_once() + mock_upload.assert_called_once() + + @patch("pytrajplot.aws_wrapper.upload_dir_to_s3") + @patch("pytrajplot.aws_wrapper.download_s3_prefix") + @patch("pytrajplot.aws_wrapper.pytrajplot_cli") + @patch("pytrajplot.aws_wrapper.boto3.client") + def test_s3_input_prefix_format(self, mock_boto3, mock_pytrajplot, mock_download, mock_upload): + """Input prefix is built as model_name/YYYYMMDD_HHMM.""" + mock_boto3.return_value = MagicMock() + + self.call(REQUIRED_ARGS) + + _, _, prefix, _ = mock_download.call_args.args + assert prefix == "ICON-CH1-EPS/20250403_0900" + + @patch("pytrajplot.aws_wrapper.upload_dir_to_s3") + @patch("pytrajplot.aws_wrapper.download_s3_prefix") + @patch("pytrajplot.aws_wrapper.pytrajplot_cli") + @patch("pytrajplot.aws_wrapper.boto3.client") + def test_passthrough_args_forwarded_to_pytrajplot( + self, mock_boto3, mock_pytrajplot, mock_download, mock_upload + ): + """Extra args after required options are forwarded to pytrajplot.""" + mock_boto3.return_value = MagicMock() + + self.call(REQUIRED_ARGS + ["--language", "de", "--domain", "ch"]) + + call_args = mock_pytrajplot.main.call_args + forwarded = call_args.kwargs.get("args") or call_args.args[0] + assert "--language" in forwarded + assert "de" in forwarded + assert "--domain" in forwarded + assert "ch" in forwarded + + @patch("pytrajplot.aws_wrapper.upload_dir_to_s3") + @patch("pytrajplot.aws_wrapper.download_s3_prefix") + @patch("pytrajplot.aws_wrapper.pytrajplot_cli") + @patch("pytrajplot.aws_wrapper.boto3.client") + def test_output_prefix_passed_to_upload( + self, mock_boto3, mock_pytrajplot, mock_download, mock_upload + ): + mock_boto3.return_value = MagicMock() + + self.call(REQUIRED_ARGS + ["--s3-output-prefix", "results/2025/"]) + + _, _, bucket, prefix = mock_upload.call_args.args + assert bucket == "output-bucket" + assert prefix == "results/2025/" + + @patch("pytrajplot.aws_wrapper.upload_dir_to_s3") + @patch("pytrajplot.aws_wrapper.download_s3_prefix") + @patch("pytrajplot.aws_wrapper.pytrajplot_cli") + @patch("pytrajplot.aws_wrapper.boto3.client") + def test_output_prefix_defaults_to_input_prefix( + self, mock_boto3, mock_pytrajplot, mock_download, mock_upload + ): + """When --s3-output-prefix is omitted, output prefix matches the input prefix.""" + mock_boto3.return_value = MagicMock() + + self.call(REQUIRED_ARGS) + + _, _, _, prefix = mock_upload.call_args.args + assert prefix == "ICON-CH1-EPS/20250403_0900" diff --git a/test/unit/test_plot_info.py b/test/unit/test_plot_info.py index 888c32b..7e34e8c 100644 --- a/test/unit/test_plot_info.py +++ b/test/unit/test_plot_info.py @@ -7,10 +7,24 @@ from click.testing import CliRunner from botocore.exceptions import ClientError -from pytrajplot.parsing.plot_info import replace_variables, check_plot_info_file +from pytrajplot.parsing.plot_info import replace_variables, check_plot_info_file, _format_model_base_time from pytrajplot.main import cli +class TestFormatModelBaseTime: + """Test the _format_model_base_time helper.""" + + def test_formats_correctly(self): + assert _format_model_base_time("202504030900") == "2025-04-03 09:00 UTC" + + def test_midnight(self): + assert _format_model_base_time("202501010000") == "2025-01-01 00:00 UTC" + + def test_invalid_format_raises(self): + with pytest.raises(ValueError): + _format_model_base_time("20250403_0900") + + class TestReplaceVariables: """Test the replace_variables function.""" @@ -33,6 +47,22 @@ def test_replace_multiple_variables(self, monkeypatch): assert result == "Start value1 middle value2 end value3" + def test_model_base_time_is_formatted(self, monkeypatch): + """MODEL_BASE_TIME must be converted from YYYYMMDDHHMM to YYYY-MM-DD HH:MM UTC.""" + monkeypatch.setenv("MODEL_BASE_TIME", "202504030900") + + result = replace_variables("Model base time: $MODEL_BASE_TIME") + + assert result == "Model base time: 2025-04-03 09:00 UTC" + + def test_model_base_time_raw_value_unchanged_in_env(self, monkeypatch): + """replace_variables must not mutate os.environ.""" + monkeypatch.setenv("MODEL_BASE_TIME", "202504030900") + + replace_variables("$MODEL_BASE_TIME") + + assert os.environ["MODEL_BASE_TIME"] == "202504030900" + class TestCheckPlotInfoFile: """Test the check_plot_info_file function.""" @@ -84,8 +114,8 @@ def test_create_plot_info_from_ssm_success(self, mock_boto3_client, tmp_path, mo @patch('boto3.client') def test_create_plot_info_template(self, mock_boto3_client, tmp_path, monkeypatch): - """Test plot_info template and vars substitutions.""" - monkeypatch.setenv("LAGRANTO_MODEL_BASE_TIME", "20240115_00") + """Test plot_info template and vars substitutions including MODEL_BASE_TIME formatting.""" + monkeypatch.setenv("MODEL_BASE_TIME", "202401150000") monkeypatch.setenv("LM_NL_C_TTAG", "ICON-CH2-EPS") monkeypatch.setenv("LM_NL_POLLONLM_C", "-170.0") monkeypatch.setenv("LM_NL_POLLATLM_C", "43.0") @@ -98,7 +128,7 @@ def test_create_plot_info_template(self, mock_boto3_client, tmp_path, monkeypatc # Lagranto configuration template template_content = ''' - Model base time: $LAGRANTO_MODEL_BASE_TIME + Model base time: $MODEL_BASE_TIME Model name: $LM_NL_C_TTAG North pole longitude: $LM_NL_POLLONLM_C North pole latitude: $LM_NL_POLLATLM_C @@ -111,7 +141,7 @@ def test_create_plot_info_template(self, mock_boto3_client, tmp_path, monkeypatc ''' expected_content = ''' - Model base time: 20240115_00 + Model base time: 2024-01-15 00:00 UTC Model name: ICON-CH2-EPS North pole longitude: -170.0 North pole latitude: 43.0 @@ -180,7 +210,7 @@ def test_cli_with_existing_plot_info(self, tmp_path): # Mock the generate_pdf function with patch('pytrajplot.main.check_input_dir') as mock_check, \ - patch('pytrajplot.main.generate_pdf') as mock_generate: + patch('pytrajplot.main.generate_pdf'): mock_check.return_value = ({}, {}) diff --git a/test/unit/test_s3_utils.py b/test/unit/test_s3_utils.py new file mode 100644 index 0000000..b11cf73 --- /dev/null +++ b/test/unit/test_s3_utils.py @@ -0,0 +1,111 @@ +"""Unit tests for pytrajplot.s3_utils.""" +import pytest +from unittest.mock import MagicMock +from botocore.exceptions import ClientError + +from pytrajplot.s3_utils import download_s3_prefix, upload_dir_to_s3 + + +def _client_error(code: str) -> ClientError: + return ClientError({"Error": {"Code": code, "Message": code}}, "operation") + + +class TestDownloadS3Prefix: + """Tests for download_s3_prefix.""" + + def _make_paginator(self, pages: list[list[dict]]) -> MagicMock: + paginator = MagicMock() + paginator.paginate.return_value = [ + {"Contents": page} for page in pages + ] + return paginator + + def test_downloads_files_to_local_dir(self, tmp_path): + s3 = MagicMock() + s3.get_paginator.return_value = self._make_paginator( + [[{"Key": "prefix/subdir/file.txt"}, {"Key": "prefix/other.txt"}]] + ) + + download_s3_prefix(s3, "my-bucket", "prefix/", str(tmp_path)) + + assert s3.download_file.call_count == 2 + s3.download_file.assert_any_call( + "my-bucket", "prefix/subdir/file.txt", str(tmp_path / "subdir" / "file.txt") + ) + s3.download_file.assert_any_call( + "my-bucket", "prefix/other.txt", str(tmp_path / "other.txt") + ) + + def test_skips_prefix_itself(self, tmp_path): + """Objects whose key equals the prefix (no relative path) are skipped.""" + s3 = MagicMock() + s3.get_paginator.return_value = self._make_paginator( + [[{"Key": "prefix/"}, {"Key": "prefix/file.txt"}]] + ) + + download_s3_prefix(s3, "my-bucket", "prefix/", str(tmp_path)) + + assert s3.download_file.call_count == 1 + + def test_empty_prefix_raises_runtime_error(self, tmp_path): + s3 = MagicMock() + s3.get_paginator.return_value = self._make_paginator([[]]) + + with pytest.raises(RuntimeError, match="No files found"): + download_s3_prefix(s3, "my-bucket", "empty/", str(tmp_path)) + + def test_no_such_bucket_raises_runtime_error(self, tmp_path): + s3 = MagicMock() + paginator = MagicMock() + paginator.paginate.side_effect = _client_error("NoSuchBucket") + s3.get_paginator.return_value = paginator + + with pytest.raises(RuntimeError, match="does not exist"): + download_s3_prefix(s3, "missing-bucket", "prefix/", str(tmp_path)) + + def test_access_denied_raises_runtime_error(self, tmp_path): + s3 = MagicMock() + paginator = MagicMock() + paginator.paginate.side_effect = _client_error("AccessDenied") + s3.get_paginator.return_value = paginator + + with pytest.raises(RuntimeError, match="Access denied"): + download_s3_prefix(s3, "my-bucket", "prefix/", str(tmp_path)) + +class TestUploadDirToS3: + """Tests for upload_dir_to_s3.""" + + def test_uploads_all_files(self, tmp_path): + (tmp_path / "file1.pdf").write_text("a") + (tmp_path / "sub").mkdir() + (tmp_path / "sub" / "file2.png").write_text("b") + + s3 = MagicMock() + upload_dir_to_s3(s3, str(tmp_path), "my-bucket", "output/") + + assert s3.upload_file.call_count == 2 + uploaded_keys = {c.args[2] for c in s3.upload_file.call_args_list} + assert "output/file1.pdf" in uploaded_keys + assert "output/sub/file2.png" in uploaded_keys + + def test_empty_directory_raises_runtime_error(self, tmp_path): + s3 = MagicMock() + + with pytest.raises(RuntimeError, match="No output files"): + upload_dir_to_s3(s3, str(tmp_path), "my-bucket", "output/") + + def test_no_such_bucket_raises_runtime_error(self, tmp_path): + (tmp_path / "file.txt").write_text("x") + s3 = MagicMock() + s3.upload_file.side_effect = _client_error("NoSuchBucket") + + with pytest.raises(RuntimeError, match="does not exist"): + upload_dir_to_s3(s3, str(tmp_path), "missing-bucket", "prefix/") + + def test_access_denied_raises_runtime_error(self, tmp_path): + (tmp_path / "file.txt").write_text("x") + s3 = MagicMock() + s3.upload_file.side_effect = _client_error("AccessDenied") + + with pytest.raises(RuntimeError, match="Access denied"): + upload_dir_to_s3(s3, str(tmp_path), "my-bucket", "prefix/")