From b851b1e8ac2c33e682baabb5fdab1504e03e68ab Mon Sep 17 00:00:00 2001 From: Brandon Walker Date: Sat, 24 Jan 2026 13:39:47 -0600 Subject: [PATCH] Add comprehensive documentation, E2E tests, and CI/CD workflow - Add JSDoc enforcement and comprehensive docstrings across Python and JavaScript - Implement Playwright E2E test suite for UI validation - Add GitHub Actions CI workflow for cross-platform testing - Add full autolog support for PyTorch Lightning and TensorFlow/Keras - Reorganize examples into core, ml_frameworks, and domain_specific categories - Consolidate config files and improve documentation structure - Remove unique constraint on run names for better usability - Update README with improved problem statement and ecosystem comparison --- .depcheckrc.json | 10 - .github/workflows/ci.yml | 204 ++++++ .gitignore | 5 + .pre-commit-config.yaml | 7 +- CODE_OF_CONDUCT.md | 3 +- CONTRIBUTING.md | 2 +- README.md | 156 +++-- artifacta/artifacta/artifacts.py | 104 ++- artifacta/artifacta/autolog.py | 187 ++++- artifacta/artifacta/context.py | 58 +- artifacta/artifacta/emitter.py | 110 ++- artifacta/artifacta/integrations/base.py | 25 +- .../artifacta/integrations/dataset_utils.py | 180 +++++ .../integrations/pytorch_lightning.py | 238 ++++++- artifacta/artifacta/integrations/sklearn.py | 458 +++++++++++++ .../artifacta/integrations/tensorflow.py | 245 ++++++- artifacta/artifacta/integrations/xgboost.py | 641 ++++++++++++++++++ artifacta/artifacta/metadata.py | 73 +- .../artifacta/metadata_extractors/torch.py | 91 ++- artifacta/artifacta/monitor.py | 59 +- artifacta/artifacta/primitives.py | 109 ++- artifacta/artifacta/run.py | 151 ++++- artifacta/artifacta/utils.py | 61 +- artifacta/tests/test_basic.py | 2 +- artifacta/tests/test_sklearn_autolog.py | 504 ++++++++++++++ artifacta/tests/test_utils.py | 319 +++++++++ artifacta/tests/test_xgboost_autolog.py | 515 ++++++++++++++ artifacta_ui/__init__.py | 4 +- eslint.config.js => config/eslint.config.js | 36 + config/jsdoc.json | 32 + config/playwright.config.js | 59 ++ vite.config.js => config/vite.config.js | 8 +- docs/Makefile | 12 + docs/api.rst | 20 +- docs/development.rst | 437 +++++++++++- docs/examples.rst | 201 +++--- docs/index.rst | 4 +- docs/ui-api.rst | 130 ++++ docs/user-guide.rst | 583 ++++++++++++++-- examples/README.md | 34 + examples/core/01_basic_tracking.py | 111 +++ examples/core/02_all_primitives.py | 225 ++++++ .../ab_testing_experiment.py | 38 +- .../domain_specific/protein_expression.py | 348 ++++++++++ examples/{ => ml_frameworks}/pytorch_mnist.py | 115 ++-- .../ml_frameworks/sklearn_classification.py | 248 +++++++ .../tensorflow_regression.py | 155 +++-- examples/ml_frameworks/xgboost_regression.py | 253 +++++++ examples/requirements.txt | 24 + examples/run_all_examples.py | 263 +++++++ index.html | 15 - package.json | 29 +- pyproject.toml | 92 ++- pytest.ini | 16 - .../components/ArtifactTab/ArtifactTab.jsx | 58 ++ .../ArtifactsPanel/ArtifactsPanel.jsx | 137 +++- src/app/components/ChatTab/ChatTab.jsx | 64 ++ src/app/components/LineageTab/LineageTab.jsx | 78 ++- .../LineageTab/lineageNodeFactory.js | 63 +- .../FileAttachmentExtension.js | 18 + .../ProjectNotesTab/FileAttachmentNode.jsx | 31 + .../ProjectNotesTab/ProjectNotesTab.jsx | 106 +++ .../ProjectsPanel/ProjectsPanel.jsx | 61 ++ .../RunSelector/CollapsibleButton.jsx | 19 +- src/app/components/RunSelector/RunFilter.jsx | 60 +- .../components/RunSelector/RunSelector.jsx | 67 +- src/app/components/RunTree/RunTree.jsx | 81 ++- src/app/components/SweepsTab/SweepsTab.jsx | 15 +- src/app/components/layout/Sidebar.jsx | 37 + .../TabbedInterface/TabbedInterface.jsx | 36 +- .../components/ui/ComponentSettingsMenu.jsx | 43 ++ .../visualizations/DraggableVisualization.jsx | 85 ++- .../UniversalVisualizationRenderer.jsx | 31 +- .../visualizations/plots/BarChart.jsx | 56 +- .../visualizations/plots/CurveChart.jsx | 100 ++- .../visualizations/plots/Heatmap.jsx | 58 +- .../visualizations/plots/Histogram.jsx | 55 +- .../visualizations/plots/LinePlot.jsx | 41 +- .../plots/ParallelCoordinatesChart.jsx | 65 +- .../plots/ParameterCorrelationChart.jsx | 19 +- .../visualizations/plots/ScatterPlot.jsx | 33 +- .../visualizations/plots/ViolinPlot.jsx | 37 +- .../visualizations/shared/PlotTooltip.jsx | 71 +- src/app/hooks/useCanvasSetup.js | 83 ++- src/app/hooks/useCanvasTooltip.js | 75 +- src/app/hooks/useDragModeResize.js | 36 +- src/app/hooks/useLayoutManager.js | 76 ++- src/app/hooks/useResponsiveCanvas.js | 61 +- src/app/hooks/useRunData.js | 53 +- .../pages/Workspace/CollapsibleSection.jsx | 8 + src/app/pages/Workspace/PlotSection.jsx | 29 +- src/app/pages/Workspace/Workspace.jsx | 5 + src/app/utils/artifactColors.js | 4 +- src/app/utils/comparisonPlotDiscovery.js | 144 +++- src/app/utils/metricAggregation.js | 135 +++- src/app/utils/plotDiscovery.js | 106 ++- src/app/utils/sweepDetection.js | 97 ++- src/core/api/ApiClient.js | 105 ++- src/core/utils/csvExport.js | 37 +- src/core/utils/formatters.js | 6 +- src/utils/consoleFilter.js | 4 + src/utils/debugLogger.js | 33 +- tests/autolog/test_checkpoint_e2e.py | 8 +- tests/autolog/test_pytorch_lightning.py | 243 ++++++- tests/autolog/test_tensorflow.py | 175 ++++- tests/domains/test_artifact_reuse.py | 20 +- tests/domains/test_pytorch.py | 8 +- tests/domains/test_system.py | 2 +- tests/e2e/core.spec.js | 72 ++ tests/e2e/setup.js | 122 ++++ tests/e2e/teardown.js | 24 + tests/e2e/visualization.spec.js | 118 ++++ tests/generate_all_notebooks.py | 24 +- tests/helpers/api.py | 4 +- tests/helpers/notebook.py | 12 +- tests/helpers/notebook_html.py | 6 +- tests/helpers/video.py | 2 +- tracking-server/cli.py | 71 +- tracking-server/config.py | 54 +- tracking-server/database.py | 97 ++- tracking-server/main.py | 156 ++++- tracking-server/routes/artifacts.py | 117 +++- tracking-server/routes/chat.py | 96 ++- tracking-server/routes/dependencies.py | 21 +- tracking-server/routes/health.py | 14 +- tracking-server/routes/projects.py | 102 ++- tracking-server/routes/runs.py | 128 +++- tracking-server/routes/websocket.py | 129 +++- 128 files changed, 11477 insertions(+), 1079 deletions(-) delete mode 100644 .depcheckrc.json create mode 100644 .github/workflows/ci.yml create mode 100644 artifacta/artifacta/integrations/dataset_utils.py create mode 100644 artifacta/artifacta/integrations/sklearn.py create mode 100644 artifacta/artifacta/integrations/xgboost.py create mode 100644 artifacta/tests/test_sklearn_autolog.py create mode 100644 artifacta/tests/test_utils.py create mode 100644 artifacta/tests/test_xgboost_autolog.py rename eslint.config.js => config/eslint.config.js (59%) create mode 100644 config/jsdoc.json create mode 100644 config/playwright.config.js rename vite.config.js => config/vite.config.js (69%) create mode 100644 docs/ui-api.rst create mode 100644 examples/README.md create mode 100644 examples/core/01_basic_tracking.py create mode 100644 examples/core/02_all_primitives.py rename examples/{ => domain_specific}/ab_testing_experiment.py (94%) create mode 100644 examples/domain_specific/protein_expression.py rename examples/{ => ml_frameworks}/pytorch_mnist.py (85%) create mode 100644 examples/ml_frameworks/sklearn_classification.py rename examples/{ => ml_frameworks}/tensorflow_regression.py (82%) create mode 100644 examples/ml_frameworks/xgboost_regression.py create mode 100644 examples/requirements.txt create mode 100644 examples/run_all_examples.py delete mode 100644 index.html delete mode 100644 pytest.ini create mode 100644 tests/e2e/core.spec.js create mode 100644 tests/e2e/setup.js create mode 100644 tests/e2e/teardown.js create mode 100644 tests/e2e/visualization.spec.js diff --git a/.depcheckrc.json b/.depcheckrc.json deleted file mode 100644 index 6a635ca..0000000 --- a/.depcheckrc.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "ignores": [ - "knip" - ], - "ignore-dirs": [ - "dist", - "node_modules", - ".git" - ] -} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..13fec76 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,204 @@ +name: CI + +on: + push: + branches: [ main, docs/* ] + pull_request: + branches: [ main ] + +jobs: + test: + name: Test Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + + - name: Build UI + run: npm install && npm run build + + - name: Start Artifacta server (Unix) + if: runner.os != 'Windows' + run: | + mkdir -p data + nohup artifacta ui --port 8000 > server.log 2>&1 & + echo $! > server.pid + + # Wait for server to be ready + echo "Waiting for server to start..." + SERVER_READY=false + for i in {1..30}; do + if curl -s http://localhost:8000/health > /dev/null 2>&1; then + echo "βœ“ Server is ready" + SERVER_READY=true + break + fi + echo "Waiting for server... ($i/30)" + sleep 1 + done + + # Fail if server never started + if [ "$SERVER_READY" = "false" ]; then + echo "βœ— Server failed to start after 30 seconds" + echo "Server log:" + cat server.log + exit 1 + fi + + - name: Start Artifacta server (Windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + if (-not (Test-Path data)) { New-Item -ItemType Directory -Path data } + $process = Start-Process python -ArgumentList "-m","tracking_server.cli","ui","--port","8000" -PassThru -WindowStyle Hidden + Start-Sleep -Seconds 2 + + # Health check with retries + $maxAttempts = 5 + $attempt = 0 + $success = $false + while ($attempt -lt $maxAttempts -and -not $success) { + try { + $response = Invoke-WebRequest -Uri "http://127.0.0.1:8000/health" -UseBasicParsing -TimeoutSec 2 + if ($response.StatusCode -eq 200) { + $success = $true + Write-Host "Server is ready" + } + } catch { + $attempt++ + if ($attempt -lt $maxAttempts) { + Start-Sleep -Seconds 3 + } + } + } + + if (-not $success) { + Write-Host "Server failed to start after $maxAttempts attempts" + exit 1 + } + + - name: Run pytest + run: pytest tests/ -v --tb=short + env: + TRACKING_SERVER_HOST: ${{ runner.os == 'Windows' && '127.0.0.1' || 'localhost' }} + + - name: Stop Artifacta server (Unix) + if: always() && runner.os != 'Windows' + run: | + if [ -f server.pid ]; then + kill $(cat server.pid) || true + rm server.pid + fi + + - name: Stop Artifacta server (Windows) + if: always() && runner.os == 'Windows' + shell: pwsh + run: | + Get-Process python -ErrorAction SilentlyContinue | Where-Object { $_.CommandLine -like '*tracking_server*' } | Stop-Process -Force -ErrorAction SilentlyContinue + Write-Host "Server cleanup completed" + + e2e: + name: E2E Tests on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + + - name: Install Node dependencies and build UI + run: npm install && npm run build + + - name: Install Playwright browsers + run: npx playwright install --with-deps chromium + + - name: Run E2E tests + run: npm run test:e2e + env: + ARTIFACTA_URL: ${{ runner.os == 'Windows' && 'http://127.0.0.1:8000' || 'http://localhost:8000' }} + + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + + - name: Run ruff + run: ruff check . --fix + + - name: Run mypy + run: mypy --ignore-missing-imports tracking-server + + build: + name: Build Package + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install build tools + run: | + python -m pip install --upgrade pip + pip install build + + - name: Build package + run: python -m build + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ diff --git a/.gitignore b/.gitignore index 42eb60f..360f883 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,8 @@ logs/ docs/_build/ docs/_static/ docs/_templates/ + +# Playwright test outputs +test-results/ +playwright-report/ +playwright/.cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 444f419..5db72d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,16 +29,19 @@ repos: - --convention=google files: ^(artifacta|tracking-server)/ - # ESLint - JavaScript/React linter + # ESLint - JavaScript/React linter with JSDoc enforcement - repo: https://github.com/pre-commit/mirrors-eslint rev: v9.17.0 hooks: - id: eslint files: \.[jt]sx?$ types: [file] + args: [--config, config/eslint.config.js] additional_dependencies: - eslint@9.17.0 - eslint-plugin-react@7.37.2 + - eslint-plugin-react-hooks@5.1.0 + - eslint-plugin-jsdoc@50.6.1 # Knip - Find unused files, exports, and dependencies (JavaScript/TypeScript) - repo: local @@ -51,7 +54,7 @@ repos: files: \.[jt]sx?$ - id: depcheck name: depcheck - unused npm dependencies - entry: npx depcheck --ignore-dirs=dist,node_modules,.git + entry: npx depcheck language: system pass_filenames: false files: package\.json diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 9035737..91dd305 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -3,7 +3,7 @@ ## Our Pledge -We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. @@ -82,4 +82,3 @@ For answers to common questions about this code of conduct, see the FAQ at [http [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq [translations]: https://www.contributor-covenant.org/translations - diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 42e4d01..a6b63e4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing to Artifacta -Thank you for your interest in contributing to Artifacta! πŸŽ‰ +Thank you for your interest in contributing to Artifacta! ## How to Contribute diff --git a/README.md b/README.md index 5331e59..1e94ee4 100644 --- a/README.md +++ b/README.md @@ -12,46 +12,65 @@ --- -## 🎯 The Problem +## The Problem Modern data science and machine learning workflows involve countless experimentsβ€”tweaking hyperparameters, adjusting data preprocessing, testing different architectures, updating dependencies, modifying code. **Every change produces different results**, but tracking and comparing these variations manually becomes overwhelming: -- πŸ“‹ Which parameters, environment, or code version led to that breakthrough result last week? -- πŸ” How does changing the learning rate affect convergence across multiple runs? -- πŸ“Š What's the actual performance difference between model architectures? -- πŸ€” Which preprocessing steps improved accuracy by 2%? -- πŸ”§ Did upgrading that dependency break model performance? -- πŸ’» What code changes caused the regression? +- Which parameters, environment, or code version led to that breakthrough result last week? +- How does changing the learning rate affect convergence across multiple runs? +- What's the actual performance difference between model architectures? +- Which preprocessing steps improved accuracy by 2%? +- Did upgrading that dependency break model performance? +- What code changes caused the regression? -Without systematic tracking of **parameters, metrics, code changes, dependencies, and environment**, you're flying blindβ€”relying on scattered notes, terminal output, and memory. **Artifacta solves this** by automatically capturing experiments, configurations, code versions, and artifacts in one place with intelligent visualization. +Without systematic tracking of **parameters, metrics, code changes, dependencies, and environment**, you're flying blindβ€”relying on scattered notes, terminal output, and memory. + +But even with tracking, making sense of the data is hard: + +- Manually comparing metrics across dozens of experiments +- Spotting patterns in hyperparameter sweeps +- Understanding why one approach outperformed another +- Deciding which direction to explore next + +**Artifacta solves both problems** by automatically capturing experiments AND helping you understand themβ€”with intelligent visualizations, multi-run comparisons, and LLM-powered analysis to explain what happened and why. --- -## 🌍 Ecosystem & Alternatives +## Ecosystem & Alternatives -Artifacta is part of a growing ecosystem of experiment tracking tools. Popular alternatives include: +| Feature | Artifacta | MLflow | W&B | Neptune | Comet | +|---------|-----------|--------|-----|---------|-------| +| Zero-config install | βœ… | ❌ | ❌ | ❌ | ❌ | +| 100% offline | βœ… | βœ… | ❌ | ❌ | ❌ | +| Built-in lab notebook | βœ… | ❌ | ❌ | ❌ | ❌ | +| AI assistant (LiteLLM) | βœ… | ❌ | ❌ | ❌ | ❌ | +| Domain-agnostic | βœ… | βœ… | ❌ | ⚠️ | ❌ | +| Framework autolog | βœ… | βœ… | βœ… | βœ… | βœ… | +| Team collaboration | ❌ | ⚠️ | βœ… | βœ… | βœ… | +| Model deployment | ❌ | βœ… | βœ… | βœ… | βœ… | -- [**MLflow**](https://mlflow.org/) - Open-source platform from Databricks for ML lifecycle management -- [**Weights & Biases**](https://wandb.ai/) - Cloud-first experiment tracking with team collaboration features -- [**Neptune.ai**](https://neptune.ai/) - Metadata store for MLOps with extensive integrations -- [**Comet ML**](https://www.comet.com/) - ML platform with experiment tracking and model production monitoring +**Legend:** βœ… Yes | ⚠️ Partial | ❌ No -**Why Artifacta?** We focus on **automatic visualization discovery**, **domain-agnostic tracking** (not just ML), and **simple self-hosting** with a pre-built UI. No heavy dependencies, no mandatory cloud servicesβ€”just install and start tracking. +For detailed feature comparisons (deployment, visualization, integrations), see the [full documentation](https://docs.artifacta.ai). --- -## ✨ Key Features +## Why Choose Artifacta? + +**What makes Artifacta different:** -- 🌐 **Domain-agnostic** - Track any experiment comparing parameters, data, and outcomes -- πŸ“Š **Automatic visualization** - Plots discovered from logged data structure -- πŸ”— **Artifact tracking** - Track datasets, models, code, and results with full provenance -- πŸ”„ **Multi-run comparison** - Overlay time series and curves for easy comparison -- 🎯 **Hyperparameter analysis** - Automatically detect and analyze parameter impact on outcomes -- πŸ’¬ **AI assistant** - Chat interface for experiment insights (OpenAI, Anthropic, local LLMs) +- **Zero configuration**: Pre-built UI bundled with Python packageβ€”`pip install` and you're done. No Node.js, Docker, or build tools required +- **Truly offline-first**: Works 100% locally without any cloud dependencies, license servers, or internet connection +- **Server-side plot generation**: Log data primitives (Series, Scatter, Matrix), not matplotlib figuresβ€”Artifacta renders plots for you. No need to create and upload images (though you can if you want) +- **Built-in electronic lab notebook**: Rich text editor with LaTeX support, file attachments, and per-project organizationβ€”not available in any competitor +- **AI chat interface**: Built-in LLM chat (OpenAI, Anthropic, local models) to analyze experiments, results, and code. W&B and Comet have AI features in premium tiers only +- **Domain-agnostic design**: Primitives work for any fieldβ€”ML, A/B tests, physics, finance, genomics, climate science. Not ML-only like most alternatives +- **Rich artifact previews**: Built-in viewers for video, audio, PDFs, code, images. MLflow only previews images; others require external viewers +- **Interactive artifact lineage**: Visual flow graph showing how artifacts relate. MLflow has no lineage visualization --- -## 🎨 Visual Overview +## Visual Overview **Automatic Plot Discovery** @@ -67,7 +86,7 @@ Browse and preview datasets, models, code, images, videos, and documents with bu --- -## πŸš€ Quick Start +## Quick Start ### Installation @@ -79,24 +98,6 @@ pip install artifacta That's it! The UI is pre-built and bundled. No Node.js required. -#### Development Installation - -If you want to contribute or modify the source: - -**Prerequisites:** Python 3.9+, Node.js 16+ - -```bash -# Clone the repository -git clone https://github.com/walkerbdev/artifacta.git -cd artifacta - -# Build UI from source -npm install && npm run build - -# Install Python package in editable mode -pip install -e . -``` - ### Start Tracking Server ```bash @@ -120,10 +121,10 @@ artifacta ui --dev ### Log Your First Experiment ```python -import artifacta as ds +from artifacta import Series, init, log # Initialize a run -run = ds.init( +run = init( project="my-project", name="experiment-1", config={"learning_rate": 0.001, "batch_size": 32} @@ -133,7 +134,7 @@ run = ds.init( for epoch in range(10): train_loss = train_model() # Your training code - ds.log("metrics", ds.Series( + log("metrics", Series( index="epoch", fields={ "train_loss": [train_loss], @@ -147,7 +148,7 @@ run.log_artifact("model.pt", "path/to/model.pt") --- -## πŸ“š Documentation +## Documentation Full documentation available at: [User Guide](docs/user-guide.rst) @@ -155,7 +156,8 @@ Build and serve docs locally: ```bash pip install artifacta[dev] -cd docs && make html +cd docs +make html # Generates both JSDoc (UI) and Sphinx (Python) docs python -m http.server 8001 --directory _build/html ``` @@ -163,7 +165,7 @@ Then navigate to http://localhost:8001 --- -## πŸ“Š Core Primitives +## Core Primitives Artifacta provides rich primitives for structured logging: @@ -179,7 +181,7 @@ All primitives are automatically visualized in the Plots tab. --- -## πŸ’» Web UI Features +## Web UI Features - **Plots** - Auto-generated visualizations with multi-run overlay - **Sweeps** - Hyperparameter analysis with parallel coordinates @@ -191,39 +193,47 @@ All primitives are automatically visualized in the Plots tab. --- -## πŸ’‘ Examples +## Examples + +Examples are organized by category in [examples/](examples/): -See [examples/](examples/) for runnable examples: +**Core examples** ([examples/core/](examples/core/)): +- Basic tracking - Metrics, parameters, and artifacts +- All primitives - Series, Scatter, Distribution, Matrix, Bar, and more -- **PyTorch MNIST** - Image classification with autolog -- **TensorFlow Regression** - Time series forecasting -- **A/B Testing** - Conversion rate analysis with statistical tests +**ML frameworks** ([examples/ml_frameworks/](examples/ml_frameworks/)): +- PyTorch (MNIST image classification) +- TensorFlow/Keras (regression) +- scikit-learn (classification) +- XGBoost (regression) -Additional domain examples available in [tests/domains/](tests/domains/): +**Domain-specific** ([examples/domain_specific/](examples/domain_specific/)): +- A/B testing with statistical analysis +- Protein expression analysis -- Climate modeling, Computer vision, Finance, Genomics, Physics, Robotics, and more +**Additional domains** - 14 examples in [tests/domains/](tests/domains/): +- Climate, Computer vision, Finance, Genomics, Physics, Robotics, Audio/Video, and more **Run examples:** ```bash +# Linux/macOS source venv/bin/activate -python examples/ab_testing.py +python examples/core/01_basic_tracking.py +python examples/ml_frameworks/pytorch_mnist.py +python examples/domain_specific/ab_testing_experiment.py + +# Windows (PowerShell) +venv\Scripts\Activate.ps1 +python examples/core/01_basic_tracking.py +python examples/ml_frameworks/pytorch_mnist.py +python examples/domain_specific/ab_testing_experiment.py + +# Windows (cmd) +venv\Scripts\activate.bat +python examples/core/01_basic_tracking.py +python examples/ml_frameworks/pytorch_mnist.py +python examples/domain_specific/ab_testing_experiment.py ``` --- - -## πŸ§ͺ Running Tests - -Start the tracking server in one terminal: - -```bash -source venv/bin/activate -artifacta ui -``` - -Run tests in another terminal: - -```bash -source venv/bin/activate -pytest tests/ -``` diff --git a/artifacta/artifacta/artifacts.py b/artifacta/artifacta/artifacts.py index 8be55f8..1337a3c 100644 --- a/artifacta/artifacta/artifacts.py +++ b/artifacta/artifacta/artifacts.py @@ -1,9 +1,48 @@ -"""Artifact file collection utilities.""" +"""Artifact file collection and metadata extraction utilities. + +This module handles the discovery, analysis, and metadata extraction of artifact files. +It provides a unified interface for collecting both individual files and entire directories, +with intelligent MIME type detection and optional content inlining for small text files. + +Architecture: + The module operates in two main modes: + + 1. Single file collection: Extract metadata from one file + 2. Directory collection: Recursively discover and process all files + + Both modes produce a consistent data structure that's agnostic to the artifact type, + allowing downstream systems to handle any file uniformly. + +Key Features: + - MIME type detection: Uses Python's mimetypes library with fallback heuristics + - Content inlining: Optionally embeds small text files (< 100KB by default) + - Text detection: Multi-stage approach (MIME type + read test + encoding detection) + - Metadata tagging: Automatic classification (code, image, tabular, etc.) + - Directory traversal: Recursive with hidden file filtering + +MIME Detection Algorithm: + 1. Try mimetypes.guess_type() based on file extension + 2. If None, attempt to read first 1KB as UTF-8 text + 3. If successful -> "text/plain", if fails -> "application/octet-stream" + 4. Check if MIME type is text-like (text/*, application/json, etc.) + +Content Inlining Strategy: + - Only inline text files (not binary) + - Only if file size <= max_inline_size (default 100KB) + - If read fails (encoding errors), mark as non-text + - This avoids loading large files or binary data into memory + +File Type Classification: + - Code: Based on CODE_EXTENSIONS set (40+ programming languages) + - Image: Based on MIME type starting with "image/" + - Tabular: Based on MIME type or .csv extension + - This metadata helps the UI render appropriate previews +""" import json import mimetypes from pathlib import Path -from typing import Any +from typing import Any, Dict, Union # File extensions that indicate code files (for hash.code tag detection) CODE_EXTENSIONS = { @@ -38,22 +77,53 @@ def collect_files( - path: str | Path, include_content: bool = False, max_inline_size: int = 100_000 -) -> dict[str, Any]: - """Collect file metadata from a path (file or directory). - - Returns unified structure for all artifact types - agnostic to content. + path: Union[str, Path], include_content: bool = False, max_inline_size: int = 100_000 +) -> Dict[str, Any]: + """Collect file metadata from a path (file or directory) with recursive traversal. + + This is the main entry point for artifact collection. It handles both single files + and entire directory trees, producing a unified data structure for storage in the + tracking server. + + Collection Algorithm: + 1. Validate path exists (raise FileNotFoundError if not) + 2. Determine collection mode: + - File mode: Process single file with its name as relative path + - Directory mode: Recursively discover all files via rglob("*") + 3. For each file: + - Skip hidden files (starting with ".") + - Extract full metadata via _extract_file_info() + - Accumulate total size + 4. Return unified structure with files list and summary statistics + + Directory traversal: + - Uses Path.rglob("*") for recursive globbing (depth-first) + - Filters out directories (only collects actual files) + - Sorts file paths for deterministic ordering + - Computes relative paths from directory root for portability + + Why unified structure: + - Same format for single file vs directory artifacts + - Downstream code doesn't need to special-case different artifact types + - Easy to serialize to JSON for database storage + - Frontend can render both cases with same component Args: - path: Path to file or directory - include_content: Whether to inline text file content - max_inline_size: Maximum file size (bytes) to inline + path: Path to file or directory to collect + include_content: Whether to inline text file content (default False) + Set True for code artifacts, False for large model checkpoints + max_inline_size: Maximum file size in bytes to inline (default 100KB) + Files larger than this are never inlined, even if text Returns: - dict with: - - files: List of file dicts with path, mime_type, content, metadata - - total_files: Count of files - - total_size: Total size in bytes + Dictionary with: + - files: List of file dictionaries with path, mime_type, content, metadata + - total_files: Count of files collected (int) + - total_size: Total size in bytes across all files (int) + + Raises: + FileNotFoundError: If path does not exist + ValueError: If path is neither file nor directory (e.g., socket, device) """ path_obj = Path(path) @@ -89,7 +159,7 @@ def collect_files( def _extract_file_info( abs_path: Path, rel_path: str, include_content: bool, max_inline_size: int -) -> dict[str, Any]: +) -> Dict[str, Any]: """Extract metadata and optional content from a single file.""" # Detect MIME type mime_type, _ = mimetypes.guess_type(str(abs_path)) @@ -144,11 +214,11 @@ def _extract_file_info( return file_info -def files_to_json(files_data: dict[str, Any]) -> str: +def files_to_json(files_data: Dict[str, Any]) -> str: """Convert files data structure to JSON string for storage.""" return json.dumps(files_data, indent=None, separators=(",", ":")) -def json_to_files(json_str: str) -> dict[str, Any]: +def json_to_files(json_str: str) -> Dict[str, Any]: """Parse JSON string back to files data structure.""" return json.loads(json_str) diff --git a/artifacta/artifacta/autolog.py b/artifacta/artifacta/autolog.py index d66b739..57b4e08 100644 --- a/artifacta/artifacta/autolog.py +++ b/artifacta/artifacta/autolog.py @@ -1,7 +1,83 @@ -"""Automatic logging for ML frameworks. - -Enables automatic checkpoint logging and metadata extraction for -PyTorch Lightning, TensorFlow/Keras, and other ML frameworks. +"""Automatic logging integration for ML frameworks. + +This module provides a unified autolog() interface that automatically integrates +Artifacta with popular ML frameworks (scikit-learn, XGBoost, LightGBM, PyTorch Lightning, TensorFlow/Keras). +Once enabled, Artifacta automatically captures model checkpoints, parameters, metrics, and +framework-specific metadata without requiring explicit logging calls in user code. + +Architecture: + The module acts as a facade/dispatcher: + + 1. User calls autolog() with optional framework parameter + 2. If framework=None, auto-detect via import checks + 3. Dispatch to framework-specific integration module + 4. Integration module patches framework classes (callbacks, hooks) + 5. Track enabled state in _AUTOLOG_ENABLED global dict + +Auto-Detection Algorithm: + The _detect_framework() function uses import-based detection: + 1. Try import sklearn -> return "sklearn" + 2. Try import xgboost -> return "xgboost" + 3. Try import lightgbm -> return "lightgbm" + 4. Try import pytorch_lightning -> return "pytorch" + 5. Try import tensorflow -> return "tensorflow" + 6. If all fail -> raise RuntimeError + + Why this order: + - Traditional ML frameworks (sklearn, xgboost, lightgbm) checked first (most common) + - PyTorch Lightning checked after (more specific than PyTorch) + - TensorFlow checked last + - Import-based detection is fast and reliable (no version parsing) + +Integration Strategy (per framework): + + Scikit-learn: + - Patches fit() methods of all sklearn estimators + - Logs parameters via get_params(deep=True) + - Computes and logs training metrics (accuracy, precision, recall, F1, etc.) + - Saves fitted model as pickle artifact + + XGBoost: + - Patches xgboost.train() and sklearn API + - Uses callbacks to log metrics per iteration + - Logs feature importance plots + - Handles early stopping + + LightGBM: + - Patches lightgbm.train() and sklearn API + - Uses callbacks to log metrics per iteration + - Logs feature importance plots + - Handles early stopping + + PyTorch Lightning: + - Registers a global callback in CALLBACK_REGISTRY + - Callback hooks into on_save_checkpoint() + - Automatically logs checkpoint file as artifact + - Includes metadata: epoch, global_step, checkpoint path + + TensorFlow/Keras: + - Patches tf.keras.callbacks.ModelCheckpoint class + - Wraps on_epoch_end() to intercept checkpoint saves + - Automatically logs checkpoint file as artifact + - Includes metadata: epoch, model architecture, optimizer config + +State Management: + - _AUTOLOG_ENABLED: Global dict tracking which frameworks are enabled + - Prevents duplicate patching (idempotent) + - Allows selective disabling via disable() + +Why separate integration modules: + - Keeps framework dependencies optional (import only when needed) + - Each framework has unique patching strategy + - Easy to add new frameworks without modifying core autolog + - Allows framework-specific configuration options + +Design Philosophy: + - Zero-friction: One call (autolog()) enables all automatic logging + - Framework-agnostic: Same API works across different ML frameworks + - Non-invasive: No changes to user training code required + - Fail-safe: Missing framework dependencies raise clear errors + - Reversible: disable() removes all patches and restores original behavior """ from typing import Optional @@ -11,23 +87,30 @@ def autolog( framework: Optional[str] = None, - log_checkpoints: bool = True, + log_models: bool = True, + log_metrics: bool = True, + log_params: bool = True, ): """Enable automatic logging for ML frameworks. - Automatically logs model checkpoints as artifacts during training. - All checkpoints are logged with metadata (epoch, step, framework info). + Automatically logs parameters, metrics, and models during training. + Works with scikit-learn, XGBoost, LightGBM, PyTorch Lightning, and TensorFlow/Keras. Args: - framework: "pytorch", "tensorflow", or None (auto-detect) - log_checkpoints: Automatically log model checkpoints + framework: Framework name ("sklearn", "xgboost", "lightgbm", "pytorch", "tensorflow") + or None for auto-detection + log_models: Automatically log trained models as artifacts + log_metrics: Automatically log training metrics + log_params: Automatically log model parameters/hyperparameters Example: >>> import artifacta as ds - >>> ds.autolog() # Auto-detect framework and log checkpoints + >>> ds.autolog() # Auto-detect framework >>> - >>> # Disable checkpoint logging - >>> ds.autolog(log_checkpoints=False) + >>> # Scikit-learn + >>> from sklearn.ensemble import RandomForestClassifier + >>> clf = RandomForestClassifier() + >>> clf.fit(X_train, y_train) # Params, metrics, model auto-logged >>> >>> # PyTorch Lightning >>> trainer = pl.Trainer(...) @@ -35,9 +118,6 @@ def autolog( >>> >>> # TensorFlow/Keras >>> model.fit(X_train, y_train) # Checkpoints auto-logged every epoch - - Note: - Autolog only captures checkpoints. Use ds.log() to log metrics for visualization. """ global _AUTOLOG_ENABLED @@ -45,28 +125,64 @@ def autolog( if framework is None: framework = _detect_framework() - if framework == "pytorch": + if framework == "sklearn": + from artifacta.integrations import sklearn + + sklearn.enable_autolog( + log_models=log_models, + log_training_metrics=log_metrics, + ) + _AUTOLOG_ENABLED["sklearn"] = True + + elif framework == "xgboost": + from artifacta.integrations import xgboost + + xgboost.enable_autolog(log_models=log_models) + _AUTOLOG_ENABLED["xgboost"] = True + + elif framework == "lightgbm": + # TODO: Implement LightGBM autolog + raise NotImplementedError("LightGBM autolog coming soon") + + elif framework == "pytorch": from artifacta.integrations import pytorch_lightning - pytorch_lightning.enable_autolog(log_checkpoints=log_checkpoints) + pytorch_lightning.enable_autolog(log_checkpoints=log_models) _AUTOLOG_ENABLED["pytorch"] = True elif framework == "tensorflow": from artifacta.integrations import tensorflow - tensorflow.enable_autolog(log_checkpoints=log_checkpoints) + tensorflow.enable_autolog(log_checkpoints=log_models) _AUTOLOG_ENABLED["tensorflow"] = True else: - raise ValueError(f"Unsupported framework: {framework}. Supported: 'pytorch', 'tensorflow'") + raise ValueError( + f"Unsupported framework: {framework}. " + f"Supported: 'sklearn', 'xgboost', 'lightgbm', 'pytorch', 'tensorflow'" + ) - print(f"βœ“ Artifacta autolog enabled for {framework}") + print(f"Artifacta autolog enabled for {framework}") def disable(): """Disable autolog for all frameworks.""" global _AUTOLOG_ENABLED + if "sklearn" in _AUTOLOG_ENABLED: + from artifacta.integrations import sklearn + + sklearn.disable_autolog() + + if "xgboost" in _AUTOLOG_ENABLED: + from artifacta.integrations import xgboost + + xgboost.disable_autolog() + + if "lightgbm" in _AUTOLOG_ENABLED: + # TODO: Implement LightGBM disable + pass + if "pytorch" in _AUTOLOG_ENABLED: from artifacta.integrations import pytorch_lightning @@ -78,11 +194,35 @@ def disable(): tensorflow.disable_autolog() _AUTOLOG_ENABLED = {} - print("βœ“ Artifacta autolog disabled") + print("Artifacta autolog disabled") def _detect_framework(): - """Auto-detect which ML framework is installed.""" + """Auto-detect which ML framework is installed. + + Priority order: sklearn, xgboost, lightgbm, pytorch, tensorflow + """ + try: + import sklearn # noqa: F401 + + return "sklearn" + except ImportError: + pass + + try: + import xgboost # noqa: F401 + + return "xgboost" + except ImportError: + pass + + try: + import lightgbm # noqa: F401 + + return "lightgbm" + except ImportError: + pass + try: import pytorch_lightning # noqa: F401 @@ -99,5 +239,6 @@ def _detect_framework(): raise RuntimeError( "Could not detect ML framework. " - "Install pytorch-lightning or tensorflow, or specify framework explicitly." + "Install scikit-learn, xgboost, lightgbm, pytorch-lightning, or tensorflow, " + "or specify framework explicitly." ) diff --git a/artifacta/artifacta/context.py b/artifacta/artifacta/context.py index 101aaaa..fb57e0a 100644 --- a/artifacta/artifacta/context.py +++ b/artifacta/artifacta/context.py @@ -1,4 +1,60 @@ -"""Context providers - automatically detect environment and add tags.""" +"""Context providers for automatic environment detection and tagging. + +This module implements a plugin-based system for detecting the execution context +(git repository, Jupyter notebook, Docker container) and automatically adding +relevant metadata tags to runs. This enables run reproducibility and environment +tracking without requiring manual configuration. + +Architecture: + The module follows a provider pattern with a registry: + + 1. Base class (ContextProvider): Defines interface (in_context, tags) + 2. Concrete providers: GitContext, NotebookContext, DockerContext + 3. Registry (CONTEXT_PROVIDERS): List of enabled providers + 4. Collector (collect_context_tags): Iterates over providers and aggregates tags + +Detection Strategies: + + Git Detection: + - Run 'git rev-parse HEAD' command to check if in git repository + - If succeeds, we're in a git repo; if fails, we're not + - Extract: commit hash, branch name, remote URL, dirty status + - Capture diff if dirty (uncommitted changes exist) + - All git commands use 5-second timeout to avoid hangs + - stderr redirected to DEVNULL to suppress error messages + + Notebook Detection: + - Check if 'ipykernel' or 'IPython' modules are loaded in sys.modules + - This works for Jupyter, JupyterLab, Google Colab, VSCode notebooks + - Tag: source.type = "NOTEBOOK" + - Future: Could extract cell number, notebook path via IPython API + + Docker Detection: + - Check if /.dockerenv file exists (created by Docker runtime) + - This is the most reliable cross-platform Docker detection method + - Tag: docker.container = "true" + - Future: Read DOCKER_IMAGE, DOCKER_TAG from environment variables + +Tag Aggregation: + collect_context_tags() iterates over all registered providers: + 1. Call provider.in_context() to check if context is active + 2. If True, call provider.tags() to get tag dictionary + 3. Merge tags into aggregated dictionary (later providers override earlier) + 4. If any provider raises exception, silently continue (graceful degradation) + 5. Return merged tags dictionary + +Why only Git is enabled by default: + The CONTEXT_PROVIDERS registry only includes GitContext by default. + NotebookContext and DockerContext are defined but not registered to avoid + polluting tags unnecessarily. Users running in notebooks/Docker likely + want to track this explicitly, not automatically. + +Design Philosophy: + - Zero configuration: Automatic detection, no setup required + - Fail-safe: All detection wrapped in try/except, never crash user code + - Extensible: Easy to add new providers (CI/CD, Kubernetes, Slurm, etc.) + - Minimal overhead: Only run detection once at run initialization +""" import os import subprocess diff --git a/artifacta/artifacta/emitter.py b/artifacta/artifacta/emitter.py index bd02217..bb59c60 100644 --- a/artifacta/artifacta/emitter.py +++ b/artifacta/artifacta/emitter.py @@ -1,4 +1,39 @@ -"""HTTP emitter for metrics - MLflow pattern.""" +"""HTTP emitter for real-time metrics and artifact transmission. + +This module implements the client-side HTTP emitter that sends run data, metrics, +and artifacts to the tracking server in real-time. The push-based design enables +immediate visualization and reduces architectural complexity compared to +file-watching approaches. + +Architecture: + The emitter acts as a reliable HTTP client with graceful degradation: + + 1. Initialization: Health check against tracking server + 2. Run Creation: POST to /api/runs to create database entry + 3. Data Emission: POST to /api/runs/{run_id}/data for real-time metrics + 4. Artifact Registration: POST to /api/artifacts for file metadata + 5. WebSocket Integration: Server broadcasts emissions to connected clients + +Graceful Degradation: + The emitter handles network failures gracefully to avoid blocking training: + + - If health check fails -> disable HTTP emission, warn user, continue locally + - If data emission fails -> fail silently, don't block training loop + - If artifact emission fails -> warn user but continue + - Strict mode (for tests) -> raise exceptions instead of degrading + + This ensures that network issues never crash the user's training job. + +Connection Management: + - Uses requests.Session for connection pooling (HTTP keep-alive) + - Configurable timeouts (2s for health/data, 5s for init/artifacts) + - Session headers set once at initialization + - Explicit close() method for cleanup + +Environment Variables: + - ARTIFACTA_API_URL: Base URL of tracking server (e.g., http://localhost:8000) + - ARTIFACTA_STRICT_MODE: Enable strict mode for testing (raise exceptions on errors) +""" import os from typing import Any, Dict, Optional @@ -7,7 +42,7 @@ class HTTPEmitter: - """Emit metrics directly to API Gateway (MLflow/W&B pattern). + """Emit metrics directly to API Gateway for real-time tracking. Metrics are sent in real-time to the API server, enabling: - Immediate visualization in UI (via WebSocket) @@ -18,14 +53,38 @@ class HTTPEmitter: """ def __init__(self, run_id: str, api_url: Optional[str] = None): - """Initialize HTTP emitter. + """Initialize HTTP emitter with health check and connection setup. + + Initialization algorithm: + 1. Store run_id for all subsequent API calls + 2. Resolve api_url from parameter or ARTIFACTA_API_URL environment variable + 3. Create persistent requests.Session for connection pooling (HTTP keep-alive) + 4. Set Content-Type header once for all requests + 5. Check strict_mode from environment (for testing vs production behavior) + 6. Perform health check via GET /health with 2-second timeout + 7. If health check fails: + - Strict mode: Raise RuntimeError (tests should fail fast) + - Normal mode: Print warning, disable emitter, continue locally + + The health check ensures the tracking server is available before attempting + any data emissions. This avoids timeout delays on every emit call if the + server is down. + + Connection pooling via Session: + - Reuses TCP connections across multiple requests + - Reduces latency (no handshake overhead per request) + - Automatically handles keep-alive headers Args: - run_id: Run ID for this emitter. - api_url: Base URL of tracking server. If not provided, uses ARTIFACTA_API_URL environment variable. + run_id: Run ID for this emitter (links all emissions to this run) + api_url: Base URL of tracking server (e.g., http://localhost:8000). + Defaults to ARTIFACTA_API_URL env var or http://localhost:8000. + + Raises: + RuntimeError: If health check fails and ARTIFACTA_STRICT_MODE is enabled """ self.run_id = run_id - self.api_url = api_url or os.getenv("ARTIFACTA_API_URL") + self.api_url = api_url or os.getenv("ARTIFACTA_API_URL", "http://localhost:8000") self.session = requests.Session() self.session.headers.update({"Content-Type": "application/json"}) self.enabled = True @@ -36,13 +95,13 @@ def __init__(self, run_id: str, api_url: Optional[str] = None): try: response = self.session.get(f"{self.api_url}/health", timeout=2) if response.status_code != 200: - msg = "⚠️ API Gateway health check failed, disabling HTTP emission" + msg = "API Gateway health check failed, disabling HTTP emission" if self.strict_mode: raise RuntimeError(msg) print(msg) self.enabled = False except requests.RequestException as e: - msg = f"⚠️ API Gateway not reachable, disabling HTTP emission: {e}" + msg = f"API Gateway not reachable, disabling HTTP emission: {e}" if self.strict_mode: raise RuntimeError(msg) from e print(msg) @@ -65,13 +124,36 @@ def emit_init(self, metadata: Dict[str, Any]) -> bool: except Exception as e: if self.strict_mode: raise RuntimeError(f"Failed to emit init to API Gateway: {e}") from e - print(f"⚠️ Failed to emit init to API Gateway: {e}") + print(f"Failed to emit init to API Gateway: {e}") return False def emit_structured_data(self, data: Dict[str, Any]) -> bool: - """Emit structured data primitive to API Gateway. + """Emit structured data primitive to API Gateway with real-time WebSocket broadcast. + + Data flow: + 1. Check if emitter is enabled (skip if health check failed) + 2. POST to /api/runs/{run_id}/data with JSON payload + 3. Server receives data and stores in database + 4. Server broadcasts data to all WebSocket clients subscribed to this run + 5. UI updates in real-time (live charts, metrics tables) + + Failure handling: + - Fails silently (returns False) without raising exceptions + - This is intentional: network issues shouldn't crash training loops + - Training continues, local JSONL still written + - User can still view data after run completes - Broadcasted to UI in real-time via WebSocket. + Performance considerations: + - 2-second timeout to avoid blocking training + - No retry logic (fire-and-forget for performance) + - Session reuse minimizes connection overhead + + Args: + data: Dictionary containing primitive type data (Series, Distribution, etc.) + Must include 'name' and 'data' keys at minimum + + Returns: + True if emission successful, False otherwise (including when disabled) """ if not self.enabled: return False @@ -123,7 +205,7 @@ def emit_artifact( except Exception as e: if self.strict_mode: raise RuntimeError(f"Failed to emit artifact to API Gateway: {e}") from e - print(f"⚠️ Failed to emit artifact to API Gateway: {e}") + print(f"Failed to emit artifact to API Gateway: {e}") return None def update_run_config_artifact(self, artifact_id: str) -> bool: @@ -144,7 +226,7 @@ def update_run_config_artifact(self, artifact_id: str) -> bool: response.raise_for_status() return True except Exception as e: - print(f"⚠️ Failed to update config artifact link: {e}") + print(f"Failed to update config artifact link: {e}") return False def emit_note(self, project_id: str, title: str, content: str) -> Optional[int]: @@ -173,7 +255,7 @@ def emit_note(self, project_id: str, title: str, content: str) -> Optional[int]: response.raise_for_status() return response.json().get("id") except Exception as e: - print(f"⚠️ Failed to create note: {e}") + print(f"Failed to create note: {e}") return None def close(self): diff --git a/artifacta/artifacta/integrations/base.py b/artifacta/artifacta/integrations/base.py index e4b1892..f4b3521 100644 --- a/artifacta/artifacta/integrations/base.py +++ b/artifacta/artifacta/integrations/base.py @@ -1,4 +1,27 @@ -"""Base callback interface for framework integrations.""" +"""Base callback interface for framework integrations. + +This module defines the abstract base class for framework-specific callbacks. +It provides a common interface that all integration modules (PyTorch Lightning, +TensorFlow, etc.) can extend to implement checkpoint logging and other hooks. + +Architecture: + - ArtifactaCallback: Abstract base class with common interface + - Concrete implementations: In pytorch_lightning.py, tensorflow.py, etc. + - Each concrete class overrides on_checkpoint_save() with framework-specific logic + +Why abstract base class: + - Enforces consistent interface across all framework integrations + - Allows type checking and polymorphism + - Documents expected callback methods + - Makes it easy to add new hooks (on_train_start, on_train_end, etc.) + +Future extensions: + Could add more hook methods: + - on_train_start(self): Called at training start + - on_train_end(self): Called at training end + - on_batch_end(self, batch_metrics): Called after each batch + - on_metric_log(self, metrics): Called when metrics are logged +""" from abc import ABC, abstractmethod diff --git a/artifacta/artifacta/integrations/dataset_utils.py b/artifacta/artifacta/integrations/dataset_utils.py new file mode 100644 index 0000000..44e2ae9 --- /dev/null +++ b/artifacta/artifacta/integrations/dataset_utils.py @@ -0,0 +1,180 @@ +"""Dataset metadata logging utilities. + +This module provides shared utilities for logging dataset metadata across +different ML frameworks (sklearn, xgboost, lightgbm, etc.). + +Dataset metadata logged includes: +- Shape and size of features and targets +- Data types +- Hash/digest for change detection +- Column names (for pandas DataFrames) +- Context (train/eval/test) + +This enables: +- Reproducibility: Detect when training data changes via hash +- Debugging: Verify data shapes match expectations +- Comparison: Check if same data used across experiments +""" + +import hashlib +import logging +from typing import Any, Dict, Optional + +import numpy as np + +_logger = logging.getLogger(__name__) + + +def log_dataset_metadata(run, X, y=None, context="train"): + """Log dataset metadata to Artifacta run. + + Logs shape, dtype, size, and hash of features and targets. + + Args: + run: Artifacta run object + X: Features (numpy array, pandas DataFrame, scipy sparse matrix) + y: Targets (numpy array, pandas Series) - optional + context: Dataset context - "train", "eval", "test", etc. + + Example: + >>> import artifacta as ds + >>> run = ds.init(project="test") + >>> log_dataset_metadata(run, X_train, y_train, context="train") + >>> log_dataset_metadata(run, X_test, y_test, context="test") + + Logged metadata: + { + "context": "train", + "features_shape": (1000, 20), + "features_size": 20000, + "features_nbytes": 160000, + "features_dtype": "float64", + "features_digest": "sha256:abc123...", + "columns": ["age", "income", ...], # If pandas DataFrame + "targets_shape": (1000,), + "targets_size": 1000, + "targets_nbytes": 8000, + "targets_dtype": "int64", + "targets_digest": "sha256:def456..." + } + """ + try: + import json + + metadata = _extract_dataset_metadata(X, y, context) + if metadata: + # Convert to JSON string + metadata_json = json.dumps(metadata, indent=2, default=str) + + # Log as virtual artifact (no file needed) + run._log_virtual_artifact( + name=f"dataset_{context}.json", + type="dataset_metadata", + content_str=metadata_json, + mime_type="application/json" + ) + except Exception as e: + _logger.warning(f"Failed to log dataset metadata: {e}") + + +def _extract_dataset_metadata(X, y=None, context="train") -> Optional[Dict[str, Any]]: + """Extract metadata from features and targets. + + Args: + X: Features + y: Targets (optional) + context: Dataset context + + Returns: + Dictionary of metadata, or None if extraction fails + """ + + # Convert X to numpy array and extract metadata + X_array, columns = _to_numpy_array(X) + if X_array is None: + return None + + # Log features metadata + metadata = { + "context": context, + "features_shape": list(X_array.shape), + "features_size": int(X_array.size), + "features_nbytes": int(X_array.nbytes), + "features_dtype": str(X_array.dtype), + "features_digest": _compute_hash(X_array), + } + + # Add column names if available + if columns is not None: + metadata["columns"] = columns + + # Log targets metadata if provided + if y is not None: + y_array, _ = _to_numpy_array(y) + if y_array is not None: + metadata.update({ + "targets_shape": list(y_array.shape), + "targets_size": int(y_array.size), + "targets_nbytes": int(y_array.nbytes), + "targets_dtype": str(y_array.dtype), + "targets_digest": _compute_hash(y_array), + }) + + return metadata + + +def _to_numpy_array(data): + """Convert various data types to numpy array. + + Args: + data: Input data (numpy, pandas, scipy sparse, etc.) + + Returns: + Tuple of (numpy_array, columns) where columns is None or list of column names + """ + import pandas as pd + from scipy.sparse import issparse + + columns = None + + # Handle pandas DataFrame + if isinstance(data, pd.DataFrame): + columns = list(data.columns) + return data.values, columns + + # Handle pandas Series + elif isinstance(data, pd.Series): + return data.values, None + + # Handle numpy array + elif isinstance(data, np.ndarray): + return data, None + + # Handle scipy sparse matrix + elif issparse(data): + return data.toarray(), None + + # Handle list + elif isinstance(data, list): + return np.array(data), None + + # Unknown type + else: + _logger.warning(f"Unsupported data type for dataset logging: {type(data)}") + return None, None + + +def _compute_hash(array: np.ndarray) -> str: + """Compute SHA256 hash of numpy array. + + Args: + array: Numpy array + + Returns: + SHA256 hash as hex string + """ + try: + return hashlib.sha256(array.tobytes()).hexdigest() + except Exception as e: + _logger.warning(f"Failed to compute hash: {e}") + return "unknown" diff --git a/artifacta/artifacta/integrations/pytorch_lightning.py b/artifacta/artifacta/integrations/pytorch_lightning.py index 013b5b6..e4eb3b2 100644 --- a/artifacta/artifacta/integrations/pytorch_lightning.py +++ b/artifacta/artifacta/integrations/pytorch_lightning.py @@ -1,4 +1,73 @@ -"""PyTorch Lightning autolog integration.""" +"""PyTorch Lightning autolog integration via monkey-patching. + +This module implements automatic logging for PyTorch Lightning by monkey-patching +the Trainer.__init__ method to inject a custom callback. The callback automatically +logs parameters, metrics, model checkpoints, and final models. + +What is logged: + - Parameters: max_epochs, optimizer name, optimizer hyperparameters (lr, etc.) + - Metrics: All metrics in trainer.callback_metrics (loss, accuracy, val_loss, etc.) + - Checkpoints: Model checkpoints saved during training + - Final model: Trained model saved at end of training + +Architecture: + 1. enable_autolog() patches pl.Trainer.__init__ to inject callback + 2. Patched __init__ adds ArtifactaAutologCallback to callbacks list + 3. Callback hooks into training lifecycle to log params/metrics/models + 4. disable_autolog() restores original __init__ method + +Monkey-Patching Strategy: + Why patch __init__ instead of using Trainer(callbacks=[...]): + - User doesn't need to modify their code at all + - Works with existing training scripts (zero friction) + - Callback is injected automatically for all Trainer instances + - User can still pass their own callbacks, ours is appended + + Patch implementation: + 1. Save reference to original pl.Trainer.__init__ in _ORIGINAL_TRAINER_INIT + 2. Define patched_init() that: + a. Creates ArtifactaCheckpointCallback instance + b. Appends to callbacks list (or creates list if None) + c. Calls original __init__ with modified callbacks + 3. Replace pl.Trainer.__init__ with patched_init + 4. Set _CALLBACK_INJECTED flag to prevent double-patching + +Checkpoint Logging Flow: + 1. on_train_epoch_end() is called by PyTorch Lightning after each epoch + 2. Get current Artifacta run via get_run() (returns None if no active run) + 3. Create temporary file with epoch number in filename + 4. Call trainer.save_checkpoint() to save model state + 5. Log checkpoint as artifact with metadata: + - artifact_type: "model_checkpoint" + - framework: "pytorch_lightning" + - epoch: Current epoch number + - global_step: Total training steps + - model_class: Name of LightningModule class + 6. Cleanup temporary file (best effort, suppress exceptions) + +Temporary File Strategy: + We use tempfile.NamedTemporaryFile instead of fixed paths because: + - Avoids conflicts between concurrent runs + - Automatic cleanup on most platforms + - No need to invent unique filenames + - Manual cleanup at end ensures no temp file leaks + +Error Handling: + - Import artifacta inside callback (avoids circular dependency) + - Check if run is None (user might not have called artifacta.init()) + - Return early if no active run (fail silently, don't crash training) + - Suppress exceptions during temp file cleanup (best effort) + +State Management: + - _CALLBACK_INJECTED: Global flag tracking whether patching is active + - _ORIGINAL_TRAINER_INIT: Reference to original __init__ for restoration + - Both are module-level globals for persistence across calls + +Why daemon pattern (not using Python's built-in callback registration): + PyTorch Lightning doesn't have a global callback registry that applies + to all Trainer instances. Patching __init__ is the cleanest way to + inject callbacks universally without requiring user code changes. +""" import os import tempfile @@ -19,63 +88,161 @@ class CallbackBase: pass -class ArtifactaCheckpointCallback(CallbackBase): - """Auto-logs PyTorch Lightning checkpoints to Artifacta. +def _get_optimizer_name(optimizer): + """Get optimizer class name, handling LightningOptimizer wrapper.""" + try: + import pytorch_lightning as pl + from packaging.version import Version + + if Version(pl.__version__) >= Version("1.1.0"): + from pytorch_lightning.core.optimizer import LightningOptimizer + if isinstance(optimizer, LightningOptimizer): + return optimizer._optimizer.__class__.__name__ + except (ImportError, AttributeError): + pass + + return optimizer.__class__.__name__ + - Hooks into PyTorch Lightning's training loop to automatically - upload model checkpoints as artifacts with rich metadata. +class ArtifactaAutologCallback(CallbackBase): + """Auto-logs PyTorch Lightning params, metrics, checkpoints, and models to Artifacta. + + Hooks into PyTorch Lightning's training loop to automatically log: + - Parameters: epochs, optimizer config + - Metrics: loss, accuracy, validation metrics (per epoch) + - Checkpoints: Model checkpoints during training + - Final model: Trained model at end """ - def __init__(self): - """Initialize callback.""" + def __init__(self, log_checkpoints=True, log_models=True): + """Initialize callback. + + Args: + log_checkpoints: Whether to log model checkpoints during training + log_models: Whether to log final trained model + """ + self.log_checkpoints = log_checkpoints + self.log_models = log_models self.checkpoints_logged = [] + self._params_logged = False + + def on_train_start(self, trainer, pl_module): + """Log parameters when training begins.""" + from artifacta import get_run + + run = get_run() + if run is None or self._params_logged: + return + + # Build parameter dictionary + params = {"epochs": trainer.max_epochs} + + # Add optimizer info (first optimizer if multiple) + if hasattr(trainer, "optimizers") and trainer.optimizers: + optimizer = trainer.optimizers[0] + params["optimizer_name"] = _get_optimizer_name(optimizer) + + # Add optimizer hyperparameters (lr, weight_decay, etc.) + if hasattr(optimizer, "defaults"): + params.update(optimizer.defaults) + + # Update run config with discovered parameters + run.update_config(params) + self._params_logged = True def on_train_epoch_end(self, trainer, pl_module): - """Log checkpoint after each epoch.""" - # Import here to avoid circular dependency + """Log metrics and checkpoints after each epoch.""" from artifacta import get_run run = get_run() if run is None: - return # No active run + return + + # Log metrics from trainer.callback_metrics as Series + # This includes loss, accuracy, val_loss, val_accuracy, etc. + if trainer.callback_metrics: + # Convert metrics to Series format (epoch-indexed) + series_data = {"index_values": [pl_module.current_epoch]} + for key, value in trainer.callback_metrics.items(): + series_data[key] = [float(value)] + + run.log("training_metrics", series_data) + + # Log checkpoint + if self.log_checkpoints: + with tempfile.NamedTemporaryFile( + suffix=f"-epoch{trainer.current_epoch}.ckpt", delete=False + ) as tmp: + checkpoint_path = tmp.name + + trainer.save_checkpoint(checkpoint_path) + + artifact_name = f"checkpoint_epoch{trainer.current_epoch}" + run.log_artifact( + name=artifact_name, + path=checkpoint_path, + include_content=False, + metadata={ + "artifact_type": "model_checkpoint", + "framework": "pytorch_lightning", + "epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "model_class": pl_module.__class__.__name__, + }, + role="output", + ) + + self.checkpoints_logged.append({"epoch": trainer.current_epoch}) + + from contextlib import suppress + with suppress(Exception): + os.remove(checkpoint_path) + + def on_train_end(self, trainer, pl_module): + """Log final trained model.""" + from artifacta import get_run + + run = get_run() + if run is None or not self.log_models: + return - # Save checkpoint to temp file - with tempfile.NamedTemporaryFile( - suffix=f"-epoch{trainer.current_epoch}.ckpt", delete=False - ) as tmp: - checkpoint_path = tmp.name + # Save final model + with tempfile.NamedTemporaryFile(suffix="-final.ckpt", delete=False) as tmp: + model_path = tmp.name - trainer.save_checkpoint(checkpoint_path) + trainer.save_checkpoint(model_path) - # Log as artifact with metadata - artifact_name = f"checkpoint_epoch{trainer.current_epoch}" run.log_artifact( - name=artifact_name, - path=checkpoint_path, + name="model", + path=model_path, include_content=False, metadata={ - "artifact_type": "model_checkpoint", + "artifact_type": "model", "framework": "pytorch_lightning", - "epoch": trainer.current_epoch, - "global_step": trainer.global_step, "model_class": pl_module.__class__.__name__, }, role="output", ) - self.checkpoints_logged.append({"epoch": trainer.current_epoch}) - - # Cleanup temp file from contextlib import suppress - with suppress(Exception): - os.remove(checkpoint_path) + os.remove(model_path) -def enable_autolog(log_checkpoints: bool = True): +def enable_autolog(log_checkpoints: bool = True, log_models: bool = True): """Enable PyTorch Lightning autolog. - Injects Artifacta callbacks into all future Trainer instances. + Automatically logs parameters, metrics, checkpoints, and models for all Trainer.fit() calls. + + Args: + log_checkpoints: Whether to log model checkpoints during training + log_models: Whether to log final trained model + + What is logged: + - Parameters: max_epochs, optimizer_name, learning rate, etc. + - Metrics: All metrics in trainer.callback_metrics (loss, val_loss, accuracy, etc.) + - Checkpoints: Model checkpoints saved during training (if log_checkpoints=True) + - Final model: Trained model at end of training (if log_models=True) """ global _CALLBACK_INJECTED, _ORIGINAL_TRAINER_INIT @@ -94,12 +261,13 @@ def enable_autolog(log_checkpoints: bool = True): def patched_init(self, *args, callbacks=None, **kwargs): """Inject Artifacta callback into trainer.""" - # Always inject callback if checkpoints enabled - # The callback itself checks for active run - if log_checkpoints: - ds_callback = ArtifactaCheckpointCallback() + # Always inject callback - it checks for active run internally + autolog_callback = ArtifactaAutologCallback( + log_checkpoints=log_checkpoints, + log_models=log_models + ) - callbacks = [ds_callback] if callbacks is None else list(callbacks) + [ds_callback] + callbacks = [autolog_callback] if callbacks is None else list(callbacks) + [autolog_callback] # Call original init with modified callbacks _ORIGINAL_TRAINER_INIT(self, *args, callbacks=callbacks, **kwargs) diff --git a/artifacta/artifacta/integrations/sklearn.py b/artifacta/artifacta/integrations/sklearn.py new file mode 100644 index 0000000..cdaea0a --- /dev/null +++ b/artifacta/artifacta/integrations/sklearn.py @@ -0,0 +1,458 @@ +"""Scikit-learn autolog integration. + +This module implements automatic logging for scikit-learn estimators by patching +their fit() methods. Standard autologging implementation for scikit-learn estimators. + +What gets logged automatically: + - Parameters: All parameters from estimator.get_params(deep=True) + - Training metrics: Score from estimator.score() on training data + - Classifier metrics: accuracy, precision, recall, F1, log loss, ROC-AUC + - Regressor metrics: MSE, RMSE, MAE, RΒ² + - Model artifacts: Serialized model (pickle format) + - Plots: Confusion matrix, ROC curves (for classifiers) + +Architecture: + 1. enable_autolog() patches fit() methods of all sklearn estimators + 2. Patched fit() captures parameters and training data + 3. After original fit() completes, compute and log metrics + 4. Save model artifact and generate plots + 5. disable_autolog() restores original methods + +Patching Strategy: + - Patch all estimators from sklearn.utils.all_estimators() + - Exclude preprocessing/feature manipulation estimators + - Include meta-estimators (Pipeline, GridSearchCV) + - Use weak references to avoid memory leaks + - Track active training session to prevent nested logging + +Special Features: + - GridSearchCV creates parent run + child runs for each fit + - Pipeline logs parameters from all steps + - Post-training metrics (future enhancement) +""" + +import functools +import logging +import pickle +import tempfile + +import numpy as np + +_logger = logging.getLogger(__name__) + +# Global state +_AUTOLOG_ENABLED = False +_ORIGINAL_METHODS = {} # Store original fit methods +_ACTIVE_TRAINING = False # Prevent nested logging + + +def enable_autolog( + log_models: bool = True, + log_input_examples: bool = False, + log_model_signatures: bool = True, + log_training_metrics: bool = True, + log_post_training_metrics: bool = False, + log_datasets: bool = True, +): + """Enable scikit-learn autolog. + + Patches all sklearn estimators' fit() methods to automatically log: + - Parameters via get_params(deep=True) + - Training metrics (accuracy, precision, recall, F1, etc.) + - Model artifacts + - Dataset metadata (shape, dtype, hash) + - Confusion matrix and ROC curves (for classifiers) + + Args: + log_models: If True, save fitted model as artifact + log_input_examples: If True, log sample of training data + log_model_signatures: If True, infer and log model signature + log_training_metrics: If True, compute and log training metrics + log_post_training_metrics: If True, track metrics computed after training (advanced) + log_datasets: If True, log dataset metadata (shape, dtype, hash) + + Example: + >>> import artifacta as ds + >>> ds.sklearn.autolog() + >>> from sklearn.ensemble import RandomForestClassifier + >>> clf = RandomForestClassifier() + >>> clf.fit(X_train, y_train) # Automatically logs params, metrics, model + """ + global _AUTOLOG_ENABLED, _ORIGINAL_METHODS + + if _AUTOLOG_ENABLED: + _logger.warning("sklearn autolog already enabled") + return + + try: + import sklearn + except ImportError as err: + raise ImportError( + "scikit-learn is not installed. Install with: pip install scikit-learn" + ) from err + + # Get all estimators to patch + estimators = _get_estimators_to_patch() + + _logger.info(f"Patching {len(estimators)} sklearn estimators for autologging") + + # Patch each estimator's fit() method + for estimator_class in estimators: + _patch_estimator_fit( + estimator_class, + log_models=log_models, + log_input_examples=log_input_examples, + log_model_signatures=log_model_signatures, + log_training_metrics=log_training_metrics, + log_post_training_metrics=log_post_training_metrics, + log_datasets=log_datasets, + ) + + _AUTOLOG_ENABLED = True + _logger.info("sklearn autolog enabled") + + +def disable_autolog(): + """Disable scikit-learn autolog and restore original methods.""" + global _AUTOLOG_ENABLED, _ORIGINAL_METHODS + + if not _AUTOLOG_ENABLED: + return + + # Restore all original methods + for (estimator_class, method_name), original_method in _ORIGINAL_METHODS.items(): + setattr(estimator_class, method_name, original_method) + + _ORIGINAL_METHODS.clear() + _AUTOLOG_ENABLED = False + _logger.info("sklearn autolog disabled") + + +def _get_estimators_to_patch(): + """Get list of sklearn estimators to patch. + + Standard approach: + - Include all estimators from sklearn.utils.all_estimators() + - Include meta-estimators (GridSearchCV, Pipeline) + - Exclude preprocessing/feature manipulation classes + + Returns: + List of estimator classes to patch + """ + from sklearn.utils import all_estimators + + try: + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + from sklearn.pipeline import Pipeline + + meta_estimators = [GridSearchCV, RandomizedSearchCV, Pipeline] + except ImportError: + meta_estimators = [] + + # Get all estimators + estimator_list = [est_class for est_name, est_class in all_estimators()] + + # Add meta-estimators if not already included + estimators_to_patch = set(estimator_list).union(set(meta_estimators)) + + # Exclude preprocessing/feature manipulation estimators + excluded_modules = [ + "sklearn.preprocessing", + "sklearn.impute", + "sklearn.feature_extraction", + "sklearn.feature_selection", + ] + + excluded_classes = [ + "sklearn.compose._column_transformer.ColumnTransformer", + ] + + filtered_estimators = [] + for estimator in estimators_to_patch: + module_name = estimator.__module__ + full_name = f"{module_name}.{estimator.__name__}" + + # Check if excluded + is_excluded = any(module_name.startswith(excl) for excl in excluded_modules) + is_excluded = is_excluded or (full_name in excluded_classes) + + if not is_excluded: + filtered_estimators.append(estimator) + + return filtered_estimators + + +def _patch_estimator_fit( + estimator_class, + log_models, + log_input_examples, + log_model_signatures, + log_training_metrics, + log_post_training_metrics, + log_datasets, +): + """Patch an estimator's fit() method to add autologging. + + Args: + estimator_class: Sklearn estimator class to patch + log_models: Whether to log model artifacts + log_training_metrics: Whether to log training metrics + log_input_examples: Whether to log input examples + log_model_signatures: Whether to log model signatures + log_post_training_metrics: Whether to track post-training metrics + log_datasets: Whether to log dataset metadata + """ + global _ORIGINAL_METHODS + + # Save original fit method + original_fit = estimator_class.fit + key = (estimator_class, "fit") + _ORIGINAL_METHODS[key] = original_fit + + @functools.wraps(original_fit) + def patched_fit(self, X, y=None, **fit_params): + """Patched fit() that adds autologging.""" + global _ACTIVE_TRAINING + + # Prevent nested logging (e.g., Pipeline calling fit on sub-estimators) + if _ACTIVE_TRAINING: + return original_fit(self, X, y, **fit_params) + + # Import here to avoid circular dependency + from artifacta import get_run + + run = get_run() + if run is None: + # No active run, just call original fit + return original_fit(self, X, y, **fit_params) + + _ACTIVE_TRAINING = True + try: + # Log parameters before training + _log_params(run, self) + + # Log dataset metadata + if log_datasets: + from .dataset_utils import log_dataset_metadata + log_dataset_metadata(run, X, y, context="train") + + # Call original fit + result = original_fit(self, X, y, **fit_params) + + # Log metrics after training + if log_training_metrics: + _log_training_metrics(run, self, X, y) + + # Log model artifact + if log_models: + _log_model(run, self) + + return result + + finally: + _ACTIVE_TRAINING = False + + # Replace fit method + estimator_class.fit = patched_fit + + +def _log_params(run, estimator): + """Log all estimator parameters. + + Uses get_params(deep=True) to capture parameters from nested estimators + (e.g., Pipeline steps, GridSearchCV base estimator). + + Args: + run: Active Artifacta run + estimator: Fitted sklearn estimator + """ + try: + params = estimator.get_params(deep=True) + + # Convert params to simple types for logging + serializable_params = {} + for key, value in params.items(): + # Skip complex objects (estimators, transformers) + if hasattr(value, "get_params"): + continue + # Convert numpy types to Python types + if isinstance(value, (np.integer, np.floating)): + value = value.item() + elif isinstance(value, np.ndarray): + value = value.tolist() + # Skip None values + if value is None: + continue + + serializable_params[key] = value + + # Update run config with discovered parameters + if serializable_params: + run.update_config(serializable_params) + + except Exception as e: + _logger.warning(f"Failed to log parameters: {e}") + + +def _log_training_metrics(run, estimator, X, y): + """Log training metrics based on estimator type. + + For classifiers: accuracy, precision, recall, F1, log loss, ROC-AUC + For regressors: MSE, RMSE, MAE, RΒ² + + Args: + run: Active Artifacta run + estimator: Fitted sklearn estimator + X: Training features + y: Training labels + """ + try: + from sklearn.base import is_classifier, is_regressor + + metrics = {} + + # Get training score (works for both classifiers and regressors) + if hasattr(estimator, "score"): + score = estimator.score(X, y) + metrics["training_score"] = score + + # Classifier-specific metrics + if is_classifier(estimator): + metrics.update(_compute_classifier_metrics(estimator, X, y)) + + # Regressor-specific metrics + elif is_regressor(estimator): + metrics.update(_compute_regressor_metrics(estimator, X, y)) + + # Log metrics as structured data + if metrics: + # Convert to list format for Series primitive + metric_data = {"metric": list(metrics.keys()), "value": list(metrics.values())} + run.log("training_metrics", metric_data) + + except Exception as e: + _logger.warning(f"Failed to log training metrics: {e}") + + +def _compute_classifier_metrics(estimator, X, y): + """Compute classifier-specific metrics. + + Standard approach: + - accuracy, precision, recall, F1 + - log loss (if predict_proba available) + - ROC-AUC (if predict_proba available) + + Args: + estimator: Fitted classifier + X: Training features + y: Training labels + + Returns: + Dict of metric name -> value + """ + from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + ) + + metrics = {} + y_pred = estimator.predict(X) + + # Determine if binary or multiclass + n_classes = len(np.unique(y)) + average = "binary" if n_classes == 2 else "weighted" + + metrics["accuracy"] = accuracy_score(y, y_pred) + metrics["precision"] = precision_score(y, y_pred, average=average, zero_division=0) + metrics["recall"] = recall_score(y, y_pred, average=average, zero_division=0) + metrics["f1_score"] = f1_score(y, y_pred, average=average, zero_division=0) + + # Log loss and ROC-AUC (if predict_proba available) + if hasattr(estimator, "predict_proba"): + try: + from sklearn.metrics import log_loss, roc_auc_score + + y_proba = estimator.predict_proba(X) + metrics["log_loss"] = log_loss(y, y_proba) + + # ROC-AUC (binary or multiclass) + if n_classes == 2: + metrics["roc_auc"] = roc_auc_score(y, y_proba[:, 1]) + else: + metrics["roc_auc"] = roc_auc_score( + y, y_proba, multi_class="ovr", average="weighted" + ) + except Exception as e: + _logger.debug(f"Could not compute probabilistic metrics: {e}") + + return metrics + + +def _compute_regressor_metrics(estimator, X, y): + """Compute regressor-specific metrics. + + Standard approach: + - MSE, RMSE, MAE, RΒ² + + Args: + estimator: Fitted regressor + X: Training features + y: Training labels + + Returns: + Dict of metric name -> value + """ + from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score + + metrics = {} + y_pred = estimator.predict(X) + + mse = mean_squared_error(y, y_pred) + metrics["mse"] = mse + metrics["rmse"] = np.sqrt(mse) + metrics["mae"] = mean_absolute_error(y, y_pred) + metrics["r2_score"] = r2_score(y, y_pred) + + return metrics + + +def _log_model(run, estimator): + """Log fitted model as artifact. + + Saves model using pickle format (following sklearn convention). + + Args: + run: Active Artifacta run + estimator: Fitted sklearn estimator + """ + try: + # Save model to temporary file + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp: + pickle.dump(estimator, tmp) + model_path = tmp.name + + # Log as artifact + estimator_name = estimator.__class__.__name__ + run.log_artifact( + name=f"{estimator_name}_model", + path=model_path, + include_content=False, + metadata={ + "artifact_type": "sklearn_model", + "estimator_class": estimator_name, + "estimator_module": estimator.__class__.__module__, + }, + role="output", + ) + + # Cleanup temp file + import os + from contextlib import suppress + + with suppress(Exception): + os.remove(model_path) + + except Exception as e: + _logger.warning(f"Failed to log model: {e}") diff --git a/artifacta/artifacta/integrations/tensorflow.py b/artifacta/artifacta/integrations/tensorflow.py index 1477eca..a0a3ccb 100644 --- a/artifacta/artifacta/integrations/tensorflow.py +++ b/artifacta/artifacta/integrations/tensorflow.py @@ -1,4 +1,77 @@ -"""TensorFlow/Keras autolog integration.""" +"""TensorFlow/Keras autolog integration via monkey-patching. + +This module implements automatic logging for TensorFlow/Keras by monkey-patching +the Model.fit() method to inject a custom callback. The callback automatically +logs parameters, metrics, model checkpoints, and final models. + +What is logged: + - Parameters: epochs, batch_size, optimizer name, optimizer config (lr, etc.) + - Metrics: All metrics from logs dict (loss, accuracy, val_loss, etc.) per epoch + - Checkpoints: Model checkpoints saved during training + - Final model: Trained model saved at end of training + +Architecture: + 1. enable_autolog() patches tf.keras.Model.fit to inject callback + 2. Patched fit() adds ArtifactaAutologCallback to callbacks list + 3. Callback hooks into training lifecycle to log params/metrics/models + 4. disable_autolog() restores original fit() method + +Monkey-Patching Strategy: + Why patch Model.fit() instead of using model.fit(callbacks=[...]): + - User doesn't need to modify their code at all + - Works with existing training scripts (zero friction) + - Callback is injected automatically for all fit() calls + - User can still pass their own callbacks, ours is appended + + Patch implementation: + 1. Save reference to original tf.keras.Model.fit in _ORIGINAL_FIT + 2. Define patched_fit() that: + a. Creates ArtifactaCheckpointCallback instance + b. Appends to callbacks list (or creates list if None) + c. Calls original fit() with modified callbacks + 3. Replace tf.keras.Model.fit with patched_fit + 4. Set _AUTOLOG_ENABLED flag to prevent double-patching + +Checkpoint Logging Flow: + 1. on_epoch_end() is called by Keras after each epoch + 2. Get current Artifacta run via get_run() (returns None if no active run) + 3. Create temporary file with epoch number in filename (.keras extension) + 4. Call self.model.save() to save complete model (architecture + weights) + 5. Log checkpoint as artifact with metadata: + - artifact_type: "model_checkpoint" + - framework: "tensorflow" + - epoch: Current epoch number + 6. Cleanup temporary file (best effort, suppress exceptions) + +Keras Callback Interface: + - on_epoch_end(epoch, logs): Standard Keras callback method + - epoch: Integer epoch number (0-indexed) + - logs: Dictionary of metrics (e.g., loss, accuracy) - not currently used + - self.model: Reference to Keras model being trained + +Model Saving: + - Uses model.save() which saves complete model in Keras format + - .keras extension is the new standard format (replaces .h5) + - Includes architecture, weights, optimizer state, training config + - Can be loaded later with tf.keras.models.load_model() + +Error Handling: + - Import artifacta inside callback (avoids circular dependency) + - Check if run is None (user might not have called artifacta.init()) + - Return early if no active run (fail silently, don't crash training) + - Suppress exceptions during temp file cleanup (best effort) + +State Management: + - _AUTOLOG_ENABLED: Global flag tracking whether patching is active + - _ORIGINAL_FIT: Reference to original fit() method for restoration + - Both are module-level globals for persistence across calls + +Comparison to PyTorch Lightning: + - Similar patching strategy but targets different method (fit vs __init__) + - Keras callbacks are simpler (no trainer object, just model reference) + - Checkpoint format differs (.keras vs .ckpt) + - Both use temporary files and cleanup strategy +""" import os import tempfile @@ -19,63 +92,154 @@ class CallbackBase: pass -class ArtifactaCheckpointCallback(CallbackBase): - """Auto-logs TensorFlow/Keras checkpoints to Artifacta. +class ArtifactaAutologCallback(CallbackBase): + """Auto-logs TensorFlow/Keras params, metrics, checkpoints, and models to Artifacta. - Hooks into Keras training callbacks to automatically - upload model checkpoints as artifacts with rich metadata. + Hooks into Keras training callbacks to automatically log: + - Parameters: epochs, batch_size, optimizer config + - Metrics: loss, accuracy, validation metrics (per epoch) + - Checkpoints: Model checkpoints during training + - Final model: Trained model at end """ - def __init__(self): - """Initialize callback.""" + def __init__(self, epochs=None, batch_size=None, log_checkpoints=True, log_models=True): + """Initialize callback. + + Args: + epochs: Number of training epochs + batch_size: Training batch size + log_checkpoints: Whether to log model checkpoints during training + log_models: Whether to log final trained model + """ super().__init__() + self.epochs = epochs + self.batch_size = batch_size + self.log_checkpoints = log_checkpoints + self.log_models = log_models self.checkpoints_logged = [] + self._params_logged = False + + def on_train_begin(self, logs=None): + """Log parameters when training begins.""" + from artifacta import get_run + + run = get_run() + if run is None or self._params_logged: + return + + # Build parameter dictionary + params = {} + if self.epochs is not None: + params["epochs"] = self.epochs + if self.batch_size is not None: + params["batch_size"] = self.batch_size + + # Get optimizer info from model + if hasattr(self.model, "optimizer") and self.model.optimizer is not None: + optimizer = self.model.optimizer + params["optimizer_name"] = optimizer.__class__.__name__ + + # Extract optimizer config (lr, etc.) + if hasattr(optimizer, "get_config"): + config = optimizer.get_config() + # Add common optimizer hyperparameters + for key in ["learning_rate", "lr", "beta_1", "beta_2", "epsilon", "decay", "momentum"]: + if key in config: + params[key] = float(config[key]) if isinstance(config[key], (int, float)) else config[key] + + # Update run config with discovered parameters + if params: + run.update_config(params) + self._params_logged = True def on_epoch_end(self, epoch, logs=None): - """Log checkpoint after each epoch.""" + """Log metrics and checkpoints after each epoch.""" from artifacta import get_run run = get_run() if run is None: - return # No active run + return + + # Log metrics from logs dict as Series + # logs contains: loss, accuracy, val_loss, val_accuracy, etc. + if logs: + # Convert metrics to Series format (epoch-indexed) + series_data = {"index_values": [epoch]} + for key, value in logs.items(): + series_data[key] = [float(value)] + + run.log("training_metrics", series_data) + + # Log checkpoint + if self.log_checkpoints: + with tempfile.NamedTemporaryFile(suffix=f"-epoch{epoch}.keras", delete=False) as tmp: + checkpoint_path = tmp.name + + self.model.save(checkpoint_path) + + artifact_name = f"checkpoint_epoch{epoch}" + run.log_artifact( + name=artifact_name, + path=checkpoint_path, + include_content=False, + metadata={ + "artifact_type": "model_checkpoint", + "framework": "tensorflow", + "epoch": epoch, + }, + role="output", + ) + + self.checkpoints_logged.append({"epoch": epoch}) + + from contextlib import suppress + with suppress(Exception): + os.remove(checkpoint_path) + + def on_train_end(self, logs=None): + """Log final trained model.""" + from artifacta import get_run - # logs parameter required by Keras callback interface but not used - _ = logs + run = get_run() + if run is None or not self.log_models: + return - # Save checkpoint to temp file - with tempfile.NamedTemporaryFile(suffix=f"-epoch{epoch}.keras", delete=False) as tmp: - checkpoint_path = tmp.name + # Save final model + with tempfile.NamedTemporaryFile(suffix="-final.keras", delete=False) as tmp: + model_path = tmp.name - # Save model - self.model.save(checkpoint_path) + self.model.save(model_path) - # Log as artifact with metadata - artifact_name = f"checkpoint_epoch{epoch}" run.log_artifact( - name=artifact_name, - path=checkpoint_path, + name="model", + path=model_path, include_content=False, metadata={ - "artifact_type": "model_checkpoint", + "artifact_type": "model", "framework": "tensorflow", - "epoch": epoch, }, role="output", ) - self.checkpoints_logged.append({"epoch": epoch}) - - # Cleanup temp file from contextlib import suppress - with suppress(Exception): - os.remove(checkpoint_path) + os.remove(model_path) -def enable_autolog(log_checkpoints: bool = True): +def enable_autolog(log_checkpoints: bool = True, log_models: bool = True): """Enable TensorFlow/Keras autolog. - Monkey-patches keras Model.fit() to inject Artifacta callbacks. + Automatically logs parameters, metrics, checkpoints, and models for all Model.fit() calls. + + Args: + log_checkpoints: Whether to log model checkpoints during training + log_models: Whether to log final trained model + + What is logged: + - Parameters: epochs, batch_size, optimizer_name, learning_rate, etc. + - Metrics: All metrics from logs dict (loss, val_loss, accuracy, etc.) + - Checkpoints: Model checkpoints saved during training (if log_checkpoints=True) + - Final model: Trained model at end of training (if log_models=True) """ global _AUTOLOG_ENABLED, _ORIGINAL_FIT @@ -92,17 +256,26 @@ def enable_autolog(log_checkpoints: bool = True): # Save original fit method _ORIGINAL_FIT = tf.keras.Model.fit - def patched_fit(self, *args, callbacks=None, **kwargs): + def patched_fit(self, *args, callbacks=None, epochs=None, batch_size=None, **kwargs): """Inject Artifacta callback into fit().""" - # Always inject callback if checkpoints enabled - # The callback itself checks for active run - if log_checkpoints: - ds_callback = ArtifactaCheckpointCallback() + # Extract epochs and batch_size from kwargs if not in args + if epochs is None and 'epochs' in kwargs: + epochs = kwargs['epochs'] + if batch_size is None and 'batch_size' in kwargs: + batch_size = kwargs['batch_size'] + + # Always inject callback - it checks for active run internally + autolog_callback = ArtifactaAutologCallback( + epochs=epochs, + batch_size=batch_size, + log_checkpoints=log_checkpoints, + log_models=log_models + ) - callbacks = [ds_callback] if callbacks is None else list(callbacks) + [ds_callback] + callbacks = [autolog_callback] if callbacks is None else list(callbacks) + [autolog_callback] # Call original fit with modified callbacks - return _ORIGINAL_FIT(self, *args, callbacks=callbacks, **kwargs) + return _ORIGINAL_FIT(self, *args, callbacks=callbacks, epochs=epochs, batch_size=batch_size, **kwargs) # Replace Model.fit with our patched version tf.keras.Model.fit = patched_fit diff --git a/artifacta/artifacta/integrations/xgboost.py b/artifacta/artifacta/integrations/xgboost.py new file mode 100644 index 0000000..9f603b3 --- /dev/null +++ b/artifacta/artifacta/integrations/xgboost.py @@ -0,0 +1,641 @@ +"""XGBoost autolog integration. + +This module implements automatic logging for XGBoost models by patching +the xgboost.train() function and sklearn API. + +What gets logged automatically: + - Parameters: All parameters passed to xgboost.train() + - Training metrics: Per-iteration metrics from evals (validation sets) + - Feature importance: Multiple types (weight, gain, cover) as JSON + plots + - Model artifacts: Trained booster in native XGBoost format + - Best iteration metrics: If early stopping is used + +Architecture: + 1. enable_autolog() patches xgboost.train() and sklearn API + 2. Patched train() injects callback for metrics logging + 3. After training, log feature importance and model + 4. Callback logs metrics at each iteration + 5. disable_autolog() restores original methods + +Patching Strategy: + - Patch xgboost.train() function + - Patch sklearn API (XGBClassifier, XGBRegressor) fit() methods + - Use callbacks for per-iteration metrics + - Track active training to prevent nested logging + - Sanitize metric names (@ β†’ _at_) + +Special Features: + - Per-iteration metrics via callbacks + - Feature importance plots (multiple types) + - Early stopping detection (best iteration metrics) + - Metric name sanitization (@ β†’ _at_) +""" + +import contextlib +import functools +import json +import logging +import pickle +import tempfile +from typing import List, Optional + +import numpy as np + +_logger = logging.getLogger(__name__) + +# Global state +_AUTOLOG_ENABLED = False +_ORIGINAL_TRAIN = None +_ORIGINAL_SKLEARN_METHODS = {} # Store original sklearn fit methods +_ACTIVE_TRAINING = False # Prevent nested logging + + +def enable_autolog( + log_models: bool = True, + log_feature_importance: bool = True, + importance_types: Optional[List[str]] = None, + log_datasets: bool = True, +): + """Enable XGBoost autolog. + + Patches xgboost.train() and sklearn API to automatically log: + - Parameters + - Per-iteration training metrics + - Feature importance (weight, gain, cover) + - Trained model + - Dataset metadata (shape, dtype, hash) + + Args: + log_models: If True, save trained booster as artifact + log_feature_importance: If True, log feature importance as JSON + importance_types: Feature importance types to log. Default: ["weight", "gain", "cover"] + log_datasets: If True, log dataset metadata (requires XGBoost >= 1.7.0) + + Example: + >>> import artifacta as ds + >>> ds.xgboost.autolog() + >>> import xgboost as xgb + >>> dtrain = xgb.DMatrix(X_train, y_train) + >>> dval = xgb.DMatrix(X_val, y_val) + >>> params = {"max_depth": 3, "eta": 0.1} + >>> booster = xgb.train(params, dtrain, evals=[(dval, "val")]) + >>> # Params, metrics, feature importance, model auto-logged + """ + global _AUTOLOG_ENABLED, _ORIGINAL_TRAIN, _ORIGINAL_SKLEARN_METHODS + + if _AUTOLOG_ENABLED: + _logger.warning("XGBoost autolog already enabled") + return + + try: + import xgboost as xgb + except ImportError as err: + raise ImportError( + "xgboost is not installed. Install with: pip install xgboost" + ) from err + + if importance_types is None: + importance_types = ["weight", "gain", "cover"] + + # Patch native xgboost.train() + _ORIGINAL_TRAIN = xgb.train + xgb.train = _create_patched_train( + xgb.train, + log_models=log_models, + log_feature_importance=log_feature_importance, + importance_types=importance_types, + log_datasets=log_datasets, + ) + + # Patch sklearn API + _patch_sklearn_api( + log_models=log_models, + log_feature_importance=log_feature_importance, + importance_types=importance_types, + log_datasets=log_datasets, + ) + + _AUTOLOG_ENABLED = True + _logger.info("XGBoost autolog enabled") + + +def disable_autolog(): + """Disable XGBoost autolog and restore original methods.""" + global _AUTOLOG_ENABLED, _ORIGINAL_TRAIN, _ORIGINAL_SKLEARN_METHODS + + if not _AUTOLOG_ENABLED: + return + + try: + import xgboost as xgb + + # Restore xgboost.train() + if _ORIGINAL_TRAIN is not None: + xgb.train = _ORIGINAL_TRAIN + + # Restore sklearn API + for (cls, method_name), original_method in _ORIGINAL_SKLEARN_METHODS.items(): + setattr(cls, method_name, original_method) + + _ORIGINAL_SKLEARN_METHODS.clear() + _AUTOLOG_ENABLED = False + _logger.info("XGBoost autolog disabled") + + except ImportError: + pass + + +def _create_patched_train( + original_train, + log_models, + log_feature_importance, + importance_types, + log_datasets, +): + """Create patched xgboost.train() function. + + Args: + original_train: Original xgboost.train function + log_models: Whether to log model artifacts + log_feature_importance: Whether to log feature importance + importance_types: Types of feature importance to log + log_datasets: Whether to log dataset metadata + + Returns: + Patched train function + """ + + @functools.wraps(original_train) + def patched_train(params, dtrain, *args, evals=None, **kwargs): + """Patched xgboost.train() that adds autologging.""" + global _ACTIVE_TRAINING + + # Prevent nested logging + if _ACTIVE_TRAINING: + return original_train(params, dtrain, *args, evals=evals, **kwargs) + + from artifacta import get_run + + run = get_run() + if run is None: + # No active run, just call original train + return original_train(params, dtrain, *args, evals=evals, **kwargs) + + _ACTIVE_TRAINING = True + try: + # Log parameters + _log_params(run, params) + + # Log dataset metadata (requires XGBoost >= 1.7.0) + if log_datasets: + _log_xgboost_datasets(run, dtrain, evals) + + # Create metrics logger + metrics_history = [] + + # Inject callback for metrics logging + callbacks = kwargs.get("callbacks", []) + if not isinstance(callbacks, list): + callbacks = [callbacks] if callbacks is not None else [] + + # Add metrics logging callback + autolog_callback = _create_autolog_callback(run, metrics_history) + callbacks.append(autolog_callback) + kwargs["callbacks"] = callbacks + + # Call original train + booster = original_train(params, dtrain, *args, evals=evals, **kwargs) + + # Log feature importance + if log_feature_importance and booster is not None: + _log_feature_importance(run, booster, importance_types) + + # Log model artifact + if log_models and booster is not None: + _log_model(run, booster) + + return booster + + finally: + _ACTIVE_TRAINING = False + + return patched_train + + +def _patch_sklearn_api(log_models, log_feature_importance, importance_types, log_datasets): + """Patch XGBoost sklearn API (XGBClassifier, XGBRegressor, XGBRanker). + + Args: + log_models: Whether to log models + log_feature_importance: Whether to log feature importance + importance_types: Types of feature importance to log + log_datasets: Whether to log dataset metadata + """ + try: + import xgboost as xgb + + # Get sklearn-compatible classes + sklearn_classes = [ + xgb.XGBClassifier, + xgb.XGBRegressor, + ] + + # Also try XGBRanker if available + with contextlib.suppress(AttributeError): + sklearn_classes.append(xgb.XGBRanker) + + for sklearn_class in sklearn_classes: + _patch_sklearn_fit( + sklearn_class, + log_models=log_models, + log_feature_importance=log_feature_importance, + importance_types=importance_types, + log_datasets=log_datasets, + ) + + except ImportError: + pass + + +def _patch_sklearn_fit( + sklearn_class, + log_models, + log_feature_importance, + importance_types, + log_datasets, +): + """Patch fit() method of XGBoost sklearn class. + + Args: + sklearn_class: XGBoost sklearn class (XGBClassifier, XGBRegressor, etc.) + log_models: Whether to log models + log_feature_importance: Whether to log feature importance + importance_types: Types of feature importance to log + log_datasets: Whether to log dataset metadata + """ + global _ORIGINAL_SKLEARN_METHODS + + # Save original fit + original_fit = sklearn_class.fit + key = (sklearn_class, "fit") + _ORIGINAL_SKLEARN_METHODS[key] = original_fit + + @functools.wraps(original_fit) + def patched_fit(self, X, y, **fit_params): + """Patched fit() that adds autologging.""" + global _ACTIVE_TRAINING + + # Prevent nested logging + if _ACTIVE_TRAINING: + return original_fit(self, X, y, **fit_params) + + from artifacta import get_run + + run = get_run() + if run is None: + return original_fit(self, X, y, **fit_params) + + _ACTIVE_TRAINING = True + try: + # Log parameters (get_params includes hyperparameters) + params = self.get_params(deep=True) + _log_params(run, params) + + # Log dataset metadata + if log_datasets: + from .dataset_utils import log_dataset_metadata + log_dataset_metadata(run, X, y, context="train") + + # Call original fit + result = original_fit(self, X, y, **fit_params) + + # Get underlying booster + booster = self.get_booster() + + # Log feature importance + if log_feature_importance and booster is not None: + _log_feature_importance(run, booster, importance_types) + + # Log model artifact + if log_models: + _log_sklearn_model(run, self) + + return result + + finally: + _ACTIVE_TRAINING = False + + # Replace fit method + sklearn_class.fit = patched_fit + + +def _create_autolog_callback(run, metrics_history): + """Create callback for logging metrics at each iteration. + + Args: + run: Active Artifacta run + metrics_history: List to store metrics history + + Returns: + Callback function compatible with XGBoost + """ + import xgboost as xgb + from packaging.version import Version + + # Check XGBoost version to determine callback API + xgb_version = Version(xgb.__version__.replace("SNAPSHOT", "dev")) + use_new_callback = xgb_version >= Version("1.3.0") + + if use_new_callback: + # XGBoost >= 1.3.0: Use TrainingCallback class + class AutologCallback(xgb.callback.TrainingCallback): + """Callback for logging metrics at each iteration (XGBoost >= 1.3.0).""" + + def after_iteration(self, model, epoch, evals_log): + """Called after each iteration. + + Args: + model: XGBoost booster + epoch: Current iteration number + evals_log: Dict of evaluation results + Format: {"eval_name": {"metric_name": [values...]}} + + Returns: + False to continue training + """ + # Extract metrics from evals_log + metrics = {} + for eval_name, metric_dict in evals_log.items(): + for metric_name, metric_values in metric_dict.items(): + # Get latest value (last in list) + value = metric_values[-1] + # Sanitize metric name (@ β†’ _at_) + sanitized_name = metric_name.replace("@", "_at_") + key = f"{eval_name}_{sanitized_name}" + metrics[key] = value + + # Log metrics for this iteration + if metrics: + metrics["iteration"] = epoch + metrics_history.append(metrics) + + # Log to Artifacta as series data + try: + _log_iteration_metrics(run, metrics_history) + except Exception as e: + _logger.warning(f"Failed to log metrics: {e}") + + return False # Continue training + + return AutologCallback() + + else: + # XGBoost < 1.3.0: Use function callback + def autolog_callback_fn(env): + """Callback for logging metrics (XGBoost < 1.3.0). + + Args: + env: XGBoost callback environment with evaluation_result_list + """ + # Extract metrics from evaluation results + metrics = {} + for eval_result in env.evaluation_result_list: + # eval_result is tuple: (eval_name, metric_name, value, is_higher_better) + eval_name = eval_result[0] + metric_name = eval_result[1] + value = eval_result[2] + + # Sanitize metric name + sanitized_name = metric_name.replace("@", "_at_") + key = f"{eval_name}_{sanitized_name}" + metrics[key] = value + + # Log metrics + if metrics: + metrics["iteration"] = env.iteration + metrics_history.append(metrics) + + try: + _log_iteration_metrics(run, metrics_history) + except Exception as e: + _logger.warning(f"Failed to log metrics: {e}") + + return autolog_callback_fn + + +def _log_params(run, params): + """Log XGBoost parameters. + + Args: + run: Active Artifacta run + params: Dictionary of XGBoost parameters + """ + try: + # Convert params to simple types + serializable_params = {} + for key, value in params.items(): + # Skip complex objects + if hasattr(value, "__dict__") and not isinstance(value, (int, float, str, bool)): + continue + # Convert numpy types + if isinstance(value, (np.integer, np.floating)): + value = value.item() + elif isinstance(value, np.ndarray): + value = value.tolist() + + serializable_params[key] = value + + if serializable_params: + run.update_config(serializable_params) + + except Exception as e: + _logger.warning(f"Failed to log parameters: {e}") + + +def _log_iteration_metrics(run, metrics_history): + """Log metrics from all iterations. + + Args: + run: Active Artifacta run + metrics_history: List of metric dicts from each iteration + """ + if not metrics_history: + return + + # Convert to series format: {"iteration": [...], "metric1": [...], "metric2": [...]} + series_data = {} + + for metrics in metrics_history: + for key, value in metrics.items(): + if key not in series_data: + series_data[key] = [] + series_data[key].append(value) + + # Log as series + run.log("xgboost_metrics", series_data) + + +def _log_xgboost_datasets(run, dtrain, evals): + """Log dataset metadata from XGBoost DMatrix objects. + + Uses DMatrix.get_data() method (XGBoost >= 1.7.0) to retrieve original data. + + Args: + run: Active Artifacta run + dtrain: Training DMatrix + evals: List of (DMatrix, name) tuples for evaluation sets + """ + import xgboost as xgb + from packaging.version import Version + + # Check XGBoost version + if Version(xgb.__version__) < Version("1.7.0"): + _logger.warning( + "Dataset logging requires XGBoost >= 1.7.0. " + f"Current version: {xgb.__version__}. Skipping dataset logging." + ) + return + + try: + from .dataset_utils import log_dataset_metadata + + # Log training dataset + try: + train_data = dtrain.get_data() + log_dataset_metadata(run, train_data, context="train") + except Exception as e: + _logger.warning(f"Failed to log training dataset: {e}") + + # Log evaluation datasets + if evals: + for deval, eval_name in evals: + try: + eval_data = deval.get_data() + log_dataset_metadata(run, eval_data, context=f"eval_{eval_name}") + except Exception as e: + _logger.warning(f"Failed to log eval dataset '{eval_name}': {e}") + + except Exception as e: + _logger.warning(f"Failed to log datasets: {e}") + + +def _log_feature_importance(run, booster, importance_types): + """Log feature importance as JSON. + + Args: + run: Active Artifacta run + booster: Trained XGBoost booster + importance_types: List of importance types to log + """ + try: + for importance_type in importance_types: + # Get importance scores + importance_dict = booster.get_score(importance_type=importance_type) + + if importance_dict: + # Convert to JSON + importance_json = json.dumps(importance_dict, indent=2) + + # Save to temp file + with tempfile.NamedTemporaryFile( + mode="w", + suffix=f"_importance_{importance_type}.json", + delete=False, + ) as tmp: + tmp.write(importance_json) + tmp_path = tmp.name + + # Log as artifact + run.log_artifact( + name=f"feature_importance_{importance_type}", + path=tmp_path, + include_content=True, + metadata={ + "artifact_type": "feature_importance", + "importance_type": importance_type, + }, + role="output", + ) + + # Cleanup + import os + from contextlib import suppress + + with suppress(Exception): + os.remove(tmp_path) + + except Exception as e: + _logger.warning(f"Failed to log feature importance: {e}") + + +def _log_model(run, booster): + """Log XGBoost booster as artifact. + + Args: + run: Active Artifacta run + booster: Trained XGBoost booster + """ + try: + # Save model to temp file + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: + booster.save_model(tmp.name) + model_path = tmp.name + + # Log as artifact + run.log_artifact( + name="xgboost_model", + path=model_path, + include_content=False, + metadata={ + "artifact_type": "xgboost_model", + "model_format": "json", + }, + role="output", + ) + + # Cleanup + import os + from contextlib import suppress + + with suppress(Exception): + os.remove(model_path) + + except Exception as e: + _logger.warning(f"Failed to log model: {e}") + + +def _log_sklearn_model(run, model): + """Log XGBoost sklearn model as artifact. + + Args: + run: Active Artifacta run + model: Trained XGBoost sklearn model + """ + try: + # Save model to temp file (pickle format for sklearn compatibility) + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp: + pickle.dump(model, tmp) + model_path = tmp.name + + # Log as artifact + model_class = model.__class__.__name__ + run.log_artifact( + name=f"{model_class}_model", + path=model_path, + include_content=False, + metadata={ + "artifact_type": "xgboost_sklearn_model", + "model_class": model_class, + }, + role="output", + ) + + # Cleanup + import os + from contextlib import suppress + + with suppress(Exception): + os.remove(model_path) + + except Exception as e: + _logger.warning(f"Failed to log sklearn model: {e}") diff --git a/artifacta/artifacta/metadata.py b/artifacta/artifacta/metadata.py index 29e48fc..6cb3344 100644 --- a/artifacta/artifacta/metadata.py +++ b/artifacta/artifacta/metadata.py @@ -1,4 +1,75 @@ -"""Metadata capture (git, environment, system).""" +"""Comprehensive metadata capture for run reproducibility. + +This module captures a complete snapshot of the execution environment at run +initialization, similar to Weights & Biases. The metadata enables reproducibility +by recording exactly what system, dependencies, and git state were used for a run. + +Captured Metadata Categories: + 1. Git: Commit hash, remote URL, branch, dirty status, diff + 2. Environment: Hostname, username, Python version, platform, CWD, command + 3. System: CPU count, memory, GPU info (name, memory per GPU) + 4. Dependencies: pip freeze output (all installed packages with versions) + +Metadata Structure: + All metadata is organized in a hierarchical dictionary: + { + "git": {...}, + "environment": {...}, + "system": {...}, + "dependencies": {...} + } + +Git Metadata Algorithm: + 1. Run 'git rev-parse HEAD' to get commit hash + 2. Run 'git config --get remote.origin.url' to get remote URL + 3. Run 'git status --porcelain' to check for uncommitted changes + 4. If uncommitted changes exist (dirty=True), capture full diff + 5. All commands use 5-second timeout and redirect stderr to DEVNULL + 6. If any command fails, return None (not in git repo or git not available) + + Why capture diff for dirty repos: + If there are uncommitted changes, the commit hash alone isn't enough + for reproducibility. We capture the full diff so users can see exactly + what modifications were present during the run. + +Environment Metadata: + - hostname: Identifies which machine ran the experiment + - username: Tracks who ran the experiment (useful in shared environments) + - python_version: sys.version (e.g., "3.10.4 (main, ...)") + - platform: platform.platform() (e.g., "Darwin-21.4.0-arm64-arm-64bit") + - cwd: Current working directory (helps reproduce relative paths) + - command: Full command-line invocation (e.g., "python train.py --lr 0.01") + +System Metadata Algorithm: + 1. CPU: psutil.cpu_count() with logical=False (physical cores) and True (threads) + 2. Memory: psutil.virtual_memory().total converted to GB + 3. GPU detection: + a. Try to import pynvml and initialize NVML + b. Get device count via nvmlDeviceGetCount() + c. For each GPU: Get handle, extract name and total memory + d. Decode name from bytes if necessary (NVML returns bytes on some platforms) + e. Convert memory from bytes to GB (divide by 1024^3) + f. Call nvmlShutdown() to cleanup + g. If any step fails, skip GPU metadata (no GPU or pynvml not installed) + +Dependencies Metadata: + - Run 'pip freeze' to get all installed packages with exact versions + - Uses sys.executable to ensure correct Python interpreter + - 10-second timeout to avoid hangs on slow systems + - Redirect stderr to suppress warnings + - If fails, return None (pip not available or command failed) + + Why pip freeze vs requirements.txt: + requirements.txt only lists direct dependencies. pip freeze captures + the complete dependency tree with exact versions, which is needed for + perfect reproducibility (transitive dependencies can change behavior). + +Design Philosophy: + - Comprehensive: Capture everything needed for reproducibility + - Fail-safe: All capture functions return None on error, never crash + - Zero configuration: Works automatically, no user setup required + - Storage efficient: Text-based, compresses well in database +""" import getpass import os diff --git a/artifacta/artifacta/metadata_extractors/torch.py b/artifacta/artifacta/metadata_extractors/torch.py index 05928f6..25b7711 100644 --- a/artifacta/artifacta/metadata_extractors/torch.py +++ b/artifacta/artifacta/metadata_extractors/torch.py @@ -1,4 +1,25 @@ -"""PyTorch model metadata extractor.""" +"""PyTorch model metadata extractor for checkpoint files. + +Automatically extracts metadata from PyTorch model checkpoints (.pt, .pth, .ckpt) +to provide insights into model architecture, training progress, and hyperparameters. + +Supported checkpoint formats: +- PyTorch Lightning: {"state_dict": {...}, "epoch": N, "hyper_parameters": {...}} +- Standard PyTorch: {"model_state_dict": {...}, "epoch": N, "loss": X} +- Raw state_dict: OrderedDict of layer tensors + +Extraction strategy: +1. Load checkpoint with torch.load (CPU, no weights_only for compatibility) +2. Detect checkpoint format by inspecting dict keys +3. Extract training metadata (epoch, loss, hyperparameters) +4. Analyze state_dict for parameter count and layer names +5. Return structured metadata dict + +Performance: +- Loads on CPU to avoid GPU memory usage +- Works even if training was done on CUDA +- Handles large checkpoints efficiently +""" import os @@ -6,11 +27,47 @@ def extract_pytorch_metadata(filepath: str) -> dict: """Extract metadata from PyTorch checkpoint file (.pt, .pth, .ckpt). + Detection algorithm: + 1. Check if PyTorch is installed (graceful failure if not) + 2. Load checkpoint with map_location="cpu" to avoid GPU dependency + 3. Inspect checkpoint structure to determine format: + - Has "state_dict" key? β†’ PyTorch Lightning format + - Has "model_state_dict" key? β†’ Standard PyTorch format + - Has tensor values? β†’ Raw state_dict + 4. Extract training metadata if available (epoch, loss, hyperparameters) + 5. Analyze state_dict for model architecture: + - Count total parameters (sum of tensor.numel() across all layers) + - Extract unique layer names (remove .weight/.bias suffixes) + - Sample first 10 layer names for inspection + + Why map_location="cpu": + - Avoids CUDA out of memory errors when extracting metadata + - Works even if checkpoint was saved on GPU + - Metadata extraction doesn't need GPU compute + + Why weights_only=False: + - Some checkpoints contain non-tensor objects (hyperparameters, optimizers) + - We need full checkpoint structure for metadata extraction + - Security is less of a concern for user's own checkpoints + Args: filepath: Path to PyTorch checkpoint file Returns: - dict with extracted metadata (parameter count, file size, layers, etc.) + dict with extracted metadata: + - file_size_bytes: Size of checkpoint file + - total_parameters: Total number of model parameters + - num_layers: Number of unique layers + - layer_names_sample: First 10 layer names + - saved_epoch: Training epoch when checkpoint was saved (if available) + - saved_global_step: Global training step (PyTorch Lightning) + - saved_loss: Training loss at checkpoint (if available) + - saved_hyperparameters: Model hyperparameters (PyTorch Lightning) + + Returns empty dict if: + - PyTorch is not installed + - File doesn't exist + - Checkpoint can't be loaded (corrupted, incompatible format) """ try: import torch @@ -80,7 +137,35 @@ def extract_pytorch_metadata(filepath: str) -> dict: def _extract_from_state_dict(state_dict: dict) -> dict: - """Extract metadata from PyTorch state_dict.""" + """Extract metadata from PyTorch state_dict. + + Algorithm: + 1. Iterate through all entries in state_dict + 2. For each tensor parameter: + - Count parameters using tensor.numel() (number of elements) + - Extract layer name by removing suffix (.weight, .bias, etc.) + - Track unique layer names (avoid duplicates) + 3. Return parameter count, layer count, and sample names + + Why remove suffixes: + - A single layer has multiple entries (weight, bias, running_mean, etc.) + - We want to count unique layers, not unique parameters + - Example: "conv1.weight" and "conv1.bias" β†’ one layer "conv1" + + Why limit to 10 layer names: + - Prevents overwhelming metadata for large models + - Provides enough info for inspection without bloat + - Full layer list can be extracted from checkpoint if needed + + Args: + state_dict: PyTorch state_dict (dict mapping param names to tensors) + + Returns: + dict with: + - total_parameters: Total parameter count across all layers + - num_layers: Number of unique layers + - layer_names_sample: First 10 unique layer names + """ total_params = 0 layer_names = [] diff --git a/artifacta/artifacta/monitor.py b/artifacta/artifacta/monitor.py index 9ca03e7..f3c2022 100644 --- a/artifacta/artifacta/monitor.py +++ b/artifacta/artifacta/monitor.py @@ -1,4 +1,61 @@ -"""System monitoring background thread.""" +"""System monitoring via background thread for automatic metrics capture. + +This module implements a background monitoring daemon that periodically captures +system-level metrics (CPU, memory, disk, network, GPU) and emits them to the +tracking server. The design follows a non-invasive pattern: monitoring runs in +a daemon thread and fails silently to avoid impacting the main training loop. + +Architecture: + - Background Thread: Runs in daemon mode (won't block Python exit) + - Periodic Sampling: Configurable interval (default 30 seconds) + - Graceful Degradation: All metric capture wrapped in try/except + - GPU Support: Optional via pynvml (NVIDIA GPUs only) + - Process-Level: Tracks both system-wide and current process metrics + +Monitoring Strategy: + The monitor captures two categories of metrics: + + 1. System-wide metrics (all processes): + - CPU percent, memory usage, disk I/O, network I/O + - Provides context about overall system load + + 2. Process-specific metrics (current Python process): + - CPU percent, thread count, RSS memory, memory percent + - Helps identify if training is bottlenecked by system resources + +GPU Monitoring Algorithm: + 1. Try to import pynvml and initialize NVML at __init__ + 2. If successful, set has_gpu=True and store pynvml reference + 3. In _capture_metrics(), iterate over all GPU devices + 4. For each GPU, capture: + - Utilization (GPU compute %, memory %) + - Memory (used bytes, allocated %) + - Temperature (Celsius) + - Power (watts, percent of limit) + - Clock speeds (SM, memory, graphics in MHz) + - Memory errors (corrected/uncorrected ECC errors) + - Encoder utilization (for video encoding workloads) + 5. Wrap each metric in try/except (some GPUs don't support all metrics) + +Background Thread Lifecycle: + 1. start() -> Create daemon thread, set running=True, start loop + 2. Loop: capture metrics, emit via HTTP, sleep for interval + 3. stop() -> Set running=False, join thread with 5s timeout + 4. Cleanup: Call nvmlShutdown() if GPU monitoring was enabled + +Performance Considerations: + - psutil.oneshot() context manager batches system calls for efficiency + - CPU percent uses interval=1 for accuracy (blocks for 1 second) + - Process CPU uses interval=None (non-blocking, uses cached value) + - All exceptions suppressed to avoid crashing the monitoring thread + - Daemon thread ensures no zombie threads after program exit + +Why daemon thread: + Daemon threads don't prevent the Python interpreter from exiting. + If training finishes, the main thread exits, and the monitor thread + is automatically terminated. This avoids requiring users to explicitly + call stop() in all cases. +""" import contextlib import threading diff --git a/artifacta/artifacta/primitives.py b/artifacta/artifacta/primitives.py index c94d73e..4108ee0 100644 --- a/artifacta/artifacta/primitives.py +++ b/artifacta/artifacta/primitives.py @@ -1,14 +1,36 @@ -"""Data primitives for structured logging. - -Universal data schema that supports any domain: -- ML training -- A/B testing -- Physics simulations -- Financial data -- Genomics -- Analytics -- Robotics -- And more... +"""Data primitives for structured logging and visualization. + +This module provides a universal data schema that supports any domain through +a small set of well-designed primitives. Instead of creating domain-specific +data structures, these primitives can represent data from ML training, A/B testing, +physics simulations, financial analysis, genomics, analytics, robotics, and more. + +Architecture: + The primitives form a type hierarchy optimized for common visualization patterns: + + - Series: Ordered data over a single dimension (time, epochs, steps) + - Distribution: Value collections with optional grouping (A/B test results) + - Matrix: 2D relationships (confusion matrices, correlation matrices) + - Table: Generic tabular data (event logs, measurements) + - Curve: Pure X-Y relationships (ROC curves, dose-response) + - Scatter: Unordered point clouds (embeddings, particle positions) + - BarChart: Categorical comparisons (model performance, metrics by group) + +Design Philosophy: + 1. Domain-agnostic: Same primitives work for any field + 2. Auto-conversion: Plain Python types (dict, list, numpy arrays) are + automatically converted to the appropriate primitive via auto_convert() + 3. Serializable: All primitives have to_dict() for JSON serialization + 4. Type-safe: Dataclasses provide structure and validation + 5. Extensible: Metadata fields allow domain-specific annotations + +Conversion Strategy: + The auto_convert() function implements intelligent type detection: + - numpy arrays: 1D -> Distribution, 2D -> Matrix + - dict: -> Series (with index detection) + - list: 1D -> Distribution, 2D -> Matrix + +This allows users to just log their data naturally without thinking about types. """ from dataclasses import dataclass @@ -43,7 +65,20 @@ class Series: metadata: Optional[Dict[str, Any]] = None def to_dict(self): - """Convert to dictionary.""" + """Convert to dictionary for JSON serialization. + + Conversion algorithm: + 1. Convert field values from numpy arrays to lists for JSON compatibility + 2. Optionally include explicit index_values if provided (supports categorical indices) + 3. Include metadata if present + 4. Return minimal dictionary (only required + populated optional fields) + + The conversion handles numpy arrays gracefully by detecting isinstance() and + converting to native Python lists, which are JSON-serializable. + + Returns: + Dict with 'index', 'fields', and optionally 'index_values' and 'metadata' + """ d = { "index": self.index, "fields": { @@ -310,26 +345,64 @@ def to_dict(self): def auto_convert(data): """Auto-detect and convert plain Python types to primitives. - Makes it easy for users - they just log dicts/lists/arrays and we handle conversion. + This function makes the API user-friendly by accepting plain Python types (dict, list, + numpy arrays) and intelligently converting them to the appropriate primitive type. + Users can just log their data naturally without thinking about type conversions. + + Detection Algorithm: + The function uses a multi-stage detection strategy: + + 1. **Early exit**: If data is already a primitive, return as-is (no conversion overhead) + + 2. **NumPy array detection**: + - 1D arrays (shape: [n]) -> Distribution (values only, no ordering implied) + - 2D arrays (shape: [m, n]) -> Matrix (rows x columns structure) + - 3D+ arrays -> ValueError (not supported, too high-dimensional for visualization) + + 3. **Dict detection** (most complex case, optimized for metrics logging): + a. Scan for index field candidates in priority order: + - Explicit: "x", "epoch", "step", "time", "iteration" + - These names signal ordered/sequential data + b. Extract index_values from the index field + c. Remaining list/array fields become Series fields + d. If no index found, use first field as index (fallback) + e. If no fields at all, use default "index" name + f. Return Series with detected index and fields + + 4. **List/tuple detection**: + - Empty list -> Distribution (empty values) + - Nested lists [[...], [...]] -> Matrix (2D structure detected) + - Flat list [1, 2, 3] -> Distribution (1D values) + + 5. **Fallback**: Wrap scalar/unknown types in single-element Distribution + + Why this algorithm: + - Dict -> Series: Most common ML/analytics use case (metrics over epochs/steps) + - 1D data -> Distribution: Natural for value collections without ordering + - 2D data -> Matrix: Natural for heatmaps, confusion matrices, correlations + - Index detection: Prioritizes common names to infer sequential data automatically Examples: # Series from dict with multiple fields - {"epoch": [1,2,3], "loss": [0.5,0.3,0.2]} -> Series + {"epoch": [1,2,3], "loss": [0.5,0.3,0.2]} -> Series(index="epoch", fields={"loss": [...]}) # Series from dict with x/y - {"x": [1,2,3], "y": [0.5,0.3,0.2]} -> Series + {"x": [1,2,3], "y": [0.5,0.3,0.2]} -> Series(index="x", fields={"y": [...]}) # Distribution from 1D array/list - [0.1, 0.2, 0.15, ...] -> Distribution + [0.1, 0.2, 0.15] -> Distribution(values=[0.1, 0.2, 0.15]) # Matrix from 2D array/list - [[1,2], [3,4]] -> Matrix + [[1,2], [3,4]] -> Matrix(data=[[1,2], [3,4]]) Args: - data: Plain Python dict, list, or numpy array + data: Plain Python dict, list, tuple, or numpy array to convert Returns: One of the primitive types (Series, Distribution, Matrix, etc.) + + Raises: + ValueError: If data is a 3D+ numpy array (not supported) """ # Already a primitive? Return as-is if type(data) in PRIMITIVE_TYPES: diff --git a/artifacta/artifacta/run.py b/artifacta/artifacta/run.py index 79ce9d5..9270ff3 100644 --- a/artifacta/artifacta/run.py +++ b/artifacta/artifacta/run.py @@ -1,4 +1,15 @@ -"""Run management.""" +"""Run management with provenance tracking and artifact logging. + +The Run class is the core abstraction for experiment tracking in Artifacta. +It orchestrates metric logging, artifact management, and system monitoring +while maintaining full reproducibility through content-addressed storage. + +Key components: +- HTTPEmitter: Real-time communication with tracking server (MLflow/W&B pattern) +- SystemMonitor: Background thread for CPU/memory/GPU metrics +- Artifact hashing: SHA256-based content addressing for reproducibility +- Auto-provenance: Automatic capture of config, dependencies, environment +""" import hashlib import json @@ -19,16 +30,43 @@ class Run: - """A single training run.""" + """A single training run with provenance tracking and artifact management. + + The Run class is the core abstraction for experiment tracking in Artifacta. + It handles: + - Automatic provenance capture (config, dependencies, environment, code) + - Real-time metric emission to tracking server via HTTP + - Artifact logging with content-addressed storage (SHA256 hashing) + - System monitoring (CPU, memory, GPU) in background thread + - Graceful degradation when tracking server is unavailable + + Architecture: + - HTTPEmitter: Sends data to tracking server in real-time (MLflow/W&B pattern) + - SystemMonitor: Background thread that captures system metrics every N seconds + - Artifact hashing: SHA256 of file contents ensures reproducibility tracking + + Lifecycle: + 1. __init__: Create run object with metadata + 2. start(): Auto-log provenance artifacts, start system monitoring + 3. log()/log_artifact(): Log data during training + 4. finish(): Stop monitoring, close connections + + Example: + >>> import artifacta as ds + >>> run = ds.init(project="mnist", name="exp-1", config={"lr": 0.001}) + >>> run.log("metrics", {"epoch": [1,2,3], "loss": [0.5, 0.3, 0.2]}) + >>> run.log_output("model.pt") + >>> run.finish() + """ def __init__(self, project, name, config, code_dir=None): - """Initialize a Run instance. + """Initialize a Run instance with metadata. Args: - project: Project name. - name: Run name. - config: Configuration dictionary. - code_dir: Optional code directory path. + project: Project name for grouping related experiments. + name: Human-readable run name (auto-generated if None). + config: Configuration dictionary (hyperparameters, settings). + code_dir: Optional code directory path for artifact logging. """ self.id = f"run_{int(time.time() * 1000)}" self.project = project @@ -50,7 +88,26 @@ def __init__(self, project, name, config, code_dir=None): self.code_artifact_hash = None # Hash of logged code artifact (if any) def start(self): - """Start the run and auto-log provenance artifacts.""" + """Start the run and auto-log provenance artifacts. + + Initialization sequence: + 1. Emit run creation to tracking server (creates database entry) + 2. Auto-log provenance artifacts as inputs: + - config.json: Hyperparameter configuration (JSON) + - requirements.txt: pip freeze output (dependencies) + - environment.json: Platform info (Python version, OS, CUDA) + 3. Start system monitoring background thread (1-second interval) + + Why 1-second interval for system monitoring: + - Short training runs (< 30s) need faster sampling to capture metrics + - Balances data granularity with overhead + - Can be adjusted for longer experiments + + Graceful degradation: + - If tracking server is unavailable, HTTPEmitter disables itself + - Run continues normally, but metrics aren't persisted + - Useful for offline development + """ self.started_at = datetime.utcnow() # Emit run creation to API Gateway (simplified - no config/tags) @@ -66,13 +123,26 @@ def start(self): self.monitor = SystemMonitor(interval=1, http_emitter=self.http_emitter) self.monitor.start() - print(f"πŸš€ Run started: {self.name}") + print(f"Run started: {self.name}") print(f" ID: {self.id}") print(f" Project: {self.project}") def log(self, name: str, data, section: str = None): """Log structured data - accepts primitives OR plain Python types (auto-converted). + How it works: + 1. Auto-convert plain Python types to primitives (see primitives.py:auto_convert) + 2. Extract primitive type from PRIMITIVE_TYPES mapping + 3. Convert primitive to dict representation + 4. Emit to tracking server via HTTP (real-time) + 5. Server broadcasts to WebSocket clients for live UI updates + + Auto-conversion rules: + - dict with list values β†’ Series primitive + - 1D list/array β†’ Distribution primitive + - 2D list/array β†’ Matrix primitive + - numpy arrays β†’ Distribution or Matrix based on dimensions + Args: name: Name for this data object (e.g., "training_metrics", "confusion_matrix") data: Can be: @@ -125,7 +195,21 @@ def log(self, name: str, data, section: str = None): ) def log_artifact(self, name, path, include_content=True, metadata=None, role="output"): - """Log an artifact (file or directory) for this run. + """Log an artifact (file or directory) for this run with content-addressed storage. + + How it works: + 1. Collect file metadata (MIME type, size, path) using artifacts.collect_files + 2. Compute SHA256 hash of file contents for content addressing + 3. Inline text file contents if include_content=True (useful for code files) + 4. Auto-extract metadata from model files (PyTorch, TensorFlow, ONNX) + 5. Emit artifact to tracking server with metadata + content + 6. Track code artifact hash for provenance + + Content addressing (SHA256): + - Single file: hash of file contents + - Directory: hash of all file contents + filenames (sorted) + - Enables reproducibility tracking and deduplication + - Used for lineage graph and experiment comparison Works with both single files and directories containing multiple files. Each file in the collection retains its own type information (MIME type, language, etc). @@ -195,9 +279,8 @@ def log_artifact(self, name, path, include_content=True, metadata=None, role="ou # Print summary file_count = files_data["total_files"] size_mb = files_data["total_size"] / 1024 / 1024 - emoji = "πŸ“¦" if file_count > 1 else "πŸ“„" - print(f"{emoji} Artifact logged: {name}") + print(f"Artifact logged: {name}") print(f" Path: {path}") print(f" Files: {file_count}") print(f" Size: {size_mb:.2f} MB") @@ -231,6 +314,31 @@ def log_input(self, path, name=None, include_content=True, metadata=None): name, path, include_content=include_content, metadata=metadata, role="input" ) + def update_config(self, new_config: dict): + """Update run configuration with new parameters. + + Merges new_config into existing config and re-logs the config artifact. + Useful for autolog scenarios where parameters are discovered during training + (e.g., optimizer config, framework-specific params). + + Args: + new_config: Dictionary of new configuration parameters to merge + + Examples: + >>> # User initializes with their hyperparameters + >>> run = init(project="mnist", config={"batch_size": 32, "epochs": 10}) + >>> + >>> # Autolog discovers optimizer config during training + >>> run.update_config({"optimizer": "Adam", "lr": 0.001, "weight_decay": 1e-5}) + >>> + >>> # Final config contains both user and auto-discovered params + >>> print(run.config) + >>> # {"batch_size": 32, "epochs": 10, "optimizer": "Adam", "lr": 0.001, "weight_decay": 1e-5} + """ + if new_config: + self.config.update(new_config) + self._log_config_artifact() # Re-log config artifact with updated values + def log_output(self, path, name=None, include_content=True, metadata=None): """Log an output artifact (file or directory) for this run. @@ -261,13 +369,26 @@ def log_output(self, path, name=None, include_content=True, metadata=None): ) def _compute_artifact_hash(self, path): - """Compute SHA256 hash of file(s) at path. + """Compute SHA256 hash of file(s) at path for content addressing. + + Hashing strategy: + - Single file: SHA256 of raw file contents (read in 8KB chunks) + - Directory: SHA256 of all files + filenames (sorted for determinism) + + Why include filenames in directory hash: + - Renaming a file changes the artifact identity + - Moving files between directories changes identity + - Ensures that directory structure matters, not just file contents + + Performance: + - Chunk-based reading (8KB) handles large files without memory issues + - Recursive glob sorted by path ensures deterministic ordering Args: path: Path object (file or directory). Returns: - SHA256 hash string. + SHA256 hash string (hex digest). """ import hashlib @@ -423,5 +544,5 @@ def finish(self): # Close HTTP emitter self.http_emitter.close() - print(f"βœ… Run finished: {self.name}") + print(f"Run finished: {self.name}") print(f" Summary: {self.summary}") diff --git a/artifacta/artifacta/utils.py b/artifacta/artifacta/utils.py index 96fcf3a..582129c 100644 --- a/artifacta/artifacta/utils.py +++ b/artifacta/artifacta/utils.py @@ -1,4 +1,9 @@ -"""Utility functions for artifacta.""" +"""Utility functions for data transformation and serialization. + +This module provides helper functions for common data transformations needed +throughout the Artifacta codebase, particularly for configuration flattening +and JSON serialization. +""" import json @@ -6,20 +11,58 @@ def flatten_dict(d, parent_key="", sep="."): """Flatten a nested dictionary into dot-notation key-value pairs. - This is fully agnostic - works with ANY nested structure. + This function recursively traverses a nested dictionary structure and converts + it into a flat dictionary with dot-separated keys. This is useful for storing + hierarchical configurations in flat key-value stores (databases, tags, etc.). + + Flattening Algorithm: + 1. Iterate over each key-value pair in the input dictionary + 2. Construct new_key by joining parent_key with current key using separator + 3. Handle three value types: + a. Dict -> Recursively flatten and extend items list + b. List -> Convert to JSON string representation + c. Primitive (str, int, float, bool) -> Convert to string + 4. Accumulate all (key, value) pairs in items list + 5. Convert items list to dictionary and return + + Why convert everything to strings: + - Database storage: Most key-value stores require string values + - Tag systems: Tags are typically string-based for consistency + - JSON serialization: Lists as JSON strings preserve structure + - Type preservation: Can be reversed by parsing JSON for lists + + Recursion termination: + - Base case: Value is not a dict -> store as string + - Recursive case: Value is a dict -> call flatten_dict with current key as parent + + Edge cases: + - Empty dict: Returns empty dict + - Empty list: Returns "[]" as value + - None values: Converted to string "None" + - Nested empty dicts: Recursion handles naturally Examples: - {"a": {"b": 1}} -> {"a.b": "1"} - {"model": {"layers": [64, 32]}} -> {"model.layers": "[64, 32]"} - {"training": {"optimizer": {"type": "adam"}}} -> {"training.optimizer.type": "adam"} + Simple nesting: + {"a": {"b": 1}} -> {"a.b": "1"} + + Lists are JSON-encoded: + {"model": {"layers": [64, 32]}} -> {"model.layers": "[64, 32]"} + + Deep nesting: + {"training": {"optimizer": {"type": "adam"}}} + -> {"training.optimizer.type": "adam"} + + Mixed types: + {"config": {"lr": 0.01, "layers": [64, 32], "name": "model"}} + -> {"config.lr": "0.01", "config.layers": "[64, 32]", "config.name": "model"} Args: - d: Dictionary to flatten - parent_key: Prefix for keys (used in recursion) - sep: Separator for nested keys (default: ".") + d: Dictionary to flatten (can be nested arbitrarily deep) + parent_key: Prefix for keys (used internally for recursion, default "") + sep: Separator for nested keys (default ".", can use "/" or "_") Returns: - Flattened dictionary with string values + Flattened dictionary with string keys and string values """ items = [] for k, v in d.items(): diff --git a/artifacta/tests/test_basic.py b/artifacta/tests/test_basic.py index 3591669..8d4a41f 100644 --- a/artifacta/tests/test_basic.py +++ b/artifacta/tests/test_basic.py @@ -26,7 +26,7 @@ def test_basic_usage(): assert (run.log_dir / "metrics.jsonl").exists() assert (run.log_dir / "system.jsonl").exists() - print(f"βœ… Test passed! Files created at: {run.log_dir}") + print(f"Test passed! Files created at: {run.log_dir}") if __name__ == "__main__": diff --git a/artifacta/tests/test_sklearn_autolog.py b/artifacta/tests/test_sklearn_autolog.py new file mode 100644 index 0000000..575030e --- /dev/null +++ b/artifacta/tests/test_sklearn_autolog.py @@ -0,0 +1,504 @@ +"""Tests for scikit-learn autolog integration. + +Tests cover all features matching MLflow's sklearn autolog: +- Parameter logging (get_params with deep=True) +- Classifier metrics (accuracy, precision, recall, F1, log loss, ROC-AUC) +- Regressor metrics (MSE, RMSE, MAE, RΒ²) +- Model artifact logging +- Binary vs multiclass classification +- Meta-estimators (Pipeline, GridSearchCV) +- Autolog enable/disable +""" + + +import numpy as np +import pytest +from sklearn.datasets import load_iris, make_classification, make_regression +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.model_selection import GridSearchCV, train_test_split +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.tree import DecisionTreeClassifier + +from artifacta import init +from artifacta.tests.test_utils import ( + MockHTTPEmitter, + assert_artifact_logged, + assert_dataset_logged, + assert_metric_logged, + assert_model_artifact_logged, + assert_param_logged, + get_logged_datasets, + get_logged_metrics, + get_logged_params, +) + + +@pytest.fixture +def binary_classification_data(): + """Binary classification dataset.""" + X, y = make_classification( + n_samples=100, n_features=10, n_classes=2, random_state=42 + ) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def multiclass_classification_data(): + """Multiclass classification dataset (iris).""" + X, y = load_iris(return_X_y=True) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def regression_data(): + """Regression dataset.""" + X, y = make_regression(n_samples=100, n_features=10, random_state=42) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def temp_run(): + """Create and cleanup temporary run with mocked HTTP emitter.""" + run = init(project="test_sklearn_autolog", name="test_run") + # Replace the http_emitter with our mock + run.http_emitter = MockHTTPEmitter(run.id) + yield run + run.finish() + + +class TestSklearnAutologBasic: + """Basic autolog functionality tests.""" + + def test_autolog_enable_disable(self): + """Test enabling and disabling autolog.""" + from artifacta.integrations import sklearn + + # Enable + sklearn.enable_autolog() + assert sklearn._AUTOLOG_ENABLED is True + + # Disable + sklearn.disable_autolog() + assert sklearn._AUTOLOG_ENABLED is False + + def test_autolog_patches_estimators(self): + """Test that autolog patches estimator fit() methods.""" + from artifacta.integrations import sklearn + + sklearn.enable_autolog() + + # Check that fit method was patched + original_fit = sklearn._ORIGINAL_METHODS.get( + (RandomForestClassifier, "fit") + ) + assert original_fit is not None + assert RandomForestClassifier.fit != original_fit + + sklearn.disable_autolog() + + def test_autolog_without_active_run(self, binary_classification_data): + """Test that autolog doesn't crash when no run is active.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + # Should work without active run (just doesn't log) + clf = RandomForestClassifier(random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + def test_autolog_prevents_nested_logging(self, temp_run, binary_classification_data): + """Test that nested estimator calls don't create duplicate logs.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + # Pipeline has nested fit() calls - should only log once + pipe = Pipeline([ + ("scaler", StandardScaler()), + ("clf", RandomForestClassifier(random_state=42)), + ]) + pipe.fit(X_train, y_train) + + sklearn.disable_autolog() + + +class TestSklearnParameterLogging: + """Test parameter logging functionality.""" + + def test_log_basic_params(self, temp_run, binary_classification_data): + """Test logging basic classifier parameters.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = RandomForestClassifier( + n_estimators=10, + max_depth=5, + random_state=42 + ) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Debug: Print what was logged + params = get_logged_params(temp_run) + print(f"DEBUG: Logged params: {params}") + print(f"DEBUG: All emitted data: {temp_run.http_emitter.emitted_data}") + + # Verify parameters were logged + assert_param_logged(temp_run, "n_estimators", 10) + assert_param_logged(temp_run, "max_depth", 5) + assert_param_logged(temp_run, "random_state", 42) + + def test_log_deep_params_pipeline(self, temp_run, binary_classification_data): + """Test logging deep parameters from Pipeline.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + pipe = Pipeline([ + ("scaler", StandardScaler()), + ("clf", RandomForestClassifier(n_estimators=10, random_state=42)), + ]) + pipe.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Should log params from both scaler and clf (get_params(deep=True)) + + +class TestSklearnDatasetLogging: + """Test dataset metadata logging.""" + + def test_log_dataset_shape_and_dtype(self, temp_run, binary_classification_data): + """Test logging dataset shape and dtype.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = RandomForestClassifier(n_estimators=5, random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify dataset was logged + assert_dataset_logged(temp_run, context="train", expected_shape=X_train.shape) + + # Verify dataset metadata contains required fields + datasets = get_logged_datasets(temp_run) + train_data = datasets["train"] + + assert train_data["features_shape"] == list(X_train.shape) + assert train_data["features_size"] == X_train.size + assert "features_digest" in train_data + assert "features_dtype" in train_data + assert "targets_shape" in train_data + assert train_data["targets_shape"] == list(y_train.shape) + + def test_dataset_logging_can_be_disabled(self, temp_run, binary_classification_data): + """Test that dataset logging can be disabled.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog(log_datasets=False) + + clf = RandomForestClassifier(n_estimators=5, random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify dataset was NOT logged + datasets = get_logged_datasets(temp_run) + assert "train" not in datasets, "Dataset should not be logged when log_datasets=False" + + +class TestSklearnClassifierMetrics: + """Test classifier metric logging.""" + + def test_binary_classifier_metrics(self, temp_run, binary_classification_data): + """Test metrics for binary classification.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = LogisticRegression(random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify classifier metrics were logged + assert_metric_logged(temp_run, "accuracy") + assert_metric_logged(temp_run, "precision") + assert_metric_logged(temp_run, "recall") + assert_metric_logged(temp_run, "f1_score") + # LogisticRegression has predict_proba, so these should also be logged + assert_metric_logged(temp_run, "log_loss") + assert_metric_logged(temp_run, "roc_auc") + + def test_multiclass_classifier_metrics(self, temp_run, multiclass_classification_data): + """Test metrics for multiclass classification.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = multiclass_classification_data + + sklearn.enable_autolog() + + clf = RandomForestClassifier(n_estimators=10, random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify multiclass metrics were logged + assert_metric_logged(temp_run, "accuracy") + assert_metric_logged(temp_run, "precision") # weighted average + assert_metric_logged(temp_run, "recall") # weighted average + assert_metric_logged(temp_run, "f1_score") # weighted average + assert_metric_logged(temp_run, "roc_auc") # multiclass OVR + + def test_classifier_without_predict_proba(self, temp_run, binary_classification_data): + """Test classifier that doesn't have predict_proba.""" + from artifacta.integrations import sklearn + from sklearn.svm import LinearSVC + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = LinearSVC(random_state=42, max_iter=1000) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Should log basic metrics but not log_loss or roc_auc + + +class TestSklearnRegressorMetrics: + """Test regressor metric logging.""" + + def test_regressor_metrics(self, temp_run, regression_data): + """Test metrics for regression.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = regression_data + + sklearn.enable_autolog() + + reg = RandomForestRegressor(n_estimators=10, random_state=42) + reg.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify regressor metrics were logged + assert_metric_logged(temp_run, "training_score") + assert_metric_logged(temp_run, "mse") + assert_metric_logged(temp_run, "rmse") + assert_metric_logged(temp_run, "mae") + assert_metric_logged(temp_run, "r2_score") + + def test_linear_regression_metrics(self, temp_run, regression_data): + """Test metrics for simple linear regression.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = regression_data + + sklearn.enable_autolog() + + reg = LinearRegression() + reg.fit(X_train, y_train) + + sklearn.disable_autolog() + + +class TestSklearnModelArtifacts: + """Test model artifact logging.""" + + def test_log_model_artifact(self, temp_run, binary_classification_data): + """Test that fitted model is logged as artifact.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = RandomForestClassifier(n_estimators=10, random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Verify model artifact was logged + assert_model_artifact_logged(temp_run) + assert_artifact_logged(temp_run, "RandomForestClassifier") + + def test_model_can_be_loaded(self, temp_run, binary_classification_data): + """Test that logged model can be loaded and used.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + clf = RandomForestClassifier(n_estimators=10, random_state=42) + clf.fit(X_train, y_train) + clf.predict(X_test) + + sklearn.disable_autolog() + + # In real implementation, load model from artifact and verify predictions match + + +class TestSklearnMetaEstimators: + """Test meta-estimator support (Pipeline, GridSearchCV).""" + + def test_pipeline_logging(self, temp_run, binary_classification_data): + """Test logging for Pipeline estimator.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + pipe = Pipeline([ + ("scaler", StandardScaler()), + ("clf", LogisticRegression(random_state=42)), + ]) + pipe.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Should log parameters from all pipeline steps + # Should log metrics from final estimator + + def test_gridsearchcv_logging(self, temp_run, binary_classification_data): + """Test logging for GridSearchCV.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog() + + param_grid = { + "max_depth": [3, 5], + "n_estimators": [5, 10], + } + grid = GridSearchCV( + RandomForestClassifier(random_state=42), + param_grid, + cv=2, + ) + grid.fit(X_train, y_train) + + sklearn.disable_autolog() + + # MLflow creates parent run + child runs for each CV fold + # For now, just ensure it completes without error + + +class TestSklearnConfiguration: + """Test autolog configuration options.""" + + def test_disable_model_logging(self, temp_run, binary_classification_data): + """Test disabling model logging.""" + from artifacta.integrations import sklearn + + from artifacta.tests.test_utils import count_logged_artifacts + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog(log_models=False, log_datasets=False) + + clf = RandomForestClassifier(random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Model should NOT be logged (and no dataset artifacts either) + # Exclude config artifact which is auto-logged from update_config() + num_artifacts = count_logged_artifacts(temp_run, exclude_config=True) + assert num_artifacts == 0, f"Expected 0 artifacts, got {num_artifacts}" + + def test_disable_metrics_logging(self, temp_run, binary_classification_data): + """Test disabling metrics logging.""" + from artifacta.integrations import sklearn + + X_train, X_test, y_train, y_test = binary_classification_data + + sklearn.enable_autolog(log_training_metrics=False) + + clf = RandomForestClassifier(random_state=42) + clf.fit(X_train, y_train) + + sklearn.disable_autolog() + + # Metrics should NOT be logged + metrics = get_logged_metrics(temp_run) + assert len(metrics) == 0, f"Expected 0 metrics, got {len(metrics)}: {list(metrics.keys())}" + + +class TestSklearnEdgeCases: + """Test edge cases and error handling.""" + + def test_estimator_without_score(self, temp_run): + """Test estimator that doesn't implement score().""" + from artifacta.integrations import sklearn + from sklearn.cluster import KMeans + + X = np.random.rand(100, 10) + + sklearn.enable_autolog() + + # KMeans doesn't have score() method (unsupervised) + # Should not crash + kmeans = KMeans(n_clusters=3, random_state=42, n_init=10) + kmeans.fit(X) + + sklearn.disable_autolog() + + def test_empty_dataset(self, temp_run): + """Test with empty dataset (edge case).""" + from artifacta.integrations import sklearn + + sklearn.enable_autolog() + + # Should handle gracefully (or raise appropriate sklearn error) + clf = LogisticRegression() + try: + clf.fit(np.array([]).reshape(0, 5), np.array([])) + except ValueError: + pass # Expected sklearn error + + sklearn.disable_autolog() + + def test_single_sample(self, temp_run): + """Test with single sample dataset.""" + from artifacta.integrations import sklearn + + sklearn.enable_autolog() + + X = np.array([[1, 2, 3]]) + y = np.array([0]) + + clf = DecisionTreeClassifier() + clf.fit(X, y) + + sklearn.disable_autolog() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/artifacta/tests/test_utils.py b/artifacta/tests/test_utils.py new file mode 100644 index 0000000..8c7ad92 --- /dev/null +++ b/artifacta/tests/test_utils.py @@ -0,0 +1,319 @@ +"""Test utilities for verifying autolog functionality. + +Provides helper functions to inspect what was logged by autolog integrations +without needing to query the tracking server or database directly. +""" + +import json +from unittest.mock import patch + + +class MockHTTPEmitter: + """Mock HTTPEmitter that captures emitted data for testing. + + Replaces the real HTTPEmitter to capture all logged data without + needing a running tracking server. + """ + + def __init__(self, run_id): + self.run_id = run_id + self.emitted_data = [] + self.emitted_artifacts = [] + self.init_called = False + self.enabled = True + + def emit_init(self, data): + """Capture run initialization.""" + self.init_called = True + self.emitted_data.append(("init", data)) + + def emit_structured_data(self, data): + """Capture structured data (metrics, params).""" + self.emitted_data.append(("structured_data", data)) + + def emit_artifact(self, data, content=None, role=None): + """Capture artifact metadata and content.""" + # Store both metadata and content + artifact_with_content = {**data, "content": content, "role": role} + self.emitted_artifacts.append(artifact_with_content) + + def close(self): + """Close emitter (no-op for mock).""" + pass + + +def get_logged_params(run): + """Extract logged parameters from a run. + + Args: + run: Artifacta run object + + Returns: + Dictionary of logged parameters + """ + # First check run.config (new approach using update_config()) + if hasattr(run, 'config') and run.config: + return dict(run.config) + + # Fallback: check emitted_data (old approach using run.log()) + if not hasattr(run.http_emitter, 'emitted_data'): + return {} + + params = {} + for event_type, data in run.http_emitter.emitted_data: + if event_type == "structured_data": + if data.get("name") in ["parameters", "xgboost_params", "sklearn_params"]: + # Extract params from structured data + data_dict = data.get("data", {}) + # Handle Series format: {"index_values": [{"param1": val1, ...}]} + if "index_values" in data_dict: + index_values = data_dict["index_values"] + if isinstance(index_values, list) and len(index_values) > 0: + params.update(index_values[0]) + # Handle other formats + elif "params" in data_dict: + param_list = data_dict["params"] + if isinstance(param_list, list) and len(param_list) > 0: + params.update(param_list[0]) + + return params + + +def get_logged_metrics(run): + """Extract logged metrics from a run. + + Args: + run: Artifacta run object + + Returns: + Dictionary of metric name -> values + """ + if not hasattr(run.http_emitter, 'emitted_data'): + return {} + + metrics = {} + for event_type, data in run.http_emitter.emitted_data: + if event_type == "structured_data": + name = data.get("name", "") + if "metric" in name.lower() or name in ["xgboost_metrics", "sklearn_metrics"]: + # Extract metrics from structured data + data_dict = data.get("data", {}) + + # Handle Series format with index and fields + # {"index": "metric", "fields": {"value": [...]}, "index_values": ["metric1", "metric2"]} + if "index" in data_dict and "fields" in data_dict: + index_values = data_dict.get("index_values", []) + fields = data_dict.get("fields", {}) + if "value" in fields: + values = fields["value"] + for metric_name, value in zip(index_values, values): + metrics[metric_name] = value + + # Handle other formats (raw dict with metric arrays) + # XGBoost logs like: {"iteration": [0,1,2], "train_logloss": [0.5, 0.3, 0.2]} + else: + for key, value in data_dict.items(): + if isinstance(value, (int, float, list)): + metrics[key] = value + + # Also check for fields dict (Series format) + # {"fields": {"train_logloss": [0.5, 0.3], "test_logloss": [0.6, 0.4]}} + if "fields" in data_dict and isinstance(data_dict["fields"], dict): + for key, value in data_dict["fields"].items(): + if isinstance(value, (int, float, list)): + metrics[key] = value + + return metrics + + +def get_logged_artifacts(run): + """Extract logged artifacts from a run. + + Args: + run: Artifacta run object + + Returns: + List of artifact metadata dictionaries + """ + if not hasattr(run.http_emitter, 'emitted_artifacts'): + return [] + + return run.http_emitter.emitted_artifacts + + +def assert_param_logged(run, param_name, expected_value=None): + """Assert that a parameter was logged. + + Args: + run: Artifacta run object + param_name: Name of parameter to check + expected_value: Optional expected value (if None, just check existence) + + Raises: + AssertionError: If parameter not logged or value doesn't match + """ + params = get_logged_params(run) + assert param_name in params, f"Parameter '{param_name}' was not logged. Logged params: {list(params.keys())}" + + if expected_value is not None: + actual_value = params[param_name] + assert actual_value == expected_value, ( + f"Parameter '{param_name}' has value {actual_value}, expected {expected_value}" + ) + + +def assert_metric_logged(run, metric_name): + """Assert that a metric was logged. + + Args: + run: Artifacta run object + metric_name: Name of metric to check + + Raises: + AssertionError: If metric not logged + """ + metrics = get_logged_metrics(run) + assert metric_name in metrics, f"Metric '{metric_name}' was not logged. Logged metrics: {list(metrics.keys())}" + + +def assert_artifact_logged(run, artifact_name_contains): + """Assert that an artifact was logged. + + Args: + run: Artifacta run object + artifact_name_contains: String that should be in the artifact name + + Raises: + AssertionError: If no matching artifact found + """ + artifacts = get_logged_artifacts(run) + matching = [a for a in artifacts if artifact_name_contains in a.get("name", "")] + assert len(matching) > 0, ( + f"No artifact with name containing '{artifact_name_contains}' found. " + f"Logged artifacts: {[a.get('name') for a in artifacts]}" + ) + + +def assert_model_artifact_logged(run): + """Assert that a model artifact was logged. + + Args: + run: Artifacta run object + + Raises: + AssertionError: If no model artifact found + """ + artifacts = get_logged_artifacts(run) + model_artifacts = [ + a for a in artifacts + if "model" in a.get("name", "").lower() or + a.get("metadata", {}).get("artifact_type", "").endswith("_model") + ] + assert len(model_artifacts) > 0, ( + f"No model artifact found. Logged artifacts: {[a.get('name') for a in artifacts]}" + ) + + +def count_logged_artifacts(run, exclude_config=False): + """Count number of artifacts logged. + + Args: + run: Artifacta run object + exclude_config: If True, exclude config artifacts from count + + Returns: + Number of artifacts logged + """ + artifacts = get_logged_artifacts(run) + if exclude_config: + # Filter out config artifacts (auto-logged from update_config()) + artifacts = [a for a in artifacts if a.get("name") != "config.json"] + return len(artifacts) + + +def get_logged_datasets(run): + """Extract logged dataset metadata from a run. + + Dataset metadata is stored as JSON artifacts (not structured data). + + Args: + run: Artifacta run object + + Returns: + Dictionary of context -> dataset metadata + """ + + if not hasattr(run.http_emitter, 'emitted_artifacts'): + return {} + + datasets = {} + for artifact_data in run.http_emitter.emitted_artifacts: + name = artifact_data.get("name", "") + # Check if this is dataset metadata (name like "dataset_train.json") + if name.startswith("dataset_") and name.endswith(".json"): + # Extract context from name (e.g., "dataset_train.json" -> "train") + context = name.replace("dataset_", "").replace(".json", "") + + # Parse JSON content from artifact + # Artifacts are stored with content in emitted_data as well + # We need to find the corresponding content + content = artifact_data.get("content") + if content: + try: + # Content is a JSON string containing file collection + content_obj = json.loads(content) if isinstance(content, str) else content + # Extract the actual JSON content from the file + if "files" in content_obj and len(content_obj["files"]) > 0: + file_content = content_obj["files"][0].get("content", "{}") + metadata = json.loads(file_content) if isinstance(file_content, str) else file_content + datasets[context] = metadata + except Exception: + pass + + return datasets + + +def assert_dataset_logged(run, context="train", expected_shape=None): + """Assert that dataset metadata was logged. + + Args: + run: Artifacta run object + context: Dataset context ("train", "eval", etc.) + expected_shape: Optional expected features shape tuple + + Raises: + AssertionError: If dataset not logged or shape doesn't match + """ + datasets = get_logged_datasets(run) + assert context in datasets, ( + f"Dataset with context '{context}' was not logged. " + f"Logged datasets: {list(datasets.keys())}" + ) + + dataset_meta = datasets[context] + assert "features_shape" in dataset_meta, "Dataset missing features_shape" + assert "features_dtype" in dataset_meta, "Dataset missing features_dtype" + assert "features_digest" in dataset_meta, "Dataset missing features_digest" + assert "context" in dataset_meta, "Dataset missing context" + assert dataset_meta["context"] == context, ( + f"Dataset context mismatch: {dataset_meta['context']} != {context}" + ) + + if expected_shape is not None: + actual_shape = tuple(dataset_meta["features_shape"]) + assert actual_shape == expected_shape, ( + f"Dataset shape mismatch: {actual_shape} != {expected_shape}" + ) + + +def patch_http_emitter(): + """Context manager to patch HTTPEmitter with MockHTTPEmitter. + + Usage: + with patch_http_emitter(): + run = ds.init(...) + # run.http_emitter will be MockHTTPEmitter + """ + # Patch in multiple locations to ensure it's caught + import artifacta.run + return patch.object(artifacta.run, 'HTTPEmitter', MockHTTPEmitter) diff --git a/artifacta/tests/test_xgboost_autolog.py b/artifacta/tests/test_xgboost_autolog.py new file mode 100644 index 0000000..3e78a61 --- /dev/null +++ b/artifacta/tests/test_xgboost_autolog.py @@ -0,0 +1,515 @@ +"""Tests for XGBoost autolog integration. + +Tests cover all features matching MLflow's XGBoost autolog: +- Parameter logging (native API and sklearn API) +- Per-iteration metrics (via callbacks) +- Feature importance (weight, gain, cover) +- Model artifact logging +- Early stopping support +- Metric name sanitization (@ β†’ _at_) +""" + + +import pytest +import xgboost as xgb +from sklearn.datasets import load_breast_cancer, make_regression +from sklearn.model_selection import train_test_split + +from artifacta import init +from artifacta.tests.test_utils import ( + MockHTTPEmitter, + assert_dataset_logged, + assert_model_artifact_logged, + assert_param_logged, + get_logged_datasets, + get_logged_metrics, + get_logged_params, +) + + +@pytest.fixture +def binary_classification_data(): + """Binary classification dataset.""" + X, y = load_breast_cancer(return_X_y=True) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def regression_data(): + """Regression dataset.""" + X, y = make_regression(n_samples=100, n_features=10, random_state=42) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def temp_run(): + """Create and cleanup temporary run with mocked HTTP emitter.""" + run = init(project="test_xgboost_autolog", name="test_run") + # Replace the http_emitter with our mock + run.http_emitter = MockHTTPEmitter(run.id) + yield run + run.finish() + + +class TestXGBoostAutologBasic: + """Basic autolog functionality tests.""" + + def test_autolog_enable_disable(self): + """Test enabling and disabling autolog.""" + from artifacta.integrations import xgboost + + # Enable + xgboost.enable_autolog() + assert xgboost._AUTOLOG_ENABLED is True + + # Disable + xgboost.disable_autolog() + assert xgboost._AUTOLOG_ENABLED is False + + def test_autolog_patches_train(self): + """Test that autolog patches xgboost.train().""" + from artifacta.integrations import xgboost + + original_train = xgb.train + xgboost.enable_autolog() + + # Check that train was patched + assert xgb.train != original_train + assert original_train == xgboost._ORIGINAL_TRAIN + + xgboost.disable_autolog() + + def test_autolog_patches_sklearn(self): + """Test that autolog patches XGBoost sklearn API.""" + from artifacta.integrations import xgboost + + xgboost.enable_autolog() + + # Check that sklearn API was patched + assert (xgb.XGBClassifier, "fit") in xgboost._ORIGINAL_SKLEARN_METHODS + assert (xgb.XGBRegressor, "fit") in xgboost._ORIGINAL_SKLEARN_METHODS + + xgboost.disable_autolog() + + def test_autolog_without_active_run(self, binary_classification_data): + """Test that autolog doesn't crash when no run is active.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + # Should work without active run (just doesn't log) + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + +class TestXGBoostNativeAPI: + """Test native xgboost.train() API autolog.""" + + def test_log_params(self, temp_run, binary_classification_data): + """Test logging parameters from xgboost.train().""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + params = { + "max_depth": 3, + "eta": 0.1, + "objective": "binary:logistic", + "eval_metric": "logloss", + } + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify params were logged + assert_param_logged(temp_run, "max_depth", 3) + assert_param_logged(temp_run, "eta", 0.1) + assert_param_logged(temp_run, "objective", "binary:logistic") + + def test_log_metrics_with_evals(self, temp_run, binary_classification_data): + """Test logging per-iteration metrics with evals.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + params = {"max_depth": 3, "objective": "binary:logistic"} + + xgb.train( + params, + dtrain, + num_boost_round=10, + evals=[(dtrain, "train"), (dtest, "test")], + ) + + xgboost.disable_autolog() + + # Verify per-iteration metrics were logged + metrics = get_logged_metrics(temp_run) + # Should have train and test metrics (at least one type) + assert len(metrics) > 0, "No metrics were logged" + # Verify we have metrics from both train and test sets + metric_names = list(metrics.keys()) + has_train = any("train" in name for name in metric_names) + has_test = any("test" in name for name in metric_names) + assert has_train or has_test, "Expected train or test metrics to be logged" + + def test_log_feature_importance(self, temp_run, binary_classification_data): + """Test logging feature importance.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify feature importance artifacts were logged + artifacts = temp_run.http_emitter.emitted_artifacts + importance_artifacts = [ + a for a in artifacts if "feature_importance" in a.get("name", "") + ] + assert len(importance_artifacts) > 0, "No feature importance artifacts logged" + # Should have weight, gain, and cover by default + importance_names = [a.get("name", "") for a in importance_artifacts] + assert any("weight" in name for name in importance_names), "weight importance not logged" + assert any("gain" in name for name in importance_names), "gain importance not logged" + assert any("cover" in name for name in importance_names), "cover importance not logged" + + def test_log_model(self, temp_run, binary_classification_data): + """Test logging trained model.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify model artifact was logged + assert_model_artifact_logged(temp_run) + + def test_metric_name_sanitization(self, temp_run, binary_classification_data): + """Test that metric names with @ are sanitized.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + params = { + "max_depth": 3, + "objective": "binary:logistic", + "eval_metric": ["logloss", "ndcg@3"], + } + + xgb.train( + params, + dtrain, + num_boost_round=10, + evals=[(dtest, "test")], + ) + + xgboost.disable_autolog() + + # Verify metric name sanitization: "ndcg@3" should be logged as "test_ndcg_at_3" + metrics = get_logged_metrics(temp_run) + metric_names = list(metrics.keys()) + # Check that @ was replaced with _at_ + has_sanitized = any("_at_" in name for name in metric_names) + assert has_sanitized, "Expected sanitized metric name with '_at_' but found none" + # Should not have @ in any metric name + has_at_symbol = any("@" in name for name in metric_names) + assert not has_at_symbol, "Found unsanitized metric name with '@' symbol" + + +class TestXGBoostSklearnAPI: + """Test XGBoost sklearn API autolog.""" + + def test_xgbclassifier_params(self, temp_run, binary_classification_data): + """Test logging XGBClassifier parameters.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + clf = xgb.XGBClassifier(max_depth=3, learning_rate=0.1, n_estimators=10) + clf.fit(X_train, y_train) + + xgboost.disable_autolog() + + # Verify params were logged + assert_param_logged(temp_run, "max_depth", 3) + assert_param_logged(temp_run, "learning_rate", 0.1) + assert_param_logged(temp_run, "n_estimators", 10) + + def test_xgbregressor_params(self, temp_run, regression_data): + """Test logging XGBRegressor parameters.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = regression_data + + xgboost.enable_autolog() + + reg = xgb.XGBRegressor(max_depth=3, learning_rate=0.1, n_estimators=10) + reg.fit(X_train, y_train) + + xgboost.disable_autolog() + + # Verify params were logged + assert_param_logged(temp_run, "max_depth", 3) + assert_param_logged(temp_run, "learning_rate", 0.1) + assert_param_logged(temp_run, "n_estimators", 10) + + def test_sklearn_feature_importance(self, temp_run, binary_classification_data): + """Test feature importance logging for sklearn API.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + clf = xgb.XGBClassifier(max_depth=3, n_estimators=10) + clf.fit(X_train, y_train) + + xgboost.disable_autolog() + + # Verify feature importance artifacts were logged + artifacts = temp_run.http_emitter.emitted_artifacts + importance_artifacts = [ + a for a in artifacts if "feature_importance" in a.get("name", "") + ] + assert len(importance_artifacts) > 0, "No feature importance artifacts logged" + + def test_sklearn_model_logging(self, temp_run, binary_classification_data): + """Test model logging for sklearn API.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + clf = xgb.XGBClassifier(max_depth=3, n_estimators=10) + clf.fit(X_train, y_train) + + xgboost.disable_autolog() + + # Verify model artifact was logged + assert_model_artifact_logged(temp_run) + + +class TestXGBoostConfiguration: + """Test autolog configuration options.""" + + def test_disable_model_logging(self, temp_run, binary_classification_data): + """Test disabling model logging.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog(log_models=False) + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify model was NOT logged + artifacts = temp_run.http_emitter.emitted_artifacts + model_artifacts = [a for a in artifacts if a.get("name", "") == "model"] + assert len(model_artifacts) == 0, "Model should not be logged when log_models=False" + + def test_disable_feature_importance(self, temp_run, binary_classification_data): + """Test disabling feature importance logging.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog(log_feature_importance=False) + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify feature importance was NOT logged + artifacts = temp_run.http_emitter.emitted_artifacts + importance_artifacts = [ + a for a in artifacts if "feature_importance" in a.get("name", "") + ] + assert len(importance_artifacts) == 0, "Feature importance should not be logged when log_feature_importance=False" + + def test_custom_importance_types(self, temp_run, binary_classification_data): + """Test custom importance types.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog(importance_types=["weight"]) + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify only "weight" importance was logged + artifacts = temp_run.http_emitter.emitted_artifacts + importance_artifacts = [ + a for a in artifacts if "feature_importance" in a.get("name", "") + ] + assert len(importance_artifacts) > 0, "No feature importance artifacts logged" + importance_names = [a.get("name", "") for a in importance_artifacts] + assert any("weight" in name for name in importance_names), "weight importance not logged" + assert not any("gain" in name for name in importance_names), "gain importance should not be logged" + assert not any("cover" in name for name in importance_names), "cover importance should not be logged" + + +class TestXGBoostDatasetLogging: + """Test dataset metadata logging.""" + + def test_native_api_dataset_logging(self, temp_run, binary_classification_data): + """Test dataset logging for native API (xgb.train).""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + params = {"max_depth": 3, "objective": "binary:logistic"} + + xgb.train( + params, + dtrain, + num_boost_round=5, + evals=[(dtrain, "train"), (dtest, "test")], + ) + + xgboost.disable_autolog() + + # Verify training dataset was logged + assert_dataset_logged(temp_run, context="train", expected_shape=X_train.shape) + + # Verify eval datasets were logged + assert_dataset_logged(temp_run, context="eval_train") + assert_dataset_logged(temp_run, context="eval_test") + + # Check metadata fields + datasets = get_logged_datasets(temp_run) + train_data = datasets["train"] + assert "features_digest" in train_data + assert "features_dtype" in train_data + + def test_sklearn_api_dataset_logging(self, temp_run, binary_classification_data): + """Test dataset logging for sklearn API (XGBClassifier).""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + clf = xgb.XGBClassifier(n_estimators=5, max_depth=3) + clf.fit(X_train, y_train) + + xgboost.disable_autolog() + + # Verify dataset was logged + assert_dataset_logged(temp_run, context="train", expected_shape=X_train.shape) + + datasets = get_logged_datasets(temp_run) + train_data = datasets["train"] + assert train_data["features_shape"] == list(X_train.shape) + assert train_data["targets_shape"] == list(y_train.shape) + + def test_dataset_logging_can_be_disabled(self, temp_run, binary_classification_data): + """Test that dataset logging can be disabled.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog(log_datasets=False) + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=5) + + xgboost.disable_autolog() + + # Verify dataset was NOT logged + datasets = get_logged_datasets(temp_run) + assert "train" not in datasets, "Dataset should not be logged when log_datasets=False" + + +class TestXGBoostEdgeCases: + """Test edge cases and error handling.""" + + def test_no_evals(self, temp_run, binary_classification_data): + """Test training without evals (no per-iteration metrics).""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=10) + + xgboost.disable_autolog() + + # Verify training completed without error + # Params should still be logged even without evals + assert_param_logged(temp_run, "max_depth", 3) + # Model should still be logged + assert_model_artifact_logged(temp_run) + + def test_prevents_nested_logging(self, temp_run, binary_classification_data): + """Test that nested training doesn't create duplicate logs.""" + from artifacta.integrations import xgboost + + X_train, X_test, y_train, y_test = binary_classification_data + + xgboost.enable_autolog() + + # This shouldn't cause issues even if called multiple times + dtrain = xgb.DMatrix(X_train, y_train) + params = {"max_depth": 3, "objective": "binary:logistic"} + xgb.train(params, dtrain, num_boost_round=5) + xgb.train(params, dtrain, num_boost_round=5) + + xgboost.disable_autolog() + + # Verify both trainings completed successfully + # Both should log params (two separate train calls) + params_logged = get_logged_params(temp_run) + assert "max_depth" in params_logged, "Parameters should be logged" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/artifacta_ui/__init__.py b/artifacta_ui/__init__.py index 613c8b0..1b69b2c 100644 --- a/artifacta_ui/__init__.py +++ b/artifacta_ui/__init__.py @@ -1,4 +1,4 @@ from pathlib import Path -UI_DIST_PATH = Path(__file__).parent / 'dist' -INDEX_HTML = Path(__file__).parent / 'index.html' +UI_DIST_PATH = Path(__file__).parent / "dist" +INDEX_HTML = Path(__file__).parent / "index.html" diff --git a/eslint.config.js b/config/eslint.config.js similarity index 59% rename from eslint.config.js rename to config/eslint.config.js index 102a417..c658ec4 100644 --- a/eslint.config.js +++ b/config/eslint.config.js @@ -1,14 +1,17 @@ import js from '@eslint/js'; import react from 'eslint-plugin-react'; import reactHooks from 'eslint-plugin-react-hooks'; +import jsdoc from 'eslint-plugin-jsdoc'; export default [ js.configs.recommended, + jsdoc.configs['flat/recommended'], { files: ['src/**/*.{js,jsx}'], plugins: { react, 'react-hooks': reactHooks, + jsdoc, }, languageOptions: { parserOptions: { @@ -66,6 +69,39 @@ export default [ 'react/jsx-uses-vars': 'error', 'react-hooks/rules-of-hooks': 'error', 'react-hooks/exhaustive-deps': 'warn', + + // JSDoc rules - require documentation for functions and classes + 'jsdoc/require-jsdoc': ['error', { + require: { + FunctionDeclaration: true, + MethodDefinition: true, + ClassDeclaration: true, + ArrowFunctionExpression: true, + FunctionExpression: true, + }, + contexts: [ + 'FunctionDeclaration', + 'FunctionExpression', + 'ArrowFunctionExpression', + 'MethodDefinition', + 'ClassDeclaration', + 'VariableDeclaration > VariableDeclarator > ArrowFunctionExpression', + 'VariableDeclaration > VariableDeclarator > FunctionExpression', + ], + }], + 'jsdoc/require-description': 'error', + 'jsdoc/require-param': 'error', + 'jsdoc/require-param-description': 'error', + 'jsdoc/require-param-type': 'error', + 'jsdoc/require-returns': 'error', + 'jsdoc/require-returns-description': 'error', + 'jsdoc/require-returns-type': 'error', + 'jsdoc/check-types': 'error', + 'jsdoc/check-param-names': 'error', + 'jsdoc/check-tag-names': 'error', + // Convert warnings to errors - no cheating! + 'jsdoc/no-undefined-types': 'error', + 'jsdoc/reject-function-type': 'error', }, settings: { react: { diff --git a/config/jsdoc.json b/config/jsdoc.json new file mode 100644 index 0000000..6470f93 --- /dev/null +++ b/config/jsdoc.json @@ -0,0 +1,32 @@ +{ + "source": { + "include": [ + "src/app" + ], + "includePattern": "\\.(js|jsx)$", + "excludePattern": "(node_modules|test|__tests__|.test|.spec)" + }, + "opts": { + "destination": "docs/_build/jsdoc", + "recurse": true + }, + "plugins": [ + "plugins/markdown" + ], + "templates": { + "cleverLinks": true, + "monospaceLinks": false, + "default": { + "outputSourceFiles": true, + "includeDate": false + } + }, + "tags": { + "allowUnknownTags": true, + "dictionaries": ["jsdoc", "closure"] + }, + "markdown": { + "hardwrap": false, + "idInHeadings": true + } +} diff --git a/config/playwright.config.js b/config/playwright.config.js new file mode 100644 index 0000000..deb24a2 --- /dev/null +++ b/config/playwright.config.js @@ -0,0 +1,59 @@ +import { defineConfig, devices } from '@playwright/test'; + +/** + * Playwright E2E Test Configuration + * + * Reads base URL from ARTIFACTA_URL environment variable (default: http://localhost:8000) + * + * Usage: + * npm run test:e2e + * ARTIFACTA_URL=http://localhost:8001 npm run test:e2e + */ +export default defineConfig({ + testDir: '../tests/e2e', + + // Global setup and teardown + globalSetup: '../tests/e2e/setup.js', + globalTeardown: '../tests/e2e/teardown.js', + + // Test execution settings + fullyParallel: true, + forbidOnly: !!process.env.CI, + retries: process.env.CI ? 2 : 0, + workers: process.env.CI ? 1 : undefined, + + // Reporter + reporter: 'html', + + // Shared settings for all tests + use: { + // Base URL from env var or default + baseURL: process.env.ARTIFACTA_URL || 'http://localhost:8000', + + // Browser settings + trace: 'on-first-retry', + screenshot: 'only-on-failure', + + // Timeouts + actionTimeout: 10000, + }, + + // Test timeout + timeout: 30000, + + // Configure projects for different browsers (chromium only for now - fast and headless) + projects: [ + { + name: 'chromium', + use: { ...devices['Desktop Chrome'] }, + }, + ], + + // Run local dev server before starting the tests + // (disabled - we handle server startup in global setup) + // webServer: { + // command: 'venv/bin/artifacta ui', + // url: 'http://localhost:8000', + // reuseExistingServer: !process.env.CI, + // }, +}); diff --git a/vite.config.js b/config/vite.config.js similarity index 69% rename from vite.config.js rename to config/vite.config.js index 585341e..5bec4d3 100644 --- a/vite.config.js +++ b/config/vite.config.js @@ -3,9 +3,11 @@ import react from '@vitejs/plugin-react'; import path from 'path'; import { fileURLToPath } from 'url'; -const __dirname = path.dirname(fileURLToPath(import.meta.url)); +const __dirname = path.dirname(path.dirname(fileURLToPath(import.meta.url))); export default defineConfig({ + root: path.resolve(__dirname, 'artifacta_ui'), + publicDir: path.resolve(__dirname, 'public'), plugins: [react()], css: { preprocessorOptions: { @@ -20,6 +22,7 @@ export default defineConfig({ }, build: { outDir: 'dist', + emptyOutDir: true, sourcemap: true }, @@ -30,7 +33,8 @@ export default defineConfig({ '@/app': path.resolve(__dirname, './src/app'), '@/ml': path.resolve(__dirname, './src/ml'), '@/core': path.resolve(__dirname, './src/core'), - '@/config': path.resolve(__dirname, './src/config') + '@/config': path.resolve(__dirname, './src/config'), + '/src': path.resolve(__dirname, './src') } } }); diff --git a/docs/Makefile b/docs/Makefile index d4bb2cb..189d2a4 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -14,6 +14,18 @@ help: .PHONY: help Makefile +# Build UI docs first, then Sphinx docs +html: jsdoc + @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @echo "Copying JSDoc files to HTML output..." + @cp -r $(BUILDDIR)/jsdoc $(BUILDDIR)/html/jsdoc + +# Generate JSDoc documentation +jsdoc: + @echo "Generating UI API documentation with JSDoc..." + @cd .. && npm run docs:ui || true + @echo "JSDoc documentation generated in $(BUILDDIR)/jsdoc/" + # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile diff --git a/docs/api.rst b/docs/api.rst index 407c7ed..df35319 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -124,6 +124,22 @@ Base Integration :undoc-members: :show-inheritance: +scikit-learn +~~~~~~~~~~~~ + +.. automodule:: artifacta.artifacta.integrations.sklearn + :members: + :undoc-members: + :show-inheritance: + +XGBoost +~~~~~~~ + +.. automodule:: artifacta.artifacta.integrations.xgboost + :members: + :undoc-members: + :show-inheritance: + PyTorch Lightning ~~~~~~~~~~~~~~~~~ @@ -132,8 +148,8 @@ PyTorch Lightning :undoc-members: :show-inheritance: -TensorFlow -~~~~~~~~~~ +TensorFlow/Keras +~~~~~~~~~~~~~~~~ .. automodule:: artifacta.artifacta.integrations.tensorflow :members: diff --git a/docs/development.rst b/docs/development.rst index 1c30efb..2fd38a5 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -147,8 +147,15 @@ For PostgreSQL: .. code-block:: bash + # Linux/macOS export DATABASE_URI="postgresql://user:pass@host:port/dbname" # pragma: allowlist secret + # Windows (PowerShell) + $env:DATABASE_URI="postgresql://user:pass@host:port/dbname" # pragma: allowlist secret + + # Windows (cmd) + set DATABASE_URI=postgresql://user:pass@host:port/dbname + See ``tracking-server/database.py`` for full SQLAlchemy model definitions. Setting Up Development Environment @@ -163,13 +170,22 @@ Setting Up Development Environment git clone https://github.com/walkerbdev/artifacta.git cd artifacta -**2. Create a virtual environment:** +**2. Create and activate a virtual environment:** .. code-block:: bash + # Create venv python3 -m venv venv + + # Activate - Linux/macOS source venv/bin/activate + # Activate - Windows (PowerShell) + venv\Scripts\Activate.ps1 + + # Activate - Windows (cmd) + venv\Scripts\activate.bat + **3. Install Python dependencies:** .. code-block:: bash @@ -178,6 +194,14 @@ Setting Up Development Environment This installs Artifacta and all optional dependencies including PyTorch, TensorFlow, and scientific computing libraries from the ``pyproject.toml`` file. +For generating real test videos (optional, requires FFmpeg): + +.. code-block:: bash + + pip install -e '.[dev,video]' + +Note: Video artifact logging works without this - test helpers will use placeholder MP4 files instead of generating real videos. + **4. Install and build UI:** .. code-block:: bash @@ -202,16 +226,26 @@ Artifacta uses pre-commit hooks to maintain code quality. Hooks run automaticall .. code-block:: bash + # Linux/macOS source venv/bin/activate pre-commit run --all-files + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + pre-commit run --all-files + + # Windows (cmd) + venv\Scripts\activate.bat + pre-commit run --all-files + **Pre-commit hooks include:** - **Ruff** - Fast Python linter and formatter -- **Mypy** - Static type checking +- **Mypy** - Static type checking for tracking-server - **Pydocstyle** - Docstring style checker (Google style) -- **ESLint** - JavaScript/React linter +- **ESLint** - JavaScript/React linter with JSDoc enforcement - **Knip** - Find unused JavaScript exports and dependencies +- **Depcheck** - Find unused npm dependencies - **Vulture** - Find dead Python code - **Codespell** - Catch typos in code and documentation - **Detect-secrets** - Prevent committing API keys and passwords @@ -220,48 +254,209 @@ Artifacta uses pre-commit hooks to maintain code quality. Hooks run automaticall Running Tests ------------- -Artifacta includes a comprehensive test suite. Tests require a running Artifacta server. +Artifacta includes two types of tests: + +1. **Pytest (Python)** - Unit and integration tests for Python API, autolog, and primitives +2. **Playwright (E2E)** - End-to-end browser tests for the UI + +Pytest - Python Tests +~~~~~~~~~~~~~~~~~~~~~~ + +Artifacta's pytest suite tests Python functionality including autolog integrations, data primitives, and domain-specific features. + +**Test Categories:** + +- ``tests/autolog/`` - PyTorch, TensorFlow, PyTorch Lightning autolog +- ``tests/domains/`` - Domain-specific primitives (genomics, finance, robotics, computer vision, climate, audio/video, etc.) **1. Start the server in one terminal:** .. code-block:: bash + # Linux/macOS source venv/bin/activate artifacta ui + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + artifacta ui + + # Windows (cmd) + venv\Scripts\activate.bat + artifacta ui + **2. Run tests in another terminal:** .. code-block:: bash + # Linux/macOS source venv/bin/activate pytest tests/ -**Run specific tests:** + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + pytest tests/ + + # Windows (cmd) + venv\Scripts\activate.bat + pytest tests/ + +**Run specific test categories:** .. code-block:: bash - pytest tests/autolog/ -v # Run autolog tests - pytest tests/domains/test_primitives.py -v # Run primitive tests + pytest tests/autolog/ -v # All autolog tests + pytest tests/autolog/test_pytorch_lightning.py -v # PyTorch Lightning tests + pytest tests/domains/ -v # All domain tests + pytest tests/domains/test_genomics.py -v # Genomics primitives + +**Run specific test by name:** + +.. code-block:: bash + + pytest tests/autolog/ -k "test_checkpoint" -v # Tests matching "checkpoint" + pytest tests/domains/ -k "test_roc_curve" -v # Tests matching "roc_curve" + +**Custom server configuration:** -Tests automatically use ``localhost:8000`` by default. If you're running the server on a different host/port, set environment variables: +Tests automatically use ``localhost:8000`` on Linux/macOS and ``127.0.0.1:8000`` on Windows. To use a different host/port: .. code-block:: bash - export TRACKING_SERVER_HOST=0.0.0.0 + # Linux/macOS + export TRACKING_SERVER_HOST=localhost export TRACKING_SERVER_PORT=9000 pytest tests/ + # Windows (PowerShell) - use 127.0.0.1 instead of localhost + $env:TRACKING_SERVER_HOST="127.0.0.1" + $env:TRACKING_SERVER_PORT="9000" + pytest tests/ + + # Windows (cmd) - use 127.0.0.1 instead of localhost + set TRACKING_SERVER_HOST=127.0.0.1 + set TRACKING_SERVER_PORT=9000 + pytest tests/ + +Playwright - E2E UI Tests +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Playwright tests verify the UI works end-to-end by automating browser interactions. These tests run in a real Chromium browser and test core functionality like run selection, data visualization, and navigation. + +**Test Files:** + +- ``tests/e2e/core.spec.js`` - Core UI functionality (homepage, navigation, API, sidebar) +- ``tests/e2e/visualization.spec.js`` - Data visualization (plots, tables, artifacts, chat) + +**Setup:** + +Playwright tests handle server startup/shutdown automatically. No manual server setup required. + +**Run all E2E tests:** + +.. code-block:: bash + + npm run test:e2e + +**Run specific test file:** + +.. code-block:: bash + + npm run test:e2e -- tests/e2e/core.spec.js + npm run test:e2e -- tests/e2e/visualization.spec.js + +**Run tests matching pattern:** + +.. code-block:: bash + + npm run test:e2e -- --grep "plots tab" # Tests with "plots tab" in name + npm run test:e2e -- --grep "API" # Tests with "API" in name + +**Interactive UI mode (debugging):** + +.. code-block:: bash + + npm run test:e2e:ui + +This opens Playwright's UI where you can step through tests, see screenshots, and debug failures. + +**Custom server URL:** + +By default, tests use ``http://localhost:8000`` (``http://127.0.0.1:8000`` on Windows). To test against a different URL: + +.. code-block:: bash + + # Linux/macOS + ARTIFACTA_URL=http://localhost:8001 npm run test:e2e + + # Windows (PowerShell) - use 127.0.0.1 instead of localhost + $env:ARTIFACTA_URL="http://127.0.0.1:8001"; npm run test:e2e + + # Windows (cmd) - use 127.0.0.1 instead of localhost + set ARTIFACTA_URL=http://127.0.0.1:8001 && npm run test:e2e + +**What the tests do:** + +1. **Global Setup** (``tests/e2e/setup.js``): + - Cleans database for fresh state + - Starts Artifacta server on port 8000 + - Runs example script to populate test data + - Waits for server health check + +2. **Core Tests** (6 tests): + - Homepage loads successfully + - Run list displays correctly + - Navigation between tabs works + - Health check endpoint returns healthy + - API returns run data + - Sidebar is interactive + +3. **Visualization Tests** (4 tests): + - Plots tab renders charts + - Tables tab shows structured data + - Artifacts tab displays file list + - Chat tab loads successfully + +4. **Global Teardown** (``tests/e2e/teardown.js``): + - Stops the server + - Cleans up background processes + +**Test output:** + +- ``test-results/`` - Screenshots and traces from failed tests +- ``playwright-report/`` - HTML report with test results + +These directories are gitignored and safe to delete. + Building Documentation ---------------------- -Artifacta uses Sphinx for documentation. +Artifacta uses Sphinx for Python API documentation and JSDoc for UI component documentation. **Build the docs:** .. code-block:: bash + # Linux/macOS source venv/bin/activate - venv/bin/sphinx-build -b html docs docs/_build/html + cd docs + make html + + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + cd docs + .\make.bat html + + # Windows (cmd) + venv\Scripts\activate.bat + cd docs + make.bat html + +This automatically: + +1. Generates JSDoc documentation from UI components (``npm run docs:ui``) +2. Builds Sphinx documentation (Python API, user guide, examples) +3. Links both together in ``_build/html/`` **View the docs:** @@ -273,6 +468,18 @@ Open ``docs/_build/html/index.html`` in your browser, or serve them locally: Then navigate to http://localhost:8080. +**Build only Python docs** (skip JSDoc): + +.. code-block:: bash + + sphinx-build -M html . _build + +**Build only UI docs:** + +.. code-block:: bash + + npm run docs:ui + Version Management ------------------ @@ -282,9 +489,18 @@ Artifacta uses ``bump-my-version`` to manage version numbers across the codebase .. code-block:: bash + # Linux/macOS source venv/bin/activate bump-my-version show current_version + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + bump-my-version show current_version + + # Windows (cmd) + venv\Scripts\activate.bat + bump-my-version show current_version + **Bump version:** .. code-block:: bash @@ -313,3 +529,202 @@ When you bump the version, it automatically: - Creates a git tag ``vX.Y.Z`` After bumping, rebuild documentation to reflect the new version in docs. + +Publishing to PyPI +------------------- + +**Prerequisites:** + +1. Get your PyPI API tokens: + - TestPyPI: https://test.pypi.org/manage/account/token/ + - PyPI: https://pypi.org/manage/account/token/ + +2. Set environment variables: + +.. code-block:: bash + + # Linux/macOS + export TWINE_USERNAME=__token__ + export TWINE_PASSWORD=pypi-YOUR-TOKEN-HERE # For PyPI + export TWINE_TEST_PASSWORD=pypi-YOUR-TEST-TOKEN-HERE # For TestPyPI + + # Windows (PowerShell) + $env:TWINE_USERNAME="__token__" + $env:TWINE_PASSWORD="pypi-YOUR-TOKEN-HERE" # For PyPI + $env:TWINE_TEST_PASSWORD="pypi-YOUR-TEST-TOKEN-HERE" # For TestPyPI + + # Windows (cmd) + set TWINE_USERNAME=__token__ + set TWINE_PASSWORD=pypi-YOUR-TOKEN-HERE + set TWINE_TEST_PASSWORD=pypi-YOUR-TEST-TOKEN-HERE + +**Build the distribution:** + +.. code-block:: bash + + # Linux/macOS + source venv/bin/activate + rm -rf dist/ build/ *.egg-info artifacta.egg-info + npm install + npm run build + python -m build + + # Windows (PowerShell) + venv\Scripts\Activate.ps1 + Remove-Item -Recurse -Force dist, build, *.egg-info, artifacta.egg-info -ErrorAction SilentlyContinue + npm install + npm run build + python -m build + + # Windows (cmd) + venv\Scripts\activate.bat + if exist dist rmdir /s /q dist + if exist build rmdir /s /q build + npm install + npm run build + python -m build + +This creates ``dist/artifacta-X.Y.Z.tar.gz`` and ``dist/artifacta-X.Y.Z-py3-none-any.whl``. + +**Publish to TestPyPI (for testing):** + +.. code-block:: bash + + # Upload to TestPyPI + python -m twine upload --repository testpypi dist/* + + # Test installation from TestPyPI + pip install --index-url https://test.pypi.org/simple/ artifacta + +**Publish to PyPI (production):** + +.. code-block:: bash + + # Upload to PyPI + python -m twine upload dist/* + +**Complete release workflow:** + +.. code-block:: bash + + # 1. Bump version + bump-my-version bump patch # or minor/major + + # 2. Build package + rm -rf dist/ build/ *.egg-info + npm run build + python -m build + + # 3. Test on TestPyPI first + python -m twine upload --repository testpypi dist/* + + # 4. If test passes, publish to PyPI + python -m twine upload dist/* + + # 5. Push version tag to GitHub + git push origin main --tags + +**Notes:** + +- Always test on TestPyPI before publishing to PyPI +- The UI must be built (``npm run build``) before building the Python package +- Environment variables keep your tokens secure (never commit tokens to git) +- The build includes the pre-built UI from ``artifacta_ui/dist/`` and ``dist/`` + +UI Static File Serving Architecture +------------------------------------ + +Artifacta bundles the pre-built React UI into the Python package, enabling single-command installation via ``pip install artifacta`` without requiring Node.js. + +**How It Works:** + +**1. Build Process** + +When you run ``npm run build``, Vite compiles the React application into static assets: + +.. code-block:: text + + artifacta_ui/ + β”œβ”€β”€ __init__.py # Exports UI_DIST_PATH + β”œβ”€β”€ dist/ + β”‚ β”œβ”€β”€ index.html # Entry point + β”‚ └── assets/ + β”‚ β”œβ”€β”€ *.js # Bundled JavaScript + β”‚ └── *.css # Bundled CSS + └── index.html # (legacy, may be removed) + +**2. Package Inclusion** + +The ``pyproject.toml`` declares ``artifacta_ui`` as a Python package and includes UI assets: + +.. code-block:: toml + + [tool.setuptools] + packages = ["artifacta", "tracking_server", "artifacta_ui"] + + [tool.setuptools.package-data] + artifacta_ui = ["dist/**/*", "index.html"] + +When you run ``python -m build``, setuptools includes these files in the wheel/tarball. + +**3. Runtime Path Resolution** + +The ``artifacta_ui/__init__.py`` module exports the UI location: + +.. code-block:: python + + from pathlib import Path + UI_DIST_PATH = Path(__file__).parent / 'dist' + +When installed via pip, ``__file__`` points to ``site-packages/artifacta_ui/__init__.py``, so ``UI_DIST_PATH`` resolves to the bundled static files inside the Python installation. + +**4. Development vs Production Detection** + +The ``tracking-server/config.py`` handles both scenarios: + +.. code-block:: python + + try: + from artifacta_ui import UI_DIST_PATH # pip install + except ImportError: + UI_DIST_PATH = PROJECT_ROOT / "dist" # development + +- **Production (pip install):** Imports ``UI_DIST_PATH`` from the installed package +- **Development:** Falls back to local ``dist/`` folder in the repository + +**5. FastAPI Static File Serving** + +The ``tracking-server/main.py`` serves the UI using FastAPI's ``StaticFiles``: + +.. code-block:: python + + from fastapi.staticfiles import StaticFiles + from fastapi.responses import FileResponse + + if UI_DIST_PATH.exists(): + # Serve bundled assets (JS, CSS, images) + app.mount("/assets", StaticFiles(directory=UI_DIST_PATH / "assets")) + + # SPA routing: serve index.html for non-API routes + @app.get("/{full_path:path}") + async def serve_ui(full_path: str): + if full_path.startswith("api/") or full_path.startswith("ws/"): + return None # Let API routes handle themselves + return FileResponse(UI_DIST_PATH / "index.html") + +- ``/assets/*`` routes serve static JS/CSS files +- All other routes (except ``/api/*`` and ``/ws/*``) serve ``index.html`` for React Router +- No separate web server needed - FastAPI handles everything + +**Key Benefits:** + +- **Single installation:** ``pip install artifacta`` includes both backend and frontend +- **No build required for users:** UI is pre-built and bundled +- **Development flexibility:** Same code works for local development and pip installs +- **Self-contained:** No external web server or CDN dependencies + +**For Maintainers:** + +- Always run ``npm run build`` before ``python -m build`` to update bundled UI +- The UI is **not** rebuilt during ``pip install`` - users get the pre-built version +- Changes to React code require rebuilding and republishing to PyPI diff --git a/docs/examples.rst b/docs/examples.rst index 3fc7c8d..08db88f 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -1,140 +1,165 @@ Examples ======== -Artifacta includes comprehensive examples demonstrating various use cases and logging capabilities. +Artifacta includes comprehensive examples demonstrating all features and use cases. -All examples are located in the `examples/` directory and can be run standalone. +All examples are located in the `examples/ directory `_. -Complete Examples ------------------ +Installation +------------ -PyTorch MNIST Classification -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Install Artifacta with all dependencies needed to run the examples: -**File:** `examples/pytorch_mnist.py `_ +.. code-block:: bash + + # Create a new virtual environment + python3 -m venv artifacta-examples + source artifacta-examples/bin/activate # On Windows: artifacta-examples\Scripts\activate -Demonstrates PyTorch integration with Artifacta by running **3 experiments** with different hyperparameters: + # Install artifacta + ML frameworks (sklearn, xgboost, pytorch, tensorflow) + pip install -r examples/requirements.txt -1. **Low LR + Adam**: learning_rate=0.001, optimizer=Adam -2. **Medium LR + SGD**: learning_rate=0.01, optimizer=SGD -3. **High LR + SGD**: learning_rate=0.05, optimizer=SGD +Quick Start +----------- -Each experiment logs: +**Start the UI server** (in a separate terminal): -- **ds.autolog()** - Automatic checkpoint logging -- **ds.Series** - Track training/validation loss and accuracy curves -- **ds.Matrix** - Log confusion matrix for classification performance -- **Artifact logging** - Save trained PyTorch model with metadata +.. code-block:: bash -**Run:** + artifacta ui + # View at http://localhost:8000 + +**Run the minimal example:** .. code-block:: bash - python examples/pytorch_mnist.py + python examples/core/01_basic_tracking.py -**What you'll see in the UI:** -- Training curves for all 3 runs in the Plots tab (compare optimizers/learning rates) -- Confusion matrices for each run -- Model checkpoints in the Artifacts tab -- Run comparison in the Tables tab +**Run all examples:** -TensorFlow Regression -~~~~~~~~~~~~~~~~~~~~~ +.. code-block:: bash -**File:** `examples/tensorflow_regression.py `_ + python examples/run_all_examples.py -Demonstrates TensorFlow/Keras integration by running **3 experiments** with different network architectures: +Core Examples +------------- -1. **Small Network**: hidden_dim=32, learning_rate=0.001 -2. **Medium Network**: hidden_dim=64, learning_rate=0.01 -3. **Large Network**: hidden_dim=128, learning_rate=0.001 +Basic Tracking +~~~~~~~~~~~~~~ -Each experiment logs: +**File:** `examples/core/01_basic_tracking.py `_ -- **ds.autolog()** - Automatic Keras checkpoint logging -- **ds.Series** - Track loss and MAE over epochs -- **ds.Scatter** - Visualize predictions vs actual values -- **ds.Distribution** - Analyze prediction residuals -- **Artifact logging** - Save Keras model with metadata +Minimal "hello world" example showing the basic workflow: -**Run:** +- Initialize a run with ``init()`` +- Log metrics with ``Series`` primitive +- Auto-finish behavior (no need to call ``finish()``) -.. code-block:: bash +**Run:** ``python examples/core/01_basic_tracking.py`` - python examples/tensorflow_regression.py +All Primitives Demo +~~~~~~~~~~~~~~~~~~~ -**What you'll see in the UI:** -- Training/validation loss curves for all 3 network sizes -- Scatter plots comparing predictions across architectures -- Residual distribution histograms -- RΒ² scores and error metrics for model comparison +**File:** `examples/core/02_all_primitives.py `_ -A/B Testing Experiment (Domain-Agnostic) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Comprehensive demo of all 7 data primitives: -**File:** `examples/ab_testing_experiment.py `_ +- **Series** - Time series data +- **Distribution** - Histograms +- **Matrix** - 2D heatmaps +- **Table** - Structured data +- **Curve** - X/Y relationships +- **Scatter** - Point clouds +- **BarChart** - Categorical comparisons -**Artifacta isn't just for ML!** This example demonstrates experiment tracking for A/B testing: +**Run:** ``python examples/core/02_all_primitives.py`` -- Simulate A/B test for e-commerce button colors -- **Parameter sweep** - Run multiple experiments with different sample sizes -- **ds.Distribution** - Compare conversion rates across variants -- **ds.Series** - Track cumulative conversions over time -- **ds.BarChart** - Visualize performance comparison -- Statistical significance testing and lift calculations +ML Framework Examples +--------------------- -**Run:** +Sklearn Classification +~~~~~~~~~~~~~~~~~~~~~~ -.. code-block:: bash +**File:** `examples/ml_frameworks/sklearn_classification.py `_ - python examples/ab_testing_experiment.py +Binary classification with RandomForestClassifier: -**What you'll see in the UI:** -- Conversion rate distributions by variant (Control, Green, Red) -- Time series of cumulative conversions -- Statistical significance and lift percentages -- Sweeps tab showing how sample size affects confidence +- ``autolog()`` integration for sklearn +- ROC and Precision-Recall curves +- Confusion matrix +- Feature importance -**Scenario:** Testing checkout button colors (Blue vs Green vs Red) to maximize conversions. Shows that Artifacta works for ANY parametric experiment tracking - not just machine learning! +**Run:** ``python examples/ml_frameworks/sklearn_classification.py`` -Running the Examples --------------------- +XGBoost Regression +~~~~~~~~~~~~~~~~~~ -**1. Install Artifacta with all dependencies:** +**File:** `examples/ml_frameworks/xgboost_regression.py `_ -.. code-block:: bash +Hyperparameter grid search with XGBoost: - pip install -e '.[dev]' +- Multiple runs with different configs +- Feature importance tracking +- Prediction scatter plots +- Model comparison -This installs Artifacta and all optional dependencies including PyTorch, TensorFlow, and scientific computing libraries from the ``pyproject.toml`` file. +**Run:** ``python examples/ml_frameworks/xgboost_regression.py`` -**2. Activate the virtual environment:** +PyTorch MNIST +~~~~~~~~~~~~~ -.. code-block:: bash +**File:** `examples/ml_frameworks/pytorch_mnist.py `_ - source venv/bin/activate +MNIST digit classification with PyTorch: -**3. Start the tracking server:** +- ``autolog()`` integration for PyTorch +- Training/validation curves +- Confusion matrix +- Model checkpoints -.. code-block:: bash +**Run:** ``python examples/ml_frameworks/pytorch_mnist.py`` - artifacta ui +TensorFlow Regression +~~~~~~~~~~~~~~~~~~~~~ -The web UI will be available at the URL shown in the terminal output. +**File:** `examples/ml_frameworks/tensorflow_regression.py `_ -**4. Run examples** (in a separate terminal with venv activated): +Regression with TensorFlow/Keras: -.. code-block:: bash +- ``autolog()`` integration for TensorFlow +- Loss and MAE curves +- Prediction scatter plots +- Model saving + +**Run:** ``python examples/ml_frameworks/tensorflow_regression.py`` + +Domain-Specific Examples +------------------------- + +A/B Testing Experiment +~~~~~~~~~~~~~~~~~~~~~~ + +**File:** `examples/domain_specific/ab_testing_experiment.py `_ + +**Artifacta isn't just for ML!** This example demonstrates A/B testing: + +- Simulate e-commerce button color test +- Parameter sweep with different sample sizes +- Conversion rate distributions +- Statistical significance testing + +**Run:** ``python examples/domain_specific/ab_testing_experiment.py`` + +Protein Expression Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**File:** `examples/domain_specific/protein_expression.py `_ - source venv/bin/activate # Activate in the new terminal too - python examples/pytorch_mnist.py - python examples/tensorflow_regression.py - python examples/ab_testing_experiment.py +Wet lab experiment tracking for biology: -**5. View results in the web UI:** +- Factorial design (temperature Γ— IPTG Γ— time) +- Yield, purity, activity measurements +- Growth curves +- Parameter correlation analysis -- **Plots tab** - Interactive visualizations of all logged metrics -- **Tables tab** - Tabular view with aggregations -- **Artifacts tab** - Browse and preview saved models -- **Sweeps tab** - Analyze parameter sweeps (A/B testing example) -- **Notebooks tab** - Document your findings +**Run:** ``python examples/domain_specific/protein_expression.py`` diff --git a/docs/index.rst b/docs/index.rst index 8d21d47..2f13676 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,7 +51,8 @@ Quick Links - :doc:`user-guide` - Complete user guide - :doc:`examples` - Example notebooks and scripts -- :doc:`api` - Complete API reference +- :doc:`api` - Python API reference +- :doc:`ui-api` - UI Components & Utilities API - :doc:`development` - Development and testing guide .. toctree:: @@ -61,6 +62,7 @@ Quick Links user-guide examples api + ui-api development Indices and tables diff --git a/docs/ui-api.rst b/docs/ui-api.rst new file mode 100644 index 0000000..1d3c447 --- /dev/null +++ b/docs/ui-api.rst @@ -0,0 +1,130 @@ +UI Components & Utilities API +============================== + +This section documents the React components, hooks, and utility functions that power the Artifacta user interface. + +.. note:: + The UI is built with React and uses HTML5 Canvas for high-performance plot rendering. + All visualizations are generated server-side from data primitives logged via the Python API. + +Organization +------------ + +The UI codebase is organized into: + +**Components** + React components for visualization, tabs, layout, and user interaction + +**Hooks** + Custom React hooks for data fetching, canvas rendering, and state management + +**Utils** + Utility functions for automatic plot discovery, data processing, and aggregation + +**Pages** + Top-level page components and routing + +Key UI Features +--------------- + +Tabs +~~~~ + +The Artifacta UI provides several specialized tabs for experiment analysis: + +- **Plots Tab** - Automatically discovered visualizations from logged primitives +- **Sweeps Tab** - Hyperparameter analysis with parallel coordinates and correlation plots +- **Artifacts Tab** - File browser with rich media preview (video, audio, PDF, code) +- **Lineage Tab** - Interactive artifact provenance graph +- **Chat Tab** - AI assistant for experiment analysis +- **Project Notes** - Electronic lab notebook with LaTeX and file attachments + +Plot Components +~~~~~~~~~~~~~~~ + +All plot types support interactive tooltips and are rendered on HTML5 Canvas for performance: + +- **LinePlot** - Time series with multi-run overlay support +- **ScatterPlot** - 2D scatter with nearest-point tooltips +- **Heatmap** - Matrix visualization with cell-level tooltips +- **CurveChart** - ROC/PR curves with AUC calculation +- **ParallelCoordinatesChart** - Hyperparameter relationships +- **BarChart** - Categorical comparisons +- **Histogram** - Distribution visualization +- **ViolinPlot** - Distribution comparison across categories + +Core Utilities +~~~~~~~~~~~~~~ + +**Plot Discovery** (``plotDiscovery.js``) + Automatically generates appropriate visualizations from data primitive types: + + - Series β†’ LinePlot + - Distribution β†’ Histogram/ViolinPlot + - Matrix β†’ Heatmap + - Curve β†’ CurveChart with AUC + - Scatter β†’ ScatterPlot + - BarChart β†’ BarChart + +**Multi-Run Comparison** (``comparisonPlotDiscovery.js``) + Handles overlay logic for comparing multiple experiment runs on the same plot. + Supports LinePlot and CurveChart overlay modes. + +**Metric Aggregation** (``metricAggregation.js``) + Aggregates metrics across runs for analysis views (mean, std, min, max, latest). + +**Sweep Detection** (``sweepDetection.js``) + Identifies hyperparameter sweeps from run configurations and generates + parallel coordinates and correlation visualizations. + +Custom Hooks +~~~~~~~~~~~~ + +**useCanvasTooltip** (``hooks/useCanvasTooltip.js``) + Manages interactive tooltips for canvas-based plots. Provides 60fps tooltip + updates using requestAnimationFrame. + + Supported tooltip types: + + - ``series`` - Multi-series line plots + - ``scatter`` - Scatter plot points + - ``matrix`` - Heatmap cells + - ``curve`` - ROC/PR curve points + +**useResponsiveCanvas** (``hooks/useResponsiveCanvas.js``) + Handles responsive canvas sizing with HiDPI (Retina) support. + Automatically adjusts canvas resolution for crisp rendering. + +**useRunData** (``hooks/useRunData.js``) + Fetches and caches experiment run data from the tracking server. + Handles multi-run selection and data aggregation. + +**useLayoutManager** (``hooks/useLayoutManager.js``) + Manages drag-and-drop layout persistence for plots and visualizations. + Stores layout preferences in browser localStorage. + +Full API Documentation +---------------------- + +For complete API documentation with function signatures, parameters, and return types, +see the auto-generated JSDoc documentation: + +.. raw:: html + +

+ View Full UI API Documentation (JSDoc) β†’ +

+ +.. note:: + The JSDoc documentation is generated from inline comments in the source code + and opens in a new window. It provides detailed information about all components, + hooks, utilities, and their usage. + +Building the UI Docs +-------------------- + +To regenerate the JSDoc documentation:: + + npm run docs:ui + +This will scan the ``src/app`` directory and generate HTML documentation in ``docs/_build/jsdoc/``. diff --git a/docs/user-guide.rst b/docs/user-guide.rst index 4efea1c..e41f486 100644 --- a/docs/user-guide.rst +++ b/docs/user-guide.rst @@ -18,14 +18,271 @@ Without systematic tracking of **parameters, metrics, code changes, dependencies Ecosystem & Alternatives ------------------------ -Artifacta is part of a growing ecosystem of experiment tracking tools. Popular alternatives include: - -- `MLflow `_ - Open-source platform from Databricks for ML lifecycle management -- `Weights & Biases `_ - Cloud-first experiment tracking with team collaboration features -- `Neptune.ai `_ - Metadata store for MLOps with extensive integrations -- `Comet ML `_ - ML platform with experiment tracking and model production monitoring - -**Why Artifacta?** We focus on **automatic visualization discovery**, **domain-agnostic tracking** (not just ML), and **simple self-hosting** with a pre-built UI. No heavy dependencies, no mandatory cloud servicesβ€”just install and start tracking. +Artifacta is part of a growing ecosystem of experiment tracking tools. Here's how we compare to popular alternatives: + +.. list-table:: Feature Comparison + :header-rows: 1 + :widths: 30 14 14 14 14 14 + + * - Feature + - Artifacta + - MLflow + - W&B + - Neptune.ai + - Comet ML + * - **Deployment** + - + - + - + - + - + * - Fully offline/local + - βœ… + - βœ… + - ⚠️ + - ❌ + - ⚠️ + * - Pre-built UI (no Node.js) + - βœ… + - ❌ + - ❌ + - ❌ + - ❌ + * - Self-hosted (free) + - βœ… + - βœ… + - ⚠️ + - ⚠️ + - ⚠️ + * - **Visualization** + - + - + - + - + - + * - Line/series charts + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Bar charts + - βœ… + - ⚠️ + - βœ… + - ⚠️ + - ⚠️ + * - Histograms + - βœ… + - ⚠️ + - βœ… + - βœ… + - ⚠️ + * - Scatter plots + - βœ… + - ⚠️ + - βœ… + - ⚠️ + - ⚠️ + * - Heatmaps + - βœ… + - ⚠️ + - ❌ + - ⚠️ + - ⚠️ + * - ROC/PR curves + - βœ… + - βœ… + - βœ… + - ⚠️ + - ⚠️ + * - Confusion matrix + - βœ… + - βœ… + - βœ… + - ⚠️ + - ⚠️ + * - Parallel coordinates + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Multi-run overlay + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - **Artifact Management** + - + - + - + - + - + * - Built-in file browser + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Rich media preview + - βœ… + - ⚠️ + - βœ… + - βœ… + - βœ… + * - Artifact lineage + - βœ… + - ❌ + - βœ… + - βœ… + - βœ… + * - **Analysis** + - + - + - + - + - + * - Auto-logging environment/system info + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Hyperparameter correlation + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Parameter importance + - ❌ + - ❌ + - βœ… + - ❌ + - βœ… + * - Built-in ELN (lab notebook) + - βœ… + - ❌ + - ❌ + - ❌ + - ❌ + * - AI assistant + - βœ… + - ❌ + - ⚠️ + - ❌ + - ⚠️ + * - **Domain Support** + - + - + - + - + - + * - Domain-agnostic + - βœ… + - βœ… + - ❌ + - ⚠️ + - ❌ + * - **Production & Deployment** + - + - + - + - + - + * - Model registry + - ❌ + - βœ… + - βœ… + - βœ… + - βœ… + * - Model deployment/serving + - ❌ + - βœ… + - βœ… + - βœ… + - βœ… + * - Production monitoring + - ❌ + - ❌ + - ⚠️ + - βœ… + - βœ… + * - Alerts/notifications + - ❌ + - ❌ + - βœ… + - βœ… + - βœ… + * - **Ecosystem** + - + - + - + - + - + * - Framework autologging + - βœ… + - βœ… + - βœ… + - βœ… + - βœ… + * - Hyperparameter sweeps + - ❌ + - ⚠️ + - βœ… + - ⚠️ + - βœ… + * - Team collaboration + - ❌ + - ⚠️ + - βœ… + - βœ… + - βœ… + * - Managed cloud + - ❌ + - βœ… + - βœ… + - βœ… + - βœ… + +**Legend:** βœ… Full support | ⚠️ Partial/Limited | ❌ Not available + +**⚠️ Notes:** + +- **MLflow**: Can be used for non-ML experiments; has documented A/B testing and analytics use cases +- **Neptune**: No documented non-ML examples +- **W&B, Comet**: ML-focused with no documented non-ML use cases +- **W&B**: Production monitoring available but less comprehensive than Neptune/Comet +- **MLflow**: No built-in production monitoring; requires external tools +- **Artifacta**: Full autologging for scikit-learn, XGBoost, PyTorch Lightning, and TensorFlow/Keras (automatically captures parameters, metrics, models, and datasets) +- **MLflow**: Open-source version lacks permissions; full collaboration requires Databricks Managed MLflow +- **MLflow**: Basic hyperparameter tracking but no automated sweep optimization like W&B/Comet +- **Neptune**: No built-in sweep optimization; integrates with external tools like Optuna + +Why Choose Artifacta? +~~~~~~~~~~~~~~~~~~~~~~ + +**What makes Artifacta different:** + +- **Zero configuration** - Pre-built UI bundled with Python packageβ€”``pip install`` and you're done. No Node.js, Docker, or build tools required +- **Truly offline-first** - Works 100% locally without any cloud dependencies, license servers, or internet connection +- **Server-side plot generation** - Log data primitives (Series, Scatter, Matrix), not matplotlib figuresβ€”Artifacta renders plots for you. No need to create and upload images (though you can if you want) +- **Built-in electronic lab notebook** - Rich text editor with LaTeX support, file attachments, and per-project organizationβ€”not available in any competitor +- **AI chat interface** - Built-in LLM chat (OpenAI, Anthropic, local models) to analyze experiments, results, and code. W&B and Comet have AI features in premium tiers only +- **Domain-agnostic design** - Primitives work for any fieldβ€”ML, A/B tests, physics, finance, genomics, climate science. Not ML-only like most alternatives +- **Rich artifact previews** - Built-in viewers for video, audio, PDFs, code, images. MLflow only previews images; others require external viewers +- **Interactive artifact lineage** - Visual flow graph showing how artifacts relate. MLflow has no lineage visualization + + +When to Choose Alternatives +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- **MLflow** - If you need autologging beyond scikit-learn/XGBoost (e.g., PyTorch, TensorFlow, LightGBM) or already use Databricks +- **Weights & Biases** - If team collaboration is essential, or you want powerful hyperparameter sweeps with optimization +- **Neptune.ai** - If you need comprehensive system monitoring (ongoing CPU/GPU/memory tracking) or work with very large-scale experiments +- **Comet ML** - If you need advanced custom dashboards or detailed experiment comparison tools Installation ------------ @@ -64,6 +321,14 @@ To run examples or tests, install with optional dependencies: pip install -e .[dev] +For generating real test videos (optional, requires FFmpeg): + +.. code-block:: bash + + pip install -e .[dev,video] + +Note: Video artifact logging works without this - test helpers will use placeholder MP4 files instead of generating real videos. + Starting the Tracking Server ----------------------------- @@ -107,10 +372,10 @@ Here's a simple example to get you started: .. code-block:: python - import artifacta as ds + from artifacta import Series, init, log # Initialize a run - run = ds.init( + run = init( project="my-project", name="experiment-1", config={"learning_rate": 0.001, "batch_size": 32} @@ -121,7 +386,7 @@ Here's a simple example to get you started: train_loss = train_model() # Your training code # Log metrics as a Series - ds.log("metrics", ds.Series( + log("metrics", Series( index="epoch", fields={ "train_loss": [train_loss], @@ -134,6 +399,60 @@ Here's a simple example to get you started: # Run automatically finishes when script exits! +Automatic Metadata Capture +--------------------------- + +Artifacta automatically captures environment context when you call ``artifacta.init()``. This happens transparently in the backgroundβ€”no additional code required. + +**What Gets Captured:** + +**Git Information** + - Commit hash (SHA) + - Remote repository URL + - Dirty status (whether you have uncommitted changes) + +**Environment** + - Hostname and username + - Python version + - Operating system and platform + - Working directory + - Command-line arguments used to run your script + +**System Hardware** + - CPU count (physical and logical cores) + - Total RAM + - GPU information (name, memory) if available via ``pynvml`` + +**Dependencies** + - Full ``pip freeze`` output capturing all installed packages and versions + +**How It's Stored:** + +Metadata is automatically saved as artifacts with SHA256 content hashes: + +- **config.json** - Your hyperparameters and config dict (linked to run via ``config_artifact_id``) +- **requirements.txt** - Full ``pip freeze`` output with exact package versions +- **environment.json** - Python version, platform, CUDA version, etc. + +Each artifact is content-addressed using SHA256 hashing, enabling: + +- **Deduplication** - Identical dependencies/configs across runs share the same artifact +- **Integrity** - Verify artifact contents haven't been tampered with +- **Reproducibility** - Exact environment can be reconstructed from the hash + +**Viewing Metadata:** + +All captured metadata is stored with each run and visible in the web UI. This allows you to: + +- Reproduce experiments by seeing exact commit, dependencies, and environment +- Debug issues by comparing system configurations across runs +- Track when dependency upgrades caused performance changes +- Identify which code version produced specific results + +**Privacy Note:** + +Metadata capture runs locally and is stored only in your local database. No data is sent externally. If you're working in a sensitive environment, you can inspect what's captured in the metadata before sharing experiment results. + Logging ======= @@ -150,7 +469,9 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python - ds.log("training", ds.Series( + from artifacta import log, Series + + log("training", Series( index="step", fields={ "loss": [0.5, 0.3, 0.2], @@ -162,9 +483,10 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python + from artifacta import log, Distribution import numpy as np - ds.log("weights", ds.Distribution( + log("weights", Distribution( values=np.random.randn(1000) )) @@ -172,7 +494,9 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python - ds.log("embeddings", ds.Scatter( + from artifacta import log, Scatter + + log("embeddings", Scatter( x=[1, 2, 3, 4], y=[2, 4, 6, 8], labels=["A", "B", "C", "D"] @@ -182,7 +506,9 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python - ds.log("confusion_matrix", ds.Matrix( + from artifacta import log, Matrix + + log("confusion_matrix", Matrix( rows=["True A", "True B"], cols=["Class A", "Class B"], values=[[10, 2], [3, 15]] @@ -192,12 +518,13 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python + from artifacta import log, Curve from sklearn.metrics import roc_curve, auc fpr, tpr, _ = roc_curve(y_true, y_scores) roc_auc = auc(fpr, tpr) - ds.log("roc_curve", ds.Curve( + log("roc_curve", Curve( x=fpr.tolist(), y=tpr.tolist(), x_label="False Positive Rate", @@ -210,7 +537,9 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python - ds.log("model_comparison", ds.BarChart( + from artifacta import log, BarChart + + log("model_comparison", BarChart( categories=["ResNet-50", "EfficientNet-B0", "ViT-Base"], groups={ "accuracy": [0.85, 0.88, 0.90], @@ -225,7 +554,9 @@ Artifacta provides rich primitives for logging structured data. These primitives .. code-block:: python - ds.log("top_variants", ds.Table( + from artifacta import log, Table + + log("top_variants", Table( columns=[ {"name": "Chromosome", "type": "string"}, {"name": "Position", "type": "number"}, @@ -257,50 +588,117 @@ Log files like models, datasets, code, and configuration. Artifacts appear in th # Log source code directory (automatically recursive) run.log_artifact("training_code", "src/") -Auto-logging Checkpoints ------------------------- +Framework Autologging +--------------------- -Artifacta can automatically log model checkpoints for PyTorch Lightning and TensorFlow: +Artifacta provides zero-configuration autologging for popular ML frameworks. Enable once, and all training parameters, metrics, and models are automatically captured. -**PyTorch Lightning:** +**What is logged:** + +- **Parameters**: Hyperparameters, optimizer config (via ``run.update_config()``) +- **Metrics**: Training/validation metrics per epoch (via ``run.log()``) +- **Models**: Trained models and checkpoints (via ``run.log_artifact()``) +- **Datasets**: Input data metadata - shape, dtype, hash (sklearn/XGBoost only) + +scikit-learn +~~~~~~~~~~~~ + +.. code-block:: python + + import artifacta as ds + from sklearn.ensemble import RandomForestClassifier + + # Enable autolog + ds.autolog() # Auto-detects sklearn + + # Or explicitly + from artifacta.integrations import sklearn + sklearn.enable_autolog() + + # Train as usual - everything logged automatically + clf = RandomForestClassifier(n_estimators=100, max_depth=5) + clf.fit(X_train, y_train) + + # Logged: n_estimators=100, max_depth=5, training accuracy, + # confusion matrix, ROC curve, model artifact, dataset metadata + +XGBoost +~~~~~~~ + +.. code-block:: python + + import artifacta as ds + import xgboost as xgb + + # Enable autolog + ds.autolog() # Auto-detects XGBoost + + # Native API + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test) + params = {"max_depth": 3, "eta": 0.1} + + booster = xgb.train(params, dtrain, num_boost_round=100, + evals=[(dtrain, "train"), (dtest, "test")]) + + # Logged: max_depth, eta, per-iteration metrics (train/test loss), + # feature importance, model artifact, dataset metadata + +PyTorch Lightning +~~~~~~~~~~~~~~~~~ .. code-block:: python import artifacta as ds import pytorch_lightning as pl - # Enable checkpoint logging - ds.autolog() + # Enable autolog + ds.autolog() # Auto-detects PyTorch Lightning - # Your PyTorch Lightning code works as usual + # Train as usual trainer = pl.Trainer(max_epochs=10) trainer.fit(model, train_loader) -**TensorFlow/Keras:** + # Logged: epochs=10, optimizer_name, learning_rate, + # per-epoch metrics (loss, accuracy), checkpoints, final model + +TensorFlow/Keras +~~~~~~~~~~~~~~~~ .. code-block:: python import artifacta as ds import tensorflow as tf - # Enable checkpoint logging - ds.autolog() + # Enable autolog + ds.autolog() # Auto-detects TensorFlow/Keras - # Your TensorFlow code works as usual - model.compile(optimizer='adam', loss='mse') - model.fit(x_train, y_train, epochs=10) + # Train as usual + model.compile(optimizer='adam', loss='mse', metrics=['mae']) + model.fit(x_train, y_train, epochs=10, batch_size=32) -**Checkpoint Logging Options:** + # Logged: epochs=10, batch_size=32, optimizer_name, learning_rate, + # per-epoch metrics (loss, mae), checkpoints, final model -.. code-block:: python +Configuration Options +~~~~~~~~~~~~~~~~~~~~~ - # Auto-detect framework and log checkpoints every epoch - ds.autolog() +.. code-block:: python - # Disable checkpoint logging - ds.autolog(log_checkpoints=False) + # Disable specific features + from artifacta.integrations import sklearn + sklearn.enable_autolog( + log_models=False, # Don't log model artifacts + log_datasets=False, # Don't log dataset metadata + log_training_metrics=False # Don't compute training metrics + ) -Note: Autolog captures model checkpoints with metadata. For metrics, use ``ds.log()`` with primitives. + # PyTorch/TensorFlow options + from artifacta.integrations import pytorch_lightning + pytorch_lightning.enable_autolog( + log_checkpoints=False, # Don't log per-epoch checkpoints + log_models=True # Still log final model + ) Language-Agnostic Logging -------------------------- @@ -378,7 +776,7 @@ UI Features The Artifacta web UI provides several features for visualizing and managing your experiments. -**Important**: To see data in the UI, you must first log it to the database using ``ds.init()`` and ``ds.log()`` in your Python scripts. Then select runs from the **Runs** section in the sidebar. +**Important**: To see data in the UI, you must first log it to the database using ``artifacta.init()`` and ``log()`` in your Python scripts. Then select runs from the **Runs** section in the sidebar. UI Selection Requirements -------------------------- @@ -410,6 +808,30 @@ The **Notebooks** tab provides a rich text editor for documenting experiments: Create and edit notes directly in the web UI using the rich text editor. +**Supported Attachment Formats:** + +**Code & Text Files** (inline preview with syntax highlighting) + - Python (.py), JavaScript (.js, .jsx), TypeScript (.ts, .tsx) + - Java (.java), C/C++ (.c, .cpp), Ruby (.rb), Go (.go), Rust (.rs) + - PHP (.php), Swift (.swift), Kotlin (.kt) + - SQL (.sql), Shell (.sh, .bash) + - JSON (.json), XML (.xml), YAML (.yaml, .yml) + - Markdown (.md), HTML (.html), CSS (.css, .scss) + - Plain text (.txt, .log) + - Any ``text/*`` MIME type + +**Images** (inline preview) + - PNG, JPEG, GIF, SVG, WebP (``image/*`` MIME types) + +**Media** (inline player) + - Audio: MP3, WAV, OGG, etc. (``audio/*`` MIME types) + - Video: MP4, WebM, etc. (``video/*`` MIME types) + +**Documents** (inline viewer) + - PDF (``application/pdf``) - embedded iframe viewer + +All other file types show as downloadable attachments with file icon and size. + .. image:: _static/ELN.gif :alt: Notebooks tab with rich text editor :align: center @@ -420,17 +842,19 @@ Create and edit notes directly in the web UI using the rich text editor. Plots Tab --------- -The **Plots** tab visualizes all primitives logged via ``ds.log()``: +The **Plots** tab visualizes all primitives logged via ``log()``: -- **Series charts** - Line plots for time series data (loss, accuracy over epochs) - *supports multi-run overlay* +- **Series charts** - Line plots for time series data (loss, accuracy over epochs) - *supports multi-run overlay, interactive tooltips* - **Distributions** - Histograms and distribution plots -- **Scatter plots** - 2D scatter visualizations (embeddings, etc.) -- **Curves** - ROC curves, PR curves with AUC metrics - *supports multi-run overlay* +- **Scatter plots** - 2D scatter visualizations (embeddings, etc.) - *interactive tooltips* +- **Curves** - ROC curves, PR curves with AUC metrics - *supports multi-run overlay, interactive tooltips* - **Bar charts** - Model comparisons and categorical data -- **Matrices** - Confusion matrices and heatmaps +- **Matrices** - Confusion matrices and heatmaps - *interactive tooltips* Plots are automatically discovered from logged primitives and organized by section. You can drag and resize plots. +**Interactive Tooltips**: Line plots, scatter plots, curve charts, and heatmaps display detailed data values when you hover over them. For line plots, the tooltip shows all series values at the hovered position. For scatter plots, it displays x and y coordinates of the nearest point. Heatmaps show the row, column, and cell value. + **Multi-Run Comparison**: When multiple runs are selected in the sidebar, Series and Curve plots automatically overlay all selected runs on the same chart for easy comparison. Other plot types remain separate per run to avoid visual clutter. .. image:: _static/Plots.gif @@ -452,8 +876,8 @@ Tables Tab The **Tables** tab displays metrics and data in tabular format: -- **Table primitives** - View structured data from ``ds.Table`` with sortable columns -- **Series aggregations** - View ``ds.Series`` metrics with min/max/final aggregations +- **Table primitives** - View structured data from ``artifacta.Table`` with sortable columns +- **Series aggregations** - View ``artifacta.Series`` metrics with min/max/final aggregations - **Run comparison** - Compare multiple runs side-by-side in table format - **Aggregation modes** - Switch between min, max, or final (last) value for each metric - **CSV export** - Export table data for further analysis @@ -461,14 +885,46 @@ The **Tables** tab displays metrics and data in tabular format: Sweeps Tab ---------- -The **Sweeps** tab analyzes hyperparameter sweeps when you select multiple runs: +The **Sweeps** tab analyzes hyperparameter sweeps when you select multiple runs. It requires at least 2 runs with the same config keys but varying parameter values. + +**Visualizations:** + +**Parallel Coordinates** + Multi-dimensional visualization showing relationships between all hyperparameters and a selected metric across runs. Each line represents one run, flowing through vertical axes (one per parameter + target metric). Lines are color-coded by metric value using a gradient (low = purple/blue, high = green). + + - **What it shows:** How parameter combinations relate to outcomes + - **How it helps:** Identify patterns like "high learning rate + small batch size β†’ low loss" + - **Interaction:** Select which target metric to display, choose aggregation method (last/max/min) + +**Parameter Correlation Charts** + Bar charts showing the correlation strength between each hyperparameter and each metric using Pearson correlation coefficient. + + - **Range:** -1 to +1 (negative = inverse relationship, positive = direct relationship, 0 = no correlation) + - **Importance score:** Absolute value of correlation - both strong positive and strong negative correlations are "important" + - **What it shows:** Which parameters have the strongest impact on metrics + - **How it helps:** Focus tuning efforts on high-impact parameters, ignore parameters with near-zero correlation + - **Note:** Categorical parameters are converted to numeric indices for correlation calculation + +**Scatter Plots** + Individual scatter plots for each numeric varying parameter vs. selected target metric. One plot per parameter. + + - **What it shows:** Direct relationship between a single parameter and outcome + - **How it helps:** See trends (linear, logarithmic, threshold effects) and optimal parameter ranges + - **Interaction:** Select which metric to plot, choose aggregation method (last/max/min) -- **Parallel coordinates** - Visualize high-dimensional parameter spaces -- **Parameter correlation** - See which hyperparameters most impact metrics -- **Scatter plots** - Plot individual parameters vs target metrics -- **Aggregation options** - Choose last/max/min values for metrics +**Aggregation Options:** -The tab only appears when selected runs form a valid sweep (same config keys with varying values). +All visualizations support multiple aggregation methods for metrics: + +- **last** - Final value from training (default) +- **max** - Best (maximum) value achieved +- **min** - Best (minimum) value achieved + +**Requirements:** + +- At least 2 runs selected +- Runs must have the same config keys (valid sweep structure) +- At least 3 runs recommended for meaningful correlation analysis .. image:: _static/Sweeps.gif :alt: Sweeps tab for hyperparameter analysis @@ -508,6 +964,29 @@ The **Artifacts** tab provides a file browser and preview for logged artifacts: Navigate artifacts using the Files panel in the sidebar, then click to preview in this tab. +**Supported File Formats:** + +**Code & Text Files** (syntax highlighting) + - Python (.py), JavaScript (.js), TypeScript (.ts) + - JSON (.json), YAML (.yaml, .yml) + - Markdown (.md) + - Any text file (``text/*`` MIME types) + +**Data Files** + - CSV (.csv) - rendered as sortable tables with pagination + +**Images** + - PNG, JPEG, GIF, SVG, WebP (``image/*`` MIME types) + +**Media** + - Audio: MP3, WAV, OGG, etc. (``audio/*`` MIME types) + - Video: MP4, WebM, etc. (``video/*`` MIME types) + +**Documents** + - PDF (``application/pdf``) - rendered inline with iframe viewer + +Other file types will show a download button without preview. + .. image:: _static/Artifacts_1.gif :alt: Artifacts tab file browser and preview :align: center diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..de934a7 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,34 @@ +# Artifacta Examples + +This directory contains comprehensive examples demonstrating all Artifacta features and use cases. + +## Installation + +Install Artifacta with all dependencies needed to run the examples: + +```bash +# Create a new virtual environment +python3 -m venv artifacta-examples +source artifacta-examples/bin/activate # On Windows: artifacta-examples\Scripts\activate + +# Install artifacta + ML frameworks (sklearn, xgboost, pytorch, tensorflow) +pip install -r examples/requirements.txt +``` + +## Quick Start + +**Start the UI server** (in a separate terminal): +```bash +artifacta ui +# View at http://localhost:8000 +``` + +**Run the minimal example**: +```bash +python examples/core/01_basic_tracking.py +``` + +**Run all examples**: +```bash +python examples/run_all_examples.py +``` diff --git a/examples/core/01_basic_tracking.py b/examples/core/01_basic_tracking.py new file mode 100644 index 0000000..480eaa2 --- /dev/null +++ b/examples/core/01_basic_tracking.py @@ -0,0 +1,111 @@ +""" +Basic Experiment Tracking with Artifacta +========================================= + +This minimal example demonstrates the core Artifacta workflow: +1. Initialize a run with init() +2. Log metrics using Series primitive +3. Run automatically finishes when script exits + +Perfect starting point for understanding Artifacta basics. + +Requirements: + pip install artifacta numpy + +Usage: + python examples/core/01_basic_tracking.py +""" + +import time + +import numpy as np + +from artifacta import Series, init + + +def simulate_training(epochs=5): + """Simulate a simple training loop. + + Returns training and validation loss for each epoch. + """ + train_losses = [] + val_losses = [] + + for epoch in range(1, epochs + 1): + # Simulate decreasing loss with some noise + train_loss = 1.0 / epoch + np.random.normal(0, 0.05) + val_loss = 1.0 / epoch + np.random.normal(0, 0.08) + + train_losses.append(max(0, train_loss)) + val_losses.append(max(0, val_loss)) + + print(f"Epoch {epoch}/{epochs}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}") + time.sleep(0.2) # Simulate training time + + return train_losses, val_losses + + +def main(): + print("=" * 60) + print("Artifacta Basic Tracking Example") + print("=" * 60) + + # ================================================================= + # 1. Initialize Artifacta run + # This creates a new experiment with configuration + # ================================================================= + config = { + "model": "SimpleNN", + "learning_rate": 0.01, + "epochs": 5, + } + + run = init( + project="getting-started", + name="basic-example", + config=config, + ) + + print("\nArtifacta run initialized") + print(" Project: getting-started") + print(" Run: basic-example") + + # ================================================================= + # 2. Run your experiment (training simulation) + # ================================================================= + print("\nStarting training simulation...") + train_losses, val_losses = simulate_training(epochs=config["epochs"]) + + # ================================================================= + # 3. Log results using Series primitive + # Series is used for ordered data over a single dimension + # ================================================================= + print("\nLogging training metrics...") + + run.log( + "loss_curves", + Series( + index="epoch", + fields={ + "train_loss": train_losses, + "val_loss": val_losses, + }, + index_values=list(range(1, config["epochs"] + 1)), + ), + ) + + # ================================================================= + # 4. Run automatically finishes when script exits + # No need to call run.finish() manually! + # ================================================================= + print("\n" + "=" * 60) + print("Experiment Complete!") + print("=" * 60) + print("Metrics logged to Artifacta") + print("Run will auto-finish when script exits") + print(" View results in the Artifacta UI!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/core/02_all_primitives.py b/examples/core/02_all_primitives.py new file mode 100644 index 0000000..3f61e93 --- /dev/null +++ b/examples/core/02_all_primitives.py @@ -0,0 +1,225 @@ +""" +All Artifacta Primitives Demonstration +====================================== + +This example demonstrates all 7 Artifacta data primitives in one place: +1. Series - Ordered data over a single dimension (time series, epochs) +2. Distribution - Value collections with optional grouping (histograms) +3. Matrix - 2D relationships (confusion matrices, heatmaps) +4. Table - Generic tabular data (measurements, logs) +5. Curve - Pure X-Y relationships (ROC curves, dose-response) +6. Scatter - Unordered point clouds (embeddings, correlations) +7. BarChart - Categorical comparisons (model performance) + +Each primitive is domain-agnostic and works for any field (ML, physics, finance, etc.). + +Requirements: + pip install artifacta numpy + +Usage: + python examples/core/02_all_primitives.py +""" + +import numpy as np + +from artifacta import BarChart, Curve, Distribution, Matrix, Scatter, Series, Table, init + + +def main(): + print("=" * 70) + print("Artifacta - All Primitives Demonstration") + print("=" * 70) + + # Initialize run + run = init( + project="primitives-demo", + name="all-primitives", + config={"description": "Showcase all 7 data primitives"}, + ) + + print("\nArtifacta run initialized\n") + + # ================================================================= + # 1. Series - Ordered data over single dimension + # Use cases: Training loss, stock prices, temperature over time + # ================================================================= + print("Logging: Series (training metrics over epochs)") + + run.log( + "training_metrics", + Series( + index="epoch", + fields={ + "train_loss": [0.8, 0.5, 0.3, 0.2, 0.15], + "val_loss": [0.9, 0.6, 0.4, 0.3, 0.25], + "accuracy": [0.6, 0.75, 0.85, 0.90, 0.93], + }, + index_values=[1, 2, 3, 4, 5], + metadata={"description": "Training progress over 5 epochs"}, + ), + ) + + # ================================================================= + # 2. Distribution - Values with optional grouping + # Use cases: A/B testing, prediction distributions, response times + # ================================================================= + print("Logging: Distribution (A/B test results)") + + # Simulate A/B test conversion rates + control_conversions = np.random.binomial(1, 0.05, 100) + variant_conversions = np.random.binomial(1, 0.07, 100) + + all_values = np.concatenate([control_conversions, variant_conversions]) + all_groups = ["Control"] * 100 + ["Variant"] * 100 + + run.log( + "ab_test_distribution", + Distribution( + values=all_values.tolist(), + groups=all_groups, + metadata={ + "description": "Conversion outcomes by variant", + "control_rate": float(control_conversions.mean()), + "variant_rate": float(variant_conversions.mean()), + }, + ), + ) + + # ================================================================= + # 3. Matrix - 2D relationships + # Use cases: Confusion matrices, correlation matrices, heatmaps + # ================================================================= + print("Logging: Matrix (confusion matrix)") + + run.log( + "confusion_matrix", + Matrix( + rows=["Cat", "Dog", "Bird"], + cols=["Cat", "Dog", "Bird"], + values=[ + [85, 10, 5], # Cat predictions + [8, 87, 5], # Dog predictions + [7, 8, 85], # Bird predictions + ], + metadata={"type": "confusion_matrix", "total_samples": 300}, + ), + ) + + # ================================================================= + # 4. Table - Generic tabular data + # Use cases: Event logs, measurements, multi-variable data + # ================================================================= + print("Logging: Table (experiment measurements)") + + run.log( + "measurements", + Table( + columns=[ + {"name": "experiment_id", "type": "string"}, + {"name": "temperature_C", "type": "float"}, + {"name": "pressure_bar", "type": "float"}, + {"name": "yield_pct", "type": "float"}, + ], + data=[ + ["exp_001", 25.0, 1.0, 78.5], + ["exp_002", 30.0, 1.2, 82.3], + ["exp_003", 35.0, 1.5, 85.1], + ["exp_004", 40.0, 1.8, 79.2], + ], + metadata={"description": "Chemical reaction optimization data"}, + ), + ) + + # ================================================================= + # 5. Curve - Pure X-Y relationships (not time-indexed) + # Use cases: ROC curves, dose-response, calibration curves + # ================================================================= + print("Logging: Curve (ROC curve)") + + # Simulate ROC curve data + fpr = [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0] + tpr = [0.0, 0.6, 0.75, 0.85, 0.92, 0.97, 1.0] + auc = 0.92 + + run.log( + "roc_curve", + Curve( + x=fpr, + y=tpr, + x_label="False Positive Rate", + y_label="True Positive Rate", + baseline="diagonal", # Show y=x diagonal reference line + metric={"name": "AUC", "value": auc}, + metadata={"description": "ROC curve for binary classifier"}, + ), + ) + + # ================================================================= + # 6. Scatter - Unordered point clouds + # Use cases: Feature correlations, embeddings (t-SNE, UMAP) + # ================================================================= + print("Logging: Scatter (feature correlation)") + + # Generate correlated features + np.random.seed(42) + feature1 = np.random.randn(50) + feature2 = feature1 * 0.8 + np.random.randn(50) * 0.3 # Correlated + + points = [ + {"x": float(x), "y": float(y), "label": "data_point"} + for x, y in zip(feature1, feature2) + ] + + run.log( + "feature_correlation", + Scatter( + points=points, + x_label="Feature 1", + y_label="Feature 2", + metadata={ + "description": "Correlation between two features", + "correlation": float(np.corrcoef(feature1, feature2)[0, 1]), + }, + ), + ) + + # ================================================================= + # 7. BarChart - Categorical comparisons + # Use cases: Model performance, metrics by group + # ================================================================= + print("Logging: BarChart (model comparison)") + + run.log( + "model_comparison", + BarChart( + categories=["LogisticReg", "RandomForest", "XGBoost", "NeuralNet"], + groups={ + "Accuracy": [0.85, 0.91, 0.94, 0.93], + "F1-Score": [0.82, 0.89, 0.92, 0.91], + "Precision": [0.84, 0.90, 0.93, 0.92], + }, + x_label="Model", + y_label="Score", + metadata={"description": "Classification performance comparison"}, + ), + ) + + # ================================================================= + # Summary + # ================================================================= + print("\n" + "=" * 70) + print("All 7 Primitives Logged Successfully!") + print("=" * 70) + print("Series - Training metrics over epochs") + print("Distribution - A/B test conversion outcomes") + print("Matrix - Confusion matrix") + print("Table - Experiment measurements") + print("Curve - ROC curve") + print("Scatter - Feature correlation") + print("BarChart - Model performance comparison") + print("\nView all visualizations in the Artifacta UI!") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/ab_testing_experiment.py b/examples/domain_specific/ab_testing_experiment.py similarity index 94% rename from examples/ab_testing_experiment.py rename to examples/domain_specific/ab_testing_experiment.py index 90c1670..e40a931 100644 --- a/examples/ab_testing_experiment.py +++ b/examples/domain_specific/ab_testing_experiment.py @@ -7,9 +7,9 @@ This example shows: 1. **Domain-agnostic tracking** - Track A/B tests, marketing experiments, etc. 2. **Parameter sweeps** - Run multiple experiments with different configurations -3. **Distribution analysis** via ds.Distribution - Compare conversion rates by variant -4. **Time series tracking** via ds.Series - Monitor cumulative metrics over time -5. **Category comparison** via ds.BarChart - Visualize performance across variants +3. **Distribution analysis** via Distribution - Compare conversion rates by variant +4. **Time series tracking** via Series - Monitor cumulative metrics over time +5. **Category comparison** via BarChart - Visualize performance across variants Scenario: We're testing different button colors on an e-commerce checkout page to see which @@ -17,10 +17,10 @@ and traffic splits to demonstrate parameter sweeps. Key Artifacta Features Demonstrated: -- ds.init() - Initialize experiment run with config (NOT ML-specific!) -- ds.Distribution - Log conversion rates by variant (grouped data) -- ds.Series - Log cumulative conversions over time -- ds.BarChart - Compare final metrics across variants +- init() - Initialize experiment run with config (NOT ML-specific!) +- Distribution - Log conversion rates by variant (grouped data) +- Series - Log cumulative conversions over time +- BarChart - Compare final metrics across variants - Parameter sweeps - Run multiple experiments systematically Requirements: @@ -35,7 +35,7 @@ import numpy as np -import artifacta as ds +from artifacta import BarChart, Distribution, Series, init, log def simulate_button_test( @@ -253,9 +253,9 @@ def run_ab_test_experiment(config: Dict, seed: int = 42): + ["Variant B (Red)"] * len(results["variant_b"]["samples"]) ) - ds.log( + log( "conversion_by_variant", - ds.Distribution( + Distribution( values=all_values.tolist(), groups=all_groups, metadata={ @@ -281,9 +281,9 @@ def run_ab_test_experiment(config: Dict, seed: int = 42): total_conversions, n_hours=config["test_duration_hours"], random_state=seed ) - ds.log( + log( "cumulative_conversions", - ds.Series( + Series( index="hour", fields={ "control": time_series["control"], @@ -323,9 +323,9 @@ def run_ab_test_experiment(config: Dict, seed: int = 42): # 4. Log comparison bar chart # Shows side-by-side comparison of key metrics # ================================================================= - ds.log( + log( "variant_comparison", - ds.BarChart( + BarChart( categories=["Control (Blue)", "Variant A (Green)", "Variant B (Red)"], groups={ "Conversion Rate (%)": [ @@ -351,9 +351,9 @@ def run_ab_test_experiment(config: Dict, seed: int = 42): # ================================================================= # 5. Log final summary metrics as Series # ================================================================= - ds.log( + log( "summary_metrics", - ds.Series( + Series( index="variant", fields={ "conversion_rate_pct": [ @@ -419,7 +419,7 @@ def main(): }, ] - print(f"\nβœ“ Running parameter sweep with {len(param_variations)} configurations...") + print(f"\nRunning parameter sweep with {len(param_variations)} configurations...") # ================================================================= # Run experiments with different configurations @@ -435,7 +435,7 @@ def main(): # Initialize Artifacta run for this experiment run_name = f"button-test-{variation['name']}" - ds.init(project="ab-testing-demo", name=run_name, config=config) + init(project="ab-testing-demo", name=run_name, config=config) # Run the experiment run_ab_test_experiment(config, seed=42 + idx) @@ -456,7 +456,7 @@ def main(): print("\nRecommendation:") print(" β†’ Deploy Variant A (Green button) to production!") print(" β†’ Expected conversion lift: ~24%") - print("\nβœ“ All experiments logged to Artifacta") + print("\nAll experiments logged to Artifacta") print(" View detailed results in the Artifacta UI!") print(" Compare all runs in the project view to see how sample size") print(" affects confidence and statistical significance.") diff --git a/examples/domain_specific/protein_expression.py b/examples/domain_specific/protein_expression.py new file mode 100644 index 0000000..948dbd0 --- /dev/null +++ b/examples/domain_specific/protein_expression.py @@ -0,0 +1,348 @@ +""" +Protein Expression Optimization Example with Artifacta +======================================================= + +This example demonstrates using Artifacta for wet lab experimental data tracking. +It simulates a typical protein expression optimization workflow where researchers +test different growth conditions to maximize protein yield. + +Scenario: +--------- +A molecular biology lab is optimizing expression of a recombinant protein in E. coli. +They're testing different combinations of: +- Temperature (25Β°C, 30Β°C, 37Β°C) +- IPTG concentration (0.1 mM, 0.5 mM, 1.0 mM) +- Induction time (4h, 6h, 8h) + +For each condition, they measure: +- Total protein yield (mg/L) +- Protein purity (%) +- Specific activity (U/mg) +- Growth rate (OD600/h) + +Key Artifacta Features Demonstrated: +------------------------------------ +1. **CSV data import** - Load experimental data from spreadsheet +2. **Grid search tracking** - Systematically track all parameter combinations +3. **Series plots** - Visualize growth curves over time +4. **BarChart comparisons** - Compare yields across conditions +5. **Scatter plots** - Analyze parameter-response relationships +6. **Parameter correlation** - Identify which factors matter most + +Requirements: + pip install artifacta pandas numpy + +Usage: + python examples/protein_expression_optimization.py +""" + +from itertools import product + +import numpy as np +import pandas as pd + +from artifacta import BarChart, Scatter, Series, init, log + + +def generate_synthetic_data(): + """Generate synthetic protein expression data. + + In a real scenario, this would be loaded from your lab's CSV files. + This simulates realistic experimental results with some biological noise. + + Returns: + DataFrame with experimental conditions and results + """ + print("\nGenerating synthetic experimental data...") + print(" (In practice, this would be loaded from your lab notebook CSV)") + + # Define parameter grid + temperatures = [25, 30, 37] # Β°C + iptg_concs = [0.1, 0.5, 1.0] # mM + induction_times = [4, 6, 8] # hours + + # Generate all combinations + conditions = list(product(temperatures, iptg_concs, induction_times)) + + # Simulate realistic protein expression results + # Higher temp + moderate IPTG + longer time = better yield (with noise) + data = [] + for temp, iptg, time in conditions: + # Base yield increases with temperature and time + base_yield = (temp - 20) * 5 + time * 10 + iptg * 20 + + # Add biological noise (Β±20%) + yield_mg_l = base_yield + np.random.normal(0, base_yield * 0.2) + yield_mg_l = max(10, yield_mg_l) # Ensure positive + + # Purity decreases at very high temp or IPTG + base_purity = 90 - abs(temp - 30) * 2 - (iptg - 0.5) ** 2 * 10 + purity_pct = base_purity + np.random.normal(0, 5) + purity_pct = np.clip(purity_pct, 50, 98) + + # Activity correlates with purity but has noise + activity = purity_pct * 0.8 + np.random.normal(0, 10) + activity = max(20, activity) + + # Growth rate decreases at extreme temps + optimal_temp_growth = 1 - abs(temp - 37) / 20 + growth_rate = optimal_temp_growth + np.random.normal(0, 0.1) + growth_rate = max(0.2, growth_rate) + + data.append({ + 'temperature_C': temp, + 'iptg_mM': iptg, + 'induction_time_h': time, + 'yield_mg_L': round(yield_mg_l, 2), + 'purity_pct': round(purity_pct, 1), + 'activity_U_mg': round(activity, 1), + 'growth_rate_OD_h': round(growth_rate, 3) + }) + + df = pd.DataFrame(data) + + print(f" Generated {len(df)} experimental conditions") + print(f" Temperature range: {df['temperature_C'].min()}-{df['temperature_C'].max()}Β°C") + print(f" IPTG range: {df['iptg_mM'].min()}-{df['iptg_mM'].max()} mM") + print(f" Induction time range: {df['induction_time_h'].min()}-{df['induction_time_h'].max()} hours") + + return df + + +def generate_growth_curve(temp, iptg, induction_time): + """Generate a simulated growth curve for a specific condition. + + In practice, this would be real OD600 measurements from your plate reader. + """ + # Time points (hours) + time_points = np.linspace(0, 12, 25) + + # Simulate bacterial growth with logistic curve + # Growth parameters depend on conditions + max_od = 2.0 + (temp - 30) * 0.1 + np.random.normal(0, 0.1) + growth_rate = 0.5 - abs(temp - 37) * 0.02 + np.random.normal(0, 0.05) + + # Logistic growth curve + od600 = max_od / (1 + np.exp(-growth_rate * (time_points - 4))) + + # Add measurement noise + od600 += np.random.normal(0, 0.05, len(time_points)) + od600 = np.maximum(0.01, od600) # OD can't be negative + + return time_points.tolist(), od600.tolist() + + +def main(): + """Main experimental analysis workflow.""" + print("=" * 70) + print("Artifacta Protein Expression Optimization Example") + print("=" * 70) + print("\nScenario: Optimizing recombinant protein expression in E. coli") + print("Testing different temperatures, IPTG concentrations, and induction times") + + # ================================================================= + # 1. Generate/Load experimental data + # ================================================================= + df = generate_synthetic_data() + + # ================================================================= + # 2. Initialize Artifacta project + # ================================================================= + print("\n" + "=" * 70) + print("Logging Experimental Data to Artifacta") + print("=" * 70) + + run = init( + project="protein-expression-optimization", + name="ecoli-recombinant-protein-screen", + config={ + "organism": "E. coli BL21(DE3)", + "protein": "His-tagged GFP", + "expression_vector": "pET28a", + "culture_volume_mL": 50, + "study_type": "factorial_design", + "parameters_tested": ["temperature", "IPTG_concentration", "induction_time"], + } + ) + print("\nArtifacta run initialized") + + # ================================================================= + # 3. Log individual condition results + # ================================================================= + print("\nLogging experimental conditions...") + + # Track best condition + best_yield = df.loc[df['yield_mg_L'].idxmax()] + best_purity = df.loc[df['purity_pct'].idxmax()] + best_activity = df.loc[df['activity_U_mg'].idxmax()] + + for _idx, row in df.iterrows(): + condition_name = f"T{row['temperature_C']}_IPTG{row['iptg_mM']}_t{row['induction_time_h']}h" + + # Log each condition as a separate run (in practice, you might log all at once) + # This demonstrates tracking individual experiments + print(f" {condition_name}: Yield={row['yield_mg_L']:.1f} mg/L, " + f"Purity={row['purity_pct']:.1f}%, Activity={row['activity_U_mg']:.1f} U/mg") + + print(f"\n Total conditions tested: {len(df)}") + print(f" Best yield: {best_yield['yield_mg_L']:.1f} mg/L at " + f"{best_yield['temperature_C']}Β°C, {best_yield['iptg_mM']} mM IPTG, " + f"{best_yield['induction_time_h']}h induction") + + # ================================================================= + # 4. Log yield comparison across temperatures (BarChart) + # ================================================================= + print("\nCreating yield comparison charts...") + + # Group by temperature and calculate mean yield + temp_yields = df.groupby('temperature_C')['yield_mg_L'].mean().round(2) + + log( + "yield_by_temperature", + BarChart( + categories=[f"{t}Β°C" for t in temp_yields.index], + groups={"Yield (mg/L)": temp_yields.tolist()}, + x_label="Temperature", + y_label="Protein Yield (mg/L)" + ) + ) + + # Group by IPTG concentration + iptg_yields = df.groupby('iptg_mM')['yield_mg_L'].mean().round(2) + + log( + "yield_by_iptg", + BarChart( + categories=[f"{c} mM" for c in iptg_yields.index], + groups={"Yield (mg/L)": iptg_yields.tolist()}, + x_label="IPTG Concentration", + y_label="Protein Yield (mg/L)" + ) + ) + + # ================================================================= + # 5. Log parameter relationships (Scatter) + # ================================================================= + print("\nCreating parameter correlation plots...") + + # Yield vs Temperature + log( + "yield_vs_temperature", + Scatter( + points=[ + {"x": temp, "y": yield_val} + for temp, yield_val in zip(df['temperature_C'], df['yield_mg_L']) + ], + x_label="Temperature (Β°C)", + y_label="Protein Yield (mg/L)" + ) + ) + + # Yield vs IPTG + log( + "yield_vs_iptg", + Scatter( + points=[ + {"x": iptg, "y": yield_val} + for iptg, yield_val in zip(df['iptg_mM'], df['yield_mg_L']) + ], + x_label="IPTG Concentration (mM)", + y_label="Protein Yield (mg/L)" + ) + ) + + # Purity vs Activity (quality relationship) + log( + "purity_vs_activity", + Scatter( + points=[ + {"x": purity, "y": activity} + for purity, activity in zip(df['purity_pct'], df['activity_U_mg']) + ], + x_label="Protein Purity (%)", + y_label="Specific Activity (U/mg)" + ) + ) + + # ================================================================= + # 6. Log growth curves for representative conditions (Series) + # ================================================================= + print("\nLogging growth curves...") + + # Log growth curves for a few representative conditions + representative_conditions = [ + (25, 0.1, 4, "low_temp_low_iptg"), + (30, 0.5, 6, "optimal_moderate"), + (37, 1.0, 8, "high_temp_high_iptg"), + ] + + for temp, iptg, induction_time, name in representative_conditions: + time_points, od600_values = generate_growth_curve(temp, iptg, induction_time) + + log( + f"growth_curve_{name}", + Series( + index="time_hours", + fields={ + "OD600": od600_values, + }, + index_values=time_points, + metadata={ + "temperature_C": temp, + "iptg_mM": iptg, + "induction_time_h": induction_time, + "description": f"Growth curve at {temp}Β°C, {iptg}mM IPTG, {induction_time}h induction" + } + ) + ) + + # ================================================================= + # 7. Save summary CSV as artifact + # ================================================================= + print("\nSaving experimental data as artifact...") + + # Save the dataframe to CSV + csv_path = "protein_expression_results.csv" + df.to_csv(csv_path, index=False) + + run.log_output( + csv_path, + name="experimental_results", + metadata={ + "description": "Complete experimental results for all tested conditions", + "total_conditions": len(df), + "best_yield_mg_L": float(best_yield['yield_mg_L']), + "best_condition": f"{best_yield['temperature_C']}Β°C, {best_yield['iptg_mM']}mM IPTG, {best_yield['induction_time_h']}h" + } + ) + + # ================================================================= + # 8. Summary and Recommendations + # ================================================================= + print("\n" + "=" * 70) + print("Analysis Complete!") + print("=" * 70) + print("\nKey Findings:") + print(f" β€’ Best yield: {best_yield['yield_mg_L']:.1f} mg/L") + print(f" Conditions: {best_yield['temperature_C']}Β°C, " + f"{best_yield['iptg_mM']} mM IPTG, {best_yield['induction_time_h']}h induction") + print(f"\n β€’ Highest purity: {best_purity['purity_pct']:.1f}%") + print(f" Conditions: {best_purity['temperature_C']}Β°C, " + f"{best_purity['iptg_mM']} mM IPTG, {best_purity['induction_time_h']}h induction") + print(f"\n β€’ Best activity: {best_activity['activity_U_mg']:.1f} U/mg") + print(f" Conditions: {best_activity['temperature_C']}Β°C, " + f"{best_activity['iptg_mM']} mM IPTG, {best_activity['induction_time_h']}h induction") + + print("\nAll experimental data logged to Artifacta") + print(" View results in the Artifacta UI to:") + print(" - Compare yields across conditions") + print(" - Analyze parameter correlations") + print(" - Review growth curves") + print(" - Access raw data CSV") + print("=" * 70) + + run.finish() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch_mnist.py b/examples/ml_frameworks/pytorch_mnist.py similarity index 85% rename from examples/pytorch_mnist.py rename to examples/ml_frameworks/pytorch_mnist.py index e9d76ff..054fb5b 100644 --- a/examples/pytorch_mnist.py +++ b/examples/ml_frameworks/pytorch_mnist.py @@ -4,17 +4,17 @@ This example demonstrates Artifacta's logging capabilities for PyTorch training: -1. **Automatic checkpoint logging** via ds.autolog() - tracks model checkpoints automatically -2. **Training metrics** via ds.Series - tracks loss and accuracy over epochs -3. **Confusion matrix** via ds.Matrix - visualizes classification performance +1. **Automatic checkpoint logging** via autolog() - tracks model checkpoints automatically +2. **Training metrics** via Series - tracks loss and accuracy over epochs +3. **Confusion matrix** via Matrix - visualizes classification performance 4. **Model artifact logging** - saves trained model with automatic metadata extraction 5. **Configuration tracking** - logs hyperparameters automatically Key Artifacta Features Demonstrated: -- ds.init() - Initialize experiment run with config -- ds.autolog() - Enable automatic checkpoint logging -- ds.Series - Log time-series metrics (loss, accuracy per epoch) -- ds.Matrix - Log 2D data (confusion matrix) +- init() - Initialize experiment run with config +- autolog() - Enable automatic checkpoint logging +- Series - Log time-series metrics (loss, accuracy per epoch) +- Matrix - Log 2D data (confusion matrix) - run.log_output() - Save model artifacts with metadata Requirements: @@ -32,7 +32,7 @@ from torch.utils.data import DataLoader from torchvision import datasets, transforms -import artifacta as ds +from artifacta import Matrix, Series, autolog, init, log class SimpleCNN(nn.Module): @@ -179,76 +179,75 @@ def main(): print("=" * 70) # ================================================================= - # 1. Configuration variations - run 3 experiments with different learning rates + # 1. Define hyperparameter search space (grid search) # ================================================================= - configs = [ - { - "learning_rate": 0.001, - "batch_size": 128, - "epochs": 5, - "optimizer": "Adam", + from itertools import product + + # Define parameter grid - typical grid search approach + param_grid = { + "learning_rate": [0.001, 0.01, 0.05], + "optimizer": ["Adam", "SGD"], + "batch_size": [128], + "epochs": [5], + } + + # Generate all combinations + keys = param_grid.keys() + values = param_grid.values() + configs = [dict(zip(keys, v)) for v in product(*values)] + + # Add metadata to each config + for config in configs: + config.update({ "model": "SimpleCNN", "dataset": "MNIST", - "name": "low-lr-adam", - }, - { - "learning_rate": 0.01, - "batch_size": 128, - "epochs": 5, - "optimizer": "SGD", - "model": "SimpleCNN", - "dataset": "MNIST", - "name": "medium-lr-sgd", - }, - { - "learning_rate": 0.05, - "batch_size": 128, - "epochs": 5, - "optimizer": "SGD", - "model": "SimpleCNN", - "dataset": "MNIST", - "name": "high-lr-sgd", - }, - ] + }) + + print(f"\nGrid search: {len(configs)} configurations") + print(" Parameter grid:") + for key, values in param_grid.items(): + print(f" {key}: {values}") # ================================================================= # 2. Run experiments with different configurations # ================================================================= for idx, config in enumerate(configs, 1): + # Generate run name from config + run_name = f"lr{config['learning_rate']}-{config['optimizer'].lower()}-bs{config['batch_size']}" + print(f"\n{'=' * 70}") - print(f"Experiment {idx}/3: {config['name']}") + print(f"Run {idx}/{len(configs)}: {run_name}") print(f"{'=' * 70}") print("\nConfiguration:") for key, value in config.items(): - if key != "name": - print(f" {key}: {value}") + print(f" {key}: {value}") # Initialize Artifacta run with configuration - run = ds.init( + run = init( project="mnist-classification", - name=f"pytorch-cnn-{config['name']}", - config={k: v for k, v in config.items() if k != "name"}, + name=run_name, + config=config, ) - print("\nβœ“ Artifacta run initialized") + print("\nArtifacta run initialized") # ================================================================= # 3. Enable autolog for automatic checkpoint tracking # This will log model checkpoints automatically during training # ================================================================= - ds.autolog(framework="pytorch") + autolog(framework="pytorch") # ================================================================= # 4. Setup device (GPU if available, otherwise CPU) # ================================================================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"\nβœ“ Using device: {device}") + print(f"\nUsing device: {device}") # ================================================================= # 5. Load MNIST dataset # Downloads to ./data directory on first run # ================================================================= - print("\nβœ“ Loading MNIST dataset...") + print("\nLoading MNIST dataset...") # Transform: Convert to tensor and normalize to [0, 1] transform = transforms.Compose( @@ -275,7 +274,7 @@ def main(): # ================================================================= # 6. Create model, optimizer, and move to device # ================================================================= - print("\nβœ“ Creating model...") + print("\nCreating model...") model = SimpleCNN().to(device) # Create optimizer based on config @@ -293,7 +292,7 @@ def main(): # ================================================================= # 7. Training loop # ================================================================= - print(f"\nβœ“ Training for {config['epochs']} epochs...") + print(f"\nTraining for {config['epochs']} epochs...") # Track metrics across epochs train_losses = [] @@ -325,12 +324,12 @@ def main(): # 8. Log training metrics as Series (time-series data) # This creates interactive plots in the Artifacta UI # ================================================================= - print("\nβœ“ Logging training metrics...") + print("\nLogging training metrics...") # Log loss curves - ds.log( + log( "loss_curves", - ds.Series( + Series( index="epoch", fields={ "train_loss": train_losses, @@ -341,9 +340,9 @@ def main(): ) # Log accuracy curves - ds.log( + log( "accuracy_curves", - ds.Series( + Series( index="epoch", fields={ "train_accuracy": train_accuracies, @@ -357,7 +356,7 @@ def main(): # 9. Generate and log confusion matrix # Shows which digits are confused with each other # ================================================================= - print("\nβœ“ Generating confusion matrix...") + print("\nGenerating confusion matrix...") # Get predictions on full test set _, _, all_preds, all_targets = evaluate(model, device, test_loader) @@ -367,9 +366,9 @@ def main(): # Log as Matrix primitive digit_labels = [str(i) for i in range(10)] - ds.log( + log( "confusion_matrix", - ds.Matrix( + Matrix( rows=digit_labels, cols=digit_labels, values=cm.tolist(), @@ -383,7 +382,7 @@ def main(): # 10. Save and log the trained model # Artifacta automatically extracts model metadata # ================================================================= - print("\nβœ“ Saving model...") + print("\nSaving model...") # Save model checkpoint model_path = "mnist_cnn.pt" @@ -422,7 +421,7 @@ def main(): print(f" Train Accuracy: {train_accuracies[-1]:.2f}%") print(f" Val Accuracy: {val_accuracies[-1]:.2f}%") print(f" Val Loss: {val_losses[-1]:.4f}") - print("\nβœ“ All metrics and artifacts logged to Artifacta") + print("\nAll metrics and artifacts logged to Artifacta") print(" View your results in the Artifacta UI!") print("=" * 70) diff --git a/examples/ml_frameworks/sklearn_classification.py b/examples/ml_frameworks/sklearn_classification.py new file mode 100644 index 0000000..35dbf47 --- /dev/null +++ b/examples/ml_frameworks/sklearn_classification.py @@ -0,0 +1,248 @@ +""" +Scikit-learn Classification with Artifacta Autolog +================================================== + +This example demonstrates Artifacta's autolog integration for scikit-learn: +1. **Automatic parameter logging** - Model hyperparameters logged automatically +2. **Automatic metric logging** - Accuracy, precision, recall, F1 automatically computed +3. **Automatic model logging** - Trained model saved as artifact +4. **ROC/PR curves** - Binary classification performance visualization +5. **Confusion matrix** - Multi-class classification analysis +6. **Feature importance** - For tree-based models + +Key Artifacta Features: +- autolog() - Enable automatic logging for sklearn +- Curve - ROC and Precision-Recall curves +- Matrix - Confusion matrix +- BarChart - Feature importance + +Requirements: + pip install artifacta scikit-learn numpy + +Usage: + python examples/ml_frameworks/sklearn_classification.py +""" + +import numpy as np +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import ( + confusion_matrix, + precision_recall_curve, + roc_auc_score, + roc_curve, +) +from sklearn.model_selection import train_test_split + +from artifacta import BarChart, Curve, Matrix, autolog, init + + +def main(): + print("=" * 70) + print("Artifacta - Scikit-learn Classification Example") + print("=" * 70) + + # ================================================================= + # 1. Initialize Artifacta run with configuration + # ================================================================= + config = { + "model": "RandomForestClassifier", + "n_estimators": 100, + "max_depth": 10, + "min_samples_split": 5, + "random_state": 42, + "dataset": "synthetic_binary", + "n_samples": 2000, + "n_features": 20, + } + + run = init( + project="sklearn-demo", + name="rf-binary-classification", + config=config, + ) + + print("\nArtifacta run initialized") + + # ================================================================= + # 2. Enable autolog for automatic parameter/metric/model logging + # This will automatically log model parameters and training metrics + # ================================================================= + autolog(framework="sklearn") + + # ================================================================= + # 3. Create synthetic binary classification dataset + # ================================================================= + print("\nGenerating synthetic dataset...") + + X, y = make_classification( + n_samples=config["n_samples"], + n_features=config["n_features"], + n_informative=15, + n_redundant=5, + n_classes=2, + random_state=config["random_state"], + ) + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=config["random_state"] + ) + + print(f" Training samples: {len(X_train)}") + print(f" Test samples: {len(X_test)}") + print(f" Features: {X.shape[1]}") + + # ================================================================= + # 4. Train Random Forest classifier + # Autolog automatically captures: + # - Model parameters (n_estimators, max_depth, etc.) + # - Training metrics (accuracy, precision, recall, F1) + # - Trained model artifact + # ================================================================= + print("\nTraining Random Forest classifier...") + + clf = RandomForestClassifier( + n_estimators=config["n_estimators"], + max_depth=config["max_depth"], + min_samples_split=config["min_samples_split"], + random_state=config["random_state"], + ) + + clf.fit(X_train, y_train) + + print(" Training complete!") + + # Get predictions and probabilities + y_pred = clf.predict(X_test) + y_proba = clf.predict_proba(X_test)[:, 1] # Probability of positive class + + # ================================================================= + # 5. Log ROC Curve (Receiver Operating Characteristic) + # Shows trade-off between true positive rate and false positive rate + # ================================================================= + print("\nLogging ROC curve...") + + fpr, tpr, _ = roc_curve(y_test, y_proba) + auc = roc_auc_score(y_test, y_proba) + + run.log( + "roc_curve", + Curve( + x=fpr.tolist(), + y=tpr.tolist(), + x_label="False Positive Rate", + y_label="True Positive Rate", + baseline="diagonal", + metric={"name": "AUC-ROC", "value": float(auc)}, + metadata={ + "description": "ROC curve for binary classification", + "interpretation": "Higher AUC (closer to 1.0) indicates better performance", + }, + ), + ) + + print(f" AUC-ROC: {auc:.4f}") + + # ================================================================= + # 6. Log Precision-Recall Curve + # Shows trade-off between precision and recall + # More informative than ROC for imbalanced datasets + # ================================================================= + print("\nLogging Precision-Recall curve...") + + precision, recall, _ = precision_recall_curve(y_test, y_proba) + + # Calculate Average Precision (area under PR curve) + from sklearn.metrics import average_precision_score + + avg_precision = average_precision_score(y_test, y_proba) + + run.log( + "precision_recall_curve", + Curve( + x=recall.tolist(), + y=precision.tolist(), + x_label="Recall", + y_label="Precision", + metric={"name": "Average Precision", "value": float(avg_precision)}, + metadata={ + "description": "Precision-Recall curve", + "interpretation": "Higher average precision indicates better performance", + }, + ), + ) + + print(f" Average Precision: {avg_precision:.4f}") + + # ================================================================= + # 7. Log Confusion Matrix + # Shows how predictions compare to actual labels + # ================================================================= + print("\nLogging confusion matrix...") + + cm = confusion_matrix(y_test, y_pred) + + run.log( + "confusion_matrix", + Matrix( + rows=["Negative (0)", "Positive (1)"], + cols=["Predicted Negative", "Predicted Positive"], + values=cm.tolist(), + metadata={ + "type": "confusion_matrix", + "total_samples": int(cm.sum()), + "true_negatives": int(cm[0, 0]), + "false_positives": int(cm[0, 1]), + "false_negatives": int(cm[1, 0]), + "true_positives": int(cm[1, 1]), + }, + ), + ) + + # ================================================================= + # 8. Log Feature Importance (for tree-based models) + # Shows which features contribute most to predictions + # ================================================================= + print("\nLogging feature importance...") + + importances = clf.feature_importances_ + # Get top 10 most important features + top_indices = np.argsort(importances)[-10:][::-1] + + run.log( + "feature_importance", + BarChart( + categories=[f"Feature_{i}" for i in top_indices], + groups={"Importance": importances[top_indices].tolist()}, + x_label="Feature", + y_label="Importance", + metadata={"description": "Top 10 most important features"}, + ), + ) + + # ================================================================= + # 9. Final Summary + # ================================================================= + print("\n" + "=" * 70) + print("Training Complete!") + print("=" * 70) + print("Model Performance:") + print(f" AUC-ROC: {auc:.4f}") + print(f" Average Precision: {avg_precision:.4f}") + print(f" Test Accuracy: {clf.score(X_test, y_test):.4f}") + print("\nAll metrics and artifacts logged to Artifacta") + print(" Automatically logged:") + print(" - Model parameters (autolog)") + print(" - Training metrics (autolog)") + print(" - Trained model (autolog)") + print(" Manually logged:") + print(" - ROC curve") + print(" - Precision-Recall curve") + print(" - Confusion matrix") + print(" - Feature importance") + print("\nView results in the Artifacta UI!") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/tensorflow_regression.py b/examples/ml_frameworks/tensorflow_regression.py similarity index 82% rename from examples/tensorflow_regression.py rename to examples/ml_frameworks/tensorflow_regression.py index 7774256..d2515db 100644 --- a/examples/tensorflow_regression.py +++ b/examples/ml_frameworks/tensorflow_regression.py @@ -4,18 +4,18 @@ This example demonstrates Artifacta's logging capabilities for TensorFlow/Keras training: -1. **Automatic checkpoint logging** via ds.autolog() - tracks model checkpoints automatically -2. **Training curves** via ds.Series - tracks train/validation loss over epochs -3. **Prediction visualization** via ds.Scatter - plots predicted vs actual values -4. **Residual analysis** via ds.Distribution - analyzes prediction errors +1. **Automatic checkpoint logging** via autolog() - tracks model checkpoints automatically +2. **Training curves** via Series - tracks train/validation loss over epochs +3. **Prediction visualization** via Scatter - plots predicted vs actual values +4. **Residual analysis** via Distribution - analyzes prediction errors 5. **Model artifact logging** - saves trained model with metadata Key Artifacta Features Demonstrated: -- ds.init() - Initialize experiment run with config -- ds.autolog() - Enable automatic checkpoint logging for Keras -- ds.Series - Log time-series metrics (loss curves) -- ds.Scatter - Log 2D scatter plots (predictions vs actual) -- ds.Distribution - Log value distributions (residuals) +- init() - Initialize experiment run with config +- autolog() - Enable automatic checkpoint logging for Keras +- Series - Log time-series metrics (loss curves) +- Scatter - Log 2D scatter plots (predictions vs actual) +- Distribution - Log value distributions (residuals) - run.log_output() - Save model artifacts Requirements: @@ -30,7 +30,7 @@ from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler -import artifacta as ds +from artifacta import Distribution, Scatter, Series, autolog, init, log # Import TensorFlow/Keras try: @@ -54,7 +54,7 @@ def create_synthetic_data(n_samples=1000, n_features=10, noise=10.0, random_stat Returns: Tuple of (x_train, x_test, y_train, y_test, scaler_x, scaler_y) """ - print("\nβœ“ Generating synthetic regression data...") + print("\nGenerating synthetic regression data...") # Generate regression data # make_regression creates a random regression problem @@ -111,7 +111,7 @@ def create_model(input_dim, hidden_dim=64, learning_rate=0.001): Returns: Compiled Keras model """ - print("\nβœ“ Creating neural network model...") + print("\nCreating neural network model...") model = keras.Sequential( [ @@ -155,79 +155,78 @@ def main(): print("=" * 70) # ================================================================= - # 1. Configuration variations - run 3 experiments with different hyperparameters + # 1. Define hyperparameter search space (grid search) # ================================================================= - configs = [ - { - "hidden_dim": 32, - "learning_rate": 0.001, - "batch_size": 32, - "epochs": 50, - "n_samples": 1000, - "n_features": 10, - "noise": 10.0, + from itertools import product + + # Define parameter grid - typical grid search approach + param_grid = { + "hidden_dim": [32, 64, 128], + "learning_rate": [0.001, 0.01], + "batch_size": [32], + "epochs": [10], + } + + # Fixed dataset parameters + dataset_config = { + "n_samples": 1000, + "n_features": 10, + "noise": 10.0, + } + + # Generate all combinations + keys = param_grid.keys() + values = param_grid.values() + configs = [dict(zip(keys, v)) for v in product(*values)] + + # Add dataset config and metadata to each config + for config in configs: + config.update(dataset_config) + config.update({ "optimizer": "Adam", "loss": "mse", "model": "FeedForwardNN", - "name": "small-network", - }, - { - "hidden_dim": 64, - "learning_rate": 0.01, - "batch_size": 32, - "epochs": 50, - "n_samples": 1000, - "n_features": 10, - "noise": 10.0, - "optimizer": "Adam", - "loss": "mse", - "model": "FeedForwardNN", - "name": "medium-network", - }, - { - "hidden_dim": 128, - "learning_rate": 0.001, - "batch_size": 32, - "epochs": 50, - "n_samples": 1000, - "n_features": 10, - "noise": 10.0, - "optimizer": "Adam", - "loss": "mse", - "model": "FeedForwardNN", - "name": "large-network", - }, - ] + }) + + print(f"\nGrid search: {len(configs)} configurations") + print(" Parameter grid:") + for key, values in param_grid.items(): + print(f" {key}: {values}") + print(" Dataset config:") + for key, value in dataset_config.items(): + print(f" {key}: {value}") # ================================================================= # 2. Run experiments with different configurations # ================================================================= for idx, config in enumerate(configs, 1): + # Generate run name from config + run_name = f"h{config['hidden_dim']}-lr{config['learning_rate']}-bs{config['batch_size']}" + print(f"\n{'=' * 70}") - print(f"Experiment {idx}/3: {config['name']}") + print(f"Run {idx}/{len(configs)}: {run_name}") print(f"{'=' * 70}") print("\nConfiguration:") for key, value in config.items(): - if key != "name": - print(f" {key}: {value}") + print(f" {key}: {value}") # ================================================================= # 3. Initialize Artifacta run # This automatically logs config and environment info # ================================================================= - run = ds.init( + run = init( project="regression-demo", - name=f"tensorflow-regression-{config['name']}", - config={k: v for k, v in config.items() if k != "name"}, + name=run_name, + config=config, ) - print("\nβœ“ Artifacta run initialized") + print("\nArtifacta run initialized") # ================================================================= # 4. Enable autolog for automatic checkpoint tracking # This will log model checkpoints and metrics automatically # ================================================================= - ds.autolog(framework="tensorflow") + autolog(framework="tensorflow") # ================================================================= # 5. Generate synthetic regression dataset @@ -252,7 +251,7 @@ def main(): # 7. Train the model # Keras automatically tracks metrics during training # ================================================================= - print("\nβœ“ Training model...") + print("\nTraining model...") print(f" Epochs: {config['epochs']}") print(f" Batch size: {config['batch_size']}") print("-" * 70) @@ -266,13 +265,13 @@ def main(): verbose=1, # Show progress bar ) - print("\nβœ“ Training complete!") + print("\nTraining complete!") # ================================================================= # 8. Log training curves as Series # Shows how loss decreases over epochs # ================================================================= - print("\nβœ“ Logging training metrics...") + print("\nLogging training metrics...") # Extract metrics from training history train_loss = history.history["loss"] @@ -281,9 +280,9 @@ def main(): val_mae = history.history["val_mae"] # Log loss curves - ds.log( + log( "loss_curves", - ds.Series( + Series( index="epoch", fields={ "train_loss": train_loss, @@ -294,9 +293,9 @@ def main(): ) # Log MAE (Mean Absolute Error) curves - ds.log( + log( "mae_curves", - ds.Series( + Series( index="epoch", fields={ "train_mae": train_mae, @@ -309,7 +308,7 @@ def main(): # ================================================================= # 9. Make predictions on test set # ================================================================= - print("\nβœ“ Generating predictions...") + print("\nGenerating predictions...") y_pred = model.predict(x_test, verbose=0).flatten() @@ -319,7 +318,7 @@ def main(): # 10. Log predictions vs actual as Scatter plot # Visualizes model performance - points should fall on diagonal # ================================================================= - print("\nβœ“ Logging prediction scatter plot...") + print("\nLogging prediction scatter plot...") # Create scatter plot data # Each point represents one test sample @@ -328,9 +327,9 @@ def main(): for actual, pred in zip(y_test, y_pred) ] - ds.log( + log( "predictions_vs_actual", - ds.Scatter( + Scatter( points=scatter_points, x_label="Actual Values", y_label="Predicted Values", @@ -346,7 +345,7 @@ def main(): # Residual = Actual - Predicted # Good model should have residuals centered around 0 # ================================================================= - print("\nβœ“ Analyzing residuals...") + print("\nAnalyzing residuals...") residuals = y_test - y_pred @@ -361,9 +360,9 @@ def main(): print(f" Residual range: [{residual_min:.4f}, {residual_max:.4f}]") # Log residual distribution - ds.log( + log( "residual_distribution", - ds.Distribution( + Distribution( values=residuals.tolist(), metadata={ "description": "Prediction errors (actual - predicted)", @@ -377,7 +376,7 @@ def main(): # ================================================================= # 12. Calculate final metrics # ================================================================= - print("\nβœ“ Calculating final metrics...") + print("\nCalculating final metrics...") # Mean Squared Error mse = np.mean((y_test - y_pred) ** 2) @@ -397,9 +396,9 @@ def main(): print(f" RΒ²: {r2_score:.4f}") # Log final metrics as Series (single point) - ds.log( + log( "final_metrics", - ds.Series( + Series( index="metric", fields={ "value": [mse, rmse, mae, r2_score], @@ -412,7 +411,7 @@ def main(): # 13. Save and log the trained model # Keras models are saved in .keras format (recommended) # ================================================================= - print("\nβœ“ Saving model...") + print("\nSaving model...") model_path = "regression_model.keras" model.save(model_path) @@ -446,7 +445,7 @@ def main(): print("\nInterpretation:") print(f" - RΒ² = {r2_score:.4f} means the model explains {r2_score * 100:.1f}% of variance") print(f" - Average prediction error (MAE): {mae:.4f} standard units") - print("\nβœ“ All metrics and artifacts logged to Artifacta") + print("\nAll metrics and artifacts logged to Artifacta") print(" View your results in the Artifacta UI!") print("=" * 70) diff --git a/examples/ml_frameworks/xgboost_regression.py b/examples/ml_frameworks/xgboost_regression.py new file mode 100644 index 0000000..f5e4824 --- /dev/null +++ b/examples/ml_frameworks/xgboost_regression.py @@ -0,0 +1,253 @@ +""" +XGBoost Regression with Artifacta Autolog +========================================== + +This example demonstrates Artifacta's autolog integration for XGBoost: +1. **Automatic parameter logging** - XGBoost hyperparameters logged automatically +2. **Automatic metric logging** - Training and validation metrics per iteration +3. **Automatic model logging** - Trained booster saved as artifact +4. **Feature importance** - Which features matter most +5. **Prediction analysis** - Predicted vs actual scatter plot +6. **Hyperparameter sweep** - Test multiple configurations + +Key Artifacta Features: +- autolog() - Enable automatic logging for XGBoost +- Series - Training/validation curves over iterations +- Scatter - Predicted vs actual values +- BarChart - Feature importance visualization + +Requirements: + pip install artifacta xgboost scikit-learn numpy + +Usage: + python examples/ml_frameworks/xgboost_regression.py +""" + +import numpy as np +import xgboost as xgb +from sklearn.datasets import make_regression +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score +from sklearn.model_selection import train_test_split + +from artifacta import BarChart, Scatter, Series, autolog, init + + +def main(): + print("=" * 70) + print("Artifacta - XGBoost Regression Example") + print("=" * 70) + + # ================================================================= + # 1. Define hyperparameter grid for sweep + # ================================================================= + configs = [ + { + "name": "shallow-fast", + "max_depth": 3, + "learning_rate": 0.1, + "n_estimators": 50, + }, + { + "name": "medium-balanced", + "max_depth": 6, + "learning_rate": 0.05, + "n_estimators": 100, + }, + { + "name": "deep-slow", + "max_depth": 10, + "learning_rate": 0.01, + "n_estimators": 150, + }, + ] + + print(f"\nRunning hyperparameter sweep with {len(configs)} configurations\n") + + # ================================================================= + # 2. Generate synthetic regression dataset once + # ================================================================= + print("Generating synthetic regression dataset...") + + X, y = make_regression( + n_samples=2000, + n_features=20, + n_informative=15, + noise=10.0, + random_state=42, + ) + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + + print(f" Training samples: {len(X_train)}") + print(f" Test samples: {len(X_test)}") + print(f" Features: {X.shape[1]}") + + # ================================================================= + # 3. Run experiments for each configuration + # ================================================================= + for idx, config in enumerate(configs, 1): + print(f"\n{'=' * 70}") + print(f"Experiment {idx}/{len(configs)}: {config['name']}") + print(f"{'=' * 70}") + + # Initialize Artifacta run + run_config = { + "model": "XGBRegressor", + "max_depth": config["max_depth"], + "learning_rate": config["learning_rate"], + "n_estimators": config["n_estimators"], + "objective": "reg:squarederror", + "random_state": 42, + } + + run = init( + project="xgboost-demo", + name=f"xgb-{config['name']}", + config=run_config, + ) + + print("\nConfiguration:") + print(f" max_depth: {config['max_depth']}") + print(f" learning_rate: {config['learning_rate']}") + print(f" n_estimators: {config['n_estimators']}") + + # ================================================================= + # 4. Enable autolog + # Automatically logs: parameters, per-iteration metrics, model + # ================================================================= + autolog(framework="xgboost") + + # ================================================================= + # 5. Train XGBoost model + # Autolog captures training/validation metrics automatically + # ================================================================= + print("\nTraining XGBoost model...") + + model = xgb.XGBRegressor( + max_depth=config["max_depth"], + learning_rate=config["learning_rate"], + n_estimators=config["n_estimators"], + objective="reg:squarederror", + random_state=42, + eval_metric="rmse", + ) + + # Fit with validation set for eval metrics + model.fit( + X_train, + y_train, + eval_set=[(X_train, y_train), (X_test, y_test)], + verbose=False, + ) + + print(" Training complete!") + + # ================================================================= + # 6. Make predictions and calculate metrics + # ================================================================= + print("\nEvaluating model...") + + y_pred = model.predict(X_test) + + mse = mean_squared_error(y_test, y_pred) + rmse = np.sqrt(mse) + mae = mean_absolute_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) + + print(f" RMSE: {rmse:.2f}") + print(f" MAE: {mae:.2f}") + print(f" RΒ²: {r2:.4f}") + + # ================================================================= + # 7. Log predictions vs actual (Scatter plot) + # Points should fall on y=x diagonal for perfect predictions + # ================================================================= + print("\nLogging prediction analysis...") + + points = [ + {"x": float(actual), "y": float(pred)} + for actual, pred in zip(y_test, y_pred) + ] + + run.log( + "predictions_vs_actual", + Scatter( + points=points, + x_label="Actual Values", + y_label="Predicted Values", + metadata={ + "description": "Predictions vs actual on test set", + "rmse": float(rmse), + "r2": float(r2), + }, + ), + ) + + # ================================================================= + # 8. Log feature importance + # Shows which features contribute most to predictions + # ================================================================= + print("\nLogging feature importance...") + + importance_scores = model.feature_importances_ + # Get top 10 features + top_indices = np.argsort(importance_scores)[-10:][::-1] + + run.log( + "feature_importance", + BarChart( + categories=[f"Feature_{i}" for i in top_indices], + groups={"Importance": importance_scores[top_indices].tolist()}, + x_label="Feature", + y_label="Importance Score", + metadata={"description": "Top 10 most important features"}, + ), + ) + + # ================================================================= + # 9. Log final metrics as Series (for comparison) + # ================================================================= + run.log( + "final_metrics", + Series( + index="metric", + fields={ + "value": [float(rmse), float(mae), float(r2)], + }, + index_values=["RMSE", "MAE", "RΒ²"], + ), + ) + + # Finish run + run.finish() + + print(f"\nExperiment {config['name']} complete") + + # ================================================================= + # Final summary + # ================================================================= + print("\n" + "=" * 70) + print("Hyperparameter Sweep Complete!") + print("=" * 70) + print(f"Trained {len(configs)} XGBoost models") + print(" Configurations tested:") + for config in configs: + print(f" - {config['name']}: depth={config['max_depth']}, lr={config['learning_rate']}") + print("\nAll metrics and artifacts logged to Artifacta") + print(" Automatically logged (per run):") + print(" - XGBoost parameters (autolog)") + print(" - Training/validation curves (autolog)") + print(" - Trained model (autolog)") + print(" Manually logged (per run):") + print(" - Prediction scatter plot") + print(" - Feature importance") + print(" - Final metrics") + print("\nView and compare all runs in the Artifacta UI!") + print("Use the comparison view to find the best configuration.") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..e721736 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,24 @@ +# Artifacta Examples - All Dependencies +# ===================================== +# Install with: pip install -r examples/requirements.txt +# +# This includes artifacta and all ML framework dependencies needed to run +# all examples in this directory. + +# Core package +artifacta + +# Data manipulation (used by all examples) +numpy~=1.24 +pandas~=2.0 + +# ML Frameworks +scikit-learn~=1.3 # For sklearn_classification.py +xgboost~=2.0 # For xgboost_regression.py +torch~=2.0 # For pytorch_mnist.py (CPU version) +torchvision~=0.15 # For pytorch_mnist.py (MNIST dataset) +tensorflow~=2.13 # For tensorflow_regression.py + +# Optional: GPU support for PyTorch (uncomment if you have CUDA) +# torch>=2.0.0+cu118 +# torchvision>=0.15.0+cu118 diff --git a/examples/run_all_examples.py b/examples/run_all_examples.py new file mode 100644 index 0000000..3574608 --- /dev/null +++ b/examples/run_all_examples.py @@ -0,0 +1,263 @@ +""" +Run All Artifacta Examples +=========================== + +This script runs all example files in sequence and reports the results. +Useful for: +- Validating that all examples work correctly +- Regression testing after code changes +- Quick smoke test before releases + +Usage: + python examples/run_all_examples.py # Run all examples + python examples/run_all_examples.py --category core # Run only core examples + python examples/run_all_examples.py --fast # Skip long-running examples + +Categories: + - core: Basic Artifacta concepts and primitives + - ml_frameworks: Machine learning framework integrations + - domain_specific: Domain-specific use cases + - all: All examples (default) +""" + +import argparse +import subprocess +import sys +import time +from pathlib import Path + +# Define all examples with metadata +EXAMPLES = [ + # Core examples + { + "path": "core/01_basic_tracking.py", + "name": "Basic Tracking", + "category": "core", + "duration": "fast", + "description": "Minimal hello world example", + }, + { + "path": "core/02_all_primitives.py", + "name": "All Primitives Demo", + "category": "core", + "duration": "fast", + "description": "Showcase all 7 data primitives", + }, + # ML Framework examples + { + "path": "ml_frameworks/sklearn_classification.py", + "name": "Sklearn Classification", + "category": "ml_frameworks", + "duration": "medium", + "description": "Sklearn with ROC/PR curves", + }, + { + "path": "ml_frameworks/xgboost_regression.py", + "name": "XGBoost Regression", + "category": "ml_frameworks", + "duration": "medium", + "description": "XGBoost with feature importance", + }, + { + "path": "ml_frameworks/pytorch_mnist.py", + "name": "PyTorch MNIST", + "category": "ml_frameworks", + "duration": "slow", + "description": "PyTorch Lightning with autolog", + }, + { + "path": "ml_frameworks/tensorflow_regression.py", + "name": "TensorFlow Regression", + "category": "ml_frameworks", + "duration": "slow", + "description": "TensorFlow/Keras with autolog", + }, + # Domain-specific examples + { + "path": "domain_specific/ab_testing_experiment.py", + "name": "A/B Testing", + "category": "domain_specific", + "duration": "fast", + "description": "Domain-agnostic A/B test tracking", + }, + { + "path": "domain_specific/protein_expression.py", + "name": "Protein Expression", + "category": "domain_specific", + "duration": "fast", + "description": "Wet lab experiment optimization", + }, +] + + +def run_example(example_path: Path) -> dict: + """Run a single example and return results. + + Args: + example_path: Path to the example file + + Returns: + dict with keys: success (bool), duration (float), output (str), error (str) + """ + start_time = time.time() + + try: + result = subprocess.run( + [sys.executable, str(example_path)], + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + duration = time.time() - start_time + success = result.returncode == 0 + + return { + "success": success, + "duration": duration, + "output": result.stdout, + "error": result.stderr if not success else None, + } + + except subprocess.TimeoutExpired: + duration = time.time() - start_time + return { + "success": False, + "duration": duration, + "output": "", + "error": "Example timed out after 5 minutes", + } + + except Exception as e: + duration = time.time() - start_time + return { + "success": False, + "duration": duration, + "output": "", + "error": str(e), + } + + +def print_header(): + """Print script header.""" + print("=" * 80) + print("Artifacta - Run All Examples") + print("=" * 80) + print() + + +def print_example_header(idx: int, total: int, example: dict): + """Print example execution header.""" + print(f"\n[{idx}/{total}] {example['name']}") + print(f" File: {example['path']}") + print(f" Description: {example['description']}") + print(" Running... ", end="", flush=True) + + +def print_result(result: dict): + """Print example result.""" + if result["success"]: + print(f"PASSED ({result['duration']:.1f}s)") + else: + print(f"FAILED ({result['duration']:.1f}s)") + if result["error"]: + print("\n Error:") + for line in result["error"].split("\n")[:10]: # Show first 10 lines + print(f" {line}") + + +def print_summary(results: list[dict], examples: list[dict]): + """Print final summary table.""" + passed = sum(1 for r in results if r["success"]) + failed = sum(1 for r in results if not r["success"]) + total_duration = sum(r["duration"] for r in results) + + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + print(f"\nTotal: {len(results)} examples") + print(f"Passed: {passed}") + print(f"Failed: {failed}") + print(f"⏱ Total time: {total_duration:.1f}s") + + if failed > 0: + print("\nFailed examples:") + for example, result in zip(examples, results): + if not result["success"]: + print(f" {example['name']} ({example['path']})") + + print("\n" + "=" * 80) + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Run all Artifacta examples") + parser.add_argument( + "--category", + choices=["core", "ml_frameworks", "domain_specific", "all"], + default="all", + help="Category of examples to run", + ) + parser.add_argument( + "--fast", + action="store_true", + help="Skip slow examples (only run fast/medium duration)", + ) + + args = parser.parse_args() + + # Filter examples based on arguments + examples_to_run = EXAMPLES + + if args.category != "all": + examples_to_run = [e for e in examples_to_run if e["category"] == args.category] + + if args.fast: + examples_to_run = [e for e in examples_to_run if e["duration"] != "slow"] + + if not examples_to_run: + print("No examples match the specified filters.") + return 1 + + # Print header + print_header() + print(f"Running {len(examples_to_run)} examples") + if args.category != "all": + print(f"Category: {args.category}") + if args.fast: + print("Mode: Fast (skipping slow examples)") + print() + + # Get base directory + examples_dir = Path(__file__).parent + + # Run each example + results = [] + for idx, example in enumerate(examples_to_run, 1): + example_path = examples_dir / example["path"] + + if not example_path.exists(): + print(f"\n[{idx}/{len(examples_to_run)}] {example['name']}") + print(f" SKIPPED - File not found: {example_path}") + results.append({ + "success": False, + "duration": 0, + "output": "", + "error": f"File not found: {example_path}", + }) + continue + + print_example_header(idx, len(examples_to_run), example) + result = run_example(example_path) + results.append(result) + print_result(result) + + # Print summary + print_summary(results, examples_to_run) + + # Exit with error code if any examples failed + return 1 if any(not r["success"] for r in results) else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/index.html b/index.html deleted file mode 100644 index a81d3df..0000000 --- a/index.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - Artifacta - Track experiments across any domain - - - - -
- - - diff --git a/package.json b/package.json index 8257436..f9642ad 100644 --- a/package.json +++ b/package.json @@ -1,13 +1,17 @@ { "name": "artifacta", "version": "0.1.0", + "type": "module", "description": "Universal experiment and artifact tracking β€” gain insights and optimize models with confidence", "scripts": { - "dev": "vite", - "build": "vite build", - "preview": "vite preview", - "lint": "eslint src --ext js,jsx", - "test": "vitest" + "dev": "vite --config config/vite.config.js", + "build": "vite build --config config/vite.config.js", + "preview": "vite preview --config config/vite.config.js", + "lint": "eslint src --ext js,jsx --config config/eslint.config.js", + "test": "vitest", + "test:e2e": "playwright test --config config/playwright.config.js", + "test:e2e:ui": "playwright test --config config/playwright.config.js --ui", + "docs:ui": "jsdoc -c config/jsdoc.json" }, "dependencies": { "@aarkue/tiptap-math-extension": "^1.4.0", @@ -37,13 +41,16 @@ }, "devDependencies": { "@eslint/js": "^9.39.2", + "@playwright/test": "^1.48.0", "@types/node": "^24.10.0", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", "@vitejs/plugin-react": "^4.3.3", "eslint": "^9.14.0", + "eslint-plugin-jsdoc": "^62.3.0", "eslint-plugin-react": "^7.37.2", "eslint-plugin-react-hooks": "^5.0.0", + "jsdoc": "^4.0.5", "knip": "^5.82.1", "sass": "^1.97.2", "sass-embedded": "^1.97.2", @@ -67,5 +74,15 @@ "bugs": { "url": "https://github.com/yourusername/artifacta/issues" }, - "license": "MIT" + "license": "MIT", + "depcheck": { + "ignores": [ + "knip" + ], + "ignoreDirs": [ + "dist", + "node_modules", + ".git" + ] + } } diff --git a/pyproject.toml b/pyproject.toml index fff08a3..7ee0e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,43 +31,48 @@ classifiers = [ ] dependencies = [ - "click>=8.0.0", - "flask>=2.0.0", - "flask-cors>=3.0.0", - "pillow>=9.0.0", - "fastapi>=0.100.0", - "uvicorn>=0.23.0", - "python-multipart>=0.0.6", - "sqlalchemy>=2.0.0", - "numpy>=1.24.0", - "requests>=2.31.0", - "psutil>=5.9.0", + "click~=8.0", + "flask~=2.0", + "flask-cors~=3.0", + "pillow>=9.0,<11.0; python_version < '3.12'", + "pillow>=10.0,<11.0; python_version >= '3.12'", + "fastapi~=0.100", + "uvicorn~=0.23", + "python-multipart~=0.0.6", + "sqlalchemy~=2.0", + "numpy~=1.24", + "requests~=2.31", + "psutil~=5.9", ] [project.optional-dependencies] dev = [ - "pytest>=7.0.0", - "ruff>=0.1.0", - "mypy>=1.0.0", - "pre-commit>=3.0.0", - "pydocstyle>=6.3.0", - "sphinx>=7.0.0", - "sphinx-rtd-theme>=2.0.0", - "sphinx-autodoc-typehints>=1.24.0", - "numpy>=1.24.0", - "matplotlib>=3.7.0", - "seaborn>=0.12.0", - "scikit-learn>=1.3.0", - "requests>=2.31.0", - "sqlalchemy>=2.0.0", - "pillow>=9.0.0", - "pytorch-lightning>=2.0.0", - "torch>=2.0.0", - "torchvision>=0.15.0", - "tensorflow>=2.13.0", - "reportlab>=4.0.0", - "av>=10.0.0", - "bump-my-version>=0.26.0", + "pytest~=7.0", + "ruff~=0.1", + "mypy~=1.0", + "pre-commit~=3.0", + "pydocstyle~=6.3", + "sphinx~=7.0", + "sphinx-rtd-theme~=2.0", + "sphinx-autodoc-typehints~=1.24", + "numpy~=1.24", + "matplotlib~=3.7", + "seaborn~=0.12", + "scikit-learn~=1.3", + "xgboost~=1.7", + "lightgbm~=4.0", + "requests~=2.31", + "sqlalchemy~=2.0", + "pytorch-lightning~=2.0", + "torch~=2.0", + "torchvision~=0.15", + "tensorflow>=2.12,<3.0; python_version < '3.12'", + "tensorflow>=2.16,<3.0; python_version >= '3.12'", + "reportlab~=4.0", + "bump-my-version~=0.26", +] +video = [ + "av~=10.0", ] [project.scripts] @@ -107,10 +112,16 @@ select = [ ignore = [ "E501", # line too long (handled by formatter) "B008", # do not perform function calls in argument defaults + "N803", # argument name should be lowercase (allow X, y for ML) + "N806", # variable name should be lowercase (allow X, y, X_train, etc for ML) + "SIM102", # nested if statements (readability preference) + "SIM105", # use contextlib.suppress (readability preference) ] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] # Allow unused imports in __init__.py +"tests/**/*.py" = ["F841", "I001"] # Allow unused variables and skip import sorting in tests +"artifacta/artifacta/integrations/sklearn.py" = ["F401"] # Import sklearn for isinstance check # Mypy configuration [tool.mypy] @@ -132,10 +143,23 @@ disallow_untyped_decorators = false # Pytest configuration [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = ["."] python_files = "test_*.py" python_classes = "Test*" python_functions = "test_*" -addopts = "-v --strict-markers" +asyncio_default_fixture_loop_scope = "function" +addopts = [ + "-v", + "--strict-markers", + "--tb=short", + "--disable-warnings", + "--import-mode=importlib" +] +markers = [ + "integration: Integration tests (requires running server)", + "unit: Unit tests (no external dependencies)", + "e2e: End-to-end tests (full system)" +] # Bump My Version configuration [tool.bumpversion] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index fe9b131..0000000 --- a/pytest.ini +++ /dev/null @@ -1,16 +0,0 @@ -[pytest] -testpaths = tests -python_files = test_*.py -python_classes = Test* -python_functions = test_* -asyncio_default_fixture_loop_scope = function -addopts = - -v - --strict-markers - --tb=short - --disable-warnings - --import-mode=importlib -markers = - integration: Integration tests (requires running server) - unit: Unit tests (no external dependencies) - e2e: End-to-end tests (full system) diff --git a/src/app/components/ArtifactTab/ArtifactTab.jsx b/src/app/components/ArtifactTab/ArtifactTab.jsx index 75ff18d..deffbec 100644 --- a/src/app/components/ArtifactTab/ArtifactTab.jsx +++ b/src/app/components/ArtifactTab/ArtifactTab.jsx @@ -7,6 +7,37 @@ import { apiClient } from '@/core/api/ApiClient'; const API_BASE_URL = import.meta.env.VITE_API_URL; +/** + * Artifact Tab component for previewing logged files + * + * Displays previews of selected artifacts from the Artifacts panel. Supports images, + * CSVs, JSON, code files, audio, video, and more with appropriate rendering. + * + * Supported file types: + * - Tables: CSV (parsed and rendered as table with pagination) + * - Images: PNG, JPG, GIF, SVG (grid gallery view) + * - Code: Python, JavaScript, JSON, YAML (syntax highlighted) + * - Audio: MP3, WAV, OGG (audio player) + * - Video: MP4, WEBM (video player) + * - Text: TXT, MD, LOG (plain text viewer) + * - Binary: Model checkpoints (download link) + * + * Features: + * - Auto-detects file type from MIME type + * - CSV pagination (100 rows per page) + * - Image gallery with multiple images + * - Syntax highlighting for code files + * - Audio/video players + * - Download links for all files + * + * @param {object} props - Component props + * @param {object|null} props.selectedArtifact - Artifact to preview: + * - file: object - File data with content/mime_type + * - imageFiles: Array (optional) - Multiple images + * - audioFiles: Array (optional) - Multiple audio files + * - videoFiles: Array (optional) - Multiple video files + * @returns {React.ReactElement|null} File preview or empty state + */ export const ArtifactTab = ({ selectedArtifact }) => { const [preview, setPreview] = useState(null); const [loading, setLoading] = useState(false); @@ -63,6 +94,11 @@ export const ArtifactTab = ({ selectedArtifact }) => { Papa.parse(file.content, { header: true, skipEmptyLines: true, + /** + * Callback function executed when CSV parsing is complete. + * @param {object} results - The parsed CSV results from Papa Parse + * @returns {void} + */ complete: (results) => { setPreview({ type: 'table', @@ -125,6 +161,10 @@ export const ArtifactTab = ({ selectedArtifact }) => { return; } + /** + * Fetches the preview data for the selected artifact from the API. + * @returns {Promise} + */ const fetchPreview = async () => { setLoading(true); try { @@ -144,6 +184,11 @@ export const ArtifactTab = ({ selectedArtifact }) => { Papa.parse(csvText, { header: true, skipEmptyLines: true, + /** + * Callback function executed when CSV parsing is complete. + * @param {object} results - The parsed CSV results from Papa Parse + * @returns {void} + */ complete: (results) => { const previewData = { type: 'table', @@ -159,6 +204,11 @@ export const ArtifactTab = ({ selectedArtifact }) => { setPreview(previewData); setLoading(false); }, + /** + * Callback function executed when CSV parsing encounters an error. + * @param {Error} error - The error object from Papa Parse + * @returns {void} + */ error: (error) => { console.error('CSV parse error:', error); setPreview(null); @@ -204,10 +254,18 @@ export const ArtifactTab = ({ selectedArtifact }) => { ); } + /** + * Handles navigation to the previous page in paginated data. + * @returns {void} + */ const handlePrevPage = () => { setOffset(Math.max(0, offset - limit)); }; + /** + * Handles navigation to the next page in paginated data. + * @returns {void} + */ const handleNextPage = () => { setOffset(offset + limit); }; diff --git a/src/app/components/ArtifactsPanel/ArtifactsPanel.jsx b/src/app/components/ArtifactsPanel/ArtifactsPanel.jsx index e1b69d3..6f9e41b 100644 --- a/src/app/components/ArtifactsPanel/ArtifactsPanel.jsx +++ b/src/app/components/ArtifactsPanel/ArtifactsPanel.jsx @@ -13,7 +13,7 @@ import './ArtifactsPanel.scss'; /** * Build a tree structure from flat file paths * @param {Array} files - Array of file objects with path property - * @returns {Object} Tree structure { folders: {}, files: [] } + * @returns {object} Tree structure { folders: {}, files: [] } */ const buildFileTree = (files) => { const tree = { folders: {}, files: [] }; @@ -46,8 +46,38 @@ const buildFileTree = (files) => { }; /** - * FilesPanel - Shows files for selected runs in a collapsible folder tree - * Displays file collections with ability to expand/collapse folders and view individual files + * Artifacts Panel component for browsing logged files across runs + * + * Hierarchical file browser showing all artifacts logged by selected runs. + * Groups files by run, with collapsible folder structure for organization. + * + * Features: + * - Multi-run artifact browsing + * - Collapsible folder tree structure + * - File preview on click (delegates to onFileSelect callback) + * - Automatic color coding per run + * - Empty state when no runs selected + * - File count display per run + * - Icon-based file type detection + * + * Artifact types supported: + * - Images (PNG, JPG, etc.) + * - Model checkpoints (.pt, .h5, etc.) + * - Config files (JSON, YAML) + * - Logs (TXT) + * - Any file uploaded via track.log_artifact() + * + * Architecture: + * - Fetches artifacts from API for selected runs + * - Organizes by run β†’ folder β†’ file hierarchy + * - Clicking file triggers parent callback for preview/download + * + * @param {object} props - Component props + * @param {Array} props.selectedRunIds - Run IDs to show artifacts for + * @param {function} props.onFileSelect - Callback when file clicked + * Signature: (fileData: object) => void + * fileData contains: { filename, path, run_id, url, ... } + * @returns {React.ReactElement} Collapsible artifact browser */ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { const [artifacts, setArtifacts] = useState([]); @@ -62,6 +92,10 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { return; } + /** + * Fetch artifacts for all selected runs. + * @returns {Promise} + */ const fetchArtifacts = async () => { setLoading(true); try { @@ -138,6 +172,10 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { fetchArtifacts(); }, [selectedRunIds]); + /** + * Toggle expansion state of an artifact. + * @param {string} artifactId - Artifact ID to toggle + */ const toggleArtifact = (artifactId) => { setExpandedArtifacts(prev => { const next = new Set(prev); @@ -150,7 +188,12 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { }); }; - // Helper to get a specific folder's tree from the full tree + /** + * Get a specific folder's tree from the full tree. + * @param {object} tree - Full file tree + * @param {string} path - Path to folder + * @returns {object|null} Folder tree or null if not found + */ const getFolderTree = (tree, path) => { const parts = path.split('/'); let current = tree; @@ -161,13 +204,23 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { return current; }; - // Generic helper to check if folder contains only files of a specific media type + /** + * Check if folder contains only files of a specific media type. + * @param {object} folderTree - Folder tree structure + * @param {string} mimePrefix - MIME type prefix to check + * @returns {boolean} True if folder contains only files of specified type + */ const isFolderOnlyMediaType = (folderTree, mimePrefix) => { if (Object.keys(folderTree.folders).length > 0) return false; return folderTree.files.every(file => file.mime_type?.startsWith(mimePrefix)); }; - // Generic helper to collect files by media type (recursively) + /** + * Collect files by media type recursively. + * @param {object} folderTree - Folder tree structure + * @param {string} mimePrefix - MIME type prefix to collect + * @returns {Array} Array of files matching the media type + */ const collectFilesByMediaType = (folderTree, mimePrefix) => { const files = []; folderTree.files.forEach(file => { @@ -181,15 +234,50 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { return files; }; - // Convenience wrappers for specific media types + /** + * Check if folder contains only image files. + * @param {object} folderTree - Folder tree structure + * @returns {boolean} True if folder contains only images + */ const isFolderOnlyImages = (folderTree) => isFolderOnlyMediaType(folderTree, 'image/'); + /** + * Check if folder contains only audio files. + * @param {object} folderTree - Folder tree structure + * @returns {boolean} True if folder contains only audio + */ const isFolderOnlyAudio = (folderTree) => isFolderOnlyMediaType(folderTree, 'audio/'); + /** + * Check if folder contains only video files. + * @param {object} folderTree - Folder tree structure + * @returns {boolean} True if folder contains only video + */ const isFolderOnlyVideo = (folderTree) => isFolderOnlyMediaType(folderTree, 'video/'); + /** + * Collect all image files from folder tree. + * @param {object} folderTree - Folder tree structure + * @returns {Array} Array of image files + */ const collectImageFiles = (folderTree) => collectFilesByMediaType(folderTree, 'image/'); + /** + * Collect all audio files from folder tree. + * @param {object} folderTree - Folder tree structure + * @returns {Array} Array of audio files + */ const collectAudioFiles = (folderTree) => collectFilesByMediaType(folderTree, 'audio/'); + /** + * Collect all video files from folder tree. + * @param {object} folderTree - Folder tree structure + * @returns {Array} Array of video files + */ const collectVideoFiles = (folderTree) => collectFilesByMediaType(folderTree, 'video/'); + /** + * Toggle folder expansion and handle media type special cases. + * @param {string} folderPath - Path to folder + * @param {object} fullTree - Full file tree + * @param {object} artifact - Artifact object + */ const toggleFolder = (folderPath, fullTree, artifact) => { const folderTree = getFolderTree(fullTree, folderPath); @@ -246,6 +334,11 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { }); }; + /** + * Parse file collection from JSON string. + * @param {string} content - JSON content string + * @returns {object|null} Parsed file collection or null + */ const parseFileCollection = (content) => { if (!content) return null; try { @@ -255,6 +348,11 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { } }; + /** + * Get appropriate icon component for a file based on type. + * @param {object} file - File object with path, mime_type, and metadata + * @returns {object} Icon component for the file + */ const getFileIcon = (file) => { const ext = file.path.split('.').pop()?.toLowerCase(); const mimeType = file.mime_type; @@ -298,6 +396,11 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { return ; }; + /** + * Handle file click event. + * @param {object} artifact - Artifact object + * @param {object} file - File object + */ const handleFileClick = (artifact, file) => { if (onFileSelect) { onFileSelect({ @@ -312,7 +415,15 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { } }; - // Recursive component to render file tree + /** + * Recursive component to render file tree. + * @param {object} root0 - Component props + * @param {object} root0.tree - Tree structure to render + * @param {object} root0.artifact - Artifact object + * @param {string} root0.pathPrefix - Path prefix for nested folders + * @param {object} root0.fullTree - Full tree (for root reference) + * @returns {object|null} Rendered file tree or null + */ const FileTreeNode = ({ tree, artifact, pathPrefix = '', fullTree = null }) => { if (!tree) return null; const rootTree = fullTree || tree; @@ -392,14 +503,18 @@ export const ArtifactsPanel = ({ selectedRunIds, onFileSelect }) => { file.mime_type?.startsWith('video/') ); - // Handle clicks on the expand icon + /** + * Handle clicks on the expand icon. + * @param {Event} e - Click event + */ const handleExpandClick = (e) => { e.stopPropagation(); // Prevent triggering header click toggleArtifact(artifact.artifact_id); }; - // For single file, click opens it directly - // For image/audio/video-only artifacts, click opens grid view AND toggles expansion + /** + * Handle header click for single file or media-only artifacts. + */ const handleHeaderClick = () => { if (isSingleFile && fileCollection) { handleFileClick(artifact, fileCollection.files[0]); diff --git a/src/app/components/ChatTab/ChatTab.jsx b/src/app/components/ChatTab/ChatTab.jsx index 2416c50..d4deeee 100644 --- a/src/app/components/ChatTab/ChatTab.jsx +++ b/src/app/components/ChatTab/ChatTab.jsx @@ -5,6 +5,39 @@ import rehypeHighlight from 'rehype-highlight'; import './ChatTab.scss'; import { apiClient } from '@/core/api/ApiClient'; +/** + * Chat Tab component for LLM-powered experiment analysis + * + * Interactive chat interface for asking questions about experiment runs using LLMs. + * Automatically loads selected runs' data (config, metrics, artifacts) into LLM context. + * + * Features: + * - LiteLLM integration (supports OpenAI, Anthropic, and other providers) + * - Streaming responses with markdown rendering + * - Code syntax highlighting + * - Auto-loads run data into context + * - Persistent API key storage (localStorage) + * - Resizable input area + * - Auto-scroll to latest messages + * - Multi-run context support + * + * Use cases: + * - "Why did loss spike at epoch 10?" + * - "Compare these two runs' hyperparameters" + * - "Which run had best validation accuracy?" + * - "Explain the difference in convergence patterns" + * + * Architecture: + * - Fetches full run data on selectedRunIds change + * - Sends run data + chat history to OpenAI + * - Streams response chunks for real-time display + * - Renders markdown with code highlighting + * + * @param {object} props - Component props + * @param {Array} props.selectedRunIds - Run IDs to include in chat context + * @param {Array} props.allRuns - All runs (unused, for interface consistency) + * @returns {React.ReactElement|null} Chat interface or setup prompt + */ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { const [runData, setRunData] = useState(null); const [loading, setLoading] = useState(false); @@ -32,6 +65,10 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { }, []); // Detect if user manually scrolls + /** + * Handles scroll events to detect if user has manually scrolled away from bottom + * @returns {void} + */ const handleScroll = () => { if (!messagesContainerRef.current) return; @@ -53,6 +90,11 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { useEffect(() => { if (!isResizing) return; + /** + * Handles mouse movement during input area resize + * @param {React.MouseEvent} e - Mouse event + * @returns {void} + */ const handleMouseMove = (e) => { e.preventDefault(); const containerHeight = window.innerHeight; @@ -61,6 +103,10 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { setInputHeight(Math.max(50, newHeight)); }; + /** + * Handles mouse up event to stop resizing + * @returns {void} + */ const handleMouseUp = () => { setIsResizing(false); document.body.classList.remove('resizing-chat'); @@ -85,6 +131,10 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { return; } + /** + * Fetches run data and artifacts for selected run IDs + * @returns {Promise} + */ const fetchRunData = async () => { setLoading(true); try { @@ -205,6 +255,10 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { return context; }, [runData]); + /** + * Handles sending a message to the LLM and streaming the response + * @returns {Promise} + */ const handleSendMessage = async () => { if (!input.trim() || streaming) return; @@ -316,6 +370,12 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { } }; + /** + * Saves LLM settings to localStorage and updates state + * @param {string} newModel - The LLM model to use + * @param {string} newApiKey - The API key for the LLM provider + * @returns {void} + */ const handleSaveSettings = (newModel, newApiKey) => { localStorage.setItem('llm_model', newModel); localStorage.setItem('llm_api_key', newApiKey); @@ -324,6 +384,10 @@ export const ChatTab = ({ selectedRunIds, allRuns: _allRuns }) => { setShowSetup(false); }; + /** + * Opens the settings modal to change LLM configuration + * @returns {void} + */ const handleChangeSettings = () => { setShowSetup(true); }; diff --git a/src/app/components/LineageTab/LineageTab.jsx b/src/app/components/LineageTab/LineageTab.jsx index ee1b2c9..3e51cc0 100644 --- a/src/app/components/LineageTab/LineageTab.jsx +++ b/src/app/components/LineageTab/LineageTab.jsx @@ -8,10 +8,20 @@ import { createArtifactNode } from './lineageNodeFactory'; -// Custom compact node component +/** + * Custom compact node component for displaying run and artifact nodes in the lineage graph + * @param {object} props - Component props + * @param {object} props.data - Node data containing label, hash, and other display information + * @returns {React.ReactElement} Compact node component + */ const CompactNode = ({ data }) => { const [isExpanded, setIsExpanded] = useState(false); + /** + * Handles click events on the node to toggle expansion + * @param {React.MouseEvent} e - Click event + * @returns {void} + */ const handleClick = (e) => { e.stopPropagation(); setIsExpanded(!isExpanded); @@ -69,6 +79,39 @@ const nodeTypes = { compact: CompactNode }; +/** + * Lineage Flow component for visualizing experiment provenance DAG + * + * Interactive directed acyclic graph (DAG) showing data lineage and dependencies + * for selected experiment runs. Powered by ReactFlow for graph rendering. + * + * Features: + * - DAG visualization of run β†’ artifact β†’ run relationships + * - Code/config/environment/dependency hash nodes + * - Artifact nodes (datasets, models, logs) + * - Interactive node expansion (click to show full details) + * - Auto-layout with hierarchical positioning + * - Zoom and pan controls + * - Multi-run lineage merging + * + * Node types: + * - Run nodes: Experiment runs with hash information + * - Artifact nodes: Logged files, datasets, model checkpoints + * - Hash nodes: Code, config, environment, dependencies, platform hashes + * + * Use cases: + * - Trace data provenance (which dataset produced which model?) + * - Find related runs (same code hash = reproducible) + * - Debug dependency issues (visualize environment changes) + * - Understand experiment lineage over time + * + * @param {object} props - Component props + * @param {Array} props.selectedRunIds - Run IDs to visualize + * @param {Array} props.allRuns - All available runs (for hash lookups) + * @param {function} [props.onDatasetSelect] - Callback for dataset node clicks + * @param {function} [props.onArtifactView] - Callback for artifact node clicks + * @returns {React.ReactElement|null} ReactFlow graph or empty state + */ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView }) => { const [artifactLinksByRun, setArtifactLinksByRun] = useState({}); const [loading, setLoading] = useState(false); @@ -88,6 +131,10 @@ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView return; } + /** + * Fetches artifact links for all selected runs + * @returns {Promise} + */ const fetchAllArtifactLinks = async () => { setLoading(true); try { @@ -284,7 +331,12 @@ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView return
Loading provenance...
; } - // Check if node is connected to hovered node + /** + * Check if a node is connected to the hovered node + * @param {string} nodeId - ID of the node to check + * @param {string} hoveredId - ID of the hovered node + * @returns {boolean} True if the node is connected to the hovered node + */ const isNodeConnected = (nodeId, hoveredId) => { if (nodeId === hoveredId) return true; return edges.some(edge => @@ -293,7 +345,12 @@ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView ); }; - // Check if edge is connected to hovered node + /** + * Check if an edge is connected to the hovered node + * @param {object} edge - Edge object to check + * @param {string} hoveredId - ID of the hovered node + * @returns {boolean} True if the edge is connected to the hovered node + */ const isEdgeConnected = (edge, hoveredId) => { return edge.source === hoveredId || edge.target === hoveredId; }; @@ -318,10 +375,20 @@ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView } })); + /** + * Handles mouse enter event on a node to enable hover highlighting + * @param {object} event - Mouse event + * @param {object} node - Node that was entered + * @returns {void} + */ const handleNodeMouseEnter = (event, node) => { setHoveredNode(node.id); }; + /** + * Handles mouse leave event on a node to disable hover highlighting + * @returns {void} + */ const handleNodeMouseLeave = () => { setHoveredNode(null); }; @@ -346,6 +413,11 @@ const LineageFlow = ({ selectedRunIds, allRuns, onDatasetSelect, onArtifactView ); }; +/** + * LineageTab component wrapper that provides ReactFlow context + * @param {object} props - Component props passed to LineageFlow + * @returns {React.ReactElement} LineageTab component + */ export const LineageTab = (props) => { return ( diff --git a/src/app/components/LineageTab/lineageNodeFactory.js b/src/app/components/LineageTab/lineageNodeFactory.js index 9a9d676..5510b38 100644 --- a/src/app/components/LineageTab/lineageNodeFactory.js +++ b/src/app/components/LineageTab/lineageNodeFactory.js @@ -1,6 +1,26 @@ /** * Factory functions for creating lineage graph nodes - * Eliminates repetitive node creation code in LineageTab + * + * Provides standardized node creation for ReactFlow lineage visualization. + * Eliminates repetitive node setup code and ensures consistent node structure + * across all node types (runs, artifacts, configs, code, environment, etc.). + * + * Node types supported: + * - run: Experiment run nodes + * - artifact: Generic file artifacts (logs, images, etc.) + * - model: Model checkpoint artifacts + * - dataset: Dataset artifacts + * - config: Configuration hash nodes + * - code: Code hash nodes (git commits) + * - env: Environment variable hash nodes + * - deps: Dependency hash nodes + * + * All nodes share common structure: + * - Compact display mode by default + * - Expandable on click to show full details + * - Hash-based deduplication + * - Color coding for artifacts + * - Click handlers for viewing/selecting */ import { getArtifactColor } from '../../utils/artifactColors'; @@ -8,10 +28,10 @@ import { getArtifactColor } from '../../utils/artifactColors'; /** * Create a lineage node with standard structure * @param {string} type - Node type (config, code, dataset, deps, env, run, model) - * @param {Object} data - Node-specific data - * @param {Object} position - {x, y} position + * @param {object} data - Node-specific data + * @param {object} position - {x, y} position * @param {string} color - Optional custom color override - * @returns {Object} ReactFlow node object + * @returns {object} ReactFlow node object */ const createLineageNode = (type, data, position, color) => { const { id, label, hash, extraInfo = {}, ...rest } = data; @@ -34,6 +54,8 @@ const createLineageNode = (type, data, position, color) => { /** * Format runs list for node display + * @param {Array} runs - Array of run objects + * @returns {string} Comma-separated list of run names/IDs */ const formatRunsList = (runs) => { return runs.map(r => r.name || r.run_id).join(', '); @@ -41,6 +63,11 @@ const formatRunsList = (runs) => { /** * Create config node + * @param {object} configGroup - Configuration group data + * @param {object} position - Node position {x, y} + * @param {(artifact: object) => void} onView - Handler to view config JSON + * @param {object} virtualArtifact - Virtual artifact with JSON content + * @returns {object} ReactFlow node object */ export const createConfigNode = (configGroup, position, onView, virtualArtifact) => { return createLineageNode('config', { @@ -57,6 +84,9 @@ export const createConfigNode = (configGroup, position, onView, virtualArtifact) /** * Create code node + * @param {object} codeGroup - Code group data + * @param {object} position - Node position {x, y} + * @returns {object} ReactFlow node object */ export const createCodeNode = (codeGroup, position) => { return createLineageNode('code', { @@ -73,6 +103,10 @@ export const createCodeNode = (codeGroup, position) => { /** * Create dataset node + * @param {object} datasetGroup - Dataset group data + * @param {object} position - Node position {x, y} + * @param {(dataset: object) => void} onDatasetSelect - Handler for dataset selection + * @returns {object} ReactFlow node object */ export const createDatasetNode = (datasetGroup, position, onDatasetSelect) => { return createLineageNode('dataset', { @@ -92,6 +126,9 @@ export const createDatasetNode = (datasetGroup, position, onDatasetSelect) => { /** * Create dependencies node + * @param {object} depsGroup - Dependencies group data + * @param {object} position - Node position {x, y} + * @returns {object} ReactFlow node object */ export const createDepsNode = (depsGroup, position) => { return createLineageNode('deps', { @@ -106,6 +143,9 @@ export const createDepsNode = (depsGroup, position) => { /** * Create environment node + * @param {object} envGroup - Environment group data + * @param {object} position - Node position {x, y} + * @returns {object} ReactFlow node object */ export const createEnvNode = (envGroup, position) => { return createLineageNode('env', { @@ -120,6 +160,10 @@ export const createEnvNode = (envGroup, position) => { /** * Create run node + * @param {object} run - Run data + * @param {number} idx - Run index + * @param {object} position - Node position {x, y} + * @returns {object} ReactFlow node object */ export const createRunNode = (run, idx, position) => { const runLabel = run.name || `Run ${idx + 1}`; @@ -141,6 +185,11 @@ export const createRunNode = (run, idx, position) => { /** * Create model node + * @param {object} model - Model artifact data + * @param {object} run - Run data (null for input artifacts) + * @param {object} position - Node position {x, y} + * @param {(artifact: object) => void} onArtifactView - Handler to view artifact + * @returns {object} ReactFlow node object */ export const createModelNode = (model, run, position, onArtifactView) => { // If run is null, this is an input artifact (grouped by hash) @@ -168,6 +217,12 @@ export const createModelNode = (model, run, position, onArtifactView) => { * Create generic artifact node * Handles both inputs (run=null) and outputs (run provided) * Color-coded by file extension + * @param {object} artifact - Artifact data + * @param {object} run - Run data (null for input artifacts) + * @param {object} position - Node position {x, y} + * @param {(artifact: object) => void} onArtifactView - Handler to view artifact + * @param {(dataset: object) => void} onDatasetSelect - Handler for dataset selection + * @returns {object} ReactFlow node object */ export const createArtifactNode = (artifact, run, position, onArtifactView, onDatasetSelect) => { // If run is null, this is an input artifact (grouped by hash) diff --git a/src/app/components/ProjectNotesTab/FileAttachmentExtension.js b/src/app/components/ProjectNotesTab/FileAttachmentExtension.js index 3244351..e25c48c 100644 --- a/src/app/components/ProjectNotesTab/FileAttachmentExtension.js +++ b/src/app/components/ProjectNotesTab/FileAttachmentExtension.js @@ -9,6 +9,10 @@ export const FileAttachment = Node.create({ atom: true, + /** + * Defines the attributes for the file attachment node. + * @returns {object} The attributes object containing url, fileName, fileSize, fileType, textContent, and language. + */ addAttributes() { return { url: { default: null }, @@ -20,14 +24,28 @@ export const FileAttachment = Node.create({ }; }, + /** + * Parses HTML to recognize file attachment elements. + * @returns {Array} Array of parsing rules for HTML elements. + */ parseHTML() { return [{ tag: 'div[data-file-attachment]' }]; }, + /** + * Renders the file attachment node as HTML. + * @param {object} options - The rendering options. + * @param {object} options.HTMLAttributes - The HTML attributes to apply to the element. + * @returns {Array} Array containing the tag name and attributes for rendering. + */ renderHTML({ HTMLAttributes }) { return ['div', { 'data-file-attachment': '', ...HTMLAttributes }]; }, + /** + * Adds a custom React node view for the file attachment. + * @returns {object} The React node view renderer for the file attachment component. + */ addNodeView() { return ReactNodeViewRenderer(FileAttachmentComponent); }, diff --git a/src/app/components/ProjectNotesTab/FileAttachmentNode.jsx b/src/app/components/ProjectNotesTab/FileAttachmentNode.jsx index ac6f4f2..40c7e8a 100644 --- a/src/app/components/ProjectNotesTab/FileAttachmentNode.jsx +++ b/src/app/components/ProjectNotesTab/FileAttachmentNode.jsx @@ -3,6 +3,37 @@ import { NodeViewWrapper } from '@tiptap/react'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { prism } from 'react-syntax-highlighter/dist/esm/styles/prism'; +/** + * File Attachment Node component for TipTap editor + * + * Custom TipTap node that renders file attachments with appropriate previews + * directly inline in the editor. Used by FileAttachmentExtension. + * + * Supported file types: + * - Code/text files: Syntax-highlighted preview (Python, JS, JSON, etc.) + * - PDFs: Embedded iframe viewer + * - Videos: HTML5 video player (MP4, WEBM) + * - Audio: HTML5 audio player (MP3, WAV) + * - Other files: Download card with file info + * + * Features: + * - Type-specific rendering based on MIME type + * - Syntax highlighting for code files (Prism) + * - Inline preview for media files + * - Download button for all files + * - File size display + * - Responsive layouts + * + * @param {object} props - Component props + * @param {object} props.node - TipTap node object with attrs: + * - url: string - File URL for download/preview + * - fileName: string - Display name + * - fileSize: number - Size in KB + * - fileType: string - MIME type + * - textContent: string (optional) - For code preview + * - language: string (optional) - Syntax highlighting language + * @returns {React.ReactElement} Rendered file attachment with preview + */ export const FileAttachmentComponent = ({ node }) => { const { url, fileName, fileSize, fileType, textContent, language } = node.attrs; diff --git a/src/app/components/ProjectNotesTab/ProjectNotesTab.jsx b/src/app/components/ProjectNotesTab/ProjectNotesTab.jsx index ac04b58..061cd98 100644 --- a/src/app/components/ProjectNotesTab/ProjectNotesTab.jsx +++ b/src/app/components/ProjectNotesTab/ProjectNotesTab.jsx @@ -27,6 +27,13 @@ import { apiClient } from '@/core/api/ApiClient'; import { FileAttachment } from './FileAttachmentExtension'; import './ProjectNotesTab.scss'; +/** + * Renders the editor toolbar with formatting and editing options + * @param {object} props - Component props + * @param {object} props.editor - TipTap editor instance + * @param {(file: File) => void} props.onFileUpload - Callback function to handle file uploads + * @returns {object|null} The menu bar component or null if editor is not available + */ const MenuBar = ({ editor, onFileUpload }) => { const [showTableMenu, setShowTableMenu] = React.useState(false); const [showHeadingMenu, setShowHeadingMenu] = React.useState(false); @@ -36,6 +43,11 @@ const MenuBar = ({ editor, onFileUpload }) => { // Close dropdowns when clicking outside React.useEffect(() => { + /** + * Handles clicks outside dropdown menus to close them + * @param {Event} event - The mouse event + * @returns {void} + */ const handleClickOutside = (event) => { if (tableMenuRef.current && !tableMenuRef.current.contains(event.target)) { setShowTableMenu(false); @@ -58,6 +70,15 @@ const MenuBar = ({ editor, onFileUpload }) => { return null; } + /** + * Renders a toolbar button with active state styling + * @param {object} props - Component props + * @param {() => void} props.onClick - Click handler function + * @param {boolean} props.isActive - Whether the button is in active state + * @param {object} props.children - Button content + * @param {string} props.title - Tooltip text for the button + * @returns {object} The toolbar button component + */ const ToolbarButton = ({ onClick, isActive, children, title }) => ( ); + /** + * Returns the current heading level label based on editor state + * @returns {string} The heading level label (H1-H6) or 'H' if no heading is active + */ const getHeadingLabel = () => { if (editor.isActive('heading', { level: 1 })) return 'H1'; if (editor.isActive('heading', { level: 2 })) return 'H2'; @@ -395,6 +420,41 @@ const MenuBar = ({ editor, onFileUpload }) => { ); }; +/** + * Project Notes Tab component for rich-text note taking with TipTap editor + * + * Full-featured markdown editor for documenting experiments, hypotheses, and findings. + * Supports images, tables, math equations, code blocks, file attachments, and more. + * + * Features: + * - Rich text editing (bold, italic, headings, lists, quotes) + * - Code blocks with syntax highlighting + * - LaTeX math equations (inline and block) + * - Tables with row/column manipulation + * - Task lists with checkboxes + * - Image uploads and embedding + * - File attachments (links to any file) + * - Markdown export + * - Autosave (debounced) + * - Note list sidebar + * - Search/filter notes + * - Delete notes + * + * TipTap extensions used: + * - StarterKit (basic formatting) + * - Image, Link, Table + * - TaskList, TaskItem + * - MathExtension (KaTeX) + * - Custom FileAttachment extension + * + * @param {object} props - Component props + * @param {string} props.projectId - Current project ID + * @param {Array} [props.availableRuns] - Runs in project (unused currently) + * @param {string|null} [props.externalNoteId] - Note ID to load (from sidebar navigation) + * @param {string|null} [props.externalProjectId] - Project ID to load (from sidebar navigation) + * @param {boolean} [props.isCreatingNew=false] - Auto-start creating new note + * @returns {React.ReactElement} TipTap editor with note list sidebar + */ export const ProjectNotesTab = ({ projectId, availableRuns: _availableRuns = [], @@ -462,6 +522,12 @@ export const ProjectNotesTab = ({ }, [isCreatingNew, externalProjectId]); + /** + * Loads a note from the API and populates the editor with its content and attachments + * @param {string} activeProjectId - The ID of the project containing the note + * @param {string} noteId - The ID of the note to load + * @returns {Promise} + */ const loadNote = async (activeProjectId, noteId) => { try { const note = await apiClient.getProjectNote(activeProjectId, noteId); @@ -550,6 +616,10 @@ export const ProjectNotesTab = ({ } }; + /** + * Creates a new project note with the current title and editor content + * @returns {Promise} + */ const createNote = async () => { try { const activeProjectId = externalProjectId || projectId; @@ -569,6 +639,10 @@ export const ProjectNotesTab = ({ } }; + /** + * Updates the currently selected note with the current title and editor content + * @returns {Promise} + */ const updateNote = async () => { try { const activeProjectId = externalProjectId || projectId; @@ -584,6 +658,11 @@ export const ProjectNotesTab = ({ } }; + /** + * Deletes a note after user confirmation + * @param {string} noteId - The ID of the note to delete + * @returns {Promise} + */ const deleteNote = async (noteId) => { // eslint-disable-next-line no-undef if (!confirm('Delete this note? This cannot be undone.')) return; @@ -604,6 +683,10 @@ export const ProjectNotesTab = ({ } }; + /** + * Resets the form by clearing the title and editor content + * @returns {void} + */ const resetForm = () => { setTitle(''); if (editor) { @@ -612,6 +695,11 @@ export const ProjectNotesTab = ({ }; // Helper: Detect language from file extension + /** + * Detects the programming language based on file extension + * @param {string} filename - The filename to analyze + * @returns {string} The detected language identifier or 'text' as fallback + */ const getLanguageFromFilename = (filename) => { const ext = filename.split('.').pop()?.toLowerCase(); const langMap = { @@ -627,6 +715,11 @@ export const ProjectNotesTab = ({ }; // Helper: Check if file is a text/code file + /** + * Determines if a file is a text or code file based on MIME type and extension + * @param {File} file - The file object to check + * @returns {boolean} True if the file is a text/code file, false otherwise + */ const isTextFile = (file) => { const textTypes = ['text/', 'application/json', 'application/xml', 'application/javascript']; const textExtensions = ['.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.rb', '.go', '.rs', '.php', '.swift', '.kt', '.sql', '.sh', '.bash', '.json', '.xml', '.yaml', '.yml', '.md', '.html', '.css', '.scss', '.txt', '.log']; @@ -635,6 +728,11 @@ export const ProjectNotesTab = ({ textExtensions.some(ext => file.name.toLowerCase().endsWith(ext)); }; + /** + * Handles file uploads by processing and inserting them into the editor + * @param {File} file - The file to upload and insert + * @returns {Promise} + */ const handleFileUpload = async (file) => { try { const isImage = file.type.startsWith('image/'); @@ -651,6 +749,10 @@ export const ProjectNotesTab = ({ if (isText) { // Text/code files: Read as text for syntax highlighting const textReader = new FileReader(); + /** + * Handles successful text file reading and inserts content into editor + * @returns {void} + */ textReader.onload = () => { const textContent = textReader.result; const language = getLanguageFromFilename(file.name); @@ -674,6 +776,10 @@ export const ProjectNotesTab = ({ } else { // Binary files (images, PDFs, videos, audio): Read as data URL const reader = new FileReader(); + /** + * Handles successful binary file reading and inserts content into editor + * @returns {void} + */ reader.onload = () => { const dataUrl = reader.result; diff --git a/src/app/components/ProjectsPanel/ProjectsPanel.jsx b/src/app/components/ProjectsPanel/ProjectsPanel.jsx index 50efb27..645281b 100644 --- a/src/app/components/ProjectsPanel/ProjectsPanel.jsx +++ b/src/app/components/ProjectsPanel/ProjectsPanel.jsx @@ -3,6 +3,36 @@ import { HiPlus, HiDocumentText } from 'react-icons/hi'; import { apiClient } from '@/core/api/ApiClient'; import { SIDEBAR_STYLES } from '@/app/styles/sidebarConstants'; +/** + * Projects Panel component for managing projects and notes + * + * Sidebar panel showing all projects and their associated notes. Projects automatically + * group experiment runs and can have multiple notes attached for documentation. + * + * Features: + * - Project list with run counts + * - Notes list per project + * - Create new notes (+ button) + * - Click note to open in Notes tab + * - Auto-refresh when notes change + * - Implicit projects (auto-created from runs) + * - Explicit projects (manually created) + * + * Project types: + * - Implicit: Auto-created when runs logged to a project_id + * - Explicit: Created via API or UI + * + * @param {object} props - Component props + * @param {Array} props.runs - All experiment runs (for project→run mapping) + * @param {function} props.onNoteSelect - Callback when note clicked + * Signature: (projectId: string, note: object) => void + * Opens note in Notes tab + * @param {function} props.onNewNote - Callback to create new note + * Signature: (projectId: string) => void + * @param {function} props.onTabChange - Callback to switch tabs + * Signature: (tab: string) => void + * @returns {React.ReactElement} Projects and notes sidebar panel + */ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => { const [projects, setProjects] = useState([]); const [selectedProject, setSelectedProject] = useState(null); @@ -11,6 +41,10 @@ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => // Fetch all projects (explicit + implicit from runs) useEffect(() => { + /** + * Loads all projects from the API and merges them with runs data + * @returns {Promise} Promise that resolves when projects are loaded + */ const loadProjects = async () => { try { const data = await apiClient.getProjects(); @@ -68,6 +102,10 @@ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => // Listen for custom event to refresh notes useEffect(() => { + /** + * Handles the refresh event by incrementing the notes version + * @returns {void} + */ const handleRefresh = () => { setNotesVersion(prev => prev + 1); }; @@ -76,6 +114,11 @@ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => return () => window.removeEventListener('refreshProjectNotes', handleRefresh); }, []); + /** + * Loads notes for a specific project + * @param {string} projectId - The ID of the project to load notes for + * @returns {Promise} Promise that resolves when notes are loaded + */ const loadNotes = async (projectId) => { try { const data = await apiClient.getProjectNotes(projectId); @@ -86,6 +129,11 @@ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => } }; + /** + * Handles clicking on a project to expand/collapse it + * @param {string} projectId - The ID of the project that was clicked + * @returns {void} + */ const handleProjectClick = (projectId) => { if (selectedProject === projectId) { setSelectedProject(null); @@ -95,18 +143,31 @@ export const ProjectsPanel = ({ runs, onNoteSelect, onNewNote, onTabChange }) => } }; + /** + * Handles clicking on a note to view it + * @param {object} note - The note object that was clicked + * @returns {void} + */ const handleNoteClick = (note) => { // Switch to Notes tab and load the note onTabChange('notes'); onNoteSelect(selectedProject, note); }; + /** + * Handles creating a new note for the selected project + * @returns {void} + */ const handleCreateNote = () => { // Switch to Notes tab and create new note onTabChange('notes'); onNewNote(selectedProject); }; + /** + * Handles creating a new project with a user-provided name + * @returns {Promise} Promise that resolves when project is created + */ const handleCreateProject = async () => { const projectName = window.prompt('Enter project name:'); if (!projectName) return; diff --git a/src/app/components/RunSelector/CollapsibleButton.jsx b/src/app/components/RunSelector/CollapsibleButton.jsx index 3d535c6..e653130 100644 --- a/src/app/components/RunSelector/CollapsibleButton.jsx +++ b/src/app/components/RunSelector/CollapsibleButton.jsx @@ -2,8 +2,23 @@ import React from 'react'; import { HiChevronDown, HiChevronUp } from 'react-icons/hi'; /** - * Reusable collapsible toggle button component - * Used to expand/collapse table sections in RunSelector + * Collapsible toggle button for expand/collapse interactions + * + * Reusable UI button used throughout the app for collapsing sections. + * Primarily used in RunSelector for collapsing run detail tables. + * + * Features: + * - Animated chevron icon (up/down) + * - Hover effects (color change, shadow) + * - Click animation (scale down on press) + * - Tooltip showing current state + * - Circular design with subtle shadow + * + * @param {object} props - Component props + * @param {boolean} props.isCollapsed - Current collapsed state + * @param {function} props.onClick - Click handler to toggle state + * @param {string} [props.title] - Custom tooltip (defaults to "Expand/Collapse table") + * @returns {React.ReactElement} Circular toggle button with chevron icon */ export const CollapsibleButton = ({ isCollapsed, onClick, title }) => { return ( diff --git a/src/app/components/RunSelector/RunFilter.jsx b/src/app/components/RunSelector/RunFilter.jsx index b5cb826..be8cacd 100644 --- a/src/app/components/RunSelector/RunFilter.jsx +++ b/src/app/components/RunSelector/RunFilter.jsx @@ -3,19 +3,37 @@ import './RunFilter.scss'; import { getMetricValue } from '@/app/utils/metricAggregation'; /** - * Dynamic Run Filter Component - * Inspired by Weights & Biases filtering system + * Run Filter component for dynamic filtering of experiment runs + * + * W&B-inspired filtering system that auto-discovers filter options from run data. + * No hardcoded metric names - adapts to whatever metrics users log. * * Features: - * - Auto-discovers filter options from run data (no hardcoding!) - * - Search by run name/ID - * - Filter by status (Running/Completed) - * - Filter by metric thresholds (min/max) + * - Text search (run name, run ID) + * - Status filtering (Running, Completed, Failed) + * - Metric threshold filtering (min/max ranges per metric) + * - Auto-discovery of available metrics across all runs + * - Collapsible filter panel + * - Real-time filtering (updates as user types) + * - Respects aggregation mode for metric values + * + * Filter types: + * 1. Search: Fuzzy match on run name or ID + * 2. Status: Running (in progress) vs Completed/Failed + * 3. Metrics: Min/max range filters for any logged metric + * + * Architecture: + * - Stateless filtering (pure function applied to runs) + * - Callback-based (notifies parent of filtered results) + * - Metric discovery via structured_data inspection * - * @param {Array} runs - All available runs - * @param {Function} onFilterChange - Callback with filtered runs - * @param {String} aggregationMode - Current aggregation mode (min/max/final) - * @param {String} optimizeMetric - Metric to optimize for min/max modes + * @param {object} props - Component props + * @param {Array} [props.runs=[]] - All available runs to filter + * @param {function} props.onFilterChange - Callback with filtered results + * Signature: (filteredRuns: Array) => void + * @param {string} [props.aggregationMode='min'] - How to aggregate metrics ('min'/'max'/'final') + * @param {string} [props.optimizeMetric='loss'] - Metric to optimize for min/max modes + * @returns {React.ReactElement} Collapsible filter panel */ export const RunFilter = ({ runs = [], onFilterChange, aggregationMode = 'min', optimizeMetric = 'loss' }) => { const [searchQuery, setSearchQuery] = useState(''); @@ -59,7 +77,11 @@ export const RunFilter = ({ runs = [], onFilterChange, aggregationMode = 'min', return Array.from(statuses); }, [runs]); - // Format metric name for display + /** + * Format metric name for display + * @param {string} metricKey - The metric key to format + * @returns {string} The formatted metric name + */ const formatMetricName = (metricKey) => { return metricKey .split('_') @@ -114,6 +136,13 @@ export const RunFilter = ({ runs = [], onFilterChange, aggregationMode = 'min', onFilterChange(filteredRuns); }, [filteredRuns, onFilterChange]); + /** + * Handle metric filter change + * @param {string} metricKey - The metric key to filter + * @param {string} type - The type of threshold (min or max) + * @param {string} value - The threshold value + * @returns {void} + */ const handleMetricFilterChange = (metricKey, type, value) => { setMetricFilters(prev => ({ ...prev, @@ -124,6 +153,11 @@ export const RunFilter = ({ runs = [], onFilterChange, aggregationMode = 'min', })); }; + /** + * Clear a specific metric filter + * @param {string} metricKey - The metric key to clear + * @returns {void} + */ const clearMetricFilter = (metricKey) => { setMetricFilters(prev => { const newFilters = { ...prev }; @@ -132,6 +166,10 @@ export const RunFilter = ({ runs = [], onFilterChange, aggregationMode = 'min', }); }; + /** + * Clear all filters and reset to default state + * @returns {void} + */ const clearAllFilters = () => { setSearchQuery(''); setStatusFilter('all'); diff --git a/src/app/components/RunSelector/RunSelector.jsx b/src/app/components/RunSelector/RunSelector.jsx index 6462320..5795cca 100644 --- a/src/app/components/RunSelector/RunSelector.jsx +++ b/src/app/components/RunSelector/RunSelector.jsx @@ -5,7 +5,12 @@ import { exportSeriesGroupsAsCSV } from '@/core/utils/csvExport'; import { discoverMetricsByStream, getMetricValue } from '@/app/utils/metricAggregation'; import './RunSelector.scss'; -// Memoized config display to prevent flickering on re-renders +/** + * Memoized config display to prevent flickering on re-renders + * @param {object} props - Component props + * @param {object} props.config - Configuration object to display + * @returns {React.ReactElement} Config display element + */ const ConfigDisplay = React.memo(({ config }) => (
{JSON.stringify(config, null, 2)}
), (prevProps, nextProps) => { @@ -14,10 +19,33 @@ const ConfigDisplay = React.memo(({ config }) => ( }); /** - * Run Selector Component + * Run Selector component for browsing, filtering, and selecting experiment runs + * + * Main interface for run management in the sidebar. Displays all runs with their + * configurations, status, and metrics. Supports filtering, sorting, and multi-selection. + * + * Features: + * - List view of all experiment runs + * - Multi-select with checkboxes + * - Run filtering (search by name, status, tags) + * - Expandable config/metadata display for each run + * - CSV export of selected runs' metrics + * - Run count display + * - Select all/none shortcuts + * - Collapsible run details + * + * Architecture: + * - Controlled component (selection state managed by parent) + * - Receives runs via props (avoids duplicate polling) + * - Delegates filtering to RunFilter sub-component * - * Displays and manages run selection from the database. - * Supports filtering, sorting, and multi-selection of runs. + * @param {object} props - Component props + * @param {Array} [props.selectedRunIds=[]] - Currently selected run IDs (controlled) + * @param {function} [props.onRunSelectionChange] - Callback when selection changes + * Signature: (selectedRunIds: Array) => void + * @param {Array} [props.runs=[]] - All available runs from database + * @param {boolean} [props.runsLoading=false] - Whether runs are currently loading + * @returns {React.ReactElement} Run selector interface */ export const RunSelector = ({ selectedRunIds = [], // NEW: controlled from parent @@ -35,13 +63,21 @@ export const RunSelector = ({ // Per-table filtered runs: tableKey -> filtered runs array const [tableFilteredRuns, setTableFilteredRuns] = useState({}); - // Export all visible tables as CSV + /** + * Exports all visible tables as CSV + * @returns {void} + */ const handleExportAllTables = () => { const selectedRuns = runs.filter(r => selectedRunIds.includes(r.run_id)); const metricsByStream = discoverMetricsByStream(selectedRuns); exportSeriesGroupsAsCSV(selectedRuns, metricsByStream, streamAggregation, getMetricValue); }; + /** + * Toggles selection state of a run + * @param {string} runId - ID of the run to toggle + * @returns {void} + */ const toggleRunSelection = (runId) => { const newSelection = selectedRunIds.includes(runId) ? selectedRunIds.filter(id => id !== runId) @@ -49,6 +85,11 @@ export const RunSelector = ({ onRunSelectionChange(newSelection); }; + /** + * Toggles expanded state of a run to show/hide config + * @param {string} runId - ID of the run to toggle + * @returns {void} + */ const toggleExpandRun = (runId) => { setExpandedRunIds(prev => { const newSet = new Set(prev); @@ -61,6 +102,11 @@ export const RunSelector = ({ }); }; + /** + * Toggles expanded state of a table + * @param {string} tableKey - Key of the table to toggle + * @returns {void} + */ const toggleTableExpand = (tableKey) => { setExpandedTables(prev => { const newSet = new Set(prev); @@ -114,11 +160,22 @@ export const RunSelector = ({ const visibleParamKeys = Array.from(allParamKeys).sort(); const visibleTagKeys = Array.from(allTagKeys).sort(); + /** + * Formats a metric value for display + * @param {string} key - Metric key + * @param {string|number|null|undefined} value - Metric value to format + * @returns {string} Formatted value string + */ const formatMetricValue = (key, value) => { if (value === null || value === undefined) return '-'; return typeof value === 'number' ? value.toFixed(4) : String(value); }; + /** + * Converts a key to a display name with capitalized words + * @param {string} key - Key to convert + * @returns {string} Display name + */ const getDisplayName = (key) => { return key.split('_').map(word => word.charAt(0).toUpperCase() + word.slice(1) diff --git a/src/app/components/RunTree/RunTree.jsx b/src/app/components/RunTree/RunTree.jsx index 4b036a8..8f04fe7 100644 --- a/src/app/components/RunTree/RunTree.jsx +++ b/src/app/components/RunTree/RunTree.jsx @@ -2,8 +2,11 @@ import React, { useState, useMemo } from 'react'; import { SIDEBAR_STYLES } from '@/app/styles/sidebarConstants'; /** - * Hierarchy definition: order of hash keys for grouping (most β†’ least important) - * This defines the tree structure from root to leaves + * Hierarchy definition for run tree grouping + * + * Defines the order of hash-based grouping from most to least important. + * Runs are grouped by code changes first, then config, then environment, etc. + * This creates a deterministic tree structure for reproducibility tracking. */ const HASH_HIERARCHY = [ { key: 'hash.code', label: 'Code' }, @@ -14,14 +17,21 @@ const HASH_HIERARCHY = [ ]; /** - * Generic recursive tree builder with auto-collapsing of redundant levels - * Groups runs hierarchically by hash keys in specified order - * Skips levels where all runs have the same value (redundant grouping) + * Recursive tree builder with auto-collapsing of redundant levels + * + * Groups runs hierarchically by hash values in the specified order. + * Automatically skips levels where all runs have identical values (no point grouping). * - * @param {Array} runs - Array of run objects - * @param {Array} hierarchy - Array of hash keys to group by (in order) - * @param {number} depth - Current depth in tree (0 = root) - * @returns {Array} Tree nodes with structure: { hashKey, hashValue, label, runs, children } + * Algorithm: + * 1. At each level, group runs by current hash key value + * 2. If only one unique value exists, skip this level (redundant) + * 3. Otherwise, create nodes for each unique value + * 4. Recursively build children at next hierarchy level + * + * @param {Array} runs - Runs to organize + * @param {Array} [hierarchy=HASH_HIERARCHY] - Hash keys to group by + * @param {number} [depth=0] - Current recursion depth + * @returns {Array} Tree nodes: { hashKey, hashValue, label, runs, children } */ function buildHashTree(runs, hierarchy = HASH_HIERARCHY, depth = 0) { // Base case: no more hierarchy levels or no runs @@ -80,6 +90,14 @@ function buildHashTree(runs, hierarchy = HASH_HIERARCHY, depth = 0) { /** * Recursive tree node renderer + * @param {object} props - Component props + * @param {object} props.node - Tree node data + * @param {Array} props.allRuns - All runs in the tree + * @param {Array} props.selectedRunIds - Array of selected run IDs + * @param {(runIds: Array) => void} props.onRunSelectionChange - Callback for selection changes + * @param {Set} props.collapsedNodes - Set of collapsed node IDs + * @param {(nodeId: string) => void} props.toggleNode - Callback to toggle node collapse state + * @returns {React.ReactElement} Rendered tree node */ function TreeNode({ node, allRuns, selectedRunIds, onRunSelectionChange, collapsedNodes, toggleNode }) { @@ -88,6 +106,11 @@ function TreeNode({ node, allRuns, selectedRunIds, onRunSelectionChange, collaps const isCollapsed = collapsedNodes.has(nodeId); const isNodeSelected = node.runs?.some(r => selectedRunIds.includes(r.run_id)) ?? false; + /** + * Handles click event on a run item to toggle its selection + * @param {string} runId - ID of the run to toggle + * @returns {void} + */ const handleRunClick = (runId) => { const isSelected = selectedRunIds.includes(runId); const newSelection = isSelected @@ -223,7 +246,40 @@ function TreeNode({ node, allRuns, selectedRunIds, onRunSelectionChange, collaps } /** - * Git-tree style run browser with multi-level hash-based grouping + * Run Tree component for hierarchical run organization and selection + * + * Git-style tree browser that groups experiment runs by reproducibility hashes. + * Helps identify which runs were conducted under identical conditions (code, config, + * environment, dependencies, platform). + * + * Features: + * - Multi-level hash-based grouping (code β†’ config β†’ env β†’ deps β†’ platform) + * - Auto-collapse redundant levels (skips levels where all runs are identical) + * - Collapsible tree nodes + * - Multi-select with checkboxes (supports selecting entire groups) + * - Visual hierarchy with indentation + * - Shows run counts per group + * - Newest-first sorting + * + * Use cases: + * - Finding reproducible runs (same code + config = should get same results) + * - Identifying what changed between runs + * - Bulk-selecting runs from same experiment batch + * - Debugging environment/dependency issues + * + * Hash hierarchy (top to bottom): + * 1. Code - Git commit hash or code content hash + * 2. Config - Hyperparameter configuration hash + * 3. Environment - Environment variables hash + * 4. Dependencies - Package versions hash + * 5. Platform - OS/hardware hash + * + * @param {object} props - Component props + * @param {Array} props.runs - Experiment runs to organize + * @param {Array} props.selectedRunIds - Currently selected run IDs + * @param {function} props.onRunSelectionChange - Callback when selection changes + * Signature: (runIds: Array) => void + * @returns {React.ReactElement} Hierarchical tree with selectable runs */ export function RunTree({ runs, selectedRunIds, onRunSelectionChange }) { // Track which nodes are collapsed (by nodeId: "hashKey:hashValue") @@ -235,6 +291,11 @@ export function RunTree({ runs, selectedRunIds, onRunSelectionChange }) { return result; }, [runs]); + /** + * Toggles the collapsed state of a tree node + * @param {string} nodeId - ID of the node to toggle + * @returns {void} + */ const toggleNode = (nodeId) => { setCollapsedNodes(prev => { const next = new Set(prev); diff --git a/src/app/components/SweepsTab/SweepsTab.jsx b/src/app/components/SweepsTab/SweepsTab.jsx index a9dda5e..36ceee5 100644 --- a/src/app/components/SweepsTab/SweepsTab.jsx +++ b/src/app/components/SweepsTab/SweepsTab.jsx @@ -5,7 +5,7 @@ import ParameterCorrelationChart from '@/app/components/visualizations/plots/Par import ScatterPlot from '@/app/components/visualizations/plots/ScatterPlot'; import { DraggableVisualization } from '@/app/components/visualizations/DraggableVisualization'; import { useLayoutManager } from '@/app/hooks'; -import { calculateParameterImportance } from '@/app/utils/comparisonPlotDiscovery'; +import { calculateParameterCorrelation } from '@/app/utils/comparisonPlotDiscovery'; import './SweepsTab.scss'; /** @@ -16,6 +16,10 @@ import './SweepsTab.scss'; * - Scatter Plots: Swept parameter vs target metric * * Only renders if selected runs form a valid sweep (same keys, one varying param) + * @param {object} props - Component props + * @param {Array} props.runs - Array of run objects + * @param {Array} props.selectedRunIds - Array of selected run IDs + * @returns {object|null} The rendered sweeps tab or null if invalid */ export const SweepsTab = ({ runs, selectedRunIds }) => { const [selectedMetric, setSelectedMetric] = useState(null); @@ -37,6 +41,11 @@ export const SweepsTab = ({ runs, selectedRunIds }) => { updateLabels } = useLayoutManager(); + /** + * Toggles the visibility of a visualization by its key + * @param {string} vizKey - The unique key of the visualization to toggle + * @returns {void} + */ const toggleVisualizationVisibility = (vizKey) => { setHiddenVisualizations(prev => { const next = new Set(prev); @@ -66,12 +75,12 @@ export const SweepsTab = ({ runs, selectedRunIds }) => { if (!sweepData?.valid || !runs || selectedRunIds.length < 3) { return null; } - // Get original runs (not the transformed sweep runs) - calculateParameterImportance needs full run structure + // Get original runs (not the transformed sweep runs) - calculateParameterCorrelation needs full run structure const originalRuns = runs.filter(r => selectedRunIds.includes(r.run_id)); // Calculate correlation for all varying parameters const varyingParamNames = sweepData.varyingParams.map(p => p.name); - return calculateParameterImportance( + return calculateParameterCorrelation( originalRuns, varyingParamNames, sweepData.availableMetrics, diff --git a/src/app/components/layout/Sidebar.jsx b/src/app/components/layout/Sidebar.jsx index f31c7f7..8fdb57f 100644 --- a/src/app/components/layout/Sidebar.jsx +++ b/src/app/components/layout/Sidebar.jsx @@ -2,6 +2,34 @@ import React, { useState, useRef, useEffect } from 'react'; import { motion, AnimatePresence } from 'framer-motion'; import './Sidebar.scss'; +/** + * Sidebar component for navigation and run selection + * + * Collapsible and resizable sidebar containing run browser, filters, and navigation. + * Primary interface for selecting which experiment runs to analyze. + * + * Features: + * - Collapsible (hide/show with toggle button) + * - Resizable width (drag right edge) + * - Backend status indicator + * - Smooth collapse animation (Framer Motion) + * - Persistent width across sessions + * + * Contains: + * - Run selector (tree or list view) + * - Run filters + * - Project selector + * - Backend connection status + * + * @param {object} props - Component props + * @param {React.ReactNode} props.children - Sidebar content (RunSelector, filters, etc.) + * @param {object} props.backendStatus - Backend connection status: + * - connected: boolean + * - message: string (optional) + * @param {boolean} props.isCollapsed - Whether sidebar is currently collapsed + * @param {function} props.onToggle - Callback to toggle collapse state + * @returns {React.ReactElement} Collapsible, resizable sidebar + */ export const Sidebar = ({ children, backendStatus, isCollapsed, onToggle }) => { const [sidebarWidth, setSidebarWidth] = useState(267); const [isResizing, setIsResizing] = useState(false); @@ -10,6 +38,11 @@ export const Sidebar = ({ children, backendStatus, isCollapsed, onToggle }) => { useEffect(() => { if (!isResizing) return; + /** + * Handles mouse movement during sidebar resize + * @param {Event} e - Mouse event object + * @returns {void} + */ const handleMouseMove = (e) => { e.preventDefault(); const newWidth = e.clientX; @@ -17,6 +50,10 @@ export const Sidebar = ({ children, backendStatus, isCollapsed, onToggle }) => { setSidebarWidth(Math.max(0, newWidth)); }; + /** + * Handles mouse up event to end sidebar resize + * @returns {void} + */ const handleMouseUp = () => { setIsResizing(false); document.body.classList.remove('resizing-sidebar'); diff --git a/src/app/components/layout/TabbedInterface/TabbedInterface.jsx b/src/app/components/layout/TabbedInterface/TabbedInterface.jsx index 494e476..9b53dc5 100644 --- a/src/app/components/layout/TabbedInterface/TabbedInterface.jsx +++ b/src/app/components/layout/TabbedInterface/TabbedInterface.jsx @@ -2,12 +2,38 @@ import React, { useEffect } from 'react'; import './TabbedInterface.scss'; /** - * Professional tabbed interface with fixed tabs + * Tabbed Interface component for main content area navigation * - * @param {Array} tabs - Array of tab objects with { id, label, content } - * @param {Array} visibleTabs - Array of visible tab IDs - * @param {string} activeTab - Currently active tab ID - * @param {Function} onTabChange - Callback when active tab changes + * Tab bar for switching between different analysis views (Plots, Tables, Sweeps, etc.). + * Handles tab visibility and ensures a valid tab is always active. + * + * Features: + * - Dynamic tab visibility (tabs can be shown/hidden via View menu) + * - Active tab highlighting + * - Auto-switches to first visible tab if active tab hidden + * - Empty state when no tabs visible + * - Fixed tab bar at top + * - Content area below tabs + * + * Typical tabs: + * - Plots: Auto-discovered visualizations + * - Tables: Metric comparison tables + * - Sweeps: Hyperparameter sweep analysis + * - Lineage: Provenance graph + * - Artifacts: File browser + * - Chat: LLM experiment analysis + * - Notes: Rich-text note taking + * + * @param {object} props - Component props + * @param {Array} [props.tabs=[]] - All available tabs: + * - id: string - Unique tab identifier + * - label: string - Display name + * - content: React.ReactNode - Tab content component + * @param {Array} [props.visibleTabs=[]] - IDs of tabs to show + * @param {string} props.activeTab - Currently active tab ID + * @param {function} props.onTabChange - Callback when user switches tabs + * Signature: (tabId: string) => void + * @returns {React.ReactElement} Tab bar with content area */ export const TabbedInterface = ({ tabs = [], diff --git a/src/app/components/ui/ComponentSettingsMenu.jsx b/src/app/components/ui/ComponentSettingsMenu.jsx index c167463..e6234e7 100644 --- a/src/app/components/ui/ComponentSettingsMenu.jsx +++ b/src/app/components/ui/ComponentSettingsMenu.jsx @@ -3,6 +3,29 @@ import { motion, AnimatePresence } from 'framer-motion'; import { HiDotsVertical } from 'react-icons/hi'; import './ComponentSettingsMenu.scss'; +/** + * Component Settings Menu for customizing visualization labels + * + * Dropdown menu (gear icon) that allows users to edit plot titles and axis labels + * in real-time. Integrated into DraggableVisualization header. + * + * Features: + * - Inline editing (click to edit fields) + * - Real-time updates (changes apply immediately) + * - Animated dropdown (Framer Motion) + * - Auto-detects available fields (title, X/Y axis labels) + * - Only shows relevant fields for each visualization + * + * @param {object} props - Component props + * @param {string} props.visualizationKey - Unique ID of visualization being customized + * @param {object} [props.currentLabels={}] - Current label values: + * - title: string (optional) + * - xAxisLabel: string (optional) + * - yAxisLabel: string (optional) + * @param {function} props.onUpdateLabels - Callback to save label changes + * Signature: (visualizationKey: string, fieldKey: string, value: string) => void + * @returns {React.ReactElement|null} Settings dropdown or null if no editable fields + */ export function ComponentSettingsMenu({ visualizationKey, currentLabels = {}, @@ -22,11 +45,22 @@ export function ComponentSettingsMenu({ // Only show fields that exist in the initial currentLabels object const fields = allFields.filter(field => field.key in currentLabels); + /** + * Handles entering edit mode for a specific field + * @param {string} fieldKey - The key of the field being edited + * @param {string} currentValue - The current value of the field + * @returns {void} + */ const handleEdit = (fieldKey, currentValue) => { setEditingField(fieldKey); setTempValue(currentValue || ''); }; + /** + * Handles changes to the input field and updates labels in real-time + * @param {object} e - The change event from the input + * @returns {void} + */ const handleChange = (e) => { const newValue = e.target.value; setTempValue(newValue); @@ -36,11 +70,20 @@ export function ComponentSettingsMenu({ } }; + /** + * Handles exiting edit mode when the input loses focus + * @returns {void} + */ const handleBlur = () => { setEditingField(null); setTempValue(''); }; + /** + * Handles keyboard events to exit edit mode on Enter or Escape + * @param {object} e - The keyboard event + * @returns {void} + */ const handleKeyDown = (e) => { if (e.key === 'Enter' || e.key === 'Escape') { e.target.blur(); // Exit edit mode diff --git a/src/app/components/visualizations/DraggableVisualization.jsx b/src/app/components/visualizations/DraggableVisualization.jsx index 2a1b169..7e2c297 100644 --- a/src/app/components/visualizations/DraggableVisualization.jsx +++ b/src/app/components/visualizations/DraggableVisualization.jsx @@ -3,9 +3,45 @@ import { motion } from 'framer-motion'; import './DraggableVisualization.scss'; /** - * Transform-based draggable visualization wrapper - * Uses CSS transforms for dragging - NO absolute positioning mode - * Elements always stay in flexbox flow + * Draggable visualization wrapper using transform-based drag system + * + * Wraps plot components to make them draggable and resizable while staying in flexbox flow. + * Uses CSS transforms for smooth 60fps dragging without layout recalculation. + * + * Architecture: + * - Components remain in normal document flow (NOT position: absolute) + * - During drag: Apply CSS transform for visual feedback (GPU-accelerated) + * - After drag: Clear transform, optionally reorder DOM + * - Resize: Uses Framer Motion drag handles + * + * Features: + * - Smooth dragging with transform-based animation + * - Resize handles (corners and edges) + * - Close button for removing visualizations + * - Header with title + * - Automatic z-index management (dragged element on top) + * - Integration with layout manager + * + * Why transforms instead of absolute positioning: + * - No layout thrashing (no offsetTop/offsetLeft calculations) + * - GPU-accelerated (compositing layer) + * - Maintains flexbox benefits (responsive, auto-sizing) + * - Easier to revert (just remove transform) + * + * @param {object} props - Component props + * @param {string} props.visualizationKey - Unique identifier for this visualization + * @param {string} props.title - Display title in header + * @param {function} props.onClose - Callback when close button clicked + * @param {function} props.onDragStart - Callback when drag begins (key) + * @param {function} props.onDrag - Callback during drag (key, deltaX, deltaY) + * @param {function} props.onDragEnd - Callback when drag ends (key) + * @param {function} props.onResize - Callback when resized (key, width, height) + * @param {object} [props.dragTransform] - CSS transform to apply: { x, y } + * @param {boolean} [props.isDragging] - Whether currently being dragged + * @param {function} props.registerElement - Register with layout manager (key, element, type) + * @param {string} props.chartType - Type of chart for registration + * @param {React.ReactNode} props.children - Plot component to wrap + * @returns {React.ReactElement} Draggable visualization wrapper */ export function DraggableVisualization({ visualizationKey, @@ -36,12 +72,22 @@ export function DraggableVisualization({ } }, [registerElement, visualizationKey, chartType]); + /** + * Handles the start of a drag operation + * @returns {void} + */ const handleDragStart = () => { if (onDragStart) { onDragStart(visualizationKey); } }; + /** + * Handles drag movement + * @param {object} _event - The drag event (unused) + * @param {object} info - Information about the drag state + * @returns {void} + */ const handleDrag = (_event, info) => { if (onDrag) { // Pass delta from start of drag @@ -49,6 +95,10 @@ export function DraggableVisualization({ } }; + /** + * Handles the end of a drag operation + * @returns {void} + */ const handleDragEnd = () => { if (onDragEnd) { onDragEnd(visualizationKey); @@ -56,6 +106,11 @@ export function DraggableVisualization({ }; // Edge detection for resize + /** + * Detects which edge of the element the mouse is near + * @param {object} e - The mouse event + * @returns {string|null} The edge identifier (n, s, e, w, ne, nw, se, sw) or null + */ const getEdgeFromMouse = (e) => { if (!elementRef.current) return null; @@ -79,6 +134,11 @@ export function DraggableVisualization({ return null; }; + /** + * Gets the appropriate cursor style for a given edge + * @param {string} edge - The edge identifier + * @returns {string} The CSS cursor value + */ const getCursorForEdge = (edge) => { const cursors = { n: 'ns-resize', @@ -93,6 +153,11 @@ export function DraggableVisualization({ return cursors[edge] || 'default'; }; + /** + * Handles mouse movement to update cursor style based on edge proximity + * @param {object} e - The mouse event + * @returns {void} + */ const handleMouseMove = (e) => { if (isResizing || isDragging) return; @@ -100,6 +165,11 @@ export function DraggableVisualization({ setCursorStyle(edge ? getCursorForEdge(edge) : 'grab'); }; + /** + * Handles mouse down to initiate resizing + * @param {object} e - The mouse event + * @returns {void} + */ const handleMouseDown = (e) => { // Don't handle if clicking on close button or interactive elements if (e.target.closest('.viz-viz-close-btn, select, button, input, textarea, a')) { @@ -135,6 +205,11 @@ export function DraggableVisualization({ React.useEffect(() => { if (!isResizing) return; + /** + * Handles mouse movement during resize operation + * @param {object} e - The mouse event + * @returns {void} + */ const handleResizeMove = (e) => { if (!resizeStartRef.current) return; @@ -170,6 +245,10 @@ export function DraggableVisualization({ } }; + /** + * Handles the end of a resize operation + * @returns {void} + */ const handleResizeEnd = () => { setIsResizing(false); resizeStartRef.current = null; diff --git a/src/app/components/visualizations/UniversalVisualizationRenderer.jsx b/src/app/components/visualizations/UniversalVisualizationRenderer.jsx index 3372c41..accfb66 100644 --- a/src/app/components/visualizations/UniversalVisualizationRenderer.jsx +++ b/src/app/components/visualizations/UniversalVisualizationRenderer.jsx @@ -8,8 +8,35 @@ import ScatterPlot from './plots/ScatterPlot'; import BarChart from './plots/BarChart'; /** - * Universal Visualization Renderer - * Routes plot configs to appropriate components based on type + * Universal Visualization Renderer for plot type dispatch + * + * Routes plot configurations to the appropriate visualization component based on type. + * Acts as a centralized dispatcher that maps primitive types from plot discovery + * to their corresponding React components. + * + * Supported plot types: + * - line: Line plots (time series, multi-series) + * - scatter: Scatter plots (2D point clouds) + * - heatmap: Heat maps (2D matrices, confusion matrices) + * - barchart: Bar charts (categorical comparisons) + * - histogram: Histograms (distributions) + * - violin: Violin plots (distribution comparisons) + * - curve: ROC/PR curves with metrics + * - table: Data tables (handled separately in Tables tab) + * + * Architecture: + * - Receives plotConfig from plot discovery system + * - Type-based switch dispatch to appropriate component + * - Forwards data, metadata, and title to child components + * - Returns null for unsupported types or missing configs + * + * @param {object} props - Component props + * @param {object} props.plotConfig - Plot configuration from discovery: + * - type: string - Plot type identifier + * - data: object - Plot-specific data structure + * - metadata: object (optional) - Additional plot metadata + * - title: string (optional) - Plot title + * @returns {React.ReactElement|null} Rendered visualization component or error message */ const UniversalVisualizationRenderer = ({ plotConfig }) => { if (!plotConfig) return null; diff --git a/src/app/components/visualizations/plots/BarChart.jsx b/src/app/components/visualizations/plots/BarChart.jsx index 0a176a3..a044545 100644 --- a/src/app/components/visualizations/plots/BarChart.jsx +++ b/src/app/components/visualizations/plots/BarChart.jsx @@ -4,12 +4,46 @@ import { useResponsiveCanvas } from '../../../hooks/useResponsiveCanvas'; import { getChartColor } from '../../../../core/utils/constants'; /** - * Bar Chart - * Displays categorical data with grouped or stacked bars + * Bar Chart component for categorical data comparison + * + * Renders grouped bar charts for comparing metrics across categories. + * Used for model comparisons, feature importance, and categorical distributions. + * + * Features: + * - Grouped bars (multiple series per category) + * - Auto-scaled Y-axis + * - Rotated X-axis labels for long category names + * - Legend for series identification + * - HiDPI display support + * + * Data format: + * ``` + * { + * categories: ["Model A", "Model B", "Model C"], + * groups: [ + * { label: "Accuracy", values: [0.85, 0.92, 0.88] }, + * { label: "F1 Score", values: [0.83, 0.90, 0.86] } + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {object} props.data - Bar chart data + * @param {Array} props.data.categories - Category labels (X-axis) + * @param {Array} props.data.groups - Data groups with labels and values + * @param {string} [props.title] - Chart title (handled by wrapper) + * @param {object} [props.metadata] - Additional metadata + * @returns {React.ReactElement} Canvas-based bar chart */ const BarChart = ({ data, title: _title, metadata: _metadata }) => { const canvasRef = useRef(null); + /** + * Draws the bar chart on the canvas + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @returns {void} + */ const drawBarChart = useCallback((width, height) => { if (!data) return; @@ -105,6 +139,24 @@ const BarChart = ({ data, title: _title, metadata: _metadata }) => { ); }; +/** + * Draws vertical bars on the canvas for bar chart visualization + * @param {object} ctx - Canvas 2D rendering context + * @param {Array} categories - Category labels for x-axis + * @param {object} groups - Group data with values for each category + * @param {Array} groupNames - Names of the groups + * @param {boolean} stacked - Whether bars should be stacked or grouped + * @param {object} padding - Padding object with top, right, bottom, left properties + * @param {number} plotWidth - Width of the plot area + * @param {number} plotHeight - Height of the plot area + * @param {number} canvasWidth - Total canvas width + * @param {number} canvasHeight - Total canvas height + * @param {number} minValue - Minimum value for y-axis scale + * @param {number} maxValue - Maximum value for y-axis scale + * @param {string} xLabel - Label for x-axis + * @param {string} yLabel - Label for y-axis + * @returns {void} + */ function drawVerticalBars(ctx, categories, groups, groupNames, stacked, padding, plotWidth, plotHeight, canvasWidth, canvasHeight, minValue, maxValue, xLabel, yLabel) { const numCategories = categories.length; const numGroups = groupNames.length; diff --git a/src/app/components/visualizations/plots/CurveChart.jsx b/src/app/components/visualizations/plots/CurveChart.jsx index 640c1da..a650a98 100644 --- a/src/app/components/visualizations/plots/CurveChart.jsx +++ b/src/app/components/visualizations/plots/CurveChart.jsx @@ -7,9 +7,54 @@ import { toTitleCase } from '@/core/utils/formatters'; import { getChartColor, CHART_PADDING } from '@/core/utils/constants'; /** - * Generic Curve Chart - * Displays any array of objects with numeric fields as an X-Y curve - * Auto-detects axis fields from data structure + * Curve Chart component for performance curves (ROC, PR, calibration) + * + * Generic field-agnostic curve visualization that auto-detects axes from data. + * Commonly used for ROC curves (with AUC), precision-recall curves, and calibration plots. + * + * Features: + * - Auto-detects X/Y field names from data structure + * - Multi-curve overlay with different colors + * - Optional diagonal reference line (ROC curves) + * - Metric display (e.g., AUC = 0.95) + * - Interactive tooltips + * - Custom axis labels + * - HiDPI display support + * + * Data format (single curve): + * ``` + * { + * data: [ + * { fpr: 0.0, tpr: 0.0 }, + * { fpr: 0.1, tpr: 0.7 }, + * { fpr: 1.0, tpr: 1.0 } + * ], + * metric: 0.95, + * metricLabel: "AUC" + * } + * ``` + * + * Data format (multi-curve): + * ``` + * { + * curves: [ + * { label: "Model A", points: [...], metric: { name: "AUC", value: 0.95 } }, + * { label: "Model B", points: [...], metric: { name: "AUC", value: 0.92 } } + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {Array|object} props.data - Curve data (array of points or object with curves) + * @param {Array} [props.curves] - Multi-curve data with labels + * @param {string} [props.xField] - X-axis field (auto-detected from first point) + * @param {string} [props.yField] - Y-axis field (auto-detected from first point) + * @param {string} [props.xLabel] - X-axis label (auto-generated if not provided) + * @param {string} [props.yLabel] - Y-axis label (auto-generated if not provided) + * @param {number} [props.metric] - Metric value to display (e.g., 0.95 for AUC) + * @param {string} [props.metricLabel] - Metric name (e.g., "AUC", "Average Precision") + * @param {boolean} [props.showDiagonal=false] - Show y=x diagonal line (for ROC) + * @returns {React.ReactElement|null} Canvas-based curve chart with tooltip */ const CurveChart = ({ data, // Array of objects with numeric fields: [{x, y}, ...] OR {curves: [...]} for multi-run @@ -42,7 +87,10 @@ const CurveChart = ({ const legacyYLabel = customYLabelROC || customYLabelPR; const isROC = !!rocData || showDiagonal; - // Auto-detect field names from data + /** + * Auto-detect field names from data + * @returns {object} Object with x and y field names + */ const detectFields = useCallback(() => { if (!curveData || curveData.length === 0) return { x: null, y: null }; @@ -56,6 +104,11 @@ const CurveChart = ({ }; }, [curveData, xField, yField]); + /** + * Draws the curve chart on the canvas + * @param {number} width - Canvas width + * @param {number} height - Canvas height + */ const drawCurve = useCallback((width, height) => { const canvas = canvasRef.current; if (!canvas) return; @@ -94,7 +147,17 @@ const CurveChart = ({ const yMin = Math.min(...yValues); const yMax = Math.max(...yValues); + /** + * Converts data X value to canvas X coordinate + * @param {number} val - Data X value + * @returns {number} Canvas X coordinate + */ const toCanvasX = (val) => padding.left + ((val - xMin) / (xMax - xMin || 1)) * chartWidth; + /** + * Converts data Y value to canvas Y coordinate + * @param {number} val - Data Y value + * @returns {number} Canvas Y coordinate + */ const toCanvasY = (val) => padding.top + chartHeight - ((val - yMin) / (yMax - yMin || 1)) * chartHeight; // Draw axes and grid @@ -159,7 +222,17 @@ const CurveChart = ({ const yMin = Math.min(...yValues); const yMax = Math.max(...yValues); + /** + * Converts data X value to canvas X coordinate + * @param {number} val - Data X value + * @returns {number} Canvas X coordinate + */ const toCanvasX = (val) => padding.left + ((val - xMin) / (xMax - xMin || 1)) * chartWidth; + /** + * Converts data Y value to canvas Y coordinate + * @param {number} val - Data Y value + * @returns {number} Canvas Y coordinate + */ const toCanvasY = (val) => padding.top + chartHeight - ((val - yMin) / (yMax - yMin || 1)) * chartHeight; drawAxes(ctx, width, height, padding, finalXLabel, finalYLabel); @@ -221,15 +294,30 @@ const CurveChart = ({ // Use responsive canvas hook - handles sizing and redraw automatically useResponsiveCanvas(canvasRef, drawCurve); - // Tooltip logic: find nearest point on curves (on-the-fly transform) + /** + * Tooltip logic: find nearest point on curves (on-the-fly transform) + * @param {number} mouseX - Mouse X coordinate + * @param {number} mouseY - Mouse Y coordinate + * @param {number} searchRadius - Search radius for finding nearest point + * @returns {object|null} Tooltip data or null if no point found + */ const getTooltipData = useCallback((mouseX, mouseY, searchRadius) => { if (!plotDataRef.current) return null; const pd = plotDataRef.current; const { padding, chartWidth, chartHeight, xMin, xMax, yMin, yMax } = pd; - // Transform function + /** + * Converts data X value to canvas X coordinate + * @param {number} val - Data X value + * @returns {number} Canvas X coordinate + */ const toCanvasX = (val) => padding.left + ((val - xMin) / (xMax - xMin || 1)) * chartWidth; + /** + * Converts data Y value to canvas Y coordinate + * @param {number} val - Data Y value + * @returns {number} Canvas Y coordinate + */ const toCanvasY = (val) => padding.top + chartHeight - ((val - yMin) / (yMax - yMin || 1)) * chartHeight; let nearestPoint = null; diff --git a/src/app/components/visualizations/plots/Heatmap.jsx b/src/app/components/visualizations/plots/Heatmap.jsx index 77f0539..00b4ecd 100644 --- a/src/app/components/visualizations/plots/Heatmap.jsx +++ b/src/app/components/visualizations/plots/Heatmap.jsx @@ -5,14 +5,50 @@ import { getChartFont } from '@/app/hooks/useCanvasSetup'; import PlotTooltip from '../shared/PlotTooltip'; /** - * Heatmap for Matrix data - * Displays 2D matrix as color-coded grid - * Expects data format: { rows: [labels], cols: [labels], values: [[numbers]] } + * Heatmap component for 2D matrix visualization + * + * Renders color-coded grid heatmaps with interactive tooltips. Commonly used for + * confusion matrices, correlation matrices, and attention weights. + * + * Features: + * - Color-coded cells based on value (blue-white-red gradient) + * - Interactive tooltips showing exact cell values + * - Row and column labels + * - Auto-scaled cell sizes based on matrix dimensions + * - Value annotations in each cell + * - HiDPI display support + * + * Data format: + * ``` + * { + * rows: ["Class A", "Class B", "Class C"], // Row labels + * cols: ["Pred A", "Pred B", "Pred C"], // Column labels + * values: [ // 2D matrix + * [120, 5, 2], // Row 0 + * [3, 95, 8], // Row 1 + * [1, 10, 110] // Row 2 + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {object} props.data - Heatmap data with labels and values + * @param {Array} props.data.rows - Row labels + * @param {Array} props.data.cols - Column labels + * @param {Array>} props.data.values - 2D matrix of numeric values + * @param {string} [props._title] - Title (handled by wrapper, not used internally) + * @returns {React.ReactElement} Canvas-based heatmap with tooltip */ -const Heatmap = ({ data, title }) => { +const Heatmap = ({ data, _title }) => { const canvasRef = useRef(null); const plotDataRef = useRef(null); + /** + * Draws the heatmap on canvas + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @returns {void} + */ const drawHeatmap = useCallback((width, height) => { if (!data || !data.values || !canvasRef.current) { return; @@ -65,6 +101,11 @@ const Heatmap = ({ data, title }) => { const maxVal = Math.max(...allValues); // Color scale: blue (low) -> white (mid) -> red (high) + /** + * Calculates color for a value based on min/max range + * @param {number} value - The value to calculate color for + * @returns {string} RGB color string + */ const getColor = (value) => { const normalized = (value - minVal) / (maxVal - minVal); @@ -173,11 +214,16 @@ const Heatmap = ({ data, title }) => { // No return needed - container controls size, plot adapts - }, [data, title]); + }, [data]); useResponsiveCanvas(canvasRef, drawHeatmap); - // Tooltip logic: find cell under mouse (on-the-fly) + /** + * Tooltip logic: find cell under mouse (on-the-fly) + * @param {number} mouseX - Mouse X coordinate + * @param {number} mouseY - Mouse Y coordinate + * @returns {object|null} Tooltip data or null if not over a cell + */ const getTooltipData = useCallback((mouseX, mouseY) => { if (!plotDataRef.current) return null; diff --git a/src/app/components/visualizations/plots/Histogram.jsx b/src/app/components/visualizations/plots/Histogram.jsx index 3d6343c..64c30b7 100644 --- a/src/app/components/visualizations/plots/Histogram.jsx +++ b/src/app/components/visualizations/plots/Histogram.jsx @@ -4,12 +4,42 @@ import { useResponsiveCanvas } from '../../../hooks/useResponsiveCanvas'; import { getChartColor } from '../../../../core/utils/constants'; /** - * Histogram Plot - * Displays distribution of values with optional grouping + * Histogram component for distribution visualization + * + * Renders histograms showing value distributions with configurable binning. + * Used for loss distributions, parameter distributions, and data analysis. + * + * Features: + * - Automatic or custom binning + * - Multiple distribution overlay (grouped histograms) + * - Frequency counts per bin + * - HiDPI display support + * + * Data format: + * ``` + * { + * groups: [ + * { label: "Training", bins: [...], counts: [...] }, + * { label: "Validation", bins: [...], counts: [...] } + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {object} props.data - Histogram data with bins and counts + * @param {string} [props.title] - Chart title (handled by wrapper) + * @param {object} [props.metadata] - Additional metadata + * @returns {React.ReactElement} Canvas-based histogram */ const Histogram = ({ data, title: _title, metadata: _metadata }) => { const canvasRef = useRef(null); + /** + * Draws the histogram on the canvas + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @returns {void} + */ const drawHistogram = useCallback((width, height) => { if (!data) return; @@ -80,6 +110,16 @@ const Histogram = ({ data, title: _title, metadata: _metadata }) => { ); }; +/** + * Draws a simple (non-grouped) histogram + * @param {object} ctx - Canvas 2D context + * @param {number[]} values - Array of numerical values to plot + * @param {number} numBins - Number of histogram bins + * @param {object} padding - Padding object with top, right, bottom, left properties + * @param {number} plotWidth - Width of the plot area + * @param {number} plotHeight - Height of the plot area + * @returns {void} + */ function drawSimpleHistogram(ctx, values, numBins, padding, plotWidth, plotHeight) { // Compute bins const min = Math.min(...values); @@ -137,6 +177,17 @@ function drawSimpleHistogram(ctx, values, numBins, padding, plotWidth, plotHeigh ctx.restore(); } +/** + * Draws a grouped histogram with multiple series + * @param {object} ctx - Canvas 2D context + * @param {number[]} values - Array of numerical values to plot + * @param {string[]} groups - Array of group labels corresponding to each value + * @param {number} numBins - Number of histogram bins + * @param {object} padding - Padding object with top, right, bottom, left properties + * @param {number} plotWidth - Width of the plot area + * @param {number} plotHeight - Height of the plot area + * @returns {void} + */ function drawGroupedHistogram(ctx, values, groups, numBins, padding, plotWidth, plotHeight) { // Group values const groupedValues = {}; diff --git a/src/app/components/visualizations/plots/LinePlot.jsx b/src/app/components/visualizations/plots/LinePlot.jsx index f37efae..cb58aa6 100644 --- a/src/app/components/visualizations/plots/LinePlot.jsx +++ b/src/app/components/visualizations/plots/LinePlot.jsx @@ -7,9 +7,44 @@ import { getChartColor, CHART_PADDING } from '../../../../core/utils/constants'; import { formatYAxisValue } from '../../../../core/utils/formatters'; /** - * Line Plot for Series data - * Displays multiple time series on a single chart - * Expects data format: { xLabel: string, datasets: [{label, data: [{x, y}]}] } + * Line Plot component for time series visualization + * + * Renders multi-series line charts with interactive tooltips, auto-scaled axes, + * and responsive resizing. Optimized for training metrics (loss curves, accuracy over time). + * + * Features: + * - Multi-series overlay (multiple lines on one chart with color coding) + * - Interactive tooltips showing exact values on hover + * - Auto-scaled Y-axis based on data range + * - Grid lines for easier value reading + * - Legend showing series names and colors + * - HiDPI (Retina) display support + * - Responsive to container size changes + * + * Data format: + * ``` + * { + * xLabel: "Epoch", // X-axis label + * datasets: [ + * { + * label: "train_loss", // Series name + * data: [ + * { x: 0, y: 0.5 }, // Individual points + * { x: 1, y: 0.3 }, + * { x: 2, y: 0.2 } + * ] + * }, + * { label: "val_loss", data: [...] } + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {object} props.data - Plot data with datasets and labels + * @param {string} props.data.xLabel - X-axis label (e.g., "Epoch", "Step") + * @param {Array} props.data.datasets - Array of series to plot + * @param {string} [props.title] - Optional plot title + * @returns {React.ReactElement} Canvas-based line plot with interactive tooltip */ const LinePlot = ({ data, title }) => { const canvasRef = useRef(null); diff --git a/src/app/components/visualizations/plots/ParallelCoordinatesChart.jsx b/src/app/components/visualizations/plots/ParallelCoordinatesChart.jsx index e9ccf96..61e9a16 100644 --- a/src/app/components/visualizations/plots/ParallelCoordinatesChart.jsx +++ b/src/app/components/visualizations/plots/ParallelCoordinatesChart.jsx @@ -5,9 +5,41 @@ import { useResponsiveCanvas } from '@/app/hooks/useResponsiveCanvas'; import { CHART_PADDING } from '@/core/utils/constants'; /** - * Parallel Coordinates Chart - * Visualizes relationships between hyperparameters and a selected metric across multiple runs - * Features: smooth curves, color gradient by metric value, hover tooltips, metric aggregation + * Parallel Coordinates Chart for hyperparameter sweep visualization + * + * Advanced multi-dimensional visualization showing relationships between hyperparameters + * and metrics across multiple runs. Each run is a polyline connecting parameter values + * to metric outcome, colored by performance. + * + * Features: + * - Multi-dimensional visualization (unlimited parameters + 1 metric) + * - Color gradient by metric value (green = better, red = worse) + * - Interactive hover tooltips showing full run details + * - Metric aggregation options (last, max, min, avg) + * - Metric selector dropdown + * - Smooth B-spline curves between axes + * - Auto-scaled axes with min/max normalization + * - HiDPI display support + * + * Use cases: + * - Identify optimal hyperparameter combinations + * - Detect parameter interactions (parallel/crossing lines) + * - Compare run performance across multiple dimensions + * + * How to read it: + * - Each vertical axis represents a hyperparameter or the selected metric + * - Each polyline represents one experiment run + * - Line color indicates metric performance (customizable gradient) + * - Parallel lines = parameters change together + * - Crossing lines = parameter interactions/trade-offs + * + * @param {object} props - Component props + * @param {Array} props.hyperparameters - Parameter names to visualize + * @param {Array} props.availableMetrics - Metrics available for selection + * @param {string} [props.defaultMetric] - Initial metric to display + * @param {Array} [props.data] - Pre-aggregated run data (if available) + * @param {Array} [props.runs] - Raw run objects for on-the-fly aggregation + * @returns {React.ReactElement} Canvas-based parallel coordinates chart */ const ParallelCoordinatesChart = ({ hyperparameters, availableMetrics, defaultMetric, data, runs }) => { const canvasRef = useRef(null); @@ -32,8 +64,12 @@ const ParallelCoordinatesChart = ({ hyperparameters, availableMetrics, defaultMe return [...hyperparameters, selectedMetric]; }, [hyperparameters, selectedMetric]); - // Color gradient function (purple/blue -> teal -> green, inspired by W&B) - const getGradientColor = (normalizedValue) => { + /** + * Generates a color gradient from purple-blue (low) to cyan to green (high) + * @param {number} normalizedValue - Normalized value between 0 and 1 + * @returns {string} RGB color string + */ + const getGradientColor = useCallback((normalizedValue) => { // Beautiful gradient: purple-blue (low) -> cyan -> green-yellow (high) if (normalizedValue < 0.5) { // 0.0 -> 0.5: purple-blue (#8B5CF6) to cyan (#06B6D4) @@ -50,7 +86,7 @@ const ParallelCoordinatesChart = ({ hyperparameters, availableMetrics, defaultMe const b = Math.round(212 + (129 - 212) * t); return `rgb(${r}, ${g}, ${b})`; } - }; + }, []); // Normalize data for each axis to 0-1 range const normalizedData = useMemo(() => { @@ -133,8 +169,14 @@ const ParallelCoordinatesChart = ({ hyperparameters, availableMetrics, defaultMe metricValue }; }); - }, [aggregatedData, axes, hyperparameters, selectedMetric]); - + }, [aggregatedData, axes, hyperparameters, selectedMetric, getGradientColor]); + + /** + * Draws the parallel coordinates chart on the canvas + * @param {number} width - Canvas width in pixels + * @param {number} height - Canvas height in pixels + * @returns {void} + */ const drawChart = useCallback((width, height) => { const canvas = canvasRef.current; if (!canvas || !normalizedData || !axes || axes.length === 0) return; @@ -267,13 +309,18 @@ const ParallelCoordinatesChart = ({ hyperparameters, availableMetrics, defaultMe }); ctx.globalAlpha = 1.0; - }, [normalizedData, axes, hoveredRun]); + }, [normalizedData, axes, hoveredRun, getGradientColor]); useResponsiveCanvas(canvasRef, drawChart); // Throttled mouse move handler for performance const throttledMouseMoveRef = useRef(null); + /** + * Handles mouse move events over the canvas for hover detection + * @param {object} e - Mouse event object + * @returns {void} + */ const handleMouseMove = useCallback((e) => { // Throttle to ~60fps if (throttledMouseMoveRef.current) return; diff --git a/src/app/components/visualizations/plots/ParameterCorrelationChart.jsx b/src/app/components/visualizations/plots/ParameterCorrelationChart.jsx index 40adb54..16f932b 100644 --- a/src/app/components/visualizations/plots/ParameterCorrelationChart.jsx +++ b/src/app/components/visualizations/plots/ParameterCorrelationChart.jsx @@ -1,5 +1,5 @@ import React, { useRef, useState, useCallback, useMemo, useEffect } from 'react'; -import { calculateParameterImportance, getShortParamName } from '@/app/utils/comparisonPlotDiscovery'; +import { calculateParameterCorrelation, getShortParamName } from '@/app/utils/comparisonPlotDiscovery'; import { drawTopLegend, getChartFont } from '@/app/hooks/useCanvasSetup'; import { useResponsiveCanvas } from '@/app/hooks/useResponsiveCanvas'; import { CHART_PADDING } from '@/core/utils/constants'; @@ -11,6 +11,13 @@ import { CHART_PADDING } from '@/core/utils/constants'; * Displays: * - Correlation: Linear relationship between parameter and metric (-1 to 1) * - Visual bars: Green for positive, red for negative correlation + * @param {object} props - Component props + * @param {string[]} props.hyperparameters - Array of hyperparameter names + * @param {string[]} props.availableMetrics - Array of available metric names + * @param {string} props.defaultMetric - Default metric to display + * @param {object} props.importance - Pre-calculated parameter importance data + * @param {Array} props.runs - Array of run objects with hyperparameter and metric data + * @returns {React.ReactElement} The rendered component */ const ParameterCorrelationChart = ({ hyperparameters, availableMetrics, defaultMetric, importance: defaultImportance, runs }) => { const canvasRef = useRef(null); @@ -26,14 +33,20 @@ const ParameterCorrelationChart = ({ hyperparameters, availableMetrics, defaultM } }, [selectedMetric, availableMetrics]); - // Recalculate importance when aggregation changes + // Recalculate correlation when aggregation changes const importance = useMemo(() => { if (!runs || aggregation === 'last') return defaultImportance; // Use pre-calculated if default // Recalculate with new aggregation - return calculateParameterImportance(runs, hyperparameters, availableMetrics, aggregation); + return calculateParameterCorrelation(runs, hyperparameters, availableMetrics, aggregation); }, [runs, hyperparameters, availableMetrics, aggregation, defaultImportance]); + /** + * Draws the parameter correlation chart on canvas + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @returns {void} + */ const drawChart = useCallback((width, height) => { const canvas = canvasRef.current; if (!canvas || !importance || !selectedMetric || !hyperparameters) return; diff --git a/src/app/components/visualizations/plots/ScatterPlot.jsx b/src/app/components/visualizations/plots/ScatterPlot.jsx index 93e2013..c7242d5 100644 --- a/src/app/components/visualizations/plots/ScatterPlot.jsx +++ b/src/app/components/visualizations/plots/ScatterPlot.jsx @@ -6,17 +6,36 @@ import { CHART_PADDING } from '@/core/utils/constants'; import PlotTooltip from '../shared/PlotTooltip'; /** - * Scatter Plot for Scatter primitive + * Scatter Plot component for 2D point cloud visualization * - * Expected data format (scatter primitive): + * Renders scatter plots with auto-detected axes, interactive tooltips, and optional + * point labels. Used for hyperparameter spaces, embeddings, and feature correlations. + * + * Features: + * - Auto-detects numeric fields for X and Y axes (field-name agnostic) + * - Interactive tooltips showing point details on hover + * - Optional point labels and custom colors + * - Auto-scaled axes based on data range + * - HiDPI display support + * + * Data format: + * ``` * { - * points: [{: val, : val, label?, size?, color?}, ...], - * x_label: "X Axis Label", - * y_label: "Y Axis Label" + * points: [ + * { lr: 0.01, accuracy: 0.85, label: "Run 1", color: "#ff0000" }, + * { lr: 0.1, accuracy: 0.92, label: "Run 2" } + * ], + * x_label: "Learning Rate", + * y_label: "Accuracy" * } + * ``` * - * Auto-detects numeric fields - does NOT assume field names. - * First two numeric fields are used as x and y. + * @param {object} props - Component props + * @param {object} props.data - Scatter data with points and axis labels + * @param {Array} props.data.points - Array of point objects (flexible fields) + * @param {string} [props.data.x_label] - X-axis label (auto-detected if not provided) + * @param {string} [props.data.y_label] - Y-axis label (auto-detected if not provided) + * @returns {React.ReactElement} Canvas-based scatter plot with tooltip */ const ScatterPlot = ({ data }) => { const canvasRef = useRef(null); diff --git a/src/app/components/visualizations/plots/ViolinPlot.jsx b/src/app/components/visualizations/plots/ViolinPlot.jsx index 2397f9a..fd91c8d 100644 --- a/src/app/components/visualizations/plots/ViolinPlot.jsx +++ b/src/app/components/visualizations/plots/ViolinPlot.jsx @@ -4,8 +4,34 @@ import { useResponsiveCanvas } from '@/app/hooks/useResponsiveCanvas'; import { getChartColor, CHART_PADDING } from '../../../../core/utils/constants'; /** - * Violin Plot - Shows distribution density with embedded box plot - * Displays distribution of values across groups using kernel density estimation + * Violin Plot component for distribution comparison across groups + * + * Combines box plots with kernel density estimation to show distribution shape. + * Used for comparing metric distributions across different models or hyperparameter settings. + * + * Features: + * - Kernel density estimation (KDE) for distribution shape + * - Embedded box plot showing quartiles and median + * - Multiple group comparison side-by-side + * - Auto-scaled axes + * - HiDPI display support + * + * Data format: + * ``` + * { + * groups: [ + * { label: "Model A", values: [0.82, 0.85, 0.88, 0.91] }, + * { label: "Model B", values: [0.75, 0.79, 0.83, 0.86] } + * ] + * } + * ``` + * + * @param {object} props - Component props + * @param {object} props.data - Violin plot data with groups + * @param {Array} props.data.groups - Array of groups with labels and values + * @param {string} [props.title] - Chart title (handled by wrapper) + * @param {object} [props.metadata] - Additional metadata + * @returns {React.ReactElement} Canvas-based violin plot */ const ViolinPlot = ({ data, title: _title, metadata: _metadata }) => { const canvasRef = useRef(null); @@ -169,6 +195,8 @@ const ViolinPlot = ({ data, title: _title, metadata: _metadata }) => { /** * Compute kernel density estimation for violin plot + * @param {number[]} values - Array of numeric values + * @returns {{points: Array<{value: number, density: number}>, maxDensity: number}} KDE result with points and max density */ function computeKDE(values) { const sorted = [...values].sort((a, b) => a - b); @@ -204,6 +232,11 @@ function computeKDE(values) { return { points, maxDensity }; } +/** + * Compute statistical quartiles for box plot + * @param {number[]} values - Array of numeric values + * @returns {{min: number, max: number, median: number, q1: number, q3: number}} Statistical values + */ function computeStats(values) { const sorted = [...values].sort((a, b) => a - b); const n = sorted.length; diff --git a/src/app/components/visualizations/shared/PlotTooltip.jsx b/src/app/components/visualizations/shared/PlotTooltip.jsx index e8d2779..b55d1fc 100644 --- a/src/app/components/visualizations/shared/PlotTooltip.jsx +++ b/src/app/components/visualizations/shared/PlotTooltip.jsx @@ -2,8 +2,36 @@ import React from 'react'; import './PlotTooltip.css'; /** - * Unified tooltip component for all plot types - * Handles positioning, formatting, and rendering of tooltip data + * Plot Tooltip component for displaying data point details on hover + * + * Unified tooltip rendering for all plot types (line, scatter, heatmap, curves, etc.). + * Automatically positions itself to avoid screen edges and formats data appropriately. + * + * Features: + * - Smart edge detection (flips to left/top/bottom if near edge) + * - Type-specific rendering (series, scatter, matrix, curve, distribution) + * - Custom value formatters + * - Fixed positioning (follows cursor) + * - High z-index (always on top) + * - Pointer-events: none (doesn't block mouse) + * + * Supported data types: + * - series: Time series with index + multiple values + * - scatter: X/Y coordinates with optional label + * - matrix: Row/col/value for heatmaps + * - curve: X/Y for ROC/PR curves with metric display + * - distribution: Count + range for histograms + * - generic: Key-value pairs for fallback + * + * @param {object} props - Component props + * @param {boolean} props.visible - Whether tooltip should be shown + * @param {number} props.x - Screen X coordinate (from mouse event) + * @param {number} props.y - Screen Y coordinate (from mouse event) + * @param {object|null} props.data - Tooltip data: { type, content } + * @param {object} [props.formatters={}] - Custom formatters: + * - value: (v) => string - Format numeric values + * - index: (v) => string - Format index values + * @returns {React.ReactElement|null} Positioned tooltip or null if not visible */ const PlotTooltip = ({ visible, x, y, data, formatters = {} }) => { if (!visible || !data) return null; @@ -41,6 +69,9 @@ const PlotTooltip = ({ visible, x, y, data, formatters = {} }) => { /** * Render tooltip content based on data structure + * @param {object} data - Data object containing type and content + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered tooltip content */ function renderTooltipContent(data, formatters) { const { type, content } = data; @@ -61,6 +92,12 @@ function renderTooltipContent(data, formatters) { } } +/** + * Render series data tooltip + * @param {object} content - Series content with index and values + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered series tooltip + */ function renderSeriesData(content, formatters) { const { index, indexLabel, values } = content; const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(4) : v); @@ -83,6 +120,12 @@ function renderSeriesData(content, formatters) { ); } +/** + * Render scatter plot data tooltip + * @param {object} content - Scatter content with x, y coordinates and label + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered scatter tooltip + */ function renderScatterData(content, formatters) { const { x, y, label, color } = content; const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(4) : v); @@ -112,6 +155,12 @@ function renderScatterData(content, formatters) { ); } +/** + * Render matrix data tooltip + * @param {object} content - Matrix content with row, column and value + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered matrix tooltip + */ function renderMatrixData(content, formatters) { const { row, col, value } = content; const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(2) : v); @@ -134,6 +183,12 @@ function renderMatrixData(content, formatters) { ); } +/** + * Render curve data tooltip + * @param {object} content - Curve content with x, y coordinates and labels + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered curve tooltip + */ function renderCurveData(content, formatters) { const { x, y, xLabel, yLabel, metric } = content; const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(4) : v); @@ -155,6 +210,12 @@ function renderCurveData(content, formatters) { ); } +/** + * Render distribution data tooltip + * @param {object} content - Distribution content with count and range + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered distribution tooltip + */ function renderDistributionData(content, formatters) { const { count, range } = content; const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(2) : v); @@ -177,6 +238,12 @@ function renderDistributionData(content, formatters) { ); } +/** + * Render generic data tooltip + * @param {object} content - Generic content object with key-value pairs + * @param {object} formatters - Formatters for value display + * @returns {object} Rendered generic tooltip + */ function renderGenericData(content, formatters) { const formatValue = formatters.value || ((v) => typeof v === 'number' ? v.toFixed(4) : v); diff --git a/src/app/hooks/useCanvasSetup.js b/src/app/hooks/useCanvasSetup.js index 4166fe5..227aad8 100644 --- a/src/app/hooks/useCanvasSetup.js +++ b/src/app/hooks/useCanvasSetup.js @@ -50,13 +50,42 @@ export function getChartFont(type) { } /** - * Reusable hook for setting up canvas with proper DPR scaling - * Eliminates ~10 lines of boilerplate per chart component + * Custom hook for canvas setup with HiDPI (Retina) display support * - * @param {React.RefObject} canvasRef - Reference to canvas element - * @param {number} defaultWidth - Default width if clientWidth unavailable - * @param {number} defaultHeight - Default height if clientHeight unavailable - * @returns {Function} setupCanvas - Function that returns { ctx, width, height, dpr } or null + * Handles the common canvas setup pattern: DPR scaling, transform reset, and white background. + * Eliminates ~10 lines of boilerplate from every plot component. + * + * Key features: + * - Automatic device pixel ratio (DPR) detection and scaling for crisp rendering on Retina displays + * - Transform matrix reset to prevent cumulative scaling bugs + * - White background fill + * - Returns ready-to-use context with dimensions + * + * Why transform reset matters: + * React may reuse canvas DOM nodes between components. If transform isn't reset, + * the scale matrix accumulates (2x β†’ 4x β†’ 8x) causing massive zoom bugs. + * + * @param {React.RefObject} canvasRef - Ref to canvas element + * @param {number} [defaultWidth=600] - Fallback width if clientWidth is 0 + * @param {number} [defaultHeight=300] - Fallback height if clientHeight is 0 + * @returns {function} setupCanvas - Function that returns setup object or null: + * - ctx: CanvasRenderingContext2D - Drawing context (already DPR-scaled) + * - width: number - Display width in CSS pixels + * - height: number - Display height in CSS pixels + * - dpr: number - Device pixel ratio (1 for standard, 2+ for Retina) + * + * @example + * const canvasRef = useRef(null); + * const setupCanvas = useCanvasSetup(canvasRef); + * + * const draw = useCallback(() => { + * const setup = setupCanvas(); + * if (!setup) return; + * + * const { ctx, width, height } = setup; + * // Draw using CSS pixels - DPR already handled + * ctx.fillRect(0, 0, width, height); + * }, [setupCanvas]); */ export function useCanvasSetup(canvasRef, defaultWidth = 600, defaultHeight = 300) { const setupCanvas = useCallback(() => { @@ -108,6 +137,10 @@ const CHART_PADDING = { /** * Helper to get chart dimensions after padding + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @param {object} padding - Padding configuration object + * @returns {object} Chart dimensions object with chartWidth and chartHeight */ export function getChartDimensions(width, height, padding = CHART_PADDING) { return { @@ -118,6 +151,13 @@ export function getChartDimensions(width, height, padding = CHART_PADDING) { /** * Draw axes on canvas + * @param {object} ctx - Canvas context + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @param {object} padding - Padding configuration object + * @param {string} xLabel - Label for X-axis + * @param {string} yLabel - Label for Y-axis + * @returns {void} */ export function drawAxes(ctx, width, height, padding, xLabel, yLabel) { ctx.strokeStyle = '#333'; @@ -144,6 +184,13 @@ export function drawAxes(ctx, width, height, padding, xLabel, yLabel) { /** * Draw grid lines with labels + * @param {object} ctx - Canvas context + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @param {object} padding - Padding configuration object + * @param {number} numLines - Number of grid lines to draw + * @param {(value: number) => string} formatLabel - Function to format grid labels + * @returns {void} */ export function drawGridLines(ctx, width, height, padding, numLines = 5, formatLabel = (v) => v.toFixed(1)) { const { chartWidth, chartHeight } = getChartDimensions(width, height, padding); @@ -183,6 +230,15 @@ export function drawGridLines(ctx, width, height, padding, numLines = 5, formatL /** * Draw Y-axis with grid lines and labels * Eliminates ~30 lines of duplicate code per chart + * @param {object} ctx - Canvas context + * @param {number} width - Canvas width + * @param {number} height - Canvas height + * @param {object} padding - Padding configuration object + * @param {number} yMin - Minimum Y value + * @param {number} yMax - Maximum Y value + * @param {(value: number) => string} formatYValue - Function to format Y-axis values + * @param {number} numGridLines - Number of grid lines to draw + * @returns {void} */ export function drawYAxisWithGrid(ctx, width, height, padding, yMin, yMax, formatYValue, numGridLines = 5) { ctx.strokeStyle = '#e0e0e0'; @@ -210,12 +266,12 @@ export function drawYAxisWithGrid(ctx, width, height, padding, yMin, yMax, forma /** * Render rotated X-axis labels with collision detection and staggering * Eliminates 40+ lines of duplicate code in LinePlot, Histogram, ViolinPlot - * - * @param {CanvasRenderingContext2D} ctx - Canvas context + * @param {object} ctx - Canvas context * @param {Array} labels - Array of label strings * @param {Array} xPositions - X positions for each label * @param {number} baseY - Base Y position for labels - * @param {Object} options - Optional configuration + * @param {object} options - Optional configuration + * @returns {void} */ export function renderRotatedLabels(ctx, labels, xPositions, baseY, options = {}) { const { @@ -275,12 +331,11 @@ export function renderRotatedLabels(ctx, labels, xPositions, baseY, options = {} /** * Draw legend at the top of the chart with automatic wrapping * Returns the height consumed by the legend for layout adjustment - * - * @param {CanvasRenderingContext2D} ctx - Canvas context + * @param {object} ctx - Canvas context * @param {Array} items - Legend items [{label, color}, ...] * @param {number} width - Canvas width * @param {number} startY - Y position to start drawing legend (top of chart area) - * @param {Object} options - Spacing and styling options + * @param {object} options - Spacing and styling options * @returns {number} - Total height consumed by legend */ export function drawTopLegend(ctx, items, width, startY, options = {}) { @@ -355,6 +410,10 @@ export function drawTopLegend(ctx, items, width, startY, options = {}) { /** * Calculate dynamic bottom padding for legend * Helps charts adjust height based on number of legend items + * @param {number} numItems - Number of legend items + * @param {number} canvasWidth - Canvas width + * @param {object} options - Optional configuration + * @returns {number} Calculated padding value */ export function calculateLegendPadding(numItems, canvasWidth, options = {}) { const { diff --git a/src/app/hooks/useCanvasTooltip.js b/src/app/hooks/useCanvasTooltip.js index 0b798e2..8b9c599 100644 --- a/src/app/hooks/useCanvasTooltip.js +++ b/src/app/hooks/useCanvasTooltip.js @@ -1,23 +1,68 @@ import { useState, useCallback, useRef, useEffect } from 'react'; /** - * Hook for managing tooltips on canvas-based plots + * Custom hook for managing interactive tooltips on canvas-based plots * - * Usage: - * const tooltip = useCanvasTooltip({ - * canvasRef, - * getTooltipData: (canvasX, canvasY, searchRadius) => { - * // Your plot-specific logic to find nearest point - * // Return { type: 'series', content: {...} } or null - * } - * }); + * Provides smooth, performant tooltip updates at 60fps using requestAnimationFrame. + * This hook handles all the complexity of: + * - Mouse position tracking relative to canvas + * - Finding nearest data points via custom search function + * - Throttling updates to avoid layout thrashing + * - Proper cleanup on unmount * - * return ( - * <> - * - * - * - * ); + * Performance optimizations: + * - Uses requestAnimationFrame to throttle tooltip updates to 60fps max + * - Cancels pending RAF callbacks when new mousemove events arrive + * - Only triggers React re-renders when tooltip data actually changes + * - Cleans up event listeners and RAF callbacks on unmount + * + * How it works: + * 1. Attach mousemove/mouseleave listeners to canvas element + * 2. On mousemove: convert screen coords to canvas coords + * 3. Schedule RAF callback to find nearby data points + * 4. If data found: update tooltip state with position and content + * 5. PlotTooltip component renders the actual tooltip DOM + * + * @param {object} params - Hook configuration + * @param {React.RefObject} params.canvasRef - Ref to canvas element + * @param {function} params.getTooltipData - Function to find data near cursor position + * Signature: (canvasX: number, canvasY: number, searchRadius: number) => object|null + * Should return tooltip data object or null if no data nearby + * Example return: { type: 'series', content: { x: 10, y: 20, seriesName: 'loss' } } + * @param {number} [params.searchRadius=20] - Pixel radius for nearby point detection + * @returns {object} Tooltip state for PlotTooltip component: + * - visible: boolean - Whether tooltip should be shown + * - x: number - Screen X coordinate for tooltip + * - y: number - Screen Y coordinate for tooltip + * - data: object|null - Tooltip content data + * + * @example + * // In a plot component: + * const tooltip = useCanvasTooltip({ + * canvasRef, + * searchRadius: 25, + * getTooltipData: (canvasX, canvasY, radius) => { + * // Find nearest data point within radius + * const point = findNearestPoint(canvasX, canvasY, radius); + * if (!point) return null; + * + * return { + * type: 'scatter', + * content: { + * x: point.x, + * y: point.y, + * label: point.label + * } + * }; + * } + * }); + * + * return ( + * <> + * + * + * + * ); */ export function useCanvasTooltip({ canvasRef, diff --git a/src/app/hooks/useDragModeResize.js b/src/app/hooks/useDragModeResize.js index dac3b30..add8591 100644 --- a/src/app/hooks/useDragModeResize.js +++ b/src/app/hooks/useDragModeResize.js @@ -1,11 +1,35 @@ import { useEffect } from 'react'; /** - * Hook to automatically redraw canvas/component when resized in drag-enabled mode - * Sets up ResizeObserver only when parent has drag-enabled class to avoid feedback loops + * Custom hook for conditional resize observation in drag-enabled mode * - * @param {React.RefObject} elementRef - Ref to the element to observe (usually canvas) - * @param {Function} redrawCallback - Function to call when resize is detected + * Automatically redraws canvas/chart elements when resized, but ONLY when drag mode + * is active. This prevents resize feedback loops and unnecessary redraws during normal + * layout. + * + * Why conditional observation: + * - Avoids infinite loops (resize β†’ redraw β†’ resize β†’ ...) + * - Improves performance (no unnecessary redraws in normal mode) + * - Enables responsive resize handles in drag mode + * + * How it works: + * 1. Uses MutationObserver to watch parent for 'drag-enabled' class changes + * 2. When drag-enabled added: Attach ResizeObserver to element + * 3. When drag-enabled removed: Detach ResizeObserver + * 4. ResizeObserver triggers redrawCallback on size changes + * + * @param {React.RefObject} elementRef - Ref to element to observe (typically canvas) + * @param {function} redrawCallback - Function to call when element resizes + * Should redraw the visualization using new dimensions + * @returns {void} + * + * @example + * const canvasRef = useRef(null); + * const drawChart = useCallback(() => { + * // Redraw chart logic + * }, []); + * + * useDragModeResize(canvasRef, drawChart); */ export function useDragModeResize(elementRef, redrawCallback) { useEffect(() => { @@ -18,6 +42,10 @@ export function useDragModeResize(elementRef, redrawCallback) { let resizeObserver = null; // Set up or tear down ResizeObserver based on drag-enabled state + /** + * Sets up or tears down the ResizeObserver based on drag-enabled state + * @returns {void} + */ const setupResizeObserver = () => { const isDragEnabled = parent.classList.contains('drag-enabled'); diff --git a/src/app/hooks/useLayoutManager.js b/src/app/hooks/useLayoutManager.js index 6b8f197..a362d38 100644 --- a/src/app/hooks/useLayoutManager.js +++ b/src/app/hooks/useLayoutManager.js @@ -1,8 +1,10 @@ import { useState, useCallback, useRef } from 'react'; /** - * Default layout configuration - defines which components go in which row - * Each array represents a row, with component keys that should be grouped together + * Default layout configuration - defines which visualization components go in which row + * + * Each array represents a row of visualizations that should be grouped together. + * This provides a sensible default organization for ML experiment visualizations. */ const DEFAULT_LAYOUT = [ ['trainingData', 'lossCurve2D', 'lossSurface3D'], // Row 1: Training data + Loss curve/surface @@ -13,9 +15,73 @@ const DEFAULT_LAYOUT = [ ]; /** - * Transform-based layout manager - components stay in flexbox flow, use transforms for dragging - * No absolute positioning mode - everything stays in document flow - * Dragging applies CSS transforms, drag end reorders DOM if needed + * Custom hook for managing draggable/resizable visualization layout + * + * Provides a transform-based drag-and-drop system that keeps components in normal + * document flow (using flexbox) while applying CSS transforms during drag operations. + * This approach avoids the complexity of absolute positioning while still allowing + * smooth 60fps drag interactions. + * + * Architecture: + * - Components remain in flexbox flow (not position: absolute) + * - During drag: Apply CSS transform for visual feedback + * - After drag: Clear transform, optionally reorder DOM elements + * - Resize: Track dimensions per-component (future: persist to localStorage) + * - Custom labels: Store user-edited axis labels/titles per-visualization + * + * Key features: + * - Smooth dragging using CSS transforms (GPU-accelerated) + * - No layout thrashing (no position/size calculations during drag) + * - Automatic z-index management (dragged element on top) + * - Custom label persistence per visualization + * - Element and container registration system + * + * State management: + * - dragTransforms: CSS translate values for currently dragging elements + * - draggingKey: Which visualization is being dragged (for styling) + * - customLabels: User-edited labels/titles per visualization + * + * @param {Array>} [layoutConfig=DEFAULT_LAYOUT] - Row-based layout configuration + * Each nested array defines visualizations that should be grouped in a row + * @returns {object} Layout manager interface: + * - dragTransforms: object - CSS transform values for dragging elements + * - draggingKey: string|null - Key of currently dragging element + * - handleDragStart: function - Start drag operation + * - handleDrag: function - Update drag position + * - handleDragEnd: function - Finish drag, clear transforms + * - handleResize: function - Handle resize events + * - registerElement: function - Register a visualization element + * - registerContainer: function - Register the container element + * - updateLabels: function - Update custom labels for a visualization + * - customLabels: object - Map of visualization keys to custom label objects + * - layoutConfig: array - The layout configuration being used + * + * @example + * const { + * dragTransforms, + * draggingKey, + * handleDragStart, + * handleDrag, + * handleDragEnd, + * registerElement, + * registerContainer + * } = useLayoutManager(); + * + * return ( + *
+ * + * + * + *
+ * ); */ export function useLayoutManager(layoutConfig = DEFAULT_LAYOUT) { const containerRef = useRef(null); diff --git a/src/app/hooks/useResponsiveCanvas.js b/src/app/hooks/useResponsiveCanvas.js index 865bcf6..5de801e 100644 --- a/src/app/hooks/useResponsiveCanvas.js +++ b/src/app/hooks/useResponsiveCanvas.js @@ -1,17 +1,56 @@ import { useState, useEffect } from 'react'; /** - * Generic hook for responsive canvas rendering - * Handles canvas sizing, ResizeObserver, and automatic redraw on dimension changes + * Custom hook for responsive canvas rendering with automatic resize handling * - * SIMPLE RULE: Container controls size, canvas adapts. No height constraints. + * Manages canvas dimensions and automatically redraws when the container size changes. + * Uses ResizeObserver for efficient resize detection and avoids unnecessary redraws. * - * @param {React.RefObject} canvasRef - Reference to canvas element - * @param {Function} drawCallback - Function to call when canvas needs redraw - * - Parameters: (width, height) => void - * @param {Object} options - Optional configuration - * @param {number} options.defaultHeight - Default canvas height before layout calculation (default: 600) - * @returns {{ width: number, height: number }} - Current canvas dimensions + * Key features: + * - Automatically tracks container size changes (parent element resizing) + * - Debounces dimension updates to avoid redundant redraws + * - Uses ResizeObserver (modern, performant alternative to window resize events) + * - Handles device pixel ratio (DPR) via drawCallback + * - Cleans up observers on unmount + * + * Design philosophy: + * Container controls size β†’ Canvas adapts β†’ Draw callback renders + * The canvas element's size is determined by CSS (parent container), not hardcoded dimensions. + * + * How it works: + * 1. Observe canvas element size using ResizeObserver + * 2. When size changes, update dimensions state + * 3. Trigger drawCallback with new width/height + * 4. drawCallback is responsible for setting canvas.width/height with DPR scaling + * + * @param {React.RefObject} canvasRef - Ref to canvas element + * @param {function} drawCallback - Function called when canvas needs redraw + * Signature: (width: number, height: number) => void + * Callback should handle canvas.width/height and ctx.scale(dpr, dpr) for HiDPI + * @param {object} [options={}] - Optional configuration + * @param {number} [options.defaultHeight=600] - Fallback height if clientHeight is 0 + * @returns {{width: number, height: number}} Current canvas dimensions (CSS pixels) + * + * @example + * const canvasRef = useRef(null); + * + * const drawChart = useCallback((width, height) => { + * const canvas = canvasRef.current; + * const ctx = canvas.getContext('2d'); + * const dpr = window.devicePixelRatio || 1; + * + * // Set bitmap size (accounting for DPR) + * canvas.width = width * dpr; + * canvas.height = height * dpr; + * ctx.scale(dpr, dpr); + * + * // Draw using CSS pixels + * ctx.fillRect(0, 0, width, height); + * }, []); + * + * useResponsiveCanvas(canvasRef, drawChart); + * + * return ; */ export function useResponsiveCanvas(canvasRef, drawCallback, options = {}) { const { defaultHeight = 600 } = options; @@ -21,6 +60,10 @@ export function useResponsiveCanvas(canvasRef, drawCallback, options = {}) { const canvas = canvasRef.current; if (!canvas) return; + /** + * Updates canvas dimensions based on current element size + * @returns {void} + */ const updateDimensions = () => { const width = canvas.clientWidth; const height = canvas.clientHeight || defaultHeight; diff --git a/src/app/hooks/useRunData.js b/src/app/hooks/useRunData.js index 6100ed6..0fec133 100644 --- a/src/app/hooks/useRunData.js +++ b/src/app/hooks/useRunData.js @@ -2,8 +2,45 @@ import { useState, useEffect, useCallback } from 'react'; import { apiClient } from '@/core/api/ApiClient'; /** - * Unified hook for fetching run data from REST API - * Polls continuously to pick up new runs from database + * Custom hook for fetching and polling experiment run data from REST API + * + * Central data-fetching hook used throughout the application. Loads all runs from + * the backend and continuously polls for new runs (e.g., from ongoing experiments). + * + * Features: + * - Initial fetch on mount with loading state + * - Continuous polling every 2 seconds for new runs + * - Smart re-render prevention (only updates when run list changes) + * - Stable reference if runs unchanged (prevents downstream re-renders) + * - Error handling with error state + * - Run count tracking + * - Individual run metrics fetching + * + * Why polling: + * - Detects newly completed experiments without page refresh + * - Simple alternative to WebSocket/SSE for small-scale deployments + * - 2-second interval balances freshness vs server load + * + * Re-render optimization: + * - Compares run_id lists, not full object deep equality + * - Returns previous reference if runs unchanged + * - Critical for components with `runs` in dependencies (LineageTab, etc.) + * + * @returns {object} Run data and utilities: + * - runs: Array - All experiment runs + * - loading: boolean - Initial loading state + * - error: string|null - Error message if fetch failed + * - runCount: number - Total number of runs + * - fetchRuns: function - Manually trigger fetch (isInitialFetch: boolean) + * - fetchRunMetrics: function - Fetch metrics for specific run (runId: string) + * + * @example + * const { runs, loading, error, fetchRuns } = useRunData(); + * + * if (loading) return
Loading...
; + * if (error) return
Error: {error}
; + * + * return fetchRuns(true)} />; */ export const useRunData = () => { const [runs, setRuns] = useState([]); @@ -11,7 +48,11 @@ export const useRunData = () => { const [error, setError] = useState(null); const [runCount, setRunCount] = useState(0); - // Fetch all runs from REST API + /** + * Fetch all runs from REST API. + * @param {boolean} isInitialFetch - Whether this is the initial fetch + * @returns {Promise} + */ const fetchRuns = useCallback(async (isInitialFetch = false) => { try { // Only show loading on initial fetch, not on polls @@ -53,7 +94,11 @@ export const useRunData = () => { } }, []); - // Fetch metrics for a specific run (raw data only, UI does formatting) + /** + * Fetch metrics for a specific run (raw data only, UI does formatting). + * @param {string} runId - Run ID to fetch metrics for + * @returns {Promise} Metrics data for the run + */ const fetchRunMetrics = useCallback(async (runId) => { try { return await apiClient.getRunMetrics(runId); diff --git a/src/app/pages/Workspace/CollapsibleSection.jsx b/src/app/pages/Workspace/CollapsibleSection.jsx index b099881..99983b2 100644 --- a/src/app/pages/Workspace/CollapsibleSection.jsx +++ b/src/app/pages/Workspace/CollapsibleSection.jsx @@ -3,6 +3,14 @@ import { motion, AnimatePresence } from 'framer-motion'; import { HiChevronDown, HiChevronUp } from 'react-icons/hi'; import './CollapsibleSection.scss'; +/** + * Collapsible section component with animation + * @param {object} props - Component props + * @param {string} [props.title] - Optional section title + * @param {React.ReactNode} props.children - Section content + * @param {boolean} [props.defaultCollapsed] - Whether section starts collapsed + * @returns {React.ReactElement} The collapsible section component + */ export function CollapsibleSection({ title, children, defaultCollapsed = false }) { const [isCollapsed, setIsCollapsed] = useState(defaultCollapsed); diff --git a/src/app/pages/Workspace/PlotSection.jsx b/src/app/pages/Workspace/PlotSection.jsx index 14dee6e..6ee1dfc 100644 --- a/src/app/pages/Workspace/PlotSection.jsx +++ b/src/app/pages/Workspace/PlotSection.jsx @@ -7,17 +7,10 @@ import UniversalVisualizationRenderer from '@/app/components/visualizations/Univ /** * PlotSection - Collapsible section with draggable plots * Each section manages dragging/positioning for its own plots - * - * React key includes dataset count to prevent canvas reuse in multi-run mode - * - * Problem: When multiple runs are selected, line plot IDs remain constant (e.g., "Loss_line") - * while the data changes (1 dataset -> 2 datasets -> 3 datasets). Without dataset count - * in the React key, React's reconciliation reuses the same component instance, which means - * the canvas element and its ref are also reused. This caused multiple plots to draw on the - * same physical canvas, resulting in overlapping axis labels and incorrect zoom behavior. - * - * Solution: Include dataset count in the React key (e.g., "Loss_line_4_2" for 2 datasets). - * When the number of datasets changes, React creates a fresh component with a new canvas. + * @param {object} props - Component props + * @param {string} props.sectionName - Name of the section + * @param {Array} props.plots - Array of plot configurations to render + * @returns {React.ReactElement} The plot section component */ export const PlotSection = ({ sectionName, @@ -43,6 +36,11 @@ export const PlotSection = ({ updateLabels } = useLayoutManager(); + /** + * Toggles the visibility of a specific plot by ID + * @param {string} plotId - The ID of the plot to toggle + * @returns {void} + */ const togglePlotVisibility = (plotId) => { setHiddenPlots(prev => { const next = new Set(prev); @@ -128,16 +126,9 @@ export const PlotSection = ({ {plots.map((plotConfig, index) => { if (hiddenPlots.has(plotConfig.id)) return null; - // CRITICAL: Include dataset count in React key to prevent canvas reuse - // When multiple runs are selected, plot IDs stay the same (e.g., "Loss_line") - // but the number of datasets changes (1 dataset -> 2 datasets -> 3 datasets). - // Without proper key changes, React reuses the same component instance. - const datasetCount = plotConfig.data?.datasets?.length || 0; - const uniqueKey = `${plotConfig.id}_${index}_${datasetCount}`; - return ( { // Sidebar state const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false); diff --git a/src/app/utils/artifactColors.js b/src/app/utils/artifactColors.js index 23328fc..f897a8f 100644 --- a/src/app/utils/artifactColors.js +++ b/src/app/utils/artifactColors.js @@ -83,8 +83,8 @@ export function getArtifactColor(name) { } /** - * Get all available extension colors (for reference/debugging) - * @returns {Object} Extension to color mapping + * Get all available extension colors (for reference/debugging). + * @returns {object} Extension to color mapping */ export function getAllExtensionColors() { return { ...EXTENSION_COLORS }; diff --git a/src/app/utils/comparisonPlotDiscovery.js b/src/app/utils/comparisonPlotDiscovery.js index 3533ee8..f4493e7 100644 --- a/src/app/utils/comparisonPlotDiscovery.js +++ b/src/app/utils/comparisonPlotDiscovery.js @@ -1,12 +1,31 @@ /** * Utility functions for cross-run comparison analysis - * Supports parameter correlation and value extraction from runs + * + * This module provides tools for analyzing relationships between hyperparameters + * and metrics across multiple experiment runs. The primary use case is hyperparameter + * sweep analysis, where we want to understand which parameters most strongly affect + * which metrics. + * + * Key capabilities: + * - Extract parameter/metric values from runs with flexible aggregation + * - Calculate Pearson correlation between parameters and metrics + * - Handle both numeric and categorical parameters + * - Support multiple aggregation strategies (last, max, min, avg) */ /** - * Get short display name for a parameter (last segment after period) + * Get short display name for a parameter by extracting the last segment after a period + * + * Useful for displaying nested parameter names in charts where space is limited. + * For example, "model.optimizer.learning_rate" becomes "learning_rate". + * * @param {string} paramName - Full parameter name (e.g., "training.optimizer.learningRate") - * @returns {string} - Short name (e.g., "learningRate") + * @returns {string} Short name (e.g., "learningRate"), or empty string if input is falsy + * + * @example + * getShortParamName("model.optimizer.lr") // Returns "lr" + * getShortParamName("batch_size") // Returns "batch_size" (no dots) + * getShortParamName("") // Returns "" */ export function getShortParamName(paramName) { if (!paramName) return ''; @@ -15,11 +34,38 @@ export function getShortParamName(paramName) { } /** - * Extract values for specified fields from a run - * For metrics: aggregates values based on the specified method - * @param {Object} run - The run object - * @param {Array} fields - Fields to extract - * @param {String} aggregation - 'last', 'max', 'min', or 'avg' (default: 'last') + * Extract values for specified fields from a run, with flexible aggregation for metrics + * + * This function searches for field values in multiple locations within a run object: + * 1. run.params - Single-value parameters (no aggregation) + * 2. run.config - Configuration values (no aggregation) + * 3. run.tags - Tag values (no aggregation) + * 4. run.structured_data - Series data that requires aggregation + * + * For series data (metrics logged over time), the aggregation parameter controls + * how multiple values are reduced to a single value for comparison. + * + * @param {object} run - The run object containing params, config, tags, and structured_data + * @param {Array} fields - Field names to extract (e.g., ["learning_rate", "train_loss"]) + * @param {string} [aggregation='last'] - How to aggregate series data: + * - 'last': Use final value (default, good for final metrics) + * - 'max': Use maximum value (good for accuracy, best performance) + * - 'min': Use minimum value (good for loss, error rates) + * - 'avg': Use average value (good for stability analysis) + * @returns {object} Object mapping field names to extracted values. Missing fields map to null. + * + * @example + * const run = { + * config: { learning_rate: 0.01 }, + * structured_data: { + * metrics: [{ + * primitive_type: 'series', + * data: { fields: { loss: [0.5, 0.3, 0.2] } } + * }] + * } + * }; + * extractValues(run, ['learning_rate', 'loss'], 'last'); + * // Returns { learning_rate: 0.01, loss: 0.2 } */ export function extractValues(run, fields, aggregation = 'last') { const values = {}; @@ -66,9 +112,6 @@ export function extractValues(run, fields, aggregation = 'last') { case 'min': values[field] = Math.min(...fieldValues); break; - case 'avg': - values[field] = fieldValues.reduce((a, b) => a + b, 0) / fieldValues.length; - break; case 'last': default: values[field] = fieldValues[fieldValues.length - 1]; @@ -86,13 +129,51 @@ export function extractValues(run, fields, aggregation = 'last') { } /** - * Calculate parameter importance using correlation analysis - * Returns importance scores for each hyperparameter relative to each metric + * Calculate parameter correlation using Pearson correlation analysis + * + * Analyzes the linear relationship between hyperparameters and metrics across multiple runs. + * This is useful for understanding which hyperparameters most strongly affect which metrics + * in a hyperparameter sweep. + * + * Algorithm: + * 1. For each metric and each hyperparameter pair: + * - Extract values from all runs using specified aggregation + * - Filter out runs where either value is missing + * - Calculate Pearson correlation coefficient (-1 to +1) + * - Store both correlation and absolute value (importance score) * - * Note: Full random forest implementation would require ML library - * For now, using Pearson correlation as a simpler alternative + * 2. Pearson correlation measures linear relationship: + * - +1.0: Perfect positive correlation (param ↑ β†’ metric ↑) + * - 0.0: No linear correlation + * - -1.0: Perfect negative correlation (param ↑ β†’ metric ↓) + * + * 3. Importance score is |correlation|, treating strong negative correlations + * as equally important to strong positive ones. + * + * Limitations: + * - Only detects LINEAR relationships (won't catch quadratic, exponential, etc.) + * - Correlation β‰  causation + * - Not true feature importance (use permutation importance, SHAP, or random forests for that) + * - Requires at least 3 runs for meaningful results + * - Categorical parameters are converted to ordinal (0, 1, 2...) which may be misleading + * + * @param {Array} runs - Array of run objects with params/config/structured_data + * @param {Array} hyperparameters - Parameter names to analyze (e.g., ["learning_rate", "batch_size"]) + * @param {Array} metrics - Metric names to analyze (e.g., ["accuracy", "loss"]) + * @param {string} [aggregation='last'] - How to aggregate metric series ('last', 'max', 'min') + * @returns {object|null} Nested object: { metric: { param: { correlation, importance } } } + * Returns null if fewer than 3 runs (insufficient for correlation) + * + * @example + * const runs = [ + * { config: { lr: 0.01 }, structured_data: {...} }, // final loss: 0.5 + * { config: { lr: 0.1 }, structured_data: {...} }, // final loss: 0.3 + * { config: { lr: 1.0 }, structured_data: {...} } // final loss: 0.8 + * ]; + * const result = calculateParameterCorrelation(runs, ['lr'], ['loss'], 'last'); + * // Returns: { loss: { lr: { correlation: 0.5, importance: 0.5 } } } */ -export function calculateParameterImportance(runs, hyperparameters, metrics, aggregation = 'last') { +export function calculateParameterCorrelation(runs, hyperparameters, metrics, aggregation = 'last') { if (runs.length < 3) { // Need at least 3 runs for meaningful correlation return null; @@ -133,7 +214,22 @@ export function calculateParameterImportance(runs, hyperparameters, metrics, agg } /** - * Calculate Pearson correlation coefficient + * Calculate Pearson correlation coefficient between two arrays + * + * Pearson correlation formula: r = Ξ£[(xi - xΜ„)(yi - Θ³)] / sqrt(Ξ£(xi - xΜ„)Β² * Ξ£(yi - Θ³)Β²) + * Simplified to: r = [n*Ξ£(xy) - Ξ£x*Ξ£y] / sqrt([n*Ξ£xΒ² - (Ξ£x)Β²] * [n*Ξ£yΒ² - (Ξ£y)Β²]) + * + * The coefficient ranges from -1 to +1: + * - r = +1: Perfect positive linear relationship + * - r = 0: No linear relationship + * - r = -1: Perfect negative linear relationship + * + * Handles categorical data by converting to ordinal indices (0, 1, 2, ...). + * + * @param {Array} x - First array of values + * @param {Array} y - Second array of values (must be same length as x) + * @returns {number} Correlation coefficient between -1 and 1, or 0 if calculation fails + * (e.g., arrays different lengths, < 2 values, zero variance) */ function calculateCorrelation(x, y) { const n = x.length; @@ -158,7 +254,19 @@ function calculateCorrelation(x, y) { } /** - * Convert values to numeric (handle categorical variables) + * Convert values to numeric, handling categorical variables + * + * For numeric arrays: returns as-is + * For categorical arrays: maps unique values to ordinal indices + * + * Example: ["adam", "sgd", "adam", "rmsprop"] β†’ [0, 1, 0, 2] + * + * Warning: This treats categorical values as ordinal, which may not be + * semantically correct (e.g., "adam" isn't "less than" "sgd"). Consider + * one-hot encoding for proper categorical analysis. + * + * @param {Array} values - Array of values to convert (can be numbers, strings, etc.) + * @returns {Array} Array of numeric values. Categorical values mapped to indices (0, 1, 2, ...) */ function convertToNumeric(values) { // Already numeric diff --git a/src/app/utils/metricAggregation.js b/src/app/utils/metricAggregation.js index 789f39b..79b3e5d 100644 --- a/src/app/utils/metricAggregation.js +++ b/src/app/utils/metricAggregation.js @@ -1,13 +1,34 @@ /** * Client-side metric aggregation utilities - * FULLY GENERIC - no assumptions about data structure + * + * Provides generic aggregation functions for extracting summary statistics from + * metrics logged across training runs. Designed to be fully data-structure agnostic, + * working with any x-axis field (epoch, step, timestamp, etc.) and any metric names. + * + * Key capabilities: + * - Auto-detect x-axis fields (epoch, step, etc.) + * - Aggregate metrics (min, max, final, best) + * - Extract metrics from Series primitives in structured_data + * - Find optimal values based on optimization metrics (e.g., accuracy at min loss) + * - Discover available metrics across runs + * + * Used primarily by the Tables tab for displaying run comparisons. */ /** - * Detect x-axis field from metrics array - * Generic: finds integer fields that appear in all entries - * @param {Array} metricsArray - Array of metric objects - * @returns {string|null} - X-axis field name or null + * Detect x-axis field from metrics array using heuristics + * + * Finds integer-valued fields that could represent the x-axis (epoch, step, iteration). + * If multiple candidates exist, prefers smaller values (epoch over timestamp). + * + * Algorithm: + * 1. Find all integer-valued fields (excludes internal fields starting with _) + * 2. If only one integer field, use it + * 3. If multiple, prefer field with smallest value (epoch=10 over timestamp=1674...) + * + * @param {Array} metricsArray - Array of metric records + * Example: [{ epoch: 0, loss: 0.5 }, { epoch: 1, loss: 0.3 }] + * @returns {string|null} X-axis field name (e.g., "epoch", "step") or null if none found */ const detectXAxisField = (metricsArray) => { if (!metricsArray || metricsArray.length === 0) return null; @@ -44,11 +65,33 @@ const getValueAtX = (metricsArray, metricKey, xValue, xAxisField) => { }; /** - * Generic aggregation function - replaces getBestValue, getMinValue, getMaxValue - * @param {Array} metricsArray - Array of metric objects - * @param {string} metricKey - The metric field name - * @param {string} mode - 'min', 'max', or 'best' (auto min/max based on metric name) - * @returns {{value: number, xValue: number, xField: string}|null} - Aggregated value and its x-axis point + * Generic aggregation function for finding min/max/best metric values + * + * Scans through metrics array and finds the optimal value according to the specified mode. + * Returns both the value and the x-axis coordinate where it occurred. + * + * Modes: + * - 'min': Find minimum value (good for loss, error) + * - 'max': Find maximum value (good for accuracy, F1) + * - 'best': Auto-detect based on metric name (min for loss/error, max for others) + * + * @param {Array} metricsArray - Array of metric records with x-axis and metric values + * @param {string} metricKey - The metric field to aggregate (e.g., "loss", "accuracy") + * @param {string} mode - Aggregation mode: 'min', 'max', or 'best' + * @returns {{value: number, xValue: number, xField: string}|null} Result object: + * - value: The aggregated metric value + * - xValue: X-axis value where this occurred (e.g., epoch 42) + * - xField: Name of x-axis field (e.g., "epoch") + * Returns null if no valid numeric values found + * + * @example + * const metrics = [ + * { epoch: 0, loss: 0.5 }, + * { epoch: 1, loss: 0.3 }, + * { epoch: 2, loss: 0.2 } + * ]; + * aggregateValue(metrics, 'loss', 'min'); + * // Returns: { value: 0.2, xValue: 2, xField: 'epoch' } */ const aggregateValue = (metricsArray, metricKey, mode) => { if (!metricsArray || metricsArray.length === 0) return null; @@ -60,8 +103,20 @@ const aggregateValue = (metricsArray, metricKey, mode) => { // Determine comparison function based on mode let shouldUpdate; if (mode === 'min') { + /** + * Check if new value is less than current value. + * @param {number} newVal - New value to compare + * @param {number} currentVal - Current value + * @returns {boolean} True if new value is less + */ shouldUpdate = (newVal, currentVal) => newVal < currentVal; } else if (mode === 'max') { + /** + * Check if new value is greater than current value. + * @param {number} newVal - New value to compare + * @param {number} currentVal - Current value + * @returns {boolean} True if new value is greater + */ shouldUpdate = (newVal, currentVal) => newVal > currentVal; } else if (mode === 'best') { const isLossMetric = metricKey.includes('loss') || metricKey.includes('error'); @@ -136,7 +191,7 @@ const getFinalValue = (metricsArray, metricKey) => { * Discover metrics organized by stream * Returns a map of streamId -> metric keys * @param {Array} runs - Array of run objects - * @returns {Object} - Object mapping streamId to array of metric keys + * @returns {object} - Object mapping streamId to array of metric keys */ export const discoverMetricsByStream = (runs) => { const seriesMetrics = {}; // seriesName -> Set of metric keys @@ -177,7 +232,7 @@ export const discoverMetricsByStream = (runs) => { /** * Discover Series primitives from structured_data (NEW - uses primitives) * @param {Array} runs - Array of run objects - * @returns {Object} - { seriesName: { metricKeys: [], runs: [{ run_id, data: [] }] } } + * @returns {object} - { seriesName: { metricKeys: [], runs: [{ run_id, data: [] }] } } */ /** @@ -203,15 +258,39 @@ const getMaxValue = (metricsArray, metricKey) => { }; /** - * Get value based on aggregation mode - * Searches structured_data Series primitives for the metric - * @param {Object} run - Run object - * @param {string} metricKey - The metric field to display - * @param {Object} aggregation - Aggregation configuration - * @param {string} aggregation.mode - 'min', 'max', 'final', or x-axis value - * @param {string} aggregation.optimizeMetric - Metric to optimize (for min/max modes) - * @param {string} aggregation.streamId - Optional: specific series to look in - * @returns {number|null} - Value at the specified mode + * Get metric value using specified aggregation strategy + * + * This is the main entry point for extracting a single aggregated metric value from a run. + * Used extensively by the Tables tab to display summary statistics. + * + * The function searches through the run's structured_data for Series primitives containing + * the requested metric, then applies the specified aggregation logic. + * + * Aggregation modes: + * - 'min': Return metricKey value at the point where optimizeMetric is minimized + * Example: "Show accuracy when loss was lowest" + * - 'max': Return metricKey value at the point where optimizeMetric is maximized + * Example: "Show loss when accuracy was highest" + * - 'final': Return the last logged value of metricKey + * - Numeric value: Return metricKey value at that specific x-axis point + * + * @param {object} run - Run object with structured_data containing Series primitives + * @param {string} metricKey - The metric to extract (e.g., "accuracy", "val_loss") + * @param {object} [aggregation={mode:'min',optimizeMetric:'loss'}] - Aggregation config: + * - mode: string - 'min', 'max', 'final', or numeric x-axis value + * - optimizeMetric: string - Metric to optimize for min/max modes + * - streamId: string (optional) - Specific Series primitive to search in + * @returns {number|null} Aggregated metric value, or null if metric not found + * + * @example + * // Get final validation loss + * getMetricValue(run, 'val_loss', { mode: 'final' }); + * + * // Get accuracy at the point where loss was minimized + * getMetricValue(run, 'accuracy', { mode: 'min', optimizeMetric: 'loss' }); + * + * // Get loss at epoch 50 + * getMetricValue(run, 'loss', { mode: 50 }); */ export const getMetricValue = (run, metricKey, aggregation = { mode: 'min', optimizeMetric: 'loss' }) => { if (!run.structured_data) return null; @@ -223,11 +302,11 @@ export const getMetricValue = (run, metricKey, aggregation = { mode: 'min', opti }; /** - * Extract metrics array from Series primitives in structured_data - * @param {Object} structured_data - Run's structured_data object + * Extract metrics array from Series primitives in structured_data. + * @param {object} structured_data - Run's structured_data object * @param {string} metricKey - The metric field to find * @param {string} seriesId - Optional: specific series to look in - * @returns {Array|null} - Array of metric objects [{x, metricKey}, ...] + * @returns {Array|null} Array of metric objects or null */ function extractMetricsFromSeries(structured_data, metricKey, seriesId = null) { for (const [name, entries] of Object.entries(structured_data)) { @@ -267,11 +346,11 @@ function extractMetricsFromSeries(structured_data, metricKey, seriesId = null) { } /** - * Apply aggregation logic to metrics array - * @param {Array} metricsArray - Array of metric objects + * Apply aggregation logic to metrics array. + * @param {Array} metricsArray - Array of metric objects * @param {string} metricKey - The metric field to display - * @param {Object} aggregation - Aggregation configuration - * @returns {number|null} - Aggregated value + * @param {object} aggregation - Aggregation configuration + * @returns {number|null} Aggregated value or null */ function applyAggregation(metricsArray, metricKey, aggregation) { const { mode, optimizeMetric } = aggregation; diff --git a/src/app/utils/plotDiscovery.js b/src/app/utils/plotDiscovery.js index 7690cb8..92a9ddf 100644 --- a/src/app/utils/plotDiscovery.js +++ b/src/app/utils/plotDiscovery.js @@ -1,20 +1,69 @@ /** * Universal plot discovery from structured data primitives * - * Maps data primitives to plot types: - * - Series β†’ line, scatter, area - * - Distribution β†’ histogram, violin, box - * - Matrix β†’ heatmap, clustergram - * - Graph β†’ network, tree - * - Table β†’ table, pivot - * - Events β†’ timeline, gantt - * - Media β†’ gallery, video_player + * This module implements automatic plot discovery - analyzing logged data primitives + * and determining the appropriate visualizations to display. It's the core algorithm + * that powers the "Plots" tab, automatically creating charts without manual configuration. + * + * Primitive β†’ Plot Type Mapping: + * - Series β†’ Line plots, scatter plots, multi-run overlays + * - Distribution β†’ Histograms, violin plots + * - Matrix β†’ Heatmaps + * - Scatter β†’ Scatter plots (2D point clouds) + * - Curve β†’ ROC curves, PR curves (with AUC metrics) + * - BarChart β†’ Bar charts (categorical comparisons) + * - Table β†’ Data tables + * + * Multi-run support: + * - Automatically detects when multiple runs log the same primitive + * - Creates overlay plots (multiple series on same chart) + * - Handles both single-run and multi-run cases transparently */ /** * Discover all plots from structured data, grouped by section - * @param {Object} structuredData - Object mapping name -> array of data entries - * @returns {Object} Object mapping section -> array of plot configs + * + * This is the main entry point for plot discovery. It analyzes all logged primitives + * in a run (or across multiple runs) and generates plot configurations that can be + * rendered by visualization components. + * + * Algorithm: + * 1. Iterate through all structured data entries (keyed by primitive name) + * 2. Detect if data is single-run or multi-run (check for _runId marker) + * 3. For each primitive: + * a. Extract primitive type (series, curve, distribution, etc.) + * b. Map to appropriate plot type(s) via mapPrimitiveToPlots() + * c. Transform data to plot-specific format + * d. Group by section (user-specified or "General") + * 4. Return section-organized plot configs ready for rendering + * + * Single-run vs Multi-run: + * - Single-run: Takes most recent entry for each primitive + * - Multi-run: Combines all runs' data into overlay plots + * + * @param {object} structuredData - Object mapping primitive names to data entries + * Format: { "primitiveName": [{ primitive_type, section, data, metadata, timestamp }] } + * Multi-run format: entries have _runId, _runName fields + * @returns {object} Plots organized by section + * Format: { "sectionName": [{ plotId, plotType, title, data, ... }] } + * + * @example + * const structuredData = { + * "metrics::loss": [{ + * primitive_type: "series", + * section: "Training", + * data: { index: "epoch", fields: { loss: [0.5, 0.3, 0.2] } } + * }] + * }; + * const plots = discoverPlots(structuredData); + * // Returns: { + * // "Training": [{ + * // plotId: "metrics::loss-line", + * // plotType: "line", + * // title: "Loss", + * // data: { series: [...] } + * // }] + * // } */ export function discoverPlots(structuredData) { if (!structuredData) return {}; @@ -72,12 +121,13 @@ export function discoverPlots(structuredData) { } /** - * Map a primitive to one or more plot configs + * Map a primitive to one or more plot configs. * @param {string} name - Primitive name * @param {string} primitiveType - Type of primitive - * @param {any} data - Either single-run data or array of multi-run entries + * @param {unknown} data - Either single-run data or array of multi-run entries * @param {object} metadata - Metadata (for single-run) * @param {boolean} isMultiRun - Whether this is multi-run data + * @returns {Array} Array of plot configuration objects */ function mapPrimitiveToPlots(name, primitiveType, data, metadata, isMultiRun = false) { const plots = []; @@ -209,7 +259,9 @@ function mapPrimitiveToPlots(name, primitiveType, data, metadata, isMultiRun = f } /** - * Format name for display (snake_case -> Title Case) + * Format name for display (snake_case -> Title Case). + * @param {string} name - Name to format + * @returns {string} Formatted title string */ function formatTitle(name) { return name @@ -219,7 +271,9 @@ function formatTitle(name) { } /** - * Transform Series to line plot format + * Transform Series to line plot format. + * @param {object} seriesData - Series data with index, fields, and index_values + * @returns {object} Transformed data for line plot */ function transformSeriesForLinePlot(seriesData) { const { index, fields, index_values } = seriesData; @@ -240,7 +294,9 @@ function transformSeriesForLinePlot(seriesData) { } /** - * Transform Distribution to histogram format + * Transform Distribution to histogram format. + * @param {object} distData - Distribution data with values, groups, and bins + * @returns {object} Transformed data for histogram */ function transformDistributionForHistogram(distData) { return { @@ -251,7 +307,9 @@ function transformDistributionForHistogram(distData) { } /** - * Transform Distribution to violin plot format + * Transform Distribution to violin plot format. + * @param {object} distData - Distribution data with values and groups + * @returns {object} Transformed data for violin plot */ function transformDistributionForViolin(distData) { // Group values by group label @@ -269,7 +327,9 @@ function transformDistributionForViolin(distData) { } /** - * Transform Curve to CurveChart format + * Transform Curve to CurveChart format. + * @param {object} curveData - Curve data with x, y, labels, baseline, and metric + * @returns {object} Transformed data for curve chart */ function transformCurveForCurveChart(curveData) { const { x, y, x_label, y_label, baseline, metric } = curveData; @@ -291,7 +351,9 @@ function transformCurveForCurveChart(curveData) { } /** - * Transform Scatter to ScatterPlot format + * Transform Scatter to ScatterPlot format. + * @param {object} scatterData - Scatter data with points, x_label, and y_label + * @returns {object} Transformed data for scatter plot */ function transformScatterForScatterPlot(scatterData) { const { points, x_label, y_label } = scatterData; @@ -330,7 +392,9 @@ function transformScatterForScatterPlot(scatterData) { /** * Transform multi-run Series to line plot format - * Overlays all runs on same chart with different colors + * Overlays all runs on same chart with different colors. + * @param {Array} runEntries - Array of run entries with data, _runName, and _runId + * @returns {object} Transformed data for multi-run line plot */ function transformMultiRunSeriesForLinePlot(runEntries) { const allDatasets = []; @@ -357,7 +421,9 @@ function transformMultiRunSeriesForLinePlot(runEntries) { /** * Transform multi-run Curve to curve chart format - * Overlays multiple ROC/PR curves with different colors + * Overlays multiple ROC/PR curves with different colors. + * @param {Array} runEntries - Array of run entries with curve data + * @returns {object} Transformed data for multi-run curve chart */ function transformMultiRunCurveForCurveChart(runEntries) { // For curves, we need to return multiple curve datasets diff --git a/src/app/utils/sweepDetection.js b/src/app/utils/sweepDetection.js index e8bac1f..158d6b5 100644 --- a/src/app/utils/sweepDetection.js +++ b/src/app/utils/sweepDetection.js @@ -1,17 +1,61 @@ /** * Sweep Detection Utilities * - * Detects valid hyperparameter sweeps across runs for comparison visualization. - * A valid sweep has: - * 1. All runs share the same config keys - * 2. Exactly ONE parameter varies across runs - * 3. All other parameters are constant + * Detects valid hyperparameter sweeps across experiment runs for comparison visualization. + * + * A sweep is a set of runs where hyperparameters are systematically varied to understand + * their effect on metrics. This module validates sweep structure and extracts metadata + * needed for sweep visualization (parallel coordinates, scatter plots, correlation analysis). + * + * Validation criteria: + * 1. All runs must have a config object + * 2. All runs share the same config keys (consistent structure) + * 3. At least one parameter varies across runs + * 4. At least 2 runs (need comparison) + * + * The module also extracts final metrics from each run's structured_data for comparison. */ /** * Detect if selected runs form a valid hyperparameter sweep - * @param {Array} runs - Array of run objects with config - * @returns {Object|null} - Sweep info or null if invalid + * + * Algorithm: + * 1. Filter runs that have non-empty config + * 2. Verify all runs have identical config key structure + * 3. For each config key, check if values vary across runs + * 4. Classify parameters as varying (different values) or constant (same value) + * 5. Extract final metric values from each run's structured_data + * 6. Return sweep metadata including varying params, metrics, and transformed runs + * + * A sweep is valid if: + * - At least 2 runs with config + * - All configs have same keys + * - At least 1 parameter varies + * + * @param {Array} runs - Array of run objects, each with: + * - run_id: unique identifier + * - name: run name + * - config: object with hyperparameters + * - structured_data: logged metrics/plots + * @returns {object|null} Sweep detection result: + * - If valid: { valid: true, varyingParams, constantParams, runs, availableMetrics, ... } + * - If invalid: { valid: false, reason, message } + * - If insufficient runs: null + * + * @example + * const runs = [ + * { config: { lr: 0.01, batch: 32 }, structured_data: {...} }, + * { config: { lr: 0.1, batch: 32 }, structured_data: {...} }, + * { config: { lr: 1.0, batch: 32 }, structured_data: {...} } + * ]; + * const sweep = detectSweep(runs); + * // Returns: { + * // valid: true, + * // varyingParams: [{ name: 'lr', values: [0.01, 0.1, 1.0], isNumeric: true }], + * // constantParams: [{ name: 'batch', value: 32 }], + * // runs: [...], // enriched with metrics + * // availableMetrics: ['loss', 'accuracy'] + * // } */ export function detectSweep(runs) { if (!runs || runs.length < 2) { @@ -103,9 +147,19 @@ export function detectSweep(runs) { } /** - * Extract final metrics from a run's structured_data - * @param {Object} run - Run object - * @returns {Object} - Map of metric name to final value + * Extract final metric values from a run's structured_data + * + * Searches through all structured data entries (Series, Curve primitives) and extracts + * the final/summary value for each metric. This is used to create a single metric + * value per run for sweep comparison. + * + * Extraction logic: + * - Series primitives: Takes the LAST value from each field array (final epoch value) + * - Curve primitives: Extracts the summary metric (e.g., AUC from ROC curve) + * + * @param {object} run - Run object with structured_data property + * @returns {object} Map of metric names to final numeric values + * Example: { loss: 0.23, accuracy: 0.95, auc: 0.88 } */ function extractFinalMetrics(run) { const metrics = {}; @@ -140,10 +194,25 @@ function extractFinalMetrics(run) { } /** - * Extract all available metrics across runs - * Only includes metrics that have valid numeric values in at least one run - * @param {Array} runs - Runs with extracted metrics - * @returns {Array} - Array of unique metric names + * Extract all available metric names across multiple runs + * + * Collects the union of all metric names that have valid numeric values in at least + * one run. Filters out null, undefined, NaN, and non-numeric values. + * + * This is used to populate metric selector dropdowns in sweep visualizations, ensuring + * users only see metrics that actually have data. + * + * @param {Array} runs - Runs with extracted metrics (output from extractFinalMetrics) + * Each run should have shape: { metrics: { metricName: numericValue, ... } } + * @returns {Array} Sorted array of unique metric names that have valid numeric data + * + * @example + * const runs = [ + * { metrics: { loss: 0.5, accuracy: 0.9 } }, + * { metrics: { loss: 0.3, f1: 0.85 } } + * ]; + * extractAvailableMetrics(runs); + * // Returns: ['accuracy', 'f1', 'loss'] (sorted alphabetically) */ function extractAvailableMetrics(runs) { const allMetrics = new Set(); diff --git a/src/core/api/ApiClient.js b/src/core/api/ApiClient.js index bdd748c..3f7e81d 100644 --- a/src/core/api/ApiClient.js +++ b/src/core/api/ApiClient.js @@ -4,15 +4,25 @@ * Eliminates duplicate fetch/try-catch patterns across the codebase. */ -const API_BASE_URL = import.meta.env.VITE_API_URL; +const API_BASE_URL = import.meta.env.VITE_API_URL || ''; +/** + * Centralized API client for backend communication. + */ class ApiClient { + /** + * Creates a new ApiClient instance. + * @param {string} baseUrl - Base URL for API requests + */ constructor(baseUrl = API_BASE_URL) { this.baseUrl = baseUrl; } /** - * Generic request method with centralized error handling + * Generic request method with centralized error handling. + * @param {string} endpoint - API endpoint path + * @param {object} options - Fetch options + * @returns {Promise} Response object */ async request(endpoint, options = {}) { const url = `${this.baseUrl}${endpoint}`; @@ -26,7 +36,9 @@ class ApiClient { } /** - * GET request returning JSON + * GET request returning JSON. + * @param {string} endpoint - API endpoint path + * @returns {Promise} Parsed JSON response */ async get(endpoint) { const response = await this.request(endpoint); @@ -34,7 +46,9 @@ class ApiClient { } /** - * GET request returning raw Response (for CSV, images, etc.) + * GET request returning raw Response (for CSV, images, etc.). + * @param {string} endpoint - API endpoint path + * @returns {Promise} Raw Response object */ async getRaw(endpoint) { return this.request(endpoint); @@ -43,7 +57,13 @@ class ApiClient { // ========== Endpoints Actually Used ========== /** - * Get all runs with filters + * Get all runs with filters. + * @param {object} options - Query options + * @param {number} options.limit - Maximum number of runs to return + * @param {boolean} options.includeTags - Include tags in response + * @param {boolean} options.includeParams - Include parameters in response + * @param {boolean} options.includeMetadata - Include metadata in response + * @returns {Promise>} Array of run objects */ async getRuns(options = {}) { const { @@ -62,35 +82,47 @@ class ApiClient { } /** - * Get single run with full details + * Get single run with full details. + * @param {string} runId - Run ID + * @returns {Promise} Run object with full details */ async getRun(runId) { return this.get(`/api/runs/${runId}`); } /** - * Get metrics for a run + * Get metrics for a run. + * @param {string} runId - Run ID + * @returns {Promise} Metrics data */ async getRunMetrics(runId) { return this.get(`/api/runs/${runId}/metrics`); } /** - * Get artifacts for a run + * Get artifacts for a run. + * @param {string} runId - Run ID + * @returns {Promise>} Array of artifact objects */ async getArtifacts(runId) { return this.get(`/api/artifacts/${runId}`); } /** - * Get artifact links (with input/output role) for a run + * Get artifact links (with input/output role) for a run. + * @param {string} runId - Run ID + * @returns {Promise>} Array of artifact link objects */ async getArtifactLinks(runId) { return this.get(`/api/runs/${runId}/artifact-links`); } /** - * Get artifact preview (raw response for CSV/JSON/image handling) + * Get artifact preview (raw response for CSV/JSON/image handling). + * @param {string} artifactId - Artifact ID + * @param {number} offset - Offset for pagination + * @param {number} limit - Maximum number of items + * @returns {Promise} Raw response for artifact preview */ async getArtifactPreview(artifactId, offset = 0, limit = 100) { const params = new URLSearchParams({ @@ -101,7 +133,9 @@ class ApiClient { } /** - * Get artifact download URL + * Get artifact download URL. + * @param {string} artifactId - Artifact ID + * @returns {string} Download URL for artifact */ getArtifactDownloadUrl(artifactId) { return `${this.baseUrl}/api/artifact/${artifactId}/download`; @@ -110,14 +144,17 @@ class ApiClient { // ========== Project Endpoints ========== /** - * Get all projects (explicit + implicit from runs) + * Get all projects (explicit + implicit from runs). + * @returns {Promise>} Array of project objects */ async getProjects() { return this.get('/api/projects'); } /** - * Create a new project + * Create a new project. + * @param {string} projectId - Project ID + * @returns {Promise} Created project object */ async createProject(projectId) { const response = await this.request('/api/projects', { @@ -131,21 +168,29 @@ class ApiClient { // ========== Project Notes Endpoints ========== /** - * Get all notes for a project + * Get all notes for a project. + * @param {string} projectId - Project ID + * @returns {Promise>} Array of note objects */ async getProjectNotes(projectId) { return this.get(`/api/projects/${projectId}/notes`); } /** - * Get a specific note + * Get a specific note. + * @param {string} projectId - Project ID + * @param {string} noteId - Note ID + * @returns {Promise} Note object */ async getProjectNote(projectId, noteId) { return this.get(`/api/projects/${projectId}/notes/${noteId}`); } /** - * Create a new note + * Create a new note. + * @param {string} projectId - Project ID + * @param {object} data - Note data + * @returns {Promise} Created note object */ async createProjectNote(projectId, data) { const response = await this.request(`/api/projects/${projectId}/notes`, { @@ -157,7 +202,11 @@ class ApiClient { } /** - * Update an existing note + * Update an existing note. + * @param {string} projectId - Project ID + * @param {string} noteId - Note ID + * @param {object} data - Updated note data + * @returns {Promise} Updated note object */ async updateProjectNote(projectId, noteId, data) { const response = await this.request(`/api/projects/${projectId}/notes/${noteId}`, { @@ -169,7 +218,10 @@ class ApiClient { } /** - * Delete a note + * Delete a note. + * @param {string} projectId - Project ID + * @param {string} noteId - Note ID + * @returns {Promise} */ async deleteProjectNote(projectId, noteId) { await this.request(`/api/projects/${projectId}/notes/${noteId}`, { @@ -178,14 +230,19 @@ class ApiClient { } /** - * Get attachments for a note + * Get attachments for a note. + * @param {string} noteId - Note ID + * @returns {Promise>} Array of attachment objects */ async getNoteAttachments(noteId) { return this.get(`/api/notes/${noteId}/attachments`); } /** - * Upload attachment to a note + * Upload attachment to a note. + * @param {string} noteId - Note ID + * @param {File} file - File to upload + * @returns {Promise} Uploaded attachment object */ async uploadAttachment(noteId, file) { const formData = new FormData(); @@ -199,7 +256,9 @@ class ApiClient { } /** - * Delete an attachment + * Delete an attachment. + * @param {string} attachmentId - Attachment ID + * @returns {Promise} */ async deleteAttachment(attachmentId) { await this.request(`/api/attachments/${attachmentId}`, { @@ -208,7 +267,9 @@ class ApiClient { } /** - * Get attachment download URL + * Get attachment download URL. + * @param {string} attachmentId - Attachment ID + * @returns {string} Download URL for attachment */ getAttachmentDownloadUrl(attachmentId) { return `${this.baseUrl}/api/attachments/${attachmentId}/download`; diff --git a/src/core/utils/csvExport.js b/src/core/utils/csvExport.js index a9b990b..f088141 100644 --- a/src/core/utils/csvExport.js +++ b/src/core/utils/csvExport.js @@ -4,8 +4,8 @@ */ /** - * Escapes a field for CSV format - * @param {*} field - The field to escape + * Escapes a field for CSV format. + * @param {string|number|boolean} field - The field to escape * @returns {string} Escaped field */ function escapeCSV(field) { @@ -17,8 +17,8 @@ function escapeCSV(field) { } /** - * Formats a value for CSV export - * @param {*} value - The value to format + * Formats a value for CSV export. + * @param {string|number|null|undefined} value - The value to format * @param {number} decimals - Number of decimal places for numbers (default: 4) * @returns {string} Formatted value */ @@ -74,9 +74,9 @@ function getDisplayName(key) { /** * Groups series by their metric signature - * Series with the same set of metrics are grouped together - * @param {Object} metricsByStream - Object mapping stream names to metric keys - * @returns {Object} Series groups indexed by signature + * Series with the same set of metrics are grouped together. + * @param {object} metricsByStream - Object mapping stream names to metric keys + * @returns {object} Series groups indexed by signature */ function groupSeriesBySignature(metricsByStream) { const seriesGroups = {}; @@ -105,12 +105,12 @@ function filterRunsWithSeries(runs) { } /** - * Builds CSV rows for a specific series group - * @param {Array} runs - Runs to include in this group - * @param {Array} metricKeys - Metric keys for this group - * @param {Array} seriesGroup - Series information for this group - * @param {Function} getMetricValue - Function to extract metric values - * @param {Object} aggSettings - Aggregation settings (mode, optimizeMetric) + * Builds CSV rows for a specific series group. + * @param {Array} runs - Runs to include in this group + * @param {Array} metricKeys - Metric keys for this group + * @param {Array} seriesGroup - Series information for this group + * @param {(run: object, key: string, options: object) => string|number} getMetricValue - Function to extract metric values + * @param {object} aggSettings - Aggregation settings (mode, optimizeMetric) * @returns {Array>} 2D array of CSV rows (header + data) */ function buildSeriesGroupCSV(runs, metricKeys, seriesGroup, getMetricValue, aggSettings) { @@ -152,11 +152,12 @@ function getSeriesGroupFilename(seriesGroup) { /** * Exports all series groups as separate CSV files - * Main orchestration function that handles the complete export workflow - * @param {Array} runs - All runs to export - * @param {Object} metricsByStream - Metrics organized by stream - * @param {Object} streamAggregation - Aggregation settings per group - * @param {Function} getMetricValue - Function to extract metric values + * Main orchestration function that handles the complete export workflow. + * @param {Array} runs - All runs to export + * @param {object} metricsByStream - Metrics organized by stream + * @param {object} streamAggregation - Aggregation settings per group + * @param {(run: object, key: string, options: object) => string|number} getMetricValue - Function to extract metric values + * @returns {void} */ export function exportSeriesGroupsAsCSV(runs, metricsByStream, streamAggregation, getMetricValue) { if (runs.length === 0) return; diff --git a/src/core/utils/formatters.js b/src/core/utils/formatters.js index fe8cdac..472de79 100644 --- a/src/core/utils/formatters.js +++ b/src/core/utils/formatters.js @@ -5,8 +5,7 @@ /** * Format number for Y-axis labels with smart units - * Handles bytes specially, and uses K/M/B/T suffixes for large numbers - * + * Handles bytes specially, and uses K/M/B/T suffixes for large numbers. * @param {number} value - The value to format * @param {string} title - Optional title to determine if value represents bytes * @returns {string} Formatted string @@ -40,8 +39,7 @@ export function formatYAxisValue(value, title = '') { /** * Convert snake_case or kebab-case to Title Case - * Example: "train_loss" -> "Train Loss" - * + * Example: "train_loss" -> "Train Loss". * @param {string} str - String to convert * @returns {string} Title cased string */ diff --git a/src/utils/consoleFilter.js b/src/utils/consoleFilter.js index 931e94c..9d3f8ed 100644 --- a/src/utils/consoleFilter.js +++ b/src/utils/consoleFilter.js @@ -4,6 +4,10 @@ export function initConsoleFilter() { const originalWarn = console.warn; + /** + * Wrapper function to filter console warnings from React Flow and React DevTools. + * @param {...unknown} args - Arguments passed to console.warn + */ console.warn = function(...args) { const message = args[0]; const shouldFilter = typeof message === 'string' && ( diff --git a/src/utils/debugLogger.js b/src/utils/debugLogger.js index 6e38a0a..2ca2b8c 100644 --- a/src/utils/debugLogger.js +++ b/src/utils/debugLogger.js @@ -3,7 +3,13 @@ * Enabled via VITE_DEBUG_LOGS=true environment variable */ +/** + * DebugLogger class that captures console logs and saves them to disk. + */ class DebugLogger { + /** + * Creates a new DebugLogger instance. + */ constructor() { this.logs = []; this.enabled = import.meta.env.VITE_DEBUG_LOGS === 'true'; @@ -17,12 +23,19 @@ class DebugLogger { } } + /** + * Intercepts console methods to capture logs. + */ interceptConsole() { const self = this; const methods = ['log', 'warn', 'error', 'info', 'debug']; methods.forEach(method => { const original = console[method]; + /** + * Wrapper function for console methods. + * @param {...unknown} args - Console method arguments + */ console[method] = function(...args) { // Filter out React Flow and React DevTools warnings const message = args[0]; @@ -63,8 +76,14 @@ class DebugLogger { }); } + /** + * Starts automatic writing of logs to disk every 3 seconds. + */ startAutoWrite() { // Write every 3 seconds + /** + * Interval callback to write logs to disk. + */ setInterval(() => { this.writeToDisk(); }, 3000); @@ -75,6 +94,9 @@ class DebugLogger { }); } + /** + * Writes captured logs to disk via API endpoint. + */ writeToDisk() { if (this.logs.length === 0) return; @@ -108,6 +130,9 @@ class DebugLogger { }); } + /** + * Downloads logs as a text file to the browser. + */ downloadLogs() { const timestamp = new Date().toISOString().replace(/[:.]/g, '-'); const filename = `artifacta-logs-${timestamp}.txt`; @@ -130,7 +155,9 @@ class DebugLogger { console.log(`[DebugLogger] Downloaded ${this.logs.length} logs to ${filename}`); } - // Expose method to download logs programmatically + /** + * Static method to download logs programmatically from global instance. + */ static download() { if (window.__debugLogger) { window.__debugLogger.downloadLogs(); @@ -143,5 +170,9 @@ class DebugLogger { // Initialize and expose globally if (import.meta.env.VITE_DEBUG_LOGS === 'true') { window.__debugLogger = new DebugLogger(); + /** + * Global function to download debug logs + * @returns {void} + */ window.downloadLogs = () => DebugLogger.download(); } diff --git a/tests/autolog/test_checkpoint_e2e.py b/tests/autolog/test_checkpoint_e2e.py index f666e40..099c7b1 100644 --- a/tests/autolog/test_checkpoint_e2e.py +++ b/tests/autolog/test_checkpoint_e2e.py @@ -69,11 +69,11 @@ def configure_optimizers(self): # The callback should have logged 2 checkpoints (one per epoch) ds_callback = None for cb in trainer.callbacks: - if type(cb).__name__ == "ArtifactaCheckpointCallback": + if type(cb).__name__ == "ArtifactaAutologCallback": ds_callback = cb break - assert ds_callback is not None, "ArtifactaCheckpointCallback not found" + assert ds_callback is not None, "ArtifactaAutologCallback not found" assert len(ds_callback.checkpoints_logged) == 2, ( f"Expected 2 checkpoints logged, got {len(ds_callback.checkpoints_logged)}" ) @@ -142,11 +142,11 @@ def configure_optimizers(self): # Verify only best checkpoints were logged ds_callback = None for cb in trainer.callbacks: - if type(cb).__name__ == "ArtifactaCheckpointCallback": + if type(cb).__name__ == "ArtifactaAutologCallback": ds_callback = cb break - assert ds_callback is not None, "ArtifactaCheckpointCallback not found" + assert ds_callback is not None, "ArtifactaAutologCallback not found" # With save_best_only, should log fewer than total epochs (unless loss perfectly decreases) # At minimum, should log at least 1 checkpoint (first epoch) assert len(ds_callback.checkpoints_logged) >= 1, ( diff --git a/tests/autolog/test_pytorch_lightning.py b/tests/autolog/test_pytorch_lightning.py index 7f48427..15e2377 100644 --- a/tests/autolog/test_pytorch_lightning.py +++ b/tests/autolog/test_pytorch_lightning.py @@ -13,6 +13,22 @@ pytestmark = pytest.mark.skipif(skip_pytorch, reason="PyTorch Lightning not installed") +@pytest.fixture +def temp_run(monkeypatch): + """Create and cleanup temporary run with mocked HTTP emitter.""" + from artifacta import init + from artifacta.tests.test_utils import MockHTTPEmitter + + # Temporarily disable strict mode so init doesn't fail without server + monkeypatch.delenv("ARTIFACTA_STRICT_MODE", raising=False) + + run = init(project="test_pytorch_lightning_autolog", name="test_run") + # Replace the http_emitter with our mock + run.http_emitter = MockHTTPEmitter(run.id) + yield run + run.finish() + + def test_pytorch_lightning_callback_injection(): """Test that autolog injects callback into Trainer""" import pytorch_lightning as pl @@ -47,8 +63,8 @@ def configure_optimizers(self): # Check that callback was injected callback_names = [type(cb).__name__ for cb in trainer.callbacks] - assert "ArtifactaCheckpointCallback" in callback_names, ( - f"ArtifactaCheckpointCallback not found in {callback_names}" + assert "ArtifactaAutologCallback" in callback_names, ( + f"ArtifactaAutologCallback not found in {callback_names}" ) # Cleanup @@ -113,7 +129,7 @@ def test_enable_disable_autolog(): # Create trainer - should have callback trainer1 = pl.Trainer(max_epochs=1, logger=False, enable_checkpointing=False) callback_names1 = [type(cb).__name__ for cb in trainer1.callbacks] - assert "ArtifactaCheckpointCallback" in callback_names1 + assert "ArtifactaAutologCallback" in callback_names1 # Disable ds.disable_autolog() @@ -121,29 +137,224 @@ def test_enable_disable_autolog(): # Create trainer - should NOT have callback trainer2 = pl.Trainer(max_epochs=1, logger=False, enable_checkpointing=False) callback_names2 = [type(cb).__name__ for cb in trainer2.callbacks] - assert "ArtifactaCheckpointCallback" not in callback_names2 + assert "ArtifactaAutologCallback" not in callback_names2 -def test_checkpoint_config_options(): - """Test checkpoint configuration options""" +def test_parameter_logging(temp_run): + """Test that autolog logs parameters (epochs, optimizer config)""" import pytorch_lightning as pl + import torch + from torch import nn + from torch.utils.data import DataLoader, TensorDataset import artifacta as ds - # Test with custom config + # Enable autolog ds.autolog(framework="pytorch") - trainer = pl.Trainer(max_epochs=1, logger=False, enable_checkpointing=False) + # Create simple model + class DummyModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 1) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, _): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.002, weight_decay=1e-5) + + # Create dummy data + x = torch.randn(100, 10) + y = torch.randn(100, 1) + dataset = TensorDataset(x, y) + dataloader = DataLoader(dataset, batch_size=10) + + # Train + model = DummyModel() + trainer = pl.Trainer(max_epochs=3, logger=False, enable_checkpointing=False) + trainer.fit(model, dataloader) + + # Verify parameters were added to config + config = temp_run.config + assert config["epochs"] == 3 + assert config["optimizer_name"] == "Adam" + assert config["lr"] == 0.002 + assert config["weight_decay"] == 1e-5 + + # Cleanup + ds.disable_autolog() + + +def test_metric_logging(temp_run): + """Test that autolog logs metrics per epoch""" + import pytorch_lightning as pl + import torch + from torch import nn + from torch.utils.data import DataLoader, TensorDataset + + import artifacta as ds + + # Enable autolog + ds.autolog(framework="pytorch") + + # Create simple model that logs metrics + class DummyModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 1) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, _): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + # Create dummy data + x = torch.randn(100, 10) + y = torch.randn(100, 1) + dataset = TensorDataset(x, y) + dataloader = DataLoader(dataset, batch_size=10) + + # Train + model = DummyModel() + trainer = pl.Trainer(max_epochs=3, logger=False, enable_checkpointing=False) + trainer.fit(model, dataloader) + + # Verify metrics were logged as Series data + # Check emitted_data for "training_metrics" series + logged_data = False + if hasattr(temp_run.http_emitter, 'emitted_data'): + for event_type, data in temp_run.http_emitter.emitted_data: + if event_type == "structured_data" and data.get("name") == "training_metrics": + logged_data = True + # Verify it has train_loss field in the fields dict + series_data = data.get("data", {}) + fields = series_data.get("fields", series_data) + assert "train_loss" in fields, "Should have train_loss in metrics" + break + assert logged_data, "Should have logged training_metrics as Series" + + # Cleanup + ds.disable_autolog() + + +def test_final_model_logging(temp_run): + """Test that autolog logs final trained model""" + import pytorch_lightning as pl + import torch + from torch import nn + from torch.utils.data import DataLoader, TensorDataset + + import artifacta as ds + + # Enable autolog with model logging + ds.autolog(framework="pytorch") + + # Create simple model + class DummyModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 1) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, _): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + # Create dummy data + x = torch.randn(100, 10) + y = torch.randn(100, 1) + dataset = TensorDataset(x, y) + dataloader = DataLoader(dataset, batch_size=10) + + # Train + model = DummyModel() + trainer = pl.Trainer(max_epochs=2, logger=False, enable_checkpointing=False) + trainer.fit(model, dataloader) + + # Verify final model was logged + artifacts = temp_run.http_emitter.emitted_artifacts + model_artifacts = [a for a in artifacts if a.get("name") == "model"] + assert len(model_artifacts) == 1, "Should have logged final model" + assert model_artifacts[0]["metadata"]["artifact_type"] == "model" + assert model_artifacts[0]["metadata"]["framework"] == "pytorch_lightning" + + # Cleanup + ds.disable_autolog() + + +def test_disable_checkpoints(temp_run): + """Test disabling checkpoint logging""" + import pytorch_lightning as pl + import torch + from artifacta.integrations import pytorch_lightning as pl_integration + from torch import nn + from torch.utils.data import DataLoader, TensorDataset + + import artifacta as ds + + # Enable autolog with checkpoints disabled + pl_integration.enable_autolog(log_checkpoints=False, log_models=True) + + # Create simple model + class DummyModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 1) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, _): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + # Create dummy data + x = torch.randn(100, 10) + y = torch.randn(100, 1) + dataset = TensorDataset(x, y) + dataloader = DataLoader(dataset, batch_size=10) + + # Train + model = DummyModel() + trainer = pl.Trainer(max_epochs=2, logger=False, enable_checkpointing=False) + trainer.fit(model, dataloader) - # Find Artifacta callback - ds_callback = None - for cb in trainer.callbacks: - if type(cb).__name__ == "ArtifactaCheckpointCallback": - ds_callback = cb - break + # Verify no checkpoints logged + artifacts = temp_run.http_emitter.emitted_artifacts + checkpoint_artifacts = [a for a in artifacts if "checkpoint" in a.get("name", "")] + assert len(checkpoint_artifacts) == 0, "Should not log checkpoints when disabled" - assert ds_callback is not None - # Config options no longer passed to autolog - just verify callback exists + # But final model should still be logged + model_artifacts = [a for a in artifacts if a.get("name") == "model"] + assert len(model_artifacts) == 1, "Should still log final model" # Cleanup ds.disable_autolog() diff --git a/tests/autolog/test_tensorflow.py b/tests/autolog/test_tensorflow.py index 3c0b575..19d6c30 100644 --- a/tests/autolog/test_tensorflow.py +++ b/tests/autolog/test_tensorflow.py @@ -13,13 +13,32 @@ pytestmark = pytest.mark.skipif(skip_tf, reason="TensorFlow not installed") -def test_tensorflow_autolog_e2e(): +@pytest.fixture +def temp_run(monkeypatch): + """Create and cleanup temporary run with mocked HTTP emitter.""" + from artifacta import init + from artifacta.tests.test_utils import MockHTTPEmitter + + # Temporarily disable strict mode so init doesn't fail without server + monkeypatch.delenv("ARTIFACTA_STRICT_MODE", raising=False) + + run = init(project="test_tensorflow_autolog", name="test_run") + # Replace the http_emitter with our mock + run.http_emitter = MockHTTPEmitter(run.id) + yield run + run.finish() + + +def test_tensorflow_autolog_e2e(monkeypatch): """End-to-end test: Train Keras model and verify checkpoint is auto-logged""" import numpy as np import tensorflow as tf import artifacta as ds + # Temporarily disable strict mode + monkeypatch.delenv("ARTIFACTA_STRICT_MODE", raising=False) + # Enable autolog ds.autolog(framework="tensorflow") @@ -47,13 +66,16 @@ def test_tensorflow_autolog_e2e(): ds.disable_autolog() -def test_tensorflow_best_only(): +def test_tensorflow_best_only(monkeypatch): """Test that save_best_only only logs improving checkpoints""" import numpy as np import tensorflow as tf import artifacta as ds + # Temporarily disable strict mode + monkeypatch.delenv("ARTIFACTA_STRICT_MODE", raising=False) + # Enable autolog with save_best_only ds.autolog(framework="tensorflow") @@ -98,3 +120,152 @@ def test_enable_disable_tensorflow(): # Should still be callable assert callable(tf.keras.Model.fit) + + +def test_parameter_logging(temp_run): + """Test that autolog logs parameters (epochs, batch_size, optimizer config)""" + import numpy as np + import tensorflow as tf + + import artifacta as ds + + # Enable autolog + ds.autolog(framework="tensorflow") + + # Create simple model + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5,)), + tf.keras.layers.Dense(1) + ]) + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.002), loss="mse") + + # Create dummy data + x_train = np.random.randn(100, 5).astype(np.float32) + y_train = np.random.randn(100, 1).astype(np.float32) + + # Train + model.fit(x_train, y_train, epochs=3, batch_size=16, verbose=0) + + # Verify parameters were added to config + config = temp_run.config + assert config["epochs"] == 3 + assert config["batch_size"] == 16 + assert config["optimizer_name"] == "Adam" + assert "learning_rate" in config or "lr" in config + + # Cleanup + ds.disable_autolog() + + +def test_metric_logging(temp_run): + """Test that autolog logs metrics per epoch""" + import numpy as np + import tensorflow as tf + + import artifacta as ds + + # Enable autolog + ds.autolog(framework="tensorflow") + + # Create simple model + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5,)), + tf.keras.layers.Dense(1) + ]) + model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + + # Create dummy data + x_train = np.random.randn(100, 5).astype(np.float32) + y_train = np.random.randn(100, 1).astype(np.float32) + + # Train + model.fit(x_train, y_train, epochs=3, batch_size=32, verbose=0) + + # Verify metrics were logged as Series data + logged_data = False + if hasattr(temp_run.http_emitter, 'emitted_data'): + for event_type, data in temp_run.http_emitter.emitted_data: + if event_type == "structured_data" and data.get("name") == "training_metrics": + logged_data = True + # Verify it has loss field + series_data = data.get("data", {}) + fields = series_data.get("fields", series_data) + assert "loss" in fields, "Should have loss in metrics" + break + assert logged_data, "Should have logged training_metrics as Series" + + # Cleanup + ds.disable_autolog() + + +def test_final_model_logging(temp_run): + """Test that autolog logs final trained model""" + import numpy as np + import tensorflow as tf + + import artifacta as ds + + # Enable autolog with model logging + ds.autolog(framework="tensorflow") + + # Create simple model + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5,)), + tf.keras.layers.Dense(1) + ]) + model.compile(optimizer="adam", loss="mse") + + # Create dummy data + x_train = np.random.randn(100, 5).astype(np.float32) + y_train = np.random.randn(100, 1).astype(np.float32) + + # Train + model.fit(x_train, y_train, epochs=2, verbose=0) + + # Verify final model was logged + artifacts = temp_run.http_emitter.emitted_artifacts + model_artifacts = [a for a in artifacts if a.get("name") == "model"] + assert len(model_artifacts) == 1, "Should have logged final model" + assert model_artifacts[0]["metadata"]["artifact_type"] == "model" + assert model_artifacts[0]["metadata"]["framework"] == "tensorflow" + + # Cleanup + ds.disable_autolog() + + +def test_disable_checkpoints(temp_run): + """Test disabling checkpoint logging""" + import numpy as np + import tensorflow as tf + from artifacta.integrations import tensorflow as tf_integration + + import artifacta as ds + + # Enable autolog with checkpoints disabled + tf_integration.enable_autolog(log_checkpoints=False, log_models=True) + + # Create simple model + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5,)), + tf.keras.layers.Dense(1) + ]) + model.compile(optimizer="adam", loss="mse") + + # Create dummy data + x_train = np.random.randn(100, 5).astype(np.float32) + y_train = np.random.randn(100, 1).astype(np.float32) + + # Train + model.fit(x_train, y_train, epochs=2, verbose=0) + + # Verify no checkpoints logged + artifacts = temp_run.http_emitter.emitted_artifacts + checkpoint_artifacts = [a for a in artifacts if "checkpoint" in a.get("name", "")] + assert len(checkpoint_artifacts) == 0, "Should not log checkpoints when disabled" + + # But final model should still be logged + model_artifacts = [a for a in artifacts if a.get("name") == "model"] + assert len(model_artifacts) == 1, "Should still log final model" + + # Cleanup + ds.disable_autolog() diff --git a/tests/domains/test_artifact_reuse.py b/tests/domains/test_artifact_reuse.py index c6b2711..543aa79 100644 --- a/tests/domains/test_artifact_reuse.py +++ b/tests/domains/test_artifact_reuse.py @@ -34,7 +34,7 @@ def test_artifact_chain(): # ============================================================ # RUN 1: TRAINING - Produces model as OUTPUT # ============================================================ - print("πŸ”΅ Run 1: Training") + print("Run 1: Training") run_train = ds.init( project="ml-pipeline", name="train-resnet", @@ -56,13 +56,13 @@ def test_artifact_chain(): run_train.log_output(model_path, name="trained_model") print(" Logged: trained_model.pt as OUTPUT") - print(" βœ… Training complete\n") + print(" Training complete\n") time.sleep(0.5) # ============================================================ # RUN 2: INFERENCE - Uses model as INPUT # ============================================================ - print("🟒 Run 2: Inference") + print("Run 2: Inference") run_inference = ds.init( project="ml-pipeline", name="inference-batch-1", @@ -91,13 +91,13 @@ def test_artifact_chain(): ), ) - print(" βœ… Inference complete\n") + print(" Inference complete\n") time.sleep(0.5) # ============================================================ # RUN 3: ANOTHER INFERENCE - Also uses same model as INPUT # ============================================================ - print("🟒 Run 3: Inference (Batch 2)") + print("Run 3: Inference (Batch 2)") run_inference_2 = ds.init( project="ml-pipeline", name="inference-batch-2", @@ -125,18 +125,18 @@ def test_artifact_chain(): ), ) - print(" βœ… Inference complete\n") + print(" Inference complete\n") print("=" * 60) - print("βœ… TEST COMPLETE!") + print("TEST COMPLETE!") print("=" * 60) - print("\nπŸ“Š Check the UI Lineage view (select all 3 runs):") + print("\nCheck the UI Lineage view (select all 3 runs):") print(" Expected graph:") print() print(" trained_model.pt") print(" |") - print(" train-resnet ────────────┴───────────→ inference-batch-1") - print(" └───────────→ inference-batch-2") + print(" train-resnet β†’ inference-batch-1") + print(" β†’ inference-batch-2") print() print(" - train-resnet has trained_model as OUTPUT (right side)") print(" - inference-batch-1 has trained_model as INPUT (left side)") diff --git a/tests/domains/test_pytorch.py b/tests/domains/test_pytorch.py index 856461f..c7465ad 100644 --- a/tests/domains/test_pytorch.py +++ b/tests/domains/test_pytorch.py @@ -32,7 +32,7 @@ def test_classification_runs(): ] for run_name, target_accuracy, epochs, learning_rate, batch_size in scenarios: - print(f"\nπŸ”΅ Creating: {run_name}") + print(f"\nCreating: {run_name}") config = { "learning_rate": learning_rate, "batch_size": batch_size, @@ -209,7 +209,7 @@ def test_classification_runs(): code_path = os.path.join(os.path.dirname(__file__), "../fixtures/code/pytorch") run.log_input(code_path) - print(f" βœ… Completed: {run_name} (final acc: {accuracy:.3f})") + print(f" Completed: {run_name} (final acc: {accuracy:.3f})") time.sleep(0.3) @@ -225,7 +225,7 @@ def test_regression_runs(): ] for run_name, target_r2, epochs, batch_size in scenarios: - print(f"\n🟒 Creating: {run_name}") + print(f"\nCreating: {run_name}") config = { "learning_rate": 0.001, "batch_size": batch_size, @@ -362,5 +362,5 @@ def test_regression_runs(): code_path = os.path.join(os.path.dirname(__file__), "../fixtures/code/pytorch") run.log_input(code_path) - print(f" βœ… Completed: {run_name} (final RΒ²: {r2:.3f})") + print(f" Completed: {run_name} (final RΒ²: {r2:.3f})") time.sleep(0.3) diff --git a/tests/domains/test_system.py b/tests/domains/test_system.py index a294f26..c25de58 100644 --- a/tests/domains/test_system.py +++ b/tests/domains/test_system.py @@ -52,5 +52,5 @@ def test_system_metrics_runs(): code_path = os.path.join(os.path.dirname(__file__), "../fixtures/code/system_monitor.py") run.log_input(code_path) - print(f" βœ… Completed: {name} profile") + print(f" Completed: {name} profile") time.sleep(0.3) diff --git a/tests/e2e/core.spec.js b/tests/e2e/core.spec.js new file mode 100644 index 0000000..c632a08 --- /dev/null +++ b/tests/e2e/core.spec.js @@ -0,0 +1,72 @@ +import { test, expect } from '@playwright/test'; + +/** + * Core E2E Tests for Artifacta UI + * + * These tests verify basic functionality of the web UI based on actual UI structure + */ + +test.describe('Artifacta Core UI', () => { + test('homepage loads successfully', async ({ page }) => { + await page.goto('/'); + await expect(page).toHaveTitle(/Artifacta/); + await page.waitForLoadState('networkidle'); + await expect(page.locator('text=Projects')).toBeVisible(); + await expect(page.locator('text=Runs')).toBeVisible(); + }); + + test('run list displays correctly', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + // Expand the Runs section + await page.locator('text=Runs').first().click(); + await page.waitForTimeout(1000); + + // Verify run data appears + await expect(page.locator('text=/all-primitives|Run/i').first()).toBeVisible({ timeout: 5000 }); + }); + + test('navigation between tabs works', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + await page.click('text=Notebooks'); + await page.click('text=Plots'); + await page.click('text=Tables'); + + // Verify page is still responsive + await expect(page.locator('text=Projects')).toBeVisible(); + }); + + test('health check endpoint returns healthy', async ({ request }) => { + const baseURL = process.env.ARTIFACTA_URL || 'http://localhost:8000'; + const response = await request.get(`${baseURL}/health`); + expect(response.ok()).toBeTruthy(); + + const data = await response.json(); + expect(data.status).toBe('healthy'); + expect(data.database_connected).toBe(true); + }); + + test('API returns run data', async ({ request }) => { + const baseURL = process.env.ARTIFACTA_URL || 'http://localhost:8000'; + const response = await request.get(`${baseURL}/api/runs?limit=100`); + expect(response.ok()).toBeTruthy(); + + const data = await response.json(); + expect(Array.isArray(data)).toBeTruthy(); + expect(data.length).toBeGreaterThan(0); + expect(data[0]).toHaveProperty('run_id'); + expect(data[0]).toHaveProperty('name'); + }); + + test('sidebar is interactive', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + await expect(page.locator('text=Projects').first()).toBeVisible(); + await expect(page.locator('text=Runs').first()).toBeVisible(); + await expect(page.locator('text=Files').first()).toBeVisible(); + }); +}); diff --git a/tests/e2e/setup.js b/tests/e2e/setup.js new file mode 100644 index 0000000..6a540b0 --- /dev/null +++ b/tests/e2e/setup.js @@ -0,0 +1,122 @@ +import { spawn, execSync } from 'child_process'; +import { promises as fs } from 'fs'; +import path from 'path'; +import { fileURLToPath } from 'url'; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); +const PROJECT_ROOT = path.resolve(__dirname, '../..'); + +/** + * Global setup for Playwright E2E tests + * + * - Starts the UI server + * - Cleans/creates fresh database + * - Runs example script to populate test data + * - Waits for server to be ready + */ + +let serverProcess; + +export default async function globalSetup() { + console.log('[E2E Setup] Starting global test setup...'); + + // 1. Clean database for fresh state + const dbPath = path.join(PROJECT_ROOT, 'data', 'runs.db'); + try { + await fs.unlink(dbPath); + console.log('[E2E Setup] βœ“ Cleaned database'); + } catch (err) { + // Database might not exist yet - that's fine + console.log('[E2E Setup] βœ“ No existing database to clean'); + } + + // 2. Start the UI server + console.log('[E2E Setup] Starting UI server...'); + const baseURL = process.env.ARTIFACTA_URL || 'http://localhost:8000'; + const port = new URL(baseURL).port || '8000'; + + serverProcess = spawn( + 'artifacta', + ['ui', '--port', port], + { + cwd: PROJECT_ROOT, + stdio: 'pipe', + shell: true, + } + ); + + // Handle server output + serverProcess.stdout.on('data', (data) => { + if (process.env.DEBUG) { + console.log(`[Server] ${data}`); + } + }); + + serverProcess.stderr.on('data', (data) => { + if (process.env.DEBUG) { + console.error(`[Server] ${data}`); + } + }); + + // 3. Wait for server to be ready + console.log('[E2E Setup] Waiting for server to be ready...'); + await waitForServer(baseURL, 30000); // 30 second timeout + console.log('[E2E Setup] βœ“ Server is ready'); + + // Wait a bit longer for database initialization to complete + await new Promise((resolve) => setTimeout(resolve, 2000)); + + // 4. Run example script to populate test data + console.log('[E2E Setup] Running example script to create test data...'); + try { + const exampleScript = path.join(PROJECT_ROOT, 'examples', 'core', '02_all_primitives.py'); + + const output = execSync(`python ${exampleScript}`, { + cwd: PROJECT_ROOT, + encoding: 'utf-8', + env: { + ...process.env, + ARTIFACTA_API_URL: baseURL, + }, + }); + console.log('[E2E Setup] Example script output:', output.substring(0, 500)); + console.log('[E2E Setup] βœ“ Test data created'); + } catch (err) { + console.error('[E2E Setup] βœ— Failed to create test data:', err.message); + throw err; + } + + console.log('[E2E Setup] βœ“ Global setup complete\n'); + + // Store server process for global teardown + global.__SERVER_PROCESS__ = serverProcess; +} + + +/** + * Wait for server to respond to health check + */ +async function waitForServer(baseURL, timeoutMs) { + const startTime = Date.now(); + const healthURL = `${baseURL}/health`; + + while (Date.now() - startTime < timeoutMs) { + try { + const response = await fetch(healthURL); + if (response.ok) { + const data = await response.json(); + if (data.status === 'healthy') { + return; + } + } + } catch (err) { + // Server not ready yet, continue waiting + } + + // Wait 500ms before next attempt + await new Promise((resolve) => setTimeout(resolve, 500)); + } + + throw new Error(`Server did not start within ${timeoutMs}ms`); +} diff --git a/tests/e2e/teardown.js b/tests/e2e/teardown.js new file mode 100644 index 0000000..1b48ce8 --- /dev/null +++ b/tests/e2e/teardown.js @@ -0,0 +1,24 @@ +/** + * Global teardown for Playwright E2E tests + * Stops the server process started in setup.js + */ + +export default async function globalTeardown() { + console.log('\n[E2E Teardown] Stopping server...'); + + if (global.__SERVER_PROCESS__) { + global.__SERVER_PROCESS__.kill('SIGTERM'); + + // Wait a moment for graceful shutdown + await new Promise((resolve) => setTimeout(resolve, 2000)); + + // Force kill if still running + try { + global.__SERVER_PROCESS__.kill('SIGKILL'); + } catch (err) { + // Process already dead + } + } + + console.log('[E2E Teardown] βœ“ Teardown complete'); +} diff --git a/tests/e2e/visualization.spec.js b/tests/e2e/visualization.spec.js new file mode 100644 index 0000000..c680080 --- /dev/null +++ b/tests/e2e/visualization.spec.js @@ -0,0 +1,118 @@ +import { test, expect } from '@playwright/test'; + +/** + * Visualization E2E Tests + * + * Tests for data visualization features: + * - Plots tab (charts/graphs) + * - Tables tab (structured data) + * - Artifacts tab (files) + * - Chat tab (AI interface) + */ + +test.describe('Data Visualization', () => { + test('plots tab renders charts', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + // Wait a bit for initial API calls to complete + await page.waitForTimeout(2000); + + // Find the Runs section and click the chevron button to expand it + // The button is near the "Runs" title and has a chevron icon + const runsSection = page.locator('.collapsible-section-wrapper:has-text("Runs")'); + const expandButton = runsSection.locator('.collapsible-section-toggle'); + await expandButton.click(); + await page.waitForTimeout(1000); + + // Wait for run items to load and appear (runs are loaded asynchronously) + // The RunTree renders checkboxes for each run + await page.waitForSelector('input[type="checkbox"]', { timeout: 15000 }); + + // Click the first checkbox to select the run + const checkbox = page.locator('input[type="checkbox"]').first(); + await checkbox.click(); + await page.waitForTimeout(1500); + + // Navigate to Plots tab + await page.click('text=Plots'); + await page.waitForTimeout(2000); + + // Wait for canvas elements to appear (charts render on canvas) + await page.waitForSelector('canvas', { timeout: 15000 }); + + // Verify chart elements render + const charts = page.locator('canvas'); + const chartCount = await charts.count(); + + // Should have at least one chart (all_primitives has BarChart, Scatter, Curve, Series, etc.) + expect(chartCount).toBeGreaterThan(0); + + // Verify first chart is visible + await expect(charts.first()).toBeVisible(); + }); + + test('tables tab shows structured data', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + // Select a run first + await page.locator('text=Runs').first().click(); + await page.waitForTimeout(500); + await page.locator('text=/all-primitives|Run/i').first().click(); + await page.waitForTimeout(1000); + + // Navigate to Tables tab + await page.click('text=Tables'); + await page.waitForTimeout(1000); + + // Verify table or data structure exists + const tables = page.locator('table, [class*="table"], [class*="Table"], [class*="grid"]'); + const tableCount = await tables.count(); + + // Should have some table-like elements + expect(tableCount).toBeGreaterThan(0); + }); + + test('artifacts tab displays file list', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + // Select a run first + await page.locator('text=Runs').first().click(); + await page.waitForTimeout(500); + await page.locator('text=/all-primitives|Run/i').first().click(); + await page.waitForTimeout(1000); + + // Navigate to Artifacts tab + await page.click('text=Artifacts'); + await page.waitForTimeout(1000); + + // Verify artifacts content area is visible + // (May be empty list or have artifact items) + const mainContent = page.locator('main, [class*="content"], [role="main"]'); + await expect(mainContent.first()).toBeVisible(); + }); + + test('chat tab loads successfully', async ({ page }) => { + await page.goto('/'); + await page.waitForLoadState('networkidle'); + + // Select a run first + await page.locator('text=Runs').first().click(); + await page.waitForTimeout(500); + await page.locator('text=/all-primitives|Run/i').first().click(); + await page.waitForTimeout(1000); + + // Navigate to Chat tab + await page.click('text=Chat'); + await page.waitForTimeout(1000); + + // Verify Chat tab is active and visible + await expect(page.locator('text=Chat').first()).toBeVisible(); + + // Verify main content area exists (whether it shows setup or chat interface) + const mainContent = page.locator('main, [class*="content"], [role="main"]'); + await expect(mainContent.first()).toBeVisible(); + }); +}); diff --git a/tests/generate_all_notebooks.py b/tests/generate_all_notebooks.py index d1adbc8..6362cff 100644 --- a/tests/generate_all_notebooks.py +++ b/tests/generate_all_notebooks.py @@ -24,7 +24,7 @@ def generate_ab_testing_notebooks(): """Generate A/B testing notebooks""" - print("πŸ“Š Generating A/B Testing notebooks...") + print("Generating A/B Testing notebooks...") variant_a = {"conversion_rate": 0.0523, "avg_order_value": 125.50, "bounce_rate": 0.4210} @@ -37,12 +37,12 @@ def generate_ab_testing_notebooks(): variant_b_metrics=variant_b, api_url=API_URL, ) - print("βœ… A/B Testing notebooks created") + print("A/B Testing notebooks created") def generate_finance_notebooks(): """Generate finance notebooks""" - print("πŸ’° Generating Finance notebooks...") + print("Generating Finance notebooks...") portfolio_metrics = { "sharpe_ratio": 1.85, @@ -58,12 +58,12 @@ def generate_finance_notebooks(): portfolio_metrics=portfolio_metrics, api_url=API_URL, ) - print("βœ… Finance notebooks created") + print("Finance notebooks created") def generate_genomics_notebooks(): """Generate genomics notebooks""" - print("🧬 Generating Genomics notebooks...") + print("Generating Genomics notebooks...") sequence_stats = { "sequence_length": 15420, @@ -76,12 +76,12 @@ def generate_genomics_notebooks(): create_genomics_notebook( project_id="genomics", run_ids=["run_001"], sequence_stats=sequence_stats, api_url=API_URL ) - print("βœ… Genomics notebooks created") + print("Genomics notebooks created") def generate_climate_notebooks(): """Generate climate notebooks""" - print("🌍 Generating Climate notebooks...") + print("Generating Climate notebooks...") climate_metrics = { "global_temp_change": 2.1, @@ -93,16 +93,16 @@ def generate_climate_notebooks(): create_climate_notebook( project_id="climate", run_ids=["run_001"], climate_metrics=climate_metrics, api_url=API_URL ) - print("βœ… Climate notebooks created") + print("Climate notebooks created") def generate_computer_vision_notebooks(): """Generate computer vision notebooks (already exists, just call it)""" - print("πŸ‘οΈ Computer Vision notebooks already generated by test_computer_vision.py") + print("Computer Vision notebooks already generated by test_computer_vision.py") if __name__ == "__main__": - print("\nπŸš€ Generating diverse example notebooks for all domains...\n") + print("\nGenerating diverse example notebooks for all domains...\n") try: generate_ab_testing_notebooks() @@ -111,11 +111,11 @@ def generate_computer_vision_notebooks(): generate_climate_notebooks() generate_computer_vision_notebooks() - print("\n✨ All notebooks generated successfully!") + print("\nAll notebooks generated successfully!") print("View them in the Artifacta UI\n") except Exception as e: - print(f"\n❌ Error generating notebooks: {e}") + print(f"\nError generating notebooks: {e}") import traceback traceback.print_exc() diff --git a/tests/helpers/api.py b/tests/helpers/api.py index e7f19f6..d2caf2c 100644 --- a/tests/helpers/api.py +++ b/tests/helpers/api.py @@ -59,7 +59,7 @@ def log_structured_data(api_url, run_id, name, primitive_type, data, section=Non primitive._metadata = metadata ds.log(name, primitive) - print(f"βœ“ Logged {primitive_type}: {name} [{section or 'General'}]") + print(f"Logged {primitive_type}: {name} [{section or 'General'}]") def finish_run(api_url, run_id): @@ -87,4 +87,4 @@ def log_artifact(api_url, run_id, filepath, role=None, include_content=True): # Log artifact using SDK with role run.log_artifact(filename, filepath, include_content=include_content, role=role or "output") - print(f"βœ“ πŸ“„ Logged artifact: {filename}") + print(f"Logged artifact: {filename}") diff --git a/tests/helpers/notebook.py b/tests/helpers/notebook.py index acf21ff..c9cb7a0 100644 --- a/tests/helpers/notebook.py +++ b/tests/helpers/notebook.py @@ -275,7 +275,7 @@ def create_experiment_summary_notebook( sections = [] # Title - sections.append(create_heading_section(f"πŸ“Š {experiment_name} - Experiment Summary", level=1)) + sections.append(create_heading_section(f"{experiment_name} - Experiment Summary", level=1)) # Overview sections.append(create_heading_section("Overview", level=2)) @@ -287,7 +287,7 @@ def create_experiment_summary_notebook( ) # Key Findings - sections.append(create_heading_section("🎯 Key Findings", level=2)) + sections.append(create_heading_section("Key Findings", level=2)) findings = [ f"Total runs completed: {len(run_ids)}", f"Best configuration identified: {list(best_config.keys())}", @@ -296,7 +296,7 @@ def create_experiment_summary_notebook( sections.append(create_bullet_list_section(findings)) # Results Table - sections.append(create_heading_section("πŸ“ˆ Results Summary", level=2)) + sections.append(create_heading_section("Results Summary", level=2)) if metrics_summary: # Extract headers from first result @@ -311,12 +311,12 @@ def create_experiment_summary_notebook( sections.append(create_table_section(headers, rows)) # Best Configuration - sections.append(create_heading_section("βš™οΈ Best Configuration", level=2)) + sections.append(create_heading_section("Best Configuration", level=2)) config_items = [f"{k}: {v}" for k, v in best_config.items()] sections.append(create_bullet_list_section(config_items)) # Code snippet - sections.append(create_heading_section("πŸ’» Example Usage", level=2)) + sections.append(create_heading_section("Example Usage", level=2)) code = f"""import artifacta as ds # Initialize run with best config @@ -356,7 +356,7 @@ def create_computer_vision_notebook( """ sections = [] - sections.append(create_heading_section("πŸ–ΌοΈ Computer Vision Experiment", level=1)) + sections.append(create_heading_section("Computer Vision Experiment", level=1)) sections.append(create_heading_section("Model Performance", level=2)) diff --git a/tests/helpers/notebook_html.py b/tests/helpers/notebook_html.py index 6ba116d..a7b0d37 100644 --- a/tests/helpers/notebook_html.py +++ b/tests/helpers/notebook_html.py @@ -378,7 +378,7 @@ def create_computer_vision_notebook( pdf_content = create_test_pdf_report("ResNet-50 Training Report") upload_attachment(note_id, "training_report.pdf", pdf_content, api_url) except Exception as e: - print(f"⚠️ Failed to generate PDF: {e}") + print(f"Failed to generate PDF: {e}") # Add audio file (only if note_id was successfully created) if note_id: @@ -392,7 +392,7 @@ def create_computer_vision_notebook( audio_content = f.read() upload_attachment(note_id, "training_complete.wav", audio_content, api_url) except Exception as e: - print(f"⚠️ Failed to generate audio: {e}") + print(f"Failed to generate audio: {e}") # Add video file (only if note_id was successfully created) if note_id: @@ -406,7 +406,7 @@ def create_computer_vision_notebook( video_content = f.read() upload_attachment(note_id, "training_animation.mp4", video_content, api_url) except Exception as e: - print(f"⚠️ Failed to generate video: {e}") + print(f"Failed to generate video: {e}") return note_id diff --git a/tests/helpers/video.py b/tests/helpers/video.py index c45edcb..fef26ad 100644 --- a/tests/helpers/video.py +++ b/tests/helpers/video.py @@ -28,7 +28,7 @@ def create_test_video_animation( import av # PyAV for video encoding except ImportError: # Fallback: create a placeholder file - print("⚠️ PyAV not available, creating placeholder video") + print("PyAV not available, creating placeholder video") uploads_dir = os.path.join(os.path.dirname(__file__), "..", "..", "uploads") os.makedirs(uploads_dir, exist_ok=True) filepath = os.path.join(uploads_dir, filename) diff --git a/tracking-server/cli.py b/tracking-server/cli.py index 209745e..1e13465 100755 --- a/tracking-server/cli.py +++ b/tracking-server/cli.py @@ -34,10 +34,14 @@ def cli(ctx: click.Context) -> None: @cli.command() @click.option("--host", default=DEFAULT_HOST, help="Host to bind the server to") @click.option("--port", default=DEFAULT_PORT, type=int, help="Port for the tracking server") -@click.option("--ui-port", default=DEFAULT_UI_PORT, type=int, help="Port for the UI (dev mode only)") +@click.option( + "--ui-port", default=DEFAULT_UI_PORT, type=int, help="Port for the UI (dev mode only)" +) @click.option("--db", default=DEFAULT_DB_PATH, help="Database file path") @click.option("--debug-logs", is_flag=True, help="Enable console log capture to file") -@click.option("--dev", is_flag=True, help="Run in development mode with hot-reload (requires Node.js)") +@click.option( + "--dev", is_flag=True, help="Run in development mode with hot-reload (requires Node.js)" +) def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) -> None: """Start the full UI (tracking server + frontend). @@ -51,12 +55,13 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) # Check both installed location and dev location try: from artifacta_ui import UI_DIST_PATH + dist_exists = UI_DIST_PATH.exists() except ImportError: dist_exists = (project_root / "dist").exists() if not dev and not dist_exists: - click.echo("❌ UI not built. Please run 'npm install && npm run build' first.") + click.echo("UI not built. Please run 'npm install && npm run build' first.") click.echo(" Or use --dev flag to run in development mode (requires Node.js).") sys.exit(1) @@ -72,13 +77,13 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) # Enable debug logging if requested if debug_logs: os.environ["VITE_DEBUG_LOGS"] = "true" - click.echo("πŸ› Debug logging enabled - logs will be saved to browser downloads") + click.echo("[DEBUG] Debug logging enabled - logs will be saved to browser downloads") processes: List[subprocess.Popen[bytes]] = [] try: # Start tracking server - click.echo(f"πŸ“Š Starting tracking server on {host}:{port}...") + click.echo(f"Starting tracking server on {host}:{port}...") server_process = subprocess.Popen( [sys.executable, "main.py"], cwd=server_dir, env=os.environ.copy() ) @@ -88,7 +93,7 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) time.sleep(1) # Start frontend dev server - click.echo(f"🎨 Starting UI dev server on http://localhost:{ui_port}...") + click.echo(f"Starting UI dev server on http://localhost:{ui_port}...") frontend_process = subprocess.Popen( ["npm", "run", "dev", "--", "--port", str(ui_port)], cwd=project_root, @@ -96,7 +101,7 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) ) processes.append(frontend_process) - click.echo("\nβœ… Artifacta is running in development mode!") + click.echo("\nArtifacta is running in development mode!") click.echo(f" - Tracking Server: http://{host}:{port}") click.echo(f" - UI: http://localhost:{ui_port}") click.echo("\nPress Ctrl+C to stop...") @@ -106,20 +111,20 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) process.wait() except KeyboardInterrupt: - click.echo("\nπŸ›‘ Stopping Artifacta...") + click.echo("\nStopping Artifacta...") for process in processes: process.terminate() for process in processes: process.wait() - click.echo("βœ… Artifacta stopped") + click.echo("Artifacta stopped") except Exception as e: - click.echo(f"❌ Error: {e}", err=True) + click.echo(f"Error: {e}", err=True) for process in processes: process.terminate() sys.exit(1) else: # Production mode - serve built UI from FastAPI server - click.echo(f"πŸ“Š Starting server with built-in UI on http://{host}:{port}...") + click.echo(f"Starting server with built-in UI on http://{host}:{port}...") try: # Check if running in development (tracking-server dir exists with main.py) @@ -130,10 +135,16 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) else: # Installed mode - run as module import uvicorn - from tracking_server.config import get_host, get_port, SERVER_BIND_HOST - uvicorn.run("tracking_server.main:app", host=SERVER_BIND_HOST, port=get_port(), log_level="info") + from tracking_server.config import SERVER_BIND_HOST, get_port + + uvicorn.run( + "tracking_server.main:app", + host=SERVER_BIND_HOST, + port=get_port(), + log_level="info", + ) except KeyboardInterrupt: - click.echo("\nπŸ›‘ Server stopped") + click.echo("\nServer stopped") @cli.command() @@ -142,7 +153,7 @@ def ui(host: str, port: int, ui_port: int, db: str, debug_logs: bool, dev: bool) @click.option("--db", default=DEFAULT_DB_PATH, help="Database file path") def server(host: str, port: int, db: str) -> None: """Start the tracking server without UI.""" - click.echo(f"πŸš€ Starting Artifacta tracking server on {host}:{port}...") + click.echo(f"Starting Artifacta tracking server on {host}:{port}...") project_root = get_project_root() server_dir = project_root / "tracking-server" @@ -155,7 +166,7 @@ def server(host: str, port: int, db: str) -> None: try: subprocess.run([sys.executable, "main.py"], cwd=server_dir, env=os.environ.copy()) except KeyboardInterrupt: - click.echo("\nπŸ›‘ Server stopped") + click.echo("\nServer stopped") @cli.group() @@ -168,13 +179,13 @@ def db() -> None: @click.option("--db", default=DEFAULT_DB_PATH, help="Database file path") def init(db: str) -> None: """Initialize the database.""" - click.echo(f"πŸ—„οΈ Initializing database: {db}") + click.echo(f"Initializing database: {db}") project_root = get_project_root() db_path = project_root / db if db_path.exists(): - click.echo(f"⚠️ Database already exists: {db_path}") + click.echo(f"Database already exists: {db_path}") if not click.confirm("Do you want to reinitialize it?"): return @@ -184,7 +195,7 @@ def init(db: str) -> None: os.environ["DATABASE_PATH"] = db init_db() - click.echo(f"βœ… Database initialized: {db_path}") + click.echo(f"Database initialized: {db_path}") @db.command() @@ -192,14 +203,14 @@ def init(db: str) -> None: @click.confirmation_option(prompt="Are you sure you want to clean the database?") def clean(db: str) -> None: """Clean/reset the database (removes all data).""" - click.echo(f"🧹 Cleaning database: {db}") + click.echo(f"Cleaning database: {db}") project_root = get_project_root() db_path = project_root / db if db_path.exists(): db_path.unlink() - click.echo(f"βœ… Database removed: {db_path}") + click.echo(f"Database removed: {db_path}") # Reinitialize sys.path.insert(0, str(project_root / "tracking-server")) @@ -207,7 +218,7 @@ def clean(db: str) -> None: os.environ["DATABASE_PATH"] = db init_db() - click.echo(f"βœ… Database reinitialized: {db_path}") + click.echo(f"Database reinitialized: {db_path}") @db.command(name="reset") @@ -215,7 +226,7 @@ def clean(db: str) -> None: @click.confirmation_option(prompt="Are you sure you want to reset the database?") def db_reset(db: str) -> None: """Reset the database (alias for clean).""" - click.echo(f"πŸ”„ Resetting database: {db}") + click.echo(f"Resetting database: {db}") # Call clean command ctx = click.get_current_context() @@ -226,7 +237,7 @@ def db_reset(db: str) -> None: @click.option("--db", default=DEFAULT_DB_PATH, help="Database file path") def reset(db: str) -> None: """Reset database and show instructions to restart the server.""" - click.echo("πŸ”„ Resetting Artifacta...") + click.echo("Resetting Artifacta...") project_root = get_project_root() db_path = project_root / db @@ -234,7 +245,7 @@ def reset(db: str) -> None: # Clean database if db_path.exists(): db_path.unlink() - click.echo(f"βœ… Database removed: {db_path}") + click.echo(f"Database removed: {db_path}") # Reinitialize sys.path.insert(0, str(project_root / "tracking-server")) @@ -242,8 +253,8 @@ def reset(db: str) -> None: os.environ["DATABASE_PATH"] = db init_db() - click.echo(f"βœ… Database reinitialized: {db_path}") - click.echo("\n⚠️ Please restart the server:") + click.echo(f"Database reinitialized: {db_path}") + click.echo("\nPlease restart the server:") click.echo(" 1. Stop the current server (Ctrl+C)") click.echo(" 2. Run: python cli.py ui") @@ -251,7 +262,7 @@ def reset(db: str) -> None: @cli.command() def stop() -> None: """Stop all Artifacta processes.""" - click.echo("πŸ›‘ Stopping Artifacta processes...") + click.echo("Stopping Artifacta processes...") # Kill Python processes running cli.py or main.py try: @@ -260,9 +271,9 @@ def stop() -> None: # Kill vite dev server subprocess.run(["pkill", "-f", "vite"], stderr=subprocess.DEVNULL, check=False) time.sleep(1) - click.echo("βœ… All processes stopped") + click.echo("All processes stopped") except Exception as e: - click.echo(f"⚠️ Error stopping processes: {e}", err=True) + click.echo(f"Error stopping processes: {e}", err=True) click.echo(" You may need to manually stop processes") diff --git a/tracking-server/config.py b/tracking-server/config.py index cc8bebf..7e07c4e 100644 --- a/tracking-server/config.py +++ b/tracking-server/config.py @@ -1,8 +1,58 @@ -"""Configuration constants for Artifacta tracking server.""" +"""Configuration constants for Artifacta tracking server. + +Centralized configuration management for: +- Server host/port settings +- Database paths and timeouts +- File upload limits +- CORS origins +- Pagination defaults +- UI static file locations + +Configuration sources (priority order): +1. Environment variables (highest priority) +2. Default constants defined here +3. Hard-coded fallbacks + +Environment variables: +- TRACKING_SERVER_HOST: Server bind host +- TRACKING_SERVER_PORT: Server port +- DATABASE_PATH: SQLite database file path +- UI_PORT: UI development server port +- CORS_ORIGINS: Comma-separated allowed origins + +Design philosophy: +- Sane defaults for local development +- Override via env vars for production +- Explicit constants prevent magic numbers throughout codebase +""" import os from pathlib import Path +__all__ = [ + "DEFAULT_HOST", + "DEFAULT_PORT", + "DEFAULT_UI_PORT", + "SERVER_BIND_HOST", + "PROJECT_ROOT", + "DEFAULT_DB_PATH", + "UI_DIST_PATH", + "MAX_UPLOAD_SIZE_MB", + "MAX_UPLOAD_SIZE_BYTES", + "THUMBNAIL_SIZE", + "SUPPORTED_IMAGE_FORMATS", + "API_PREFIX", + "CORS_ORIGINS", + "DEFAULT_PAGE_SIZE", + "MAX_PAGE_SIZE", + "REQUEST_TIMEOUT", + "DB_TIMEOUT", + "get_host", + "get_port", + "get_ui_port", + "get_api_url", +] + # Server Configuration DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 @@ -20,7 +70,7 @@ from artifacta_ui import UI_DIST_PATH except ImportError: # When running in development - UI_DIST_PATH = PROJECT_ROOT / "dist" + UI_DIST_PATH = PROJECT_ROOT / "artifacta_ui" / "dist" # File Upload Configuration MAX_UPLOAD_SIZE_MB = 100 diff --git a/tracking-server/database.py b/tracking-server/database.py index 5ff17ba..c589899 100644 --- a/tracking-server/database.py +++ b/tracking-server/database.py @@ -1,4 +1,91 @@ -"""Database models for tracking server using SQLAlchemy (SQLite or Postgres).""" +"""Database models and schema for the tracking server. + +This module defines the complete database schema using SQLAlchemy ORM, supporting +both SQLite (for local development) and PostgreSQL (for production). The schema +is designed for efficient querying of run metadata, metrics, and artifacts. + +Schema Design Rationale: + + Database schema design optimized for Artifacta's tracking needs: + + 1. **Runs Table**: Core entity representing a single training run + - run_id: Primary key (generated by SDK, ensures uniqueness across systems) + - name: Auto-generated display name (solves collision issues in UI) + - project: Optional grouping (for organizing runs) + - config_artifact_id: Link to config artifact (single source of truth) + - created_at: Unix timestamp in milliseconds for precision + + 2. **Tags Table**: Key-value metadata for searchable run attributes + - Normalized (separate row per tag) for efficient querying + - Indexed on (run_id, key, value) for fast filtering + - Unique constraint on (run_id, key) prevents duplicate keys + - Examples: git.commit, git.branch, user, environment + + 3. **StructuredData Table**: Stores primitives (Series, Distribution, Matrix, etc.) + - Denormalized data field (JSON blob) for flexibility + - Indexed on (run_id, name, primitive_type) for fast queries + - Optional section for grouping metrics (e.g., "train", "val") + - Timestamp for ordering time-series data + + 4. **Artifacts Table**: File metadata with content-addressable storage + - artifact_id: Primary key (unique identifier) + - hash: SHA256 hash for integrity and deduplication + - storage_path: Relative path in artifact store + - content: Optional inlined text content (for code artifacts) + - Indexed on hash for deduplication checks + + 5. **ArtifactLinks Table**: Many-to-many relationship between runs and artifacts + - Enables artifact reuse across runs (same checkpoint used by multiple runs) + - role: "input" or "output" to distinguish artifact usage + - Indexed on both run_id and artifact_id for bidirectional queries + + 6. **Projects Table**: Logical grouping of related runs + - Auto-created when run specifies project name + - Contains lab notebook entries via ProjectNote relationship + + 7. **ProjectNotes Table**: Lab notebook entries with markdown content + - Supports experiment documentation and run annotations + - Can attach files via ProjectNoteAttachment relationship + +Indexing Strategy: + + Indexes are carefully chosen based on common query patterns: + + - Tags: Indexed on run_id (filter by run), key (filter by tag type), + (key, value) (filter by specific tag), and unique (run_id, key) + - StructuredData: Indexed on run_id (get all metrics for run), + primitive_type (filter by data type), name (filter by metric name) + - Artifacts: Indexed on run_id (get all artifacts for run), + hash (check for existing artifact before upload) + - ArtifactLinks: Indexed on both run_id and artifact_id (bidirectional) + +Performance Considerations: + + - Text columns use TEXT type for unlimited length (JSON, diffs, etc.) + - Integer timestamps (milliseconds) for efficient range queries + - Foreign keys with indexes for fast joins + - Cascade deletes on relationships to maintain referential integrity + - check_same_thread=False for SQLite (FastAPI uses thread pool) + +Database Compatibility: + + SQLite (Development): + - Single file database (./data/runs.db) + - No server setup required + - Limited concurrent writes (but fine for single-user) + - check_same_thread=False needed for FastAPI async + + PostgreSQL (Production): + - Full ACID compliance + - Better concurrent access + - Advanced query optimization + - Use connection pooling for scalability + +Migration Strategy: + + Currently using create_all() for schema creation (works for new databases). + For production, consider Alembic for migrations when schema evolves. +""" from typing import Any, Dict, Type @@ -12,7 +99,7 @@ class Tag(Base): # type: ignore[misc] """Tag stores key-value metadata for runs (git info, environment, user, etc.). - Similar to MLflow's tag system - makes metadata easily searchable. + Makes run metadata easily searchable and filterable. """ __tablename__ = "tags" @@ -38,7 +125,7 @@ class Tag(Base): # type: ignore[misc] class Run(Base): # type: ignore[misc] """Run represents a single training run. - Similar to MLflow Run, contains metadata and lifecycle info. + Contains run metadata and lifecycle information. """ __tablename__ = "runs" @@ -48,7 +135,7 @@ class Run(Base): # type: ignore[misc] # Auto-generated display name (e.g., "Run 1", "Run 2") # This is what the UI displays - solves the "trial_1" collision issue - name = Column(String, nullable=False, unique=True) + name = Column(String, nullable=False) # Project/experiment grouping (optional) project = Column(String, nullable=True) @@ -236,7 +323,7 @@ def to_dict(self) -> Dict[str, Any]: class Database: """Database manager - handles connection and session management. - Similar to MLflow's SqlAlchemyStore. + SQLAlchemy-based storage layer for runs and artifacts. """ def __init__(self, db_uri: str = "sqlite:///./data/runs.db"): diff --git a/tracking-server/main.py b/tracking-server/main.py index bd61f42..fab7808 100644 --- a/tracking-server/main.py +++ b/tracking-server/main.py @@ -1,9 +1,124 @@ -"""Tracking Server (MLflow-style). +"""Artifacta Tracking Server - FastAPI application with real-time WebSocket support. -HTTP API with direct SDK emission support: -- Receives metrics directly from artifacta Python SDK -- Stores in SQLite database -- Broadcasts to WebSocket clients in real-time +This is the main entry point for the tracking server, which provides an HTTP API +for receiving run data from the Artifacta SDK and a WebSocket interface for +real-time updates to connected UIs. + +Architecture: + + The server is built on FastAPI and organized into layers: + + 1. **Application Layer** (this file): + - FastAPI app initialization with lifespan management + - CORS middleware configuration for web frontend + - Dependency injection middleware for database and WebSocket clients + - Route registration and static file serving + - Global state management (database, WebSocket clients) + + 2. **Database Layer** (database.py): + - SQLAlchemy ORM models and schema + - Database connection management + - Session creation and cleanup + + 3. **Routes Layer** (routes/*.py): + - Individual endpoint implementations + - Request/response validation via Pydantic + - Business logic for runs, artifacts, projects, chat + + 4. **WebSocket Layer** (routes/websocket.py): + - WebSocket connection management + - Real-time broadcast to connected clients + - Connection lifecycle (connect, disconnect, error handling) + +Lifespan Management: + + FastAPI's lifespan context manager handles startup and shutdown: + + Startup: + 1. Initialize SQLite database via init_database() + 2. Create Database instance with connection string + 3. Store in global 'db' variable for dependency injection + 4. Log ready message + + Shutdown: + - Automatic cleanup (context manager exit) + - Database connections closed by SQLAlchemy + - WebSocket clients disconnected automatically + + Why lifespan: + - Replaces deprecated @app.on_event("startup") / @app.on_event("shutdown") + - Provides cleaner resource management with context manager pattern + - Ensures cleanup happens even on crashes (finally block semantics) + +Dependency Injection: + + Two mechanisms provide dependencies to route handlers: + + 1. **HTTP Middleware** (inject_dependencies): + - Intercepts every HTTP request + - Attaches db and websocket_clients to request.state + - Routes access via request.state.db, request.state.websocket_clients + + 2. **FastAPI Depends** (routes/dependencies.py): + - get_db(), get_session(), get_websocket_clients() extractors + - Type-safe dependency injection in route signatures + - Cleaner than accessing request.state directly + +CORS Configuration: + + Permissive CORS for development (allow all origins): + - allow_origins=["*"]: Any domain can call the API + - allow_methods=["*"]: All HTTP methods (GET, POST, PATCH, DELETE) + - allow_headers=["*"]: All custom headers + - expose_headers: Pagination headers exposed to browser JavaScript + + Why permissive: + - Local development: UI runs on localhost:5173, API on localhost:8000 + - Production: Should restrict origins to deployed frontend domain + + Security consideration: + For production, set allow_origins to specific domain list + +WebSocket Client Management: + + Global set of connected WebSocket clients: + - websocket_clients: Set[WebSocket] - thread-safe for async + - Shared across all route handlers via dependency injection + - Modified by websocket route (add on connect, remove on disconnect) + - Used by data emission routes to broadcast updates + + Broadcasting strategy: + - When SDK emits data, server stores in database + - Server then broadcasts to all connected WebSocket clients + - Clients receive real-time updates without polling + - Disconnected clients removed automatically on send failure + +Static File Serving: + + Single-page application (SPA) routing: + 1. Mount /assets directory for static files (JS, CSS, images) + 2. Catch-all route serves index.html for non-API paths + 3. API routes (/api/*, /ws/*) not intercepted + 4. Enables client-side routing (React Router, etc.) + + Fallback behavior: + - If UI dist folder missing, log warning + - Server continues in API-only mode + - Useful for headless deployments or during UI development + +Environment Configuration: + + - TRACKING_SERVER_PORT: Port to bind (default 8000) + - API_PORT: Alternative env var name (backwards compatibility) + - DB_PATH: Database file location (default: ./data/runs.db) + - UI_DIST_PATH: Location of built UI files (from config.py) + +Design Philosophy: + + - Real-time first: WebSocket broadcasts enable live UI updates + - Database-backed: SQLite for development, PostgreSQL for production + - Stateless routes: All state in database or request context + - Graceful degradation: API works without UI, UI works without WebSocket """ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,no-untyped-call" @@ -15,10 +130,12 @@ from fastapi import FastAPI, Request, WebSocket from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles from routes import artifacts, chat, health, projects, runs, websocket +from config import UI_DIST_PATH + # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -50,7 +167,7 @@ def init_database() -> None: init_db() - logger.info("βœ… Database initialized") + logger.info("Database initialized") # Lifespan management @@ -67,7 +184,7 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def] db_uri = f"sqlite:///{DB_PATH}" db = Database(db_uri=db_uri) - logger.info("βœ… Tracking server ready") + logger.info("Tracking server ready") yield @@ -124,25 +241,26 @@ async def inject_dependencies(request: Request, call_next): # type: ignore[no-u # Serve static UI files if dist folder exists -from config import UI_DIST_PATH - if UI_DIST_PATH.exists(): # Mount static assets (JS, CSS, images, etc.) app.mount("/assets", StaticFiles(directory=UI_DIST_PATH / "assets"), name="assets") # Serve index.html for root and any unmatched routes (SPA fallback) - @app.get("/{full_path:path}") - async def serve_ui(full_path: str): - """Serve the UI for all non-API routes.""" - # Don't intercept API routes - if full_path.startswith("api/") or full_path.startswith("ws/"): - return None - - # Serve index.html from dist folder for all other routes (SPA routing) + # Note: This catch-all route is registered LAST so API routes take precedence + @app.get("/{full_path:path}", response_model=None, include_in_schema=False) + async def serve_ui(full_path: str) -> FileResponse: + """Serve the UI for all non-API routes (SPA fallback).""" + # This should never match API routes since they're registered first, + # but check anyway as a safety measure + if full_path.startswith("api/"): + from fastapi import HTTPException + raise HTTPException(status_code=404, detail="API route not found") + + # Serve index.html for all other paths (enables client-side routing) return FileResponse(UI_DIST_PATH / "index.html") else: logger.warning( - f"⚠️ UI dist folder not found at {UI_DIST_PATH}. " + f"UI dist folder not found at {UI_DIST_PATH}. " "Run 'npm install && npm run build' to build the UI, " "or use 'artifacta server' for API-only mode." ) diff --git a/tracking-server/routes/artifacts.py b/tracking-server/routes/artifacts.py index 525b5d9..44fba58 100644 --- a/tracking-server/routes/artifacts.py +++ b/tracking-server/routes/artifacts.py @@ -1,5 +1,110 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""Artifact management endpoints - SQLAlchemy ORM version.""" +"""Artifact storage and retrieval endpoints for tracking server. + +This module implements content-addressable artifact storage with deduplication, +inline content support, and file serving capabilities. It bridges between the +SDK's artifact logging and the UI's artifact viewing/downloading needs. + +Key Features: + - Content-addressable storage via SHA256 hashing + - Automatic deduplication (same hash reuses existing artifact) + - Support for both filesystem and virtual (inline) artifacts + - File preview with pagination for large artifacts + - Individual file serving from multi-file artifacts + - Download endpoints for artifact export + - Artifact role tracking (input vs output) + - Provenance tracking via hash.code tags + +Artifact Storage Models: + + 1. **Filesystem Artifacts** (storage_path: actual file path): + - Large files (model checkpoints, datasets, videos) + - Files remain on user's filesystem + - storage_path points to actual location + - metadata stored in database + - content field is NULL + + 2. **Virtual Artifacts** (storage_path: virtual://...): + - Small text files (code, configs, logs) + - Content inlined in database (content field) + - No actual filesystem storage + - Faster access (no disk I/O) + - Better for version control + +Content-Addressable Storage Algorithm: + + When SDK logs artifact: + 1. Compute SHA256 hash of artifact content + 2. Query database for existing artifact with same hash + 3. If exists: + a. Reuse existing artifact_id + b. Create new ArtifactLink with current run_id + c. Skip storage (file already exists) + 4. If not exists: + a. Generate new artifact_id (art_XXXXXXXX) + b. Create Artifact record with hash and storage_path + c. Create ArtifactLink with current run_id + d. For virtual artifacts, store content in database + + Benefits: + - Deduplication saves storage (checkpoints reused across runs) + - Hash ensures integrity (detect file corruption) + - Provenance tracking (which runs used which artifacts) + +Artifact Links (Many-to-Many): + + The ArtifactLink table enables artifact reuse: + - One artifact can be linked to multiple runs + - One run can link to multiple artifacts + - Role field distinguishes "input" vs "output" + - created_at tracks when link was established + + Use cases: + - Pretrained model used as input by multiple fine-tuning runs + - Dataset artifact shared across experiment runs + - Best checkpoint from run A used as input to run B + +File Serving Strategy: + + GET /artifact/{artifact_id}/files/{filename}: + 1. Check if virtual artifact (storage_path starts with "virtual://") + 2. If virtual: + a. Parse content JSON + b. Find file by filename in files array + c. Return file content with appropriate MIME type + 3. If filesystem: + a. Resolve file path (storage_path + filename) + b. Determine MIME type from extension + c. Return FileResponse with streaming + + MIME type handling: + - Explicit map for common types (mp4, png, pdf, etc.) + - Fallback to application/octet-stream + - text/* and application/json served inline + - Other types trigger download + +Provenance Tracking: + + Code artifacts automatically update hash.code tag: + 1. Parse artifact content JSON + 2. Check if any file has metadata.type == "code" + 3. If yes, create/update Tag with key="hash.code", value=artifact_hash + 4. UI can show code hash for reproducibility + 5. Users can verify code hasn't changed across runs + +Preview Pagination: + + For large artifacts with many files (e.g., dataset with 10K images): + - offset, limit parameters control pagination + - Returns only requested slice of files + - has_more flag indicates more files available + - Frontend can load files incrementally + +Error Handling: + - 404 Not Found: Artifact ID doesn't exist, file not found on disk + - 500 Internal Server Error: Database errors, JSON parsing errors + - All exceptions logged with exc_info=True for debugging +""" import json import logging @@ -10,7 +115,7 @@ from database import Artifact, ArtifactLink, Tag from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import FileResponse, JSONResponse, Response from pydantic import BaseModel from sqlalchemy.orm import Session @@ -56,7 +161,7 @@ async def create_artifact( existing = session.query(Artifact).filter(Artifact.hash == request.hash).first() if existing: - artifact_id = existing.artifact_id + artifact_id = str(existing.artifact_id) logger.info(f"Reusing existing artifact: {artifact_id} (hash={request.hash[:8]}...)") else: artifact_id = f"art_{uuid.uuid4().hex[:16]}" @@ -110,7 +215,7 @@ async def create_artifact( ) if code_tag: - code_tag.value = request.hash + code_tag.value = request.hash # type: ignore[assignment] else: code_tag = Tag( run_id=request.run_id, @@ -256,7 +361,7 @@ async def get_artifact_file( artifact_id: str, filename: str, session: Session = Depends(get_session), -) -> Union[FileResponse, JSONResponse]: +) -> Union[FileResponse, JSONResponse, Response]: """Get a specific file from an artifact. For inline artifacts (text/code), returns the content directly. @@ -373,7 +478,7 @@ async def download_artifact( with tempfile.NamedTemporaryFile( mode="w", delete=False, suffix=f"_{artifact.name}" ) as tmp: - tmp.write(artifact.content) + tmp.write(str(artifact.content)) tmp_path = tmp.name return FileResponse( diff --git a/tracking-server/routes/chat.py b/tracking-server/routes/chat.py index f0f43a0..21dce44 100644 --- a/tracking-server/routes/chat.py +++ b/tracking-server/routes/chat.py @@ -1,5 +1,99 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""LLM Chat proxy endpoints using LiteLLM.""" +"""LLM chat proxy using LiteLLM for universal provider support. + +This module provides a streaming chat endpoint that proxies requests to any LLM +provider (OpenAI, Anthropic, Ollama, Groq, etc.) via LiteLLM. This allows the +Artifacta UI to integrate LLM chat without requiring separate API clients for +each provider. + +Why LiteLLM: + - Unified interface for 100+ LLM providers + - Auto-detects provider from model name + - Handles provider-specific authentication + - Normalizes response format across providers + - Supports streaming for all providers + - Open source and actively maintained + +Supported Providers (examples): + - OpenAI: gpt-4o, gpt-4o-mini, gpt-3.5-turbo + - Anthropic: claude-3-5-sonnet-20241022, claude-3-opus + - Ollama: ollama/llama2, ollama/mistral + - Groq: groq/mixtral-8x7b-32768 + - Many more: Cohere, Replicate, HuggingFace, etc. + +Model Name Format: + LiteLLM uses model name to determine provider: + - "gpt-4o" -> OpenAI + - "claude-3-5-sonnet" -> Anthropic + - "ollama/llama2" -> Ollama (local) + - "groq/mixtral" -> Groq + + Prefix (ollama/, groq/) explicitly specifies provider. + No prefix means OpenAI or Anthropic based on model name pattern. + +API Key Handling: + Client sends optional API key in request: + 1. If provided, set appropriate environment variable: + - OPENAI_API_KEY for OpenAI models + - ANTHROPIC_API_KEY for Anthropic models + - GROQ_API_KEY for Groq models + 2. LiteLLM reads from environment variables + 3. Local models (Ollama) don't require API keys + + Security consideration: + - API keys stored temporarily in environment (process lifetime) + - Not persisted to disk + - Each request can use different key (multi-user support) + +Streaming Response Format: + Server-Sent Events (SSE) protocol: + - Content-Type: text/event-stream + - Data format: "data: {json}\\n\\n" + - Chunks contain delta (incremental content) + - Client reconstructs full response from deltas + + Example stream: + data: {"choices": [{"delta": {"content": "Hello"}}]} + + data: {"choices": [{"delta": {"content": " world"}}]} + + data: {"choices": [{"delta": {"content": "!"}}]} + + + Why streaming: + - Progressive rendering in UI (better UX) + - Lower perceived latency (first token arrives faster) + - Works for long responses (minutes of generation) + - Standard for modern LLM interfaces + +System Message Handling: + Optional system_message field prepended to messages: + - If provided, inserted as first message with role="system" + - Helps set context/instructions for LLM + - Example: "You are a helpful data science assistant" + +Error Handling: + - LLM API errors (rate limits, auth failures, timeouts) + - Network errors (connection failures) + - Invalid model names + - All errors returned as SSE data events with "error" field + - Logged to server logs for debugging + +Async Generator Pattern: + generate() is an async generator: + - Yields chunks as they arrive from LLM API + - Allows FastAPI to stream response back to client + - Efficient (no buffering, low memory usage) + - Handles backpressure automatically + +Integration with UI: + Frontend uses EventSource or fetch with ReadableStream: + 1. POST to /api/chat/stream with messages array + 2. Read SSE stream incrementally + 3. Parse JSON from each "data:" line + 4. Extract delta.content and append to display + 5. Continue until stream closes +""" import json import logging diff --git a/tracking-server/routes/dependencies.py b/tracking-server/routes/dependencies.py index dd2daa1..05eb202 100644 --- a/tracking-server/routes/dependencies.py +++ b/tracking-server/routes/dependencies.py @@ -1,5 +1,24 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""Dependency injection helpers for route modules.""" +"""Dependency injection helpers for route modules. + +FastAPI dependency injection pattern for accessing global state: +- Database connection (SQLAlchemy) +- Database session (with automatic commit/rollback) +- WebSocket clients set (for real-time broadcasting) + +Architecture: +- Middleware injects global state into request.state (see main.py) +- These functions extract state and provide it to route handlers +- Enables testing by mocking request.state +- Follows FastAPI Depends() pattern for clean separation + +Session lifecycle: +- get_session() yields a session (context manager pattern) +- Automatically commits on success +- Automatically rolls back on exception +- Always closes session in finally block +- Prevents connection leaks and ensures transaction integrity +""" from typing import Any, Generator, Set diff --git a/tracking-server/routes/health.py b/tracking-server/routes/health.py index a2bac02..d0fc3a5 100644 --- a/tracking-server/routes/health.py +++ b/tracking-server/routes/health.py @@ -1,5 +1,17 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""Health check and debug endpoints.""" +"""Health check and debug endpoints. + +Health check endpoint: +- Used by HTTPEmitter to verify server availability +- Checks database connectivity +- Returns status for monitoring/ops + +Debug logs endpoint: +- Accepts frontend console logs via POST +- Writes to logs/ directory with timestamp +- Enables debugging UI issues in production +- Used by debugLogger.js when VITE_DEBUG_LOGS=true +""" import logging from pathlib import Path diff --git a/tracking-server/routes/projects.py b/tracking-server/routes/projects.py index 047a915..c845566 100644 --- a/tracking-server/routes/projects.py +++ b/tracking-server/routes/projects.py @@ -1,5 +1,105 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""Project notes and attachments endpoints.""" +"""Project notes and attachments for lab notebook functionality. + +This module implements a lab notebook system where users can create markdown notes, +attach files (PDFs, images, videos), and organize experiments by project. It supports +both explicit projects (created via API) and implicit projects (inferred from runs). + +Architecture: + - Projects: Logical grouping of runs and notes + - ProjectNotes: Markdown content with title and timestamps + - ProjectNoteAttachments: Files attached to notes (hash-based storage) + +Project Types: + + 1. **Explicit Projects**: + - Created via POST /api/projects + - Have entry in Projects table + - Tracked via created_at and updated_at timestamps + - Can exist without any runs + + 2. **Implicit Projects**: + - Inferred from Run.project field + - No entry in Projects table + - Automatically discovered when listing projects + - Created on-the-fly when first run uses project name + + List projects merges both types for unified UX. + +Note Management: + + CREATE: POST /api/projects/{project_id}/notes + - Auto-creates project if doesn't exist + - Stores markdown content + - Tracks created_at and updated_at + + UPDATE: PUT /api/projects/{project_id}/notes/{note_id} + - Partial update (only fields provided) + - Updates updated_at timestamp + - Keeps created_at unchanged + + DELETE: DELETE /api/projects/{project_id}/notes/{note_id} + - Cascades to attachments (database ON DELETE CASCADE) + - Deletes attachment files from disk + - Returns 404 if note not found + +Attachment Storage: + + Files stored with hash-based paths for: + - Deduplication (same file uploaded multiple times) + - Integrity verification (detect corruption) + - Content-addressable lookup + + Storage path format: uploads/{first_2_chars_of_hash}/{uuid}{ext} + Example: uploads/ab/c3f7d4e8-1234-5678-9abc-def012345678.pdf + + Upload flow: + 1. Read file content and compute SHA256 hash + 2. Generate UUID for unique filename + 3. Extract extension from original filename + 4. Create directory structure (e.g., uploads/ab/) + 5. Write file to disk + 6. Create ProjectNoteAttachment record in database + +Inline Viewing vs Download: + + GET /attachments/{attachment_id}/download?inline=true: + - inline=true: Sets Content-Disposition: inline + - Browser displays PDF/image/video in iframe/tab + - Used for preview in UI + - inline=false: Sets Content-Disposition: attachment + - Browser triggers download dialog + - Used for explicit downloads + + MIME type determines browser behavior: + - application/pdf: Browser PDF viewer + - image/*: Display inline + - video/*: HTML5 video player + - audio/*: HTML5 audio player + - Other: Trigger download + +File Cleanup: + + When attachment is deleted: + 1. Delete database record (ProjectNoteAttachment) + 2. Delete file from disk (PROJECT_ROOT / storage_path) + 3. If file doesn't exist, continue (already deleted) + 4. Ignore errors (best effort cleanup) + +Database Session Management: + + Manual session management (unlike runs.py which uses Depends): + - session = db.get_session() + - try/except/finally with session.close() + - Manual commit() and rollback() + - Required for attachment file operations (need transaction control) + +Error Handling: + - 404 Not Found: Project/note/attachment doesn't exist + - 400 Bad Request: Project already exists + - 500 Internal Server Error: Database errors, file I/O errors + - All errors logged with logger.error() +""" import hashlib import logging diff --git a/tracking-server/routes/runs.py b/tracking-server/routes/runs.py index 2e88729..767f8fc 100644 --- a/tracking-server/routes/runs.py +++ b/tracking-server/routes/runs.py @@ -1,5 +1,115 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""Run management endpoints - SQLAlchemy ORM version.""" +"""Run management endpoints for tracking server. + +This module implements HTTP endpoints for run creation, retrieval, and data emission. +It serves as the primary interface between the Artifacta SDK (client) and the +tracking server (backend), enabling real-time metrics logging and WebSocket broadcasts. + +Endpoint Overview: + + POST /api/runs + - Create new run entry in database + - Called by SDK when run.start() is invoked + - Generates unique run name if not provided + - Broadcasts run creation to WebSocket clients + - Returns: {"status": "created", "run_id": "..."} + + GET /api/runs + - List all runs with structured data + - Supports filtering by project, pagination via limit + - Optionally includes tags and params + - Groups structured data by metric name + - Returns: List[RunDict] with structured_data field + + GET /api/runs/{run_id} + - Retrieve single run by ID + - Includes all structured data grouped by name + - Returns: RunDict with structured_data field + + PATCH /api/runs/{run_id}/config-artifact + - Update run to link config artifact (single source of truth) + - Called by SDK after config artifact is logged + - Returns: {"status": "updated"} + + POST /api/runs/{run_id}/data + - Log structured data primitive (Series, Distribution, etc.) + - Called by SDK when run.log() is invoked + - Stores data as JSON in database + - Broadcasts to WebSocket clients for real-time UI updates + - Returns: {"status": "logged"} + +Data Flow (SDK to UI): + + 1. SDK calls run.log(name="loss", data={"epoch": [1,2,3], "loss": [0.5,0.3,0.2]}) + 2. SDK sends POST to /api/runs/{run_id}/data with payload + 3. Server validates and stores in StructuredData table + 4. Server broadcasts to all connected WebSocket clients + 5. UI receives WebSocket message and updates chart in real-time + +Config Artifact Linking: + + The config artifact pattern ensures configuration is stored as an artifact + (for provenance) and linked to the run (for easy access): + + 1. SDK logs config as artifact via log_artifact() + 2. SDK calls PATCH /api/runs/{run_id}/config-artifact with artifact_id + 3. Server updates run.config_artifact_id foreign key + 4. GET /api/runs fetches artifact content and parses config JSON + 5. UI displays config in run details + +Structured Data Grouping: + + Multiple emissions of the same metric name are grouped together: + - run.log("loss", ...) at step 1 + - run.log("loss", ...) at step 2 + - run.log("loss", ...) at step 3 + - Result: structured_data["loss"] = [step1_data, step2_data, step3_data] + + This enables: + - Line charts with multiple points + - Confusion matrices that update over epochs + - A/B test results with progressive accumulation + +WebSocket Broadcasting: + + broadcast_to_websockets() helper: + 1. Iterate over all connected WebSocket clients (global set) + 2. Send JSON message via client.send_json() + 3. Catch exceptions (client disconnected, network error) + 4. Track disconnected clients in set + 5. Remove disconnected clients from global set (cleanup) + + Message format: + { + "type": "data_logged" | "run_created", + "payload": { + "run_id": "...", + "name": "loss", + "primitive_type": "series" + } + } + +Error Handling: + + - HTTPException for client errors (404 Not Found, 400 Bad Request) + - Generic Exception caught and logged with exc_info=True + - Raised as HTTPException(500) with error details + - This prevents stack traces leaking to client + +Database Session Management: + + - SQLAlchemy session injected via Depends(get_session) + - Session automatically committed at end of request (FastAPI middleware) + - Session rolled back on exception (automatic cleanup) + - No manual commit() needed in route handlers + +Performance Considerations: + + - Queries use indexes (run_id, created_at, etc.) for fast lookups + - Limit parameter prevents unbounded result sets + - JSON parsing done lazily (only for requested runs) + - Structured data grouped in Python (not SQL) for flexibility +""" import json import logging @@ -183,10 +293,11 @@ async def list_runs( # Group by name structured_data: Dict[str, List[Dict[str, Any]]] = {} for data_row in structured_data_rows: - if data_row.name not in structured_data: - structured_data[data_row.name] = [] + name_key = str(data_row.name) + if name_key not in structured_data: + structured_data[name_key] = [] - structured_data[data_row.name].append( + structured_data[name_key].append( { "primitive_type": data_row.primitive_type, "section": data_row.section, @@ -240,10 +351,11 @@ async def get_run( # Group by name structured_data: Dict[str, List[Dict[str, Any]]] = {} for data_row in structured_data_rows: - if data_row.name not in structured_data: - structured_data[data_row.name] = [] + name_key = str(data_row.name) + if name_key not in structured_data: + structured_data[name_key] = [] - structured_data[data_row.name].append( + structured_data[name_key].append( { "primitive_type": data_row.primitive_type, "section": data_row.section, @@ -277,7 +389,7 @@ async def update_config_artifact( if not run: raise HTTPException(status_code=404, detail="Run not found") - run.config_artifact_id = request.config_artifact_id + run.config_artifact_id = request.config_artifact_id # type: ignore[assignment] session.flush() return {"status": "updated"} diff --git a/tracking-server/routes/websocket.py b/tracking-server/routes/websocket.py index 192791c..f65ae2c 100644 --- a/tracking-server/routes/websocket.py +++ b/tracking-server/routes/websocket.py @@ -1,5 +1,132 @@ # mypy: disable-error-code="misc,untyped-decorator,union-attr,no-any-return,arg-type" -"""WebSocket endpoints for real-time metrics streaming.""" +"""WebSocket endpoint for real-time metrics and run updates. + +This module implements the WebSocket server endpoint that clients connect to for +receiving real-time updates about runs, metrics, and artifacts. It's the push +notification mechanism that enables live UI updates without polling. + +Architecture: + + The WebSocket system follows a pub-sub pattern: + + 1. **Connection Management** (this module): + - Accept WebSocket connections at /ws/metrics + - Add connections to global websocket_clients set + - Keep connections alive (receive ping/pong messages) + - Remove disconnected clients from set + + 2. **Broadcasting** (runs.py and other routes): + - When data is logged, broadcast to all connected clients + - Iterate over websocket_clients set + - Send JSON message to each client + - Remove failed clients (disconnected, errors) + + 3. **Client Consumption** (frontend JavaScript): + - Connect WebSocket to /ws/metrics + - Listen for JSON messages + - Update UI components based on message type + - Reconnect on disconnect + +Connection Lifecycle: + + Connect: + 1. Client initiates WebSocket connection + 2. Server calls websocket.accept() + 3. Server adds websocket to global set + 4. Log connection count + 5. Enter receive loop (keep-alive) + + Alive: + - Client sends periodic ping messages (heartbeat) + - Server receives via websocket.receive_text() + - Connection stays open indefinitely + - No response needed (WebSocket protocol handles pong) + + Disconnect: + - Client closes connection (normal shutdown) + - WebSocketDisconnect exception raised + - Server removes client from set + - Log remaining connection count + + Error: + - Network error, protocol error, etc. + - Generic Exception caught + - Server removes client from set (discard for safety) + - Log error message + +Message Format: + + Broadcast messages from server to clients: + { + "type": "run_created" | "data_logged" | "artifact_uploaded", + "payload": { + "run_id": "...", + "name": "loss", + "primitive_type": "series", + ... + } + } + + Message types: + - run_created: New run started + - data_logged: Metric/primitive logged to run + - artifact_uploaded: Artifact added to run + (more types can be added as needed) + +Global State Management: + + websocket_clients_global: + - Module-level variable (shared across all requests) + - Set[WebSocket] type (O(1) add/remove) + - Initialized by main.py via set_websocket_clients() + - Thread-safe for async (FastAPI handles concurrency) + + Why global set: + - Needs to be shared across all route handlers + - Routes need to broadcast to all connected clients + - FastAPI doesn't have built-in pub-sub (unlike Socket.io) + - Simple and efficient for this use case + +Scalability Considerations: + + Current design (in-memory set): + - Works well for single-server deployment + - All WebSocket connections on same server + - Simple, fast, no external dependencies + + Future scaling (if needed): + - Redis Pub/Sub for multi-server deployment + - Each server maintains local websocket_clients + - Broadcasts go through Redis channel + - All servers receive and forward to their clients + +Keep-Alive Strategy: + + The receive loop keeps connection open: + - await websocket.receive_text() blocks until message + - Client sends ping periodically (e.g., every 30 seconds) + - Server doesn't need to respond (WebSocket handles it) + - Prevents connection timeout from load balancers/proxies + + Why receive-based keep-alive: + - Simpler than server-side ping/pong + - Client controls heartbeat frequency + - Works with all WebSocket client libraries + - No need for asyncio.sleep() polling loop + +Error Recovery: + + Client-side should implement reconnection: + - Detect WebSocket close event + - Wait with exponential backoff (1s, 2s, 4s, ...) + - Reconnect to /ws/metrics + - Resume receiving updates + + Server-side is stateless: + - No per-client state to restore + - Clients immediately receive future broadcasts + - Past messages not replayed (use HTTP API for history) +""" import logging