From 70247bfb86bf53e3258e5d57f9e3ecf8505c6722 Mon Sep 17 00:00:00 2001 From: Daniel Toyama Date: Tue, 19 May 2026 14:15:35 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 918023705 --- android_env/components/coordinator.py | 17 +++--- .../simulators/emulator/emulator_launcher.py | 5 +- .../simulators/emulator/emulator_simulator.py | 52 ++++++++++++------ .../emulator/emulator_simulator_test.py | 54 +++++++++++++++++-- android_env/environment_test.py | 52 +++++++++++------- 5 files changed, 135 insertions(+), 45 deletions(-) diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py index 90506574..31b9313b 100644 --- a/android_env/components/coordinator.py +++ b/android_env/components/coordinator.py @@ -54,7 +54,7 @@ def __init__( self._task_manager = task_manager self._config = config or config_classes.CoordinatorConfig() self._device_settings = device_settings - self._adb_call_parser: adb_call_parser.AdbCallParser = None + self._adb_call_parser: adb_call_parser.AdbCallParser | None = None # Initialize stats. self._stats = { @@ -72,7 +72,7 @@ def __init__( # Initialize counters. self._simulator_healthy = False - self._latest_observation_time = 0 + self._latest_observation_time = 0.0 self._simulator_start_time = None logging.info('Starting the simulator...') @@ -145,7 +145,7 @@ def _launch_simulator(self, max_retries: int = 3): except errors.AdbControllerError as e: logging.exception('device_settings.update() failed.') self._stats['relaunch_count_update_settings'] += 1 - self._latest_error = e + latest_error = e num_tries += 1 continue @@ -175,6 +175,10 @@ def _create_adb_call_parser(self): ) def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: + assert self._adb_call_parser is not None, ( + '_adb_call_parser is None. `start()` or `launch()` must be called ' + 'before executing ADB calls.' + ) return self._adb_call_parser.parse(call) def rl_reset(self) -> dm_env.TimeStep: @@ -185,10 +189,10 @@ def rl_reset(self) -> dm_env.TimeStep: self._launch_simulator() # Reset counters. - self._latest_observation_time = 0 + self._latest_observation_time = 0.0 for key in self._stats: if key.startswith('episode'): - self._stats[key] = 0.0 + self._stats[key] = 0 # Execute a lift action before resetting the task. if not action_fns.send_action_to_simulator( @@ -231,6 +235,7 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep: ): self._stats['relaunch_count_execute_action'] += 1 self._simulator_healthy = False + return dm_env.truncation(reward=0.0, observation=None) # Get data from the simulator. try: @@ -239,8 +244,6 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep: logging.exception('Unable to fetch observation. Restarting simulator.') self._stats['relaunch_count_fetch_observation'] += 1 self._simulator_healthy = False - - if not self._simulator_healthy: return dm_env.truncation(reward=0.0, observation=None) return self._task_manager.rl_step(simulator_signals) diff --git a/android_env/components/simulators/emulator/emulator_launcher.py b/android_env/components/simulators/emulator/emulator_launcher.py index 10f01d28..bca135b4 100644 --- a/android_env/components/simulators/emulator/emulator_launcher.py +++ b/android_env/components/simulators/emulator/emulator_launcher.py @@ -155,8 +155,11 @@ def confirm_shutdown(self) -> None: self._emulator.returncode) self._emulator.kill() self._emulator = None + + if self._emulator_output is not None: self._emulator_output.close() - logging.info('The emulator process has finished.') + self._emulator_output = None + logging.info('The emulator process has finished.') def close(self): """Clean up launcher files and processes.""" diff --git a/android_env/components/simulators/emulator/emulator_simulator.py b/android_env/components/simulators/emulator/emulator_simulator.py index 2c612ac9..0616a7b1 100644 --- a/android_env/components/simulators/emulator/emulator_simulator.py +++ b/android_env/components/simulators/emulator/emulator_simulator.py @@ -15,6 +15,8 @@ """A class that manages an Android Emulator.""" +from collections.abc import Callable +import functools import os import time from typing import Any @@ -94,6 +96,34 @@ class EmulatorCrashError(errors.SimulatorError): """Raised when a simulator crashed.""" +def _reconnect_on_grpc_error( + func: Callable[..., Any], +) -> Callable[..., Any]: + """Decorates a function to reconnect to the emulator upon gRPC errors.""" + + @functools.wraps(func) + def wrapper(self: 'EmulatorSimulator', *args: Any, **kwargs: Any) -> Any: + try: + return func(self, *args, **kwargs) + except grpc.RpcError as error: + logging.exception( + 'RpcError caught while calling %s with args=%s, kwargs=%s. ' + 'Error details: %s. Reconnecting to emulator...', + func.__name__, + args, + kwargs, + error, + ) + # pylint: disable=protected-access + self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( + self._config.emulator_launcher.grpc_port + ) + # pylint: enable=protected-access + return func(self, *args, **kwargs) + + return wrapper + + class EmulatorSimulator(base_simulator.BaseSimulator): """Controls an Android Emulator.""" @@ -101,7 +131,7 @@ def __init__(self, config: config_classes.EmulatorConfig): """Instantiates an EmulatorSimulator.""" super().__init__(config) - self._config = config + self._config: config_classes.EmulatorConfig = config # If adb_port, console_port and grpc_port are all already provided, # we assume the emulator already exists and there's no need to launch. @@ -163,21 +193,6 @@ def __init__(self, config: config_classes.EmulatorConfig): self._config.logfile_path or self._launcher.logfile_path() ) - def _reconnect_on_grpc_error(func): - """Decorator function for reconnecting to emulator upon grpc errors.""" - - def wrapper(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except grpc.RpcError: - logging.exception('RpcError caught. Reconnecting to emulator...') - self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( - self._config.emulator_launcher.grpc_port - ) - return func(self, *args, **kwargs) - - return wrapper - def get_logs(self) -> str: """Returns logs recorded by the emulator.""" if self._logfile_path and os.path.exists(self._logfile_path): @@ -338,6 +353,11 @@ def _connect_to_emulator( ]: """Connects to an emulator and returns a corresponsing stub.""" + if hasattr(self, '_channel') and self._channel is not None: + logging.info('Closing previous gRPC channel before reconnecting.') + self._channel.close() + self._channel = None + logging.info('Creating gRPC channel to the emulator on port %r', grpc_port) port = f'localhost:{grpc_port}' options = [('grpc.max_send_message_length', -1), diff --git a/android_env/components/simulators/emulator/emulator_simulator_test.py b/android_env/components/simulators/emulator/emulator_simulator_test.py index 12bb2d5d..7bc937e9 100644 --- a/android_env/components/simulators/emulator/emulator_simulator_test.py +++ b/android_env/components/simulators/emulator/emulator_simulator_test.py @@ -305,10 +305,11 @@ def test_get_screenshot(self): ), ) simulator = emulator_simulator.EmulatorSimulator(config) + self.addCleanup(simulator.close) # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.getScreenshot = mock.MagicMock( return_value=emulator_controller_pb2.Image( format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), @@ -320,6 +321,53 @@ def test_get_screenshot(self): # and it should have 3 channels (RGB). self.assertEqual(screenshot.shape, (1234, 5678, 3)) + def test_get_screenshot_reconnects_on_grpc_error(self): + config = config_classes.EmulatorConfig( + interaction_rate_sec=0.0, + emulator_launcher=config_classes.EmulatorLauncherConfig( + grpc_port=1234, tmp_dir=self.create_tempdir().full_path + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path='/my/adb', + adb_server_port=5037, + ), + ) + simulator = emulator_simulator.EmulatorSimulator(config) + self.addCleanup(simulator.close) + simulator.launch() + self.assertIsNotNone(simulator._emulator_stub) + + # Mock first call to raise grpc.RpcError, second call to succeed. + side_effect = [ + grpc.RpcError('gRPC error'), + emulator_controller_pb2.Image( + format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), + image=Image.new('RGBA', (1234, 5678)).tobytes(), + timestampUs=123, + ), + ] + get_screenshot_mock = mock.MagicMock(side_effect=side_effect) + simulator._emulator_stub.getScreenshot = get_screenshot_mock + + # Mock _connect_to_emulator to return the same mock stub on reconnect. + mock_connect = self.enter_context( + mock.patch.object(simulator, '_connect_to_emulator', autospec=True) + ) + mock_connect.return_value = ( + simulator._emulator_stub, + simulator._snapshot_stub, + ) + + screenshot = simulator.get_screenshot() + + # Assertions: + self.assertEqual(screenshot.shape, (1234, 5678, 3)) + mock_connect.assert_called_once_with( + simulator._config.emulator_launcher.grpc_port + ) + # Verify getScreenshot was called twice (first failed, second succeeded). + self.assertEqual(simulator._emulator_stub.getScreenshot.call_count, 2) + def test_load_state(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( @@ -430,7 +478,7 @@ def test_send_touch(self): # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)]) @@ -487,7 +535,7 @@ def test_send_key(self): # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_key(123, 'keydown') diff --git a/android_env/environment_test.py b/android_env/environment_test.py index cf0e3876..94e808be 100644 --- a/android_env/environment_test.py +++ b/android_env/environment_test.py @@ -30,21 +30,21 @@ import numpy as np -def _create_mock_coordinator() -> coordinator_lib.Coordinator: - coordinator = mock.create_autospec(coordinator_lib.Coordinator) - coordinator.action_spec.return_value = { - 'action_type': - dm_env.specs.DiscreteArray(num_values=3), - 'touch_position': - dm_env.specs.BoundedArray( - shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), +def _get_fake_action_spec() -> dict[str, dm_env.specs.Array]: + return { + 'action_type': dm_env.specs.DiscreteArray(num_values=3), + 'touch_position': dm_env.specs.BoundedArray( + shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0 + ), } - coordinator.observation_spec.return_value = { + + +def _get_fake_observation_spec() -> dict[str, dm_env.specs.Array]: + return { 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), } - return coordinator def _create_fake_simulator() -> fake_simulator.FakeSimulator: @@ -57,7 +57,9 @@ class AndroidEnvTest(absltest.TestCase): def test_specs(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -92,7 +94,7 @@ def test_specs(self): def test_reset_and_step(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) task_manager = mock.create_autospec(task_manager_lib.TaskManager) coordinator.action_spec.return_value = { 'action_type': @@ -161,7 +163,10 @@ def test_reset_and_step(self): discount=0.0, observation=latest_observation, ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) + ts = env.step({ + 'action_type': np.array(1), + 'touch_position': np.array((10, 20)), + }) self.assertIsInstance(ts, dm_env.TimeStep) # The StepType now should NOT be FIRST. self.assertFalse(ts.first()) @@ -192,7 +197,10 @@ def test_reset_and_step(self): reward=0.0, observation=None, ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) + ts = env.step({ + 'action_type': np.array(1), + 'touch_position': np.array((10, 20)), + }) self.assertIsInstance(ts, dm_env.TimeStep) # Assert the observation matches the latest observation. obs = ts.observation @@ -207,7 +215,9 @@ def test_reset_and_step(self): def test_adb_call(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -225,7 +235,9 @@ def test_adb_call(self): def test_load_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -243,7 +255,9 @@ def test_load_state(self): def test_save_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager @@ -259,7 +273,9 @@ def test_save_state(self): def test_double_close(self): simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() + coordinator = mock.create_autospec(coordinator_lib.Coordinator) + coordinator.action_spec.return_value = _get_fake_action_spec() + coordinator.observation_spec.return_value = _get_fake_observation_spec() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager