Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._task_manager = task_manager
self._config = config or config_classes.CoordinatorConfig()
self._device_settings = device_settings
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._adb_call_parser: adb_call_parser.AdbCallParser | None = None

# Initialize stats.
self._stats = {
Expand All @@ -72,7 +72,7 @@ def __init__(

# Initialize counters.
self._simulator_healthy = False
self._latest_observation_time = 0
self._latest_observation_time = 0.0
self._simulator_start_time = None

logging.info('Starting the simulator...')
Expand Down Expand Up @@ -145,7 +145,7 @@ def _launch_simulator(self, max_retries: int = 3):
except errors.AdbControllerError as e:
logging.exception('device_settings.update() failed.')
self._stats['relaunch_count_update_settings'] += 1
self._latest_error = e
latest_error = e
num_tries += 1
continue

Expand Down Expand Up @@ -175,6 +175,10 @@ def _create_adb_call_parser(self):
)

def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
assert self._adb_call_parser is not None, (
'_adb_call_parser is None. `start()` or `launch()` must be called '
'before executing ADB calls.'
)
return self._adb_call_parser.parse(call)

def rl_reset(self) -> dm_env.TimeStep:
Expand All @@ -185,10 +189,10 @@ def rl_reset(self) -> dm_env.TimeStep:
self._launch_simulator()

# Reset counters.
self._latest_observation_time = 0
self._latest_observation_time = 0.0
for key in self._stats:
if key.startswith('episode'):
self._stats[key] = 0.0
self._stats[key] = 0

# Execute a lift action before resetting the task.
if not action_fns.send_action_to_simulator(
Expand Down Expand Up @@ -231,6 +235,7 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False
return dm_env.truncation(reward=0.0, observation=None)

# Get data from the simulator.
try:
Expand All @@ -239,8 +244,6 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
logging.exception('Unable to fetch observation. Restarting simulator.')
self._stats['relaunch_count_fetch_observation'] += 1
self._simulator_healthy = False

if not self._simulator_healthy:
return dm_env.truncation(reward=0.0, observation=None)

return self._task_manager.rl_step(simulator_signals)
Expand Down
Loading