Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ yapf = "^0.43.0"

[tool.poetry.scripts]
pytrajplot = "pytrajplot.main:cli"
pytrajplot-aws = "pytrajplot.aws_wrapper:cli"

[tool.yapf]
based_on_style = "pep8"
Expand Down
97 changes: 97 additions & 0 deletions pytrajplot/aws_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""AWS wrapper for pytrajplot.

Downloads input files from S3, invokes the standard pytrajplot CLI, then uploads output to S3.
Intended for use in AWS Step Functions / ECS Fargate tasks.

S3-specific options are defined here; all other pytrajplot options are passed through unchanged.
"""
import logging
import os
import tempfile

import boto3
import click

from pytrajplot.main import cli as pytrajplot_cli
from pytrajplot.main import print_version
from pytrajplot.s3_utils import download_s3_prefix, upload_dir_to_s3

log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)


@click.command(
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
epilog="Any additional options are forwarded to pytrajplot unchanged (see: pytrajplot --help).",
)
@click.option(
"--s3-input-bucket",
required=True,
envvar="S3_INPUT_BUCKET",
help="S3 bucket containing input files.",
)
@click.option(
"--model-name",
required=True,
envvar="LM_NL_C_TTAG",
help="Model name, first path segment of the S3 input prefix (e.g. 'ICON-CH1-EPS').",
)
@click.option(
"--model-base-time",
required=True,
envvar="MODEL_BASE_TIME",
help="Model base time in YYYYMMDDHHMM format (e.g. '202504030900').",
)
@click.option(
"--s3-output-bucket",
required=True,
envvar="S3_OUTPUT_BUCKET",
help="S3 bucket for output files.",
)
@click.option(
"--s3-output-prefix",
default="",
envvar="S3_OUTPUT_PREFIX",
help="S3 key prefix (folder) for output files. Defaults to empty (bucket root).",
)
@click.option(
"--version",
"-V",
help="Print version and exit.",
is_flag=True,
expose_value=False,
callback=print_version,
)
@click.argument("pytrajplot_args", nargs=-1, type=click.UNPROCESSED)
def cli(
*,
s3_input_bucket: str,
model_name: str,
model_base_time: str,
s3_output_bucket: str,
s3_output_prefix: str,
pytrajplot_args: tuple[str, ...],
) -> None:
"""Run pytrajplot with input/output backed by S3.

S3 options can be supplied as environment variables for ECS / Step Functions deployments.
All standard pytrajplot options (--language, --domain, --datatype, etc.) are passed through.
"""
s3_client = boto3.client("s3")
s3_input_prefix = f"{model_name}/{model_base_time[:8]}_{model_base_time[8:]}"
s3_output_prefix = s3_output_prefix or s3_input_prefix

with tempfile.TemporaryDirectory() as input_dir, tempfile.TemporaryDirectory() as output_dir:
logger.info("Downloading input files from s3://%s/%s", s3_input_bucket, s3_input_prefix)
download_s3_prefix(s3_client, s3_input_bucket, s3_input_prefix, input_dir)

pytrajplot_cli.main(
args=[input_dir, output_dir, *pytrajplot_args],
standalone_mode=False,
)

logger.info("Uploading output files to s3://%s/%s", s3_output_bucket, s3_output_prefix)
upload_dir_to_s3(s3_client, output_dir, s3_output_bucket, s3_output_prefix)

print("--- Done.")
2 changes: 1 addition & 1 deletion pytrajplot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def cli(

if not plot_info_created:
if ssm_parameter_path :
logger.error("File %s/%s does not exist and plot_info could not be created from SSM parameter.", input_dir, info_name)
logger.error("File %s/%s doesn't exist and couldn't be created from SSM parameter.", input_dir, info_name)
raise click.ClickException("Missing plot_info file and failed to create from SSM parameter.")

logger.error("File %s/%s does not exist.", input_dir, info_name)
Expand Down
69 changes: 43 additions & 26 deletions pytrajplot/parsing/plot_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Standard library
from typing import Any
from typing import Dict
import datetime
import os
import logging
from pathlib import Path
Expand All @@ -24,7 +25,7 @@ class PLOT_INFO:

"""

def __init__(self, file) -> None:
def __init__(self, file: str | Path) -> None:
"""Create an instance of ``PLOT_INFO``.

Args:
Expand All @@ -40,38 +41,57 @@ def __init__(self, file) -> None:

def _parse(self) -> None:
"""Parse the plot info file."""
# read the plot_info file
with open(self.file, "r") as file:
for line in file:
elements = line.strip().split(":", maxsplit=1)
# Skip extraction of header information if line contains no ":"
if len(elements) == 1:
continue
key, data = elements[0], elements[1].lstrip()
if key == "Model base time":
self.data["mbt"] = "".join(data)
if key == "Model name":
self.data["model_name"] = "".join(data)
raw_content = file.read()

logger.info("plot_info file content (%s):\n%s", self.file, raw_content)

for line in raw_content.splitlines():
elements = line.strip().split(":", maxsplit=1)
# Skip extraction of header information if line contains no ":"
if len(elements) == 1:
continue
key, data = elements[0], elements[1].lstrip()
if key == "Model base time":
self.data["mbt"] = "".join(data)
if key == "Model name":
self.data["model_name"] = "".join(data)

logger.info("plot_info parsed dict: %s", self.data)


def _format_model_base_time(raw: str) -> str:
"""Convert MODEL_BASE_TIME from YYYYMMDDHHMM to YYYY-MM-DD HH:MM UTC.

