Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions android_env/components/simulators/emulator/emulator_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""A class that manages an Android Emulator."""

from collections.abc import Callable
import functools
import os
import time
from typing import Any
Expand Down Expand Up @@ -94,14 +96,42 @@ class EmulatorCrashError(errors.SimulatorError):
"""Raised when a simulator crashed."""


def _reconnect_on_grpc_error(
func: Callable[..., Any],
) -> Callable[..., Any]:
"""Decorates a function to reconnect to the emulator upon gRPC errors."""

@functools.wraps(func)
def wrapper(self: 'EmulatorSimulator', *args: Any, **kwargs: Any) -> Any:
try:
return func(self, *args, **kwargs)
except grpc.RpcError as error:
logging.exception(
'RpcError caught while calling %s with args=%s, kwargs=%s. '
'Error details: %s. Reconnecting to emulator...',
func.__name__,
args,
kwargs,
error,
)
# pylint: disable=protected-access
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
# pylint: enable=protected-access
return func(self, *args, **kwargs)

return wrapper


class EmulatorSimulator(base_simulator.BaseSimulator):
"""Controls an Android Emulator."""

def __init__(self, config: config_classes.EmulatorConfig):
"""Instantiates an EmulatorSimulator."""

super().__init__(config)
self._config = config
self._config: config_classes.EmulatorConfig = config

# If adb_port, console_port and grpc_port are all already provided,
# we assume the emulator already exists and there's no need to launch.
Expand Down Expand Up @@ -163,21 +193,6 @@ def __init__(self, config: config_classes.EmulatorConfig):
self._config.logfile_path or self._launcher.logfile_path()
)

def _reconnect_on_grpc_error(func):
"""Decorator function for reconnecting to emulator upon grpc errors."""

def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except grpc.RpcError:
logging.exception('RpcError caught. Reconnecting to emulator...')
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
return func(self, *args, **kwargs)

return wrapper

def get_logs(self) -> str:
"""Returns logs recorded by the emulator."""
if self._logfile_path and os.path.exists(self._logfile_path):
Expand Down Expand Up @@ -338,6 +353,11 @@ def _connect_to_emulator(
]:
"""Connects to an emulator and returns a corresponsing stub."""

if hasattr(self, '_channel') and self._channel is not None:
logging.info('Closing previous gRPC channel before reconnecting.')
self._channel.close()
self._channel = None

logging.info('Creating gRPC channel to the emulator on port %r', grpc_port)
port = f'localhost:{grpc_port}'
options = [('grpc.max_send_message_length', -1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,11 @@ def test_get_screenshot(self):
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.addCleanup(simulator.close)

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.getScreenshot = mock.MagicMock(
return_value=emulator_controller_pb2.Image(
format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
Expand All @@ -320,6 +321,53 @@ def test_get_screenshot(self):
# and it should have 3 channels (RGB).
self.assertEqual(screenshot.shape, (1234, 5678, 3))

def test_get_screenshot_reconnects_on_grpc_error(self):
config = config_classes.EmulatorConfig(
interaction_rate_sec=0.0,
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.addCleanup(simulator.close)
simulator.launch()
self.assertIsNotNone(simulator._emulator_stub)

# Mock first call to raise grpc.RpcError, second call to succeed.
side_effect = [
grpc.RpcError('gRPC error'),
emulator_controller_pb2.Image(
format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
image=Image.new('RGBA', (1234, 5678)).tobytes(),
timestampUs=123,
),
]
get_screenshot_mock = mock.MagicMock(side_effect=side_effect)
simulator._emulator_stub.getScreenshot = get_screenshot_mock

# Mock _connect_to_emulator to return the same mock stub on reconnect.
mock_connect = self.enter_context(
mock.patch.object(simulator, '_connect_to_emulator', autospec=True)
)
mock_connect.return_value = (
simulator._emulator_stub,
simulator._snapshot_stub,
)

screenshot = simulator.get_screenshot()

# Assertions:
self.assertEqual(screenshot.shape, (1234, 5678, 3))
mock_connect.assert_called_once_with(
simulator._config.emulator_launcher.grpc_port
)
# Verify getScreenshot was called twice (first failed, second succeeded).
self.assertEqual(simulator._emulator_stub.getScreenshot.call_count, 2)

def test_load_state(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
Expand Down Expand Up @@ -430,7 +478,7 @@ def test_send_touch(self):

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)

simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)])
Expand Down Expand Up @@ -487,7 +535,7 @@ def test_send_key(self):

# The simulator should launch and not crash.
simulator.launch()

self.assertIsNotNone(simulator._emulator_stub)
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)

simulator.send_key(123, 'keydown')
Expand Down
Loading