Skip to content

Fix JAX transformation compatibility with functional updates#2

Merged
juehang merged 2 commits into
mainfrom
fix/jax-transformation-compatibility
Nov 20, 2025
Merged

Fix JAX transformation compatibility with functional updates#2
juehang merged 2 commits into
mainfrom
fix/jax-transformation-compatibility

Conversation

@juehang
Copy link
Copy Markdown
Collaborator

@juehang juehang commented Nov 20, 2025

Summary

This PR improves JAX compatibility for dataloader transformation methods by replacing incomplete PR #1 with a more robust functional approach.

Changes

  • Fixes both methods: data_transformation and data_inv_transformation now work with jax.grad
  • Functional approach: Uses jnp.concatenate instead of in-place mutations
  • Cleaner code: Removes unnecessary copy.deepcopy calls
  • Universal compatibility: Works uniformly with NumPy and JAX arrays

Test Results

✅ All 3 JAX grad compatibility tests pass:

  • test_transformation_grad_compatible[data_transformation]
  • test_transformation_grad_compatible[data_inv_transformation]
  • test_round_trip_grad_precision

Why This Approach

The functional pattern is more idiomatic JAX code that naturally handles both NumPy and JAX arrays without conditional logic.

…unctional updates

Improves upon the dataloader transformations to ensure robust JAX compatibility:
- Fixes both data_transformation and data_inv_transformation for jax.grad
- Uses purely functional approach with jnp.concatenate instead of in-place mutations
- Removes unnecessary copy.deepcopy calls
- Works uniformly with NumPy and JAX arrays without type checking
- Passes all jax.grad compatibility tests
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves JAX transformation compatibility for the dataloader by replacing in-place array mutations with functional operations using jnp.concatenate, enabling these methods to work with jax.grad.

Key Changes

  • Replaced in-place mutations with functional jnp.concatenate operations in data_transformation and data_inv_transformation
  • Removed unnecessary copy.deepcopy calls that were blocking JAX gradient flow
  • Added comprehensive JAX gradient compatibility tests

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
src/probabilistic_posrec/dataloader/base.py Refactored data_transformation and data_inv_transformation to use functional concatenation instead of in-place mutations, enabling JAX gradient compatibility
tests/test_dataloader_jax_grad.py Added new test file with gradient compatibility tests for both transformation methods and round-trip precision validation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@juehang juehang merged commit 6981650 into main Nov 20, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants