diff --git a/Makefile b/Makefile index 59731544..dc6d05b8 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,7 @@ help: @echo "" @echo "📊 Example Datasets:" @echo " make run-titanic Run on Titanic dataset (medium)" + @echo " make run-titanic-explicit-test-split Run Titanic with explicit train+test inputs" @echo " make run-titanic-proba Run Titanic with probability-focused intent" @echo " make run-house-prices Run on House Prices dataset (regression)" @echo "" @@ -332,6 +333,33 @@ run-titanic: build --spark-mode local \ --enable-final-evaluation +# Spaceship Titanic dataset with explicit test split input +.PHONY: run-titanic-explicit-test-split +run-titanic-explicit-test-split: build + @echo "📊 Running on Spaceship Titanic dataset (explicit train + test splits)..." + $(eval TIMESTAMP := $(shell date +%Y%m%d_%H%M%S)) + docker run --rm \ + --add-host=host.docker.internal:host-gateway \ + $(CONFIG_MOUNT) \ + $(CONFIG_ENV) \ + -v $(PWD)/examples/datasets:/data:ro \ + -v $(PWD)/workdir:/workdir \ + -e OPENAI_API_KEY=$(OPENAI_API_KEY) \ + -e ANTHROPIC_API_KEY=$(ANTHROPIC_API_KEY) \ + -e SPARK_LOCAL_CORES=4 \ + -e SPARK_DRIVER_MEMORY=4g \ + plexe:py$(PYTHON_VERSION) \ + python -m plexe.main \ + --train-dataset-uri /data/spaceship-titanic/train.parquet \ + --test-dataset-uri /data/spaceship-titanic/test.csv \ + --user-id dev_user \ + --intent "predict whether a passenger was transported" \ + --experiment-id titanic_explicit_test \ + --max-iterations 10 \ + --work-dir /workdir/titanic_explicit_test/$(TIMESTAMP) \ + --spark-mode local \ + --enable-final-evaluation + # Spaceship Titanic dataset with probability-focused objective .PHONY: run-titanic-proba run-titanic-proba: build diff --git a/plexe/CODE_INDEX.md b/plexe/CODE_INDEX.md index f3ee5380..51b1da24 100644 --- a/plexe/CODE_INDEX.md +++ b/plexe/CODE_INDEX.md @@ -1,6 +1,6 @@ # Code Index: plexe -> Generated on 2026-03-03 05:08:33 +> Generated on 2026-03-05 21:32:55 Code structure and public interface documentation for the **plexe** package. @@ -17,7 +17,7 @@ Dataset Splitter Agent. **`DatasetSplitterAgent`** - Agent that generates PySpark code for intelligent dataset splitting. - `__init__(self, spark: SparkSession, dataset_uri: str, context: BuildContext, config: Config)` -- `run(self, split_ratios: dict[str, float], output_dir: str | Path) -> tuple[str, str, str]` - Generate and execute intelligent dataset splitting. +- `run(self, split_ratios: dict[str, float], output_dir: str | Path) -> tuple[str, str, str | None]` - Generate and execute intelligent dataset splitting. --- ## `agents/feature_processor.py` @@ -306,7 +306,7 @@ Amazon S3 storage helper. Universal entry point for plexe. **Functions:** -- `main(intent: str, data_refs: list[str], integration: WorkflowIntegration | None, spark_mode: str, user_id: str, experiment_id: str, max_iterations: int, global_seed: int | None, work_dir: Path, test_dataset_uri: str | None, enable_final_evaluation: bool, max_epochs: int | None, allowed_model_types: list[str] | None, is_retrain: bool, original_model_uri: str | None, original_experiment_id: str | None, auto_mode: bool, user_feedback: dict | None, enable_otel: bool, otel_endpoint: str | None, otel_headers: dict[str, str] | None, external_storage_uri: str | None, csv_delimiter: str, csv_header: bool)` - Main model building function. +- `main(intent: str, data_refs: list[str] | None, integration: WorkflowIntegration | None, spark_mode: str, user_id: str, experiment_id: str, max_iterations: int, global_seed: int | None, work_dir: Path, train_dataset_uri: str | None, val_dataset_uri: str | None, test_dataset_uri: str | None, enable_final_evaluation: bool, nn_default_epochs: int | None, nn_max_epochs: int | None, allowed_model_types: list[str] | None, is_retrain: bool, original_model_uri: str | None, original_experiment_id: str | None, auto_mode: bool, user_feedback: dict | None, enable_otel: bool, otel_endpoint: str | None, otel_headers: dict[str, str] | None, external_storage_uri: str | None, csv_delimiter: str, csv_header: bool)` - Main model building function. --- ## `models.py` @@ -728,10 +728,10 @@ Streamlit dashboard for plexe. Main workflow orchestrator. **Functions:** -- `build_model(spark: SparkSession, train_dataset_uri: str, test_dataset_uri: str | None, user_id: str, intent: str, experiment_id: str, work_dir: Path, runner: TrainingRunner, search_policy: SearchPolicy, config: Config, integration: WorkflowIntegration, enable_final_evaluation: bool, on_checkpoint_saved: Callable[[str, Path, Path], None] | None, pause_points: list[str] | None, on_pause: Callable[[str], None] | None, user_feedback: dict | None) -> tuple[Solution, dict, EvaluationReport | None] | None` - Main workflow orchestrator. +- `build_model(spark: SparkSession, train_dataset_uri: str, val_dataset_uri: str | None, test_dataset_uri: str | None, user_id: str, intent: str, experiment_id: str, work_dir: Path, runner: TrainingRunner, search_policy: SearchPolicy, config: Config, integration: WorkflowIntegration, enable_final_evaluation: bool, on_checkpoint_saved: Callable[[str, Path, Path], None] | None, pause_points: list[str] | None, on_pause: Callable[[str], None] | None, user_feedback: dict | None) -> tuple[Solution, dict, EvaluationReport | None] | None` - Main workflow orchestrator. - `sanitize_dataset_column_names(spark: SparkSession, dataset_uri: str, context: BuildContext) -> str` - Sanitize column names by replacing special characters with underscores. - `analyze_data(spark: SparkSession, dataset_uri: str, context: BuildContext, config: Config, on_checkpoint_saved: Callable[[str, Path, Path], None] | None)` - Phase 1: Layout detection + Statistical + ML task analysis + metric selection. -- `prepare_data(spark: SparkSession, training_dataset_uri: str, test_dataset_uri: str | None, context: BuildContext, config: Config, integration: WorkflowIntegration, generate_test_set: bool, on_checkpoint_saved: Callable[[str, Path, Path], None] | None)` - Phase 2: Split dataset and extract sample. +- `prepare_data(spark: SparkSession, training_dataset_uri: str, val_dataset_uri: str | None, test_dataset_uri: str | None, context: BuildContext, config: Config, integration: WorkflowIntegration, generate_test_set: bool, on_checkpoint_saved: Callable[[str, Path, Path], None] | None)` - Phase 2: Split dataset and extract sample. - `build_baselines(spark: SparkSession, context: BuildContext, config: Config, on_checkpoint_saved: Callable[[str, Path, Path], None] | None)` - Phase 3: Build baseline models. - `search_models(spark: SparkSession, context: BuildContext, runner: TrainingRunner, search_policy: SearchPolicy, config: Config, integration: WorkflowIntegration, on_checkpoint_saved: Callable[[str, Path, Path], None] | None, restored_journal: SearchJournal | None, restored_insight_store: InsightStore | None) -> Solution | None` - Phase 4: Iterative tree-search for best model. - `retrain_on_full_dataset(spark: SparkSession, best_solution: Solution, context: BuildContext, runner: TrainingRunner, config: Config) -> Solution` - Retrain best solution on FULL dataset. diff --git a/plexe/agents/dataset_splitter.py b/plexe/agents/dataset_splitter.py index 5d4837a2..99c3a76c 100644 --- a/plexe/agents/dataset_splitter.py +++ b/plexe/agents/dataset_splitter.py @@ -21,7 +21,7 @@ from plexe.tools.submission import get_save_split_uris_tool from plexe.utils.tracing import agent_span from plexe.config import get_routing_for_model -from plexe.validation.validators import validate_dataset_splits +from plexe.validation.validators import canonicalize_split_ratios, validate_dataset_splits logger = logging.getLogger(__name__) @@ -57,15 +57,20 @@ def _build_agent(self, split_ratios: dict[str, float]) -> CodeAgent: """Build CodeAgent with splitting tool.""" # Get routing configuration for this agent's model api_base, headers = get_routing_for_model(self.config.routing_config, self.llm_model) - # TODO(splitter-prompts): Make split instructions conditional on requested split mode. - # 2-way modes should not instruct writing test.parquet or passing test_uri. + # Clean up ratio names and shape before deciding split mode. + split_ratios = canonicalize_split_ratios(split_ratios) + expects_test_split = split_ratios.get("test", 0.0) > 0 return CodeAgent( name="DatasetSplitter", instructions=( "## YOUR ROLE:\n" - "Intelligently split datasets into train/validation/test sets. This is NOT trivial - " - "how you split SIGNIFICANTLY impacts model quality, validity, and data leakage prevention.\n" + + ( + "Intelligently split datasets into train/validation/test sets. This is NOT trivial - " + if expects_test_split + else "Intelligently split datasets into train/validation sets. This is NOT trivial - " + ) + + "how you split SIGNIFICANTLY impacts model quality, validity, and data leakage prevention.\n" "\n" "## CONTEXT AVAILABLE:\n" "Review prior analysis to inform your strategy:\n" @@ -76,66 +81,81 @@ def _build_agent(self, split_ratios: dict[str, float]) -> CodeAgent: "Your PySpark code has access to these variables:\n" "- `spark`: SparkSession\n" "- `dataset_uri`: Path to parquet dataset\n" - "- `split_ratios`: Dict with 'train'/'val'/'test' fractions (e.g., {\"train\": 0.7, \"val\": 0.15, \"test\": 0.15})\n" - "- `output_dir`: Directory for writing split files\n" - "- `task_type`: Task type string (e.g., 'binary_classification', 'regression', 'time_series')\n" - "- `output_targets`: List of target column names\n" - "\n" - "## STRATEGY SELECTION:\n" - "\n" - "Choose based on data characteristics:\n" - "1. **Time Series**: If temporal columns exist AND task requires forecasting → Chronological split (prevent future leakage)\n" - "2. **Classification**: If task_type contains 'classification' → Stratified split (preserve class balance)\n" - "3. **Small Datasets**: If <10,000 rows → Use 90/5/5 instead of given ratios\n" - "4. **Group Preservation**: If user_id, session_id, group_id columns exist → Keep groups intact\n" - "5. **Regression** (default): Random split\n" - "\n" - "## TEMPORAL SPLIT DECISION GUIDE:\n" - "\n" - "IMPORTANT: Timestamp columns do NOT automatically require chronological splitting.\n" - "\n" - "**USE chronological split when:**\n" - "- Forecasting future values (predict tomorrow from today)\n" - "- Target has trends/seasonality that change over time\n" - "- Concept drift likely (patterns evolve)\n" - "- Production will predict on FUTURE time periods\n" - "\n" - "**Random/stratified split is acceptable when:**\n" - "- Timestamp is metadata only (created_at, upload_date)\n" - "- Cross-sectional classification (predict attribute, not future event)\n" - "- Target distribution stable across time\n" - "- Time is a feature, not the prediction axis\n" - "\n" - "**Quick test:** Will model predict SAME time period (unseen records) → random OK. " - "Will model predict FUTURE time periods → chronological required.\n" - "\n" - "## KEY PATTERNS:\n" - "\n" - "**Stratified (Classification):**\n" - "Use `df.sampleBy(target_col, fractions, seed=42)` to sample proportionally from each class.\n" - "Apply to train, then remainder for val/test.\n" - "\n" - "**Chronological (Time-Series):**\n" - "Detect time column → sort by it → use Window.row_number() → filter by cutoffs.\n" - "Train=oldest, Test=newest to simulate production (no future leakage).\n" - "\n" - "**Random (Regression):**\n" - "Use `df.randomSplit(weights, seed=42)` with normalized ratios.\n" - "\n" - "## YOUR TASK:\n" - "1. Load: df = spark.read.parquet(dataset_uri)\n" - "2. Inspect stats_report/task_analysis for time columns, groups, imbalance\n" - "3. Generate appropriate PySpark split code\n" - "4. Write to {output_dir}/train.parquet, val.parquet, test.parquet (mode='overwrite')\n" - "5. Call save_split_uris(train_uri, val_uri, test_uri)\n" - "\n" - "## CRITICAL RULES:\n" - "- PySpark only (NOT pandas) - data may be 200GB\n" - "- seed=42 for reproducibility\n" - "- Classification → MUST stratify (unless forecasting task)\n" - "- Forecasting/predicting future → MUST be chronological\n" - "- Timestamp alone does NOT mandate chronological split (see TEMPORAL SPLIT DECISION GUIDE)\n" - "- NO data leakage between splits\n" + + ( + "- `split_ratios`: Dict with 'train'/'val'/'test' fractions " + '(e.g., {"train": 0.7, "val": 0.15, "test": 0.15})\n' + if expects_test_split + else "- `split_ratios`: Dict with 'train'/'val' fractions " '(e.g., {"train": 0.85, "val": 0.15})\n' + ) + + "- `output_dir`: Directory for writing split files\n" + + "- `task_type`: Task type string (e.g., 'binary_classification', 'regression', 'time_series')\n" + + "- `output_targets`: List of target column names\n" + + "\n" + + "## STRATEGY SELECTION:\n" + + "\n" + + "Choose based on data characteristics:\n" + + "1. **Time Series**: If temporal columns exist AND task requires forecasting → Chronological split (prevent future leakage)\n" + + "2. **Classification**: If task_type contains 'classification' → Stratified split (preserve class balance)\n" + + f"3. **Small Datasets**: If <10,000 rows → Use {'90/5/5' if expects_test_split else '90/10'} instead of given ratios\n" + + "4. **Group Preservation**: If user_id, session_id, group_id columns exist → Keep groups intact\n" + + "5. **Regression** (default): Random split\n" + + "\n" + + "## TEMPORAL SPLIT DECISION GUIDE:\n" + + "\n" + + "IMPORTANT: Timestamp columns do NOT automatically require chronological splitting.\n" + + "\n" + + "**USE chronological split when:**\n" + + "- Forecasting future values (predict tomorrow from today)\n" + + "- Target has trends/seasonality that change over time\n" + + "- Concept drift likely (patterns evolve)\n" + + "- Production will predict on FUTURE time periods\n" + + "\n" + + "**Random/stratified split is acceptable when:**\n" + + "- Timestamp is metadata only (created_at, upload_date)\n" + + "- Cross-sectional classification (predict attribute, not future event)\n" + + "- Target distribution stable across time\n" + + "- Time is a feature, not the prediction axis\n" + + "\n" + + "**Quick test:** Will model predict SAME time period (unseen records) → random OK. " + + "Will model predict FUTURE time periods → chronological required.\n" + + "\n" + + "## KEY PATTERNS:\n" + + "\n" + + "**Stratified (Classification):**\n" + + "Use `df.sampleBy(target_col, fractions, seed=42)` to sample proportionally from each class.\n" + + "Apply to train, then remainder for val/test.\n" + + "\n" + + "**Chronological (Time-Series):**\n" + + "Detect time column → sort by it → use Window.row_number() → filter by cutoffs.\n" + + "Train=oldest, Test=newest to simulate production (no future leakage).\n" + + "\n" + + "**Random (Regression):**\n" + + "Use `df.randomSplit(weights, seed=42)` with normalized ratios.\n" + + "\n" + + "## YOUR TASK:\n" + + "1. Load: df = spark.read.parquet(dataset_uri)\n" + + "2. Inspect stats_report/task_analysis for time columns, groups, imbalance\n" + + "3. Generate appropriate PySpark split code\n" + + ( + "4. Write to {output_dir}/train.parquet, val.parquet, test.parquet (mode='overwrite')\n" + "5. Call save_split_uris(train_uri, val_uri, test_uri)\n" + if expects_test_split + else "4. Write to {output_dir}/train.parquet and val.parquet (mode='overwrite')\n" + "5. Call save_split_uris(train_uri, val_uri) with NO test_uri\n" + ) + + "\n" + + "## CRITICAL RULES:\n" + + "- PySpark only (NOT pandas) - data may be 200GB\n" + + "- seed=42 for reproducibility\n" + + "- Classification → MUST stratify (unless forecasting task)\n" + + "- Forecasting/predicting future → MUST be chronological\n" + + "- Timestamp alone does NOT mandate chronological split (see TEMPORAL SPLIT DECISION GUIDE)\n" + + "- NO data leakage between splits\n" + + ( + "- Create all three splits (train, val, test)\n" + if expects_test_split + else "- Create only train and validation splits (no test split for this run)\n" + ) ), model=PlexeLiteLLMModel( model_id=self.llm_model, @@ -160,7 +180,7 @@ def _build_agent(self, split_ratios: dict[str, float]) -> CodeAgent: ) @agent_span("DatasetSplitterAgent") - def run(self, split_ratios: dict[str, float], output_dir: str | Path) -> tuple[str, str, str]: + def run(self, split_ratios: dict[str, float], output_dir: str | Path) -> tuple[str, str, str | None]: """ Generate and execute intelligent dataset splitting. @@ -172,6 +192,7 @@ def run(self, split_ratios: dict[str, float], output_dir: str | Path) -> tuple[s (train_uri, val_uri, test_uri) """ + split_ratios = canonicalize_split_ratios(split_ratios) logger.info(f"Starting intelligent dataset splitting with ratios: {split_ratios}") # Convert to Path if local path string, or keep as S3 URI string @@ -243,9 +264,11 @@ def _build_task_prompt(self, split_ratios: dict[str, float], output_dir: str) -> output_targets = self.context.output_targets data_challenges = self.context.task_analysis.get("data_challenges", []) recommended_split = self.context.task_analysis.get("recommended_split", {}) + normalized_ratios = canonicalize_split_ratios(split_ratios) + expects_test_split = normalized_ratios.get("test", 0.0) > 0 prompt = ( - f"Split the dataset into train/validation/test sets.\n" + f"Split the dataset into {'train/validation/test' if expects_test_split else 'train/validation'} sets.\n" f"\n" f"Task Type: {task_type}\n" f"Output Targets: {output_targets}\n" @@ -265,32 +288,41 @@ def _build_task_prompt(self, split_ratios: dict[str, float], output_dir: str) -> prompt += ( "\n" - # TODO(splitter-prompts): Make this task prompt explicitly 2-way vs 3-way. - # Current wording always asks for train/val/test outputs, which can induce accidental holdouts. - "Based on the task type and data characteristics, choose the appropriate splitting strategy:\n" - "- Classification → Stratified split (preserve class balance)\n" - "- Forecasting future events/values → Chronological split (train on past, test on future)\n" - "- Timestamp exists but task is cross-sectional → Random/stratified split is acceptable\n" - "- Small dataset (<10K rows) → Adjust ratios to 90/5/5\n" - "- Group-based data (user_id, session_id) → Preserve groups within splits\n" - "- Regression → Simple random split\n" - "\n" - "Review the stats_report and task_analysis (available in additional_args) to detect:\n" - "- Temporal columns (datetime types or time/date in name)\n" - "- Group columns (user_id, customer_id, session_id patterns)\n" - "- Class imbalance (from target_analysis)\n" - "- Dataset size (from stats_report)\n" - "\n" - "Generate PySpark code to:\n" - "1. Load dataset from dataset_uri\n" - "2. Apply appropriate split strategy (use examples in instructions)\n" - "3. Write three parquet files to output_dir:\n" - " - {output_dir}/train.parquet\n" - " - {output_dir}/val.parquet\n" - " - {output_dir}/test.parquet\n" - "4. Call save_split_uris(train_path, val_path, test_path) with the full file paths\n" - "\n" - "CRITICAL: Ensure splits are appropriate for this specific dataset and task.\n" + + "Based on the task type and data characteristics, choose the appropriate splitting strategy:\n" + + "- Classification → Stratified split (preserve class balance)\n" + + "- Forecasting future events/values → Chronological split (train on past, test on future)\n" + + "- Timestamp exists but task is cross-sectional → Random/stratified split is acceptable\n" + + ( + "- Small dataset (<10K rows) → Adjust ratios to 90/5/5\n" + if expects_test_split + else "- Small dataset (<10K rows) → Adjust ratios to 90/10\n" + ) + + "- Group-based data (user_id, session_id) → Preserve groups within splits\n" + + "- Regression → Simple random split\n" + + "\n" + + "Review the stats_report and task_analysis (available in additional_args) to detect:\n" + + "- Temporal columns (datetime types or time/date in name)\n" + + "- Group columns (user_id, customer_id, session_id patterns)\n" + + "- Class imbalance (from target_analysis)\n" + + "- Dataset size (from stats_report)\n" + + "\n" + + "Generate PySpark code to:\n" + + "1. Load dataset from dataset_uri\n" + + "2. Apply appropriate split strategy (use examples in instructions)\n" + + ( + "3. Write three parquet files to output_dir:\n" + " - {output_dir}/train.parquet\n" + " - {output_dir}/val.parquet\n" + " - {output_dir}/test.parquet\n" + "4. Call save_split_uris(train_path, val_path, test_path) with the full file paths\n" + if expects_test_split + else "3. Write two parquet files to output_dir:\n" + " - {output_dir}/train.parquet\n" + " - {output_dir}/val.parquet\n" + "4. Call save_split_uris(train_path, val_path) with NO test_path\n" + "5. Do NOT create or submit a test split in this run\n" + ) + + "\nCRITICAL: Ensure splits are appropriate for this specific dataset and task.\n" ) return prompt diff --git a/plexe/main.py b/plexe/main.py index 49fb4f52..28a64bac 100644 --- a/plexe/main.py +++ b/plexe/main.py @@ -24,6 +24,7 @@ from plexe.integrations.base import WorkflowIntegration from plexe.config import setup_logging, setup_litellm, get_config from plexe.constants import DirNames, PhaseNames +from plexe.execution.dataproc.dataset_io import DatasetNormalizer from plexe.execution.dataproc.session import get_or_create_spark_session, stop_spark_session from plexe.execution.training.local_runner import LocalProcessRunner from plexe.search.tree_policy import TreeSearchPolicy @@ -36,7 +37,7 @@ def main( intent: str, - data_refs: list[str], # TODO: Support multiple datasets + join_strategy when multi-dataset joining is implemented + data_refs: list[str] | None = None, # Deprecated fallback for backward compatibility integration: WorkflowIntegration | None = None, spark_mode: str = "local", user_id: str = "default_user", @@ -44,9 +45,12 @@ def main( max_iterations: int = 10, global_seed: int | None = None, work_dir: Path = Path("/tmp/model_builder_v2"), + train_dataset_uri: str | None = None, + val_dataset_uri: str | None = None, test_dataset_uri: str | None = None, enable_final_evaluation: bool = False, - max_epochs: int | None = None, + nn_default_epochs: int | None = None, + nn_max_epochs: int | None = None, allowed_model_types: list[str] | None = None, is_retrain: bool = False, original_model_uri: str | None = None, @@ -65,7 +69,7 @@ def main( Args: intent: ML task description - data_refs: Dataset references + data_refs: Deprecated dataset references list (only first element used) integration: WorkflowIntegration instance (default: StandaloneIntegration) spark_mode: Spark backend ("local" or "databricks") user_id: User identifier @@ -73,9 +77,12 @@ def main( max_iterations: Maximum search iterations global_seed: Global seed for reproducible runs (random + numpy + search policies) work_dir: Working directory for artifacts + train_dataset_uri: Required training dataset URI (preferred over data_refs) + val_dataset_uri: Optional validation dataset URI test_dataset_uri: Optional test dataset URI enable_final_evaluation: Whether to run test evaluation - max_epochs: Cap Keras epochs (for testing) + nn_default_epochs: Override default epochs for neural network training + nn_max_epochs: Override max epochs cap for neural network training allowed_model_types: Restrict model types is_retrain: Whether this is a retraining job original_model_uri: URI to original model.tar.gz (for retraining) @@ -110,10 +117,19 @@ def main( # Apply CLI argument overrides (highest priority) config.max_search_iterations = max_iterations config.spark_mode = spark_mode - if max_epochs: - config.nn_max_epochs = max_epochs - if config.nn_default_epochs > config.nn_max_epochs: - config.nn_default_epochs = config.nn_max_epochs + epoch_overrides = {} + if nn_default_epochs is not None: + epoch_overrides["nn_default_epochs"] = nn_default_epochs + if nn_max_epochs is not None: + epoch_overrides["nn_max_epochs"] = nn_max_epochs + if epoch_overrides: + merged_config = config.model_dump() | epoch_overrides + # Preserve prior max-only override behavior: lowering the cap alone should + # also lower the default when it would otherwise violate validation. + if "nn_max_epochs" in epoch_overrides and "nn_default_epochs" not in epoch_overrides: + if merged_config.get("nn_default_epochs", 0) > epoch_overrides["nn_max_epochs"]: + merged_config["nn_default_epochs"] = epoch_overrides["nn_max_epochs"] + config = config.__class__.model_validate(merged_config) if allowed_model_types: config.allowed_model_types = allowed_model_types if global_seed is not None: @@ -155,21 +171,76 @@ def main( work_dir.mkdir(parents=True, exist_ok=True) integration.prepare_workspace(experiment_id, work_dir) - # Normalize dataset to parquet - if not data_refs: - raise ValueError("No dataset references provided") - input_uri = data_refs[0] - normalized_output = integration.get_artifact_location("normalized", input_uri, experiment_id, work_dir) - spark = get_or_create_spark_session(config) - from plexe.execution.dataproc.dataset_io import DatasetNormalizer + # Resolve training dataset input (new train_dataset_uri preferred) + if train_dataset_uri: + resolved_train_input_uri = train_dataset_uri + if data_refs: + logger.warning( + "Both train_dataset_uri and deprecated data_refs provided; " + "using train_dataset_uri and ignoring data_refs" + ) + elif data_refs: + resolved_train_input_uri = data_refs[0] + logger.warning( + "data_refs is deprecated and will be removed in a future release. " + "Please pass train_dataset_uri instead." + ) + if len(data_refs) > 1: + logger.warning( + f"Multiple dataset refs provided, using first: {resolved_train_input_uri}. " + "Multi-dataset joining is not supported yet." + ) + else: + raise ValueError("train_dataset_uri is required (or use deprecated data_refs=[...])") + # Normalize training dataset to parquet + normalized_output = integration.get_artifact_location( + "normalized", resolved_train_input_uri, experiment_id, work_dir + ) + spark = get_or_create_spark_session(config) normalizer = DatasetNormalizer(spark) - csv_options = {"sep": csv_delimiter, "header": csv_header} + csv_options = {"sep": config.csv_delimiter, "header": config.csv_header} train_dataset_uri, input_format = normalizer.normalize( - input_uri=input_uri, output_uri=normalized_output, read_options=csv_options + input_uri=resolved_train_input_uri, output_uri=normalized_output, read_options=csv_options ) input_format = input_format.value + # Optional explicit split datasets are ignored in retraining mode + if is_retrain and (val_dataset_uri or test_dataset_uri): + logger.warning("val_dataset_uri/test_dataset_uri are ignored in retraining mode") + val_dataset_uri = None + test_dataset_uri = None + + # Normalize optional provided val/test datasets using existing normalizer logic + val_input_format = None + test_input_format = None + if not is_retrain: + normalized_output_base = integration.get_artifact_location( + "normalized", train_dataset_uri, experiment_id, work_dir + ) + + def normalize_dataset_uri(dataset_uri: str, split_name: str) -> tuple[str, str]: + if normalized_output_base.startswith("s3://"): + output_uri = f"{normalized_output_base}/normalized_inputs/{split_name}.parquet" + else: + output_uri = str(Path(normalized_output_base) / "normalized_inputs" / f"{split_name}.parquet") + + normalized_uri, detected_format = normalizer.normalize( + input_uri=dataset_uri, + output_uri=output_uri, + read_options=csv_options, + ) + return normalized_uri, detected_format.value + + if val_dataset_uri: + val_dataset_uri, val_input_format = normalize_dataset_uri(val_dataset_uri, "val") + if test_dataset_uri: + test_dataset_uri, test_input_format = normalize_dataset_uri(test_dataset_uri, "test") + + if test_dataset_uri and not enable_final_evaluation: + logger.info("test_dataset_uri provided; auto-enabling final evaluation") + enable_final_evaluation = True + # Prepare original model if retraining if is_retrain: # Pick which reference to use (prioritize explicit URI over experiment ID) @@ -189,6 +260,10 @@ def main( logger.info(f"LiteLLM routing: {'custom config' if config.routing_config else 'default providers'}") logger.info(f"Intent: {intent}") logger.info(f"Dataset: {train_dataset_uri} (format: {input_format}) | Max iterations: {max_iterations}") + if val_dataset_uri: + logger.info(f"Validation dataset: {val_dataset_uri} (format: {val_input_format})") + if test_dataset_uri: + logger.info(f"Test dataset: {test_dataset_uri} (format: {test_input_format})") if config.global_seed is not None and config.max_parallel_variants > 1: logger.info( "Reproducibility note: max_parallel_variants>1 can introduce nondeterminism under threading. " @@ -219,7 +294,6 @@ def _on_checkpoint(phase_name, checkpoint_path, work_dir): evaluation_report = None else: # Normal build workflow - spark = get_or_create_spark_session(config) search_policy = TreeSearchPolicy(seed=config.global_seed) # search_policy = EvolutionarySearchPolicy(seed=config.global_seed) TODO: enable after testing @@ -233,6 +307,7 @@ def _on_checkpoint(phase_name, checkpoint_path, work_dir): result = build_model( spark=spark, train_dataset_uri=train_dataset_uri, + val_dataset_uri=val_dataset_uri, test_dataset_uri=test_dataset_uri, user_id=user_id, intent=intent, @@ -286,6 +361,7 @@ def _on_checkpoint(phase_name, checkpoint_path, work_dir): parser.add_argument( "--train-dataset-uri", required=True, help="Path to training dataset (CSV, ORC, Avro, or Parquet)" ) + parser.add_argument("--val-dataset-uri", help="Optional: Path to validation dataset (CSV, ORC, Avro, or Parquet)") parser.add_argument("--test-dataset-uri", help="Optional: Path to test dataset (CSV, ORC, Avro, or Parquet)") parser.add_argument("--user-id", default=os.getenv("USER_ID", "default_user"), help="User identifier") parser.add_argument("--intent", required=True, help="ML task description") @@ -294,7 +370,16 @@ def _on_checkpoint(phase_name, checkpoint_path, work_dir): parser.add_argument("--seed", type=int, help="Global seed for reproducible runs") parser.add_argument("--work-dir", type=Path, default=Path("/tmp/model_builder_v2"), help="Working directory") parser.add_argument("--enable-final-evaluation", action="store_true", help="Enable test set evaluation") - parser.add_argument("--max-epochs", type=int, help="Cap neural network epochs (Keras, PyTorch)") + parser.add_argument( + "--nn-default-epochs", + type=int, + help="Override default epochs for neural network training (Keras, PyTorch)", + ) + parser.add_argument( + "--nn-max-epochs", + type=int, + help="Override max epochs cap for neural network training (Keras, PyTorch)", + ) parser.add_argument( "--allowed-model-types", nargs="+", @@ -418,16 +503,18 @@ def _on_checkpoint(phase_name, checkpoint_path, work_dir): try: main( intent=args.intent, - data_refs=[args.train_dataset_uri], + train_dataset_uri=args.train_dataset_uri, spark_mode=spark_mode, user_id=args.user_id, experiment_id=args.experiment_id, max_iterations=args.max_iterations, global_seed=args.seed, work_dir=args.work_dir, + val_dataset_uri=args.val_dataset_uri, test_dataset_uri=args.test_dataset_uri, enable_final_evaluation=enable_final_evaluation, - max_epochs=args.max_epochs, + nn_default_epochs=args.nn_default_epochs, + nn_max_epochs=args.nn_max_epochs, allowed_model_types=args.allowed_model_types, is_retrain=args.is_retrain, original_model_uri=args.original_model_uri, diff --git a/plexe/workflow.py b/plexe/workflow.py index 88a709d8..d81ab93c 100644 --- a/plexe/workflow.py +++ b/plexe/workflow.py @@ -107,6 +107,7 @@ def _apply_allowed_model_types_on_resume(context: BuildContext, config: Config, def build_model( spark: SparkSession, train_dataset_uri: str, + val_dataset_uri: str | None, test_dataset_uri: str | None, user_id: str, intent: str, @@ -128,6 +129,7 @@ def build_model( Args: spark: SparkSession train_dataset_uri: URI to training dataset + val_dataset_uri: Optional URI to separate validation dataset test_dataset_uri: Optional URI to separate test dataset user_id: User identifier intent: Natural language description of ML task @@ -244,64 +246,38 @@ def build_model( train_uri_to_use = context.scratch.get("_filtered_dataset_uri") or context.scratch.get( "_sanitized_dataset_uri", train_dataset_uri ) + splits_output_dir = integration.get_artifact_location( + "splits", train_uri_to_use, context.experiment_id, context.work_dir + ) - # Sanitize test dataset too if provided (must match training data schema) + # Sanitize explicit val/test datasets too if provided (must match training data schema) + val_uri_to_use = val_dataset_uri test_uri_to_use = test_dataset_uri - if test_dataset_uri and "_original_column_names" in context.scratch: - import re - - # Reuse training data's column mapping to ensure consistency - column_mapping = context.scratch["_original_column_names"] - test_df = spark.read.parquet(test_dataset_uri) - - # Apply same transformations as training data - for original, sanitized in column_mapping.items(): - if original in test_df.columns: - # Check if target name already exists (would be silently overwritten) - if sanitized in test_df.columns and sanitized != original: - raise ValueError( - f"Cannot sanitize test dataset: column '{original}' would rename to " - f"'{sanitized}', but '{sanitized}' already exists in test dataset. " - f"This would cause data loss. Please ensure test dataset columns match " - f"training dataset or have unique sanitized names." - ) - test_df = test_df.withColumnRenamed(original, sanitized) - - # Handle columns in test that weren't in training data - test_only_mapping = {} - # Track all existing names: training-sanitized + already-processed test columns - all_existing_names = set(column_mapping.values()) - - for idx, col_name in enumerate(test_df.columns): - if col_name not in column_mapping.values() and any( - char in col_name for char in [".", " ", "-", "(", ")", "[", "]"] - ): - safe_name = re.sub(r"[.\s\-\(\)\[\]]", "_", col_name) - safe_name = re.sub(r"_+", "_", safe_name).strip("_") - if not safe_name: - safe_name = f"col_{idx}" - - # Check for collisions with training names AND other test columns - original_safe_name = safe_name - counter = 1 - while safe_name in all_existing_names: - safe_name = f"{original_safe_name}_{counter}" - counter += 1 - - test_df = test_df.withColumnRenamed(col_name, safe_name) - test_only_mapping[col_name] = safe_name - all_existing_names.add(safe_name) # Mark as used - logger.info(f" Test-only column: '{col_name}' → '{safe_name}'") - - # Save sanitized test dataset - test_uri_to_use = f"{context.work_dir}/{DirNames.BUILD_DIR}/data/test_sanitized.parquet" - test_df.write.mode("overwrite").parquet(test_uri_to_use) - logger.info("✓ Test dataset sanitized using training column mapping") - context.scratch["_sanitized_test_dataset_uri"] = test_uri_to_use + column_mapping = context.scratch.get("_original_column_names") + if column_mapping: + if val_dataset_uri: + val_uri_to_use = _sanitize_aux_dataset_column_names( + spark=spark, + dataset_uri=val_dataset_uri, + column_mapping=column_mapping, + context=context, + split_name="val", + output_dir=splits_output_dir, + ) + if test_dataset_uri: + test_uri_to_use = _sanitize_aux_dataset_column_names( + spark=spark, + dataset_uri=test_dataset_uri, + column_mapping=column_mapping, + context=context, + split_name="test", + output_dir=splits_output_dir, + ) prepare_data( spark, train_uri_to_use, + val_uri_to_use, test_uri_to_use, context, config, @@ -615,6 +591,64 @@ def sanitize_dataset_column_names(spark: SparkSession, dataset_uri: str, context return sanitized_uri +def _sanitize_aux_dataset_column_names( + spark: SparkSession, + dataset_uri: str, + column_mapping: dict[str, str], + context: BuildContext, + split_name: str, + output_dir: str, +) -> str: + """ + Apply training-derived column sanitization mapping to auxiliary datasets (val/test). + """ + import re + + df = spark.read.parquet(dataset_uri) + + # Apply exact mapping from training sanitization first + for original, sanitized in column_mapping.items(): + if original in df.columns: + if sanitized in df.columns and sanitized != original: + raise ValueError( + f"Cannot sanitize {split_name} dataset: column '{original}' would rename to " + f"'{sanitized}', but '{sanitized}' already exists in {split_name} dataset." + ) + df = df.withColumnRenamed(original, sanitized) + + # Track all names currently present in aux dataset after applying training mapping. + # New sanitized names must not collide with any existing aux column name. + all_existing_names = set(df.columns) + for idx, col_name in enumerate(df.columns): + if col_name not in column_mapping.values() and any( + char in col_name for char in [".", " ", "-", "(", ")", "[", "]"] + ): + safe_name = re.sub(r"[.\s\-\(\)\[\]]", "_", col_name) + safe_name = re.sub(r"_+", "_", safe_name).strip("_") + if not safe_name: + safe_name = f"col_{idx}" + + original_safe_name = safe_name + counter = 1 + while safe_name in all_existing_names: + safe_name = f"{original_safe_name}_{counter}" + counter += 1 + + df = df.withColumnRenamed(col_name, safe_name) + all_existing_names.add(safe_name) + logger.info(f" {split_name.capitalize()}-only column: '{col_name}' → '{safe_name}'") + + if output_dir.startswith("s3://"): + output_uri = f"{output_dir}/{split_name}_sanitized.parquet" + else: + output_uri = str(Path(output_dir) / f"{split_name}_sanitized.parquet") + + df.write.mode("overwrite").parquet(output_uri) + context.scratch[f"_sanitized_{split_name}_dataset_uri"] = output_uri + logger.info(f"✓ {split_name.capitalize()} dataset sanitized using training column mapping") + return output_uri + + def _set_noop_filtered_dataset_uri(context: BuildContext, dataset_uri: str) -> str: """Store no-op filtering result in context and return original URI.""" context.excluded_columns = [] @@ -848,9 +882,39 @@ def analyze_data( # ============================================ +def _materialize_explicit_split( + spark: SparkSession, + dataset_uri: str, + split_name: str, + context: BuildContext, + output_dir: str, +) -> str: + """Copy explicit val/test split into build directory and apply excluded-column drops.""" + split_df = spark.read.parquet(dataset_uri) + + if context.excluded_columns: + excluded_column_names = [ + entry.get("column") for entry in context.excluded_columns if isinstance(entry, dict) and entry.get("column") + ] + columns_to_drop = [col for col in excluded_column_names if col in split_df.columns] + if columns_to_drop: + split_df = split_df.drop(*columns_to_drop) + logger.info(f"Excluded {len(columns_to_drop)} columns from {split_name} dataset: {columns_to_drop}") + + if output_dir.startswith("s3://"): + output_uri = f"{output_dir}/{split_name}.parquet" + else: + output_uri = str(Path(output_dir) / f"{split_name}.parquet") + + split_df.write.mode("overwrite").parquet(output_uri) + logger.info(f"Copied {split_name} dataset to: {output_uri}") + return output_uri + + def prepare_data( spark: SparkSession, training_dataset_uri: str, + val_dataset_uri: str | None, test_dataset_uri: str | None, context: BuildContext, config: Config, @@ -861,91 +925,111 @@ def prepare_data( """ Phase 2: Split dataset and extract sample. - Supports two modes: - 1. Single dataset: Split training_dataset_uri into train/val(/test if enabled) - 2. Separate datasets: Split training_dataset_uri into train/val, use test_dataset_uri as test + Supports explicit split injection while generating missing splits from training data. Args: spark: SparkSession training_dataset_uri: URI to training dataset + val_dataset_uri: Optional URI to separate validation dataset test_dataset_uri: Optional URI to separate test dataset context: Build context config: Configuration integration: WorkflowIntegration for infrastructure queries - generate_test_set: Whether to create test set from training data (ignored if test_dataset_uri provided) + generate_test_set: Whether to create test set from training data when not explicitly provided on_checkpoint_saved: Optional callback for platform integration """ logger.info("=== Phase 2: Data Preparation ===") - # Step 1: Handle Test Dataset (if provided separately) - if test_dataset_uri: - logger.info(f"Separate test dataset mode: {test_dataset_uri}") - - # Copy test dataset to DirNames.BUILD_DIR/data/ for consistency - test_df = spark.read.parquet(test_dataset_uri) - if context.excluded_columns: - excluded_column_names = [ - entry.get("column") - for entry in context.excluded_columns - if isinstance(entry, dict) and entry.get("column") - ] - columns_to_drop = [col for col in excluded_column_names if col in test_df.columns] - if columns_to_drop: - test_df = test_df.drop(*columns_to_drop) - logger.info(f"Excluded {len(columns_to_drop)} columns from test dataset: {columns_to_drop}") - test_uri = str(context.work_dir / DirNames.BUILD_DIR / "data" / "test.parquet") - test_df.write.mode("overwrite").parquet(test_uri) - logger.info(f"Copied test dataset to: {test_uri}") - - # Always 2-way split when separate test provided - split_ratios = {"train": 0.85, "val": 0.15} - logger.info("Splitting training data into train/val only (test provided separately)") + provided_val = val_dataset_uri is not None + provided_test = test_dataset_uri is not None - else: - # Single dataset mode: create test from split if requested - test_uri = None - if generate_test_set: - # 3-way split: train/val/test - # Handle both new schema (ratios nested) and legacy (ratios at top level) - recommended_split = context.task_analysis.get("recommended_split", {}) - if "ratios" in recommended_split: - # New schema: extract ratios from nested structure - split_ratios = recommended_split["ratios"] - elif "train" in recommended_split: - # Legacy schema: ratios at top level - split_ratios = recommended_split - else: - # Default fallback - split_ratios = {"train": 0.7, "val": 0.15, "test": 0.15} + splits_output_dir = integration.get_artifact_location( + "splits", training_dataset_uri, context.experiment_id, context.work_dir + ) + val_uri = ( + _materialize_explicit_split(spark, val_dataset_uri, "val", context, splits_output_dir) if provided_val else None + ) + test_uri = ( + _materialize_explicit_split(spark, test_dataset_uri, "test", context, splits_output_dir) + if provided_test + else None + ) + train_uri = training_dataset_uri + + def run_split(split_ratios: dict[str, float], output_dir: str = splits_output_dir) -> tuple[str, str, str | None]: + splitter = DatasetSplitterAgent(spark=spark, dataset_uri=training_dataset_uri, context=context, config=config) + return splitter.run(split_ratios=split_ratios, output_dir=output_dir) + + def get_generated_split_output_dir() -> str: + if splits_output_dir.startswith("s3://"): + return f"{splits_output_dir}/generated" + return str(Path(splits_output_dir) / "generated") + + def get_three_way_split_ratios() -> dict[str, float]: + recommended_split = (context.task_analysis or {}).get("recommended_split", {}) + if "ratios" in recommended_split: + split_ratios = recommended_split["ratios"] + elif "train" in recommended_split: + split_ratios = recommended_split + else: + split_ratios = {"train": 0.7, "val": 0.15, "test": 0.15} + + split_ratios = canonicalize_split_ratios(split_ratios) + if {"train", "val", "test"}.issubset(split_ratios): + return split_ratios - split_ratios = canonicalize_split_ratios(split_ratios) - if not {"train", "val", "test"}.issubset(split_ratios): + logger.warning( + "Recommended split ratios are missing one of train/val/test (%s); " + "falling back to default 70/15/15 for final evaluation.", + split_ratios, + ) + return {"train": 0.7, "val": 0.15, "test": 0.15} + + if provided_val and provided_test: + logger.info("Using provided train/val/test splits without splitter") + elif provided_val and not provided_test: + if generate_test_set: + logger.info("Validation split provided - generating missing test split from training dataset") + split_ratios = get_three_way_split_ratios() + test_ratio = split_ratios.get("test", 0.15) + split_ratios = {"train": max(1.0 - test_ratio, 0.01), "val": test_ratio} + generated_output_dir = get_generated_split_output_dir() + train_uri, generated_test_uri, split_test_uri = run_split(split_ratios, output_dir=generated_output_dir) + if split_test_uri: logger.warning( - "Recommended split ratios are missing one of train/val/test (%s); " - "falling back to default 70/15/15 for final evaluation.", - split_ratios, + "Splitter returned an unexpected third split while generating test-only split; " + "using returned test URI" ) - split_ratios = {"train": 0.7, "val": 0.15, "test": 0.15} - - logger.info("Creating train/val/test splits from single dataset (final evaluation enabled)") + test_uri = split_test_uri + else: + # In this path the 2-way splitter "val" output is intentionally repurposed as test split. + test_uri = generated_test_uri + else: + logger.info("Validation split provided - test split not requested") + elif not provided_val and provided_test: + logger.info("Test split provided - generating missing validation split from training dataset") + val_ratio = get_three_way_split_ratios().get("val", 0.15) + split_ratios = {"train": max(1.0 - val_ratio, 0.01), "val": val_ratio} + train_uri, val_uri, split_test_uri = run_split(split_ratios) + if split_test_uri: + logger.warning("Splitter returned unexpected test split in train/val mode; ignoring generated test URI") + else: + if generate_test_set: + split_ratios = get_three_way_split_ratios() + logger.info("Creating train/val/test splits from training dataset") else: - # 2-way split: train/val only split_ratios = {"train": 0.85, "val": 0.15} - logger.info("Creating train/val splits only (final evaluation disabled)") + logger.info("Creating train/val splits from training dataset") - # Step 2: Split Training Dataset - splitter = DatasetSplitterAgent(spark=spark, dataset_uri=training_dataset_uri, context=context, config=config) - - # Get splits output location from integration (based on dataset location) - splits_output_dir = integration.get_artifact_location( - "splits", training_dataset_uri, context.experiment_id, context.work_dir - ) - - train_uri, val_uri, split_test_uri = splitter.run(split_ratios=split_ratios, output_dir=splits_output_dir) + train_uri, val_uri, split_test_uri = run_split(split_ratios) + if generate_test_set: + test_uri = split_test_uri + elif split_test_uri: + logger.warning("Splitter returned unexpected test split while final evaluation disabled; ignoring it") - # Use separate test if provided, otherwise use split test - test_uri = test_uri if test_dataset_uri else split_test_uri + if not val_uri: + raise ValueError("Validation split URI is required for sampling and was not resolved") # Step 3: Create Intelligent Samples sampler = SamplingAgent(spark=spark, context=context, config=config) @@ -974,36 +1058,8 @@ def prepare_data( val_sample_uri=val_sample_uri, ) - # Step 5: Validate schema compatibility (if separate test provided) - if test_dataset_uri: - logger.info("Validating test dataset schema compatibility...") - train_df = spark.read.parquet(train_uri) - test_df = spark.read.parquet(test_uri) - - # Check target column exists in test - if context.output_targets[0] not in test_df.columns: - raise ValueError( - f"Test dataset missing target column '{context.output_targets[0]}'. " f"Test columns: {test_df.columns}" - ) - - # Check feature overlap - train_features = set(train_df.columns) - set(context.output_targets) - test_features = set(test_df.columns) - set(context.output_targets) - - missing_in_test = train_features - test_features - if missing_in_test: - raise ValueError( - f"Test dataset missing {len(missing_in_test)} features from training data: {sorted(missing_in_test)}\n" - f"Model cannot make predictions without these features." - ) - - extra_in_test = test_features - train_features - if extra_in_test: - logger.warning( - f"Test dataset has {len(extra_in_test)} extra features not in training data (will be ignored)" - ) - - logger.info("✓ Test dataset schema validation complete") + # TODO(explicit-split-schema-validation): Validate explicit val/test schema compatibility + # against training schema (target + feature alignment) in a dedicated follow-up PR. logger.info(f"Data preparation complete: train={train_uri}, val={val_uri}, test={test_uri}") logger.info(f"Samples created: train_sample={train_sample_uri}, val_sample={val_sample_uri}") diff --git a/pyproject.toml b/pyproject.toml index aee2a460..0b3a01bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "plexe" -version = "1.4.3" +version = "1.4.4" description = "An agentic framework for building ML models from natural language" authors = [ "Marcello De Bernardi ", diff --git a/tests/CODE_INDEX.md b/tests/CODE_INDEX.md index 66b0f26a..1fbe1b68 100644 --- a/tests/CODE_INDEX.md +++ b/tests/CODE_INDEX.md @@ -1,6 +1,6 @@ # Code Index: tests -> Generated on 2026-03-03 05:08:33 +> Generated on 2026-03-05 21:32:55 Test suite structure and test case documentation. @@ -52,6 +52,14 @@ Stage 3 integration tests: run evaluation/packaging and validate predictors. **Functions:** - `test_resume_and_run_eval_then_predict(model_type: str, artifact_root, repo_root) -> None` - Resume from stage 2 checkpoints, run to completion, and validate predictor inference. +--- +## `unit/agents/test_dataset_splitter_prompt.py` +Prompt-level tests for DatasetSplitterAgent split-mode instructions. + +**Functions:** +- `test_build_task_prompt_for_two_way_split_avoids_test_output()` - No description +- `test_build_task_prompt_for_three_way_split_requires_test_output()` - No description + --- ## `unit/agents/test_feedback.py` Tests for user feedback integration in agents. @@ -275,6 +283,17 @@ Unit tests for LightGBM predictor template. - `test_lightgbm_predictor_predict_proba_without_metadata(tmp_path: Path) -> None` - No description - `test_lightgbm_predictor_predict_proba_raises_for_regression(tmp_path: Path) -> None` - No description +--- +## `unit/test_main_dataset_inputs.py` +Unit tests for main() dataset input handling. + +**Functions:** +- `test_main_prefers_train_dataset_uri_and_forwards_optional_splits(monkeypatch, tmp_path)` - No description +- `test_main_auto_enables_final_evaluation_when_test_dataset_is_provided(monkeypatch, tmp_path)` - No description +- `test_main_nn_max_epochs_override_clamps_default_when_only_cap_is_set(monkeypatch, tmp_path)` - No description +- `test_main_uses_data_refs_fallback_when_train_dataset_uri_missing(monkeypatch, tmp_path)` - No description +- `test_main_requires_train_dataset_uri_or_data_refs(monkeypatch, tmp_path)` - No description + --- ## `unit/test_models.py` Unit tests for core model dataclasses. @@ -421,6 +440,15 @@ Unit tests for model card generation. - `test_generate_model_card_full_context(tmp_path: Path) -> None` - No description - `test_generate_model_card_minimal_context(tmp_path: Path) -> None` - No description +--- +## `unit/workflow/test_prepare_data_explicit_splits.py` +Unit tests for prepare_data split resolution with explicit val/test datasets. + +**Functions:** +- `test_prepare_data_uses_all_provided_splits_without_running_splitter(monkeypatch, tmp_path)` - No description +- `test_prepare_data_generates_missing_test_when_only_val_is_provided(monkeypatch, tmp_path)` - No description +- `test_prepare_data_generates_missing_val_when_only_test_is_provided(monkeypatch, tmp_path)` - No description + --- ## `unit/workflow/test_resume_model_type_filtering.py` Tests for resume-time model type filtering. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8d316284..730baecf 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -251,6 +251,7 @@ def build_seed_workflow(work_dir: Path, dataset_input: Path, intent: str, experi return build_model( spark=spark, train_dataset_uri=train_dataset_uri, + val_dataset_uri=None, test_dataset_uri=None, user_id="integration_test", intent=intent, @@ -306,6 +307,7 @@ def resume_workflow( return build_model( spark=spark, train_dataset_uri=resume_context["dataset_uri"], + val_dataset_uri=None, test_dataset_uri=None, user_id=resume_context["user_id"], intent=resume_context["intent"], diff --git a/tests/unit/agents/test_dataset_splitter_prompt.py b/tests/unit/agents/test_dataset_splitter_prompt.py new file mode 100644 index 00000000..31467ead --- /dev/null +++ b/tests/unit/agents/test_dataset_splitter_prompt.py @@ -0,0 +1,56 @@ +"""Prompt-level tests for DatasetSplitterAgent split-mode instructions.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytest.importorskip("dataclasses_json") +pytest.importorskip("smolagents") + +from plexe.agents.dataset_splitter import DatasetSplitterAgent +from plexe.models import BuildContext + + +def _make_context(): + context = BuildContext( + user_id="user-1", + experiment_id="exp-1", + dataset_uri="train.parquet", + work_dir="/tmp", + intent="predict churn", + ) + context.task_analysis = {"task_type": "binary_classification", "data_challenges": [], "recommended_split": {}} + context.output_targets = ["target"] + return context + + +def test_build_task_prompt_for_two_way_split_avoids_test_output(): + agent = DatasetSplitterAgent( + spark=object(), + dataset_uri="train.parquet", + context=_make_context(), + config=SimpleNamespace(dataset_splitting_llm="test-model"), + ) + + prompt = agent._build_task_prompt({"train": 0.85, "val": 0.15}, "/tmp/splits") + + assert "train/validation sets" in prompt + assert "save_split_uris(train_path, val_path) with NO test_path" in prompt + assert "Do NOT create or submit a test split in this run" in prompt + + +def test_build_task_prompt_for_three_way_split_requires_test_output(): + agent = DatasetSplitterAgent( + spark=object(), + dataset_uri="train.parquet", + context=_make_context(), + config=SimpleNamespace(dataset_splitting_llm="test-model"), + ) + + prompt = agent._build_task_prompt({"train": 0.7, "val": 0.15, "test": 0.15}, "/tmp/splits") + + assert "train/validation/test sets" in prompt + assert "save_split_uris(train_path, val_path, test_path)" in prompt + assert "Do NOT create or submit a test split in this run" not in prompt diff --git a/tests/unit/test_main_dataset_inputs.py b/tests/unit/test_main_dataset_inputs.py new file mode 100644 index 00000000..ce74487a --- /dev/null +++ b/tests/unit/test_main_dataset_inputs.py @@ -0,0 +1,217 @@ +"""Unit tests for main() dataset input handling.""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest + +pytest.importorskip("pydantic_settings") +pytest.importorskip("pyspark") + +import plexe.main as main_module + + +class _FakeIntegration: + def __init__(self): + self.workspace_calls: list[tuple[str, Path]] = [] + + def prepare_workspace(self, experiment_id: str, work_dir: Path) -> None: + self.workspace_calls.append((experiment_id, work_dir)) + + def get_artifact_location(self, artifact_type: str, dataset_uri: str, experiment_id: str, work_dir: Path) -> str: + _ = dataset_uri, experiment_id + return str(work_dir / ".build" / "data" / artifact_type) + + def ensure_local(self, uris: list[str], work_dir: Path) -> list[str]: + _ = work_dir + return uris + + def prepare_original_model(self, model_reference: str, work_dir: Path) -> str: + _ = work_dir + return model_reference + + def on_checkpoint(self, experiment_id: str, phase_name: str, checkpoint_path: Path, work_dir: Path) -> None: + _ = experiment_id, phase_name, checkpoint_path, work_dir + + def on_completion(self, experiment_id: str, work_dir: Path, final_metrics: dict, evaluation_report) -> None: + _ = experiment_id, work_dir, final_metrics, evaluation_report + + def on_failure(self, experiment_id: str, error: Exception) -> None: + _ = experiment_id, error + + def on_pause(self, phase_name: str) -> None: + _ = phase_name + + +def _patch_main_dependencies(monkeypatch, build_model_spy: dict, normalize_calls: list[tuple]): + class _FakeConfig(SimpleNamespace): + def model_dump(self): + return self.__dict__.copy() + + @classmethod + def model_validate(cls, payload): + return cls(**payload) + + fake_config = _FakeConfig( + max_search_iterations=10, + spark_mode="local", + nn_max_epochs=10, + nn_default_epochs=10, + allowed_model_types=None, + global_seed=None, + csv_delimiter=",", + csv_header=True, + enable_otel=False, + otel_endpoint=None, + otel_headers={}, + routing_config=None, + max_parallel_variants=1, + ) + + monkeypatch.setattr(main_module, "get_config", lambda: fake_config) + monkeypatch.setattr(main_module, "setup_logging", lambda *_args, **_kwargs: None) + monkeypatch.setattr(main_module, "setup_litellm", lambda *_args, **_kwargs: None) + monkeypatch.setattr(main_module, "setup_opentelemetry", lambda *_args, **_kwargs: None) + monkeypatch.setattr(main_module, "stop_spark_session", lambda *_args, **_kwargs: None) + monkeypatch.setattr(main_module, "get_or_create_spark_session", lambda *_args, **_kwargs: object()) + monkeypatch.setattr(main_module, "TreeSearchPolicy", lambda *args, **kwargs: object()) + monkeypatch.setattr(main_module, "LocalProcessRunner", lambda *args, **kwargs: object()) + + def _fake_build_model(**kwargs): + build_model_spy["kwargs"] = kwargs + return SimpleNamespace(performance=0.82), {"performance": 0.82}, None + + monkeypatch.setattr(main_module, "build_model", _fake_build_model) + + class _FakeNormalizer: + def __init__(self, _spark): + pass + + def normalize(self, input_uri, output_uri, read_options): + _ = read_options + output_path = Path(output_uri) + if output_path.name.endswith(".parquet"): + split_name = output_path.stem + else: + split_name = "train" + normalize_calls.append((split_name, input_uri)) + return f"normalized_{split_name}.parquet", SimpleNamespace(value="csv") + + monkeypatch.setattr(main_module, "DatasetNormalizer", _FakeNormalizer) + + +def test_main_prefers_train_dataset_uri_and_forwards_optional_splits(monkeypatch, tmp_path): + fake_integration = _FakeIntegration() + build_model_spy: dict = {} + normalize_calls: list[tuple] = [] + _patch_main_dependencies(monkeypatch, build_model_spy, normalize_calls) + + result = main_module.main( + intent="predict churn", + train_dataset_uri="s3://bucket/new-train.csv", + data_refs=["s3://bucket/legacy-train.csv"], + val_dataset_uri="s3://bucket/val.csv", + test_dataset_uri="s3://bucket/test.csv", + integration=fake_integration, + spark_mode="local", + work_dir=tmp_path, + user_id="user-1", + experiment_id="exp-1", + enable_final_evaluation=True, + ) + + assert result[0].performance == pytest.approx(0.82) + assert fake_integration.workspace_calls == [("exp-1", tmp_path)] + assert normalize_calls == [ + ("train", "s3://bucket/new-train.csv"), + ("val", "s3://bucket/val.csv"), + ("test", "s3://bucket/test.csv"), + ] + assert build_model_spy["kwargs"]["train_dataset_uri"] == "normalized_train.parquet" + assert build_model_spy["kwargs"]["val_dataset_uri"] == "normalized_val.parquet" + assert build_model_spy["kwargs"]["test_dataset_uri"] == "normalized_test.parquet" + + +def test_main_auto_enables_final_evaluation_when_test_dataset_is_provided(monkeypatch, tmp_path): + fake_integration = _FakeIntegration() + build_model_spy: dict = {} + normalize_calls: list[tuple] = [] + _patch_main_dependencies(monkeypatch, build_model_spy, normalize_calls) + + main_module.main( + intent="predict churn", + train_dataset_uri="s3://bucket/train.csv", + test_dataset_uri="s3://bucket/test.csv", + integration=fake_integration, + spark_mode="local", + work_dir=tmp_path, + user_id="user-1", + experiment_id="exp-1", + ) + + assert normalize_calls == [("train", "s3://bucket/train.csv"), ("test", "s3://bucket/test.csv")] + assert build_model_spy["kwargs"]["enable_final_evaluation"] is True + + +def test_main_nn_max_epochs_override_clamps_default_when_only_cap_is_set(monkeypatch, tmp_path): + fake_integration = _FakeIntegration() + build_model_spy: dict = {} + normalize_calls: list[tuple] = [] + _patch_main_dependencies(monkeypatch, build_model_spy, normalize_calls) + + main_module.main( + intent="predict churn", + train_dataset_uri="s3://bucket/train.csv", + nn_max_epochs=5, + integration=fake_integration, + spark_mode="local", + work_dir=tmp_path, + user_id="user-1", + experiment_id="exp-1", + ) + + used_config = build_model_spy["kwargs"]["config"] + assert used_config.nn_max_epochs == 5 + assert used_config.nn_default_epochs == 5 + + +def test_main_uses_data_refs_fallback_when_train_dataset_uri_missing(monkeypatch, tmp_path): + fake_integration = _FakeIntegration() + build_model_spy: dict = {} + normalize_calls: list[tuple] = [] + _patch_main_dependencies(monkeypatch, build_model_spy, normalize_calls) + + main_module.main( + intent="predict churn", + data_refs=["s3://bucket/train-a.csv", "s3://bucket/train-b.csv"], + integration=fake_integration, + spark_mode="local", + work_dir=tmp_path, + user_id="user-1", + experiment_id="exp-1", + ) + + assert fake_integration.workspace_calls == [("exp-1", tmp_path)] + assert normalize_calls == [("train", "s3://bucket/train-a.csv")] + assert build_model_spy["kwargs"]["train_dataset_uri"] == "normalized_train.parquet" + + +def test_main_requires_train_dataset_uri_or_data_refs(monkeypatch, tmp_path): + fake_integration = _FakeIntegration() + build_model_spy: dict = {} + normalize_calls: list[tuple] = [] + _patch_main_dependencies(monkeypatch, build_model_spy, normalize_calls) + + with pytest.raises(ValueError, match="train_dataset_uri is required"): + main_module.main( + intent="predict churn", + train_dataset_uri=None, + data_refs=None, + integration=fake_integration, + spark_mode="local", + work_dir=tmp_path, + user_id="user-1", + experiment_id="exp-1", + ) diff --git a/tests/unit/workflow/test_prepare_data_explicit_splits.py b/tests/unit/workflow/test_prepare_data_explicit_splits.py new file mode 100644 index 00000000..2047dc6c --- /dev/null +++ b/tests/unit/workflow/test_prepare_data_explicit_splits.py @@ -0,0 +1,181 @@ +"""Unit tests for prepare_data split resolution with explicit val/test datasets.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +pytest.importorskip("dataclasses_json") +pytest.importorskip("pyspark") + +from plexe.models import BuildContext +import plexe.workflow as workflow + + +class _DummyIntegration: + def get_artifact_location(self, artifact_type, dataset_uri, experiment_id, work_dir): # noqa: D401 + _ = dataset_uri, experiment_id + return str(work_dir / ".build" / "data" / artifact_type) + + def ensure_local(self, uris, work_dir): # noqa: D401 + _ = work_dir + return uris + + +def _make_context(tmp_path) -> BuildContext: + context = BuildContext( + user_id="user-1", + experiment_id="exp-1", + dataset_uri="train_input.parquet", + work_dir=tmp_path, + intent="predict churn", + ) + context.output_targets = ["target"] + context.task_analysis = {"recommended_split": {"ratios": {"train": 0.6, "val": 0.2, "test": 0.2}}} + return context + + +def test_prepare_data_uses_all_provided_splits_without_running_splitter(monkeypatch, tmp_path): + context = _make_context(tmp_path) + config = SimpleNamespace(train_sample_size=100, val_sample_size=40) + integration = _DummyIntegration() + calls = {"materialize": [], "sampler": None} + + def _materialize(_spark, dataset_uri, split_name, _context, output_dir): + calls["materialize"].append((split_name, dataset_uri)) + assert output_dir == str(tmp_path / ".build" / "data" / "splits") + return f"copied_{split_name}.parquet" + + class _FailingSplitter: + def __init__(self, *args, **kwargs): + raise AssertionError("Splitter should not be used when val and test are both provided") + + class _FakeSampler: + def __init__(self, *args, **kwargs): + pass + + def run(self, train_uri, val_uri, train_sample_size, val_sample_size, output_dir): + calls["sampler"] = (train_uri, val_uri, train_sample_size, val_sample_size, output_dir) + return "train_sample.parquet", "val_sample.parquet" + + monkeypatch.setattr(workflow, "_materialize_explicit_split", _materialize) + monkeypatch.setattr(workflow, "DatasetSplitterAgent", _FailingSplitter) + monkeypatch.setattr(workflow, "SamplingAgent", _FakeSampler) + monkeypatch.setattr(workflow, "_save_phase_checkpoint", lambda *args, **kwargs: None) + + workflow.prepare_data( + spark=object(), + training_dataset_uri="train_input.parquet", + val_dataset_uri="val_input.parquet", + test_dataset_uri="test_input.parquet", + context=context, + config=config, + integration=integration, + generate_test_set=True, + ) + + assert context.train_uri == "train_input.parquet" + assert context.val_uri == "copied_val.parquet" + assert context.test_uri == "copied_test.parquet" + assert calls["materialize"] == [("val", "val_input.parquet"), ("test", "test_input.parquet")] + assert calls["sampler"][0] == "train_input.parquet" + assert calls["sampler"][1] == "copied_val.parquet" + + +def test_prepare_data_generates_missing_test_when_only_val_is_provided(monkeypatch, tmp_path): + context = _make_context(tmp_path) + config = SimpleNamespace(train_sample_size=100, val_sample_size=40) + integration = _DummyIntegration() + calls = {"split_ratios": None, "split_output_dir": None} + + def _materialize(_spark, dataset_uri, split_name, _context, output_dir): + assert split_name == "val" + assert output_dir == str(tmp_path / ".build" / "data" / "splits") + return "copied_val.parquet" + + class _FakeSplitter: + def __init__(self, *args, **kwargs): + pass + + def run(self, split_ratios, output_dir): + calls["split_ratios"] = split_ratios + calls["split_output_dir"] = output_dir + return "split_train.parquet", "generated_test.parquet", None + + class _FakeSampler: + def __init__(self, *args, **kwargs): + pass + + def run(self, train_uri, val_uri, train_sample_size, val_sample_size, output_dir): + return "train_sample.parquet", "val_sample.parquet" + + monkeypatch.setattr(workflow, "_materialize_explicit_split", _materialize) + monkeypatch.setattr(workflow, "DatasetSplitterAgent", _FakeSplitter) + monkeypatch.setattr(workflow, "SamplingAgent", _FakeSampler) + monkeypatch.setattr(workflow, "_save_phase_checkpoint", lambda *args, **kwargs: None) + + workflow.prepare_data( + spark=object(), + training_dataset_uri="train_input.parquet", + val_dataset_uri="val_input.parquet", + test_dataset_uri=None, + context=context, + config=config, + integration=integration, + generate_test_set=True, + ) + + assert calls["split_ratios"] == {"train": 0.8, "val": 0.2} + assert calls["split_output_dir"] == str(tmp_path / ".build" / "data" / "splits" / "generated") + assert context.train_uri == "split_train.parquet" + assert context.val_uri == "copied_val.parquet" + assert context.test_uri == "generated_test.parquet" + + +def test_prepare_data_generates_missing_val_when_only_test_is_provided(monkeypatch, tmp_path): + context = _make_context(tmp_path) + config = SimpleNamespace(train_sample_size=100, val_sample_size=40) + integration = _DummyIntegration() + calls = {"split_ratios": None} + + def _materialize(_spark, dataset_uri, split_name, _context, output_dir): + assert split_name == "test" + assert output_dir == str(tmp_path / ".build" / "data" / "splits") + return "copied_test.parquet" + + class _FakeSplitter: + def __init__(self, *args, **kwargs): + pass + + def run(self, split_ratios, output_dir): + calls["split_ratios"] = split_ratios + return "split_train.parquet", "split_val.parquet", None + + class _FakeSampler: + def __init__(self, *args, **kwargs): + pass + + def run(self, train_uri, val_uri, train_sample_size, val_sample_size, output_dir): + return "train_sample.parquet", "val_sample.parquet" + + monkeypatch.setattr(workflow, "_materialize_explicit_split", _materialize) + monkeypatch.setattr(workflow, "DatasetSplitterAgent", _FakeSplitter) + monkeypatch.setattr(workflow, "SamplingAgent", _FakeSampler) + monkeypatch.setattr(workflow, "_save_phase_checkpoint", lambda *args, **kwargs: None) + + workflow.prepare_data( + spark=object(), + training_dataset_uri="train_input.parquet", + val_dataset_uri=None, + test_dataset_uri="test_input.parquet", + context=context, + config=config, + integration=integration, + generate_test_set=False, + ) + + assert calls["split_ratios"] == {"train": 0.8, "val": 0.2} + assert context.train_uri == "split_train.parquet" + assert context.val_uri == "split_val.parquet" + assert context.test_uri == "copied_test.parquet"