Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions android_env/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@
import numpy as np


def _create_mock_coordinator() -> coordinator_lib.Coordinator:
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = {
'action_type':
dm_env.specs.DiscreteArray(num_values=3),
'touch_position':
dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
def _get_fake_action_spec() -> dict[str, dm_env.specs.Array]:
return {
'action_type': dm_env.specs.DiscreteArray(num_values=3),
'touch_position': dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0
),
}
coordinator.observation_spec.return_value = {


def _get_fake_observation_spec() -> dict[str, dm_env.specs.Array]:
return {
'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
}
return coordinator


def _create_fake_simulator() -> fake_simulator.FakeSimulator:
Expand All @@ -57,7 +57,9 @@ class AndroidEnvTest(absltest.TestCase):

def test_specs(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand Down Expand Up @@ -92,7 +94,7 @@ def test_specs(self):

def test_reset_and_step(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
coordinator.action_spec.return_value = {
'action_type':
Expand Down Expand Up @@ -161,7 +163,10 @@ def test_reset_and_step(self):
discount=0.0,
observation=latest_observation,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
ts = env.step({
'action_type': np.array(1),
'touch_position': np.array((10, 20)),
})
self.assertIsInstance(ts, dm_env.TimeStep)
# The StepType now should NOT be FIRST.
self.assertFalse(ts.first())
Expand Down Expand Up @@ -192,7 +197,10 @@ def test_reset_and_step(self):
reward=0.0,
observation=None,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
ts = env.step({
'action_type': np.array(1),
'touch_position': np.array((10, 20)),
})
self.assertIsInstance(ts, dm_env.TimeStep)
# Assert the observation matches the latest observation.
obs = ts.observation
Expand All @@ -207,7 +215,9 @@ def test_reset_and_step(self):

def test_adb_call(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -225,7 +235,9 @@ def test_adb_call(self):

def test_load_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -243,7 +255,9 @@ def test_load_state(self):

def test_save_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -259,7 +273,9 @@ def test_save_state(self):

def test_double_close(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand Down
Loading