From 7d6f96482236a2c23edf1fbd4e3dec3eeda16779 Mon Sep 17 00:00:00 2001 From: Daniel Toyama Date: Wed, 20 May 2026 00:38:42 -0700 Subject: [PATCH] Enable Pyrefly type checking for root AndroidEnv and AndroidEnvInterface API classes. PiperOrigin-RevId: 918276061 --- android_env/environment_test.py | 52 +++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/android_env/environment_test.py b/android_env/environment_test.py index cf0e3876..94e808be 100644 --- a/android_env/environment_test.py +++ b/android_env/environment_test.py @@ -30,21 +30,21 @@ import numpy as np -def _create_mock_coordinator() -> coordinator_lib.Coordinator: - coordinator = mock.create_autospec(coordinator_lib.Coordinator) - coordinator.action_spec.return_value = { - 'action_type': - dm_env.specs.DiscreteArray(num_values=3), - 'touch_position': - dm_env.specs.BoundedArray( - shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), +def _get_fake_action_spec() -> dict[str, dm_env.specs.Array]: + return { + 'action_type': dm_env.specs.DiscreteArray(num_values=3), + 'touch_position': dm_env.specs.BoundedArray( + shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0 + ), } - coordinator.observation_spec.return_value = { + + +def _get_fake_observation_spec() -> dict[str, dm_env.specs.Array]: + return { 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), } - return coordinator def _create_fake_simulator() -> fake_simulator.FakeSimulator: @@ -57,7 +57,9 @@ class AndroidEnvTest(absltest.TestCase): def test_specs(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -92,7 +94,7 @@ def test_specs(self): def test_reset_and_step(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) task_manager = mock.create_autospec(task_manager_lib.TaskManager) coordinator.action_spec.return_value = { 'action_type': @@ -161,7 +163,10 @@ def test_reset_and_step(self): discount=0.0, observation=latest_observation, ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) + ts = env.step({ + 'action_type': np.array(1), + 'touch_position': np.array((10, 20)), + }) self.assertIsInstance(ts, dm_env.TimeStep) # The StepType now should NOT be FIRST. self.assertFalse(ts.first()) @@ -192,7 +197,10 @@ def test_reset_and_step(self): reward=0.0, observation=None, ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) + ts = env.step({ + 'action_type': np.array(1), + 'touch_position': np.array((10, 20)), + }) self.assertIsInstance(ts, dm_env.TimeStep) # Assert the observation matches the latest observation. obs = ts.observation @@ -207,7 +215,9 @@ def test_reset_and_step(self): def test_adb_call(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -225,7 +235,9 @@ def test_adb_call(self): def test_load_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -243,7 +255,9 @@ def test_load_state(self): def test_save_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -259,7 +273,9 @@ def test_save_state(self): def test_double_close(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager