Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._task_manager = task_manager
self._config = config or config_classes.CoordinatorConfig()
self._device_settings = device_settings
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._adb_call_parser: adb_call_parser.AdbCallParser | None = None

# Initialize stats.
self._stats = {
Expand All @@ -72,7 +72,7 @@ def __init__(

# Initialize counters.
self._simulator_healthy = False
self._latest_observation_time = 0
self._latest_observation_time = 0.0
self._simulator_start_time = None

logging.info('Starting the simulator...')
Expand Down Expand Up @@ -145,7 +145,7 @@ def _launch_simulator(self, max_retries: int = 3):
except errors.AdbControllerError as e:
logging.exception('device_settings.update() failed.')
self._stats['relaunch_count_update_settings'] += 1
self._latest_error = e
latest_error = e
num_tries += 1
continue

Expand Down Expand Up @@ -175,6 +175,10 @@ def _create_adb_call_parser(self):
)

def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
assert self._adb_call_parser is not None, (
'_adb_call_parser is None. `start()` or `launch()` must be called '
'before executing ADB calls.'
)
return self._adb_call_parser.parse(call)

def rl_reset(self) -> dm_env.TimeStep:
Expand All @@ -185,10 +189,10 @@ def rl_reset(self) -> dm_env.TimeStep:
self._launch_simulator()

# Reset counters.
self._latest_observation_time = 0
self._latest_observation_time = 0.0
for key in self._stats:
if key.startswith('episode'):
self._stats[key] = 0.0
self._stats[key] = 0

# Execute a lift action before resetting the task.
if not action_fns.send_action_to_simulator(
Expand Down Expand Up @@ -231,6 +235,7 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False
return dm_env.truncation(reward=0.0, observation=None)

# Get data from the simulator.
try:
Expand All @@ -239,8 +244,6 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
logging.exception('Unable to fetch observation. Restarting simulator.')
self._stats['relaunch_count_fetch_observation'] += 1
self._simulator_healthy = False

if not self._simulator_healthy:
return dm_env.truncation(reward=0.0, observation=None)

return self._task_manager.rl_step(simulator_signals)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def confirm_shutdown(self) -> None:
self._emulator.returncode)
self._emulator.kill()
self._emulator = None

if self._emulator_output is not None:
self._emulator_output.close()
logging.info('The emulator process has finished.')
self._emulator_output = None
logging.info('The emulator process has finished.')

def close(self):
"""Clean up launcher files and processes."""
Expand Down
52 changes: 36 additions & 16 deletions android_env/components/simulators/emulator/emulator_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""A class that manages an Android Emulator."""

from collections.abc import Callable
import functools
import os
import time
from typing import Any
Expand Down Expand Up @@ -94,14 +96,42 @@ class EmulatorCrashError(errors.SimulatorError):
"""Raised when a simulator crashed."""


def _reconnect_on_grpc_error(
func: Callable[..., Any],
) -> Callable[..., Any]:
"""Decorates a function to reconnect to the emulator upon gRPC errors."""

@functools.wraps(func)
def wrapper(self: 'EmulatorSimulator', *args: Any, **kwargs: Any) -> Any:
try:
return func(self, *args, **kwargs)
except grpc.RpcError as error:
logging.exception(
'RpcError caught while calling %s with args=%s, kwargs=%s. '
'Error details: %s. Reconnecting to emulator...',
func.__name__,
args,
kwargs,
error,
)
# pylint: disable=protected-access
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
# pylint: enable=protected-access
return func(self, *args, **kwargs)

return wrapper


class EmulatorSimulator(base_simulator.BaseSimulator):
"""Controls an Android Emulator."""

def __init__(self, config: config_classes.EmulatorConfig):
"""Instantiates an EmulatorSimulator."""

super().__init__(config)
self._config = config
self._config: config_classes.EmulatorConfig = config

# If adb_port, console_port and grpc_port are all already provided,
# we assume the emulator already exists and there's no need to launch.
Expand Down Expand Up @@ -163,21 +193,6 @@ def __init__(self, config: config_classes.EmulatorConfig):
self._config.logfile_path or self._launcher.logfile_path()
)

def _reconnect_on_grpc_error(func):
"""Decorator function for reconnecting to emulator upon grpc errors."""

def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except grpc.RpcError:
logging.exception('RpcError caught. Reconnecting to emulator...')
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
return func(self, *args, **kwargs)

return wrapper

def get_logs(self) -> str:
"""Returns logs recorded by the emulator."""
if self._logfile_path and os.path.exists(self._logfile_path):
Expand Down Expand Up @@ -338,6 +353,11 @@ def _connect_to_emulator(
]:
"""Connects to an emulator and returns a corresponsing stub."""

if hasattr(self, '_channel') and self._channel is not None:
logging.info('Closing previous gRPC channel before reconnecting.')
self._channel.close()
self._channel = None

logging.info('Creating gRPC channel to the emulator on port %r', grpc_port)
port = f'localhost:{grpc_port}'
options = [('grpc.max_send_message_length', -1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,11 @@ def test_get_screenshot(self):
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.addCleanup(simulator.close)

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.getScreenshot = mock.MagicMock(
return_value=emulator_controller_pb2.Image(
format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
Expand All @@ -320,6 +321,53 @@ def test_get_screenshot(self):
# and it should have 3 channels (RGB).
self.assertEqual(screenshot.shape, (1234, 5678, 3))

def test_get_screenshot_reconnects_on_grpc_error(self):
config = config_classes.EmulatorConfig(
interaction_rate_sec=0.0,
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.addCleanup(simulator.close)
simulator.launch()
self.assertIsNotNone(simulator._emulator_stub)

# Mock first call to raise grpc.RpcError, second call to succeed.
side_effect = [
grpc.RpcError('gRPC error'),
emulator_controller_pb2.Image(
format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
image=Image.new('RGBA', (1234, 5678)).tobytes(),
timestampUs=123,
),
]
get_screenshot_mock = mock.MagicMock(side_effect=side_effect)
simulator._emulator_stub.getScreenshot = get_screenshot_mock

# Mock _connect_to_emulator to return the same mock stub on reconnect.
mock_connect = self.enter_context(
mock.patch.object(simulator, '_connect_to_emulator', autospec=True)
)
mock_connect.return_value = (
simulator._emulator_stub,
simulator._snapshot_stub,
)

screenshot = simulator.get_screenshot()

# Assertions:
self.assertEqual(screenshot.shape, (1234, 5678, 3))
mock_connect.assert_called_once_with(
simulator._config.emulator_launcher.grpc_port
)
# Verify getScreenshot was called twice (first failed, second succeeded).
self.assertEqual(simulator._emulator_stub.getScreenshot.call_count, 2)

def test_load_state(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
Expand Down Expand Up @@ -430,7 +478,7 @@ def test_send_touch(self):

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)

simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)])
Expand Down Expand Up @@ -487,7 +535,7 @@ def test_send_key(self):

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)

simulator.send_key(123, 'keydown')
Expand Down
52 changes: 34 additions & 18 deletions android_env/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@
import numpy as np


def _create_mock_coordinator() -> coordinator_lib.Coordinator:
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = {
'action_type':
dm_env.specs.DiscreteArray(num_values=3),
'touch_position':
dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
def _get_fake_action_spec() -> dict[str, dm_env.specs.Array]:
return {
'action_type': dm_env.specs.DiscreteArray(num_values=3),
'touch_position': dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0
),
}
coordinator.observation_spec.return_value = {


def _get_fake_observation_spec() -> dict[str, dm_env.specs.Array]:
return {
'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
}
return coordinator


def _create_fake_simulator() -> fake_simulator.FakeSimulator:
Expand All @@ -57,7 +57,9 @@ class AndroidEnvTest(absltest.TestCase):

def test_specs(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand Down Expand Up @@ -92,7 +94,7 @@ def test_specs(self):

def test_reset_and_step(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
coordinator.action_spec.return_value = {
'action_type':
Expand Down Expand Up @@ -161,7 +163,10 @@ def test_reset_and_step(self):
discount=0.0,
observation=latest_observation,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
ts = env.step({
'action_type': np.array(1),
'touch_position': np.array((10, 20)),
})
self.assertIsInstance(ts, dm_env.TimeStep)
# The StepType now should NOT be FIRST.
self.assertFalse(ts.first())
Expand Down Expand Up @@ -192,7 +197,10 @@ def test_reset_and_step(self):
reward=0.0,
observation=None,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
ts = env.step({
'action_type': np.array(1),
'touch_position': np.array((10, 20)),
})
self.assertIsInstance(ts, dm_env.TimeStep)
# Assert the observation matches the latest observation.
obs = ts.observation
Expand All @@ -207,7 +215,9 @@ def test_reset_and_step(self):

def test_adb_call(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -225,7 +235,9 @@ def test_adb_call(self):

def test_load_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -243,7 +255,9 @@ def test_load_state(self):

def test_save_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand All @@ -259,7 +273,9 @@ def test_save_state(self):

def test_double_close(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = _get_fake_action_spec()
coordinator.observation_spec.return_value = _get_fake_observation_spec()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
Expand Down
Loading