Args:
raw: Model base time string in YYYYMMDDHHMM format (e.g. '202504030900')

Returns:
Formatted string suitable for plot_info (e.g. '2025-04-03 09:00 UTC')
"""
return datetime.datetime.strptime(raw, "%Y%m%d%H%M").strftime("%Y-%m-%d %H:%M UTC")


def replace_variables(template_content: str) -> str:
"""
Replace $VAR with actual environment variable values.
MODEL_BASE_TIME is converted from YYYYMMDDHHMM to YYYY-MM-DD HH:MM UTC before substitution.
Args:
template_content: Template string with $VARIABLE placeholders
Returns:
String with variables replaced by environment values
"""
result = template_content
# Get all environment variables as dict
env_vars = dict(os.environ)

# Replace variables found in the template
if "MODEL_BASE_TIME" in env_vars:
env_vars["MODEL_BASE_TIME"] = _format_model_base_time(env_vars["MODEL_BASE_TIME"])

for env_key, env_value in env_vars.items():
placeholder = f'${env_key}'
if placeholder in result:
result = result.replace(placeholder, env_value)
logger.info(f"Replaced {placeholder} with {env_value}")
logger.info("Replaced %s with %s", placeholder, env_value)
return result


Expand All @@ -91,20 +111,17 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str

# If file exists, use it regardless of SSM config
if plot_info_file.exists():
logger.info(f"Plot info file already exists: {plot_info_file}")
logger.info("Plot info file already exists: %s", plot_info_file)
return True

# File doesn't exist, try to create it from SSM parameter
ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH')
if not ssm_param_path:
logger.error(f"Plot info file not found and no ssm parameter set: {plot_info_file}")
logger.error("Plot info file not found and no ssm parameter set: %s", plot_info_file)
return False

try:
# Get SSM parameter path from argument or environment
#ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH')
logger.info(f"Fetching SSM parameter: {ssm_param_path}")

logger.info("Fetching SSM parameter: %s", ssm_param_path)

# Fetch template from SSM Parameter
ssm_client = boto3.client('ssm')
response = ssm_client.get_parameter(
Expand All @@ -114,7 +131,7 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str

# Get the template content
template_content = response['Parameter']['Value']
logger.info(f"Template content length: {len(template_content)} chars")
logger.info("Template content length: %s chars", len(template_content))

# Replace variables with environment variable values
substituted_content = replace_variables(template_content)
Expand All @@ -123,10 +140,10 @@ def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str
with open(plot_info_file, 'w') as f:
f.write(substituted_content)

logger.info(f"Successfully created plot info file: {plot_info_file}")
logger.info("Successfully created plot info file: %s", plot_info_file)
return True

except Exception as e:
logger.error(f"Failed to create plot info file from SSM parameter: {str(e)}")
logger.error(f"SSM parameter path: {ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set')}")
logger.error("Failed to create plot info file from SSM parameter: %s", e)
logger.error("SSM parameter path: %s", ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set'))
return False
72 changes: 72 additions & 0 deletions pytrajplot/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""S3 utility functions for pytrajplot AWS integration."""
import logging
import os
from typing import Any

from botocore.exceptions import ClientError

log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)


def download_s3_prefix(s3_client: Any, bucket: str, prefix: str, local_dir: str) -> None:
"""Download all objects under an S3 prefix into a local directory, preserving relative paths.

Raises:
RuntimeError: If no files are found under the given S3 prefix.
RuntimeError: If the bucket does not exist or access is denied.
"""
try:
paginator = s3_client.get_paginator("list_objects_v2")
downloaded = 0
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key: str = obj["Key"]
relative_path = key[len(prefix):].lstrip("/")
if not relative_path:
continue
local_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
logger.info("Downloading s3://%s/%s -> %s", bucket, key, local_path)
s3_client.download_file(bucket, key, local_path)
downloaded += 1
except ClientError as e:
code = e.response["Error"]["Code"]
if code == "NoSuchBucket":
raise RuntimeError(f"S3 bucket does not exist: {bucket}") from e
if code == "AccessDenied":
raise RuntimeError(f"Access denied reading from s3://{bucket}/{prefix}") from e
raise

if downloaded == 0:
raise RuntimeError(f"No files found at s3://{bucket}/{prefix} — bucket prefix is empty or does not exist.")


def upload_dir_to_s3(s3_client: Any, local_dir: str, bucket: str, prefix: str) -> None:
"""Upload all files in a local directory to an S3 prefix.

Raises:
RuntimeError: If the output directory is empty (pytrajplot produced no output).
RuntimeError: If the bucket does not exist or access is denied.
"""
uploaded = 0
try:
for root, _, files in os.walk(local_dir):
for filename in files:
local_path = os.path.join(root, filename)
relative_path = os.path.relpath(local_path, local_dir)
s3_key = f"{prefix.rstrip('/')}/{relative_path}" if prefix else relative_path
logger.info("Uploading %s -> s3://%s/%s", local_path, bucket, s3_key)
s3_client.upload_file(local_path, bucket, s3_key)
uploaded += 1
except ClientError as e:
code = e.response["Error"]["Code"]
if code == "NoSuchBucket":
raise RuntimeError(f"S3 bucket does not exist: {bucket}") from e
if code == "AccessDenied":
raise RuntimeError(f"Access denied writing to s3://{bucket}/{prefix}") from e
raise

if uploaded == 0:
raise RuntimeError(f"No output files were produced — nothing uploaded to s3://{bucket}/{prefix}.")
Loading