From fbc0b825943c186e92fc00f73e7f1fd4b2ef8779 Mon Sep 17 00:00:00 2001 From: Daniel Toyama Date: Wed, 20 May 2026 00:33:55 -0700 Subject: [PATCH] Refactor gRPC reconnection channels to explicitly close previous sockets, preventing file descriptor leaks during emulator restarts. Improve reconnect decorator logging and clean up background test thread leaks. PiperOrigin-RevId: 918274262 --- .../simulators/emulator/emulator_simulator.py | 52 ++++++++++++------ .../emulator/emulator_simulator_test.py | 54 +++++++++++++++++-- 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/android_env/components/simulators/emulator/emulator_simulator.py b/android_env/components/simulators/emulator/emulator_simulator.py index 2c612ac9..0616a7b1 100644 --- a/android_env/components/simulators/emulator/emulator_simulator.py +++ b/android_env/components/simulators/emulator/emulator_simulator.py @@ -15,6 +15,8 @@ """A class that manages an Android Emulator.""" +from collections.abc import Callable +import functools import os import time from typing import Any @@ -94,6 +96,34 @@ class EmulatorCrashError(errors.SimulatorError): """Raised when a simulator crashed.""" +def _reconnect_on_grpc_error( + func: Callable[..., Any], +) -> Callable[..., Any]: + """Decorates a function to reconnect to the emulator upon gRPC errors.""" + + @functools.wraps(func) + def wrapper(self: 'EmulatorSimulator', *args: Any, **kwargs: Any) -> Any: + try: + return func(self, *args, **kwargs) + except grpc.RpcError as error: + logging.exception( + 'RpcError caught while calling %s with args=%s, kwargs=%s. ' + 'Error details: %s. Reconnecting to emulator...', + func.__name__, + args, + kwargs, + error, + ) + # pylint: disable=protected-access + self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( + self._config.emulator_launcher.grpc_port + ) + # pylint: enable=protected-access + return func(self, *args, **kwargs) + + return wrapper + + class EmulatorSimulator(base_simulator.BaseSimulator): """Controls an Android Emulator.""" @@ -101,7 +131,7 @@ def __init__(self, config: config_classes.EmulatorConfig): """Instantiates an EmulatorSimulator.""" super().__init__(config) - self._config = config + self._config: config_classes.EmulatorConfig = config # If adb_port, console_port and grpc_port are all already provided, # we assume the emulator already exists and there's no need to launch. @@ -163,21 +193,6 @@ def __init__(self, config: config_classes.EmulatorConfig): self._config.logfile_path or self._launcher.logfile_path() ) - def _reconnect_on_grpc_error(func): - """Decorator function for reconnecting to emulator upon grpc errors.""" - - def wrapper(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except grpc.RpcError: - logging.exception('RpcError caught. Reconnecting to emulator...') - self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( - self._config.emulator_launcher.grpc_port - ) - return func(self, *args, **kwargs) - - return wrapper - def get_logs(self) -> str: """Returns logs recorded by the emulator.""" if self._logfile_path and os.path.exists(self._logfile_path): @@ -338,6 +353,11 @@ def _connect_to_emulator( ]: """Connects to an emulator and returns a corresponsing stub.""" + if hasattr(self, '_channel') and self._channel is not None: + logging.info('Closing previous gRPC channel before reconnecting.') + self._channel.close() + self._channel = None + logging.info('Creating gRPC channel to the emulator on port %r', grpc_port) port = f'localhost:{grpc_port}' options = [('grpc.max_send_message_length', -1), diff --git a/android_env/components/simulators/emulator/emulator_simulator_test.py b/android_env/components/simulators/emulator/emulator_simulator_test.py index 12bb2d5d..7bc937e9 100644 --- a/android_env/components/simulators/emulator/emulator_simulator_test.py +++ b/android_env/components/simulators/emulator/emulator_simulator_test.py @@ -305,10 +305,11 @@ def test_get_screenshot(self): ), ) simulator = emulator_simulator.EmulatorSimulator(config) + self.addCleanup(simulator.close) # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.getScreenshot = mock.MagicMock( return_value=emulator_controller_pb2.Image( format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), @@ -320,6 +321,53 @@ def test_get_screenshot(self): # and it should have 3 channels (RGB). self.assertEqual(screenshot.shape, (1234, 5678, 3)) + def test_get_screenshot_reconnects_on_grpc_error(self): + config = config_classes.EmulatorConfig( + interaction_rate_sec=0.0, + emulator_launcher=config_classes.EmulatorLauncherConfig( + grpc_port=1234, tmp_dir=self.create_tempdir().full_path + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path='/my/adb', + adb_server_port=5037, + ), + ) + simulator = emulator_simulator.EmulatorSimulator(config) + self.addCleanup(simulator.close) + simulator.launch() + self.assertIsNotNone(simulator._emulator_stub) + + # Mock first call to raise grpc.RpcError, second call to succeed. + side_effect = [ + grpc.RpcError('gRPC error'), + emulator_controller_pb2.Image( + format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), + image=Image.new('RGBA', (1234, 5678)).tobytes(), + timestampUs=123, + ), + ] + get_screenshot_mock = mock.MagicMock(side_effect=side_effect) + simulator._emulator_stub.getScreenshot = get_screenshot_mock + + # Mock _connect_to_emulator to return the same mock stub on reconnect. + mock_connect = self.enter_context( + mock.patch.object(simulator, '_connect_to_emulator', autospec=True) + ) + mock_connect.return_value = ( + simulator._emulator_stub, + simulator._snapshot_stub, + ) + + screenshot = simulator.get_screenshot() + + # Assertions: + self.assertEqual(screenshot.shape, (1234, 5678, 3)) + mock_connect.assert_called_once_with( + simulator._config.emulator_launcher.grpc_port + ) + # Verify getScreenshot was called twice (first failed, second succeeded). + self.assertEqual(simulator._emulator_stub.getScreenshot.call_count, 2) + def test_load_state(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( @@ -430,7 +478,7 @@ def test_send_touch(self): # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)]) @@ -487,7 +535,7 @@ def test_send_key(self): # The simulator should launch and not crash. simulator.launch() - + self.assertIsNotNone(simulator._emulator_stub) simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_key(123, 'keydown')