Skip to content

Test failures in policies_test.py #103

@samuela

Description

@samuela

I'm seeing the following test failures on 2a6919d:

============================= test session starts ==============================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /build/source
collected 20 items

mctx/_src/tests/mctx_test.py .                                           [  5%]
mctx/_src/tests/policies_test.py ....FF....                              [ 55%]
mctx/_src/tests/qtransforms_test.py ..                                   [ 65%]
mctx/_src/tests/seq_halving_test.py .......                              [100%]

=================================== FAILURES ===================================
____________________ PoliciesTest.test_gumbel_muzero_policy ____________________

self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy>

    def test_gumbel_muzero_policy(self):
      root_value = jnp.array([-5.0])
      root = mctx.RootFnOutput(
          prior_logits=jnp.array([
              [0.0, -1.0, 2.0, 3.0],
          ]),
          value=root_value,
          embedding=(),
      )
      rewards = jnp.array([
          [20.0, 3.0, -1.0, 10.0],
      ])
      invalid_actions = jnp.array([
          [1.0, 0.0, 0.0, 1.0],
      ])

      value_scale = 0.05
      maxvisit_init = 60
      num_simulations = 17
      max_depth = 3
      qtransform = functools.partial(
          mctx.qtransform_completed_by_mix_value,
          value_scale=value_scale,
          maxvisit_init=maxvisit_init,
          rescale_values=True)
      policy_output = mctx.gumbel_muzero_policy(
          params=(),
          rng_key=jax.random.PRNGKey(0),
          root=root,
          recurrent_fn=_make_bandit_recurrent_fn(rewards),
          num_simulations=num_simulations,
          invalid_actions=invalid_actions,
          max_depth=max_depth,
          qtransform=qtransform,
          gumbel_scale=1.0)
      # Testing the action.
      expected_action = jnp.array([1], dtype=jnp.int32)
>     np.testing.assert_array_equal(expected_action, policy_output.action)

mctx/_src/tests/policies_test.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (Array([1], dtype=int32), Array([2], dtype=int32)), kwargs = {}
old_name = 'y', new_name = 'desired'

    @functools.wraps(fun)
    def wrapper(*args, **kwargs):
        for old_name, new_name in zip(old_names, new_names):
            if old_name in kwargs:
                if dep_version:
                    end_version = dep_version.split('.')
                    end_version[1] = str(int(end_version[1]) + 2)
                    end_version = '.'.join(end_version)
                    msg = (f"Use of keyword argument `{old_name}` is "
                           f"deprecated and replaced by `{new_name}`. "
                           f"Support for `{old_name}` will be removed "
                           f"in NumPy {end_version}.")
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)
                if new_name in kwargs:
                    msg = (f"{fun.__name__}() got multiple values for "
                           f"argument now known as `{new_name}`")
                    raise TypeError(msg)
                kwargs[new_name] = kwargs.pop(old_name)
>       return fun(*args, **kwargs)
E       AssertionError:
E       Arrays are not equal
E
E       Mismatched elements: 1 / 1 (100%)
E       Max absolute difference among violations: 1
E       Max relative difference among violations: 0.5
E        ACTUAL: array([1], dtype=int32)
E        DESIRED: array([2], dtype=int32)

/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
________ PoliciesTest.test_gumbel_muzero_policy_without_invalid_actions ________

self = <policies_test.PoliciesTest testMethod=test_gumbel_muzero_policy_without_invalid_actions>

    def test_gumbel_muzero_policy_without_invalid_actions(self):
      root_value = jnp.array([-5.0])
      root = mctx.RootFnOutput(
          prior_logits=jnp.array([
              [0.0, -1.0, 2.0, 3.0],
          ]),
          value=root_value,
          embedding=(),
      )
      rewards = jnp.array([
          [20.0, 3.0, -1.0, 10.0],
      ])

      value_scale = 0.05
      maxvisit_init = 60
      num_simulations = 17
      max_depth = 3
      qtransform = functools.partial(
          mctx.qtransform_completed_by_mix_value,
          value_scale=value_scale,
          maxvisit_init=maxvisit_init,
          rescale_values=True)
      policy_output = mctx.gumbel_muzero_policy(
          params=(),
          rng_key=jax.random.PRNGKey(0),
          root=root,
          recurrent_fn=_make_bandit_recurrent_fn(rewards),
          num_simulations=num_simulations,
          invalid_actions=None,
          max_depth=max_depth,
          qtransform=qtransform,
          gumbel_scale=1.0)
      # Testing the action.
      expected_action = jnp.array([3], dtype=jnp.int32)
      np.testing.assert_array_equal(expected_action, policy_output.action)

      # Testing the action_weights.
      summary = policy_output.search_tree.summary()
      completed_qvalues = rewards
      max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
      min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
      total_value_scale = (maxvisit_init + summary.visit_counts.max()
                           ) * value_scale
      rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
          max_value - min_value)
      expected_action_weights = jax.nn.softmax(
          root.prior_logits + rescaled_qvalues)
      np.testing.assert_allclose(expected_action_weights,
                                 policy_output.action_weights,
                                 atol=1e-6)

      # Testing the visit_counts.
      expected_visit_counts = jnp.array(
          [[6, 2, 2, 7]])
>     np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)

mctx/_src/tests/policies_test.py:307:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (Array([[6, 2, 2, 7]], dtype=int32), Array([[2., 2., 6., 7.]], dtype=float32))
kwargs = {}, old_name = 'y', new_name = 'desired'

    @functools.wraps(fun)
    def wrapper(*args, **kwargs):
        for old_name, new_name in zip(old_names, new_names):
            if old_name in kwargs:
                if dep_version:
                    end_version = dep_version.split('.')
                    end_version[1] = str(int(end_version[1]) + 2)
                    end_version = '.'.join(end_version)
                    msg = (f"Use of keyword argument `{old_name}` is "
                           f"deprecated and replaced by `{new_name}`. "
                           f"Support for `{old_name}` will be removed "
                           f"in NumPy {end_version}.")
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)
                if new_name in kwargs:
                    msg = (f"{fun.__name__}() got multiple values for "
                           f"argument now known as `{new_name}`")
                    raise TypeError(msg)
                kwargs[new_name] = kwargs.pop(old_name)
>       return fun(*args, **kwargs)
E       AssertionError:
E       Arrays are not equal
E
E       Mismatched elements: 2 / 4 (50%)
E       Max absolute difference among violations: 4.
E       Max relative difference among violations: 2.
E        ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
E        DESIRED: array([[2., 2., 6., 7.]], dtype=float32)

/nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
=========================== short test summary info ============================
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
======================== 2 failed, 18 passed in 14.88s =========================
error: builder for '/nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv' failed with exit code 1;
       last 25 log lines:
       >                     msg = (f"Use of keyword argument `{old_name}` is "
       >                            f"deprecated and replaced by `{new_name}`. "
       >                            f"Support for `{old_name}` will be removed "
       >                            f"in NumPy {end_version}.")
       >                     warnings.warn(msg, DeprecationWarning, stacklevel=2)
       >                 if new_name in kwargs:
       >                     msg = (f"{fun.__name__}() got multiple values for "
       >                            f"argument now known as `{new_name}`")
       >                     raise TypeError(msg)
       >                 kwargs[new_name] = kwargs.pop(old_name)
       > >       return fun(*args, **kwargs)
       > E       AssertionError:
       > E       Arrays are not equal
       > E
       > E       Mismatched elements: 2 / 4 (50%)
       > E       Max absolute difference among violations: 4.
       > E       Max relative difference among violations: 2.
       > E        ACTUAL: array([[6, 2, 2, 7]], dtype=int32)
       > E        DESIRED: array([[2., 2., 6., 7.]], dtype=float32)
       >
       > /nix/store/s3k7qby931y3hc7b2phvyay054idkfcg-python3.12-numpy-2.2.3/lib/python3.12/site-packages/numpy/_utils/__init__.py:85: AssertionError
       > =========================== short test summary info ============================
       > FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy - AssertionError:
       > FAILED mctx/_src/tests/policies_test.py::PoliciesTest::test_gumbel_muzero_policy_without_invalid_actions - AssertionError:
       > ======================== 2 failed, 18 passed in 14.88s =========================
       For full logs, run 'nix log /nix/store/08wrk930cpjywd59b6x8vyc21p6wc25m-python3.12-mctx-0-unstable-2025-04-04.drv'.
error: 1 dependencies of derivation '/nix/store/1g944w7bws4kkqc3zcryhj4yk8an05bf-python3-3.12.9-env.drv' failed to build

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions