Skip to content

Commit c0634a9

Browse files
dougalmlearned_optimization authors
authored andcommitted
Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change.
PiperOrigin-RevId: 677843398
1 parent 4bcaeb0 commit c0634a9

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

learned_optimization/jax_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ def body_fn(_, operand):
4646

4747
def in_jit() -> bool:
4848
"""Returns true if tracing jit."""
49-
return "DynamicJaxprTrace" in str(
50-
jax.core.thread_local_state.trace_state.trace_stack
51-
)
49+
if jax.__version_info__ <= (0, 4, 33):
50+
return "DynamicJaxprTrace" in str(
51+
jax.core.thread_local_state.trace_state.trace_stack
52+
)
53+
54+
return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
5255

5356

5457
Carry = TypeVar("Carry")

0 commit comments

Comments
 (0)