diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
deleted file mode 100644
index 61fb1eec..00000000
--- a/.github/workflows/tests.yml
+++ /dev/null
@@ -1,40 +0,0 @@
-name: tests
-
-on:
- push:
- branches: [ main ]
- pull_request:
- branches: [ main ]
- workflow_dispatch:
- inputs:
- git-ref:
- description: Git Ref (Optional)
- required: false
-
-jobs:
- build:
- runs-on: ubuntu-latest
- env:
- TEST_TMPDIR: '/tmp'
- strategy:
- matrix:
- python-version: ["3.11", "3.12", "3.13"]
- steps:
- - uses: actions/checkout@v4
-
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Install dependencies
- run: |
- pip install --upgrade pip setuptools
- python setup.py install
- pip install .[testing]
-
- - name: Run tests
- run: |
- # Find all test files, print their names and execute them in parallel
- # with a maximum of 20 proccesses.
- find . -type f -name "*_test.py" -print0 | xargs -t -0 -n1 -P 20 python3
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
deleted file mode 100644
index 03c65248..00000000
--- a/CONTRIBUTING.md
+++ /dev/null
@@ -1,25 +0,0 @@
-
-# How to Contribute
-
-# Pull Requests
-
-Please send in fixes or feature additions through Pull Requests.
-
-## Contributor License Agreement
-
-Contributions to this project must be accompanied by a Contributor License
-Agreement. You (or your employer) retain the copyright to your contribution,
-this simply gives us permission to use and redistribute your contributions as
-part of the project. Head over to to see
-your current agreements on file or to sign a new one.
-
-You generally only need to submit a CLA once, so if you've already submitted one
-(even if it was for a different project), you probably don't need to do it
-again.
-
-## Code reviews
-
-All submissions, including submissions by project members, require review. We
-use GitHub pull requests for this purpose. Consult
-[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
-information on using pull requests.
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index 7a4a3ea2..00000000
--- a/LICENSE
+++ /dev/null
@@ -1,202 +0,0 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
deleted file mode 100644
index b48c6bce..00000000
--- a/README.md
+++ /dev/null
@@ -1,163 +0,0 @@
-
-# AndroidEnv - The Android Learning Environment
-
-
-
-[AndroidEnv](https://github.com/deepmind/android_env) is a Python library that
-exposes an [Android](https://www.android.com/) device as a Reinforcement
-Learning (RL) environment. The library provides a flexible platform for defining
-custom tasks on top of the Android Operating System, including any Android
-application. Agents interact with the device through a universal action
-interface - the touchscreen - by sending localized touch and lift events to the
-system. The library processes these events and returns pixel observations and
-rewards as provided by specific [task definitions](docs/tasks_guide.md). For
-example, rewards might be given for events such as successfully scrolling down a
-page, sending an email, or achieving some score in a game, depending on the
-research purpose and how the user configures the task.
-
-[](https://github.com/deepmind/android_env/actions/workflows/tests.yml)
-[](https://badge.fury.io/py/android-env)
-[](https://pepy.tech/project/android-env)
-
-## Index
-
-* [Environment details](docs/environment.md)
-* [Running AndroidEnv](docs/instructions.md)
-* [Setting up a virtual Android device](docs/emulator_guide.md)
-* [Defining a task in AndroidEnv](docs/tasks_guide.md)
-* [Example tasks available for download](docs/example_tasks.md)
-
-## Environment features
-
-There are a number of aspects that make AndroidEnv a challenging yet suitable
-environment for Reinforcement Learning research:
-
-* Allowing agents to interact with a system used daily by billions of users
- around the world, AndroidEnv offers a platform for RL agents to navigate,
- learn tasks and have direct impact in **real-world contexts**. The
- environment wraps a simulated Android device, which runs independently from
- the environment, completely unaltered, and works in exactly the same way as
- the devices that humans use, exposing exactly the same features and
- services.
-
-* The platform offers a virtually infinite **range of possible tasks**, all
- sharing a common action interface. The library facilitates the design of
- Reinforcement Learning tasks for any existing or custom built Android
- application. For example, it exposes the broad world of Android games,
- ranging from card games, puzzle games, time reactive games, all requiring a
- diverse set of action combinations and interaction types.
-
-* The environment runs on top of a **real-time simulation** of an Android
- device. In other words, the environment dynamics does not wait for the agent
- to deliberate, and the speed of the simulation cannot be increased.
-
-* The observation is a collection of **RGB values** corresponding to the
- displayed pixels on the screen. The exact screen resolution depends on the
- simulated device, but in general it will be considered relatively large in
- an RL context. However, users have the option of downsampling each
- observation.
-
-* The learning environment has an interesting, **complex action space** unique
- to the touchscreen interface of Android.
-
- * The raw, **hybrid action space** consists of a continuous tuple
- signifying the action location, and a discrete signal determining
- whether the agent wants to touch the screen or lift its virtual finger.
- * Raw actions are highly **composable**: the Android UI and most
- applications were designed so that they could be intuitively navigated
- via common
- [touchscreen gestures](https://developer.android.com/training/gestures/detector)
- such as tapping, scrolling, swiping, pinching, drag & drop etc. This is
- still the case in AndroidEnv: to trigger meaningful changes in the
- environment, the agent often has to perform carefully timed and
- positioned sequences of raw actions. For example, in order to navigate
- to the next image in a photo gallery, the agent would have to perform a
- *swipe*, touching the screen multiple times, gradually shifting the
- actions' positions to the right. Thus, in most contexts raw actions do
- not trigger changes in the state of the environment unless correctly
- chained together to make up a human gesture.
- * The action interface is **closely related to the observation space**, as
- meaningful touch and lift events are often either co-localized or
- strongly correlated to the location or movement of salient objects in
- the observation. For example, the position of a button on the screen
- aligns with the location of the actions that trigger the button press.
- * The library provides tools for flexibly **altering the action
- interface** if needed for particular studies, such as discretization or
- hard-coding gesture skills. Still, we believe that the real challenge
- remains in devising agents that are capable of dealing with a large
- suite of diverse tasks, through acting and learning in the complex
- unifying action interface.
-
-# Getting started
-
-### Installation
-
-The easiest way to get AndroidEnv is with pip:
-
-```shell
-$ python3 -m pip install android-env
-```
-
-Please note that `/examples` are not included in this package.
-
-Alternatively, you can clone the repository from git's `main` branch:
-
-```shell
-$ git clone https://github.com/deepmind/android_env/
-$ cd android_env
-$ python3 setup.py install
-```
-
-Update: the environment now runs on Windows, but please keep in mind that this
-option is not well-maintained or widely supported, as Unix-based systems are the
-primary target platforms of this project.
-
-### Create a simulator
-
-Before running the environment, you will need access to an emulated Android
-device. For instructions on creating a virtual Android device, see the
-[Emulator guide](docs/emulator_guide.md).
-
-### Define a task
-
-Then, you will want to define what the agent's *task* is. At this point, the
-agent will be able to communicate with the emulated device, but it will not yet
-have an objective, or access to signals such as rewards or RL episode ends.
-Learn [how to define an RL task](docs/tasks_guide.md) of your own, or use one of
-the [existing task definitions](docs/example_tasks.md) for training.
-
-### Load and run
-
-To find out how to run and train agents on AndroidEnv, see these
-[detailed instructions](docs/instructions.md). Here you can also find example
-scripts demonstrating how to run a random agent, an
-[acme](https://github.com/deepmind/acme) agent, or a human agent on AndroidEnv.
-
-## About
-
-This library is developed and maintained by [DeepMind](http://deepmind.com). \
-You can find the [technical report](https://arxiv.org/abs/2105.13231) on Arxiv,
-as well as an introductory
-[blog
-post](https://www.deepmind.com/publications/androidenv-the-android-learning-environment)
-on DeepMind's website.
-
-If you use AndroidEnv in your research, you can cite the paper using the
-following BibTeX:
-
-```
-@article{ToyamaEtAl2021AndroidEnv,
- title = {{AndroidEnv}: A Reinforcement Learning Platform for Android},
- author = {Daniel Toyama and Philippe Hamel and Anita Gergely and
- Gheorghe Comanici and Amelia Glaese and Zafarali Ahmed and Tyler
- Jackson and Shibl Mourad and Doina Precup},
- year = {2021},
- eprint = {2105.13231},
- archivePrefix = {arXiv},
- primaryClass = {cs.LG},
- volume = {abs/2105.13231},
- url = {http://arxiv.org/abs/2105.13231},
-}
-```
-
-Disclaimer: This is not an official Google product.
diff --git a/android_env/__init__.py b/android_env/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt
deleted file mode 100644
index 6e9bf82d..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt
+++ /dev/null
@@ -1,280 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.accessibilityservice.AccessibilityService
-import android.util.Log
-import android.view.accessibility.AccessibilityEvent
-import android.view.accessibility.AccessibilityNodeInfo
-import android.view.accessibility.AccessibilityWindowInfo
-import com.google.androidenv.accessibilityforwarder.A11yServiceGrpcKt.A11yServiceCoroutineStub
-import io.grpc.ManagedChannel
-import io.grpc.ManagedChannelBuilder
-import io.grpc.ProxyDetector
-import io.grpc.StatusException
-import kotlinx.coroutines.TimeoutCancellationException
-import kotlinx.coroutines.runBlocking
-import kotlinx.coroutines.withTimeout
-
-/**
- * An Android service that listens to accessibility events and forwards them via gRPC.
- *
- * This service also logs the accessibility tree if [LogFlags.logAccessibilityTree] is set and if
- * [LogFlags.grpcPort] is positive.
- *
- * Please see
- * https://developer.android.com/reference/android/view/accessibility/AccessibilityEvent#getEventType()
- * for a comprehensive list of events emitted by Android.
- */
-class AccessibilityForwarder(
- private val channelFactory: (host: String, port: Int) -> ManagedChannel = { host, port ->
- ManagedChannelBuilder.forAddress(host, port)
- .proxyDetector(ProxyDetector { _ -> null })
- .usePlaintext()
- .build()
- }
-) : AccessibilityService() {
-
- init {
- // Spawn long-running thread for periodically logging the tree.
- Thread(
- Runnable {
- while (LogFlags.a11yTreePeriodMs > 0) {
- try {
- val windows = this.windows
- logAccessibilityTree(windows)
- } catch (e: ConcurrentModificationException) {
- continue
- }
-
- Thread.sleep(/* millis= */ LogFlags.a11yTreePeriodMs)
- }
- }
- )
- .start()
- }
-
- // grpcStub has a backing property that can be reset to null.
- private var _grpcStub: A11yServiceCoroutineStub? = null
- val grpcStub: A11yServiceCoroutineStub
- get() {
- if (_grpcStub == null) {
- Log.i(TAG, "Building channel on ${LogFlags.grpcHost}:${LogFlags.grpcPort}.")
- _grpcStub = A11yServiceCoroutineStub(channelFactory(LogFlags.grpcHost, LogFlags.grpcPort))
- }
- return _grpcStub!!
- }
-
- private fun resetGrpcStub() {
- _grpcStub = null
- }
-
- override fun onInterrupt() {
- LogFlags.a11yTreePeriodMs = 0 // Turn off periodic tree forwarding.
- }
-
- override fun onAccessibilityEvent(event: AccessibilityEvent?) {
- if (event == null) {
- Log.i(TAG, "`event` is null.")
- return
- }
-
- logExtrasForEvent(event)
- val eventType = event.eventType
- val eventTypeStr: String = AccessibilityEvent.eventTypeToString(eventType)
- if (eventTypeStr.isNotEmpty()) {
- Log.i(TAG, eventTypeStr)
- }
- }
-
- private fun logAccessibilityTree(windows: List) {
- if (!LogFlags.logAccessibilityTree) {
- Log.i(TAG, "Not logging accessibility tree")
- return
- }
-
- // Check gRPC port before actually building the forest.
- if (LogFlags.grpcPort <= 0) {
- Log.w(TAG, "Can't log accessibility tree because gRPC port has not been set.")
- return
- }
-
- val forest = creator.buildForest(windows)
- try {
- val grpcTimeoutMillis = 1000L
- val response: ForestResponse =
- with(grpcStub) {
- Log.i(TAG, "sending (blocking) gRPC request for tree.")
- runBlocking { withTimeout(grpcTimeoutMillis) { sendForest(forest) } }
- }
- if (response.error.isNotEmpty()) {
- Log.w(TAG, "gRPC response.error: ${response.error}")
- } else {
- Log.i(TAG, "gRPC request for tree succeeded.")
- }
- } catch (e: StatusException) {
- Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?")
- Log.i(TAG, "extra: exception ['$e']")
- resetGrpcStub()
- } catch (e: TimeoutCancellationException) {
- Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?")
- Log.i(TAG, "extra: exception ['$e']")
- resetGrpcStub()
- }
- }
-
- /** Logs extras for all event types. */
- private fun logExtrasForEvent(event: AccessibilityEvent) {
-
- val events: MutableMap = mutableMapOf()
-
- val sourceDescription = event.source?.contentDescription()
- if (!sourceDescription.isNullOrEmpty()) {
- events.put("source_content_description", sourceDescription)
- }
-
- // Output the event text.
- val eventText = event.text.joinToString(", ")
- if (eventText.isNotEmpty()) {
- events.put("event_text", eventText)
- }
-
- // Output the source text.
- val sourceText = event.source?.text?.toString()
- if (!sourceText.isNullOrEmpty()) {
- events.put("source_text", sourceText)
- }
-
- val eventTypeStr: String = AccessibilityEvent.eventTypeToString(event.eventType)
- if (eventTypeStr.isNotEmpty()) {
- events.put("event_type", eventTypeStr)
- }
-
- val className = event.source?.className?.toString()
- if (!className.isNullOrEmpty()) {
- events.put("source_class_name", className)
- }
-
- val packageName = event.packageName?.toString()
- if (!packageName.isNullOrEmpty()) {
- events.put("event_package_name", packageName)
- }
-
- // Text editing properties.
- val beforeText = event.beforeText?.toString()
- if (!beforeText.isNullOrEmpty()) {
- events.put("before_text", beforeText)
- }
-
- val fromIndex = event.fromIndex
- if (fromIndex != -1) {
- events.put("from_index", fromIndex.toString())
- }
-
- val toIndex = event.toIndex
- if (toIndex != -1) {
- events.put("to_index", toIndex.toString())
- }
-
- val addedCount = event.addedCount
- if (addedCount != -1) {
- events.put("added_count", addedCount.toString())
- }
-
- val removedCount = event.removedCount
- if (removedCount != -1) {
- events.put("removed_count", removedCount.toString())
- }
-
- // Text traversal properties
- val movementGranularity = event.movementGranularity
- if (movementGranularity != 0) {
- events.put("movement_granularity", movementGranularity.toString())
- }
-
- val action = event.action
- if (action != 0) {
- events.put("action", action.toString())
- }
-
- // Scrolling properties.
- if (eventTypeStr == "TYPE_VIEW_SCROLLED") {
- events.put("scroll_delta_x", event.scrollDeltaX.toString())
- events.put("scroll_delta_y", event.scrollDeltaY.toString())
- }
-
- // Report viewID so we know exactly where the event came from.
- val viewId = event.source?.viewIdResourceName?.toString()
- if (!viewId.isNullOrEmpty()) {
- events.put("view_id", viewId)
- }
-
- // Format [events] as a Python dict.
- if (events.isNotEmpty()) {
- events.put("event_timestamp_ms", event.eventTime.toString(10))
- // Check if we want to use gRPC.
- if (LogFlags.grpcPort > 0) {
- try {
- val grpcTimeoutMillis = 1000L
- val request = eventRequest { this.event.putAll(events) }
- val response: EventResponse =
- with(grpcStub) {
- Log.i(TAG, "sending (blocking) gRPC request for event.")
- runBlocking { withTimeout(grpcTimeoutMillis) { sendEvent(request) } }
- }
- if (response.error.isNotEmpty()) {
- Log.w(TAG, "gRPC response.error: ${response.error}")
- } else {
- Log.i(TAG, "gRPC request for event succeeded.")
- }
- } catch (e: StatusException) {
- Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?")
- Log.i(TAG, "extra: exception ['$e']")
- resetGrpcStub()
- } catch (e: TimeoutCancellationException) {
- Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?")
- Log.i(TAG, "extra: exception ['$e']")
- resetGrpcStub()
- }
- } else {
- Log.w(TAG, "Can't log accessibility event because gRPC port has not been set.")
- }
- }
- }
-
- /** Recursively climbs the accessibility tree until the root, collecting descriptions. */
- private fun AccessibilityNodeInfo?.contentDescription(): String {
- if (this == null) {
- return ""
- }
-
- val descriptions = mutableListOf()
- var current: AccessibilityNodeInfo? = this
- while (current != null) {
- val description = current.contentDescription
- if (description != null) {
- descriptions.add(description.toString())
- }
-
- current = current.parent
- }
- return descriptions.joinToString(", ")
- }
-
- companion object {
- private const val TAG = "AndroidRLTask"
- private val creator = AccessibilityTreeCreator()
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt
deleted file mode 100644
index b446db7a..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt
+++ /dev/null
@@ -1,516 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.view.accessibility.AccessibilityEvent
-import android.view.accessibility.AccessibilityNodeInfo
-import android.view.accessibility.AccessibilityWindowInfo
-import com.google.common.truth.Truth.assertThat
-import io.grpc.Status
-import io.grpc.StatusException
-import io.grpc.inprocess.InProcessChannelBuilder
-import io.grpc.inprocess.InProcessServerBuilder
-import io.grpc.testing.GrpcCleanupRule
-import org.junit.Assert.assertFalse
-import org.junit.Rule
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.robolectric.RobolectricTestParameterInjector
-import org.robolectric.Shadows.shadowOf
-
-@RunWith(RobolectricTestParameterInjector::class)
-class AccessibilityForwarderTest {
-
- @get:Rule(order = 1) val cleanupRule = GrpcCleanupRule()
-
- class FakeAccessibilityService : A11yServiceGrpcKt.A11yServiceCoroutineImplBase() {
- var sendForestChecker: (AndroidAccessibilityForest) -> String = { _ -> "" }
- var sendEventChecker: (EventRequest) -> String = { _ -> "" }
-
- override suspend fun sendForest(request: AndroidAccessibilityForest) = forestResponse {
- error = sendForestChecker(request)
- }
-
- override suspend fun sendEvent(request: EventRequest) = eventResponse {
- error = sendEventChecker(request)
- }
- }
-
- protected lateinit var forwarder: AccessibilityForwarder
- protected val fakeA11yService = FakeAccessibilityService()
- protected val channel by lazy {
- val serverName: String = InProcessServerBuilder.generateName()
- cleanupRule.register(
- InProcessServerBuilder.forName(serverName)
- .directExecutor()
- .addService(fakeA11yService)
- .build()
- .start()
- )
- cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build())
- }
-
- /** Initializes [forwarder] and [LogFlags] from the given args. */
- fun createForwarder(
- logAccessibilityTree: Boolean = false,
- a11yTreePeriodMs: Long = 0,
- grpcHost: String = "10.0.2.2",
- grpcPort: Int = 0,
- a11yWindows: MutableList? = null,
- ) {
- LogFlags.logAccessibilityTree = logAccessibilityTree
- LogFlags.a11yTreePeriodMs = a11yTreePeriodMs
- LogFlags.grpcHost = grpcHost
- LogFlags.grpcPort = grpcPort
- forwarder = AccessibilityForwarder({ _, _ -> channel })
- if (a11yWindows == null) {
- shadowOf(forwarder).setWindows(mutableListOf(AccessibilityWindowInfo.obtain()))
- } else {
- shadowOf(forwarder).setWindows(a11yWindows)
- }
- }
-
- @Test
- fun onInterrupt_doesNotCrash() {
- // Arrange.
- createForwarder(logAccessibilityTree = false)
- fakeA11yService.sendEventChecker = { _: EventRequest ->
- assertFalse(true) // This should not be called.
- "" // This should be unreachable
- }
-
- // Act.
- forwarder.onInterrupt()
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_nullEventShouldBeIgnored() {
- // Arrange.
- createForwarder(logAccessibilityTree = false)
- fakeA11yService.sendEventChecker = { _: EventRequest ->
- assertFalse(true) // This should not be called.
- "" // This should be unreachable
- }
-
- // Act.
- forwarder.onAccessibilityEvent(null)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_knownEventWithNoInformationShouldNotBeEmitted() {
- // Arrange.
- createForwarder(logAccessibilityTree = false)
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("")
- var event = AccessibilityEvent()
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { _: EventRequest ->
- assertFalse(true) // This should not be called.
- "" // This should be unreachable
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_typeViewClicked_sendEventViaGrpc() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, grpcPort = 1234)
- forwarder = AccessibilityForwarder({ _, _ -> channel })
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
- event.getText().add("Some text!")
- event.setPackageName("some.loooong.package.name")
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { request: EventRequest ->
- // Check that all fields are consistent with how they were set above.
- assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_CLICKED")
- assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
- assertThat(request.eventMap.get("source_content_description"))
- .isEqualTo("My Content Description")
- assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
- assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
- assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
- assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
- // No error message
- ""
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_typeViewTextChanged_ensureAllFieldsForwarded() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, grpcPort = 1234)
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED)
- event.getText().add("Some text!")
- event.fromIndex = 7
- event.beforeText = "Old words"
- event.addedCount = 12
- event.removedCount = 9
- event.setPackageName("some.loooong.package.name")
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { request: EventRequest ->
- // Check that all fields are consistent with how they were set above.
- assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_TEXT_CHANGED")
- assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
- assertThat(request.eventMap.get("source_content_description"))
- .isEqualTo("My Content Description")
- assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
- assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
- assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
- assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
- assertThat(request.eventMap.get("from_index")).isEqualTo("7")
- assertThat(request.eventMap.get("before_text")).isEqualTo("Old words")
- assertThat(request.eventMap.get("added_count")).isEqualTo("12")
- assertThat(request.eventMap.get("removed_count")).isEqualTo("9")
- assertFalse(request.eventMap.containsKey("to_index"))
- assertFalse(request.eventMap.containsKey("view_id"))
- assertFalse(request.eventMap.containsKey("action"))
- assertFalse(request.eventMap.containsKey("movement_granularity"))
- assertFalse(request.eventMap.containsKey("scroll_delta_x"))
- assertFalse(request.eventMap.containsKey("scroll_delta_y"))
- // No error message
- ""
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_typeViewScrolled_ensureAllFieldsForwarded() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, grpcPort = 1234)
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_SCROLLED)
- event.getText().add("Some text!")
- event.scrollDeltaX = 13
- event.scrollDeltaY = 27
- event.setPackageName("some.loooong.package.name")
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { request: EventRequest ->
- // Check that all fields are consistent with how they were set above.
- assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_SCROLLED")
- assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
- assertThat(request.eventMap.get("source_content_description"))
- .isEqualTo("My Content Description")
- assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
- assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
- assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
- assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
- assertThat(request.eventMap.get("scroll_delta_x")).isEqualTo("13")
- assertThat(request.eventMap.get("scroll_delta_y")).isEqualTo("27")
- assertFalse(request.eventMap.containsKey("from_index"))
- assertFalse(request.eventMap.containsKey("to_index"))
- assertFalse(request.eventMap.containsKey("before_text"))
- assertFalse(request.eventMap.containsKey("added_count"))
- assertFalse(request.eventMap.containsKey("removed_count"))
- // No error message
- ""
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_typeViewTextTraversedAtMovementGranularity_ensureAllFieldsForwarded() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, grpcPort = 1234)
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- nodeInfo.viewIdResourceName = "this.big.old.view.id"
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY)
- event.getText().add("Some text!")
- event.setPackageName("some.loooong.package.name")
- event.movementGranularity = 5
- event.fromIndex = 6
- event.toIndex = 8
- event.action = 23
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { request: EventRequest ->
- // Check that all fields are consistent with how they were set above.
- assertThat(request.eventMap.get("event_type"))
- .isEqualTo("TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY")
- assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
- assertThat(request.eventMap.get("source_content_description"))
- .isEqualTo("My Content Description")
- assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
- assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
- assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
- assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
- assertThat(request.eventMap.get("movement_granularity")).isEqualTo("5")
- assertThat(request.eventMap.get("from_index")).isEqualTo("6")
- assertThat(request.eventMap.get("to_index")).isEqualTo("8")
- assertThat(request.eventMap.get("view_id")).isEqualTo("this.big.old.view.id")
- assertThat(request.eventMap.get("action")).isEqualTo("23")
- // No error message
- ""
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_sendingevent_grpcTimeout() {
- // Arrange.
- createForwarder(
- logAccessibilityTree = false,
- a11yTreePeriodMs = 0,
- grpcHost = "amazing.host",
- grpcPort = 4321,
- )
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
- event.getText().add("Some text!")
- event.setPackageName("some.loooong.package.name")
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { _ ->
- // Delay the request to prompt a timeout
- Thread.sleep(1500L)
- "" // Return no error.
- }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Run a second request to ensure that the channel gets rebuilt.
- fakeA11yService.sendEventChecker = { _ -> "" }
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun onAccessibilityEvent_sendingevent_grpcStatusException() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, grpcHost = "amazing.host", grpcPort = 4321)
- var nodeInfo = AccessibilityNodeInfo()
- nodeInfo.setContentDescription("My Content Description")
- nodeInfo.setText("My Source Text")
- nodeInfo.setClassName("AwesomeClass")
- var event = AccessibilityEvent()
- event.setEventTime(1357924680)
- event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
- event.getText().add("Some text!")
- event.setPackageName("some.loooong.package.name")
- shadowOf(event).setSourceNode(nodeInfo)
- fakeA11yService.sendEventChecker = { _ -> throw StatusException(Status.UNAVAILABLE) }
-
- // Act.
- forwarder.onAccessibilityEvent(event)
-
- // Run a second request to ensure that the channel gets rebuilt.
- fakeA11yService.sendEventChecker = { _ -> "" }
- forwarder.onAccessibilityEvent(event)
-
- // Assert.
- // See `sendEventChecker` above.
- }
-
- @Test
- fun logAccessibilityTreeFalse_doesNotLogAccessibilityTree() {
- // Arrange.
- createForwarder(logAccessibilityTree = false, a11yTreePeriodMs = 10, grpcPort = 13579)
- fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest ->
- assertFalse(true) // This should not be called.
- "" // This should be unreachable
- }
-
- // Act.
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-
- @Test
- fun grpcPortZero_doesNotSendTree() {
- // Arrange.
- createForwarder(logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcPort = 0)
- fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest ->
- assertFalse(true) // This should not be called.
- "" // This should be unreachable
- }
-
- // Act.
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-
- @Test
- fun grpcPortPositive_shouldSendTreeViaGrpc() {
- // Arrange.
- val window = AccessibilityWindowInfo()
- shadowOf(window).setType(AccessibilityWindowInfo.TYPE_SYSTEM)
- createForwarder(
- logAccessibilityTree = true,
- a11yTreePeriodMs = 10,
- grpcPort = 1234,
- a11yWindows = mutableListOf(window),
- )
- fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest ->
- // Check that we get only a single window.
- assertThat(request.windowsList.size).isEqualTo(1)
- // And that its type is what we set above.
- assertThat(request.windowsList[0].windowType)
- .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_SYSTEM)
- // The error message
- "Something went wrong!"
- }
-
- // Act.
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-
- @Test
- fun grpcPortPositiveAndHost_shouldSendTreeViaGrpc() {
- // Arrange.
- fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest ->
- // Check that we get only a single window.
- assertThat(request.windowsList.size).isEqualTo(1)
- // And that its type is what we set above.
- assertThat(request.windowsList[0].windowType)
- .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_ACCESSIBILITY_OVERLAY)
- "" // Return no error.
- }
- val window = AccessibilityWindowInfo()
- shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
- createForwarder(
- logAccessibilityTree = true,
- a11yTreePeriodMs = 500,
- grpcHost = "amazing.host",
- grpcPort = 4321,
- a11yWindows = mutableListOf(window),
- )
-
- // Act.
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-
- @Test
- fun sendingForest_grpcTimeout() {
- // Arrange.
- fakeA11yService.sendForestChecker = { _ ->
- // Delay the request to prompt a timeout
- Thread.sleep(1500L)
- "" // Return no error.
- }
- val window = AccessibilityWindowInfo()
- shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
- createForwarder(
- logAccessibilityTree = true,
- a11yTreePeriodMs = 10,
- grpcHost = "amazing.host",
- grpcPort = 4321,
- a11yWindows = mutableListOf(window),
- )
-
- // Act.
- Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Run a second request to ensure that the channel gets rebuilt.
- fakeA11yService.sendForestChecker = { _ -> "" }
- Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-
- @Test
- fun sendingForest_grpcStatusException() {
- // Arrange.
- val window = AccessibilityWindowInfo()
- shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
- createForwarder(
- logAccessibilityTree = true,
- a11yTreePeriodMs = 10,
- grpcHost = "amazing.host",
- grpcPort = 4321,
- a11yWindows = mutableListOf(window),
- )
- fakeA11yService.sendForestChecker = { _ -> throw StatusException(Status.UNAVAILABLE) }
-
- // Act.
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Run a second request to ensure that the channel gets rebuilt.
- fakeA11yService.sendForestChecker = { _ -> "" }
- Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
-
- // Assert.
- // See `sendForestChecker` above.
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt
deleted file mode 100644
index aeaee706..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.graphics.Rect
-import android.util.Log
-import android.view.accessibility.AccessibilityNodeInfo
-import android.view.accessibility.AccessibilityWindowInfo
-import com.google.androidenv.accessibilityforwarder.AndroidAccessibilityWindowInfo.WindowType
-import java.util.concurrent.ConcurrentHashMap
-import java.util.stream.Collectors
-import kotlin.collections.mutableListOf
-import kotlinx.coroutines.Deferred
-import kotlinx.coroutines.async
-import kotlinx.coroutines.awaitAll
-import kotlinx.coroutines.runBlocking
-
-/** Helper methods for creating the android accessibility info extra. */
-class AccessibilityTreeCreator() {
-
- /** Creates an accessibility forest proto. */
- fun buildForest(windowInfos: List): AndroidAccessibilityForest {
- val sourcesMap: ConcurrentHashMap =
- ConcurrentHashMap()
- val windows: List =
- processWindowsAndBlock(windowInfos, sourcesMap)
- return androidAccessibilityForest { this.windows += windows }
- }
-
- private fun processWindowsAndBlock(
- windowInfos: List,
- sourcesMap: ConcurrentHashMap,
- ): List {
- val windows: List
- runBlocking { windows = processWindows(windowInfos, sourcesMap) }
- return windows
- }
-
- private suspend fun processWindows(
- windowInfos: List,
- sourcesMap: ConcurrentHashMap,
- ): List {
- var windowInfoProtos = mutableListOf()
- for (i in windowInfos.size - 1 downTo 0) {
- val windowInfoProto = processWindow(windowInfos.get(i), sourcesMap)
- windowInfoProto?.let { windowInfoProtos.add(windowInfoProto) }
- }
- return windowInfoProtos.toList()
- }
-
- private suspend fun processWindow(
- windowInfo: AccessibilityWindowInfo,
- sources: ConcurrentHashMap,
- ): AndroidAccessibilityWindowInfo? {
- val bounds = Rect()
- windowInfo.getBoundsInScreen(bounds)
- val root: AccessibilityNodeInfo? = windowInfo.root
- if (root == null) {
- Log.i(TAG, "window root is null")
- return androidAccessibilityWindowInfo {
- this.tree = androidAccessibilityTree {}
- this.isActive = windowInfo.isActive
- this.id = windowInfo.id
- this.layer = windowInfo.layer
- this.isAccessibilityFocused = windowInfo.isAccessibilityFocused
- this.isFocused = windowInfo.isFocused
- this.boundsInScreen = convertToRectProto(bounds)
- this.windowType = toWindowType(windowInfo.type)
- }
- }
- val treeDeferred: Deferred
- runBlocking { treeDeferred = async { processNodesInWindow(root, sources) } }
- return androidAccessibilityWindowInfo {
- this.tree = treeDeferred.await()
- this.isActive = windowInfo.isActive
- this.id = windowInfo.id
- this.layer = windowInfo.layer
- this.isAccessibilityFocused = windowInfo.isAccessibilityFocused
- this.isFocused = windowInfo.isFocused
- this.boundsInScreen = convertToRectProto(bounds)
- this.windowType = toWindowType(windowInfo.type)
- }
- }
-
- private suspend fun processNodesInWindow(
- root: AccessibilityNodeInfo,
- sources: ConcurrentHashMap,
- ): AndroidAccessibilityTree {
- Log.d(TAG, "processNodesInWindow()")
- val traversalQueue = ArrayDeque()
- traversalQueue.add(ParentChildNodePair.builder().child(root).build())
- val uniqueIdsCache: UniqueIdsGenerator = UniqueIdsGenerator()
- var currentDepth = 0
- val nodesDeferred = mutableListOf>()
- val seenNodes: HashSet = HashSet()
- seenNodes.add(root)
- runBlocking {
- while (!traversalQueue.isEmpty()) {
- // Traverse the tree layer-by-layer.
- // The first layer has only the root and depth 0.
- // The second layer has all the root's children and depth 1.
- for (nodesAtCurrentDepth in traversalQueue.size downTo 1) {
- val nodePair: ParentChildNodePair = traversalQueue.removeFirst()
- for (i in 0 until nodePair.child().childCount) {
- val childNode: AccessibilityNodeInfo? = nodePair.child().getChild(i)
- if (childNode != null && !seenNodes.contains(childNode)) {
- traversalQueue.add(
- ParentChildNodePair.builder().child(childNode).parent(nodePair.child()).build()
- )
- seenNodes.add(childNode)
- }
- }
- val thisDepth = currentDepth
- var deferred = async { processNode(nodePair, sources, uniqueIdsCache, thisDepth) }
- nodesDeferred.add(deferred)
- }
- currentDepth++
- }
- }
- return androidAccessibilityTree { this.nodes += nodesDeferred.awaitAll() }
- }
-
- companion object {
- private const val TAG = "AndroidRLTask"
- }
-}
-
-private fun processNode(
- nodePair: ParentChildNodePair,
- sourceBuilder: ConcurrentHashMap,
- uniqueIdsCache: UniqueIdsGenerator,
- nodeDepth: Int,
-): AndroidAccessibilityNodeInfo {
- val node: AccessibilityNodeInfo = nodePair.child()
- val immutableNode: AndroidAccessibilityNodeInfo =
- createAndroidAccessibilityNode(
- node,
- uniqueIdsCache.getUniqueId(node),
- nodeDepth,
- getChildUniqueIds(node, uniqueIdsCache),
- )
- sourceBuilder.put(immutableNode, node)
- return immutableNode
-}
-
-private fun createAndroidAccessibilityNode(
- node: AccessibilityNodeInfo,
- nodeId: Int,
- depth: Int,
- childIds: List,
-): AndroidAccessibilityNodeInfo {
- val bounds = Rect()
- node.getBoundsInScreen(bounds)
- val actions = node.getActionList().stream().map(::createAction).collect(Collectors.toList())
- return androidAccessibilityNodeInfo {
- this.actions += actions
- this.boundsInScreen = convertToRectProto(bounds)
- this.isCheckable = node.isCheckable
- this.isChecked = node.isChecked
- this.className = stringFromNullableCharSequence(node.getClassName())
- this.isClickable = node.isClickable
- this.contentDescription = stringFromNullableCharSequence(node.getContentDescription())
- this.isEditable = node.isEditable
- this.isEnabled = node.isEnabled
- this.isFocusable = node.isFocusable
- this.hintText = stringFromNullableCharSequence(node.getHintText())
- this.isLongClickable = node.isLongClickable
- this.packageName = stringFromNullableCharSequence(node.getPackageName())
- this.isPassword = node.isPassword
- this.isScrollable = node.isScrollable
- this.isSelected = node.isSelected
- this.text = stringFromNullableCharSequence(node.getText())
- this.textSelectionEnd = node.getTextSelectionEnd().toLong()
- this.textSelectionStart = node.getTextSelectionStart().toLong()
- this.viewIdResourceName = node.getViewIdResourceName() ?: ""
- this.isVisibleToUser = node.isVisibleToUser
- this.windowId = node.windowId
- this.uniqueId = nodeId
- this.childIds += childIds
- this.drawingOrder = node.drawingOrder
- this.tooltipText = stringFromNullableCharSequence(node.getTooltipText())
- this.depth = depth
- }
-}
-
-private fun createAction(
- action: AccessibilityNodeInfo.AccessibilityAction
-): AndroidAccessibilityAction =
- AndroidAccessibilityAction.newBuilder()
- .setId(action.id)
- .setLabel(stringFromNullableCharSequence(action.label))
- .build()
-
-private fun getChildUniqueIds(
- node: AccessibilityNodeInfo,
- uniqueIdsCache: UniqueIdsGenerator,
-): List {
- val ids = mutableListOf()
- for (childId in 0 until node.getChildCount()) {
- val child: AccessibilityNodeInfo = node.getChild(childId) ?: continue
- ids.add(uniqueIdsCache.getUniqueId(child))
- }
- return ids.toList()
-}
-
-fun stringFromNullableCharSequence(cs: CharSequence?): String = cs?.toString() ?: ""
-
-fun convertToRectProto(rect: Rect) = protoRect {
- left = rect.left
- top = rect.top
- right = rect.right
- bottom = rect.bottom
-}
-
-private fun toWindowType(type: Int): WindowType =
- when (type) {
- AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY -> WindowType.TYPE_ACCESSIBILITY_OVERLAY
- AccessibilityWindowInfo.TYPE_APPLICATION -> WindowType.TYPE_APPLICATION
- AccessibilityWindowInfo.TYPE_INPUT_METHOD -> WindowType.TYPE_INPUT_METHOD
- AccessibilityWindowInfo.TYPE_SYSTEM -> WindowType.TYPE_SYSTEM
- AccessibilityWindowInfo.TYPE_SPLIT_SCREEN_DIVIDER -> WindowType.TYPE_SPLIT_SCREEN_DIVIDER
- else -> WindowType.UNKNOWN_TYPE
- }
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt
deleted file mode 100644
index b05d23a3..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.view.accessibility.AccessibilityNodeInfo
-import android.view.accessibility.AccessibilityWindowInfo
-import kotlin.test.assertEquals
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.robolectric.RobolectricTestRunner
-import org.robolectric.Shadows.shadowOf
-
-@RunWith(RobolectricTestRunner::class)
-class AccessibilityTreeCreatorTest {
-
- @Test
- fun buildForest_buildsAccessibilityForestCorrectly() {
- val creator = AccessibilityTreeCreator()
-
- val forest = creator.buildForest(mutableListOf(createWindowInfo()))
-
- assertEquals(forest.windowsCount, 1)
- assertEquals(forest.getWindows(0).tree.nodesCount, 3)
- var rootNode: AndroidAccessibilityNodeInfo? = null
- var checkableNode: AndroidAccessibilityNodeInfo? = null
- val nodes = forest.getWindows(0).tree.nodesList
- for (i in nodes.size - 1 downTo 0) {
- if (nodes[i].text == "root node") {
- rootNode = nodes[i]
- }
- if (nodes[i].isCheckable == true) {
- checkableNode = nodes[i]
- }
- }
- assertEquals(rootNode?.childIdsCount, 2)
- assertEquals(checkableNode?.text, "Check box")
- }
-
- @Test
- fun buildForest_noRootInWindow_returnsEmptyTree() {
- val creator = AccessibilityTreeCreator()
- val windowInfo = AccessibilityWindowInfo.obtain()
- shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
-
- val forest = creator.buildForest(mutableListOf(windowInfo))
-
- assertEquals(0, forest.getWindows(0).tree.nodesList.size)
- }
-
- private fun createAccessibilityNodeInfo(): AccessibilityNodeInfo {
- val root = AccessibilityNodeInfo.obtain()
- root.text = "root node"
- root.isClickable = true
- val accessibilityNodeInfo = AccessibilityNodeInfo.obtain()
- accessibilityNodeInfo.viewIdResourceName = "test"
- accessibilityNodeInfo.isClickable = true
- accessibilityNodeInfo.isEditable = true
- accessibilityNodeInfo.hintText = "Please enter your address"
- shadowOf(root).addChild(accessibilityNodeInfo)
- val anotherChildNode = AccessibilityNodeInfo.obtain()
- anotherChildNode.isCheckable = true
- anotherChildNode.text = "Check box"
- shadowOf(root).addChild(anotherChildNode)
- return root
- }
-
- private fun createWindowInfo(): AccessibilityWindowInfo {
- val windowInfo = AccessibilityWindowInfo.obtain()
- shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
- shadowOf(windowInfo).setRoot(createAccessibilityNodeInfo())
- return windowInfo
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml
deleted file mode 100644
index debf611f..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml
deleted file mode 100644
index 7bf2e851..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-
-
-
-
-
-
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt
deleted file mode 100644
index 1ce5e763..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.content.BroadcastReceiver
-import android.content.Context
-import android.content.Intent
-import android.util.Log
-
-/** Broadcast receiver responsible for enabling or disabling flags. */
-class FlagsBroadcastReceiver() : BroadcastReceiver() {
-
- override fun onReceive(context: Context?, intent: Intent?) {
- val action = intent?.action
- Log.i(TAG, "Received broadcast intent with action: " + action)
- when (action) {
- ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS -> {
- Log.i(TAG, "Enabling Accessibility Tree logging.")
- LogFlags.logAccessibilityTree = true
- }
- ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS -> {
- Log.i(TAG, "Disabling Accessibility Tree logging.")
- LogFlags.logAccessibilityTree = false
- }
- ACTION_SET_GRPC -> {
- // The Android Emulator uses 10.0.2.2 as a redirect to the workstation's IP. Most often the
- // gRPC server will be running locally so it makes sense to use this as the default value.
- // See https://developer.android.com/studio/run/emulator-networking#networkaddresses.
- val host = intent.getStringExtra("host") ?: "10.0.2.2"
- // The TCP port to connect. If <=0 gRPC is disabled.
- val port = intent.getIntExtra("port", 0)
- Log.i(TAG, "Setting gRPC endpoint to ${host}:${port}.")
- LogFlags.grpcHost = host
- LogFlags.grpcPort = port
- }
- else -> Log.w(TAG, "Unknown action: ${action}")
- }
- }
-
- companion object {
- private const val TAG = "FlagsBroadcastReceiver"
- private const val ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS =
- "accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS"
- private const val ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS =
- "accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS"
- private const val ACTION_SET_GRPC = "accessibility_forwarder.intent.action.SET_GRPC"
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt
deleted file mode 100644
index f643cfab..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt
+++ /dev/null
@@ -1,166 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.content.Intent
-import com.google.common.truth.Truth.assertThat
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.robolectric.RobolectricTestRunner
-
-@RunWith(RobolectricTestRunner::class)
-class FlagsBroadcastReceiverTest {
-
- @Test
- fun onReceive_nullIntent_shouldNotLogAnything() {
- // Arrange.
- LogFlags.logAccessibilityTree = false
- val receiver = FlagsBroadcastReceiver()
-
- // Act.
- receiver.onReceive(context = null, intent = null)
-
- // Assert.
- assertThat(LogFlags.logAccessibilityTree).isFalse()
- }
-
- @Test
- fun onReceive_nullIntent_actionShouldNotLogAnything() {
- // Arrange.
- LogFlags.logAccessibilityTree = false
- val receiver = FlagsBroadcastReceiver()
- val intent = Intent()
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.logAccessibilityTree).isFalse()
- }
-
- @Test
- fun onReceive_unknownIntent_actionShouldIssueWarning() {
- // Arrange.
- LogFlags.logAccessibilityTree = false
- val receiver = FlagsBroadcastReceiver()
- val intent = Intent("SOME_WEIRD_ACTION")
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.logAccessibilityTree).isFalse()
- }
-
- @Test
- fun onReceive_intentWithDisableAction_shouldDisableTreeLogging() {
- // Arrange.
- LogFlags.logAccessibilityTree = true
- val receiver = FlagsBroadcastReceiver()
- val intent = Intent("accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS")
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.logAccessibilityTree).isFalse()
- }
-
- @Test
- fun onReceive_intentWithEnableAction_shouldEnableTreeLogging() {
- // Arrange.
- LogFlags.logAccessibilityTree = false
- val receiver = FlagsBroadcastReceiver()
- val intent = Intent("accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS")
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.logAccessibilityTree).isTrue()
- }
-
- @Test
- fun onReceive_intentWithSetGrpcActionNoArgs_shouldDefaultToEmuIpAndPortZero() {
- // Arrange.
- LogFlags.grpcHost = "some_host"
- LogFlags.grpcPort = 9999
- val receiver = FlagsBroadcastReceiver()
- val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC")
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2")
- assertThat(LogFlags.grpcPort).isEqualTo(0)
- }
-
- @Test
- fun onReceive_intentWithSetGrpcActionWithHostNoPort_shouldDefaultPortToZero() {
- // Arrange.
- LogFlags.grpcHost = "some_host"
- LogFlags.grpcPort = 9999
- val receiver = FlagsBroadcastReceiver()
- val intent =
- Intent("accessibility_forwarder.intent.action.SET_GRPC").apply {
- putExtra("host", "awesome.server.ca")
- }
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.grpcHost).isEqualTo("awesome.server.ca")
- assertThat(LogFlags.grpcPort).isEqualTo(0)
- }
-
- @Test
- fun onReceive_intentWithSetGrpcActionWithPortNoHost_shouldDefaultHostToEmuIp() {
- // Arrange.
- LogFlags.grpcHost = "some_host"
- LogFlags.grpcPort = 9999
- val receiver = FlagsBroadcastReceiver()
- val intent =
- Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("port", 54321) }
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2")
- assertThat(LogFlags.grpcPort).isEqualTo(54321)
- }
-
- @Test
- fun onReceive_intentWithSetGrpcActionWithHostAndPort_shouldSetBoth() {
- // Arrange.
- LogFlags.grpcHost = "some_host"
- LogFlags.grpcPort = 9999
- val receiver = FlagsBroadcastReceiver()
- val intent =
- Intent("accessibility_forwarder.intent.action.SET_GRPC").apply {
- putExtra("host", "grpc.ca")
- putExtra("port", 54321)
- }
-
- // Act.
- receiver.onReceive(context = null, intent = intent)
-
- // Assert.
- assertThat(LogFlags.grpcHost).isEqualTo("grpc.ca")
- assertThat(LogFlags.grpcPort).isEqualTo(54321)
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt
deleted file mode 100644
index 7af6fbde..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-/**
- * Controls global settings in AccessibilityForwarder.
- *
- * Please note that this class is not thread safe.
- */
-object LogFlags {
- // Whether to log the accessibility tree.
- var logAccessibilityTree: Boolean = false
- // How frequent to emit a11y trees (in milliseconds).
- var a11yTreePeriodMs: Long = 100
-
- // The gRPC server to connect to. (Only available if grpcPort>0).
- var grpcHost: String = ""
- // If >0 this represents the gRPC port number to connect to.
- var grpcPort: Int = 0
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt
deleted file mode 100644
index 4773a5ca..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import android.view.accessibility.AccessibilityNodeInfo
-import com.google.auto.value.AutoValue
-
-/** Parent and child [AccessibilityNodeInfo] relationship. */
-@AutoValue
-internal abstract class ParentChildNodePair {
- abstract fun parent(): AccessibilityNodeInfo?
-
- abstract fun child(): AccessibilityNodeInfo
-
- /** [ParentChildNodePair] builder. */
- @AutoValue.Builder
- abstract class Builder {
- abstract fun parent(parent: AccessibilityNodeInfo?): Builder
-
- abstract fun child(child: AccessibilityNodeInfo): Builder
-
- abstract fun build(): ParentChildNodePair
- }
-
- companion object {
- @JvmStatic fun builder(): Builder = AutoValue_ParentChildNodePair.Builder()
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt
deleted file mode 100644
index eeedacd5..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.androidenv.accessibilityforwarder
-
-import java.util.concurrent.ConcurrentHashMap
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.function.Function
-
-/** Thread-safe helper class for assigning a unique ID to an object. */
-internal class UniqueIdsGenerator {
- private val nextId = AtomicInteger(0)
- private val uniqueIdsByNode = ConcurrentHashMap()
-
- fun getUniqueId(a: A): Int {
- return uniqueIdsByNode.computeIfAbsent(a, Function { _: A -> nextId.getAndIncrement() })
- }
-}
diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml
deleted file mode 100644
index e73943d0..00000000
--- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-
-
-
diff --git a/android_env/components/__init__.py b/android_env/components/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/components/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/components/action_fns.py b/android_env/components/action_fns.py
deleted file mode 100644
index e290e9d0..00000000
--- a/android_env/components/action_fns.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Functions to convert actions between different components' formats."""
-
-from absl import logging
-from android_env.components import action_type as action_type_lib
-from android_env.components import errors
-from android_env.components import pixel_fns
-from android_env.components.simulators import base_simulator
-import numpy as np
-
-
-def send_action_to_simulator(
- action: dict[str, np.ndarray],
- simulator: base_simulator.BaseSimulator,
- screen_width: int,
- screen_height: int,
- num_fingers: int,
-) -> bool:
- """Sends the selected action to the given simulator.
-
- The simulator will interpret the action according to `action["action_type"]`.
- The effect this action triggers in the Android OS will be determined by the
- currently running application.
-
- Args:
- action: action which will get interpreted as a touchscreen event.
- simulator: The simulator that will receive the action.
- screen_width: The width of the touchscreen in pixels.
- screen_height: The height of the touchscreen in pixels.
- num_fingers: The number of fingers used in this simulator.
- """
-
- try:
- match action['action_type']:
- # If the action is a TOUCH or LIFT, send a touch event to the simulator.
- case action_type_lib.ActionType.TOUCH | action_type_lib.ActionType.LIFT:
- prepared_action = _prepare_touch_action(
- action, screen_width, screen_height, num_fingers
- )
- simulator.send_touch(prepared_action)
- # If the action is a key event, send a key event to the simulator.
- case action_type_lib.ActionType.KEYDOWN:
- simulator.send_key(action['keycode'].item(0), event_type='keydown')
- case action_type_lib.ActionType.KEYUP:
- simulator.send_key(action['keycode'].item(0), event_type='keyup')
- case action_type_lib.ActionType.KEYPRESS:
- simulator.send_key(action['keycode'].item(0), event_type='keypress')
- except errors.SendActionError:
- logging.exception('Unable to execute action: %r', action)
- return False
-
- return True
-
-
-def _prepare_touch_action(
- action: dict[str, np.ndarray],
- screen_width: int,
- screen_height: int,
- num_fingers: int,
-) -> list[tuple[int, int, bool, int]]:
- """Turns an AndroidEnv action into values that the simulator can interpret.
-
- Converts float-valued 'touch_position' to integer coordinates corresponding
- to specific pixels, and 'action_type' to booleans indicating whether the
- screen is touched at said location or not. The result of this function can
- be sent directly to the underlying simulator (e.g. the Android Emulator,
- virtual machine, or a phone).
-
- Args:
- action: An action containing 'action_type' and 'touch_position'.
-
- Returns:
- A tuple with the format (x: int, y: int, down/up: bool, finger_index: int).
- """
-
- touch_events = []
- for i, finger_action in enumerate(_split_touch_action(action, num_fingers)):
- is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH
- touch_position = finger_action['touch_position']
- touch_pixels = pixel_fns.touch_position_to_pixel_position(
- touch_position, width_height=(screen_width, screen_height)
- )
- touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i))
- return touch_events
-
-
-def _split_touch_action(
- action: dict[str, np.ndarray], num_fingers: int
-) -> list[dict[str, np.ndarray]]:
- """Splits a multitouch action into a list of single-touch actions."""
-
- single_touch_actions = [{
- 'action_type': action['action_type'],
- 'touch_position': action['touch_position'],
- }]
- for i in range(2, num_fingers + 1):
- single_touch_actions.append({
- 'action_type': action[f'action_type_{i}'],
- 'touch_position': action[f'touch_position_{i}'],
- })
- return single_touch_actions
-
-
-def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]:
- """A lift action with each finger."""
-
- # There's always at least one finger.
- lift_action = {
- 'action_type': np.array(action_type_lib.ActionType.LIFT),
- 'touch_position': np.array([0, 0]),
- }
- # Subsequent fingers have separate dict entries.
- for i in range(2, num_fingers + 1):
- lift_action |= {
- f'action_type_{i}': np.array(action_type_lib.ActionType.LIFT),
- f'touch_position_{i}': np.array([0, 0]),
- }
- return lift_action
diff --git a/android_env/components/action_fns_test.py b/android_env/components/action_fns_test.py
deleted file mode 100644
index 5612b6da..00000000
--- a/android_env/components/action_fns_test.py
+++ /dev/null
@@ -1,227 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import action_fns
-from android_env.components import action_type as action_type_lib
-from android_env.components import errors
-from android_env.components.simulators import base_simulator
-import numpy as np
-
-
-class ActionFnsTest(parameterized.TestCase):
-
- def test_send_action_to_simulator_missing_action_type(self):
- """A `KeyError` should be raised if the action is missing "action_type"."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {'some_key': np.array(123, np.int32)}
-
- # Act & Assert.
- self.assertRaises(
- KeyError,
- action_fns.send_action_to_simulator,
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- def test_send_action_to_simulator_sendactionerror(self):
- """Returns `False` if the simulator raises a SendActionError."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- simulator.send_touch.side_effect = errors.SendActionError('oops!')
- action = {
- 'action_type': action_type_lib.ActionType.TOUCH,
- 'touch_position': np.array([0.3, 0.5], np.float32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- # Assert.
- self.assertFalse(output)
- simulator.send_touch.assert_called_once()
-
- def test_send_action_to_simulator_touch_success_one_finger(self):
- """Returns `True` with a proper 1-finger touch action."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {
- 'action_type': action_type_lib.ActionType.TOUCH,
- 'touch_position': np.array([0.2, 0.5], np.float32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- # Assert.
- self.assertTrue(output)
- simulator.send_touch.assert_called_once_with(
- [(np.int32(160), np.int32(300), True, 0)]
- )
-
- def test_send_action_to_simulator_touch_success_multiple_finger(self):
- """Returns `True` with a proper 3-finger touch action."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {
- 'action_type': action_type_lib.ActionType.TOUCH,
- 'touch_position': np.array([0.2, 0.5], np.float32),
- 'action_type_2': action_type_lib.ActionType.LIFT,
- 'touch_position_2': np.array([0.1, 0.2], np.float32),
- 'action_type_3': action_type_lib.ActionType.TOUCH,
- 'touch_position_3': np.array([0.5, 0.2], np.float32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 3,
- )
-
- # Assert.
- self.assertTrue(output)
- simulator.send_touch.assert_called_once_with([
- (np.int32(160), np.int32(300), True, 0),
- (np.int32(80), np.int32(120), False, 1),
- (np.int32(400), np.int32(120), True, 2),
- ])
-
- def test_send_action_to_simulator_keydown_success(self):
- """Returns `True` with a proper keydown action."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {
- 'action_type': action_type_lib.ActionType.KEYDOWN,
- 'keycode': np.array([21], np.int32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- # Assert.
- self.assertTrue(output)
- simulator.send_key.assert_called_once_with(21, event_type='keydown')
-
- def test_send_action_to_simulator_keyup_success(self):
- """Returns `True` with a proper keyup action."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {
- 'action_type': action_type_lib.ActionType.KEYUP,
- 'keycode': np.array([42], np.int32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- # Assert.
- self.assertTrue(output)
- simulator.send_key.assert_called_once_with(42, event_type='keyup')
-
- def test_send_action_to_simulator_keypress_success(self):
- """Returns `True` with a proper keypress action."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- action = {
- 'action_type': action_type_lib.ActionType.KEYPRESS,
- 'keycode': np.array([96], np.int32),
- }
-
- # Act.
- output = action_fns.send_action_to_simulator(
- action,
- simulator,
- 800,
- 600,
- 1,
- )
-
- # Assert.
- self.assertTrue(output)
- simulator.send_key.assert_called_once_with(96, event_type='keypress')
-
- @parameterized.named_parameters(
- (
- 'one_finger',
- 1,
- {
- 'action_type': np.array(action_type_lib.ActionType.LIFT),
- 'touch_position': np.array([0, 0]),
- },
- ),
- (
- 'two_fingers',
- 2,
- {
- 'action_type': np.array(action_type_lib.ActionType.LIFT),
- 'touch_position': np.array([0, 0]),
- 'action_type_2': np.array(action_type_lib.ActionType.LIFT),
- 'touch_position_2': np.array([0, 0]),
- },
- ),
- )
- def test_lift_all_fingers_action(
- self, num_fingers: int, expected_action: dict[str, np.ndarray]
- ):
- """Returns the expected action."""
-
- output = action_fns.lift_all_fingers_action(num_fingers)
- for k, v in expected_action.items():
- np.testing.assert_array_equal(v, output[k])
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/action_type.py b/android_env/components/action_type.py
deleted file mode 100644
index da57aa4f..00000000
--- a/android_env/components/action_type.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""The different kinds of actions that AndroidEnv supports.
-
-The native action space of AndroidEnv consists of a tuple consisting of
-- A position (x, y) ∈ [0, 1] x [0, 1], determining the location of the action on
- the screen, and
-- A discrete value, indicating the action type, which is in this file.
-
-See https://arxiv.org/abs/2105.13231, section 2.2 for details.
-"""
-
-import enum
-
-
-@enum.unique
-class ActionType(enum.IntEnum):
- """Integer values to describe each supported action in AndroidEnv.
-
- Note for KEY* types:
- - Only meaningful if connected to a _physical_ keyboard, _not_ virtual
- keyboard.
- - Added afterwards so they did not appear in the paper.
-
- Attributes:
- TOUCH: Touching the screen at a location.
- LIFE: Lifting the (imaginary) pointer from the screen at a location.
- REPEAT: Repeating the last chosen action.
- KEYDOWN: Sending a key down event.
- KEYUP: Sending a key up event.
- KEYPRESS: Sending a key down event, immediately followed by a key up event.
- """
-
- TOUCH = 0
- LIFT = 1
- REPEAT = 2
- KEYDOWN = 3
- KEYUP = 4
- KEYPRESS = 5
diff --git a/android_env/components/adb_call_parser.py b/android_env/components/adb_call_parser.py
deleted file mode 100644
index 53efe478..00000000
--- a/android_env/components/adb_call_parser.py
+++ /dev/null
@@ -1,915 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Processes adb_pb2.AdbRequest commands."""
-
-import os
-import re
-import subprocess
-import sys
-import tempfile
-
-from absl import logging
-from android_env.components import adb_controller as adb_control
-from android_env.proto import adb_pb2
-
-# A mapping from a Button enum to keycode strings.
-#
-# Please see https://developer.android.com/reference/android/view/KeyEvent
-#
-# We currently only accept the following entries:
-_BUTTON_TO_KEYCODE = {
- adb_pb2.AdbRequest.PressButton.Button.HOME: 'KEYCODE_HOME',
- adb_pb2.AdbRequest.PressButton.Button.BACK: 'KEYCODE_BACK',
- adb_pb2.AdbRequest.PressButton.Button.ENTER: 'KEYCODE_ENTER',
-}
-
-
-class AdbCallParser:
- """Parses AdbRequest messages and executes corresponding adb commands."""
-
- def __init__(self, adb_controller: adb_control.AdbController):
- self._adb_controller = adb_controller
- self._handlers = {
- 'install_apk': self._install_apk,
- 'start_activity': self._start_activity,
- 'force_stop': self._force_stop,
- 'tap': self._tap,
- 'press_button': self._press_button,
- 'start_screen_pinning': self._start_screen_pinning,
- 'send_broadcast': self._send_broadcast,
- 'uninstall_package': self._handle_uninstall_package,
- 'get_current_activity': self._get_current_activity,
- 'get_orientation': self._get_orientation,
- 'push': self._push,
- 'pull': self._pull,
- 'input_text': self._input_text,
- 'settings': self._handle_settings,
- 'generic': self._handle_generic,
- 'package_manager': self._handle_package_manager,
- 'dumpsys': self._handle_dumpsys,
- }
-
- def _execute_command(
- self, command_args: list[str], timeout: float | None
- ) -> tuple[adb_pb2.AdbResponse, bytes]:
- """Executes the command, catches errors and populates the response status.
-
- Args:
- command_args: a list of arguments for the ADB request.
- timeout: Timeout in seconds.
-
- Returns:
- A tuple of the AdbResponse with the status populated, and the output
- bytes from the command.
- """
- response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
- command_output = b''
- try:
- command_output = self._adb_controller.execute_command(
- command_args, timeout=timeout)
- except subprocess.CalledProcessError as adb_error:
- if adb_error.stdout is not None:
- response.status = adb_pb2.AdbResponse.Status.ADB_ERROR
- response.error_message = adb_error.stdout
- except subprocess.TimeoutExpired:
- response.status = adb_pb2.AdbResponse.Status.TIMEOUT
- response.error_message = 'Timeout'
-
- return response, command_output
-
- def parse(self, request: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
- """Executes `request` and returns an appropriate response."""
-
- response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
- command_type = request.WhichOneof('command')
- logging.info('AdbRequest command type: %s', command_type)
- if command_type is None:
- response.status = adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND
- response.error_message = 'AdbRequest.command is None.'
- return response
-
- if request.timeout_sec < 0:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = ('AdbRequest.timeout_sec cannot be negative. '
- f'Got: {request.timeout_sec}')
- return response
-
- timeout: float | None = request.timeout_sec or None
- return self._handlers[command_type](request, timeout)
-
- def _force_stop(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Stops an application.
-
- Args:
- request: The external request containing the package to force stop.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- force_stop = request.force_stop
- response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
- if not force_stop.package_name:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = '`force_stop.package_name` cannot be empty.'
- return response
-
- response, _ = self._execute_command(
- ['shell', 'am', 'force-stop', force_stop.package_name], timeout)
-
- return response
-
- def _fetch_current_task_id(
- self, full_activity_name: str, timeout: float | None = None
- ) -> int:
- """Returns the task ID of the given `full_activity_name`.
-
- Args:
- full_activity_name: The full name of the activity whose corresponding
- task id we are looking for.
- timeout: Optional time limit in seconds.
- Returns:
- task_id: An integer corresponding to the specified activity.
- """
-
- stack = self._adb_controller.execute_command(
- ['shell', 'am', 'stack', 'list'], timeout=timeout)
- lines = stack.decode('utf-8').splitlines()
-
- regex = re.compile(
- r'^\ *taskId=(?P[0-9]*): (?P[^\s]*) .*visible=true'
- r'.*topActivity=ComponentInfo{(?P[^\s]*)}$')
-
- for line in lines:
- match = regex.search(line)
- if match is None:
- continue
-
- current_task_id_str = match.group('id')
- base_activity = match.group('base_activity')
- top_activity = match.group('top_activity')
-
- # If neither of the matched activities equals the activity we are
- # looking for, we discard their task id and continue the search.
- if full_activity_name not in {base_activity, top_activity}:
- logging.info('Full activity %s was not found in current line %s',
- full_activity_name, line)
- continue
-
- # Otherwise return the integer task id.
- try:
- return int(current_task_id_str)
- except ValueError:
- logging.info('Failed to parse task ID [%r].', current_task_id_str)
-
- # At this point if we could not find a task ID, there's nothing we can do.
- logging.error('Could not find current activity in stack list: %r', lines)
- return -1
-
- def _start_screen_pinning(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Pins an application.
-
- Args:
- request: The request containing the activity to pin.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- full_activity = request.start_screen_pinning.full_activity
- response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
- if not full_activity:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- '`start_screen_pinning.full_activity` cannot be empty.')
- return response
-
- current_task_id = self._fetch_current_task_id(full_activity, timeout)
- if current_task_id == -1:
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = ('Could not find task ID for activity '
- f'[{full_activity}]')
- return response
-
- response, _ = self._execute_command(
- ['shell', 'am', 'task', 'lock',
- str(current_task_id)], timeout=timeout)
-
- return response
-
- def _send_broadcast(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Sends a broadcast.
-
- Args:
- request: The request with the information for the broadcast event.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- send_broadcast = request.send_broadcast
- response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
- if not send_broadcast.action:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = ('`send_broadcast.{action}` cannot be empty.')
- return response
-
- if send_broadcast.component:
- component_args = ['-n', send_broadcast.component]
- else:
- component_args = []
-
- response, _ = self._execute_command(
- ['shell', 'am', 'broadcast', '-a', send_broadcast.action]
- + component_args,
- timeout=timeout,
- )
-
- return response
-
- def _install_apk(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Installs an app given its local path in the filesystem.
-
- Args:
- request: The external request with an install_apk field.
- Contains information for the .apk installation.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- install_apk = request.install_apk
- response = adb_pb2.AdbResponse()
- location_type = install_apk.WhichOneof('location')
- logging.info('location_type: %s', location_type)
-
- match location_type:
- case 'filesystem':
- fpath = install_apk.filesystem.path
- if not os.path.exists(fpath):
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = f'Could not find local_apk_path: {fpath}'
- return response
-
- response, _ = self._execute_command(
- ['install', '-r', '-t', '-g', fpath], timeout=timeout
- )
- case 'blob':
-
- # `delete_on_close` was only added in Python 3.12 so we add a switch
- # here to still support previous Python versions.
- if sys.version_info >= (3, 12):
- kwargs = {'suffix': '.apk', 'delete_on_close': False}
- else:
- kwargs = {'suffix': '.apk'}
-
- with tempfile.NamedTemporaryFile(**kwargs) as f:
- fpath = f.name
- f.write(install_apk.blob.contents)
-
- response, _ = self._execute_command(
- ['install', '-r', '-t', '-g', fpath], timeout=timeout
- )
- case _:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Unsupported `install_apk.location` type: {location_type}'
- )
- return response
-
- return response
-
- def _start_activity(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Starts a given activity.
-
- Options for `start_activity`:
- `am start` command options:
- -D: enable debugging
- -W: wait for launch to complete
- --start-profiler : start profiler and send results to
- -P : like above, but profiling stops when app goes idle
- -R: repeat the activity launch times. Prior to each repeat,
- the top activity will be finished.
- -S: force stop the target app before starting the activity
- --opengl-trace: enable tracing of OpenGL functions
-
- Args:
- request: The request with information on what activity to start.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse. If successful, StartActivityResponse will contain the
- activity name and adb command output.
- """
-
- activity = request.start_activity.full_activity
- if not activity:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message='`start_activity.full_activity` cannot be empty.')
-
- force_stop = '-S' if request.start_activity.force_stop else ''
- response, command_output = self._execute_command(
- ['shell', 'am', 'start', force_stop, '-W', '-n', activity] +
- list(request.start_activity.extra_args or []),
- timeout=timeout)
-
- # Check command output for potential errors.
- expected_error = re.compile(r""".*Error.*""", re.VERBOSE)
- if expected_error.match(str(command_output)):
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.INTERNAL_ERROR,
- error_message=f'start_activity failed with error: {command_output}')
-
- response.start_activity.full_activity = activity
- response.start_activity.output = command_output
- return response
-
- def _press_button(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Presses a keyboard key.
-
- Args:
- request: The request with information on what button to press.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- button = request.press_button.button
- if button not in _BUTTON_TO_KEYCODE:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message=('PressButton.button must be one of '
- f'[{_BUTTON_TO_KEYCODE.keys()}]. '
- f'Got: {button}. Please see `adb.proto`.'))
-
- keycode = _BUTTON_TO_KEYCODE[button]
- response, command_output = self._execute_command(
- ['shell', 'input', 'keyevent', keycode], timeout=timeout)
- response.press_button.output = command_output
- return response
-
- def _handle_uninstall_package(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Handles UninstallPackage messages.
-
- Args:
- request: The specification of what to uninstall.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse
- """
-
- package_name = request.uninstall_package.package_name
- response = adb_pb2.AdbResponse()
- # Every UninstallPackage should have a package_name.
- if not package_name:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- '`uninstall_package.package_name` cannot be empty.')
- return response
-
- # Get list of installed packages and issue an uninstall only if it's
- # already installed.
- package_response = self._handle_package_manager(
- adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- packages=adb_pb2.AdbRequest.PackageManagerRequest.List
- .Packages()))))
- if package_name in package_response.package_manager.list.items:
- response, _ = self._execute_command(['uninstall', package_name], timeout)
- else:
- msg = (f'Cannot uninstall {package_name} since it is not installed.')
- logging.warning(msg)
- response.error_message = msg
-
- return response
-
- def _get_current_activity(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Fetches current activity.
-
- Args:
- request: The request with the `.get_current_activity` field set. This is
- unused, but it's in the signature so that all calls are uniform.
- timeout: Optional time limit in seconds.
-
- Returns:
- AdbResponse containing the current activity.
- """
-
- del request # Unused.
-
- response, visible_task = self._execute_command(
- ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
- timeout=timeout)
-
- if response.status != adb_pb2.AdbResponse.Status.OK:
- return response
-
- if not visible_task:
- _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'],
- timeout=timeout)
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = ('Empty visible_task. `am stack list`: '
- f'{am_stack_list}')
- return response
-
- visible_task = visible_task.decode('utf-8')
- if sys.platform == 'win32':
- visible_task_list = re.findall(
- r'visible=true topActivity=ComponentInfo{(.+?)}', visible_task)
- if not visible_task_list:
- visible_task = ''
- else:
- visible_task = 'ComponentInfo{' + visible_task_list[0] + '}'
-
- p = re.compile(r'.*\{(.*)\}')
- matches = p.search(visible_task)
- if matches is None:
- _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'],
- timeout=timeout)
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = (
- 'Could not extract current activity. Will return nothing. '
- f'`am stack list`: {am_stack_list}')
- return response
-
- response.get_current_activity.full_activity = matches.group(1)
- return response
-
- def _get_orientation(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Fetches current device orientation.
-
- Args:
- request: The request with the `.get_orientation` field set.
- timeout: Optional time limit in seconds.
-
- Returns:
- AdbResponse containing the current device orientation. This is
- unused, but it's in the signature so that all calls are uniform.
- """
-
- del request # Unused.
-
- logging.info('Getting orientation...')
- response = self._handle_dumpsys(
- adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='input')),
- timeout=timeout)
- output = response.dumpsys.output
- if not output:
- logging.error('Empty dumpsys output.')
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = 'Failed to execute `dumpsys input`'
- return response
-
- output = output.decode('utf-8')
- lines = output.split('\n') # Split by lines.
- skip_next = False
- for line in lines:
- # There may be multiple devices in output. An invalid device can be
- # identified by negative PhysicalWidth.
- physical_width = re.match(r'\s+PhysicalWidth:\s+(-?\d+)px', line)
- if physical_width:
- skip_next = int(physical_width.group(1)) < 0
-
- surface_orientation = re.match(
- r'\s+(SurfaceOrientation|InputDeviceOrientation):\s+(\d)', line
- )
-
- if surface_orientation is not None:
- if skip_next:
- continue
- if surface_orientation.re.groups < 2:
- continue
- orientation = surface_orientation.group(2)
- logging.info('Done getting orientation: %r', orientation)
- response.get_orientation.orientation = int(orientation)
- return response
-
- response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
- response.error_message = (
- 'Could not find SurfaceOrientation/InputDeviceOrientation in dumpsys '
- 'output'
- )
- return response
-
- def _push(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Uploads contents to the device.
-
- Args:
- request: The request with the contents to push to the device.
- timeout: Optional time limit in seconds.
-
- Returns:
- An empty AdbResponse.
- """
-
- path = request.push.path
- if not path:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message='Push.path is empty.')
-
- # Create temporary file with `push` contents.
- with tempfile.NamedTemporaryFile(delete=False) as f:
- fname = f.name
- f.write(request.push.content)
- # Issue `adb push` command to upload file.
- logging.info('Uploading %r to %r.', fname, path)
- response, _ = self._execute_command(['push', fname, path], timeout=timeout)
- # Delete it.
- os.remove(fname)
-
- return response
-
- def _pull(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Downloads file content from the device.
-
- Args:
- request: The request with the information on what to get from the device.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse with the contents of the specified file.
- """
-
- path = request.pull.path
- if not path:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message='Pull.path is empty.')
-
- # Issue `adb pull` command to copy it to a temporary file.
- with tempfile.NamedTemporaryFile(delete=False) as f:
- fname = f.name
- logging.info('Downloading %r to %r.', path, fname)
- response, _ = self._execute_command(['pull', path, fname],
- timeout=timeout)
- # Read the content of the file.
- with open(fname, 'rb') as f:
- response.pull.content = f.read()
- # Delete it.
- os.remove(fname)
-
- return response
-
- def _input_text(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Inserts text as keyboard events.
-
- Args:
- request: The external request.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse
- """
-
- text = request.input_text.text
- if not text:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message='InputText.text is empty.')
-
- response, _ = self._execute_command(['shell', 'input', 'text', text],
- timeout=timeout)
- return response
-
- def _tap(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Taps the device screen.
-
- Args:
- request: The request with information on where to tap the screen.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse
- """
-
- x = request.tap.x
- y = request.tap.y
- # Check for negative coordinates.
- # Notice that zero coordinates are valid coordinates (i.e. the first
- # column/row of the screen).
- if x < 0 or y < 0:
- return adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
- error_message=(
- f'Tap coordinates must be non-negative. Got: {request.tap}.'))
-
- response, _ = self._execute_command(
- ['shell', 'input', 'tap', str(x),
- str(y)], timeout=timeout)
-
- return response
-
- def _handle_settings(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Handles SettingsRequest messages.
-
- Args:
- request: The specification of what to do with settings.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse
- """
-
- request = request.settings
- response = adb_pb2.AdbResponse()
- # Every SettingsRequest should have a namespace.
- if request.name_space == adb_pb2.AdbRequest.SettingsRequest.Namespace.UNKNOWN:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Unknown SettingsRequest.name_space. Got: {request}.')
- return response
-
- namespace = adb_pb2.AdbRequest.SettingsRequest.Namespace.Name(
- request.name_space).lower()
-
- match request.WhichOneof('verb'):
- case 'get':
- get = request.get
- if not get.key:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Empty SettingsRequest.get.key. Got: {request}.'
- )
- return response
- response, command_output = self._execute_command(
- ['shell', 'settings', 'get', namespace, get.key], timeout=timeout
- )
- response.settings.output = command_output
- case 'put':
- put = request.put
- if not put.key or not put.value:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Empty SettingsRequest.put key or value. Got: {request}.'
- )
- return response
- response, command_output = self._execute_command(
- ['shell', 'settings', 'put', namespace, put.key, put.value],
- timeout=timeout,
- )
- response.settings.output = command_output
- case 'delete_key':
- delete = request.delete_key
- if not delete.key:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Empty SettingsRequest.delete_key.key. Got: {request}.'
- )
- return response
- response, command_output = self._execute_command(
- ['shell', 'settings', 'delete', namespace, delete.key],
- timeout=timeout,
- )
- response.settings.output = command_output
- case 'reset':
- reset = request.reset
- # At least one of `package_name` or `mode` should be given.
- if (
- not reset.package_name
- and reset.mode
- == adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN
- ):
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- 'At least one of SettingsRequest.reset package_name or mode'
- f' should be given. Got: {request}.'
- )
- return response
-
- mode = adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.Name(
- reset.mode
- ).lower()
- arg = reset.package_name or mode
- response, command_output = self._execute_command(
- ['shell', 'settings', 'reset', namespace, arg], timeout=timeout
- )
- response.settings.output = command_output
- case 'list':
- response, command_output = self._execute_command(
- ['shell', 'settings', 'list', namespace], timeout=timeout
- )
- response.settings.output = command_output
- case _:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Unknown SettingsRequest.verb. Got: {request}.'
- )
-
- return response
-
- def _handle_generic(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Handles GenericRequest messages.
-
- Args:
- request: The request with the `.generic` field set indicating what `adb`
- shell command to issue
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse
- """
-
- response, command_output = self._execute_command(
- list(request.generic.args), timeout)
- response.generic.output = command_output
- return response
-
- def _handle_package_manager(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Handles PackageManagerRequest messages.
-
- Args:
- request: The request with the `.package_manager` field set containing the
- sub-commands to issue to `adb pm`.
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- request = request.package_manager
- response = adb_pb2.AdbResponse()
-
- match request.WhichOneof('verb'):
- case 'list':
- what = request.list.WhichOneof('what')
- response, output = self._execute_command(
- ['shell', 'pm', 'list', what], timeout=timeout
- )
-
- if output:
- items = output.decode('utf-8').split()
- # Remove prefix for each item.
- prefix = {
- 'features': 'feature:',
- 'libraries': 'library:',
- 'packages': 'package:',
- }[what]
- items = [x[len(prefix) :] for x in items if x.startswith(prefix)]
- response.package_manager.list.items.extend(items)
- response.package_manager.output = output
- case 'clear':
- package_name = request.clear.package_name
- if not package_name:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- f'Empty PackageManagerRequest.clear.package_name. Got: {request}.'
- )
- return response
-
- args = ['shell', 'pm', 'clear', package_name]
- if request.clear.user_id:
- args.insert(3, '-f')
- args.insert(4, request.clear.user_id)
- response, response.package_manager.output = self._execute_command(
- args, timeout=timeout
- )
- case 'grant':
- grant = request.grant
- if not grant.package_name:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = '`grant.package_name` cannot be empty.'
- return response
-
- if not grant.permissions:
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = '`grant.permissions` cannot be empty.'
- return response
-
- for permission in grant.permissions:
- logging.info('Granting permission: %r', permission)
- response, response.package_manager.output = self._execute_command(
- ['shell', 'pm', 'grant', grant.package_name, permission],
- timeout=timeout,
- )
-
- return response
-
- def _handle_dumpsys(
- self, request: adb_pb2.AdbRequest, timeout: float | None = None
- ) -> adb_pb2.AdbResponse:
- """Handles DumpsysRequest messages.
-
- Args:
- request: The request with the `.dumpsys` field set containing
- sub-commands to `adb dumpsys` shell command..
- timeout: Optional time limit in seconds.
-
- Returns:
- An AdbResponse.
- """
-
- request = request.dumpsys
- cmd = ['shell', 'dumpsys']
-
- if request.timeout_sec < 0 or request.timeout_ms < 0:
- response = adb_pb2.AdbResponse()
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- 'DumpsysRequest.timeout_{sec, ms} should be non-negative. '
- f'Got: {request}.')
- return response
-
- if request.list_only:
- # `-l` cannot be combined with the following options.
- if request.service or request.args or request.skip_services:
- response = adb_pb2.AdbResponse()
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- 'DumpsysRequest.list_only cannot be combined with other options. '
- f'Got: {request}.')
- return response
-
- cmd.append('-l')
-
- if request.timeout_sec > 0:
- cmd.append('-t')
- cmd.append(str(request.timeout_sec))
- elif request.timeout_ms > 0:
- cmd.append('-T')
- cmd.append(str(request.timeout_ms))
-
- if request.priority != adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET:
- cmd.append('--priority')
- cmd.append(adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.Name(
- request.priority))
-
- if request.skip_services:
- if request.service:
- response = adb_pb2.AdbResponse()
- response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- response.error_message = (
- 'DumpsysRequest.skip_services cannot be combined with `service`. '
- f'Got: {request}.')
- return response
-
- cmd.append('--skip')
- cmd.append(','.join(request.skip_services))
-
- if request.service:
- cmd.append(request.service)
-
- if request.args:
- cmd += list(request.args)
-
- if request.proto:
- cmd.append('--proto')
-
- response, response.dumpsys.output = self._execute_command(
- cmd, timeout=timeout)
-
- return response
diff --git a/android_env/components/adb_call_parser_test.py b/android_env/components/adb_call_parser_test.py
deleted file mode 100644
index 7b178d71..00000000
--- a/android_env/components/adb_call_parser_test.py
+++ /dev/null
@@ -1,1215 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import builtins
-import os
-import subprocess
-import sys
-import tempfile
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import adb_call_parser
-from android_env.components import adb_controller
-from android_env.proto import adb_pb2
-
-
-class AdbCallParserTest(parameterized.TestCase):
-
- def test_unknown_command(self):
- """Gets UNKNOWN_COMMAND for an empty request."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- response = parser.parse(request)
- self.assertEqual(
- response.status, adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND
- )
-
- def test_invalid_timeout(self):
- """AdbRequest.timeout_sec must be positive."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.tap.x = 123
- request.timeout_sec = -5
- response = parser.parse(request)
- self.assertEqual(
- response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- )
-
- @mock.patch.object(os.path, 'exists', autospec=True)
- def test_install_apk_file_not_found(self, mock_exists):
- """Should fail installing APK when it is not found."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.install_apk.filesystem.path = '/my/home/game.apk'
- mock_exists.return_value = False
-
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- @mock.patch.object(os.path, 'exists', autospec=True)
- def test_install_apk_successful(self, mock_exists):
- """Should succeed installing an arbitrary APK."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.install_apk.filesystem.path = '/my/home/game.apk'
- mock_exists.return_value = True
-
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['install', '-r', '-t', '-g', '/my/home/game.apk'], None)
-
- @mock.patch.object(tempfile, 'NamedTemporaryFile', autospec=True)
- def test_install_apk_from_blob(self, mock_tempfile):
- """Should succeed installing APK from blob."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- blob_content = b'A fake blob content'
- request.install_apk.blob.contents = blob_content
- mock_tempfile.return_value.__enter__.return_value.name = '/my/home/test.apk'
- mock_tempfile.return_value.__enter__.return_value.write.return_value = None
-
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['install', '-r', '-t', '-g', '/my/home/test.apk'], None
- )
- # pytype: disable=attribute-error
- expected_tempfile_kwargs = (
- {'suffix': '.apk', 'delete_on_close': False}
- if sys.version_info > (3, 12)
- else {'suffix': '.apk'}
- )
- mock_tempfile.assert_has_calls([
- mock.call(**expected_tempfile_kwargs), # Constructor
- mock.call().__enter__(), # Enter context
- mock.call().__enter__().write(blob_content), # Call write function
- mock.call().__exit__(None, None, None), # Exit context
- ])
- # pytype: enable=attribute-error
-
- def test_start_activity_empty_full_activity(self):
- """A start_activity command should always have a nonempty activity."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_activity.extra_args.extend(['blah'])
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_start_activity_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- command_output = (b'Stopping: my.project.SplashActivity\n'
- b'Starting: Intent { cmp=my.project.SplashActivity }\n')
- adb.execute_command.return_value = command_output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_activity.full_activity = 'my.project.SplashActivity'
- request.start_activity.extra_args.extend(['blah'])
- request.start_activity.force_stop = True
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call([
- 'shell', 'am', 'start', '-S', '-W', '-n',
- 'my.project.SplashActivity', 'blah'
- ],
- timeout=None),
- ])
-
- def test_start_activity_successful_no_force_stop(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- command_output = (b'Stopping: my.project.SplashActivity\n'
- b'Starting: Intent { cmp=my.project.SplashActivity }\n')
- adb.execute_command.return_value = command_output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_activity.full_activity = 'my.project.SplashActivity'
- request.start_activity.extra_args.extend(['blah'])
- request.start_activity.force_stop = False
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call([
- 'shell', 'am', 'start', '', '-W', '-n', 'my.project.SplashActivity',
- 'blah'
- ],
- timeout=None),
- ])
-
- def test_start_activity_error(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- command_output = (b'Stopping: my.project.SplashActivity\n'
- b'Starting: Intent { cmp=my.project.SplashActivity }\n'
- b'Error: Activity not started, unknown error code 101\n')
- adb.execute_command.return_value = command_output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_activity.full_activity = 'my.project.SplashActivity'
- request.start_activity.extra_args.extend(['blah'])
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertEqual(
- response.error_message,
- f'start_activity failed with error: {str(command_output)}')
-
- def test_force_stop(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.force_stop.package_name = 'my.project'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'am', 'force-stop', 'my.project'], None)
-
- def test_grant_permissions_empty_package_name(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.package_manager.grant.permissions.extend(['perm1', 'perm2'])
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_grant_permissions_empty_permissions(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.package_manager.grant.package_name = 'my.project'
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_grant_permissions_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.package_manager.grant.package_name = 'my.project'
- request.package_manager.grant.permissions.extend(['perm1', 'perm2'])
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call(['shell', 'pm', 'grant', 'my.project', 'perm1'], None),
- mock.call(['shell', 'pm', 'grant', 'my.project', 'perm2'], None),
- ])
-
- def test_press_button_invalid_button(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.press_button.button = 99999
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_press_button_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b''
- parser = adb_call_parser.AdbCallParser(adb)
- # HOME.
- request = adb_pb2.AdbRequest()
- request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.HOME
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_with(
- ['shell', 'input', 'keyevent', 'KEYCODE_HOME'], None)
- # BACK.
- request = adb_pb2.AdbRequest()
- request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.BACK
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_with(
- ['shell', 'input', 'keyevent', 'KEYCODE_BACK'], None)
- # ENTER.
- request = adb_pb2.AdbRequest()
- request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.ENTER
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_with(
- ['shell', 'input', 'keyevent', 'KEYCODE_ENTER'], None)
-
- def test_start_screen_pinning_package_not_found(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = (
- b' taskId=12345: my.project.AnotherActivity visible=true'
- b' topActivity=ComponentInfo{my.project.AnotherActivity}')
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_screen_pinning.full_activity = 'my.project.AmazingActivity'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'am', 'stack', 'list'], None)
-
- def test_start_screen_pinning_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = (
- b' taskId=12345: my.project.AmazingActivity visible=true'
- b' topActivity=ComponentInfo{my.project.AmazingActivity}')
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_screen_pinning.full_activity = 'my.project.AmazingActivity'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call(['shell', 'am', 'stack', 'list'], None),
- mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
- ])
-
- def test_start_screen_pinning_base_activity(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = (
- b' taskId=12345: my.project.MainActivity visible=true'
- b' topActivity=ComponentInfo{my.project.TopActivity}')
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_screen_pinning.full_activity = 'my.project.MainActivity'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call(['shell', 'am', 'stack', 'list'], None),
- mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
- ])
-
- def test_start_screen_pinning_top_activity(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = (
- b' taskId=12345: my.project.MainActivity visible=true'
- b' topActivity=ComponentInfo{my.project.TopActivity}')
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.start_screen_pinning.full_activity = 'my.project.TopActivity'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call(['shell', 'am', 'stack', 'list'], None),
- mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
- ])
-
- def test_send_broadcast_empty_action(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- send_broadcast=adb_pb2.AdbRequest.SendBroadcast())
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_send_broadcast_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.send_broadcast.action = 'SOME-ACTION'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
-
- def test_send_broadcast_with_component_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.send_broadcast.action = 'SOME-ACTION'
- request.send_broadcast.component = 'SOME-COMPONENT'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
-
- def test_uninstall_package_empty_package_name(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.uninstall_package.package_name = ''
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
-
- def test_uninstall_package_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'package:my.package'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest()
- request.uninstall_package.package_name = 'my.package'
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
-
- def test_get_current_activity_no_visible_task(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = None
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call(
- ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
- None),
- mock.call(['shell', 'am', 'stack', 'list'], None),
- ])
-
- def test_get_orientation_empty_dumpsys(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b''
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
- None)
-
- def test_get_orientation_invalid_device_no_surface_orientation(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b' PhysicalWidth: -123px'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
- None)
-
- @parameterized.named_parameters(
- ('rotation_0', b""" SurfaceOrientation: 0""", 0),
- ('rotation_90', b""" SurfaceOrientation: 1""", 1),
- ('rotation_180', b""" SurfaceOrientation: 2""", 2),
- ('rotation_270', b""" SurfaceOrientation: 3""", 3),
- ('rotation_0_new', b""" InputDeviceOrientation: 0""", 0),
- ('rotation_90_new', b""" InputDeviceOrientation: 1""", 1),
- ('rotation_180_new', b""" InputDeviceOrientation: 2""", 2),
- ('rotation_270_new', b""" InputDeviceOrientation: 3""", 3),
- )
- def test_get_orientation_success(
- self, orientation: bytes, expected_orientation: int
- ):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = (
- b"""SomeRandomKey: 12345\n""" + orientation + b"""
- MoreRandomStuff: awesome_value
-"""
- )
-
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.get_orientation.orientation, expected_orientation)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
- None)
-
- def test_get_current_activity_no_matches(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
- for platform in ['win32', 'linux']:
- with mock.patch.object(
- sys, 'platform', autospec=True, return_value=platform):
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_has_calls([
- mock.call([
- 'shell', 'am', 'stack', 'list', '|', 'grep', '-E',
- 'visible=true'
- ], None),
- mock.call(['shell', 'am', 'stack', 'list'], None),
- ])
-
- def test_get_current_activity_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'{MyAwesomeActivity}'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
- for platform in ['win32', 'linux']:
- with mock.patch.object(
- sys, 'platform', autospec=True, return_value=platform):
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- # `execute_command` will be called once for each platform.
- adb.execute_command.assert_called_with(
- ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
- None)
- self.assertEqual(response.get_current_activity.full_activity,
- 'MyAwesomeActivity')
-
- def test_push_no_path(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- push=adb_pb2.AdbRequest.Push(content=b'Has content but no path'))
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_push_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- push=adb_pb2.AdbRequest.Push(
- content=b'My text.', path='/sdcard/my_file.txt'))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once()
- args, kwargs = adb.execute_command.call_args
- self.assertLen(args, 1)
- cmd_args = args[0]
- self.assertLen(cmd_args, 3)
- self.assertEqual(cmd_args[0], 'push')
- self.assertEqual(cmd_args[2], '/sdcard/my_file.txt')
- self.assertIn('timeout', kwargs)
- self.assertIsNone(kwargs['timeout'])
-
- def test_pull_no_path(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(pull=adb_pb2.AdbRequest.Pull())
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- @mock.patch.object(builtins, 'open', autospec=True)
- def test_pull_successful(self, mock_open):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- mock_open.return_value.__enter__ = mock_open
- mock_open.return_value.read.return_value = b'S3cR3t. dO nOt TeLl ANYONE'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- pull=adb_pb2.AdbRequest.Pull(path='/sdcard/my_file.txt'))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.pull.content, b'S3cR3t. dO nOt TeLl ANYONE')
- adb.execute_command.assert_called_once()
- args, kwargs = adb.execute_command.call_args
- self.assertLen(args, 1)
- cmd_args = args[0]
- self.assertLen(cmd_args, 3)
- self.assertEqual(cmd_args[0], 'pull')
- self.assertEqual(cmd_args[1], '/sdcard/my_file.txt')
- self.assertIn('timeout', kwargs)
- self.assertIsNone(kwargs['timeout'])
-
- def test_input_text_no_text(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(input_text=adb_pb2.AdbRequest.InputText())
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_input_text_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- input_text=adb_pb2.AdbRequest.InputText(
- text='The Greatest Text of All Time'))
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'input', 'text', 'The Greatest Text of All Time'], None)
-
- @parameterized.named_parameters(
- ('negative_x_and_negative_y',
- adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=-1))),
- ('negative_x',
- adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=123))),
- ('negative_y',
- adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=456, y=-1))),
- )
- def test_tap_failed(self, request: adb_pb2.AdbRequest):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_tap_successful(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=135, y=246))
- response = parser.parse(request)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'input', 'tap', '135', '246'], None)
-
- @parameterized.named_parameters(
- ('empty_request', adb_pb2.AdbRequest.SettingsRequest()),
- ('no_namespace',
- adb_pb2.AdbRequest.SettingsRequest(
- get=adb_pb2.AdbRequest.SettingsRequest.Get(key='my_key'))),
- ('get_no_key',
- adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- get=adb_pb2.AdbRequest.SettingsRequest.Get())),
- ('put_no_key',
- adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- put=adb_pb2.AdbRequest.SettingsRequest.Put())),
- ('put_no_value',
- adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(key='another_key'))),
- ('delete_no_key',
- adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete())),
- ('reset_no_package_name_and_no_mode',
- adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- reset=adb_pb2.AdbRequest.SettingsRequest.Reset())),
- )
- def test_settings_failures(self, request):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_settings_success_get(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'here it is!'
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- get=adb_pb2.AdbRequest.SettingsRequest.Get(key='some_key'))
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.settings.output, b'here it is!')
- adb.execute_command.assert_called_once_with(
- ['shell', 'settings', 'get', 'system', 'some_key'], None)
-
- def test_settings_success_put(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'Done for ya!'
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(key='key1', value='val2'))
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.settings.output, b'Done for ya!')
- adb.execute_command.assert_called_once_with(
- ['shell', 'settings', 'put', 'secure', 'key1', 'val2'], None)
-
- def test_settings_success_delete(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'Key deleted.'
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
- delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete(key='useless_key'))
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.settings.output, b'Key deleted.')
- adb.execute_command.assert_called_once_with(
- ['shell', 'settings', 'delete', 'global', 'useless_key'], None)
-
- @parameterized.named_parameters(
- ('mode_untrusted_defaults',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, '',
- 'untrusted_defaults'),
- ('mode_untrusted_clear',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, '',
- 'untrusted_clear'),
- ('mode_trusted_defaults',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, '',
- 'trusted_defaults'),
- # If `package_name` is given, it takes precedence over `mode`.
- ('mode_unknown_package_given',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN, 'great.package',
- 'great.package'),
- ('mode_untrusted_defaults_package_given',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS,
- 'great.package', 'great.package'),
- ('mode_untrusted_clear_package_given',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR,
- 'great.package', 'great.package'),
- ('mode_trusted_defaults_package_given',
- adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS,
- 'great.package', 'great.package'),
- )
- def test_settings_success_reset(self, mode, package_name, expected_arg):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'Pkg reset.'
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
- reset=adb_pb2.AdbRequest.SettingsRequest.Reset(
- package_name=package_name, mode=mode))
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.settings.output, b'Pkg reset.')
- adb.execute_command.assert_called_once_with(
- ['shell', 'settings', 'reset', 'global', expected_arg], None)
-
- def test_settings_success_list(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'volume_ring=5\nvolume_system=7'
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- list=adb_pb2.AdbRequest.SettingsRequest.List())
- request = adb_pb2.AdbRequest(settings=request)
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.settings.output,
- b'volume_ring=5\nvolume_system=7')
- adb.execute_command.assert_called_once_with(
- ['shell', 'settings', 'list', 'system'], None)
-
- def test_generic_command(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- expected_output = b'generic_output'
- args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
- adb.execute_command.return_value = expected_output
- parser = adb_call_parser.AdbCallParser(adb)
-
- generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
- request = adb_pb2.AdbRequest(generic=generic_request)
- response = parser.parse(request)
-
- self.assertEqual(adb_pb2.AdbResponse.Status.OK, response.status)
- self.assertEmpty(response.error_message)
- self.assertEqual(response.generic.output, expected_output)
- adb.execute_command.assert_called_once_with(args, None)
-
- def test_generic_command_adb_error(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
- adb.execute_command.side_effect = subprocess.CalledProcessError(
- cmd='cmd', output='adb_error', returncode=-1)
- parser = adb_call_parser.AdbCallParser(adb)
-
- generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
- request = adb_pb2.AdbRequest(generic=generic_request)
- response = parser.parse(request)
-
- self.assertEqual(adb_pb2.AdbResponse.Status.ADB_ERROR, response.status)
- self.assertEqual('adb_error', response.error_message)
- self.assertEmpty(response.generic.output)
- adb.execute_command.assert_called_once_with(args, None)
-
- def test_generic_command_timeout(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
- adb.execute_command.side_effect = subprocess.TimeoutExpired(
- cmd='cmd', timeout=10)
- parser = adb_call_parser.AdbCallParser(adb)
-
- generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
- request = adb_pb2.AdbRequest(generic=generic_request)
- response = parser.parse(request)
-
- self.assertEqual(adb_pb2.AdbResponse.Status.TIMEOUT, response.status)
- self.assertEqual('Timeout', response.error_message)
- self.assertEmpty(response.generic.output)
- adb.execute_command.assert_called_once_with(args, None)
-
- @parameterized.named_parameters(
- ('features',
- adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- features=adb_pb2.AdbRequest.PackageManagerRequest.List
- .Features())))),
- ('libraries',
- adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- libraries=adb_pb2.AdbRequest.PackageManagerRequest.List
- .Libraries())))),
- ('packages',
- adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- packages=adb_pb2.AdbRequest.PackageManagerRequest.List
- .Packages())))),
- )
- def test_package_manager_list_bad_output(self, request):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b"""Something irrelevant."""
- parser = adb_call_parser.AdbCallParser(adb)
- response = parser.parse(request)
- response.package_manager.output = b"""Something irrelevant."""
- self.assertEmpty(response.package_manager.list.items)
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once()
-
- def test_package_manager_list_features(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- output = b"""
-feature:android.hardware.audio.output
-feature:android.hardware.bluetooth
-feature:android.hardware.camera
-feature:android.hardware.fingerprint
-feature:android.software.autofill
-feature:android.software.backup
-feature:android.software.webview
-"""
- adb.execute_command.return_value = output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- features=adb_pb2.AdbRequest.PackageManagerRequest.List.Features(
- ))))
- response = parser.parse(request)
- self.assertEqual(response.package_manager.output, output)
- self.assertEqual(response.package_manager.list.items, [
- 'android.hardware.audio.output',
- 'android.hardware.bluetooth',
- 'android.hardware.camera',
- 'android.hardware.fingerprint',
- 'android.software.autofill',
- 'android.software.backup',
- 'android.software.webview',
- ])
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'pm', 'list', 'features'], None)
-
- def test_package_manager_list_libraries(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- output = b"""
-library:android.ext.shared
-library:android.hidl.base-V1.0-java
-library:android.hidl.manager-V1.0-java
-library:android.net.ipsec.ike
-library:android.test.base
-library:android.test.mock
-library:android.test.runner
-library:androidx.window.sidecar
-library:com.android.future.usb.accessory
-library:com.android.location.provider
-library:com.android.media.remotedisplay
-library:com.android.mediadrm.signer
-library:com.android.nfc_extras
-library:com.google.android.gms
-library:com.google.android.trichromelibrary
-library:javax.obex
-library:org.apache.http.legacy
-"""
- adb.execute_command.return_value = output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- libraries=adb_pb2.AdbRequest.PackageManagerRequest.List
- .Libraries())))
- response = parser.parse(request)
- self.assertEqual(response.package_manager.output, output)
- self.assertEqual(response.package_manager.list.items, [
- 'android.ext.shared',
- 'android.hidl.base-V1.0-java',
- 'android.hidl.manager-V1.0-java',
- 'android.net.ipsec.ike',
- 'android.test.base',
- 'android.test.mock',
- 'android.test.runner',
- 'androidx.window.sidecar',
- 'com.android.future.usb.accessory',
- 'com.android.location.provider',
- 'com.android.media.remotedisplay',
- 'com.android.mediadrm.signer',
- 'com.android.nfc_extras',
- 'com.google.android.gms',
- 'com.google.android.trichromelibrary',
- 'javax.obex',
- 'org.apache.http.legacy',
- ])
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'pm', 'list', 'libraries'], None)
-
- def test_package_manager_list_packages(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- output = b"""
-package:com.android.phone
-package:com.awesome.company
-package:com.another.great.thingie
-"""
- adb.execute_command.return_value = output
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages(
- ))))
- response = parser.parse(request)
- self.assertEqual(response.package_manager.output, output)
- self.assertEqual(response.package_manager.list.items, [
- 'com.android.phone',
- 'com.awesome.company',
- 'com.another.great.thingie',
- ])
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'pm', 'list', 'packages'], None)
-
- def test_package_manager_clear_no_package_name(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b"""Something irrelevant."""
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
- package_name='')))
- response = parser.parse(request)
-
- self.assertEmpty(response.package_manager.output)
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_package_manager_clear_successful_no_user_id(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b"""Some successful message."""
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
- package_name='my.package')))
- response = parser.parse(request)
-
- self.assertEqual(response.package_manager.output,
- b"""Some successful message.""")
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'pm', 'clear', 'my.package'], None)
-
- def test_package_manager_clear_successful_with_user_id(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b"""Some successful message."""
- parser = adb_call_parser.AdbCallParser(adb)
-
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
- package_name='my.package', user_id='mrawesome')))
- response = parser.parse(request)
-
- self.assertEqual(response.package_manager.output,
- b"""Some successful message.""")
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'pm', 'clear', '-f', 'mrawesome', 'my.package'], None)
-
- def test_dumpsys_empty_request(self):
- """An empty `DumpsysRequest` is a valid request."""
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(dumpsys=adb_pb2.AdbRequest.DumpsysRequest())
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys'],
- timeout=None)
-
- @parameterized.named_parameters(
- ('negative_timeout_sec',
- adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_sec=-1))),
- ('negative_timeout_ms',
- adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_ms=-2))),
- )
- def test_dumpsys_negative_timeouts(self, request):
- """`DumpsysRequest.timeout_{sec, ms}` if passed, should be positive."""
- adb = mock.create_autospec(adb_controller.AdbController)
- parser = adb_call_parser.AdbCallParser(adb)
-
- response = parser.parse(request)
-
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- @parameterized.named_parameters(
- ('both_timeouts_zero', 0, 0, ['shell', 'dumpsys']),
- ('sec_takes_precedence_zero', 123, 0, ['shell', 'dumpsys', '-t', '123']),
- ('sec_takes_precedence', 123, 456, ['shell', 'dumpsys', '-t', '123']),
- ('ms_if_no_sec', 0, 456, ['shell', 'dumpsys', '-T', '456']),
- )
- def test_dumpsys_timeout_successful(self, timeout_sec, timeout_ms, expected):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
- timeout_sec=timeout_sec, timeout_ms=timeout_ms))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(expected, timeout=None)
-
- @parameterized.named_parameters(
- ('priority_undefined',
- adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET,
- ['shell', 'dumpsys']),
- ('priority_normal',
- adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.NORMAL,
- ['shell', 'dumpsys', '--priority', 'NORMAL']),
- ('priority_high', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.HIGH,
- ['shell', 'dumpsys', '--priority', 'HIGH']),
- ('priority_critical',
- adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.CRITICAL,
- ['shell', 'dumpsys', '--priority', 'CRITICAL']),
- )
- def test_dumpsys_priority_timeout_successful(self, priority, expected):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(priority=priority))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(expected, timeout=None)
-
- @parameterized.named_parameters(
- (
- 'window_service',
- adb_pb2.AdbRequest.DumpsysRequest(list_only=True, service='window'),
- ),
- (
- 'arbitrary_args',
- adb_pb2.AdbRequest.DumpsysRequest(
- list_only=True, args=['myoption', 'anotheroption']
- ),
- ),
- (
- 'skip_usb',
- adb_pb2.AdbRequest.DumpsysRequest(
- list_only=True, skip_services=['usb']
- ),
- ),
- )
- def test_dumpsys_list_only_cannot_be_combined(
- self, dumpsys_request: adb_pb2.AdbRequest.DumpsysRequest
- ):
- """When `list_only==True`, the request cannot contain a few fields."""
-
- # Arrange.
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(dumpsys=dumpsys_request)
-
- # Act.
- response = parser.parse(request)
-
- # Assert.
- self.assertEqual(
- response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
- )
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_dumpsys_list_only_success(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(list_only=True))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys', '-l'],
- timeout=None)
-
- def test_dumpsys_skip_services_cannot_combine_with_service(self):
- """When using `DumpsysRequest.skip_service`, it cannot contain `.service`."""
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
- service='wifi', skip_services=['window', 'usb']))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status,
- adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
- self.assertNotEmpty(response.error_message)
- adb.execute_command.assert_not_called()
-
- def test_dumpsys_skip_services(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
- skip_services=['window', 'usb']))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'dumpsys', '--skip', 'window,usb'], timeout=None)
-
- def test_dumpsys_single_service(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window'))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'window'],
- timeout=None)
-
- def test_dumpsys_single_service_with_args(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'whatever'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
- service='window', args=['arg1', 'arg2']))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'dumpsys', 'window', 'arg1', 'arg2'], timeout=None)
-
- def test_dumpsys_single_service_with_proto(self):
- adb = mock.create_autospec(adb_controller.AdbController)
- adb.execute_command.return_value = b'some binary output'
- parser = adb_call_parser.AdbCallParser(adb)
- request = adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window', proto=True))
-
- response = parser.parse(request)
-
- self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
- self.assertEmpty(response.error_message)
- adb.execute_command.assert_called_once_with(
- ['shell', 'dumpsys', 'window', '--proto'], timeout=None)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/adb_controller.py b/android_env/components/adb_controller.py
deleted file mode 100644
index a8a1b9d3..00000000
--- a/android_env/components/adb_controller.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A class to manage and control an external ADB process."""
-
-import os
-import subprocess
-import time
-
-from absl import logging
-from android_env.components import config_classes
-from android_env.components import errors
-
-
-class AdbController:
- """Manages communication with adb."""
-
- def __init__(self, config: config_classes.AdbControllerConfig):
- """Instantiates an AdbController object."""
-
- self._config = config
- logging.info('config: %r', self._config)
-
- # Unset problematic environment variables. ADB commands will fail if these
- # are set. They are normally exported by AndroidStudio.
- if 'ANDROID_HOME' in os.environ:
- del os.environ['ANDROID_HOME']
- if 'ANDROID_ADB_SERVER_PORT' in os.environ:
- del os.environ['ANDROID_ADB_SERVER_PORT']
-
- # Explicitly expand the $HOME environment variable.
- self._os_env_vars = dict(os.environ).copy()
- self._os_env_vars.update(
- {'HOME': os.path.expandvars(self._os_env_vars.get('HOME', ''))}
- )
- logging.info('self._os_env_vars: %r', self._os_env_vars)
-
- def command_prefix(self, include_device_name: bool = True) -> list[str]:
- """The command for instantiating an adb client to this server."""
- command_prefix = [
- self._config.adb_path,
- '-P',
- str(self._config.adb_server_port),
- ]
- if include_device_name:
- command_prefix.extend(['-s', self._config.device_name])
- return command_prefix
-
- def init_server(self, timeout: float | None = None):
- """Initialize the ADB server deamon on the given port.
-
- This function should be called immediately after initializing the first
- adb_controller, and before launching the simulator.
-
- Args:
- timeout: A timeout to use for this operation. If not set the default
- timeout set on the constructor will be used.
- """
- # Make an initial device-independent call to ADB to start the deamon.
- self.execute_command(['devices'], timeout, device_specific=False)
- time.sleep(0.2)
-
- def _restart_server(self, timeout: float | None = None):
- """Kills and restarts the adb server.
-
- Args:
- timeout: A timeout to use for this operation. If not set the default
- timeout set on the constructor will be used.
- """
- logging.info('Restarting adb server.')
- self.execute_command(
- ['kill-server'], timeout=timeout, device_specific=False)
- time.sleep(0.2)
- cmd_output = self.execute_command(
- ['start-server'], timeout=timeout, device_specific=False)
- logging.info('start-server output: %r', cmd_output.decode('utf-8'))
- time.sleep(2.0)
- self.execute_command(
- ['devices'], timeout=timeout, device_specific=False)
- time.sleep(0.2)
-
- def execute_command(
- self,
- args: list[str],
- timeout: float | None = None,
- device_specific: bool = True,
- ) -> bytes:
- """Executes an adb command.
-
- Args:
- args: A list of strings representing each adb argument.
- For example: ['install', '/my/app.apk']
- timeout: A timeout to use for this operation. If not set the default
- timeout set on the constructor will be used.
- device_specific: Whether the call is device-specific or independent.
-
- Returns:
- The output of running such command as a binary string.
- """
- timeout = self._config.default_timeout if timeout is None else timeout
- command = self.command_prefix(include_device_name=device_specific) + args
- command_str = 'adb ' + ' '.join(command[1:])
-
- n_retries = 2
- n_tries = 1
- latest_error = None
- while n_tries <= n_retries:
- try:
- logging.info('Executing ADB command: [%s]', command_str)
- cmd_output = subprocess.check_output(
- command,
- stderr=subprocess.STDOUT,
- timeout=timeout,
- env=self._os_env_vars,
- )
- logging.debug('ADB command output: %s', cmd_output)
- return cmd_output
- except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
- logging.exception(
- 'Failed to execute ADB command (try %r of 3): [%s]',
- n_tries, command_str)
- if e.stdout is not None:
- logging.error('**stdout**:')
- for line in e.stdout.splitlines():
- logging.error(' %s', line)
- if e.stderr is not None:
- logging.error('**stderr**:')
- for line in e.stderr.splitlines():
- logging.error(' %s', line)
- n_tries += 1
- latest_error = e
- if device_specific and n_tries <= n_retries:
- self._restart_server(timeout=timeout)
-
- raise errors.AdbControllerError(
- f'Error executing adb command: [{command_str}]\n'
- f'Caused by: {latest_error}\n'
- f'adb stdout: [{latest_error.stdout}]\n'
- f'adb stderr: [{latest_error.stderr}]') from latest_error
diff --git a/android_env/components/adb_controller_test.py b/android_env/components/adb_controller_test.py
deleted file mode 100644
index dddd04d3..00000000
--- a/android_env/components/adb_controller_test.py
+++ /dev/null
@@ -1,234 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.adb_controller."""
-
-import os
-import subprocess
-import time
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_controller as adb_controller_lib
-from android_env.components import config_classes
-from android_env.components import errors
-
-# Timeout to be used by default in tests below. Set to a small value to avoid
-# hanging on a failed test.
-_TIMEOUT = 2
-
-
-class AdbControllerTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- # Set two env vars.
- os.environ['MY_ENV_VAR'] = '/some/path/'
- os.environ['HOME'] = '$MY_ENV_VAR'
- self._env_before = os.environ
- self._adb_controller = adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- )
- )
-
- @mock.patch.object(subprocess, 'check_output', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_init_server(self, mock_sleep, mock_check_output):
- # Arrange.
- adb_controller = adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- )
- )
-
- # Act.
- adb_controller.init_server(timeout=_TIMEOUT)
-
- # Assert.
- expected_env = self._env_before
- expected_env['HOME'] = '/some/path/'
- mock_check_output.assert_called_once_with(
- ['my_adb', '-P', '9999', 'devices'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- )
- mock_sleep.assert_called_once()
-
- @mock.patch.object(subprocess, 'check_output', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_restart_server(self, mock_sleep, mock_check_output):
- # Arrange.
- mock_check_output.side_effect = [
- subprocess.CalledProcessError(returncode=1, cmd='blah'),
- ] + ['fake_output'.encode('utf-8')] * 4
- adb_controller = adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- )
- )
-
- # Act.
- adb_controller.execute_command(['my_command'], timeout=_TIMEOUT)
-
- # Assert.
- expected_env = self._env_before
- expected_env['HOME'] = '/some/path/'
- mock_check_output.assert_has_calls([
- mock.call(
- ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'kill-server'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'start-server'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'devices'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- ])
- mock_sleep.assert_has_calls(
- [mock.call(0.2), mock.call(2.0), mock.call(0.2)])
-
- @mock.patch.object(subprocess, 'check_output', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_invalid_command(self, mock_sleep, mock_check_output):
- # Arrange.
- restart_sequence = ['fake_output'.encode('utf-8')] * 3
- mock_check_output.side_effect = (
- [
- subprocess.CalledProcessError(returncode=1, cmd='blah'),
- ]
- + restart_sequence
- + [subprocess.CalledProcessError(returncode=1, cmd='blah')]
- # Don't restart if last call fails.
- )
- adb_controller = adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- )
- )
-
- # Act.
- with self.assertRaises(errors.AdbControllerError):
- adb_controller.execute_command(['my_command'], timeout=_TIMEOUT)
-
- # Assert.
- expected_env = self._env_before
- expected_env['HOME'] = '/some/path/'
- mock_check_output.assert_has_calls(
- [
- mock.call(
- ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'kill-server'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'start-server'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', 'devices'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- mock.call(
- ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'],
- stderr=subprocess.STDOUT,
- timeout=_TIMEOUT,
- env=expected_env,
- ),
- ],
- any_order=False,
- )
- mock_sleep.assert_has_calls(
- [mock.call(0.2), mock.call(2.0), mock.call(0.2)]
- )
-
- @mock.patch.object(subprocess, 'check_output', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_avoid_infinite_recursion(self, mock_sleep, mock_check_output):
- del mock_sleep
- mock_check_output.side_effect = subprocess.CalledProcessError(
- returncode=1, cmd='blah')
- adb_controller = adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- )
- )
- self.assertRaises(
- errors.AdbControllerError,
- adb_controller.execute_command, ['my_command'], timeout=_TIMEOUT)
-
-
-class AdbControllerInitTest(absltest.TestCase):
-
- def test_deletes_problem_env_vars(self):
- os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk'
- os.environ['ANDROID_ADB_SERVER_PORT'] = '1337'
- adb_controller_lib.AdbController(
- config_classes.AdbControllerConfig(
- adb_path='my_adb',
- device_name='awesome_device',
- adb_server_port=9999,
- default_timeout=_TIMEOUT,
- )
- )
- self.assertNotIn('ANDROID_HOME', os.environ)
- self.assertNotIn('ANDROID_ADB_SERVER_PORT', os.environ)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/adb_log_stream.py b/android_env/components/adb_log_stream.py
deleted file mode 100644
index 17ad4eb8..00000000
--- a/android_env/components/adb_log_stream.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Class for a stream of logs output by a locally running emulator."""
-
-import subprocess
-
-from absl import logging
-from android_env.components import log_stream
-
-
-_LOGCAT_COMMAND = ['logcat', '-v', 'epoch']
-
-
-class AdbLogStream(log_stream.LogStream):
- """Manages adb logcat process for a locally running emulator."""
-
- def __init__(self, adb_command_prefix: list[str], verbose: bool = False):
- super().__init__(verbose=verbose)
- self._adb_command_prefix = adb_command_prefix
-
- def _get_stream_output(self):
-
- # Before spawning a long-lived process, we issue `logcat -b all -c` to clear
- # all buffers to avoid interference from previous runs.
- clear_buffer_output = subprocess.check_output(
- self._adb_command_prefix + ['logcat', '-b', 'all', '-c'],
- stderr=subprocess.STDOUT,
- timeout=100)
- logging.info('clear_buffer_output: %r', clear_buffer_output)
- cmd = self._adb_command_prefix + _LOGCAT_COMMAND + self._filters
- self._adb_subprocess = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1,
- universal_newlines=True)
- return self._adb_subprocess.stdout
-
- def stop_stream(self):
- if not hasattr(self, '_adb_subprocess') or self._adb_subprocess is None:
- logging.error('`stop_stream()` called before `get_stream_output()`. '
- 'This violates the `LogStream` API.')
- else:
- self._adb_subprocess.kill()
diff --git a/android_env/components/adb_log_stream_test.py b/android_env/components/adb_log_stream_test.py
deleted file mode 100644
index 98ea2120..00000000
--- a/android_env/components/adb_log_stream_test.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for adb_log_stream."""
-
-import subprocess
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_log_stream
-
-
-class FakeAdbSubprocess:
-
- @property
- def stdout(self):
- return [f'line_{i}' for i in range(100)]
-
- def kill(self):
- pass
-
-
-class AdbLogStreamTest(absltest.TestCase):
-
- @mock.patch.object(subprocess, 'check_output', return_value=b'')
- @mock.patch.object(subprocess, 'Popen', return_value=FakeAdbSubprocess())
- def test_get_stream_output(self, mock_popen, unused_mock_check_output):
- stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo'])
- stream.set_log_filters(['bar'])
- stream_output = stream.get_stream_output()
-
- for i, line in enumerate(stream_output):
- self.assertEqual(line, f'line_{i}')
-
- mock_popen.assert_called_with(
- ['foo', 'logcat', '-v', 'epoch', 'bar', '*:S'],
- stderr=subprocess.STDOUT,
- stdout=subprocess.PIPE,
- bufsize=1,
- universal_newlines=True)
-
- def test_stop_stream_before_get_stream_output(self):
- """Calling `stop_stream()` before `get_stream_output()` should not crash."""
-
- # Arrange.
- stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo'])
-
- # Act.
- stream.stop_stream()
-
- # Assert.
- # Nothing to assert. The test should just finish without raising an
- # exception.
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/app_screen_checker.py b/android_env/components/app_screen_checker.py
deleted file mode 100644
index 742e481c..00000000
--- a/android_env/components/app_screen_checker.py
+++ /dev/null
@@ -1,269 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Determines if the current app screen matches an expected app screen."""
-
-from collections.abc import Callable, Sequence
-import enum
-import re
-import time
-from typing import Self
-
-from absl import logging
-from android_env.components import adb_call_parser as adb_call_parser_lib
-from android_env.components import errors
-from android_env.proto import adb_pb2
-from android_env.proto import task_pb2
-
-
-class _DumpsysNode:
- """A node in a dumpsys tree."""
-
- def __init__(self, data: str):
- self._children = []
- self._data = data
-
- @property
- def data(self) -> str:
- return self._data
-
- @property
- def children(self) -> list[Self]:
- return self._children
-
- def find_child(
- self, predicate: Callable[[Self], bool], max_levels: int = 0
- ) -> Self | None:
- """Returns the first direct child that matches `predicate`, None otherwise.
-
- Args:
- predicate: Function-like that accepts a _DumpsysNode and returns boolean.
- max_levels: Maximum number of levels down the tree to search for a child.
- If non-positive, only direct children will be searched for.
-
- Returns:
- A _DumpsysNode or None.
- """
- if not self.children:
- return None
-
- try:
- return next(x for x in self.children if predicate(x))
- except StopIteration:
- logging.info('Failed to find child. max_levels: %i.', max_levels)
- # Search children.
- if max_levels:
- for child in self.children:
- child_result = child.find_child(predicate, max_levels - 1)
- if child_result is not None:
- return child_result
-
- return None
-
- def __repr__(self):
- return self._data
-
- def print_tree(self, indent: int = 2):
- """Prints this tree in logging.info()."""
- logging.info(' ' * indent + self.data)
- for c in self.children:
- c.print_tree(indent + 2)
-
-
-def build_tree_from_dumpsys_output(dumpsys_output: str) -> _DumpsysNode:
- """Constructs a tree from a dumpsys string output.
-
- Args:
- dumpsys_output: string Verbatim output from adb dumpsys. The expected format
- is a list where each line is a node and the indentation marks the
- relationship with its parent or sibling.
-
- Returns:
- _DumpsysNode The root of the tree.
- """
- lines = dumpsys_output.split('\n') # Split by lines.
- lines = [x.rstrip(' \r') for x in lines]
- lines = [x for x in lines if len(x)] # Remove empty lines.
-
- root = _DumpsysNode('___root___') # The root of all nodes.
- parents_stack = [root]
- for line in lines:
- stripped_line = line.lstrip(' ')
- indent = len(line) - len(stripped_line) # Number of indent spaces.
- new_node = _DumpsysNode(stripped_line) # Create a node without indentation.
-
- parent = parents_stack.pop()
- if parent.data == '___root___': # The root is an exception for indentation.
- parent_indent = -2
- else:
- parent_indent = (len(parents_stack) - 1) * 2
-
- if indent == parent_indent: # `new_node` is a sibiling.
- parent = parents_stack.pop()
- elif indent < parent_indent: # Indentation reduced (i.e. a block finished)
- num_levels = (indent // 2) + 1
- parents_stack = parents_stack[:num_levels]
- parent = parents_stack.pop()
- elif indent > parent_indent: # `new_node` is a child.
- pass # No need to change the current parent.
-
- parent.children.append(new_node)
- parents_stack.append(parent)
- parents_stack.append(new_node)
-
- return root
-
-
-def matches_path(
- dumpsys_activity_output: str,
- expected_view_hierarchy_path: Sequence[re.Pattern[str]],
- max_levels: int = 0,
-) -> bool:
- """Returns True if the current dumpsys output matches the expected path.
-
- Args:
- dumpsys_activity_output: The output of running `dumpsys activity ...`.
- expected_view_hierarchy_path: [regex] A list of regular expressions to be
- tested at each level of the tree.
- max_levels: How many levels to search from root for View Hierarchy.
-
- Returns:
- True if the dumpsys tree contains one path that matches all regexes.
- """
- root = build_tree_from_dumpsys_output(dumpsys_activity_output)
-
- # Find the View Hierarchy.
- view_hierarchy = root.find_child(
- lambda x: x.data.startswith('View Hierarchy'), max_levels)
- if view_hierarchy is None:
- logging.error(
- 'view_hierarchy is None. Dumpsys activity output: %s. tree: %r',
- str(dumpsys_activity_output), root.print_tree())
- logging.error('Tree root: %s', str(root))
- return False
-
- current_node = view_hierarchy
- for i, regex in enumerate(expected_view_hierarchy_path):
-
- def regex_predicate(node, expr=regex):
- matches = expr.match(node.data)
- return matches is not None
-
- child = current_node.find_child(regex_predicate)
- if child is None:
- logging.error('Mismatched regex (%i, %s). current_node: %s', i,
- regex.pattern, current_node)
- logging.error('Dumpsys activity output: %s', str(dumpsys_activity_output))
- logging.error('Tree root: %s', str(root))
- return False
- else:
- current_node = child
- return True
-
-
-class AppScreenChecker:
- """Checks that the current app screen matches an expected screen."""
-
- class Outcome(enum.IntEnum):
- """Possible return vales from checking the current app screen."""
- # The current app screen matches the expected app screen.
- SUCCESS = 0
- # There's no activity to check.
- EMPTY_EXPECTED_ACTIVITY = 1
- # We were unable to determine the current activity.
- FAILED_ACTIVITY_EXTRACTION = 2
- # The current activity does not match the expected activity.
- UNEXPECTED_ACTIVITY = 3
- # The current view hierarchy does not match the expected view hierarchy.
- UNEXPECTED_VIEW_HIERARCHY = 4
-
- def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser,
- expected_app_screen: task_pb2.AppScreen):
- self._adb_call_parser = adb_call_parser
- self._expected_app_screen = expected_app_screen
- self._expected_activity = expected_app_screen.activity
- self._expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_app_screen.view_hierarchy_path
- ]
-
- # Return type is AppScreenChecker.Outcome, but pytype doesn't understand that.
- def matches_current_app_screen(self) -> enum.IntEnum:
- """Determines whether the current app screen matches `expected_app_screen`."""
- if not self._expected_activity:
- return AppScreenChecker.Outcome.EMPTY_EXPECTED_ACTIVITY
-
- # Check if we are still on the expected Activity.
- response = self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()))
- if response.status != adb_pb2.AdbResponse.OK:
- return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
-
- current_activity = response.get_current_activity.full_activity
- if current_activity != self._expected_activity:
- logging.error('current_activity: %s, expected_activity: %s',
- current_activity, self._expected_activity)
- return AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY
-
- # Extract just the package name from the full activity name.
- package_name = self._expected_activity.split('/')[0]
-
- # Check if we are in the expected view hierarchy path.
- if self._expected_view_hierarchy_path:
- dumpsys_response = self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
- service='activity', args=[package_name, package_name])))
- if dumpsys_response.status != adb_pb2.AdbResponse.OK:
- return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
-
- if dumpsys_response.dumpsys.output:
- if not matches_path(
- dumpsys_response.dumpsys.output.decode('utf-8'),
- self._expected_view_hierarchy_path,
- max_levels=3):
- return AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY
-
- return AppScreenChecker.Outcome.SUCCESS
-
- def wait_for_app_screen(self, timeout_sec: float) -> float:
- """Waits for `self._expected_app_screen` to be the current screen.
-
- Args:
- timeout_sec: Maximum total time to wait for the screen to pop up.
-
- Returns:
- The total amount of time in seconds spent waiting for the screen to pop
- up.
- Raises:
- errors.WaitForAppScreenError if the screen does not pop up within
- `timeout_sec`.
- """
-
- logging.info('Waiting for app screen...')
- start_time = time.time()
- while time.time() - start_time < timeout_sec:
- if self.matches_current_app_screen() == AppScreenChecker.Outcome.SUCCESS:
- wait_time = time.time() - start_time
- logging.info('Successfully waited for app screen in %r seconds: [%r]',
- wait_time, self._expected_app_screen)
- return wait_time
- time.sleep(0.1)
-
- wait_time = time.time() - start_time
- logging.error('Failed to wait for app screen in %r seconds: [%r].',
- wait_time, self._expected_app_screen)
-
- raise errors.WaitForAppScreenError()
diff --git a/android_env/components/app_screen_checker_test.py b/android_env/components/app_screen_checker_test.py
deleted file mode 100644
index d9ec6aeb..00000000
--- a/android_env/components/app_screen_checker_test.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.app_screen_checker."""
-
-import re
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_call_parser
-from android_env.components import app_screen_checker
-from android_env.components import errors
-from android_env.proto import adb_pb2
-from android_env.proto import task_pb2
-
-
-def _flatten_tree(
- tree: app_screen_checker._DumpsysNode, flat_tree: list[str], indent: int = 2
-):
- """Appends a list of strings to `flat_tree` from `tree`."""
- flat_tree.append(' ' * indent + tree.data)
- for c in tree.children:
- _flatten_tree(c, flat_tree, indent + 2)
-
-
-class AppScreenCheckerTest(absltest.TestCase):
-
- # Ensures that build_tree_from_dumpsys_output produces a node whose flat
- # representation matches our expectation from an arbitrary hierarchy.
- def test_build_tree_from_dumpsys_output(self):
- dumpsys_output = """
-Queen Elizabeth II
- Charles
- William
- George
- Charlotte
- Louis
- Harry
- Archie
- Anne
- Peter
- Savannah
- Isla
- Zara
- Mia
- Lena
- Andrew
- Beatrice
- Eugenie
- Edward
- Louise
- James
-"""
- tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output)
- flat_tree = []
- _flatten_tree(tree, flat_tree, indent=2)
- self.assertEqual(flat_tree, [
- ' ___root___',
- ' Queen Elizabeth II',
- ' Charles',
- ' William',
- ' George',
- ' Charlotte',
- ' Louis',
- ' Harry',
- ' Archie',
- ' Anne',
- ' Peter',
- ' Savannah',
- ' Isla',
- ' Zara',
- ' Mia',
- ' Lena',
- ' Andrew',
- ' Beatrice',
- ' Eugenie',
- ' Edward',
- ' Louise',
- ' James',
- ])
-
- # Ensures that build_tree_from_dumpsys_output produces a node whose flat
- # representation matches our expectation from an arbitrary hierarchy.
- def test_build_forest_from_dumpsys_output(self):
- dumpsys_output = """
-Tree1
- Branch1
- Leaf1
- Leaf2
- Branch2
- Leaf3
- Leaf4
- Leaf5
-Tree2
- Branch3
- Leaf6
- Leaf7
- Branch4
- Leaf8
- Leaf9
- Leaf10
- Leaf11
-"""
- tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output)
- flat_tree = []
- _flatten_tree(tree, flat_tree, indent=2)
- self.assertEqual(flat_tree, [
- ' ___root___',
- ' Tree1',
- ' Branch1',
- ' Leaf1',
- ' Leaf2',
- ' Branch2',
- ' Leaf3',
- ' Leaf4',
- ' Leaf5',
- ' Tree2',
- ' Branch3',
- ' Leaf6',
- ' Leaf7',
- ' Branch4',
- ' Leaf8',
- ' Leaf9',
- ' Leaf10',
- ' Leaf11',
- ])
-
- def test_no_view_hierarchy_matches_path(self):
- dumpsys_output = """
-TASK
- ACTIVITY
- Missing View Hierarchy
- A
- B
- C
- D
- E
- F
-"""
- expected_path = ['^A$', 'B$']
- expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_path
- ]
- self.assertFalse(
- app_screen_checker.matches_path(dumpsys_output,
- expected_view_hierarchy_path))
-
- def test_matches_path(self):
- dumpsys_output = """
-TASK
- ACTIVITY
- Some node we don't care
- Blah
-
- View Hierarchy
- Hirohito
- Akihito
- Naruhito
- Aiko
- Fumihito
- Mako
- Kako
- Hisahito
- Masahito
-"""
- expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
- expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_path
- ]
- self.assertTrue(
- app_screen_checker.matches_path(
- dumpsys_output, expected_view_hierarchy_path, max_levels=2))
-
- # Also check that the following path does not match anything in the tree.
- expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kenji$']
- expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_path
- ]
- self.assertFalse(
- app_screen_checker.matches_path(dumpsys_output,
- expected_view_hierarchy_path))
-
- def test_matches_path_one_level_deep(self):
- dumpsys_output = """
-TASK
- ACTIVITY
- Some node we don't care
- Blah
-
- Some intermediate node
- View Hierarchy
- Hirohito
- Akihito
- Naruhito
- Aiko
- Fumihito
- Mako
- Kako
- Hisahito
- Masahito
-"""
- expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
- expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_path
- ]
- self.assertTrue(
- app_screen_checker.matches_path(
- dumpsys_output, expected_view_hierarchy_path, max_levels=3))
-
- # Also check that the view hierarchy is not found when searching only grand
- # children of TASK.
- expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
- expected_view_hierarchy_path = [
- re.compile(regex) for regex in expected_path
- ]
- self.assertFalse(
- app_screen_checker.matches_path(
- dumpsys_output, expected_view_hierarchy_path, max_levels=2))
-
- def test_wait_for_app_screen_zero_timeout(self):
- """Ensures that an exception is raised if the timeout is passed."""
- app_screen = task_pb2.AppScreen(activity='whatever.MyActivity')
- call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
- screen_checker = app_screen_checker.AppScreenChecker(
- adb_call_parser=call_parser,
- expected_app_screen=app_screen)
- # With a zero timeout, the method should never be able to wait for the
- # screen to pop up and an exception should be raised.
- self.assertRaises(
- errors.WaitForAppScreenError,
- screen_checker.wait_for_app_screen,
- timeout_sec=0.0)
-
- def test_wait_for_app_screen_successful(self):
- """Ensures that with the right conditions, the app screen should pop up."""
- app_screen = task_pb2.AppScreen(activity='my.favorite.AwesomeActivity')
- call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
- call_parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK,
- get_current_activity=adb_pb2.AdbResponse.GetCurrentActivityResponse(
- full_activity='my.favorite.AwesomeActivity'))
-
- screen_checker = app_screen_checker.AppScreenChecker(
- call_parser, app_screen)
- timeout = 1.0
- wait_time = screen_checker.wait_for_app_screen(timeout_sec=timeout)
-
- # The call should not generate an exception and the return value should be
- # less than the timeout given.
- self.assertLess(wait_time, timeout)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/config_classes.py b/android_env/components/config_classes.py
deleted file mode 100644
index b4ed956f..00000000
--- a/android_env/components/config_classes.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Dataclass definitions used for instantiating AndroidEnv components."""
-
-import dataclasses
-
-
-@dataclasses.dataclass
-class AdbControllerConfig:
- """Settings for instatiating an `AdbController` instance."""
-
- # Filesystem path to the `adb` binary.
- # NOTE: This must be a full path and must not contain environment variables
- # or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be
- # expanded internally by AndroidEnv.
- adb_path: str = '~/Android/Sdk/platform-tools/adb'
- # Port for adb server.
- adb_server_port: int = 5037
- # Default timeout in seconds for internal commands.
- default_timeout: float = 120.0
- # Name of the device to communicate with.
- device_name: str = ''
-
-
-@dataclasses.dataclass
-class DeviceSettingsConfig:
- """Config class for DeviceSettings."""
-
- # Whether to show circles on the screen indicating touch position.
- show_touches: bool = True
- # Whether to show blue lines on the screen indicating touch position.
- show_pointer_location: bool = True
- # Whether or not to show the status (top) bar.
- show_status_bar: bool = False
- # Whether or not to show the navigation (bottom) bar.
- show_navigation_bar: bool = False
-
-
-@dataclasses.dataclass
-class CoordinatorConfig:
- """Config class for Coordinator."""
-
- # Number of virtual "fingers" of the agent.
- num_fingers: int = 1
- # Whether to enable keyboard key events.
- enable_key_events: bool = False
- # Time between periodic restarts in minutes. If > 0, will trigger
- # a simulator restart at the beginning of the next episode once the time has
- # been reached.
- periodic_restart_time_min: float = 0.0
- # General Android settings.
- device_settings: DeviceSettingsConfig = dataclasses.field(
- default_factory=DeviceSettingsConfig
- )
-
-
-@dataclasses.dataclass
-class SimulatorConfig:
- """Base class for all simulator configs."""
-
- # If true, the log stream of the simulator will be verbose.
- verbose_logs: bool = False
- # How often to (asynchronously) grab the screenshot from the simulator.
- # If <= 0, stepping the environment blocks on fetching the screenshot (the
- # environment is synchronous).
- interaction_rate_sec: float = 0.0
-
-
-@dataclasses.dataclass
-class EmulatorLauncherConfig:
- """Config class for EmulatorLauncher."""
-
- # NOTE: If `adb_port`, `emulator_console_port` and `grpc_port` are defined
- # (i.e. not all equal to 0), it is assumed that the emulator they point to
- # exists already and EmulatorLauncher will be skipped.
-
- # Filesystem path to the `emulator` binary.
- emulator_path: str = '~/Android/Sdk/emulator/emulator'
- # Filesystem path to the Android SDK root.
- android_sdk_root: str = '~/Android/Sdk'
- # Name of the AVD.
- avd_name: str = ''
- # Local directory for AVDs.
- android_avd_home: str = '~/.android/avd'
- # Name of the snapshot to load.
- snapshot_name: str = ''
- # Path to the KVM device.
- kvm_device: str = '/dev/kvm'
- # Path to directory which will hold temporary files.
- tmp_dir: str = '/tmp/android_env/simulator/'
- # GPU mode override.
- # Please see
- # https://developer.android.com/studio/run/emulator-acceleration#accel-graphics.
- gpu_mode: str = 'swangle_indirect' # Alternative: swiftshader_indirect, host
- # Whether to run in headless mode (i.e. without a graphical window).
- run_headless: bool = True
- # Whether to restrict network access.
- # If True, will disable networking on the device. This option is only
- # available for emulator version > 31.3.9 (June 2022).
- restrict_network: bool = False
- # Whether to set `SHOW_PERF_STATS=1` when launching the emulator to display
- # performance and memory statistics.
- show_perf_stats: bool = False
-
- # ADB port for the Android device.
- adb_port: int = 0
- # Port for telnet communication with the emulator.
- emulator_console_port: int = 0
- # Port for gRPC communication with the emulator.
- grpc_port: int = 0
-
-
-@dataclasses.dataclass
-class EmulatorConfig(SimulatorConfig):
- """Config class for EmulatorSimulator."""
-
- # Configuration for launching the Android Emulator.
- emulator_launcher: EmulatorLauncherConfig = dataclasses.field(
- default_factory=EmulatorLauncherConfig
- )
- # Configuration for talking to adb.
- adb_controller: AdbControllerConfig = dataclasses.field(
- default_factory=AdbControllerConfig
- )
- # Path to file which holds emulator logs. If not provided, it will be
- # determined by the EmulatorLauncher.
- logfile_path: str = ''
- # The number of times to try launching the emulator before rebooting (reboot
- # on the n+1-st try).
- launch_n_times_without_reboot: int = 1
- # The number of times to try launching the emulator before reinstalling
- # (reinstall on the n+1-st try).
- launch_n_times_without_reinstall: int = 2
-
-
-@dataclasses.dataclass
-class FakeSimulatorConfig(SimulatorConfig):
- """Config class for FakeSimulator."""
-
- # The dimensions in pixels of the device screen (HxW).
- screen_dimensions: tuple[int, int] = (0, 0)
-
-
-@dataclasses.dataclass
-class TaskManagerConfig:
- """Config class for TaskManager."""
-
- # If max_bad_states episodes finish in a bad state in a row, restart
- # the simulation.
- max_bad_states: int = 3
- # The frequency to check for the current activity and view hierarchy.
- # The unit is raw observation (i.e. each call to AndroidEnv.step()).
- dumpsys_check_frequency: int = 150
- # The maximum number of tries for extracting the current activity before
- # forcing the episode to restart.
- max_failed_current_activity: int = 10
- # The maximum number of extras elements to store. If this number is exceeded,
- # elements are dropped in the order they were received.
- extras_max_buffer_size: int = 100
-
-
-@dataclasses.dataclass
-class TaskConfig:
- """Base config class for loading tasks."""
-
- # The directory for temporary task-related resources.
- tmp_dir: str = ''
-
-
-@dataclasses.dataclass
-class FilesystemTaskConfig(TaskConfig):
- """Config for protobuf files stored in the local filesystem."""
-
- # Filesystem path to `.binarypb` or `.textproto` protobuf Task.
- path: str = ''
-
-
-@dataclasses.dataclass
-class AndroidEnvConfig:
- """Config class for AndroidEnv."""
-
- # Configs for main components.
- task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
- task_manager: TaskManagerConfig = dataclasses.field(
- default_factory=TaskManagerConfig
- )
- coordinator: CoordinatorConfig = dataclasses.field(
- default_factory=CoordinatorConfig
- )
- simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig)
diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py
deleted file mode 100644
index bca636ec..00000000
--- a/android_env/components/coordinator.py
+++ /dev/null
@@ -1,282 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Coordinator handles interaction between internal components of AndroidEnv."""
-
-import copy
-import time
-from typing import Any
-
-from absl import logging
-from android_env.components import action_fns
-from android_env.components import action_type as action_type_lib
-from android_env.components import adb_call_parser
-from android_env.components import config_classes
-from android_env.components import device_settings as device_settings_lib
-from android_env.components import errors
-from android_env.components import pixel_fns
-from android_env.components import specs
-from android_env.components import task_manager as task_manager_lib
-from android_env.components.simulators import base_simulator
-from android_env.proto import adb_pb2
-import dm_env
-import numpy as np
-
-
-class Coordinator:
- """Handles interaction between internal components of AndroidEnv."""
-
- def __init__(
- self,
- simulator: base_simulator.BaseSimulator,
- task_manager: task_manager_lib.TaskManager,
- device_settings: device_settings_lib.DeviceSettings,
- config: config_classes.CoordinatorConfig | None = None,
- ):
- """Handles communication between AndroidEnv and its components.
-
- Args:
- simulator: A BaseSimulator instance.
- task_manager: The TaskManager, responsible for coordinating RL tasks.
- config: Settings to customize this Coordinator.
- """
- self._simulator = simulator
- self._task_manager = task_manager
- self._config = config or config_classes.CoordinatorConfig()
- self._device_settings = device_settings
- self._adb_call_parser: adb_call_parser.AdbCallParser = None
-
- # Initialize stats.
- self._stats = {
- 'relaunch_count': 0,
- 'relaunch_count_periodic': 0,
- 'relaunch_count_setup_steps': 0,
- 'relaunch_count_reset_steps': 0,
- 'relaunch_count_simulator_launch': 0,
- 'relaunch_count_simulator_reset': 0,
- 'relaunch_count_execute_action': 0,
- 'relaunch_count_fetch_observation': 0,
- 'relaunch_count_update_settings': 0,
- 'failed_task_updates': 0,
- }
-
- # Initialize counters.
- self._simulator_healthy = False
- self._latest_observation_time = 0
- self._simulator_start_time = None
-
- logging.info('Starting the simulator...')
- self._launch_simulator()
-
- def action_spec(self) -> dict[str, dm_env.specs.Array]:
- return specs.base_action_spec(
- num_fingers=self._config.num_fingers,
- enable_key_events=self._config.enable_key_events,
- )
-
- def observation_spec(self) -> dict[str, dm_env.specs.Array]:
- return specs.base_observation_spec(
- height=self._device_settings.screen_height(),
- width=self._device_settings.screen_width(),
- )
-
- def _should_periodic_relaunch(self) -> bool:
- """Checks if it is time to restart the simulator.
-
- If a periodic restart time was specified, the Coordinator will re-launch
- the simulator at regular time intervals. This helps to make sure that the
- simulator is not in a stale state even if the environment has been running
- for a significant amount of time.
-
- Returns:
- Boolean indicating if it is time to restart the simulator.
- """
-
- if self._config.periodic_restart_time_min and self._simulator_start_time:
- sim_alive_time = (time.time() - self._simulator_start_time) / 60.0
- logging.info('Simulator has been running for %f mins', sim_alive_time)
- if sim_alive_time > self._config.periodic_restart_time_min:
- logging.info('Maximum alive time reached. Restarting simulator.')
- self._stats['relaunch_count_periodic'] += 1
- return True
- return False
-
- def _launch_simulator(self, max_retries: int = 3):
- """Launches the simulator.
-
- Sets up the simulator and other task-related settings.
-
- Args:
- max_retries: Number of times to attempt a restart before raising an error.
- """
-
- self._simulator_healthy = False
-
- # Attempt to restart the system a given number of times.
- num_tries = 1
- latest_error = None
- while True:
- if num_tries > max_retries:
- raise errors.TooManyRestartsError(
- 'Maximum number of restart attempts reached.'
- ) from latest_error
- logging.info('Simulator launch attempt %d of %d', num_tries, max_retries)
-
- self._task_manager.stop()
-
- # Launch the simulator.
- self._simulator.launch()
- self._simulator_start_time = time.time()
-
- # From here on, the simulator is assumed to be up and running.
- self._adb_call_parser = self._create_adb_call_parser()
- try:
- self._device_settings.update(self._config.device_settings)
- except errors.AdbControllerError as e:
- logging.exception('device_settings.update() failed.')
- self._stats['relaunch_count_update_settings'] += 1
- self._latest_error = e
- num_tries += 1
- continue
-
- # Start the task.
- self._task_manager.start(
- adb_call_parser_factory=self._create_adb_call_parser,
- log_stream=self._simulator.create_log_stream(),
- )
- try:
- self._task_manager.setup_task()
- except errors.StepCommandError as error:
- logging.exception('Failed to set up the task. Restarting simulator.')
- self._stats['relaunch_count_setup_steps'] += 1
- latest_error = error
- num_tries += 1
- continue
-
- # Restart was successful.
- self._simulator_healthy = True
- self._stats['relaunch_count'] += 1
- break
-
- def _create_adb_call_parser(self):
- """Creates a new AdbCallParser instance."""
- return adb_call_parser.AdbCallParser(
- adb_controller=self._simulator.create_adb_controller()
- )
-
- def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
- return self._adb_call_parser.parse(call)
-
- def rl_reset(self) -> dm_env.TimeStep:
- """Resets the RL episode."""
-
- # Relaunch the simulator if necessary.
- if not self._simulator_healthy or self._should_periodic_relaunch():
- self._launch_simulator()
-
- # Reset counters.
- self._latest_observation_time = 0
- for key in self._stats:
- if key.startswith('episode'):
- self._stats[key] = 0.0
-
- # Execute a lift action before resetting the task.
- if not action_fns.send_action_to_simulator(
- action_fns.lift_all_fingers_action(self._config.num_fingers),
- self._simulator,
- self._device_settings.screen_width(),
- self._device_settings.screen_height(),
- self._config.num_fingers,
- ):
- self._stats['relaunch_count_execute_action'] += 1
- self._simulator_healthy = False
-
- # Reset the task.
- self._task_manager.reset_task()
- self._device_settings.get_orientation()
-
- # Get data from the simulator.
- simulator_signals = self._gather_simulator_signals()
-
- return self._task_manager.rl_reset(simulator_signals)
-
- def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- """Executes the selected action and returns a timestep.
-
- Args:
- agent_action: Selected action to perform on the simulated Android device.
- If `agent_action` is `None` it means that this is an RL reset (to start
- a new episode).
-
- Returns:
- An RL timestep.
- """
-
- if not action_fns.send_action_to_simulator(
- agent_action,
- self._simulator,
- self._device_settings.screen_width(),
- self._device_settings.screen_height(),
- self._config.num_fingers,
- ):
- self._stats['relaunch_count_execute_action'] += 1
- self._simulator_healthy = False
-
- # Get data from the simulator.
- try:
- simulator_signals = self._gather_simulator_signals()
- except errors.ReadObservationError:
- logging.exception('Unable to fetch observation. Restarting simulator.')
- self._stats['relaunch_count_fetch_observation'] += 1
- self._simulator_healthy = False
-
- if not self._simulator_healthy:
- return dm_env.truncation(reward=0.0, observation=None)
-
- return self._task_manager.rl_step(simulator_signals)
-
- def _gather_simulator_signals(self) -> dict[str, np.ndarray]:
- """Gathers data from various sources to assemble the RL observation."""
-
- # Get current timestamp and update the delta.
- now = time.time()
- timestamp_delta = (
- 0
- if self._latest_observation_time == 0
- else (now - self._latest_observation_time) * 1e6
- )
- self._latest_observation_time = now
-
- return {
- 'pixels': self._simulator.get_screenshot(),
- 'orientation': self._device_settings.get_orientation(),
- 'timedelta': np.array(timestamp_delta, dtype=np.int64),
- }
-
- def __del__(self):
- self.close()
-
- def stats(self) -> dict[str, Any]:
- """Returns various statistics."""
-
- return copy.deepcopy(self._stats)
-
- def close(self):
- """Cleans up the state of this Coordinator."""
-
- if hasattr(self, '_task_manager'):
- self._task_manager.stop()
- if hasattr(self, '_simulator'):
- self._simulator.close()
diff --git a/android_env/components/coordinator_test.py b/android_env/components/coordinator_test.py
deleted file mode 100644
index 78d3a4fe..00000000
--- a/android_env/components/coordinator_test.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.coordinator."""
-
-import tempfile
-import time
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import action_type
-from android_env.components import adb_call_parser
-from android_env.components import config_classes
-from android_env.components import coordinator as coordinator_lib
-from android_env.components import device_settings as device_settings_lib
-from android_env.components import errors
-from android_env.components import task_manager
-from android_env.components.simulators import base_simulator
-from android_env.proto import adb_pb2
-from android_env.proto import state_pb2
-from android_env.proto import task_pb2
-import dm_env
-import numpy as np
-
-
-class CoordinatorTest(parameterized.TestCase):
-
- def setUp(self):
- super().setUp()
- self.addCleanup(mock.patch.stopall) # Disable previous patches.
-
- self._simulator = mock.create_autospec(base_simulator.BaseSimulator)
- self._random_screenshot = np.random.randint(
- low=0, high=255, size=(800, 600, 3), dtype=np.uint8)
- self._simulator.get_screenshot.return_value = self._random_screenshot
- self._task_manager = mock.create_autospec(task_manager.TaskManager)
- self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
- self.enter_context(
- mock.patch.object(
- adb_call_parser,
- 'AdbCallParser',
- autospec=True,
- return_value=self._adb_call_parser))
- self._coordinator = coordinator_lib.Coordinator(
- simulator=self._simulator,
- task_manager=self._task_manager,
- device_settings=device_settings_lib.DeviceSettings(self._simulator),
- )
-
- def tearDown(self):
- super().tearDown()
- self._coordinator.close()
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_relaunch_simulator(self, unused_mock_sleep):
- relaunch_count = self._coordinator.stats()['relaunch_count']
- self._coordinator._launch_simulator()
- self.assertEqual(self._coordinator.stats()['relaunch_count'],
- relaunch_count + 1)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_reset(self, unused_mock_sleep):
- """'relaunch_count_execute_action' should be zero if there's no error."""
-
- self._coordinator.rl_reset()
- stats = self._coordinator.stats()
- self.assertIn('relaunch_count_execute_action', stats)
- self.assertEqual(stats['relaunch_count_execute_action'], 0)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_reset_error_sending_action(self, unused_mock_sleep):
- """'relaunch_count_execute_action' should be positive if there's an error."""
-
- self._simulator.send_touch.side_effect = errors.SendActionError()
- self._coordinator.rl_reset()
- stats = self._coordinator.stats()
- self.assertIn('relaunch_count_execute_action', stats)
- self.assertEqual(stats['relaunch_count_execute_action'], 1)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_lift_all_fingers(self, unused_mock_sleep):
- self._coordinator = coordinator_lib.Coordinator(
- simulator=self._simulator,
- task_manager=self._task_manager,
- device_settings=device_settings_lib.DeviceSettings(self._simulator),
- config=config_classes.CoordinatorConfig(num_fingers=3),
- )
- self._coordinator.rl_reset()
- expected_actions = [
- # (x, y, is_down, identifier).
- (0, 0, False, 0),
- (0, 0, False, 1),
- (0, 0, False, 2),
- ]
- actual_actions = self._simulator.send_touch.call_args[0][0]
- for actual, expected in zip(actual_actions, expected_actions):
- np.testing.assert_array_equal(actual, expected)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_process_action(self, unused_mock_sleep):
-
- def fake_rl_step(simulator_signals):
- return dm_env.transition(
- reward=10.0,
- observation={
- 'pixels': simulator_signals['pixels'],
- 'orientation': simulator_signals['orientation'],
- 'timedelta': simulator_signals['timedelta'],
- 'extras': {
- 'extra': [0.0]
- }
- })
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- timestep = self._coordinator.rl_step(
- agent_action={
- 'action_type': np.array(action_type.ActionType.LIFT),
- 'touch_position': np.array([0.5, 0.5]),
- })
- obs = timestep.observation
- self.assertEqual(obs['pixels'].shape, (800, 600, 3))
- np.testing.assert_equal(obs['orientation'],
- np.array([0, 0, 0, 0], dtype=np.uint8))
- self.assertEqual(timestep.reward, 10.0)
- self.assertEqual(obs['extras'], {'extra': [0.0]})
- self.assertFalse(timestep.last())
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_process_action_error(self, unused_mock_sleep):
-
- def fake_rl_step(simulator_signals):
- self.assertFalse(simulator_signals['simulator_healthy'])
- return dm_env.truncation(reward=0.0, observation=None)
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- self._simulator.get_screenshot.side_effect = errors.ReadObservationError()
- timestep = self._coordinator.rl_step(
- agent_action={
- 'action_type': np.array(action_type.ActionType.LIFT),
- 'touch_position': np.array([0.5, 0.5]),
- })
- self.assertIsNone(timestep.observation)
- self.assertEqual(timestep.reward, 0.0)
- self.assertTrue(timestep.last())
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_execute_action_touch(self, unused_mock_sleep):
-
- def fake_rl_step(simulator_signals):
- return dm_env.transition(
- reward=123.0,
- observation={
- 'pixels': simulator_signals['pixels'],
- 'orientation': simulator_signals['orientation'],
- 'timedelta': simulator_signals['timedelta'],
- 'extras': {
- 'extra': [0.0]
- }
- })
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- timestep = self._coordinator.rl_step(
- agent_action={
- 'action_type': np.array(action_type.ActionType.TOUCH),
- 'touch_position': np.array([0.5, 0.5])
- })
- self.assertEqual(timestep.reward, 123.0)
- np.testing.assert_equal(timestep.observation['pixels'],
- self._random_screenshot)
- self._simulator.send_touch.assert_called_once_with([(300, 400, True, 0)])
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_execute_multitouch_action(self, unused_mock_sleep):
- self._coordinator = coordinator_lib.Coordinator(
- simulator=self._simulator,
- task_manager=self._task_manager,
- device_settings=device_settings_lib.DeviceSettings(self._simulator),
- config=config_classes.CoordinatorConfig(num_fingers=3),
- )
-
- def fake_rl_step(simulator_signals):
- return dm_env.transition(
- reward=456.0,
- observation={
- 'pixels': simulator_signals['pixels'],
- 'orientation': simulator_signals['orientation'],
- 'timedelta': simulator_signals['timedelta'],
- 'extras': {
- 'extra': [0.0]
- }
- })
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- action = {
- 'action_type': np.array([action_type.ActionType.TOUCH]),
- 'touch_position': np.array([0.25, 0.75]),
- 'action_type_2': np.array([action_type.ActionType.TOUCH]),
- 'touch_position_2': np.array([0.75, 0.25]),
- 'action_type_3': np.array([action_type.ActionType.LIFT]),
- 'touch_position_3': np.array([0.5, 0.5]),
- }
- timestep = self._coordinator.rl_step(action)
- self._simulator.send_touch.assert_called_once_with([(150, 600, True, 0),
- (450, 200, True, 1),
- (300, 400, False, 2)])
- self.assertEqual(timestep.reward, 456.0)
- np.testing.assert_equal(timestep.observation['pixels'],
- self._random_screenshot)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_execute_action_repeat(self, unused_mock_sleep):
- def fake_rl_step(simulator_signals):
- return dm_env.transition(
- reward=10.0,
- observation={
- 'pixels': simulator_signals['pixels'],
- 'orientation': simulator_signals['orientation'],
- 'timedelta': simulator_signals['timedelta'],
- 'extras': {
- 'extra': [0.0]
- }
- })
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- timestep = self._coordinator.rl_step(
- {'action_type': np.array(action_type.ActionType.REPEAT)})
- self._simulator.send_touch.assert_not_called()
- np.testing.assert_equal(timestep.observation['pixels'],
- self._random_screenshot)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_execute_action_error(self, unused_mock_sleep):
- def fake_rl_step(simulator_signals):
- self.assertFalse(simulator_signals['simulator_healthy'])
- return dm_env.truncation(reward=0.0, observation=None)
-
- self._task_manager.rl_step.side_effect = fake_rl_step
- self._simulator.send_touch.side_effect = errors.SendActionError
- timestep = self._coordinator.rl_step({
- 'action_type': np.array(action_type.ActionType.TOUCH),
- 'touch_position': np.array([0.3, 0.8])
- })
- self.assertIsNone(timestep.observation)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_max_restarts_setup_steps(self, unused_mock_sleep):
- init_fn_call = self._task_manager.setup_task.call_count
- self._task_manager.setup_task.side_effect = errors.StepCommandError
- self.assertRaises(errors.TooManyRestartsError,
- self._coordinator._launch_simulator)
- # The method was called three more times when attempting to relaunch.
- self.assertEqual(init_fn_call + 3,
- self._task_manager.setup_task.call_count)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_execute_adb_call(self, unused_mock_sleep):
- call = adb_pb2.AdbRequest(
- force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
- expected_response = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- self._adb_call_parser.parse.side_effect = [expected_response]
-
- response = self._coordinator.execute_adb_call(call)
-
- self.assertEqual(response, expected_response)
- self._adb_call_parser.parse.assert_called_with(call)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/device_settings.py b/android_env/components/device_settings.py
deleted file mode 100644
index 8b5b7a5e..00000000
--- a/android_env/components/device_settings.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Sets and gets some global settings on an Android device."""
-
-from typing import Final
-from unittest import mock
-
-from absl import logging
-from android_env.components import adb_call_parser
-from android_env.components import config_classes
-from android_env.components.simulators import base_simulator
-from android_env.proto import adb_pb2
-import numpy as np
-
-
-# The internal `AdbCallParser` instance is lazily instantiated within
-# `DeviceSettings`. If we make it optional (i.e. `| None`), pytype will think
-# that it could be `None`, requiring either explicit runtime checks or escape
-# hatches in every actual call, even if it's never actually `None` if reached
-# via the public API.
-# The trick here is to create this dummy instance of the right type that's used
-# as a sentinel to indicate that it hasn't been initialized yet.
-_PLACEHOLDER_ADB_CALL_PARSER: Final[adb_call_parser.AdbCallParser] = (
- mock.create_autospec(adb_call_parser.AdbCallParser)
-)
-
-
-class DeviceSettings:
- """An abstraction for general properties and settings of an Android device."""
-
- def __init__(self, simulator: base_simulator.BaseSimulator):
- self._simulator = simulator
- self._adb_call_parser = _PLACEHOLDER_ADB_CALL_PARSER
-
- # The size of the device screen in pixels.
- self._screen_width: int = 0
- self._screen_height: int = 0
- # The device orientation.
- self._orientation = np.zeros(4, dtype=np.uint8)
-
- def update(self, config: config_classes.DeviceSettingsConfig) -> None:
- """Sets the configuration of the device according to `config`."""
-
- if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER:
- self._adb_call_parser = adb_call_parser.AdbCallParser(
- adb_controller=self._simulator.create_adb_controller()
- )
-
- self._update_screen_size()
- self._set_show_touches(config.show_touches)
- self._set_show_pointer_location(config.show_pointer_location)
- self._set_status_navigation_bars(
- config.show_navigation_bar, config.show_status_bar
- )
-
- def screen_width(self) -> int:
- """The screen width in pixels. Only valid after `update()` is called."""
-
- return self._screen_width
-
- def screen_height(self) -> int:
- """The screen height in pixels. Only valid after `update()` is called."""
-
- return self._screen_height
-
- def get_orientation(self) -> np.ndarray:
- """Returns the device orientation. Please see specs.py for details."""
-
- if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER:
- self._adb_call_parser = adb_call_parser.AdbCallParser(
- adb_controller=self._simulator.create_adb_controller()
- )
-
- self._update_orientation()
- return self._orientation
-
- def _update_screen_size(self) -> None:
- """Sets the screen size from a screenshot ignoring the color channel."""
-
- screenshot = self._simulator.get_screenshot()
- self._screen_height = screenshot.shape[0]
- self._screen_width = screenshot.shape[1]
-
- def _set_show_touches(self, show: bool) -> None:
- """Whether to display circles indicating the touch position."""
-
- self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='show_touches', value='1' if show else '0'
- ),
- )
- )
- )
-
- def _set_show_pointer_location(self, show: bool) -> None:
- """Whether to display blue lines on the screen indicating touch position."""
-
- self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='pointer_location', value='1' if show else '0'
- ),
- )
- )
- )
-
- def _set_status_navigation_bars(
- self, show_navigation: bool, show_status: bool
- ) -> None:
- """Whether to display the status (top) and navigation (bottom) bars."""
-
- if show_navigation and show_status:
- policy_control_value = 'null*'
- elif show_navigation and not show_status:
- policy_control_value = 'immersive.status=*'
- elif not show_navigation and show_status:
- policy_control_value = 'immersive.navigation=*'
- else:
- policy_control_value = 'immersive.full=*'
-
- self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='policy_control', value=policy_control_value
- ),
- )
- )
- )
-
- def _update_orientation(self) -> None:
- """Updates the current device orientation."""
-
- # Skip fetching the orientation if we already have it.
- if not np.all(self._orientation == np.zeros(4)):
- return
-
- orientation_response = self._adb_call_parser.parse(
- adb_pb2.AdbRequest(
- get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()
- )
- )
- if orientation_response.status != adb_pb2.AdbResponse.Status.OK:
- logging.error('Got bad orientation: %r', orientation_response)
- return
-
- orientation = orientation_response.get_orientation.orientation
- if orientation not in {0, 1, 2, 3}:
- logging.error('Got bad orientation: %r', orientation)
- return
-
- # Transform into one-hot format.
- orientation_onehot = np.zeros([4], dtype=np.uint8)
- orientation_onehot[orientation] = 1
- self._orientation = orientation_onehot
diff --git a/android_env/components/device_settings_test.py b/android_env/components/device_settings_test.py
deleted file mode 100644
index 83eb5f55..00000000
--- a/android_env/components/device_settings_test.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import config_classes
-from android_env.components import device_settings as device_settings_lib
-from android_env.components.simulators import base_simulator
-import numpy as np
-
-
-class DeviceSettingsTest(parameterized.TestCase):
-
- def test_screen_size_before_update(self):
- """The screen size should be 0x0 without calling `update()`."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- height = device_settings.screen_height()
- width = device_settings.screen_width()
-
- # Assert.
- self.assertEqual(height, 0)
- self.assertEqual(width, 0)
-
- def test_screen_size_after_update(self):
- """The screen size should be set after calling `update()`."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- simulator.get_screenshot.return_value = np.random.randint(
- low=0, high=255, size=(123, 456, 3), dtype=np.uint8
- )
- adb_controller = simulator.create_adb_controller.return_value
- adb_controller.execute_command.return_value = b''
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- device_settings.update(config_classes.DeviceSettingsConfig())
- height = device_settings.screen_height()
- width = device_settings.screen_width()
-
- # Assert.
- self.assertEqual(height, 123)
- self.assertEqual(width, 456)
-
- @parameterized.named_parameters(
- (
- 'show_touches',
- config_classes.DeviceSettingsConfig(show_touches=True),
- mock.call(
- ['shell', 'settings', 'put', 'system', 'show_touches', '1'],
- timeout=None,
- ),
- ),
- (
- 'show_touches_false',
- config_classes.DeviceSettingsConfig(show_touches=False),
- mock.call(
- ['shell', 'settings', 'put', 'system', 'show_touches', '0'],
- timeout=None,
- ),
- ),
- (
- 'show_pointer_location',
- config_classes.DeviceSettingsConfig(show_pointer_location=True),
- mock.call(
- ['shell', 'settings', 'put', 'system', 'pointer_location', '1'],
- timeout=None,
- ),
- ),
- (
- 'show_pointer_location_false',
- config_classes.DeviceSettingsConfig(show_pointer_location=False),
- mock.call(
- ['shell', 'settings', 'put', 'system', 'pointer_location', '0'],
- timeout=None,
- ),
- ),
- (
- 'show_navigation_and_status',
- config_classes.DeviceSettingsConfig(
- show_navigation_bar=True, show_status_bar=True
- ),
- mock.call(
- ['shell', 'settings', 'put', 'global', 'policy_control', 'null*'],
- timeout=None,
- ),
- ),
- (
- 'show_navigation_and_no_status',
- config_classes.DeviceSettingsConfig(
- show_navigation_bar=True, show_status_bar=False
- ),
- mock.call(
- [
- 'shell',
- 'settings',
- 'put',
- 'global',
- 'policy_control',
- 'immersive.status=*',
- ],
- timeout=None,
- ),
- ),
- (
- 'show_no_navigation_and_status',
- config_classes.DeviceSettingsConfig(
- show_navigation_bar=False, show_status_bar=True
- ),
- mock.call(
- [
- 'shell',
- 'settings',
- 'put',
- 'global',
- 'policy_control',
- 'immersive.navigation=*',
- ],
- timeout=None,
- ),
- ),
- (
- 'show_no_navigation_and_no_status',
- config_classes.DeviceSettingsConfig(
- show_navigation_bar=False, show_status_bar=False
- ),
- mock.call(
- [
- 'shell',
- 'settings',
- 'put',
- 'global',
- 'policy_control',
- 'immersive.full=*',
- ],
- timeout=None,
- ),
- ),
- )
- def test_update(
- self, settings: config_classes.DeviceSettingsConfig, expected_call
- ):
- """We expect the right call for each setting."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- adb_controller = simulator.create_adb_controller.return_value
- adb_controller.execute_command.return_value = b''
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- device_settings.update(settings)
-
- # Assert.
- adb_controller.execute_command.assert_has_calls(
- [expected_call], any_order=True
- )
-
- def test_get_orientation_bad_response(self):
- """The orientation should be unset if the underlying response is bad."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- adb_controller = simulator.create_adb_controller.return_value
- adb_controller.execute_command.return_value = b''
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- orientation = device_settings.get_orientation()
-
- # Assert.
- np.testing.assert_array_equal(orientation, np.zeros(4))
-
- def test_get_orientation_bad_orientation(self):
- """The orientation should be unset if the underlying orientation is bad."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- adb_controller = simulator.create_adb_controller.return_value
- adb_controller.execute_command.return_value = b' InputDeviceOrientation: 9'
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- orientation = device_settings.get_orientation()
-
- # Assert.
- np.testing.assert_array_equal(orientation, np.zeros(4))
-
- def test_get_orientation_success(self):
- """Checks that the orientation comes back as expected."""
-
- # Arrange.
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- adb_controller = simulator.create_adb_controller.return_value
- adb_controller.execute_command.return_value = b' InputDeviceOrientation: 3'
- device_settings = device_settings_lib.DeviceSettings(simulator)
-
- # Act.
- orientation = device_settings.get_orientation()
- # The output should be idempotent if the underlying system did not change.
- orientation_again = device_settings.get_orientation()
-
- # Assert.
- np.testing.assert_array_equal(orientation, np.array([0, 0, 0, 1]))
- np.testing.assert_array_equal(orientation, orientation_again)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/dumpsys_thread.py b/android_env/components/dumpsys_thread.py
deleted file mode 100644
index 3466307a..00000000
--- a/android_env/components/dumpsys_thread.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A ThreadFunction that runs and parses adb dumpsys."""
-
-import concurrent.futures
-
-from absl import logging
-from android_env.components import app_screen_checker as app_screen_checker_lib
-
-_Outcome = app_screen_checker_lib.AppScreenChecker.Outcome
-
-
-class DumpsysThread:
- """A thread that checks if the user is in the expected app screen."""
-
- def __init__(
- self,
- app_screen_checker: app_screen_checker_lib.AppScreenChecker,
- check_frequency: int = 10,
- max_failed_current_activity: int = 10,
- ):
- """Initializes the dumpsys reader thread.
-
- This loops forever checking if the user is in the expected screen dictated
- by `app_screen_checker`. These analyses are too expensive to be in the
- critical path of AndroidEnv::step() so we consume them async from this
- separate thread.
-
- Args:
- app_screen_checker: The class that actually determines if the current
- screen matches the expected screen.
- check_frequency: Integer. We only call dumpsys 1/check_frequency times in
- each iteration of the while loop below.
- max_failed_current_activity: Integer. We try to fetch the current activity
- but sometimes it fails. If it fails more than
- `max_failed_current_activity` consecutive times, we declare that the
- user has exited `expected_activity`.
- """
-
- self._app_screen_checker = app_screen_checker
- self._main_loop_counter = 0
- self._check_frequency = check_frequency
- self._max_failed_activity_extraction = max_failed_current_activity
- self._num_failed_activity_extraction = 0
- self._latest_check: concurrent.futures.Future | None = None
-
- def check_user_exited(self, timeout: float | None = None) -> bool:
- """Returns True if the user is not in the expected screen.
-
- Args:
- timeout: An optional time in seconds to block waiting for the result of
- the (expensive) checking operation. If None, the function will return
- immediately with `False`.
-
- Returns:
- Whether the user of the Android device has exited the expected screen
- determined by `AppScreenChecker` given at __init__().
- """
-
- # Update and check loop_counter against check_frequency.
- self._main_loop_counter += 1
- if (self._check_frequency <= 0 or
- self._main_loop_counter < self._check_frequency):
- return False
- self._main_loop_counter = 0
-
- # If the latest check is None, perform a check and return.
- if self._latest_check is None:
- with concurrent.futures.ThreadPoolExecutor() as executor:
- self._latest_check = executor.submit(self._check_impl)
- return False
-
- # If there's a check in flight, continue only if it's finished.
- if not timeout and not self._latest_check.done():
- return False
-
- v = self._latest_check.result(timeout=timeout)
- self._latest_check = None # Reset the check.
- return v
-
- def _check_impl(self) -> bool:
- """The synchronous implementation of Dumpsys."""
-
- outcome = self._app_screen_checker.matches_current_app_screen()
-
- # We were unable to determine the current activity.
- if outcome == _Outcome.FAILED_ACTIVITY_EXTRACTION:
- self._num_failed_activity_extraction += 1
- logging.info('self._num_failed_activity_extraction: %s',
- self._num_failed_activity_extraction)
- if (self._num_failed_activity_extraction >=
- self._max_failed_activity_extraction):
- logging.error('Maximum number of failed activity extraction reached.')
- self._num_failed_activity_extraction = 0
- return True
- else:
- self._num_failed_activity_extraction = 0
-
- # The current app screen matches all expectations.
- if (outcome == _Outcome.SUCCESS or
- outcome == _Outcome.EMPTY_EXPECTED_ACTIVITY):
- return False
-
- # Player has exited the app. Terminate the episode.
- elif outcome == _Outcome.UNEXPECTED_ACTIVITY:
- return True
-
- # Player has exited the main game. Terminate the episode.
- elif outcome == _Outcome.UNEXPECTED_VIEW_HIERARCHY:
- return True
-
- return False
diff --git a/android_env/components/dumpsys_thread_test.py b/android_env/components/dumpsys_thread_test.py
deleted file mode 100644
index fad76656..00000000
--- a/android_env/components/dumpsys_thread_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.dumpsys_thread."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import app_screen_checker as screen_checker
-from android_env.components import dumpsys_thread
-
-
-class DumpsysThreadTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self._app_screen_checker = mock.create_autospec(
- screen_checker.AppScreenChecker)
-
- def test_unexpected_activity(self):
- dumpsys = dumpsys_thread.DumpsysThread(
- app_screen_checker=self._app_screen_checker, check_frequency=1)
- outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY
- self._app_screen_checker.matches_current_app_screen.return_value = outcome
- # The first time that `check_user_exited()` is called, it'll only trigger
- # the processing, but it should return immediately.
- self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
- # The second time it should then wait for the result.
- self.assertTrue(dumpsys.check_user_exited(timeout=1.0))
-
- def test_unexpected_view_hierarchy(self):
- dumpsys = dumpsys_thread.DumpsysThread(
- app_screen_checker=self._app_screen_checker, check_frequency=1)
- outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY
- self._app_screen_checker.matches_current_app_screen.return_value = outcome
- self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
- self.assertTrue(dumpsys.check_user_exited(timeout=1.0))
-
- def test_success(self):
- dumpsys = dumpsys_thread.DumpsysThread(
- app_screen_checker=self._app_screen_checker, check_frequency=1)
- outcome = screen_checker.AppScreenChecker.Outcome.SUCCESS
- self._app_screen_checker.matches_current_app_screen.return_value = outcome
- self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
- self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
-
- def test_skipped(self):
- dumpsys = dumpsys_thread.DumpsysThread(
- app_screen_checker=self._app_screen_checker, check_frequency=5)
- self._app_screen_checker.matches_current_app_screen.side_effect = [
- screen_checker.AppScreenChecker.Outcome.SUCCESS,
- screen_checker.AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
- ]
-
- for _ in range(17):
- self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
-
- # The first 4 calls will hit the early exit from `check_frequency`.
- # The 5th call will trigger the processing (increasing the call count to
- # matches_current_app_screen() by 1), but it should return early.
- # The 10th call will find a result of the previous processing, and it should
- # be SUCCESS.
- # The next 4 calls (11, 12, 13, 14) will hit the early exit from
- # `check_frequency`.
- # The 15th call should trigger the processing again (increasing the call
- # count to matches_current_app_screen() by 1), but it should return early.
- # The next 2 calls (16, 17) will hit the early exit from `check_frequency`.
- # In total there should be only two calls to `matches_current_app_screen()`.
- expected_call_count = 2
- self.assertEqual(
- self._app_screen_checker.matches_current_app_screen.call_count,
- expected_call_count)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/errors.py b/android_env/components/errors.py
deleted file mode 100644
index 50ef8b4f..00000000
--- a/android_env/components/errors.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Definitions of exceptions used by AndroidEnv."""
-
-
-class AndroidEnvError(Exception):
- """Base class for all known errors generated by AndroidEnv."""
-
- # An integer that identifies this class of error.
- # Subclasses should use a different value.
- ERROR_CODE: int = 0
-
-
-class ReadObservationError(AndroidEnvError):
- """When the environment is unable to obtain an observation from a simulator."""
-
- ERROR_CODE = 1
-
-
-class CoordinatorError(AndroidEnvError):
- """Error raised by the Coordinator."""
-
- ERROR_CODE = 2
-
-
-class TooManyRestartsError(CoordinatorError):
- """The number of restarts has exceeded _MAX_RESTART_TRIES."""
-
- ERROR_CODE = 3
-
-
-class AdbControllerError(AndroidEnvError):
- """Errors that can be raised by ADBController."""
-
- ERROR_CODE = 4
-
-
-class SimulatorError(AndroidEnvError):
- """Errors that can be raised by a simulator."""
-
- ERROR_CODE = 5
-
-
-class SendActionError(AndroidEnvError):
- """Raised when action couldn't be sent successfully."""
-
- ERROR_CODE = 6
-
-
-class StepCommandError(AndroidEnvError):
- """Raised when setup step interpreter cannot process a command."""
-
- ERROR_CODE = 7
-
-
-class WaitForAppScreenError(StepCommandError):
- """Raised when the wait_for_app_screen success check is not met."""
-
- ERROR_CODE = 8
-
-
-class CheckInstallError(StepCommandError):
- """Raised when the check_install success check is not met."""
-
- ERROR_CODE = 9
-
-
-def from_code(code: int, msg: str = '') -> AndroidEnvError | None:
- """Returns an AndroidEnvError instance from the given arguments."""
-
- code_to_error = {
- 0: AndroidEnvError,
- 1: ReadObservationError,
- 2: CoordinatorError,
- 3: TooManyRestartsError,
- 4: AdbControllerError,
- 5: SimulatorError,
- 6: SendActionError,
- 7: StepCommandError,
- 8: WaitForAppScreenError,
- 9: CheckInstallError,
- }
-
- if code in code_to_error:
- return code_to_error[code](msg)
diff --git a/android_env/components/errors_test.py b/android_env/components/errors_test.py
deleted file mode 100644
index df0deca6..00000000
--- a/android_env/components/errors_test.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for errors.py."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import errors
-
-
-class ErrorsTest(parameterized.TestCase):
-
- @parameterized.parameters(
- (errors.ReadObservationError, 1),
- (errors.CoordinatorError, 2),
- (errors.TooManyRestartsError, 3),
- (errors.AdbControllerError, 4),
- (errors.SimulatorError, 5),
- (errors.SendActionError, 6),
- (errors.StepCommandError, 7),
- (errors.WaitForAppScreenError, 8),
- (errors.CheckInstallError, 9),
- )
- def test_error_codes(self, error, expected_error_code):
- with self.assertRaises(error) as context:
- raise error()
- self.assertEqual(context.exception.ERROR_CODE, expected_error_code)
-
- def test_error_codes_unique(self):
- error_codes = set()
- errors_list = (
- errors.ReadObservationError,
- errors.CoordinatorError,
- errors.TooManyRestartsError,
- errors.AdbControllerError,
- errors.SimulatorError,
- errors.SendActionError,
- errors.StepCommandError,
- errors.WaitForAppScreenError,
- errors.CheckInstallError,
- )
- for error in errors_list:
- self.assertNotIn(error.ERROR_CODE, error_codes)
- error_codes.add(error.ERROR_CODE)
-
- @parameterized.parameters([
- errors.ReadObservationError(),
- errors.CoordinatorError(),
- errors.TooManyRestartsError(),
- errors.AdbControllerError(),
- errors.SimulatorError(),
- errors.SendActionError(),
- errors.StepCommandError(),
- errors.WaitForAppScreenError(),
- errors.CheckInstallError(),
- ])
- def test_all_errors_are_androidenv_errors(self, error):
- self.assertIsInstance(error, errors.AndroidEnvError)
-
- @parameterized.named_parameters([
- ('less_than_zero', -1),
- # The largest `ERROR_CODE` is currently `CheckInstallError == 10`.
- ('greater_than_all_errors', 10 + 1),
- ('less_than_zero_float', -3.14159265),
- ('greater_than_all_errors_float', 123.456),
- ])
- def test_from_code_unsupported_code(self, code: int):
- """Unsupported errors should raise `RuntimeError`."""
-
- self.assertIsNone(errors.from_code(code))
-
- @parameterized.parameters([
- (-1, None, 'No such error code.'),
- (0, errors.AndroidEnvError, 'hello'),
- (0, errors.AndroidEnvError, ''),
- (1, errors.ReadObservationError, 'Could not read obs.'),
- (2, errors.CoordinatorError, 'Some error'),
- (3, errors.TooManyRestartsError, 'Too many already...'),
- (4, errors.AdbControllerError, 'Some adb error...'),
- (5, errors.SimulatorError, 'Simulator is not coping.'),
- (6, errors.SendActionError, 'Could not send action.'),
- (7, errors.StepCommandError, 'Some issue setting up the task.'),
- (8, errors.WaitForAppScreenError, 'Waited for too long!'),
- (9, errors.CheckInstallError, 'App did not install correctly.'),
- ])
- def test_from_code(self, code: int, expected_class: errors.AndroidEnvError,
- msg: str):
- """`from_code` should produce consistent outputs for known errors."""
-
- error = errors.from_code(code, msg)
- if error is not None:
- self.assertIsInstance(error, expected_class)
- self.assertEqual(error.ERROR_CODE, code)
- self.assertEqual(str(error), msg)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/log_stream.py b/android_env/components/log_stream.py
deleted file mode 100644
index d9a047b5..00000000
--- a/android_env/components/log_stream.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Abstract class for handling a stream of logs from a simulator."""
-
-import abc
-from collections.abc import Generator, Sequence
-import threading
-from absl import logging
-
-
-class LogStream(metaclass=abc.ABCMeta):
- """Manages the stream of logs output by a simulator."""
-
- def __init__(self, verbose: bool = False):
- self._verbose = verbose
- self._filters = []
- self._should_stream = threading.Event()
-
- def get_stream_output(self) -> Generator[str, None, None]:
- """Starts log process and returns the stream of logs."""
- for line in self._get_stream_output():
- if self._verbose:
- logging.info('line: %r', line)
- if self._should_stream.is_set():
- yield line
-
- @abc.abstractmethod
- def _get_stream_output(self):
- """Starts log process and returns the stream of logs."""
- pass
-
- @abc.abstractmethod
- def stop_stream(self) -> None:
- """Terminates the log stream.
-
- NOTE: This should only be called _after_ `get_stream_output()`.
- """
-
- def pause_stream(self) -> None:
- """No lines are yielded while the event is not set."""
- logging.info('Pausing LogStream.')
- self._should_stream.clear()
-
- def resume_stream(self) -> None:
- """The stream will continue yielding lines if the event is set."""
- logging.info('Resuming LogStream.')
- self._should_stream.set()
-
- def set_log_filters(self, log_filters: Sequence[str]):
- """Sets the filters for the log stream."""
- self._filters = list(log_filters) + ['*:S']
diff --git a/android_env/components/log_stream_test.py b/android_env/components/log_stream_test.py
deleted file mode 100644
index 90fd7dcf..00000000
--- a/android_env/components/log_stream_test.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for log_stream."""
-
-from absl.testing import absltest
-from android_env.components import log_stream
-
-
-class FakeLogStream(log_stream.LogStream):
-
- def __init__(self, filter_name: str):
- super().__init__()
- self._filter_name = filter_name
-
- def _get_stream_output(self):
- """Starts a log process and returns the stream of logs."""
- lines = [
- f'{self._filter_name} fake_line_1',
- 'fake_line_2',
- f'{self._filter_name} fake_line_3',
- f'{self._filter_name} fake_line_4',
- 'fake_line_5',
- 'fake_line_6',
- ]
- for line in lines:
- if f'{self._filter_name}:V' in self._filters:
- if self._filter_name in line:
- yield line
- else:
- yield line
-
- def stop_stream(self):
- """Stops the log stream from the simulator."""
- pass
-
-
-class LogStreamTest(absltest.TestCase):
-
- def test_get_stream_output(self):
- filter_name = 'AndroidRLTask'
- stream = FakeLogStream(filter_name=filter_name)
- stream.resume_stream()
- stream_output = stream.get_stream_output()
- expected_lines = [
- f'{filter_name} fake_line_1',
- 'fake_line_2',
- f'{filter_name} fake_line_3',
- f'{filter_name} fake_line_4',
- 'fake_line_5',
- 'fake_line_6',
- ]
- for line, expected_line in zip(stream_output, expected_lines):
- self.assertEqual(line, expected_line)
-
- def test_set_log_filters(self):
- filter_name = 'AndroidRLTask'
- stream = FakeLogStream(filter_name=filter_name)
- stream.set_log_filters([f'{filter_name}:V'])
- stream.resume_stream()
- stream_output = stream.get_stream_output()
- expected_lines = [
- f'{filter_name} fake_line_1',
- f'{filter_name} fake_line_3',
- f'{filter_name} fake_line_4',
- ]
- for line, expected_line in zip(stream_output, expected_lines):
- self.assertEqual(line, expected_line)
-
- def test_pause_resume_stream(self):
- filter_name = 'AndroidRLTask'
- stream = FakeLogStream(filter_name=filter_name)
- stream.resume_stream()
- stream_output = stream.get_stream_output()
- expected_lines = [
- f'{filter_name} fake_line_1',
- 'fake_line_2',
- f'{filter_name} fake_line_3',
- f'{filter_name} fake_line_4',
- 'fake_line_5',
- 'fake_line_6',
- ]
- for line, expected_line in zip(stream_output, expected_lines):
- self.assertEqual(line, expected_line)
- # If the stream is paused, we expect no lines to be yielded.
- stream.pause_stream()
- stream_output = list(stream.get_stream_output())
- self.assertEmpty(stream_output)
- # If the stream is resumed, we expect to see all lines yielded.
- stream.resume_stream()
- stream_output = stream.get_stream_output()
- for line, expected_line in zip(stream_output, expected_lines):
- self.assertEqual(line, expected_line)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/logcat_thread.py b/android_env/components/logcat_thread.py
deleted file mode 100644
index fffadbcb..00000000
--- a/android_env/components/logcat_thread.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A class that launches a thread to read Android log outputs."""
-
-from collections.abc import Callable
-import dataclasses
-import re
-import threading
-
-from absl import logging
-from android_env.components import log_stream as log_stream_lib
-
-
-@dataclasses.dataclass
-class EventListener:
- """A function that's called when an event is triggered."""
-
- regexp: re.Pattern[str]
- handler_fn: Callable[[re.Pattern[str], re.Match[str]], None]
-
-
-class LogcatThread:
- """Reads log entries in a separate thread."""
-
- def __init__(self, log_stream: log_stream_lib.LogStream):
- """Initializes this LogcatThread with optional filters.
-
- Please see https://developer.android.com/studio/command-line/logcat for more
- info on `logcat`.
-
- Args:
- log_stream: Stream of logs from simulator.
- """
-
- self._log_stream = log_stream
- self._listeners = {}
- self._line_ready = threading.Event()
- self._line_ready.set()
- self._should_stop = threading.Event()
- self._thread = threading.Thread(target=self._process_logs)
- self._thread.daemon = True
- self._thread.start()
-
- def add_event_listener(self, event_listener: EventListener) -> None:
- """Adds `fn` to the list of handlers to call when `event` occurs."""
- event_regexp = event_listener.regexp
- if event_regexp not in self._listeners:
- self._listeners[event_regexp] = []
- self._listeners[event_regexp].append(event_listener.handler_fn)
-
- def remove_event_listener(self, event_listener: EventListener) -> None:
- """Removes `fn` from the list of handlers to call when `event` occurs."""
- event_regexp = event_listener.regexp
- if event_regexp not in self._listeners:
- logging.error('Event: %r is not registered.', event_regexp)
- return
- self._listeners[event_regexp].remove(event_listener.handler_fn)
-
- def line_ready(self) -> threading.Event:
- """Indicates whether all listeners have been notified for a given line."""
- return self._line_ready
-
- def pause(self):
- self._log_stream.pause_stream()
-
- def resume(self):
- self._log_stream.resume_stream()
-
- def kill(self):
- self._should_stop.set()
- self._log_stream.stop_stream()
- self._thread.join(timeout=3.0)
-
- def _process_logs(self) -> None:
- """A loop that runs until `self._should_stop` is set()."""
-
- # pylint: disable=g-line-too-long
- # Format is: "TIME_SEC PID TID PRIORITY TAG: MESSAGE"
- #
- # Example:
- # ' 1553110400.424 5583 5658 D NostalgicRacer: com.google.example.games.nostalgicracer.views.renderers.OpenGLRenderDriver@912fb8.onSurfaceChanged 480x320' #
- # pylint: enable=g-line-too-long
-
- logline_regexp = r"""
- ^ # Beginning of the line.
- [ ]+(?P[0-9]+\.[0-9]+) # Spaces and a float.
- [ ]+(?P[0-9]+) # Spaces and an int.
- [ ]+(?P[0-9]+) # Spaces and an int.
- [ ]+(?P.) # Spaces and any single character.
- [ ]+(?P[^:]*): # Spaces and any char that's not ':'.
- [ ](?P.*)$ # The actual log message.
- """
- logline_re = re.compile(logline_regexp, re.VERBOSE)
-
- for line in self._log_stream.get_stream_output():
- if self._should_stop.is_set():
- break
-
- if not line: # Skip empty lines.
- continue
-
- matches = logline_re.match(line)
- if not matches or len(matches.groups()) != 6:
- continue
-
- # Make sure that values are not read until all listeners are notified.
- self._line_ready.clear()
-
- # We're currently only consuming `message`, but we may use the other
- # fields in the future.
- content = matches.group('message')
- for ev, listeners in self._listeners.items():
- ev_matches = ev.match(content)
- if ev_matches:
- for listener in listeners: # Notify listeners.
- listener(ev, ev_matches)
-
- self._line_ready.set() # Allow consumers to read values.
diff --git a/android_env/components/logcat_thread_test.py b/android_env/components/logcat_thread_test.py
deleted file mode 100644
index fadcbf93..00000000
--- a/android_env/components/logcat_thread_test.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.logcat_thread."""
-
-import re
-import threading
-
-from absl.testing import absltest
-from android_env.components import log_stream
-from android_env.components import logcat_thread
-from android_env.proto import task_pb2
-
-
-class FakeStream:
- """This class simulates the logs coming from ADB."""
-
- def __init__(self):
- self._values = []
- self._kill = False
- self._lock = threading.Lock()
-
- def send_value(self, value):
- with self._lock:
- self._values.append(value)
-
- def has_next_value(self):
- return bool(self._values)
-
- def kill(self):
- self._kill = True
-
- def __iter__(self):
- while True:
- if self._kill:
- return
- if not self._values:
- continue
- else:
- with self._lock:
- next_value = self._values.pop(0)
- yield next_value
-
-
-def make_stdout(data):
- """Returns a valid log output with given data as message."""
- return ' 1553110400.424 5583 5658 D Tag: %s' % data
-
-
-class FakeLogStream(log_stream.LogStream):
- """FakeLogStream class that wraps a FakeStream."""
-
- def __init__(self):
- super().__init__(verbose=False)
- self.logs = FakeStream()
- self.stream_is_alive = True
-
- def _get_stream_output(self):
- return self.logs
-
- def stop_stream(self):
- self.stream_is_alive = False
- self.logs.kill()
-
-
-class LogcatThreadTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.fake_log_stream = FakeLogStream()
-
- def tearDown(self):
- self.fake_log_stream.stop_stream()
- super().tearDown()
-
- def test_set_filters(self):
- log_parsing_config = task_pb2.LogParsingConfig(filters=['AndroidRLTask:V'])
- self.fake_log_stream.set_log_filters(log_parsing_config.filters)
- _ = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
- expected_filters = ['AndroidRLTask:V', '*:S']
- self.assertEqual(expected_filters, self.fake_log_stream._filters)
-
- def test_kill(self):
- logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
- self.assertTrue(self.fake_log_stream.stream_is_alive)
- logcat.kill()
- self.assertFalse(self.fake_log_stream.stream_is_alive)
-
- def test_listeners(self):
- """Ensures that we can wait for a specific message without polling."""
- logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
- # Start yielding lines from LogStream.
- logcat.resume()
-
- # Set up a listener that modifies an arbitrary state.
- some_state = threading.Event()
-
- def my_handler(event: re.Pattern[str], match: re.Match[str]):
- del event, match
- nonlocal some_state
- some_state.set()
-
- # Create a desired event and hook up the listener.
- my_event = re.compile('Hello world')
- listener = logcat_thread.EventListener(my_event, my_handler)
- logcat.add_event_listener(listener)
- self.fake_log_stream.logs.send_value('Hi there!') # This should not match.
- self.assertFalse(some_state.is_set())
- self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
- some_state.wait(timeout=1.0)
- self.assertTrue(some_state.is_set())
-
- # Waiting for any events should also trigger the listener.
- some_state.clear()
- self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
- some_state.wait(timeout=1.0)
- self.assertTrue(some_state.is_set())
-
- # After removing the listener, it should not be called anymore.
- some_state.clear()
- logcat.remove_event_listener(listener)
- self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
- some_state.wait(timeout=1.0)
- self.assertFalse(some_state.is_set())
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/pixel_fns.py b/android_env/components/pixel_fns.py
deleted file mode 100644
index 65bb2eea..00000000
--- a/android_env/components/pixel_fns.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Utils for AndroidEnv."""
-
-from collections.abc import Sequence
-
-from dm_env import specs
-import numpy as np
-
-
-def touch_position_to_pixel_position(
- touch_position: np.ndarray,
- width_height: Sequence[int],
-) -> tuple[int, int]:
- """Maps touch position in [0,1] to the corresponding pixel on the screen."""
- touch_pixels = (touch_position * width_height).astype(np.int32)
- cap_idx = lambda v, idx_len: min(v, idx_len - 1)
- return tuple(map(cap_idx, touch_pixels, width_height))
-
-
-def transpose_pixels(frame: np.ndarray) -> np.ndarray:
- """Converts image from shape (H, W, C) to (W, H, C) and vice-versa."""
- return np.transpose(frame, axes=(1, 0, 2))
-
-
-def orient_pixels(frame: np.ndarray, orientation: int) -> np.ndarray:
- """Rotates screen pixels according to the given orientation."""
-
- match orientation:
- case 0: # PORTRAIT_90
- return frame
- case 1: # LANDSCAPE_90
- return np.rot90(frame, k=3, axes=(0, 1))
- case 2: # PORTRAIT_180
- return np.rot90(frame, k=2, axes=(0, 1))
- case 3: # LANDSCAPE_270
- return np.rot90(frame, k=1, axes=(0, 1))
- case _:
- raise ValueError(
- 'Orientation must be an integer in [0, 3] but is %r' % orientation
- )
-
-
-def convert_int_to_float(data: np.ndarray, data_spec: specs.Array):
- """Converts an array of int values to floats between 0 and 1."""
-
- if not np.issubdtype(data.dtype, np.integer):
- raise TypeError(f'{data.dtype} is not an integer type')
- if isinstance(data_spec, specs.BoundedArray):
- value_min = data_spec.minimum
- value_max = data_spec.maximum
- else:
- # We use the int type to figure out the boundaries.
- iinfo = np.iinfo(data_spec.dtype)
- value_min = iinfo.min
- value_max = iinfo.max
- return np.float32(1.0 * (data - value_min) / (value_max - value_min))
diff --git a/android_env/components/pixel_fns_test.py b/android_env/components/pixel_fns_test.py
deleted file mode 100644
index 82024307..00000000
--- a/android_env/components/pixel_fns_test.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for pixel_fns."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import pixel_fns
-from dm_env import specs
-import numpy as np
-
-
-class UtilsTest(parameterized.TestCase):
-
- @parameterized.parameters(
- ([0.5, 0.5], [320, 480], (160, 240)),
- ([0.25, 0.75], [320, 480], (80, 360)),
- ([0.0, 0.0], [320, 480], (0, 0)),
- ([1.0, 1.0], [320, 480], (319, 479)),
- )
- def test_touch_position_to_pixel_position(
- self, touch_pos, width_height, pixel_pos):
- self.assertEqual(
- pixel_fns.touch_position_to_pixel_position(
- np.array(touch_pos), width_height
- ),
- pixel_pos,
- )
-
- def test_transpose_pixels(self):
- image = np.reshape(np.array(range(12)), (3, 2, 2))
- expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]]
- self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2))
- self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all())
-
- def test_orient_pixels(self):
- image = np.reshape(np.array(range(12)), (3, 2, 2))
-
- expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]]
- rot_90 = 1 # LANDSCAPE_90
- rotated = pixel_fns.orient_pixels(image, rot_90)
- self.assertEqual(rotated.shape, (2, 3, 2))
- self.assertTrue((rotated == expected_90).all())
-
- expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]]
- rot_180 = 2 # PORTRAIT_180
- rotated = pixel_fns.orient_pixels(image, rot_180)
- self.assertEqual(rotated.shape, (3, 2, 2))
- self.assertTrue((rotated == expected_180).all())
-
- expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]]
- rot_270 = 3 # LANDSCAPE_270
- rotated = pixel_fns.orient_pixels(image, rot_270)
- self.assertEqual(rotated.shape, (2, 3, 2))
- self.assertTrue((rotated == expected_270).all())
-
- rot_0 = 0 # PORTRAIT_0
- rotated = pixel_fns.orient_pixels(image, rot_0)
- self.assertEqual(rotated.shape, (3, 2, 2))
- self.assertTrue((rotated == image).all())
-
- def test_convert_int_to_float_bounded_array(self):
- spec = specs.BoundedArray(
- shape=(4,),
- dtype=np.int32,
- minimum=[0, 1, 10, -2],
- maximum=[5, 5, 20, 2],
- name='bounded_array')
- data = np.array([2, 2, 10, 0], dtype=np.int32)
- float_data = pixel_fns.convert_int_to_float(data, spec)
- np.testing.assert_equal(
- np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data
- )
-
- def test_convert_int_to_float_bounded_array_broadcast(self):
- spec = specs.BoundedArray(
- shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array')
- data = np.array([2, 3, 4], dtype=np.int16)
- float_data = pixel_fns.convert_int_to_float(data, spec)
- np.testing.assert_equal(
- np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data)
-
- def test_convert_int_to_float_no_bounds(self):
- spec = specs.Array(
- shape=(3,),
- dtype=np.int8, # int8 implies min=-128, max=127
- name='bounded_array')
- data = np.array([-128, 0, 127], dtype=np.int16)
- float_data = pixel_fns.convert_int_to_float(data, spec)
- np.testing.assert_equal(
- np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/setup_step_interpreter.py b/android_env/components/setup_step_interpreter.py
deleted file mode 100644
index fb509db6..00000000
--- a/android_env/components/setup_step_interpreter.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A component that parses and processes SetupSteps."""
-
-from collections.abc import Sequence
-import copy
-import time
-from typing import Any
-
-from absl import logging
-from android_env.components import adb_call_parser as adb_call_parser_lib
-from android_env.components import app_screen_checker
-from android_env.components import errors
-from android_env.proto import adb_pb2
-from android_env.proto import task_pb2
-
-
-class SetupStepInterpreter:
- """An interpreter for SetupSteps."""
-
- def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser):
- """Initializes this interpreter.
-
- Args:
- adb_call_parser: An object to communicate with Android via ADB.
- """
- self._adb_call_parser = adb_call_parser
- self._stats = {
- 'error_count_adb_request': 0,
- 'error_count_wait_for_app_screen': 0,
- 'error_count_check_install': 0,
- 'error_count_wait_for_message': 0,
- 'total_time_waiting_for_app_screen': 0
- }
-
- def stats(self) -> dict[str, Any]:
- return copy.deepcopy(self._stats)
-
- def interpret(self, setup_steps: Sequence[task_pb2.SetupStep]) -> None:
- """Returns True if parsing and processing `setup_steps` is successful."""
- if setup_steps:
- logging.info('Executing setup steps: %s', setup_steps)
- for step in setup_steps:
- self._process_step_command(step)
- logging.info('Done executing setup steps.')
-
- def _process_step_command(self, step_cmd: task_pb2.SetupStep) -> None:
- """Processes a single step command from a reset or extra setup."""
-
- if not step_cmd:
- logging.info('Empty step_cmd')
- return
-
- logging.info('Executing step_cmd: %r', step_cmd)
- step_type = step_cmd.WhichOneof('step')
- success_condition = step_cmd.success_condition
- success_check = success_condition.WhichOneof('check')
- assert step_type or success_check, (
- 'At least one of step and success_condition must be defined.')
-
- num_tries = 0
- max_retries = max(success_condition.num_retries, 3)
- latest_error = None
- while num_tries < max_retries:
-
- num_tries += 1
-
- try:
- unused_adb_response = self._execute_step_cmd(step_cmd, step_type)
- time.sleep(0.5)
- self._check_success(success_check, success_condition)
- return
-
- except NotImplementedError:
- logging.exception('Not implemented error! Skipping this step command.')
- return
-
- except errors.AdbControllerError as error:
- latest_error = error
- self._stats['error_count_adb_request'] += 1
- logging.exception('ADB call [%r] has failed. Try %d of %d.',
- step_cmd.adb_request, num_tries, max_retries)
-
- except errors.WaitForAppScreenError as error:
- latest_error = error
- self._stats['error_count_wait_for_app_screen'] += 1
- logging.exception('Failed to wait for app screen. Try %d of %d.',
- num_tries, max_retries)
-
- except errors.CheckInstallError as error:
- latest_error = error
- self._stats['error_count_check_install'] += 1
- logging.exception('Package [%r] not installed. Try %d of %d.',
- success_condition.check_install.package_name,
- num_tries, max_retries)
-
- raise errors.StepCommandError(
- f'Step failed: [{step_cmd}]') from latest_error
-
- def _execute_step_cmd(
- self, step_cmd: task_pb2.SetupStep, step_type: str | None
- ) -> adb_pb2.AdbResponse | None:
- """Executes a step command of given type."""
-
- match step_type:
- case None:
- return None
- case 'sleep':
- time.sleep(step_cmd.sleep.time_sec)
- return None
- case 'adb_request':
- response = self._adb_call_parser.parse(step_cmd.adb_request)
- if response.status != adb_pb2.AdbResponse.Status.OK:
- raise errors.AdbControllerError(
- f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n'
- f'Status: {response.status}\n'
- f'Error: {response.error_message}'
- )
- return response
- case _:
- raise NotImplementedError(f'No step command of type [{step_type}].')
-
- def _check_success(
- self,
- success_check: str | None,
- success_condition: task_pb2.SuccessCondition,
- ) -> None:
- """Checks whether the given success condition was met."""
-
- match success_check:
- case None:
- return None
- case 'wait_for_app_screen':
- wait_for_app_screen = success_condition.wait_for_app_screen
- screen_checker = app_screen_checker.AppScreenChecker(
- adb_call_parser=self._adb_call_parser,
- expected_app_screen=wait_for_app_screen.app_screen,
- )
- wait_time = screen_checker.wait_for_app_screen(
- timeout_sec=wait_for_app_screen.timeout_sec
- )
- self._stats['total_time_waiting_for_app_screen'] += wait_time
- case 'check_install':
- self._check_install(success_condition.check_install)
- case _:
- raise NotImplementedError(f'No success check called [{success_check}].')
-
- def _check_install(self, check_install: task_pb2.CheckInstall) -> None:
- """Checks that the given package is installed."""
-
- package = check_install.package_name
- logging.info('Checking if package is installed: [%r]', package)
-
- request = adb_pb2.AdbRequest(
- package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
- list=adb_pb2.AdbRequest.PackageManagerRequest.List(
- packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages(
- ))))
-
- start_time = time.time()
- while time.time() - start_time < check_install.timeout_sec:
- response = self._adb_call_parser.parse(request)
- if package in response.package_manager.list.items:
- logging.info('Done confirming that package is installed.')
- return
- time.sleep(0.1)
-
- logging.error('Package not found.')
- raise errors.CheckInstallError()
diff --git a/android_env/components/setup_step_interpreter_test.py b/android_env/components/setup_step_interpreter_test.py
deleted file mode 100644
index 26468159..00000000
--- a/android_env/components/setup_step_interpreter_test.py
+++ /dev/null
@@ -1,342 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.setup_step_interpreter."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_call_parser
-from android_env.components import errors
-from android_env.components import setup_step_interpreter
-from android_env.proto import adb_pb2
-from android_env.proto import task_pb2
-
-from google.protobuf import text_format
-
-
-def _to_proto(proto_class, text):
- proto = proto_class()
- text_format.Parse(text, proto)
- return proto
-
-
-class SetupStepInterpreterTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self._parser = mock.create_autospec(
- adb_call_parser.AdbCallParser, instance=True)
-
- def test_empty_setup_steps(self):
- """Simple test where nothing should break, and nothing should be done.
-
- The test simply expects this test to not crash.
- """
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([])
-
- def test_none_setup_steps(self):
- """Simple test where nothing should break, and nothing should be done.
-
- The test simply expects this test to not crash.
- """
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- # Empty setup steps should be ignored.
- interpreter.interpret([])
-
- def test_invalid_setup_step(self):
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- # Empty setup steps should be ignored.
- self.assertRaises(AssertionError, interpreter.interpret,
- [task_pb2.SetupStep()])
-
- def test_adb_install_apk_filesystem(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-adb_request: {
- install_apk: {
- filesystem: {
- path: "/my/favorite/dir/my_apk.apk"
- }
- }
-}""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- install_apk=adb_pb2.AdbRequest.InstallApk(
- filesystem=adb_pb2.AdbRequest.InstallApk.Filesystem(
- path='/my/favorite/dir/my_apk.apk'))))
-
- def test_adb_force_stop(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-adb_request: { force_stop: { package_name: "my.app.Activity" } }""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- force_stop=adb_pb2.AdbRequest.ForceStop(
- package_name='my.app.Activity')))
-
- def test_adb_start_activity(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-adb_request: {
- start_activity: {
- full_activity: "my.app.Activity"
- extra_args: "arg1"
- extra_args: "arg2"
- }
-}""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- start_activity=adb_pb2.AdbRequest.StartActivity(
- full_activity='my.app.Activity', extra_args=['arg1', 'arg2'])))
-
- def test_adb_single_tap(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(task_pb2.SetupStep, """
-adb_request: {
- tap: {
- x: 321
- y: 654
- }
-}""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=321, y=654)))
-
- def test_adb_press_button(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(task_pb2.SetupStep,
- """ adb_request: { press_button: { button: HOME } }""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- press_button=adb_pb2.AdbRequest.PressButton(
- button=adb_pb2.AdbRequest.PressButton.Button.HOME)))
-
- self._parser.reset_mock()
- interpreter.interpret([
- _to_proto(task_pb2.SetupStep,
- """ adb_request: { press_button: { button: BACK } }""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- press_button=adb_pb2.AdbRequest.PressButton(
- button=adb_pb2.AdbRequest.PressButton.Button.BACK)))
-
- def test_adb_start_screen_pinning(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-adb_request: {
- start_screen_pinning: {
- full_activity: "my.app.HighlanderApp" # "There can be only one".
- }
-}""")
- ])
- self._parser.parse.assert_called_once_with(
- adb_pb2.AdbRequest(
- start_screen_pinning=adb_pb2.AdbRequest.StartScreenPinning(
- full_activity='my.app.HighlanderApp')))
-
- @mock.patch('time.sleep')
- def test_time_sleep(self, mock_sleep):
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret(
- [_to_proto(task_pb2.SetupStep, """sleep: { time_sec: 0.875 }""")])
- assert mock_sleep.call_count == 2
- mock_sleep.assert_has_calls([mock.call(0.875), mock.call(0.5)])
-
- @mock.patch('time.sleep')
- def test_wait_for_app_screen_empty_activity(self, unused_mock_sleep):
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- with self.assertRaises(errors.StepCommandError):
- interpreter.interpret([
- _to_proto(task_pb2.SetupStep,
- """success_condition: {wait_for_app_screen: { }}""")
- ])
-
- @mock.patch('time.sleep')
- def test_check_install_not_installed(self, unused_mock_sleep):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
- 'com.some.package',
- 'not.what.you.are.looking.for',
- ])))
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- with self.assertRaises(errors.StepCommandError):
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-success_condition: {
- check_install: {
- package_name: "faz"
- timeout_sec: 0.0001
- }
-}
-""")
- ])
-
- def test_check_install_installed(self):
- self._parser.parse.return_value = adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
- 'com.some.package',
- 'baz',
- ])))
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- # The test checks that this command raises no AssertionError.
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-success_condition: {
- check_install: {
- package_name: "baz"
- timeout_sec: 0.0001
- }
-}""")
- ])
-
- def test_num_retries_failure(self):
- self._parser.parse.side_effect = [
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(
- items=[]))),
- ] * 3
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- with self.assertRaises(errors.StepCommandError):
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-success_condition: {
- check_install: {
- package_name: "faz"
- timeout_sec: 0.0001
- }
- num_retries: 3
-}""")
- ])
- # We retried 3 times after the first call, so we expect 3+1 calls.
- self.assertEqual(self._parser.parse.call_count, 3)
-
- @mock.patch('time.sleep')
- def test_num_retries_success(self, unused_mock_sleep):
- self._parser.parse.side_effect = [
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(
- items=[]))),
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(
- items=[]))),
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
- 'com.some.package',
- 'bar',
- ]))),
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[])))
- ]
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-success_condition: {
- check_install: {
- package_name: "bar"
- timeout_sec: 0.0001
- }
- num_retries: 5
-}""")
- ])
- # The check should succeed on the third try.
- self.assertEqual(self._parser.parse.call_count, 3)
-
- def test_retry_step(self):
- self._parser.parse.side_effect = [
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(
- items=[]))),
- adb_pb2.AdbResponse(
- package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
- list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
- 'com.some.package',
- 'bar',
- ]))),
- ]
- interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=self._parser)
- interpreter.interpret([
- _to_proto(
- task_pb2.SetupStep, """
-success_condition: {
- check_install: {
- package_name: "bar"
- timeout_sec: 0.0001
- }
- num_retries: 2
-}""")
- ])
- # We expect the check to fail once and succeed on the second pass.
- self.assertEqual(self._parser.parse.call_count, 2)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/simulators/__init__.py b/android_env/components/simulators/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/components/simulators/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/components/simulators/base_simulator.py b/android_env/components/simulators/base_simulator.py
deleted file mode 100644
index 49ab1040..00000000
--- a/android_env/components/simulators/base_simulator.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A base class for talking to different types of Android simulators."""
-
-import abc
-from collections.abc import Callable
-import threading
-import time
-
-from absl import logging
-from android_env.components import adb_controller
-from android_env.components import config_classes
-from android_env.components import errors
-from android_env.components import log_stream
-from android_env.proto import state_pb2
-import numpy as np
-
-
-class BaseSimulator(metaclass=abc.ABCMeta):
- """An interface for communicating with an Android simulator."""
-
- def __init__(self, config: config_classes.SimulatorConfig):
- """Instantiates a BaseSimulator object.
-
- The simulator may be an emulator, virtual machine or even a physical device.
- Each simulator has its own AdbController that is used for internal
- bookkeeping.
-
- Args:
- config: Settings for this simulator.
- """
-
- self._config = config
- self._interaction_thread: InteractionThread | None = None
-
- # An increasing number that tracks the attempt at launching the simulator.
- self._num_launch_attempts: int = 0
-
- def get_logs(self) -> str:
- """Returns logs recorded by the simulator (if provided)."""
- return 'No simulator logs provided.'
-
- @abc.abstractmethod
- def adb_device_name(self) -> str:
- """Returns the device name that the adb client will connect to."""
-
- @abc.abstractmethod
- def create_adb_controller(self) -> adb_controller.AdbController:
- """Returns an ADB controller which can communicate with this simulator."""
-
- @abc.abstractmethod
- def create_log_stream(self) -> log_stream.LogStream:
- """Creates a stream of logs from the simulator."""
-
- def launch(self) -> None:
- """Starts the simulator."""
-
- # Stop screenshot thread if it's enabled.
- if self._interaction_thread is not None:
- self._interaction_thread.stop()
- self._interaction_thread.join()
-
- self._num_launch_attempts += 1
- try:
- self._launch_impl()
- except Exception as error:
- for line in self.get_logs().splitlines():
- logging.error(line)
- raise errors.SimulatorError(
- 'Exception caught in simulator. Please see the simulator logs '
- 'above for more details.'
- ) from error
-
- # Start interaction thread.
- if self._config.interaction_rate_sec > 0:
- self._interaction_thread = InteractionThread(
- self._get_screenshot_impl, self._config.interaction_rate_sec
- )
- self._interaction_thread.start()
-
- @abc.abstractmethod
- def _launch_impl(self) -> None:
- """Platform specific launch implementation."""
-
- @abc.abstractmethod
- def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
- """Sends a touch event to be executed on the simulator.
-
- Args:
- touches: A list of touch events. Each element in the list corresponds to a
- single touch event. Each touch event tuple should have:
- 0 x: The horizontal coordinate of this event.
- 1 y: The vertical coordinate of this event.
- 2 is_down: Whether the finger is touching or not the screen.
- 3 identifier: Identifies a particular finger in a multitouch event.
- """
-
- @abc.abstractmethod
- def send_key(self, keycode: np.int32, event_type: str) -> None:
- """Sends a keyboard event.
-
- Args:
- keycode: Represents a specific keyboard key. This is platform and
- simulator-specific.
- event_type: Type of key event to be sent.
- """
-
- def load_state(
- self, request: state_pb2.LoadStateRequest
- ) -> state_pb2.LoadStateResponse:
- """Loads a state.
-
- Args:
- request: A `LoadStateRequest` containing any parameters necessary to
- specify how/what state to load.
-
- Returns:
- A `LoadStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
- raise NotImplementedError('This simulator does not support load_state()')
-
- def save_state(
- self, request: state_pb2.SaveStateRequest
- ) -> state_pb2.SaveStateResponse:
- """Saves a state.
-
- Args:
- request: A `SaveStateRequest` containing any parameters necessary to
- specify how/what state to save.
-
- Returns:
- A `SaveStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
- raise NotImplementedError('This simulator does not support save_state()')
-
- def get_screenshot(self) -> np.ndarray:
- """Returns pixels representing the current screenshot of the simulator."""
-
- if self._config.interaction_rate_sec > 0:
- assert self._interaction_thread is not None
- return self._interaction_thread.screenshot() # Async mode.
- else:
- return self._get_screenshot_impl() # Sync mode.
-
- @abc.abstractmethod
- def _get_screenshot_impl(self) -> np.ndarray:
- """Actual implementation of `get_screenshot()`.
-
- The output numpy array should have shape [height, width, num_channels] and
- can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved
- as a PNG file using my_pil.save('/tmp/my_screenshot.png', 'PNG').
- """
-
- def close(self):
- """Frees up resources allocated by this object."""
-
- if self._interaction_thread is not None:
- self._interaction_thread.stop()
- self._interaction_thread.join()
-
-
-class InteractionThread(threading.Thread):
- """A thread that gets screenshot in the background."""
-
- def __init__(
- self,
- get_screenshot_fn: Callable[[], np.ndarray],
- interaction_rate_sec: float,
- ):
- super().__init__()
- self._get_screenshot_fn = get_screenshot_fn
- self._interaction_rate_sec = interaction_rate_sec
- self._should_stop = threading.Event()
- self._screenshot = self._get_screenshot_fn()
-
- def run(self):
- last_read = time.time()
- while not self._should_stop.is_set():
- self._screenshot = self._get_screenshot_fn()
- now = time.time()
- elapsed = now - last_read
- last_read = now
- sleep_time = self._interaction_rate_sec - elapsed
- if sleep_time > 0.0:
- time.sleep(sleep_time)
- logging.info('InteractionThread.run() finished.')
-
- def stop(self):
- logging.info('Stopping InteractionThread.')
- self._should_stop.set()
-
- def screenshot(self) -> np.ndarray:
- return self._screenshot
diff --git a/android_env/components/simulators/base_simulator_test.py b/android_env/components/simulators/base_simulator_test.py
deleted file mode 100644
index 6b5457ad..00000000
--- a/android_env/components/simulators/base_simulator_test.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import itertools
-import time
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import config_classes
-from android_env.components import errors
-# fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need
-# to import it here explicitly.
-from android_env.components.simulators import base_simulator
-from android_env.components.simulators.fake import fake_simulator
-import numpy as np
-
-
-class BaseSimulatorTest(absltest.TestCase):
-
- def test_launch(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
- )
- # The simulator should launch and not crash.
- simulator.launch()
-
- def test_launch_close(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig()
- )
- # The simulator should launch and not crash.
- simulator.launch()
- # Closing the simulator should also not crash.
- simulator.close()
-
- def test_get_screenshot(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
- )
- # The simulator should launch and not crash.
- simulator.launch()
-
- screenshot = simulator.get_screenshot()
- np.testing.assert_equal(screenshot.shape, [640, 480, 3])
-
- def test_print_logs_on_exception(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig()
- )
- with mock.patch.object(
- simulator, 'get_logs'
- ) as mock_get_logs, mock.patch.object(
- simulator, '_launch_impl', autospec=True
- ) as mock_launch:
- mock_launch.side_effect = ValueError('Oh no!')
- self.assertRaises(errors.SimulatorError, simulator.launch)
- mock_get_logs.assert_called_once()
-
- def test_get_screenshot_error_async(self):
- """An exception in the underlying interaction thread should bubble up."""
-
- # Arrange.
- mock_interaction_thread = mock.create_autospec(
- base_simulator.InteractionThread
- )
- mock_interaction_thread.screenshot.side_effect = (
- errors.ReadObservationError()
- )
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5)
- )
- with mock.patch.object(
- base_simulator,
- 'InteractionThread',
- autospec=True,
- return_value=mock_interaction_thread,
- ):
- simulator.launch()
-
- # Act & Assert.
- self.assertRaises(errors.ReadObservationError, simulator.get_screenshot)
-
- # Cleanup.
- simulator.close()
-
- def test_get_screenshot_faster_than_screenshot_impl(self):
- """Return same screenshot when step is faster than the interaction rate."""
-
- # Arrange.
- slow_rate = 0.5
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate)
- )
-
- # Act.
- with mock.patch.object(
- simulator, '_get_screenshot_impl', autospec=True
- ) as mock_get_screenshot_impl:
- mock_get_screenshot_impl.side_effect = (
- np.array(i, ndmin=3) for i in itertools.count(0, 1)
- )
- simulator.launch()
- # Get two screenshots one after the other without pausing.
- screenshot1 = simulator.get_screenshot()
- screenshot2 = simulator.get_screenshot()
-
- # Assert.
- self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0])
-
- # Cleanup.
- simulator.close()
-
- def test_get_screenshot_slower_than_screenshot_impl(self):
- """Return different screenshots when step slower than the interaction rate."""
-
- # Arrange.
- fast_rate = 0.01
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate)
- )
-
- # Act.
- with mock.patch.object(
- simulator, '_get_screenshot_impl', autospec=True
- ) as mock_get_screenshot_impl:
- mock_get_screenshot_impl.side_effect = (
- np.array(i, ndmin=3) for i in itertools.count(0, 1)
- )
- simulator.launch()
- # Sleep for 500ms between two screenshots.
- screenshot1 = simulator.get_screenshot()
- time.sleep(0.5)
- screenshot2 = simulator.get_screenshot()
-
- # Assert.
- self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0])
-
- # Cleanup.
- simulator.close()
-
- def test_interaction_thread_closes_upon_relaunch(self):
- """Async interaction should kill the InteractionThread when relaunching."""
-
- # Arrange.
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(interaction_rate_sec=0.01)
- )
- mock_interaction_thread = mock.create_autospec(
- base_simulator.InteractionThread
- )
-
- # Act & Assert.
- with mock.patch.object(
- base_simulator,
- 'InteractionThread',
- autospec=True,
- return_value=mock_interaction_thread,
- ):
- simulator.launch()
- mock_interaction_thread.stop.assert_not_called()
- mock_interaction_thread.join.assert_not_called()
- simulator.launch()
- mock_interaction_thread.stop.assert_called_once()
- mock_interaction_thread.join.assert_called_once()
- simulator.close()
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/simulators/emulator/__init__.py b/android_env/components/simulators/emulator/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/components/simulators/emulator/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/components/simulators/emulator/emulator_launcher.py b/android_env/components/simulators/emulator/emulator_launcher.py
deleted file mode 100644
index a1abe3f3..00000000
--- a/android_env/components/simulators/emulator/emulator_launcher.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Prepares and launches an emulator process."""
-
-import glob
-import os
-import subprocess
-import tempfile
-
-from absl import logging
-from android_env.components import config_classes
-
-
-class EmulatorLauncher:
- """Handles launching an emulator."""
-
- def __init__(
- self,
- config: config_classes.EmulatorLauncherConfig,
- adb_controller_config: config_classes.AdbControllerConfig,
- ):
- """Launches an emulator."""
-
- self._config = config
- self._adb_controller_config = adb_controller_config
-
- self._emulator = None
- self._emulator_output = None
- self._is_closed = False
-
- # Create directory for tmp files.
- # Note: this will be deleted once EmulatorLauncher instance is cleaned up.
- os.makedirs(config.tmp_dir, exist_ok=True)
- self._local_tmp_dir_handle = tempfile.TemporaryDirectory(
- dir=config.tmp_dir, prefix='simulator_instance_'
- )
- self._local_tmp_dir = self._local_tmp_dir_handle.name
- self._logfile_path = os.path.join(self._local_tmp_dir, 'emulator_output')
- logging.info('Simulator local_tmp_dir: %s', self._local_tmp_dir)
-
- def logfile_path(self) -> str:
- return self._logfile_path
-
- def launch_emulator_process(self) -> None:
- """Launches the emulator."""
-
- logging.info('Booting new emulator: %s', self._config.emulator_path)
-
- # Set necessary environment variables.
- base_lib_dir = self._config.emulator_path[:-8] + 'lib64/'
- ld_library_path = ':'.join([
- base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/',
- base_lib_dir + 'gles_swiftshader/', base_lib_dir
- ])
- extra_env_vars = {
- 'ANDROID_HOME': '',
- 'ANDROID_SDK_ROOT': self._config.android_sdk_root,
- 'ANDROID_AVD_HOME': self._config.android_avd_home,
- 'ANDROID_EMULATOR_KVM_DEVICE': self._config.kvm_device,
- 'ANDROID_ADB_SERVER_PORT': str(
- self._adb_controller_config.adb_server_port
- ),
- 'LD_LIBRARY_PATH': ld_library_path,
- 'QT_XKB_CONFIG_ROOT': str(
- self._config.emulator_path[:-8] + 'qt_config/'
- ),
- 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1',
- 'SHOW_PERF_STATS': str(1 if self._config.show_perf_stats else 0),
- }
- logging.info('extra_env_vars: %s',
- ' '.join(f'{k}={v}' for k, v in extra_env_vars.items()))
- env_vars = dict(os.environ).copy()
- env_vars.update(extra_env_vars)
-
- # Compile command.
- grpc_port = (
- ['-grpc', str(self._config.grpc_port)]
- if self._config.grpc_port >= 0
- else []
- )
- run_headless = (
- ['-no-skin', '-no-window'] if self._config.run_headless else []
- )
- ports = [
- '-ports',
- '%s,%s' % (self._config.emulator_console_port, self._config.adb_port),
- ]
- snapshot = [
- '-snapshot',
- self._config.snapshot_name,
- '-feature',
- 'AllowSnapshotMigration,MigratableSnapshotSave',
- ]
- snapshot = snapshot if self._config.snapshot_name else ['-no-snapshot']
- restrict_network_args = [
- '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options',
- 'restrict=y'
- ]
- network_args = (
- restrict_network_args if self._config.restrict_network else []
- )
- command = (
- [
- self._config.emulator_path,
- '-adb-path',
- self._adb_controller_config.adb_path,
- '-gpu',
- self._config.gpu_mode,
- '-no-audio',
- '-show-kernel',
- '-verbose',
- '-avd',
- self._config.avd_name,
- ]
- + grpc_port
- + run_headless
- + ports
- + snapshot
- + network_args
- )
- logging.info('Emulator launch command: %s', ' '.join(command))
- # Prepare logfile.
- self._emulator_output = open(self._logfile_path, 'wb')
-
- # Spawn the emulator process.
- self._emulator = subprocess.Popen(
- command,
- env=env_vars,
- stdout=self._emulator_output,
- stderr=self._emulator_output)
-
- def confirm_shutdown(self) -> None:
- """Shuts down the emulator process."""
- if self._emulator is not None:
- logging.info('Checking if emulator process has finished...')
- try:
- self._emulator.wait(timeout=30.0)
- except subprocess.TimeoutExpired:
- logging.exception(
- 'The emulator process did not finish after 30s. '
- 'returncode: %s. Will now try to kill() it.',
- self._emulator.returncode)
- self._emulator.kill()
- self._emulator = None
- self._emulator_output.close()
- logging.info('The emulator process has finished.')
-
- def close(self):
- """Clean up launcher files and processes."""
- if not self._is_closed:
- self._local_tmp_dir_handle.cleanup()
- self.confirm_shutdown()
- self._is_closed = True
-
- def __del__(self):
- self.close()
diff --git a/android_env/components/simulators/emulator/emulator_launcher_test.py b/android_env/components/simulators/emulator/emulator_launcher_test.py
deleted file mode 100644
index 3bfd44f1..00000000
--- a/android_env/components/simulators/emulator/emulator_launcher_test.py
+++ /dev/null
@@ -1,294 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.emulator_launcher."""
-
-import builtins
-import os
-import subprocess
-import tempfile
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import config_classes
-from android_env.components.simulators.emulator import emulator_launcher
-
-
-class EmulatorLauncherTest(parameterized.TestCase):
-
- def setUp(self):
- super().setUp()
-
- self._emulator_path = 'fake/path/emulator'
- self._adb_path = 'fake/path/adb'
- self._adb_port = 5554
- self._adb_server_port = 1234
- self._emulator_console_port = 5555
- self._avd_name = 'my_avd_name'
-
- self._expected_command = [
- self._emulator_path,
- '-adb-path',
- 'fake/path/adb',
- '-gpu',
- 'swangle_indirect',
- '-no-audio',
- '-show-kernel',
- '-verbose',
- '-avd',
- self._avd_name,
- ]
- self._headless = ['-no-skin', '-no-window']
- self._ports = ['-ports', f'{self._emulator_console_port},{self._adb_port}']
- self._snapshot = ['-no-snapshot']
-
- base_lib_dir = self._emulator_path[:-8] + 'lib64/'
- ld_library_path = ':'.join([
- base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/',
- base_lib_dir + 'gles_swiftshader/', base_lib_dir
- ])
-
- # Instantiate the config to extract default values.
- config = config_classes.EmulatorLauncherConfig()
- self._expected_env_vars = {
- 'ANDROID_HOME': '',
- 'ANDROID_SDK_ROOT': config.android_sdk_root,
- 'ANDROID_AVD_HOME': config.android_avd_home,
- 'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm',
- 'ANDROID_ADB_SERVER_PORT': '1234',
- 'LD_LIBRARY_PATH': ld_library_path,
- 'QT_XKB_CONFIG_ROOT': str(self._emulator_path[:-8] + 'qt_config/'),
- 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1',
- }
-
- @parameterized.named_parameters([
- ('hide_perf_stats', False),
- ('show_perf_stats', True),
- ])
- @mock.patch.object(os, 'makedirs')
- @mock.patch.object(os, 'environ', autospec=True, return_value=dict())
- @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
- def test_launch(
- self,
- show_perf_stats: bool,
- mock_tmp_dir,
- unused_os_environ,
- unused_os_makedirs,
- ):
- mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
-
- config = config_classes.EmulatorLauncherConfig(
- adb_port=self._adb_port,
- emulator_console_port=self._emulator_console_port,
- emulator_path=self._emulator_path,
- avd_name=self._avd_name,
- grpc_port=-1,
- show_perf_stats=show_perf_stats,
- )
- adb_controller_config = config_classes.AdbControllerConfig(
- adb_path=self._adb_path,
- adb_server_port=self._adb_server_port,
- )
- launcher = emulator_launcher.EmulatorLauncher(
- config=config, adb_controller_config=adb_controller_config
- )
-
- expected_env_vars = self._expected_env_vars
- expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
-
- with mock.patch.object(
- subprocess, 'Popen', autospec=True
- ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f:
- f.return_value.__enter__ = f()
- launcher.launch_emulator_process()
- emulator_init.assert_called_once_with(
- args=self._expected_command
- + self._headless
- + self._ports
- + self._snapshot,
- env=expected_env_vars,
- stdout=f(),
- stderr=f(),
- )
-
- @parameterized.named_parameters([
- ('hide_perf_stats', False),
- ('show_perf_stats', True),
- ])
- @mock.patch.object(os, 'makedirs')
- @mock.patch.object(os, 'environ', autospec=True, return_value=dict())
- @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
- def test_grpc_port(
- self,
- show_perf_stats: bool,
- mock_tmp_dir,
- unused_os_environ,
- unused_os_makedirs,
- ):
- mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
-
- config = config_classes.EmulatorLauncherConfig(
- adb_port=self._adb_port,
- emulator_console_port=self._emulator_console_port,
- emulator_path=self._emulator_path,
- avd_name=self._avd_name,
- grpc_port=8554,
- show_perf_stats=show_perf_stats,
- )
- adb_controller_config = config_classes.AdbControllerConfig(
- adb_path=self._adb_path,
- adb_server_port=self._adb_server_port,
- )
- launcher = emulator_launcher.EmulatorLauncher(
- config=config, adb_controller_config=adb_controller_config
- )
-
- expected_env_vars = self._expected_env_vars
- expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
-
- with mock.patch.object(
- subprocess, 'Popen', autospec=True
- ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f:
- f.return_value.__enter__ = f()
- launcher.launch_emulator_process()
- emulator_init.assert_called_once_with(
- args=self._expected_command
- + ['-grpc', '8554']
- + self._headless
- + self._ports
- + self._snapshot,
- env=expected_env_vars,
- stdout=f(),
- stderr=f(),
- )
-
- @parameterized.named_parameters([
- ('hide_perf_stats', False),
- ('show_perf_stats', True),
- ])
- @mock.patch.object(os, 'makedirs')
- @mock.patch.object(os, 'environ', autospec=True, return_value=dict())
- @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
- def test_snapshot(
- self,
- show_perf_stats: bool,
- mock_tmp_dir,
- unused_os_environ,
- unused_os_makedirs,
- ):
- mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
-
- config = config_classes.EmulatorLauncherConfig(
- adb_port=self._adb_port,
- emulator_console_port=self._emulator_console_port,
- emulator_path=self._emulator_path,
- avd_name=self._avd_name,
- grpc_port=-1,
- snapshot_name='my_snapshot',
- show_perf_stats=show_perf_stats,
- )
- adb_controller_config = config_classes.AdbControllerConfig(
- adb_path=self._adb_path,
- adb_server_port=self._adb_server_port,
- )
- launcher = emulator_launcher.EmulatorLauncher(
- config=config, adb_controller_config=adb_controller_config
- )
-
- expected_snapshot = [
- '-snapshot', 'my_snapshot', '-feature',
- 'AllowSnapshotMigration,MigratableSnapshotSave'
- ]
-
- expected_env_vars = self._expected_env_vars
- expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
-
- with mock.patch.object(
- subprocess, 'Popen', autospec=True) as emulator_init, \
- mock.patch.object(builtins, 'open', autospec=True) as f:
- f.return_value.__enter__ = f()
- launcher.launch_emulator_process()
- emulator_init.assert_called_once_with(
- args=self._expected_command
- + self._headless
- + self._ports
- + expected_snapshot,
- env=expected_env_vars,
- stdout=f(),
- stderr=f(),
- )
-
- @parameterized.named_parameters([
- ('hide_perf_stats', False),
- ('show_perf_stats', True),
- ])
- @mock.patch.object(os, 'makedirs')
- @mock.patch.object(os, 'environ', autospec=True, return_value=dict())
- @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
- def test_network_restrict(
- self,
- show_perf_stats: bool,
- mock_tmp_dir,
- unused_os_environ,
- unused_os_makedirs,
- ):
- mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
-
- config = config_classes.EmulatorLauncherConfig(
- adb_port=self._adb_port,
- emulator_console_port=self._emulator_console_port,
- emulator_path=self._emulator_path,
- avd_name=self._avd_name,
- grpc_port=-1,
- restrict_network=True,
- show_perf_stats=show_perf_stats,
- )
- adb_controller_config = config_classes.AdbControllerConfig(
- adb_path=self._adb_path,
- adb_server_port=self._adb_server_port,
- )
- launcher = emulator_launcher.EmulatorLauncher(
- config=config, adb_controller_config=adb_controller_config
- )
-
- expected_snapshot = ['-no-snapshot']
- expected_network_restrict = [
- '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options',
- 'restrict=y'
- ]
-
- expected_env_vars = self._expected_env_vars
- expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
-
- with mock.patch.object(
- subprocess, 'Popen', autospec=True) as emulator_init, \
- mock.patch.object(builtins, 'open', autospec=True) as f:
- f.return_value.__enter__ = f()
- launcher.launch_emulator_process()
- emulator_init.assert_called_once_with(
- self._expected_command
- + self._headless
- + self._ports
- + expected_snapshot
- + expected_network_restrict,
- env=expected_env_vars,
- stdout=f(),
- stderr=f(),
- )
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/simulators/emulator/emulator_simulator.py b/android_env/components/simulators/emulator/emulator_simulator.py
deleted file mode 100644
index d7f61874..00000000
--- a/android_env/components/simulators/emulator/emulator_simulator.py
+++ /dev/null
@@ -1,486 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A class that manages an Android Emulator."""
-
-import os
-import time
-from typing import Any
-
-from absl import logging
-from android_env.components import adb_controller
-from android_env.components import adb_log_stream
-from android_env.components import config_classes
-from android_env.components import errors
-from android_env.components import log_stream
-from android_env.components.simulators import base_simulator
-from android_env.components.simulators.emulator import emulator_launcher
-from android_env.proto import state_pb2
-import grpc
-import numpy as np
-import portpicker
-
-from android_env.proto import emulator_controller_pb2
-from android_env.proto import emulator_controller_pb2_grpc
-from android_env.proto import snapshot_service_pb2
-from android_env.proto import snapshot_service_pb2_grpc
-from google.protobuf import empty_pb2
-
-
-_DEFAULT_SNAPSHOT_NAME: str = 'default_snapshot'
-
-
-def _is_existing_emulator_provided(
- launcher_config: config_classes.EmulatorLauncherConfig,
-) -> bool:
- """Returns true if all necessary args were provided."""
-
- return bool(
- launcher_config.adb_port
- and launcher_config.emulator_console_port
- and launcher_config.grpc_port
- )
-
-
-def _pick_adb_port() -> int:
- """Tries to pick a port in the recommended range 5555-5585.
-
- If no such port can be found, will return a random unused port. More info:
- https://developer.android.com/studio/command-line/adb#howadbworks.
-
- Returns:
- port: an available port for adb.
- """
-
- for p in range(5555, 5587, 2):
- if portpicker.is_port_free(p):
- return p
- return portpicker.pick_unused_port()
-
-
-def _pick_emulator_grpc_port() -> int:
- """Tries to pick the recommended port for grpc.
-
- If no such port can be found, will return a random unused port. More info:
- https://android.googlesource.com/platform/external/qemu/+/emu-master-dev/android/android-grpc/docs/.
-
- Returns:
- port: an available port for emulator grpc.
- """
-
- if portpicker.is_port_free(8554):
- return 8554
- else:
- return portpicker.pick_unused_port()
-
-
-class EmulatorBootError(errors.SimulatorError):
- """Raised when an emulator failed to boot."""
-
-
-class EmulatorCrashError(errors.SimulatorError):
- """Raised when a simulator crashed."""
-
-
-class EmulatorSimulator(base_simulator.BaseSimulator):
- """Controls an Android Emulator."""
-
- def __init__(self, config: config_classes.EmulatorConfig):
- """Instantiates an EmulatorSimulator."""
-
- super().__init__(config)
- self._config = config
-
- # If adb_port, console_port and grpc_port are all already provided,
- # we assume the emulator already exists and there's no need to launch.
- if _is_existing_emulator_provided(self._config.emulator_launcher):
- self._existing_emulator_provided = True
- logging.info('Connecting to existing emulator "%r"',
- self.adb_device_name())
- else:
- self._existing_emulator_provided = False
- self._config.emulator_launcher.adb_port = _pick_adb_port()
- self._config.emulator_launcher.emulator_console_port = (
- portpicker.pick_unused_port()
- )
- self._config.emulator_launcher.grpc_port = _pick_emulator_grpc_port()
-
- self._channel = None
- self._emulator_stub: emulator_controller_pb2_grpc.EmulatorControllerStub | None = (
- None
- )
- self._snapshot_stub = None
- # Set the image format to RGBA. The width and height of the returned
- # screenshots will use the device's width and height.
- self._image_format = emulator_controller_pb2.ImageFormat(
- format=emulator_controller_pb2.ImageFormat.ImgFormat.RGBA8888)
-
- if (
- self._config.launch_n_times_without_reboot
- > self._config.launch_n_times_without_reinstall
- ):
- raise ValueError(
- 'Number of launch attempts before reboot'
- f' ({self._config.launch_n_times_without_reboot}) should not be'
- ' greater than number of launch attempts before reinstall'
- f' ({self._config.launch_n_times_without_reinstall})'
- )
-
- # Initialize own ADB controller.
- self._config.adb_controller.device_name = self.adb_device_name()
- self._adb_controller = self.create_adb_controller()
- self._adb_controller.init_server()
- logging.info(
- 'Initialized simulator with ADB server port %r.',
- self._config.adb_controller.adb_server_port,
- )
-
- # If necessary, create EmulatorLauncher.
- if self._existing_emulator_provided:
- self._logfile_path = self._config.logfile_path or None
- self._launcher = None
- else:
- logging.info(
- 'emulator_launcher config: %r', self._config.emulator_launcher
- )
- self._launcher = emulator_launcher.EmulatorLauncher(
- config=self._config.emulator_launcher,
- adb_controller_config=self._config.adb_controller,
- )
- self._logfile_path = (
- self._config.logfile_path or self._launcher.logfile_path()
- )
-
- def _reconnect_on_grpc_error(func):
- """Decorator function for reconnecting to emulator upon grpc errors."""
-
- def wrapper(self, *args, **kwargs):
- try:
- return func(self, *args, **kwargs)
- except grpc.RpcError:
- logging.exception('RpcError caught. Reconnecting to emulator...')
- self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
- self._config.emulator_launcher.grpc_port
- )
- return func(self, *args, **kwargs)
-
- return wrapper
-
- def get_logs(self) -> str:
- """Returns logs recorded by the emulator."""
- if self._logfile_path and os.path.exists(self._logfile_path):
- with open(self._logfile_path, 'rb') as f:
- return f.read().decode('utf-8')
- else:
- return f'Logfile does not exist: {self._logfile_path}.'
-
- def adb_device_name(self) -> str:
- return 'emulator-%s' % (self._config.emulator_launcher.adb_port - 1)
-
- def create_adb_controller(self):
- """Returns an ADB controller which can communicate with this simulator."""
- return adb_controller.AdbController(self._config.adb_controller)
-
- def create_log_stream(self) -> log_stream.LogStream:
- return adb_log_stream.AdbLogStream(
- adb_command_prefix=self._adb_controller.command_prefix(),
- verbose=self._config.verbose_logs,
- )
-
- def _launch_impl(self) -> None:
- """Prepares an Android Emulator for RL interaction.
-
- The behavior depends on `self._num_launch_attempts`'s value:
- * <= self._config.launch_n_times_without_reboot -> Normal boot behavior.
- * > self._config.launch_n_times_without_reboot but <=
- self._config.launch_n_times_without_reinstall -> reboot (i.e. process
- is killed and started again).
- * > self._config.launch_n_times_without_reinstall -> reinstall (i.e.
- process is killed, emulator files are deleted and the process started
- again).
- """
-
- logging.info('Attempt %r at launching the Android Emulator (%r)',
- self._num_launch_attempts, self.adb_device_name())
-
- if self._launcher is not None:
- # If not the first time, then shutdown the emulator first.
- if (
- self._emulator_stub is not None
- and self._num_launch_attempts
- > self._config.launch_n_times_without_reboot
- ):
- self._shutdown_emulator()
- # Subsequent attempts cause the emulator files to be reinstalled.
- if (
- self._num_launch_attempts
- > self._config.launch_n_times_without_reinstall
- ):
- logging.info('Closing emulator (%r)', self.adb_device_name())
- self._launcher.close()
- self._launcher = emulator_launcher.EmulatorLauncher(
- config=self._config.emulator_launcher,
- adb_controller_config=self._config.adb_controller,
- )
- self._launcher.launch_emulator_process()
- # Establish grpc connection to emulator process.
- self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
- self._config.emulator_launcher.grpc_port
- )
-
- # Confirm booted status.
- try:
- self._confirm_booted()
- except EmulatorCrashError:
- logging.exception('Failed to confirm booted status of emulator.')
-
- logging.info('Done booting the Android Emulator.')
-
- def load_state(
- self, request: state_pb2.LoadStateRequest
- ) -> state_pb2.LoadStateResponse:
- """Loads a state using the emulator's snapshotting mechanism.
-
- Args:
- request: The `LoadStateRequest`. In this case, `args` should be a dict
- containing the key 'snapshot_name', representing the name of the
- snapshot to load. If `request.args.snapshot_name` is `None`, a default
- snapshot name is used.
-
- Returns:
- A response indicating whether the snapshot was successfully loaded.
- * If the snapshot was loaded successfully, the status will be `OK`.
- * If no snapshot of the given name was found, the status will be
- `NOT_FOUND`.
- * If an error occurred during the snapshot loading process, the status
- will be `ERROR` and the `error_message` field will be filled.
- """
- assert self._snapshot_stub is not None
- snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
- snapshot_list = self._snapshot_stub.ListSnapshots(
- snapshot_service_pb2.SnapshotFilter(
- statusFilter=snapshot_service_pb2.SnapshotFilter.LoadStatus.All
- )
- )
- if any(
- snapshot.snapshot_id == snapshot_name
- for snapshot in snapshot_list.snapshots
- ):
- snapshot_result = self._snapshot_stub.LoadSnapshot(
- snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name)
- )
- if snapshot_result.success:
- return state_pb2.LoadStateResponse(
- status=state_pb2.LoadStateResponse.Status.OK
- )
- else:
- return state_pb2.LoadStateResponse(
- status=state_pb2.LoadStateResponse.Status.ERROR,
- error_message=snapshot_result.err.decode('utf-8'),
- )
-
- else:
- return state_pb2.LoadStateResponse(
- status=state_pb2.LoadStateResponse.Status.NOT_FOUND
- )
-
- def save_state(
- self, request: state_pb2.SaveStateRequest
- ) -> state_pb2.SaveStateResponse:
- """Saves a state using the emulator's snapshotting mechanism.
-
- Args:
- request: The `SaveStateRequest`. In this case, `args` should be a dict
- containing the key 'snapshot_name', representing the name of the
- snapshot to save. If `request.args.snapshot_name` is `None`, a default
- snapshot name is used.
-
- Returns:
- A response indicating whether the snapshot was successfully saved.
- * If the snapshot was saved successfully, the status will be `OK`.
- * If an error occurred during the snapshot saving process, the status
- will be `ERROR` and the `error_message` field will be filled.
- """
- assert self._snapshot_stub is not None
- snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
- snapshot_result = self._snapshot_stub.SaveSnapshot(
- snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name)
- )
- if snapshot_result.success:
- return state_pb2.SaveStateResponse(
- status=state_pb2.SaveStateResponse.Status.OK
- )
- else:
- return state_pb2.SaveStateResponse(
- status=state_pb2.SaveStateResponse.Status.ERROR,
- error_message=snapshot_result.err.decode('utf-8'),
- )
-
- def _connect_to_emulator(
- self,
- grpc_port: int,
- timeout_sec: int = 100,
- ) -> tuple[
- emulator_controller_pb2_grpc.EmulatorControllerStub,
- snapshot_service_pb2_grpc.SnapshotServiceStub,
- ]:
- """Connects to an emulator and returns a corresponsing stub."""
-
- logging.info('Creating gRPC channel to the emulator on port %r', grpc_port)
- port = f'localhost:{grpc_port}'
- options = [('grpc.max_send_message_length', -1),
- ('grpc.max_receive_message_length', -1)]
- creds = grpc.local_channel_credentials()
-
- try:
- self._channel = grpc.secure_channel(port, creds, options=options)
- grpc.channel_ready_future(self._channel).result(timeout=timeout_sec)
- except (grpc.RpcError, grpc.FutureTimeoutError) as grpc_error:
- logging.exception('Failed to connect to the emulator.')
- raise EmulatorBootError(
- 'Failed to connect to the emulator.') from grpc_error
-
- logging.info('Added gRPC channel for the Emulator on port %s', port)
- emulator_controller_stub = (
- emulator_controller_pb2_grpc.EmulatorControllerStub(self._channel)
- )
- snapshot_stub = snapshot_service_pb2_grpc.SnapshotServiceStub(self._channel)
- return emulator_controller_stub, snapshot_stub
-
- @_reconnect_on_grpc_error
- def _confirm_booted(self, startup_wait_time_sec: int = 300):
- """Waits until the emulator is fully booted."""
-
- assert (
- self._emulator_stub is not None
- ), 'Emulator stub has not been initialized yet.'
- start_time = time.time()
- deadline = start_time + startup_wait_time_sec
- success = False
- while time.time() < deadline:
- emu_status = self._emulator_stub.getStatus(empty_pb2.Empty())
- logging.info('Waiting for emulator (%r) to start... (%rms)',
- self.adb_device_name(), emu_status.uptime)
- if emu_status.booted:
- success = True
- break
- time.sleep(5.0)
-
- elapsed_time = time.time() - start_time
- if not success:
- raise EmulatorCrashError(
- f'The emulator failed to boot after {startup_wait_time_sec} seconds')
-
- logging.info('Done booting the emulator (in %f seconds).', elapsed_time)
- logging.info('********** Emulator logs **********')
- for line in self.get_logs().splitlines():
- logging.info(line)
- logging.info('******* End of emulator logs *******')
- logging.info('See the full emulator logs at %r', self._logfile_path)
-
- @_reconnect_on_grpc_error
- def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
- """Sends a touch event to be executed on the simulator.
-
- Args:
- touches: A list of touch events. Each element in the list corresponds to a
- single touch event. Each touch event tuple should have:
- 0 x: The horizontal coordinate of this event.
- 1 y: The vertical coordinate of this event.
- 2 is_down: Whether the finger is touching or not the screen.
- 3 identifier: Identifies a particular finger in a multitouch event.
- """
-
- assert (
- self._emulator_stub is not None
- ), 'Emulator stub has not been initialized yet.'
- touch_events = [
- emulator_controller_pb2.Touch(
- x=t[0], y=t[1], pressure=int(t[2]), identifier=t[3])
- for t in touches
- ]
- self._emulator_stub.sendTouch(
- emulator_controller_pb2.TouchEvent(touches=touch_events))
-
- @_reconnect_on_grpc_error
- def send_key(self, keycode: np.int32, event_type: str) -> None:
- """Sends a key event to the emulator.
-
- Args:
- keycode: Code representing the desired key press in XKB format.
- See the emulator_controller_pb2 for details.
- event_type: Type of key event to be sent.
- """
-
- event_types = emulator_controller_pb2.KeyboardEvent.KeyEventType.keys()
- if event_type not in event_types:
- raise ValueError(
- f'Event type must be one of {event_types} but is {event_type}.')
-
- assert (
- self._emulator_stub is not None
- ), 'Emulator stub has not been initialized yet.'
- self._emulator_stub.sendKey(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType.Value(
- event_type
- ),
- keyCode=int(keycode),
- )
- )
-
- @_reconnect_on_grpc_error
- def _get_screenshot_impl(self) -> np.ndarray:
- """Fetches the latest screenshot from the emulator."""
-
- assert (
- self._emulator_stub is not None
- ), 'Emulator stub has not been initialized yet.'
- assert self._image_format, 'ImageFormat has not been initialized yet.'
- image_proto = self._emulator_stub.getScreenshot(self._image_format)
- h, w = image_proto.format.height, image_proto.format.width
- image = np.frombuffer(image_proto.image, dtype='uint8', count=h * w * 4)
- image.shape = (h, w, 4)
- return image[:, :, :3]
-
- @_reconnect_on_grpc_error
- def _shutdown_emulator(self):
- """Sends a signal to trigger emulator shutdown."""
-
- if self._emulator_stub is None:
- logging.info('Emulator (%r) is not up.', self.adb_device_name())
- return
-
- assert self._launcher is not None, 'Launcher is already down.'
-
- logging.info('Shutting down the emulator (%r)...', self.adb_device_name())
- self._emulator_stub.setVmState(
- emulator_controller_pb2.VmRunState(
- state=emulator_controller_pb2.VmRunState.RunState.SHUTDOWN))
- self._launcher.confirm_shutdown()
-
- def close(self):
- super().close()
-
- if self._launcher is not None:
- self._shutdown_emulator()
- logging.info('Closing emulator (%r)', self.adb_device_name())
- self._launcher.close()
- self._emulator_stub = None
- self._snapshot_stub = None
- if self._channel is not None:
- self._channel.close()
- super().close()
diff --git a/android_env/components/simulators/emulator/emulator_simulator_test.py b/android_env/components/simulators/emulator/emulator_simulator_test.py
deleted file mode 100644
index 97e24711..00000000
--- a/android_env/components/simulators/emulator/emulator_simulator_test.py
+++ /dev/null
@@ -1,543 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.emulator_simulator."""
-
-import builtins
-import os
-import time
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_call_parser
-from android_env.components import adb_controller
-from android_env.components import config_classes
-from android_env.components.simulators.emulator import emulator_launcher
-from android_env.components.simulators.emulator import emulator_simulator
-from android_env.proto import state_pb2
-import grpc
-from PIL import Image
-import portpicker
-
-from android_env.proto import emulator_controller_pb2
-from android_env.proto import emulator_controller_pb2_grpc
-from android_env.proto import snapshot_service_pb2
-
-
-class EmulatorSimulatorTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.addCleanup(mock.patch.stopall) # Disable previous patches.
-
- self._adb_controller = mock.create_autospec(adb_controller.AdbController)
- self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
- self._launcher = mock.create_autospec(emulator_launcher.EmulatorLauncher)
- self._launcher.logfile_path.return_value = 'logfile_path'
- self._emulator_stub = mock.create_autospec(
- emulator_controller_pb2_grpc.EmulatorControllerStub)
-
- self._grpc_channel = mock.create_autospec(grpc.Channel)
- mock.patch.object(
- grpc.aio, 'secure_channel', return_value=self._grpc_channel).start()
- mock.patch.object(
- grpc, 'secure_channel', return_value=self._grpc_channel).start()
- mock.patch.object(
- grpc, 'local_channel_credentials',
- return_value=self._grpc_channel).start()
- self._mock_future = mock.create_autospec(grpc.Future)
- mock.patch.object(
- grpc, 'channel_ready_future', return_value=self._mock_future).start()
- mock.patch.object(time, 'time', return_value=12345).start()
-
- mock.patch.object(
- adb_controller, 'AdbController',
- return_value=self._adb_controller).start()
- mock.patch.object(
- adb_call_parser,
- 'AdbCallParser',
- autospec=True,
- return_value=self._adb_call_parser).start()
- mock.patch.object(
- emulator_launcher, 'EmulatorLauncher',
- return_value=self._launcher).start()
-
- def test_adb_device_name_not_empty(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
- self.assertNotEmpty(simulator.adb_device_name())
-
- @mock.patch.object(os.path, 'exists', autospec=True, return_value=True)
- @mock.patch.object(builtins, 'open', autospec=True)
- def test_logfile_path(self, mock_open, unused_mock_exists):
- config = config_classes.EmulatorConfig(
- logfile_path='fake/logfile/path',
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
- mock_open.return_value.__enter__.return_value.read.return_value = (
- 'fake_logs'.encode('utf-8'))
- logs = simulator.get_logs()
- mock_open.assert_called_once_with('fake/logfile/path', 'rb')
- self.assertEqual(logs, 'fake_logs')
-
- @mock.patch.object(portpicker, 'is_port_free', return_value=True)
- def test_grpc_port(self, unused_mock_portpicker):
-
- launcher_config = config_classes.EmulatorLauncherConfig(
- tmp_dir=self.create_tempdir().full_path
- )
- config = config_classes.EmulatorConfig(
- emulator_launcher=launcher_config,
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
- self.assertEqual(launcher_config.grpc_port, 8554)
-
- @mock.patch.object(portpicker, 'is_port_free', return_value=False)
- def test_grpc_port_unavailable(self, unused_mock_portpicker):
-
- launcher_config = config_classes.EmulatorLauncherConfig(
- tmp_dir=self.create_tempdir().full_path
- )
- config = config_classes.EmulatorConfig(
- emulator_launcher=launcher_config,
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
- self.assertNotEqual(launcher_config.grpc_port, 8554)
-
- def test_launch_operation_order(self):
- """Makes sure that adb_controller is started before Emulator is launched."""
-
- # Arrange.
- call_order = []
- self._adb_controller.init_server.side_effect = lambda: call_order.append(
- 'init_server'
- )
- self._launcher.launch_emulator_process.side_effect = (
- lambda: call_order.append('launch_emulator_process')
- )
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # Act.
- simulator.launch() # The simulator should launch and not crash.
-
- # Assert.
- # The adb server should be initialized before launching the emulator.
- self.assertEqual(call_order, ['init_server', 'launch_emulator_process'])
-
- def test_close(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- # For whatever reason clients may want to close the EmulatorSimulator.
- # We just want to check that the simulator does not crash and/or leak
- # resources.
- simulator.close()
-
- def test_value_error_if_launch_attempt_params_incorrect(self):
- self.assertRaises(
- ValueError,
- emulator_simulator.EmulatorSimulator,
- config=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- launch_n_times_without_reboot=2,
- launch_n_times_without_reinstall=1,
- ),
- )
-
- def test_launch_attempt_reboot(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- launch_n_times_without_reboot=1,
- launch_n_times_without_reinstall=2,
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- self._launcher.launch_emulator_process.assert_called_once()
- self._launcher.reset_mock()
-
- # Launch attempt 2.
- simulator.launch()
- self._launcher.confirm_shutdown.assert_called_once()
- self._launcher.close.assert_not_called()
- self._launcher.launch_emulator_process.assert_called_once()
-
- def test_launch_attempt_reinstall_after_zero_attempts(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- launch_n_times_without_reboot=0,
- launch_n_times_without_reinstall=0,
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should not reboot or reinstall on its very first launch.
- simulator.launch()
- self._launcher.launch_emulator_process.assert_called_once()
- self._launcher.confirm_shutdown.assert_not_called()
- self._launcher.close.assert_not_called()
-
- # Every subsequent attempt should reboot and reinstall.
- self._launcher.reset_mock()
- simulator.launch()
- self._launcher.confirm_shutdown.assert_called_once()
- self._launcher.close.assert_called_once() # Now this should `close()`.
- self._launcher.launch_emulator_process.assert_called_once()
-
- def test_launch_attempt_reinstall(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- launch_n_times_without_reboot=1,
- launch_n_times_without_reinstall=2,
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
- self._launcher.launch_emulator_process.assert_called_once()
-
- # Launch attempt 2.
- self._launcher.reset_mock()
- simulator.launch()
- self._launcher.confirm_shutdown.assert_called_once()
- self._launcher.close.assert_not_called() # Reboots don't `close()`.
- self._launcher.launch_emulator_process.assert_called_once()
-
- # Launch attempt 3.
- self._launcher.reset_mock()
- simulator.launch()
- self._launcher.confirm_shutdown.assert_called_once()
- self._launcher.close.assert_called_once() # Now this should `close()`.
- self._launcher.launch_emulator_process.assert_called_once()
-
- def test_get_screenshot(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- simulator._emulator_stub.getScreenshot = mock.MagicMock(
- return_value=emulator_controller_pb2.Image(
- format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
- image=Image.new('RGBA', (1234, 5678)).tobytes(),
- timestampUs=123))
-
- screenshot = simulator.get_screenshot()
- # The screenshot should have the same screen dimensions as reported by ADB
- # and it should have 3 channels (RGB).
- self.assertEqual(screenshot.shape, (1234, 5678, 3))
-
- def test_load_state(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- with mock.patch.object(
- simulator, '_snapshot_stub', create_autospec=True
- ) as mock_snapshot_stub:
- snapshot_list = snapshot_service_pb2.SnapshotList()
- snapshot_list.snapshots.add(snapshot_id='snapshot_name_foo')
- snapshot_list.snapshots.add(snapshot_id='snapshot_name_bar')
- mock_snapshot_stub.ListSnapshots.return_value = snapshot_list
- mock_snapshot_stub.LoadSnapshot.return_value = (
- snapshot_service_pb2.SnapshotPackage(success=True)
- )
- load_response = simulator.load_state(
- request=state_pb2.LoadStateRequest(
- args={'snapshot_name': 'snapshot_name_foo'}
- )
- )
- self.assertEqual(
- load_response.status, state_pb2.LoadStateResponse.Status.OK
- )
- load_response = simulator.load_state(
- request=state_pb2.LoadStateRequest(
- args={'snapshot_name': 'snapshot_name_baz'}
- )
- )
- self.assertEqual(
- load_response.status, state_pb2.LoadStateResponse.Status.NOT_FOUND
- )
- mock_snapshot_stub.LoadSnapshot.return_value = (
- snapshot_service_pb2.SnapshotPackage(success=False, err=b'error')
- )
- load_response = simulator.load_state(
- request=state_pb2.LoadStateRequest(
- args={'snapshot_name': 'snapshot_name_bar'}
- )
- )
- self.assertEqual(
- load_response.status, state_pb2.LoadStateResponse.Status.ERROR
- )
- self.assertEqual(load_response.error_message, 'error')
-
- def test_save_state(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- with mock.patch.object(
- simulator, '_snapshot_stub', create_autospec=True
- ) as mock_snapshot_stub:
- mock_snapshot_stub.SaveSnapshot.return_value = (
- snapshot_service_pb2.SnapshotPackage(success=True)
- )
- save_response = simulator.save_state(
- request=state_pb2.SaveStateRequest(
- args={'snapshot_name': 'snapshot_name_foo'}
- )
- )
- self.assertEqual(
- save_response.status, state_pb2.SaveStateResponse.Status.OK
- )
- mock_snapshot_stub.SaveSnapshot.return_value = (
- snapshot_service_pb2.SnapshotPackage(success=False, err=b'error')
- )
- save_response = simulator.save_state(
- request=state_pb2.SaveStateRequest(
- args={'snapshot_name': 'snapshot_name_bar'}
- )
- )
- self.assertEqual(
- save_response.status, state_pb2.SaveStateResponse.Status.ERROR
- )
- self.assertEqual(save_response.error_message, 'error')
-
- def test_send_touch(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)
-
- simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)])
- simulator.send_touch([(1, 2, True, 0), (3, 4, True, 1)])
- simulator.send_touch([(321, 654, False, 0), (531, 642, False, 1)])
-
- simulator._emulator_stub.sendTouch.assert_has_calls([
- mock.call(
- emulator_controller_pb2.TouchEvent(touches=[{
- 'x': 123,
- 'y': 456,
- 'pressure': 1
- }, {
- 'x': 135,
- 'y': 246,
- 'pressure': 1,
- 'identifier': 1
- }])),
- mock.call(
- emulator_controller_pb2.TouchEvent(touches=[{
- 'x': 1,
- 'y': 2,
- 'pressure': 1
- }, {
- 'x': 3,
- 'y': 4,
- 'pressure': 1,
- 'identifier': 1
- }])),
- mock.call(
- emulator_controller_pb2.TouchEvent(touches=[{
- 'x': 321,
- 'y': 654,
- 'pressure': 0
- }, {
- 'x': 531,
- 'y': 642,
- 'pressure': 0,
- 'identifier': 1
- }])),
- ])
-
- def test_send_key(self):
- config = config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- grpc_port=1234, tmp_dir=self.create_tempdir().full_path
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='/my/adb',
- adb_server_port=5037,
- ),
- )
- simulator = emulator_simulator.EmulatorSimulator(config)
-
- # The simulator should launch and not crash.
- simulator.launch()
-
- simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)
-
- simulator.send_key(123, 'keydown')
- simulator.send_key(321, 'keydown')
- simulator.send_key(321, 'keyup')
- simulator.send_key(123, 'keyup')
- simulator.send_key(321, 'keypress')
- simulator.send_key(123, 'keypress')
-
- simulator._emulator_stub.sendKey.assert_has_calls([
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keydown,
- keyCode=123,
- )),
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keydown,
- keyCode=321,
- )),
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keyup,
- keyCode=321,
- )),
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keyup,
- keyCode=123,
- )),
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keypress,
- keyCode=321,
- )),
- mock.call(
- emulator_controller_pb2.KeyboardEvent(
- codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
- eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
- .keypress,
- keyCode=123,
- ))
- ])
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/simulators/fake/__init__.py b/android_env/components/simulators/fake/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/components/simulators/fake/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/components/simulators/fake/fake_simulator.py b/android_env/components/simulators/fake/fake_simulator.py
deleted file mode 100644
index 79996c8b..00000000
--- a/android_env/components/simulators/fake/fake_simulator.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Fake Simulator for testing AndroidEnv infrastructure."""
-
-import random
-import threading
-import time
-
-from absl import logging
-from android_env.components import adb_controller
-from android_env.components import config_classes
-from android_env.components import log_stream
-from android_env.components.simulators import base_simulator
-import numpy as np
-
-
-class FakeStream:
- """This class simulates the logs coming from ADB."""
-
- def __init__(self):
- self._values = [
- '',
- self._make_stdout('reward: 0.5'),
- self._make_stdout('reward: 1.0'),
- self._make_stdout('extra: my_extra [1.0]'),
- self._make_stdout('episode end'),
- ]
- self._kill = False
- self._lock = threading.Lock()
-
- def _make_stdout(self, data):
- """Returns a valid log output with given data as message."""
- return f' 1553110400.424 5583 5658 D Tag: {data}'
-
- def kill(self):
- self._kill = True
-
- def __iter__(self):
- while True:
- if self._kill:
- return
- else:
- with self._lock:
- next_value = random.choices(
- self._values, weights=[0.49, 0.15, 0.15, 0.15, 0.01], k=1)[0]
- time.sleep(0.1)
- yield next_value
-
-
-class FakeLogStream(log_stream.LogStream):
- """FakeLogStream class that wraps a FakeStream."""
-
- def __init__(self):
- super().__init__(verbose=False)
- self.stream = FakeStream()
-
- def _get_stream_output(self):
- return self.stream
-
- def stop_stream(self):
- self.stream.kill()
-
-
-class FakeAdbController(adb_controller.AdbController):
- """Fake adb controller for FakeSimulator."""
-
- def execute_command(
- self,
- args: list[str],
- timeout: float | None = None,
- device_specific: bool = True,
- ) -> bytes:
- """Returns fake output for adb commands."""
-
- del timeout, device_specific
-
- # Fake "service is ready" output.
- if args[:3] == ['shell', 'service', 'check']:
- return f'Service {args[-1]}: found'.encode('utf-8')
-
- # Fake dumpsys output for getting orientation.
- if args == ['shell', 'dumpsys', 'input']:
- return b' SurfaceOrientation: 0'
-
- # app_screen_checker: fake_task expects 'fake_activity'.
- if args[:4] == ['shell', 'am', 'stack', 'list']:
- return (b'taskId=0 fake_activity visible=true '
- b'topActivity=ComponentInfo{fake_activity}')
-
- return b'fake output'
-
-
-class FakeSimulator(base_simulator.BaseSimulator):
- """FakeSimulator class."""
-
- def __init__(self, config: config_classes.FakeSimulatorConfig):
- """FakeSimulator class that can replace EmulatorSimulator in AndroidEnv."""
- super().__init__(config)
- self._screen_dimensions = np.array(config.screen_dimensions)
- logging.info('Created FakeSimulator.')
-
- def get_logs(self) -> str:
- return 'FakeSimulator: fake logs'
-
- def adb_device_name(self) -> str:
- return 'fake_simulator'
-
- def create_adb_controller(self):
- return FakeAdbController(config_classes.AdbControllerConfig())
-
- def create_log_stream(self) -> log_stream.LogStream:
- return FakeLogStream()
-
- def _launch_impl(self) -> None:
- pass
-
- def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
- del touches
-
- def send_key(self, keycode: np.int32, event_type: str) -> None:
- del keycode, event_type
-
- def _get_screenshot_impl(self) -> np.ndarray:
- return np.random.randint(
- low=0, high=255, size=(*self._screen_dimensions, 3), dtype=np.uint8)
diff --git a/android_env/components/simulators/fake/fake_simulator_test.py b/android_env/components/simulators/fake/fake_simulator_test.py
deleted file mode 100644
index 5231433e..00000000
--- a/android_env/components/simulators/fake/fake_simulator_test.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for fake_simulator."""
-
-import re
-from absl.testing import absltest
-from android_env.components import config_classes
-from android_env.components.simulators.fake import fake_simulator
-import numpy as np
-
-
-class FakeSimulatorTest(absltest.TestCase):
-
- def test_device_name(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- self.assertEqual(simulator.adb_device_name(), 'fake_simulator')
-
- def test_launch_close(self):
- # The simulator should launch and not crash.
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
- # Closing the simulator should also not crash.
- simulator.close()
-
- def test_get_screenshot(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
-
- screenshot = simulator.get_screenshot()
- np.testing.assert_equal(screenshot.shape, [320, 480, 3])
- np.testing.assert_equal(screenshot.dtype, np.uint8)
-
- def test_log_stream(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
- log_stream = simulator.create_log_stream()
- # Start yielding lines from LogStream.
- log_stream.resume_stream()
- lines = [
- '',
- ' 1553110400.424 5583 5658 D Tag: reward: 0.5',
- ' 1553110400.424 5583 5658 D Tag: reward: 1.0',
- ' 1553110400.424 5583 5658 D Tag: extra: my_extra [1.0]',
- ' 1553110400.424 5583 5658 D Tag: episode end',
- ]
- for i, line in enumerate(log_stream.get_stream_output()):
- self.assertIn(line, lines)
- if i > 10:
- break
-
- def test_adb_output(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
- adb_controller = simulator.create_adb_controller()
- line = adb_controller.execute_command(['shell', 'dumpsys', 'input'])
- line = line.decode('utf-8')
- matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line)
- self.assertIsNotNone(matches)
- orientation = matches.group(1)
- self.assertEqual(orientation, '0')
- line = adb_controller.execute_command(['shell', 'service', 'check', 'foo'])
- line = line.decode('utf-8')
- self.assertEqual(line, 'Service foo: found')
- line = adb_controller.execute_command(['shell', 'am', 'stack', 'list'])
- line = line.decode('utf-8')
- self.assertEqual(line, 'taskId=0 fake_activity visible=true '
- 'topActivity=ComponentInfo{fake_activity}')
-
- def test_send_touch(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
- simulator.send_touch([(0, 1, True, 0)])
- simulator.send_touch([(0, 1, False, 0)])
- # No assertions, we just want to ensure that `send_touch()` can be called
- # without crashing anything.
-
- def test_send_key(self):
- simulator = fake_simulator.FakeSimulator(
- config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
- )
- simulator.launch()
- simulator.send_key(np.int32(123), 'keydown')
- simulator.send_key(np.int32(123), 'keyup')
- simulator.send_key(np.int32(123), 'keypress')
- # No assertions, we just want to ensure that `send_key()` can be called
- # without crashing anything.
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/specs.py b/android_env/components/specs.py
deleted file mode 100644
index d17aa7c0..00000000
--- a/android_env/components/specs.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base specs for AndroidEnv."""
-
-from android_env.components import action_type
-from android_env.proto import task_pb2
-from dm_env import specs
-import numpy as np
-
-
-_PROTO_DTYPE_TO_NUMPY_DTYPE = {
- task_pb2.ArraySpec.DataType.FLOAT: np.float32,
- task_pb2.ArraySpec.DataType.DOUBLE: np.float64,
- task_pb2.ArraySpec.DataType.INT8: np.int8,
- task_pb2.ArraySpec.DataType.INT16: np.int16,
- task_pb2.ArraySpec.DataType.INT32: np.int32,
- task_pb2.ArraySpec.DataType.INT64: np.int64,
- task_pb2.ArraySpec.DataType.UINT8: np.uint8,
- task_pb2.ArraySpec.DataType.UINT16: np.uint16,
- task_pb2.ArraySpec.DataType.UINT32: np.uint32,
- task_pb2.ArraySpec.DataType.UINT64: np.uint64,
- task_pb2.ArraySpec.DataType.BOOL: np.bool_,
- task_pb2.ArraySpec.DataType.STRING_U1: np.dtype(('U1')),
- task_pb2.ArraySpec.DataType.STRING_U16: np.dtype((' dict[str, specs.Array]:
- """Default action spec for AndroidEnv.
-
- Args:
- num_fingers: Number of virtual fingers of the agent.
- enable_key_events: Whether keyboard key events are enabled.
-
- Returns:
- A dict of action specs, each item corresponding to a virtual finger.
- action_type: An integer of type ActionType: TOUCH=0, LIFT=1, REPEAT=2
- touch_position: Position [x, y] of the touch action, where x, y are float
- values between 0.0 and 1.0 corresponding to the relative position on the
- screen. IGNORED when (action_type != ActionType.TOUCH).
- keycode: code representing the desired key press in XKB format. See the
- emulator_controller_pb2 for details.
- action_type_i: Action type for additional fingers (i>1).
- touch_position_i: Touch position for additional fingers (i>1).
- """
-
- num_actions = len(action_type.ActionType) if enable_key_events else 3
-
- action_spec = {
- 'action_type':
- specs.DiscreteArray(num_values=num_actions, name='action_type'),
- 'touch_position':
- specs.BoundedArray(
- shape=(2,),
- dtype=np.float32,
- minimum=[0.0, 0.0],
- maximum=[1.0, 1.0],
- name='touch_position'),
- }
-
- for i in range(2, num_fingers + 1):
- action_spec.update({
- f'action_type_{i}':
- specs.DiscreteArray(
- num_values=len(action_type.ActionType),
- name=f'action_type_{i}'),
- f'touch_position_{i}':
- specs.BoundedArray(
- shape=(2,),
- dtype=np.float32,
- minimum=[0.0, 0.0],
- maximum=[1.0, 1.0],
- name=f'touch_position_{i}'),
- })
-
- if enable_key_events:
- action_spec['keycode'] = specs.DiscreteArray(
- num_values=(1 << 16) - 1, name='keycode')
-
- return action_spec
-
-
-def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]:
- """Default observation spec for AndroidEnv.
-
- Args:
- height: Height of the device screen in pixels.
- width: Width of the device screen in pixels.
-
- Returns:
- pixels: Spec for the RGB screenshot of the device. Has shape (H, W, 3)
- timedelta: Spec for time delta since the last observation (in microseconds).
- The first timestep immediately after reset() will have this value set to
- 0.
- orientation: Spec for the latest orientation in a one-hot representation:
- [1, 0, 0, 0]: PORTRAIT (0 degrees)
- [0, 1, 0, 0]: LANDSCAPE (90 degrees clockwise)
- [0, 0, 1, 0]: PORTRAIT (180 degrees) ("upside down")
- [0, 0, 0, 1]: LANDSCAPE (270 degrees clockwise)
- """
-
- return {
- 'pixels':
- specs.BoundedArray(
- shape=(height, width, 3),
- dtype=np.uint8,
- name='pixels',
- minimum=0,
- maximum=255),
- 'timedelta':
- specs.Array(shape=(), dtype=np.int64, name='timedelta'),
- 'orientation':
- specs.BoundedArray(
- shape=np.array([4]),
- dtype=np.uint8,
- name='orientation',
- minimum=0,
- maximum=1),
- }
diff --git a/android_env/components/specs_test.py b/android_env/components/specs_test.py
deleted file mode 100644
index 389eba5e..00000000
--- a/android_env/components/specs_test.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for specs.py."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.components import specs
-from android_env.proto import task_pb2
-from dm_env import specs as dm_env_specs
-import numpy as np
-
-
-class SpecsTest(parameterized.TestCase):
-
- def test_base_action_spec(self):
- action_spec = specs.base_action_spec(num_fingers=1)
- for spec in action_spec.values():
- self.assertIsInstance(spec, dm_env_specs.Array)
- self.assertEqual(action_spec['action_type'].shape, ())
- self.assertEqual(action_spec['action_type'].dtype, np.int32)
- self.assertEqual(action_spec['touch_position'].shape, (2,))
- self.assertEqual(action_spec['touch_position'].dtype, np.float32)
-
- def test_base_action_spec_with_key_events(self):
- action_spec = specs.base_action_spec(num_fingers=1, enable_key_events=True)
- for spec in action_spec.values():
- self.assertIsInstance(spec, dm_env_specs.Array)
- self.assertEqual(action_spec['action_type'].shape, ())
- self.assertEqual(action_spec['action_type'].dtype, np.int32)
- self.assertEqual(action_spec['touch_position'].shape, (2,))
- self.assertEqual(action_spec['touch_position'].dtype, np.float32)
- self.assertEqual(action_spec['keycode'].shape, ())
- self.assertEqual(action_spec['keycode'].dtype, np.int32)
-
- def test_base_action_spec_multitouch(self):
- action_spec = specs.base_action_spec(num_fingers=3)
- self.assertLen(action_spec.keys(), 6)
- for spec in action_spec.values():
- self.assertIsInstance(spec, dm_env_specs.Array)
- self.assertEqual(action_spec['action_type'].shape, ())
- self.assertEqual(action_spec['action_type'].dtype, np.int32)
- self.assertEqual(action_spec['touch_position'].shape, (2,))
- self.assertEqual(action_spec['touch_position'].dtype, np.float32)
- self.assertEqual(action_spec['action_type_2'].shape, ())
- self.assertEqual(action_spec['action_type_2'].dtype, np.int32)
- self.assertEqual(action_spec['touch_position_2'].shape, (2,))
- self.assertEqual(action_spec['touch_position_2'].dtype, np.float32)
- self.assertEqual(action_spec['action_type_3'].shape, ())
- self.assertEqual(action_spec['action_type_3'].dtype, np.int32)
- self.assertEqual(action_spec['touch_position_3'].shape, (2,))
- self.assertEqual(action_spec['touch_position_3'].dtype, np.float32)
-
- @parameterized.parameters(
- (480, 320),
- (100, 100),
- (1440, 1960),
- )
- def test_base_observation_spec(self, height, width):
- observation_spec = specs.base_observation_spec(height, width)
- for spec in observation_spec.values():
- self.assertIsInstance(spec, dm_env_specs.Array)
- self.assertEqual(observation_spec['pixels'].shape, (height, width, 3))
- self.assertEqual(observation_spec['pixels'].dtype, np.uint8)
- self.assertEqual(observation_spec['timedelta'].shape, ())
- self.assertEqual(observation_spec['timedelta'].dtype, np.int64)
- self.assertEqual(observation_spec['orientation'].shape, (4,))
- self.assertEqual(observation_spec['orientation'].dtype, np.uint8)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/components/task_manager.py b/android_env/components/task_manager.py
deleted file mode 100644
index 81718546..00000000
--- a/android_env/components/task_manager.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""TaskManager handles all events and information related to the task."""
-
-import ast
-from collections.abc import Callable
-import copy
-import datetime
-import json
-import re
-import threading
-from typing import Any
-
-from absl import logging
-from android_env.components import adb_call_parser as adb_call_parser_lib
-from android_env.components import app_screen_checker
-from android_env.components import config_classes
-from android_env.components import dumpsys_thread
-from android_env.components import log_stream as log_stream_lib
-from android_env.components import logcat_thread
-from android_env.components import setup_step_interpreter
-from android_env.proto import task_pb2
-import dm_env
-import numpy as np
-
-
-class TaskManager:
- """Handles all events and information related to the task."""
-
- def __init__(
- self,
- task: task_pb2.Task,
- config: config_classes.TaskManagerConfig | None = None,
- ):
- """Controls task-relevant events and information.
-
- Args:
- task: A task proto defining the RL task.
- config: Configuration for this instance.
- """
-
- self._task = task
- self._config = config or config_classes.TaskManagerConfig()
- self._lock = threading.Lock()
- self._logcat_thread = None
- self._dumpsys_thread = None
- self._setup_step_interpreter = None
-
- # Initialize stats.
- self._stats = {
- 'episode_steps': 0,
- 'reset_count_step_timeout': 0,
- 'reset_count_user_exited': 0,
- 'reset_count_episode_end': 0,
- 'reset_count_max_duration_reached': 0,
- 'restart_count_max_bad_states': 0,
- 'task_updates': 0,
- }
-
- # Initialize internal state
- self._task_start_time = None
- self._bad_state_counter = 0
- self._is_bad_episode = False
-
- self._latest_values = {
- 'reward': 0.0,
- 'score': 0.0,
- 'extra': {},
- 'episode_end': False,
- }
-
- logging.info('Task config: %s', self._task)
-
- def stats(self) -> dict[str, Any]:
- """Returns a dictionary of stats.
-
- This method is expected to be called after setup_task() has been called.
- """
- output = copy.deepcopy(self._stats)
- if self._setup_step_interpreter is not None:
- output.update(self._setup_step_interpreter.stats())
- return output
-
- def setup_task(self) -> None:
- """Performs one-off task setup.."""
- self._setup_step_interpreter.interpret(self._task.setup_steps)
-
- def stop(self) -> None:
- """Suspends task processing."""
- self._stop_logcat_thread()
-
- def start(
- self,
- adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser],
- log_stream: log_stream_lib.LogStream) -> None:
- """Starts task processing."""
-
- self._start_logcat_thread(log_stream=log_stream)
- self._logcat_thread.resume()
- self._start_dumpsys_thread(adb_call_parser_factory())
- self._start_setup_step_interpreter(adb_call_parser_factory())
-
- def reset_task(self) -> None:
- """Resets a task for a new run."""
-
- self._logcat_thread.pause()
- self._setup_step_interpreter.interpret(self._task.reset_steps)
- self._logcat_thread.resume()
-
- # Reset some other variables.
- if not self._is_bad_episode:
- self._bad_state_counter = 0
- self._is_bad_episode = False
-
- self._task_start_time = datetime.datetime.now()
- with self._lock:
- self._latest_values = {
- 'reward': 0.0,
- 'score': 0.0,
- 'extra': {},
- 'episode_end': False,
- }
-
- def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
- """Performs one RL step."""
-
- self._stats['episode_steps'] = 0
-
- self._logcat_thread.line_ready().wait()
- with self._lock:
- extras = self._get_current_extras()
-
- observation['extras'] = extras
-
- return dm_env.TimeStep(
- step_type=dm_env.StepType.FIRST,
- reward=0.0,
- discount=0.0,
- observation=observation)
-
- def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
- """Performs one RL step."""
-
- self._stats['episode_steps'] += 1
-
- self._logcat_thread.line_ready().wait()
- with self._lock:
- reward = self._get_current_reward()
- extras = self._get_current_extras()
- transition_fn = self._determine_transition_fn()
-
- observation['extras'] = extras
-
- return transition_fn(reward=reward, observation=observation)
-
- def _get_current_reward(self) -> float:
- """Returns total reward accumulated since the last step."""
- reward = self._latest_values['reward']
- self._latest_values['reward'] = 0.0
- return reward
-
- def _get_current_extras(self) -> dict[str, Any]:
- """Returns task extras accumulated since the last step."""
- extras = {}
- for name, values in self._latest_values['extra'].items():
- extras[name] = np.stack(values)
- self._latest_values['extra'] = {}
- return extras
-
- def _determine_transition_fn(self) -> Callable[..., dm_env.TimeStep]:
- """Determines the type of RL transition will be used."""
-
- # Check if user existed the task
- if self._dumpsys_thread.check_user_exited():
- self._increment_bad_state()
- self._stats['reset_count_user_exited'] += 1
- logging.warning('User exited the task. Truncating the episode.')
- logging.info('************* END OF EPISODE *************')
- return dm_env.truncation
-
- # Check if episode has ended
- if self._latest_values['episode_end']:
- self._stats['reset_count_episode_end'] += 1
- logging.info('End of episode from logcat! Ending episode.')
- return dm_env.termination
-
- # Check if step limit or time limit has been reached
- if self._task.max_episode_steps > 0:
- if self._stats['episode_steps'] > self._task.max_episode_steps:
- self._stats['reset_count_max_duration_reached'] += 1
- logging.info('Maximum task duration (%r steps) reached. '
- 'Truncating the episode.', self._task.max_episode_steps)
- return dm_env.truncation
-
- if self._task.max_episode_sec > 0.0:
- task_duration = datetime.datetime.now() - self._task_start_time
- max_episode_sec = self._task.max_episode_sec
- if task_duration > datetime.timedelta(seconds=int(max_episode_sec)):
- self._stats['reset_count_max_duration_reached'] += 1
- logging.info('Maximum task duration (%r sec) reached. '
- 'Truncating the episode.', max_episode_sec)
- return dm_env.truncation
-
- return dm_env.transition
-
- def _start_setup_step_interpreter(
- self, adb_call_parser: adb_call_parser_lib.AdbCallParser):
- self._setup_step_interpreter = setup_step_interpreter.SetupStepInterpreter(
- adb_call_parser=adb_call_parser)
-
- def _start_logcat_thread(self, log_stream: log_stream_lib.LogStream):
- log_stream.set_log_filters(list(self._task.log_parsing_config.filters))
- self._logcat_thread = logcat_thread.LogcatThread(log_stream=log_stream)
-
- for event_listener in self._logcat_listeners():
- self._logcat_thread.add_event_listener(event_listener)
-
- def _start_dumpsys_thread(self,
- adb_call_parser: adb_call_parser_lib.AdbCallParser):
- self._dumpsys_thread = dumpsys_thread.DumpsysThread(
- app_screen_checker=app_screen_checker.AppScreenChecker(
- adb_call_parser=adb_call_parser,
- expected_app_screen=self._task.expected_app_screen,
- ),
- check_frequency=self._config.dumpsys_check_frequency,
- max_failed_current_activity=self._config.max_failed_current_activity,
- )
-
- def _stop_logcat_thread(self):
- if self._logcat_thread is not None:
- self._logcat_thread.kill()
- self._logcat_thread = None
-
- def _increment_bad_state(self) -> None:
- """Increments the bad state counter.
-
- Bad states are errors that shouldn't happen and that trigger an
- episode reset. If enough bad states have been seen consecutively,
- we restart the simulation in the hope of returning the simulation
- to a good state.
- """
- logging.warning('Bad state detected.')
- if self._config.max_bad_states:
- self._is_bad_episode = True
- self._bad_state_counter += 1
- logging.warning('Bad state counter: %d.', self._bad_state_counter)
- if self._bad_state_counter >= self._config.max_bad_states:
- logging.error('Too many consecutive bad states. Restarting simulator.')
- self._stats['restart_count_max_bad_states'] += 1
- self._should_restart = True
- else:
- logging.warning('Max bad states not set, bad states will be ignored.')
-
- def _logcat_listeners(self):
- """Creates list of EventListeners for logcat thread."""
-
- # Defaults to 'a^' since that regex matches no string by definition.
- regexps = self._task.log_parsing_config.log_regexps
- listeners = []
-
- # Reward listeners
- def _reward_handler(event, match):
- del event
- reward = float(match.group(1))
- with self._lock:
- self._latest_values['reward'] += reward
-
- for regexp in regexps.reward:
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(regexp or 'a^'),
- handler_fn=_reward_handler))
-
- # RewardEvent listeners
- for reward_event in regexps.reward_event:
-
- def get_reward_event_handler(reward):
- def _reward_event_handler(event, match):
- del event, match
- with self._lock:
- self._latest_values['reward'] += reward
- return _reward_event_handler
-
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(reward_event.event or 'a^'),
- handler_fn=get_reward_event_handler(reward_event.reward)))
-
- # Score listener
- def _score_handler(event, match):
- del event
- current_score = float(match.group(1))
- with self._lock:
- current_reward = current_score - self._latest_values['score']
- self._latest_values['score'] = current_score
- self._latest_values['reward'] += current_reward
-
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(regexps.score or 'a^'),
- handler_fn=_score_handler))
-
- # Episode end listeners
- def _episode_end_handler(event, match):
- del event, match
- with self._lock:
- self._latest_values['episode_end'] = True
-
- for regexp in regexps.episode_end:
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(regexp or 'a^'),
- handler_fn=_episode_end_handler))
-
- # Extra listeners
- def _extras_handler(event, match):
- del event
- extra_name = match.group('name')
- extra = match.group('extra')
- if extra:
- try:
- extra = ast.literal_eval(extra)
- except (
- ValueError,
- TypeError,
- SyntaxError,
- MemoryError,
- RecursionError,
- ):
- logging.exception('Could not parse extra: %s', extra)
- # Don't try to process the extra as text; that would probably crash.
- return
- else:
- # No extra value provided for boolean extra. Setting value to True.
- extra = 1
- _process_extra(extra_name, extra)
-
- for regexp in regexps.extra:
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(regexp or 'a^'),
- handler_fn=_extras_handler))
-
- # JSON extra listeners
- def _json_extras_handler(event, match):
- del event
- extra_data = match.group('json_extra')
- try:
- extra = dict(json.loads(extra_data))
- except ValueError:
- logging.error('JSON string could not be parsed: %s', extra_data)
- return
- for extra_name, extra_value in extra.items():
- _process_extra(extra_name, extra_value)
-
- for regexp in regexps.json_extra:
- listeners.append(logcat_thread.EventListener(
- regexp=re.compile(regexp or 'a^'),
- handler_fn=_json_extras_handler))
-
- def _process_extra(extra_name, extra):
- extra = np.array(extra)
- with self._lock:
- latest_extras = self._latest_values['extra']
- if extra_name in latest_extras:
- # If latest extra is not flushed, append.
- if (
- len(latest_extras[extra_name])
- >= self._config.extras_max_buffer_size
- ):
- latest_extras[extra_name].pop(0)
- latest_extras[extra_name].append(extra)
- else:
- latest_extras[extra_name] = [extra]
- self._latest_values['extra'] = latest_extras
-
- return listeners
diff --git a/android_env/components/task_manager_test.py b/android_env/components/task_manager_test.py
deleted file mode 100644
index 471ecfd9..00000000
--- a/android_env/components/task_manager_test.py
+++ /dev/null
@@ -1,415 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.components.task_manager.py."""
-
-import json
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.components import adb_call_parser as adb_call_parser_lib
-from android_env.components import dumpsys_thread
-from android_env.components import log_stream
-from android_env.components import logcat_thread
-from android_env.components import setup_step_interpreter
-from android_env.components import task_manager
-from android_env.proto import task_pb2
-import numpy as np
-
-
-class TaskManagerTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.addCleanup(mock.patch.stopall) # Disable previous patches.
-
- self._setup_step_interpreter = mock.create_autospec(
- setup_step_interpreter.SetupStepInterpreter)
- self._dumpsys_thread = mock.create_autospec(dumpsys_thread.DumpsysThread)
- self._logcat_thread = mock.create_autospec(logcat_thread.LogcatThread)
- self._log_stream = mock.create_autospec(log_stream.LogStream)
-
- mock.patch.object(
- setup_step_interpreter,
- 'SetupStepInterpreter',
- return_value=self._setup_step_interpreter).start()
- mock.patch.object(
- dumpsys_thread, 'DumpsysThread',
- return_value=self._dumpsys_thread).start()
- mock.patch.object(
- logcat_thread, 'LogcatThread',
- return_value=self._logcat_thread).start()
- mock.patch.object(
- log_stream, 'LogStream',
- return_value=self._log_stream).start()
-
- def test_start(self):
- task_mgr = task_manager.TaskManager(task=task_pb2.Task())
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- self.assertIsNotNone(task_mgr._logcat_thread)
- self.assertIsNotNone(task_mgr._dumpsys_thread)
- self.assertIsNotNone(task_mgr._setup_step_interpreter)
-
- def test_setup_task(self):
- task_mgr = task_manager.TaskManager(task=task_pb2.Task())
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- self._setup_step_interpreter.interpret.assert_called_once()
-
- def test_step_count(self):
- task_mgr = task_manager.TaskManager(task=task_pb2.Task())
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- task_mgr.rl_reset(observation={})
- self.assertEqual(task_mgr.stats()['episode_steps'], 0)
- task_mgr.rl_step(observation={})
- self.assertEqual(task_mgr.stats()['episode_steps'], 1)
- task_mgr.rl_step(observation={})
- self.assertEqual(task_mgr.stats()['episode_steps'], 2)
- task_mgr.rl_reset(observation={})
- self.assertEqual(task_mgr.stats()['episode_steps'], 0)
-
- def test_get_current_reward(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- match = event_listener.regexp.match('Reward: 123.0')
- if match is None: # Ignore events that are not rewards.
- return
-
- event_listener.handler_fn(event_listener.regexp, match)
-
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.reward.extend([
- '^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$'
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertEqual(timestep.reward, 123.0)
- np.testing.assert_equal(timestep.observation['pixels'], np.array([1, 2, 3]))
-
- def test_reward_event(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- match_1 = event_listener.regexp.match('foo_1')
- match_2 = event_listener.regexp.match('foo_2')
- match_3 = event_listener.regexp.match('Reward: 2.0')
- if match_1:
- event_listener.handler_fn(event_listener.regexp, match_1)
- if match_2:
- event_listener.handler_fn(event_listener.regexp, match_2)
- if match_3:
- event_listener.handler_fn(event_listener.regexp, match_3)
-
- task = task_pb2.Task()
- reward_event_1 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent(
- event='foo_1', reward=5.0)
- reward_event_2 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent(
- event='foo_2', reward=-1.0)
- task.log_parsing_config.log_regexps.reward_event.extend(
- [reward_event_1, reward_event_2])
- task.log_parsing_config.log_regexps.reward.extend(
- ['^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$'])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertEqual(timestep.reward, 6.0)
-
- def test_get_current_reward_via_score(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- event = event_listener.regexp
- match = event.match('score: 200.0')
- if match is None: # Ignore events that are not scores.
- return
-
- event_listener.handler_fn(event, match)
-
- # Scores are accumulated by their differences, so a subsequent lower score
- # means that the final reward decreases.
- match = event.match('score: 185')
- event_listener.handler_fn(event, match)
-
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.score = (
- '^score: ([-+]?[0-9]*\\.?[0-9]*)$')
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertEqual(timestep.reward, 185.0)
-
- def test_get_current_extras(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- event = event_listener.regexp
- match = event.match('extra: some_extra [1, 2]')
- if match is None: # Ignore events that are not extras.
- return
-
- # Emit events.
- fn = event_listener.handler_fn
- fn(event, event.match('extra: an_extra [1, 2, 3]'))
- fn(event, event.match('extra: an_extra [4, 5, 6]'))
- fn(event, event.match('extra: another_extra 0.5'))
- fn(event, event.match('extra: multi_dimension_extra [[9,8,7],[6,5,4]]'))
- fn(event, event.match('extra: boolean_extra'))
-
- # Setup the task and trigger the listener.
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.extra.extend([
- '^extra: (?P[^ ]*)[ ]?(?P.*)$'
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
-
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
-
- # Check expectations.
- self.assertIn('extras', timestep.observation)
- extras = timestep.observation['extras']
- np.testing.assert_almost_equal([[1, 2, 3], [4, 5, 6]],
- extras.get('an_extra'))
- np.testing.assert_almost_equal([0.5], extras.get('another_extra'))
- np.testing.assert_almost_equal([[[9, 8, 7], [6, 5, 4]]],
- extras.get('multi_dimension_extra'))
- np.testing.assert_equal([1], extras.get('boolean_extra'))
-
- def test_get_current_extras_json_format(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- event = event_listener.regexp
- match = event.match('json_extra: {}')
- if match is None: # Ignore events that are not extras.
- return
-
- # Emit events.
- extra = {
- 'extra_scalar': 0,
- 'extra_list': [1, 2, 3, 4],
- 'extra_dict': {
- 'foo': 'bar'
- },
- 'extra_string': 'a_string'
- }
- extra_update = {'extra_string': 'a_new_string', 'extra_float': 0.6}
- fn = event_listener.handler_fn
- fn(event, event.match(f'json_extra: {json.dumps(extra)}'))
- fn(event, event.match(f'json_extra: {json.dumps(extra_update)}'))
-
- # Setup the task and trigger the listener.
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.json_extra.extend([
- '^json_extra: (?P.*)$'
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
-
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
-
- # Check expectations.
- self.assertIn('extras', timestep.observation)
- extras = timestep.observation['extras']
- expected_extra = {
- 'extra_scalar': [0],
- 'extra_list': [[1, 2, 3, 4]],
- 'extra_dict': [{
- 'foo': 'bar'
- }],
- 'extra_string': ['a_string', 'a_new_string'],
- 'extra_float': [0.6]
- }
- np.testing.assert_almost_equal(
- expected_extra.get('extra_scalar'), extras.get('extra_scalar'))
- np.testing.assert_almost_equal(
- expected_extra.get('extra_list'), extras.get('extra_list'))
- np.testing.assert_equal(
- expected_extra.get('extra_string'), extras.get('extra_string'))
- np.testing.assert_almost_equal(
- expected_extra.get('extra_float'), extras.get('extra_float'))
- np.testing.assert_equal(
- expected_extra.get('extra_dict'), extras.get('extra_dict'))
-
- def test_get_current_extras_failed_to_parse(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- event = event_listener.regexp
- match = event.match('extra: some_extra [1, 2]')
- if match is None: # Ignore events that are not extras.
- return
-
- # Emit events.
- fn = event_listener.handler_fn
- fn(event, event.match('extra: extra_with_malformed_1 [1]'))
- fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]'))
- fn(event, event.match('extra: extra_with_malformed_1 [2]'))
- fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]'))
- fn(event, event.match('extra: extra_with_malformed_2 [1]'))
- fn(event, event.match('extra: extra_malformed_only [_very_bad_news]'))
-
- # Setup the task and trigger the listener.
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.extra.extend([
- '^extra: (?P[^ ]*)[ ]?(?P.*)$'
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
-
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
-
- # Check expectations.
- self.assertIn('extras', timestep.observation)
- extras = timestep.observation['extras']
- np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'),
- [[1], [2]])
- np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]])
- self.assertNotIn('extra_malformed_only', extras)
-
- def test_multi_log_regexp(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- match = event_listener.regexp.match('Reward_2: 123.0')
- if match is None: # Ignore events that are not rewards.
- return
-
- event_listener.handler_fn(event_listener.regexp, match)
-
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.reward.extend([
- '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$',
- '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$'
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertEqual(timestep.reward, 123.0)
-
- def test_multi_reward_regexp(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.'
-
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- match_1 = event_listener.regexp.match('Reward_1: 5.0')
- match_2 = event_listener.regexp.match('Reward_2: 10.0')
-
- if match_1:
- event_listener.handler_fn(event_listener.regexp, match_1)
-
- if match_2:
- event_listener.handler_fn(event_listener.regexp, match_2)
-
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.reward.extend([
- '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$',
- '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$',
- ])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertEqual(timestep.reward, 15.0)
-
- def test_determine_transition_fn(self):
- # Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
- # right away.
- def my_add_ev_listener(event_listener: logcat_thread.EventListener):
- # Check that the event matches what's expected.
- event = event_listener.regexp
- match = event.match('I am done!')
- if match is None: # Ignore events that are not episode end.
- return
-
- event_listener.handler_fn(event, match)
-
- task = task_pb2.Task()
- task.log_parsing_config.log_regexps.episode_end.extend(['I am done!'])
- task_mgr = task_manager.TaskManager(task=task)
- self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
- adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
- task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
- task_mgr.setup_task()
- timestep = task_mgr.rl_step(
- observation={
- 'pixels': np.array([1, 2, 3]),
- })
- self.assertTrue(timestep.last())
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/env_interface.py b/android_env/env_interface.py
deleted file mode 100644
index 45e8faf3..00000000
--- a/android_env/env_interface.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Abstract AndroidEnv interface.
-
-AndroidEnv is a standard dm_env.Environment instance, but it also offers a few
-extra methods that clients may use for extended functionality.
-"""
-
-import abc
-from typing import Any
-
-from android_env.proto import adb_pb2
-from android_env.proto import state_pb2
-import dm_env
-import numpy as np
-
-
-class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
- """Pure virtual interface for AndroidEnv implementations."""
-
- # Methods required by dm_env.Environment.
-
- @abc.abstractmethod
- def action_spec(self) -> dict[str, dm_env.specs.Array]:
- """Returns the action specification."""
-
- @abc.abstractmethod
- def observation_spec(self) -> dict[str, dm_env.specs.Array]:
- """Returns the observation specification."""
-
- @abc.abstractmethod
- def reset(self) -> dm_env.TimeStep:
- """Resets the current episode."""
-
- @abc.abstractmethod
- def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- """Executes `action` and returns a `TimeStep`."""
-
- @abc.abstractmethod
- def close(self) -> None:
- """Frees up resources."""
-
- # Extensions provided by AndroidEnv.
-
- def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
- """Returns extra info provided by tasks."""
-
- return {}
-
- @property
- def raw_action(self):
- """Returns the latest action."""
-
- @property
- def raw_observation(self):
- """Returns the latest observation."""
-
- def stats(self) -> dict[str, Any]:
- """Returns information generated inside the implementation."""
-
- return {}
-
- def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
- """Executes `call` and returns its response."""
-
- return adb_pb2.AdbResponse()
-
- def load_state(
- self, request: state_pb2.LoadStateRequest
- ) -> state_pb2.LoadStateResponse:
- """Loads a state.
-
- Args:
- request: A `LoadStateRequest` containing any parameters necessary to
- specify how/what state to load.
-
- Returns:
- A `LoadStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
- raise NotImplementedError('This environment does not support loading state')
-
- def save_state(
- self, request: state_pb2.SaveStateRequest
- ) -> state_pb2.SaveStateResponse:
- """Saves a state.
-
- Args:
- request: A `SaveStateRequest` containing any parameters necessary to
- specify how/what state to save.
-
- Returns:
- A `SaveStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
- raise NotImplementedError('This environment does not support saving state')
diff --git a/android_env/environment.py b/android_env/environment.py
deleted file mode 100644
index 85638e68..00000000
--- a/android_env/environment.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Android environment implementation."""
-
-from typing import Any
-
-from absl import logging
-from android_env import env_interface
-from android_env.components import adb_call_parser
-from android_env.components import coordinator as coordinator_lib
-from android_env.components import task_manager as task_manager_lib
-from android_env.components.simulators import base_simulator
-from android_env.proto import adb_pb2
-from android_env.proto import state_pb2
-import dm_env
-import numpy as np
-
-
-class AndroidEnv(env_interface.AndroidEnvInterface):
- """An RL environment that interacts with Android apps."""
-
- def __init__(
- self,
- simulator: base_simulator.BaseSimulator,
- coordinator: coordinator_lib.Coordinator,
- task_manager: task_manager_lib.TaskManager,
- ):
- """Initializes the state of this AndroidEnv object."""
-
- self._simulator = simulator
- self._coordinator = coordinator
- self._task_manager = task_manager
- self._latest_action = {}
- self._latest_observation = {}
- self._latest_extras = {}
- self._reset_next_step = True
- self._is_closed = False
-
- logging.info('Action spec: %s', self.action_spec())
- logging.info('Observation spec: %s', self.observation_spec())
-
- def __del__(self) -> None:
- self.close()
-
- # Methods required by dm_env.Environment.
-
- def action_spec(self) -> dict[str, dm_env.specs.Array]:
- return self._coordinator.action_spec()
-
- def observation_spec(self) -> dict[str, dm_env.specs.Array]:
- return self._coordinator.observation_spec()
-
- def reset(self) -> dm_env.TimeStep:
- """Resets the environment for a new RL episode."""
-
- logging.info('Resetting AndroidEnv...')
-
- # Execute a reset. Timestep will be of type FIRST.
- timestep = self._coordinator.rl_reset()
-
- # Process relevant information.
- if timestep.observation is not None:
- self._latest_extras = timestep.observation.pop('extras')
- self._latest_observation = timestep.observation.copy()
- else:
- # If the observation is None, we return the latest observation again.
- timestep = timestep._replace(observation=self._latest_observation.copy())
-
- self._latest_action = {}
- self._reset_next_step = False
-
- logging.info('Done resetting AndroidEnv.')
- logging.info('************* NEW EPISODE *************')
-
- return timestep
-
- def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- """Takes a step in the environment."""
-
- # Check if it's time to reset the episode.
- if self._reset_next_step:
- return self.reset()
-
- # Execute selected action.
- timestep = self._coordinator.rl_step(action)
-
- # Process relevant information.
- if timestep.observation is not None:
- self._latest_extras = timestep.observation.pop('extras')
- self._latest_observation = timestep.observation.copy()
- else:
- # If the observation is None, we return the latest observation again.
- timestep = timestep._replace(observation=self._latest_observation.copy())
-
- self._latest_action = action.copy()
-
- if timestep.last():
- self._reset_next_step = True
- logging.info('************* END OF EPISODE *************')
-
- return timestep
-
- def close(self) -> None:
- """Cleans up running processes, threads and local files."""
- if not self._is_closed:
- logging.info('Cleaning up AndroidEnv...')
- if hasattr(self, '_coordinator'):
- self._coordinator.close()
- logging.info('Done cleaning up AndroidEnv.')
- self._is_closed = True
-
- # Extensions provided by AndroidEnv.
-
- def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
- """Returns latest task extras."""
-
- task_extras = {} # Build a copy to avoid reusing objects.
- for k, spec in self._latest_extras.items():
- extra_values = spec.astype(spec.dtype)
- task_extras[k] = extra_values[-1] if latest_only else extra_values
- return task_extras
-
- @property
- def raw_action(self):
- return self._latest_action.copy()
-
- @property
- def raw_observation(self):
- return self._latest_observation.copy()
-
- def stats(self) -> dict[str, Any]:
- coordinator_stats = self._coordinator.stats()
- task_manager_stats = self._task_manager.stats()
- return coordinator_stats | task_manager_stats
-
- def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
- return self._coordinator.execute_adb_call(call)
-
- def load_state(
- self, request: state_pb2.LoadStateRequest
- ) -> state_pb2.LoadStateResponse:
- """Loads a state.
-
- Args:
- request: A `LoadStateRequest` containing any parameters necessary to
- specify how/what state to load.
-
- Returns:
- A `LoadStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
-
- self._task_manager.stop()
- response = self._simulator.load_state(request)
- self._task_manager.start(
- adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
- self._simulator.create_adb_controller()
- ),
- log_stream=self._simulator.create_log_stream(),
- )
- return response
-
- def save_state(
- self, request: state_pb2.SaveStateRequest
- ) -> state_pb2.SaveStateResponse:
- """Saves a state.
-
- Args:
- request: A `SaveStateRequest` containing any parameters necessary to
- specify how/what state to save.
-
- Returns:
- A `SaveStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
-
- return self._simulator.save_state(request)
diff --git a/android_env/environment_test.py b/android_env/environment_test.py
deleted file mode 100644
index b74c5989..00000000
--- a/android_env/environment_test.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Unit tests for AndroidEnv."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import environment
-from android_env.components import config_classes
-from android_env.components import coordinator as coordinator_lib
-from android_env.components import task_manager as task_manager_lib
-from android_env.components.simulators import base_simulator
-from android_env.components.simulators.fake import fake_simulator
-from android_env.proto import adb_pb2
-from android_env.proto import state_pb2
-import dm_env
-import numpy as np
-
-
-def _create_mock_coordinator() -> coordinator_lib.Coordinator:
- coordinator = mock.create_autospec(coordinator_lib.Coordinator)
- coordinator.action_spec.return_value = {
- 'action_type':
- dm_env.specs.DiscreteArray(num_values=3),
- 'touch_position':
- dm_env.specs.BoundedArray(
- shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
- }
- coordinator.observation_spec.return_value = {
- 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
- 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
- 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
- }
- return coordinator
-
-
-def _create_fake_simulator() -> fake_simulator.FakeSimulator:
- return fake_simulator.FakeSimulator(
- config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
- )
-
-
-class AndroidEnvTest(absltest.TestCase):
-
- def test_specs(self):
- simulator = _create_fake_simulator()
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
-
- # Check action spec.
- self.assertNotEmpty(env.action_spec())
- self.assertIn('action_type', env.action_spec())
- self.assertIsInstance(env.action_spec()['action_type'],
- dm_env.specs.DiscreteArray)
- self.assertIn('touch_position', env.action_spec())
- self.assertIsInstance(env.action_spec()['touch_position'],
- dm_env.specs.BoundedArray)
-
- # Check observation spec.
- self.assertNotEmpty(env.observation_spec())
- self.assertIn('pixels', env.observation_spec())
- self.assertIsInstance(env.observation_spec()['pixels'], dm_env.specs.Array)
- # The `pixels` entry in the observation spec should match the screen size of
- # the simulator with three color channels (RGB).
- self.assertEqual(env.observation_spec()['pixels'].shape, (123, 456, 3))
- self.assertIn('timedelta', env.observation_spec())
- self.assertIsInstance(env.observation_spec()['timedelta'],
- dm_env.specs.Array)
- # The `timedelta` should be a scalar.
- self.assertEqual(env.observation_spec()['timedelta'].shape, ())
- self.assertIn('orientation', env.observation_spec())
- # The `orientation` should be a one-hot vector with four dimensions.
- self.assertIsInstance(env.observation_spec()['orientation'],
- dm_env.specs.Array)
- self.assertEqual(env.observation_spec()['orientation'].shape, (4,))
-
- def test_reset_and_step(self):
- simulator = _create_fake_simulator()
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- coordinator.action_spec.return_value = {
- 'action_type':
- dm_env.specs.DiscreteArray(num_values=3),
- 'touch_position':
- dm_env.specs.BoundedArray(
- shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
- }
- coordinator.observation_spec.return_value = {
- 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
- 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
- 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
- }
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
- coordinator.rl_reset.return_value = dm_env.TimeStep(
- step_type=dm_env.StepType.FIRST,
- reward=0.0,
- discount=0.0,
- observation={
- 'pixels': np.random.rand(987, 654, 3),
- 'timedelta': 123456,
- 'orientation': np.array((1, 0, 0, 0)),
- 'extras': {
- 'click': np.array([[246]], dtype=np.int64)
- }
- },
- )
-
- ts = env.reset()
- self.assertIsInstance(ts, dm_env.TimeStep)
- # After a `reset()` the TimeStep should follow some expectations.
- self.assertTrue(ts.first())
- self.assertEqual(ts.reward, 0.0)
- self.assertEqual(ts.discount, 0.0)
- obs = ts.observation
- self.assertIn('pixels', obs)
- self.assertEqual(obs['pixels'].shape, (987, 654, 3))
- self.assertIn('timedelta', obs)
- self.assertEqual(obs['timedelta'], 123456)
- self.assertIn('orientation', obs)
- self.assertEqual(obs['orientation'].shape, (4,))
- np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
-
- # Extras should also be provided.
- extras = env.task_extras()
- self.assertIn('click', extras)
- self.assertEqual(extras['click'], np.array([246], dtype=np.int64))
-
- coordinator.stats.return_value = {'my_measurement': 135}
- task_manager.stats.return_value = {'another_measurement': 79}
-
- # Step again in the environment and check expectations again.
- pixels = np.random.rand(987, 654, 3)
- latest_observation = {
- 'pixels': pixels,
- 'timedelta': 123456,
- 'orientation': np.array((1, 0, 0, 0)),
- 'extras': {
- 'click': np.array([[246]], dtype=np.int64)
- }
- }
- coordinator.rl_step.return_value = dm_env.transition(
- reward=0.0,
- discount=0.0,
- observation=latest_observation,
- )
- ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
- self.assertIsInstance(ts, dm_env.TimeStep)
- # The StepType now should NOT be FIRST.
- self.assertFalse(ts.first())
- self.assertEqual(ts.reward, 0.0)
- self.assertEqual(ts.discount, 0.0)
- obs = ts.observation
- self.assertIn('pixels', obs)
- self.assertEqual(obs['pixels'].shape, (987, 654, 3))
- self.assertIn('timedelta', obs)
- self.assertEqual(obs['timedelta'], 123456)
- self.assertIn('orientation', obs)
- self.assertEqual(obs['orientation'].shape, (4,))
- np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
-
- # Extras should still be provided.
- extras = env.task_extras()
- self.assertIn('click', extras)
- self.assertEqual(extras['click'], np.array([246], dtype=np.int64))
-
- # At this point these methods and properties should return something.
- self.assertNotEmpty(env.stats())
- self.assertNotEmpty(env.raw_observation)
- self.assertNotIn('extras', env.raw_observation)
- self.assertNotEmpty(env.raw_action)
-
- # If the observation is None, we want to return the latest observation.
- coordinator.rl_step.return_value = dm_env.truncation(
- reward=0.0,
- observation=None,
- )
- ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
- self.assertIsInstance(ts, dm_env.TimeStep)
- # Assert the observation matches the latest observation.
- obs = ts.observation
- self.assertIn('pixels', obs)
- self.assertEqual(obs['pixels'].shape, (987, 654, 3))
- np.testing.assert_equal(obs['pixels'], pixels)
- self.assertIn('timedelta', obs)
- self.assertEqual(obs['timedelta'], 123456)
- self.assertIn('orientation', obs)
- self.assertEqual(obs['orientation'].shape, (4,))
- np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
-
- def test_adb_call(self):
- simulator = _create_fake_simulator()
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
- call = adb_pb2.AdbRequest(
- force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
- expected_response = adb_pb2.AdbResponse(
- status=adb_pb2.AdbResponse.Status.OK)
- coordinator.execute_adb_call.return_value = expected_response
-
- response = env.execute_adb_call(call)
-
- self.assertEqual(response, expected_response)
- coordinator.execute_adb_call.assert_called_once_with(call)
-
- def test_load_state(self):
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
- expected_response = state_pb2.LoadStateResponse(
- status=state_pb2.LoadStateResponse.Status.OK
- )
- request = state_pb2.LoadStateRequest(args={'foo': 'bar'})
- simulator.load_state.return_value = expected_response
- response = env.load_state(request)
- self.assertEqual(response, expected_response)
- simulator.load_state.assert_called_once_with(request)
- task_manager.stop.assert_called_once()
- task_manager.start.assert_called_once()
-
- def test_save_state(self):
- simulator = mock.create_autospec(base_simulator.BaseSimulator)
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
- expected_response = state_pb2.SaveStateResponse(
- status=state_pb2.SaveStateResponse.Status.OK
- )
- request = state_pb2.SaveStateRequest(args={'foo': 'bar'})
- simulator.save_state.return_value = expected_response
- response = env.save_state(request)
- self.assertEqual(response, expected_response)
- simulator.save_state.assert_called_once_with(request)
-
- def test_double_close(self):
- simulator = _create_fake_simulator()
- coordinator = _create_mock_coordinator()
- task_manager = mock.create_autospec(task_manager_lib.TaskManager)
- env = environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
- env.close()
- env.close()
- coordinator.close.assert_called_once()
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/loader.py b/android_env/loader.py
deleted file mode 100644
index 0588ff33..00000000
--- a/android_env/loader.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Function for loading AndroidEnv."""
-
-import os
-
-from absl import logging
-from android_env import environment
-from android_env.components import config_classes
-from android_env.components import coordinator as coordinator_lib
-from android_env.components import device_settings as device_settings_lib
-from android_env.components import task_manager as task_manager_lib
-from android_env.components.simulators.emulator import emulator_simulator
-from android_env.components.simulators.fake import fake_simulator
-from android_env.proto import task_pb2
-
-from google.protobuf import text_format
-
-
-def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task:
- """Returns the task according to `task_config`."""
-
- task = task_pb2.Task()
- match task_config:
- case config_classes.FilesystemTaskConfig():
- with open(task_config.path, 'r') as proto_file:
- text_format.Parse(proto_file.read(), task)
- case _:
- logging.error('Unsupported TaskConfig: %r', task_config)
-
- return task
-
-
-def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv:
- """Loads an AndroidEnv instance."""
-
- task = _load_task(config.task)
- task_manager = task_manager_lib.TaskManager(task)
-
- match config.simulator:
- case config_classes.EmulatorConfig():
- _process_emulator_launcher_config(config.simulator)
- simulator = emulator_simulator.EmulatorSimulator(config=config.simulator)
- case config_classes.FakeSimulatorConfig():
- simulator = fake_simulator.FakeSimulator(config=config.simulator)
- case _:
- raise ValueError('Unsupported simulator config: {config.simulator}')
-
- device_settings = device_settings_lib.DeviceSettings(simulator)
- coordinator = coordinator_lib.Coordinator(
- simulator, task_manager, device_settings
- )
- return environment.AndroidEnv(
- simulator=simulator, coordinator=coordinator, task_manager=task_manager
- )
-
-
-def _process_emulator_launcher_config(
- emulator_config: config_classes.EmulatorConfig,
-) -> None:
- """Adjusts the configuration of the emulator depending on some conditions."""
-
- # Expand the user directory if specified.
- launcher_config = emulator_config.emulator_launcher
- launcher_config.android_avd_home = os.path.expanduser(
- launcher_config.android_avd_home
- )
- launcher_config.android_sdk_root = os.path.expanduser(
- launcher_config.android_sdk_root
- )
- launcher_config.emulator_path = os.path.expanduser(
- launcher_config.emulator_path
- )
- emulator_config.adb_controller.adb_path = os.path.expanduser(
- emulator_config.adb_controller.adb_path
- )
diff --git a/android_env/loader_test.py b/android_env/loader_test.py
deleted file mode 100644
index 280fa4ab..00000000
--- a/android_env/loader_test.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for loader."""
-
-import builtins
-import os
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env import loader
-from android_env.components import config_classes
-from android_env.components import coordinator as coordinator_lib
-from android_env.components import device_settings as device_settings_lib
-from android_env.components import task_manager as task_manager_lib
-from android_env.components.simulators.emulator import emulator_simulator
-from android_env.components.simulators.fake import fake_simulator
-from android_env.proto import task_pb2
-
-
-class LoaderTest(absltest.TestCase):
-
- @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
- @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
- @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True)
- @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
- @mock.patch.object(builtins, 'open', autospec=True)
- def test_load_emulator(
- self,
- mock_open,
- mock_coordinator,
- mock_device_settings,
- mock_simulator_class,
- mock_task_manager,
- ):
-
- # Arrange.
- mock_open.return_value.__enter__ = mock_open
- mock_open.return_value.read.return_value = ''
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path='some/path/'),
- simulator=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- avd_name='my_avd',
- android_avd_home='~/.android/avd',
- android_sdk_root='~/Android/Sdk',
- emulator_path='~/Android/Sdk/emulator/emulator',
- run_headless=False,
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='~/Android/Sdk/platform-tools/adb',
- ),
- ),
- )
-
- # Act.
- env = loader.load(config)
-
- # Assert.
- self.assertIsInstance(env, env_interface.AndroidEnvInterface)
- mock_simulator_class.assert_called_with(
- config=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- avd_name='my_avd',
- android_avd_home=os.path.expanduser('~/.android/avd'),
- android_sdk_root=os.path.expanduser('~/Android/Sdk'),
- emulator_path=os.path.expanduser(
- '~/Android/Sdk/emulator/emulator'
- ),
- run_headless=False,
- gpu_mode='swangle_indirect',
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
- adb_server_port=5037,
- ),
- )
- )
- mock_coordinator.assert_called_with(
- mock_simulator_class.return_value,
- mock_task_manager.return_value,
- mock_device_settings.return_value,
- )
-
- @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
- @mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True)
- @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True)
- @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
- @mock.patch.object(builtins, 'open', autospec=True)
- def test_load_fake_simulator(
- self,
- mock_open,
- mock_coordinator,
- mock_device_settings,
- mock_simulator_class,
- mock_task_manager,
- ):
-
- # Arrange.
- mock_open.return_value.__enter__ = mock_open
- mock_open.return_value.read.return_value = ''
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path='some/path/'),
- simulator=config_classes.FakeSimulatorConfig(
- screen_dimensions=(1234, 5678)
- ),
- )
-
- # Act.
- env = loader.load(config)
-
- # Assert.
- self.assertIsInstance(env, env_interface.AndroidEnvInterface)
- mock_simulator_class.assert_called_with(
- config=config_classes.FakeSimulatorConfig(
- screen_dimensions=(1234, 5678)
- )
- )
- mock_coordinator.assert_called_with(
- mock_simulator_class.return_value,
- mock_task_manager.return_value,
- mock_device_settings.return_value,
- )
-
- @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
- @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
- @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
- @mock.patch.object(builtins, 'open', autospec=True)
- def test_task(
- self, mock_open, mock_coordinator, mock_simulator, mock_task_manager
- ):
-
- # Arrange.
- del mock_coordinator, mock_simulator
- mock_open.return_value.__enter__ = mock_open
- mock_open.return_value.read.return_value = r'''
-id: "fake_task"
-name: "Fake Task"
-description: "Task for testing loader."
-max_episode_sec: 0
-'''
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path='some/path/'),
- simulator=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- avd_name='my_avd'
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path='~/Android/Sdk/platform-tools/adb',
- ),
- ),
- )
-
- # Act.
- env = loader.load(config)
-
- # Assert.
- expected_task = task_pb2.Task()
- expected_task.id = 'fake_task'
- expected_task.name = 'Fake Task'
- expected_task.description = 'Task for testing loader.'
- expected_task.max_episode_sec = 0
-
- mock_task_manager.assert_called_with(expected_task)
- self.assertIsInstance(env, env_interface.AndroidEnvInterface)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/proto/__init__.py b/android_env/proto/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/proto/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/proto/a11y/__init__.py b/android_env/proto/a11y/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/proto/a11y/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/proto/a11y/a11y.proto b/android_env/proto/a11y/a11y.proto
deleted file mode 100644
index 90c266bb..00000000
--- a/android_env/proto/a11y/a11y.proto
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/a11y/android_accessibility_forest.proto";
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// A service to send Accessibility information to a remote server.
-//
-// The client is assumed to be running inside an Android device (e.g. emulator
-// or real device) while the server is assumed to be running outside (e.g. in a
-// Python process).
-service A11yService {
- // Sends a forest of Accessibility trees to a server.
- rpc SendForest(AndroidAccessibilityForest) returns (ForestResponse) {}
- // Sends an a11y event to a server.
- rpc SendEvent(EventRequest) returns (EventResponse) {}
-
- // Long-lived bidirection communication between the client and the server.
- rpc Bidi(stream ClientToServer) returns (stream ServerToClient) {}
-}
-
-// TODO(b/334952387): Remove `ForestResponse`, `EventRequest` and
-// `EventResponse` once bidi communication is in-place.
-message ForestResponse {
- // The error if anything.
- string error = 1;
-}
-
-// An Accessibility event.
-message EventRequest {
- // A single event as a dictionary.
- map event = 1;
-}
-
-message EventResponse {
- // The error if anything.
- string error = 1;
-}
-
-// The message sent from the Android device to the server running outside of the
-// device.
-message ClientToServer {
- oneof payload {
- EventRequest event = 1;
- AndroidAccessibilityForest forest = 2;
- }
-}
-
-// The message sent from the server running outside of the device to the Android
-// device.
-message ServerToClient {
- // A request to obtain the Accessibility forest.
- message GetA11yForest {}
-
- oneof payload {
- GetA11yForest get_forest = 1;
- }
-}
diff --git a/android_env/proto/a11y/android_accessibility_action.proto b/android_env/proto/a11y/android_accessibility_action.proto
deleted file mode 100644
index 5273d7d9..00000000
--- a/android_env/proto/a11y/android_accessibility_action.proto
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// An Android Accessibility Action.
-// Next index: 3
-message AndroidAccessibilityAction {
- // Required ID that uniquely identifies the action for this node.
- // Can be one of the standard action IDs listed in the documentation.
- // https://developer.android.com/reference/android/view/accessibility/AccessibilityNodeInfo.AccessibilityAction
- int32 id = 1;
-
- // Optional label describing what the action is.
- string label = 2;
-}
diff --git a/android_env/proto/a11y/android_accessibility_forest.proto b/android_env/proto/a11y/android_accessibility_forest.proto
deleted file mode 100644
index 63ddb8c5..00000000
--- a/android_env/proto/a11y/android_accessibility_forest.proto
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/a11y/android_accessibility_window_info.proto";
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// A forest of Android accessibility trees. Each tree belongs to a single
-// window. Next index: 2
-message AndroidAccessibilityForest {
- // All of the windows present on screen.
- repeated AndroidAccessibilityWindowInfo windows = 1;
-}
diff --git a/android_env/proto/a11y/android_accessibility_node_info.proto b/android_env/proto/a11y/android_accessibility_node_info.proto
deleted file mode 100644
index 3c904c86..00000000
--- a/android_env/proto/a11y/android_accessibility_node_info.proto
+++ /dev/null
@@ -1,122 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/a11y/android_accessibility_action.proto";
-import "android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto";
-import "android_env/proto/a11y/rect.proto";
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// An Android AccessibilityNodeInfo.
-// Next index: 32
-message AndroidAccessibilityNodeInfo {
- // Unique monotonically-increasing ID.
- int32 unique_id = 1;
-
- // The bounds of this node within the device's screen.
- ProtoRect bounds_in_screen = 2;
-
- // The name of the View class that created this node.
- string class_name = 3;
-
- // The content description of the node.
- string content_description = 4;
-
- // The hint text of the node.
- string hint_text = 5;
-
- // The name of the package this node comes from.
- string package_name = 6;
-
- // The text of this node.
- string text = 7;
-
- // The start index of the text selection.
- int64 text_selection_start = 8;
-
- // The end index of the text selection.
- int64 text_selection_end = 9;
-
- // The view ID resource name of the node.
- string view_id_resource_name = 10;
-
- // The ID of the window this node belongs to.
- int32 window_id = 11;
-
- // If true, this node can be checked.
- bool is_checkable = 12;
-
- // If true, this node is currently checked.
- bool is_checked = 13;
-
- // If true, this node (probably) responds to being clicked.
- bool is_clickable = 14;
-
- // If true, this node's text can be edited by the user.
- bool is_editable = 15;
-
- // If true, this node is enabled (e.g., if it is a button).
- bool is_enabled = 16;
-
- // If true, this node can be focused (e.g., a text input).
- bool is_focusable = 17;
-
- // If true, this node is currently focused.
- bool is_focused = 18;
-
- // If true, this node (probably) responds to being long pressed.
- bool is_long_clickable = 19;
-
- // If true, this node is a password input.
- bool is_password = 20;
-
- // If true, this node can be scrolled.
- bool is_scrollable = 21;
-
- // If true, this node is currently selected.
- bool is_selected = 22;
-
- // If true, this node is (probably) visible to the user.
- bool is_visible_to_user = 23;
-
- // List of actions that can be performed on this node.
- repeated AndroidAccessibilityAction actions = 24;
-
- // Ordered list of child IDs (i.e., unique_id).
- repeated int32 child_ids = 25 [packed = true];
-
- // List of clickable spans present in the node's text or content description.
- repeated AndroidAccessibilityNodeInfoClickableSpan clickable_spans = 26;
-
- // The depth of this node in the accessibility tree.
- int32 depth = 27;
-
- // Unique ID of the node that this node is declaring itself to be labeled by.
- int32 labeled_by_id = 28;
-
- // Unique ID of the node that this is node is declaring itself to be a label
- // for.
- int32 label_for_id = 29;
-
- // The drawing order for the node.
- int32 drawing_order = 30;
-
- // The tooltip text of the node.
- string tooltip_text = 31;
-}
diff --git a/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto b/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto
deleted file mode 100644
index f20d0bfd..00000000
--- a/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// A single clickable span found in the accessibility node's text.
-// Next index: 6
-message AndroidAccessibilityNodeInfoClickableSpan {
- // The source of the span (so the client can find the correct spannable string
- // in the node).
- // Next index: 3
- enum SpanSource {
- UNKNOWN_TYPE = 0; // Catch all type for forward compatibility.
- TEXT = 1; // The span is from node#getText
- CONTENT_DESCRIPTION = 2; // The span is from node#getContentDescription.
- }
-
- // The text of the span (a substring of the spannable string).
- string text = 1;
-
- // The URL attached to the span if specified.
- string url = 2;
-
- // The source of the span.
- SpanSource source = 3;
-
- // The index of the first character of the span in the spannable string.
- // The end of the span would be a sum of span_start and text.length().
- int32 start = 4;
-
- // The unique_id from the corresponding AndroidAccessibilityNodeInfo.
- int32 node_id = 5;
-}
diff --git a/android_env/proto/a11y/android_accessibility_tree.proto b/android_env/proto/a11y/android_accessibility_tree.proto
deleted file mode 100644
index 4bc48ef9..00000000
--- a/android_env/proto/a11y/android_accessibility_tree.proto
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/a11y/android_accessibility_node_info.proto";
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// A tree (actually a graph) of Android accessibility nodes.
-// Next index: 3
-message AndroidAccessibilityTree {
- // All of the nodes in the graph. The root node is the node whose ID is 0.
- repeated AndroidAccessibilityNodeInfo nodes = 1;
-}
diff --git a/android_env/proto/a11y/android_accessibility_window_info.proto b/android_env/proto/a11y/android_accessibility_window_info.proto
deleted file mode 100644
index 2e0baeec..00000000
--- a/android_env/proto/a11y/android_accessibility_window_info.proto
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/a11y/android_accessibility_tree.proto";
-import "android_env/proto/a11y/rect.proto";
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// An Android AccessibilityWindowInfo.
-// Next index: 12
-message AndroidAccessibilityWindowInfo {
- // Type of the window.
- // Next index: 6
- enum WindowType {
- // The window type is an unknown value.
- UNKNOWN_TYPE = 0;
-
- // A standard application window.
- TYPE_APPLICATION = 1;
-
- // An IME window (e.g. GBoard).
- TYPE_INPUT_METHOD = 2;
-
- // A system window (e.g., a notification).
- TYPE_SYSTEM = 3;
-
- // An accessibility overlay.
- TYPE_ACCESSIBILITY_OVERLAY = 4;
-
- // A system window used to divide the screen in split-screen mode. This type
- // of window is present only in split-screen mode.
- TYPE_SPLIT_SCREEN_DIVIDER = 5;
- }
-
- // Bounds of this window in the device's screen.
- ProtoRect bounds_in_screen = 1;
-
- // A unique ID identifying the display in which this window is shown.
- int32 display_id = 2;
-
- // Unique ID as defined by the Android platform.
- int32 id = 3;
-
- // Z-index of the window. Windows with a greater z-index appear in front of
- // those with a lesser z-index.
- int32 layer = 4;
-
- // The title of the window, if set.
- string title = 5;
-
- // The type of the window.
- WindowType window_type = 6;
-
- // If true, the window is currently accessibility-focused.
- bool is_accessibility_focused = 7;
-
- // If true, the window is currently active.
- bool is_active = 8;
-
- // If true, the window is currently focused.
- bool is_focused = 9;
-
- // If true, the window is in Picture in Picture mode.
- bool is_in_picture_in_picture_mode = 10;
-
- // The associated accessibility tree for this window.
- AndroidAccessibilityTree tree = 11;
-}
diff --git a/android_env/proto/a11y/rect.proto b/android_env/proto/a11y/rect.proto
deleted file mode 100644
index 58167b44..00000000
--- a/android_env/proto/a11y/rect.proto
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-option java_multiple_files = true;
-option java_package = "com.google.androidenv.accessibilityforwarder";
-
-// Proto representation of Android Rect.
-// https://developer.android.com/reference/android/graphics/Rect
-// Next index: 5
-message ProtoRect {
- int32 left = 1;
- int32 top = 2;
- int32 right = 3;
- int32 bottom = 4;
-}
diff --git a/android_env/proto/adb.proto b/android_env/proto/adb.proto
deleted file mode 100644
index 86e65709..00000000
--- a/android_env/proto/adb.proto
+++ /dev/null
@@ -1,433 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-message AdbRequest {
- // Installs an APK into the simulator.
- message InstallApk {
-
- // A location in the filesystem.
- message Filesystem {
- string path = 1;
- }
-
- // A byte sequence of a single APK file.
- message Blob {
- // The serialized file as bytes.
- bytes contents = 1;
- }
-
- oneof location {
- Filesystem filesystem = 2;
- Blob blob = 6;
- }
- }
-
- message StartActivity {
- string full_activity = 1;
- repeated string extra_args = 2;
- // Whether to stop the current app before starting the activity.
- // Notice that if this option is `true`, the activity probably needs the
- // `android:launchMode="singleTop"` attribute in its `AndroidManifest.xml`,
- // otherwise intents may not be received by `onNewIntent()`. Please see more
- // info on `android:launchMode` at
- // https://developer.android.com/guide/topics/manifest/activity-element.
- bool force_stop = 3;
- }
-
- message SendBroadcast {
- // Action to send during the broadcast event.
- string action = 1;
-
- // Specify the component name with package name prefix to create an explicit
- // intent, such as com.example.app/.ExampleActivity (see -n specification at
- // https://developer.android.com/tools/adb#IntentSpec).
- string component = 2;
- }
-
- message UninstallPackage {
- string package_name = 1;
- }
-
- message ForceStop {
- string package_name = 1;
- }
-
- message Tap {
- // NOTE: These are absolute coordinates in the range of the screen
- // resolution. They are NOT floats in [0,1].
- // Precondition: `x` and `y` must be non-negative.
- int32 x = 1;
- int32 y = 2;
- }
-
- message PressButton {
- enum Button {
- HOME = 0;
- BACK = 1;
- ENTER = 2;
- }
- Button button = 1;
- }
-
- // Pins the given activity to the screen.
- // This essentially locks the user into a single app mode (aka "Kiosk mode").
- message StartScreenPinning {
- string full_activity = 1;
- }
-
- // Returns the full activity name that is currently opened to the user.
- // If successful, a GetCurrentActivityResponse is returned.
- message GetCurrentActivity {}
-
- // Returns the orientation of the device.
- message GetOrientationRequest {}
-
- // Performs `adb push`.
- // Please see https://developer.android.com/studio/command-line/adb#copyfiles.
- //
- // Notice that a source destination path for the file is not sent, but raw
- // bytes in `content` instead. Obviously, the `content` can be set from a real
- // file, but this is done to ensure Task definitions are as hermetic as
- // possible, without depending on the environment from where they're run.
- message Push {
- // The contents of the file.
- bytes content = 1;
-
- // Destination path _inside_ Android. E.g. /sdcard/my_file.txt.
- string path = 2;
- }
-
- // Performs `adb pull`.
- // Please see https://developer.android.com/studio/command-line/adb#copyfiles.
- //
- // Notice that a local destination for the copied file is not sent, as raw
- // bytes are returned instead (please see PullResponse). Obviously, these
- // bytes can be written to disk by the caller of this command.
- message Pull {
- // Path _inside_ Android. E.g. /sdcard/my_file.txt.
- string path = 1;
- }
-
- // Inserts text into the current text field (if any).
- // Essentially `adb shell input text `.
- message InputText {
- string text = 1;
- }
-
- // Issues an `adb shell settings` command.
- message SettingsRequest {
- // Each request has an associated namespace.
- enum Namespace {
- UNKNOWN = 0;
- SYSTEM = 1;
- SECURE = 2;
- GLOBAL = 3;
- }
-
- // Retrieves the current value for `key`.
- message Get {
- string key = 1;
- }
-
- // Changes the contents `key` to `value`.
- message Put {
- string key = 1;
- string value = 2;
- }
-
- // Deletes the entry for `key`.
- message Delete {
- string key = 1;
- }
-
- // Resets the global/secure table for a package with the given mode.
- message Reset {
- enum Mode {
- UNKNOWN = 0;
- UNTRUSTED_DEFAULTS = 1;
- UNTRUSTED_CLEAR = 2;
- TRUSTED_DEFAULTS = 3;
- }
-
- string package_name = 1;
- Mode mode = 2;
- }
-
- // Prints all defined keys in the given namespace.
- message List {}
-
- // The part of the system where this command will take place.
- // NOTE: We avoid the identifier `namespace` because it's a keyword in C++.
- Namespace name_space = 1;
-
- // The subcommand to issue to `adb settings`.
- // NOTE: We avoid the identifiers `delete` and `del` because they're
- // keywords in C++ and Python respectively.
- oneof verb {
- Get get = 2;
- Put put = 3;
- Delete delete_key = 4;
- Reset reset = 5;
- List list = 6;
- }
- }
-
- // Generic ADB command. Use this for commands that are not
- // explicitly implemented.
- // Calls `adb [args...]`.
- message GenericRequest {
- repeated string args = 1;
- }
-
- message PackageManagerRequest {
- message List {
- // Lists all features of the system.
- message Features {}
-
- // Lists all system libraries.
- message Libraries {}
-
- // Lists all packages; optionally only those whose name contains the text
- // in `filter`.
- message Packages {
- string filter = 1;
-
- // Extra options that control the output. Please see `pm help` for
- // details.
- repeated string options = 2;
- }
-
- oneof what {
- Features features = 1;
- Libraries libraries = 2;
- Packages packages = 3;
- }
- }
-
- // Deletes all data associated with a package.
- message Clear {
- // The package name to clear its cache.
- string package_name = 1;
-
- // Optional USER_ID.
- string user_id = 2;
- }
-
- message Grant {
- string package_name = 1;
-
- // Possible values listed at
- // https://developer.android.com/reference/android/Manifest.permission
- // To query an app's required permissions, use the following adb command:
- // > adb shell dumpsys package
- // The output will contain things like
- // android.permission.WRITE_SECURE_SETTINGS
- repeated string permissions = 2;
- }
-
- // The subcommand to issue to `pm`.
- oneof verb {
- List list = 1;
- Clear clear = 2;
- Grant grant = 3;
- }
- }
-
- // For executing `dumpsys` commands.
- message DumpsysRequest {
- enum PriorityLevel {
- UNSET = 0;
- NORMAL = 1;
- HIGH = 2;
- CRITICAL = 3;
- }
-
- // The service to dump. If empty, all services will be dumped.
- string service = 1;
-
- // Optional arguments to pass to the specific service dump.
- repeated string args = 2;
-
- // Lists services, does not dump them.
- // This effectively disables dumping information about any particular
- // service.
- bool list_only = 3;
-
- // Timeouts natively supported by `dumpsys`.
- int32 timeout_sec = 4;
- int32 timeout_ms = 5;
-
- // Whether to dump the process ID instead of the usual dump.
- bool pid = 6;
-
- // Whether dumps will be in proto format. Only works for services that
- // support dumping data in proto format.
- bool proto = 7;
-
- // Filters services based on specified priority.
- PriorityLevel priority = 8;
-
- // Excludes services from the dump.
- repeated string skip_services = 9;
- }
-
- oneof command {
- InstallApk install_apk = 1;
- StartActivity start_activity = 2;
- ForceStop force_stop = 3;
- Tap tap = 6;
- PressButton press_button = 7;
- StartScreenPinning start_screen_pinning = 10;
- UninstallPackage uninstall_package = 16;
- GetCurrentActivity get_current_activity = 17;
- GetOrientationRequest get_orientation = 24;
- Push push = 18;
- Pull pull = 19;
- InputText input_text = 20;
- SettingsRequest settings = 21;
- GenericRequest generic = 22;
- PackageManagerRequest package_manager = 23;
- DumpsysRequest dumpsys = 26;
- SendBroadcast send_broadcast = 25;
- }
-
- // Optional (soft) deadline in seconds for completing this command.
- // Expected to be >0. If ==0 (the default), it's ignored.
- // Notice that not all commands accept timeouts, but because it's such a
- // common parameter, we include it here instead of in each separate command.
- float timeout_sec = 100;
-}
-
-message AdbResponse {
- enum Status {
- // Reserved value for unset statuses.
- UNDEFINED = 0;
- // Returned when everything goes well.
- OK = 1;
- // Returned when handling unknown AdbRequest commands.
- UNKNOWN_COMMAND = 2;
- // Returned when an argument does not respect a precondition.
- FAILED_PRECONDITION = 3;
- // Returned when something internal did not work as expected.
- INTERNAL_ERROR = 4;
- // Returned when the adb command failed.
- ADB_ERROR = 5;
- // Returned when the adb command timed out.
- TIMEOUT = 6;
- }
- Status status = 1;
-
- // `error_message` is only populated in case of errors.
- string error_message = 2;
-
- // General stats that components may optionally report.
- map stats = 3;
-
- // Response for GetCurrentActivity requests.
- message GetCurrentActivityResponse {
- // The format of the output is `package/package.ActivityName', for example:
- // "com.example.vokram/com.example.vokram.MainActivity"
- string full_activity = 1;
- }
-
- // Response for GetOrientationRequests.
- message GetOrientationResponse {
- // Possible values are {0, 1, 2, 3} corresponding to {0, 90, 180, 270}
- // degrees respectively.
- // Please see https://developer.android.com/reference/android/view/Surface.
- int32 orientation = 1;
- }
-
- // Response for StartActivity requests.
- message StartActivityResponse {
- // The activity that was actually started. On a failed request, this will be
- // empty.
- string full_activity = 1;
- bytes output = 2;
- }
-
- // Response for PressButton requests.
- message PressButtonResponse {
- // The output, if any, by `adb` after sending a key press.
- // This is intentionally left as `bytes` instead of `string` so that content
- // other than `UTF-8` can be transmitted.
- bytes output = 1;
- }
-
- // Response for Push requests.
- message PushResponse {}
-
- // Response for Pull requests.
- message PullResponse {
- // The contents of the file.
- // This is intentionally left as `bytes` instead of `string` so that content
- // other than `UTF-8` can be transmitted.
- bytes content = 1;
- }
-
- // Response for InputText requests.
- message InputTextResponse {}
-
- // Response for SettingsRequests.
- message SettingsResponse {
- // The output, if any, of the `adb shell settings` command.
- bytes output = 1;
- }
-
- // Response for GenericRequests.
- message GenericResponse {
- // The output, if any, of the generic adb command.
- bytes output = 1;
- }
-
- // Response for PackageManagerRequests.
- message PackageManagerResponse {
- // The output, if any, of the `adb shell pm` command.
- bytes output = 1;
-
- message List {
- // A list of items. The actual content depends on the request, but it
- // could be things like features, libraries or package names.
- repeated string items = 1;
- }
-
- oneof verb {
- List list = 2;
- }
- }
-
- // Response for DumpsysRequests.
- message DumpsysResponse {
- // The output, if any, of the `dumpsys` command.
- bytes output = 1;
- }
-
- oneof payload {
- GetCurrentActivityResponse get_current_activity = 10;
- StartActivityResponse start_activity = 11;
- PressButtonResponse press_button = 12;
- PushResponse push = 13;
- PullResponse pull = 14;
- InputTextResponse input_text = 15;
- SettingsResponse settings = 16;
- GenericResponse generic = 17;
- PackageManagerResponse package_manager = 18;
- GetOrientationResponse get_orientation = 19;
- DumpsysResponse dumpsys = 21;
- }
-}
diff --git a/android_env/proto/emulator_controller.proto b/android_env/proto/emulator_controller.proto
deleted file mode 100644
index 563b2f1f..00000000
--- a/android_env/proto/emulator_controller.proto
+++ /dev/null
@@ -1,1132 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Note that if you add/remove methods in this file you must update
-// the metrics sql as well ./android/scripts/gen-grpc-sql.py
-//
-// Please group deleted methods in a block including the date (MM/DD/YY)
-// it was removed. This enables us to easily keep metrics around after removal
-//
-// list of deleted methods
-// rpc iWasDeleted (03/12/12)
-// ...
-
-// LINT: LEGACY_NAMES
-
-syntax = "proto3";
-
-package android.emulation.control;
-
-import "google/protobuf/empty.proto";
-
-option java_multiple_files = true;
-option java_package = "com.android.emulator.control";
-option objc_class_prefix = "AEC";
-
-// An EmulatorController service lets you control the emulator.
-// Note that this is currently an experimental feature, and that the
-// service definition might change without notice. Use at your own risk!
-//
-// We use the following rough conventions:
-//
-// streamXXX --> streams values XXX (usually for emulator lifetime). Values
-// are updated as soon as they become available.
-// getXXX --> gets a single value XXX
-// setXXX --> sets a single value XXX, does not returning state, these
-// usually have an observable lasting side effect.
-// sendXXX --> send a single event XXX, possibly returning state information.
-// android usually responds to these events.
-service EmulatorController {
- // Set the sensor data
- rpc streamSensor(SensorValue) returns (stream SensorValue) {}
- // Get the sensor data
- rpc getSensor(SensorValue) returns (SensorValue) {}
- // Stream the sensor data
- rpc setSensor(SensorValue) returns (google.protobuf.Empty) {}
-
- // Set the physical model, this is likely the one you are
- // looking for when you wish to modify the device state.
- rpc setPhysicalModel(PhysicalModelValue) returns (google.protobuf.Empty) {}
- // Get the physical model
- rpc getPhysicalModel(PhysicalModelValue) returns (PhysicalModelValue) {}
- // Stream the physical model
- rpc streamPhysicalModel(PhysicalModelValue)
- returns (stream PhysicalModelValue) {}
-
- // Atomically set the current primary clipboard data.
- rpc setClipboard(ClipData) returns (google.protobuf.Empty) {}
- // Atomically get the current primary clipboard data.
- rpc getClipboard(google.protobuf.Empty) returns (ClipData) {}
-
- // Streams the current data on the clipboard. This will immediately produce
- // a result with the current state of the clipboard after which the stream
- // will block and wait until a new clip event is available from the guest.
- // Calling the setClipboard method above will not result in generating a clip
- // event. It is possible to lose clipboard events if the clipboard updates
- // very rapidly.
- rpc streamClipboard(google.protobuf.Empty) returns (stream ClipData) {}
-
- // Set the battery to the given state.
- rpc setBattery(BatteryState) returns (google.protobuf.Empty) {}
- // Get the battery to the given state.
- rpc getBattery(google.protobuf.Empty) returns (BatteryState) {}
-
- // Set the state of the gps, gps support will only work
- // properly if:
- //
- // - no location ui is active. That is the emulator
- // is launched in headless mode (-no-window) or the location
- // ui is disabled (-no-location-ui).
- // - the passiveUpdate is set to false. Setting this to false
- // will disable/break the LocationUI.
- //
- // Keep in mind that android usually only samples the gps at 1 hz.
- rpc setGps(GpsState) returns (google.protobuf.Empty) {}
-
- // Gets the latest gps state as delivered by the setGps call, or location ui
- // if active.
- //
- // Note: this is not necessarily the actual gps coordinate visible at the
- // time, due to gps sample frequency (usually 1hz).
- rpc getGps(google.protobuf.Empty) returns (GpsState) {}
-
- // Simulate a touch event on the finger print sensor.
- rpc sendFingerprint(Fingerprint) returns (google.protobuf.Empty) {}
-
- // Send a keyboard event. Translating the event.
- rpc sendKey(KeyboardEvent) returns (google.protobuf.Empty) {}
-
- // Send touch events. Note that mouse events can be simulated by touch events.
- rpc sendTouch(TouchEvent) returns (google.protobuf.Empty) {}
- // Send mouse events.
- rpc sendMouse(MouseEvent) returns (google.protobuf.Empty) {}
-
- // Make a phone call.
- rpc sendPhone(PhoneCall) returns (PhoneResponse) {}
-
- // Sends an sms message to the emulator.
- rpc sendSms(SmsMessage) returns (PhoneResponse) {}
-
- // Retrieve the status of the emulator. This will contain general
- // hardware information, and whether the device has booted or not.
- rpc getStatus(google.protobuf.Empty) returns (EmulatorStatus) {}
-
- // Gets an individual screenshot in the desired format.
- //
- // The image will be scaled to the desired ImageFormat, while maintaining
- // the aspect ratio. The returned image will never exceed the provided width
- // and height. Not setting the width or height (i.e. they are 0) will result
- // in using the device width and height.
- //
- // The resulting image will be properly oriented and can be displayed
- // directly without post processing. For example, if the device has a
- // 1080x1920 screen and is in landscape mode and called with no width or
- // height parameter, it will return an 1920x1080 image.
- //
- // This method will return an empty image if the display is not visible.
- rpc getScreenshot(ImageFormat) returns (Image) {}
-
- // Streams a series of screenshots in the desired format.
- // A new frame will be delivered whenever the device produces a new frame.
- // (Beware that this can produce a significant amount of data, and that
- // certain translations are (png transform) can be costly).
- //
- // If the requested display is not visible it will send a single empty image
- // and wait start producing images once the display becomes active, again
- // producing a single empty image when the display becomes inactive.
- rpc streamScreenshot(ImageFormat) returns (stream Image) {}
-
- // Streams a series of audio packets in the desired format.
- // A new frame will be delivered whenever the emulated device
- // produces a new audio frame. You can expect packets to be
- // delivered in intervals of 20-30ms.
- //
- // Be aware that this can block when the emulator does not
- // produce any audio whatsoever!
- rpc streamAudio(AudioFormat) returns (stream AudioPacket) {}
-
- // Injects a series of audio packets to the android microphone.
- // A new frame will be delivered whenever the emulated device
- // requests a new audio frame. Audio is usually delivered at a rate
- // that the emulator is requesting frames. Audio will be stored in a
- // temporary buffer that can hold 500ms of audio.
- //
- // Note: Currently the emulator will downsample to 16khz.
- //
- // - INVALID_ARGUMENT (code 3) The sampling rate was too high
- // - INVALID_ARGUMENT (code 3) The audio packet was too large to handle.
- // - FAILED_PRECONDITION (code 9) If there was a microphone registered
- // already.
- rpc injectAudio(stream AudioPacket) returns (google.protobuf.Empty) {}
-
- // Returns the last 128Kb of logcat output from the emulator
- // Note that parsed logcat messages are only available after L (Api >23).
- // it is possible that the logcat buffer gets overwritten, or falls behind.
- rpc getLogcat(LogMessage) returns (LogMessage) {}
-
- // Streams the logcat output from the emulator. The first call
- // can retrieve up to 128Kb. This call will not return.
- // Note that parsed logcat messages are only available after L (Api >23)
- // it is possible that the logcat buffer gets overwritten, or falls behind.
- rpc streamLogcat(LogMessage) returns (stream LogMessage) {}
-
- // Transition the virtual machine to the desired state. Note that
- // some states are only observable. For example you cannot transition
- // to the error state.
- rpc setVmState(VmRunState) returns (google.protobuf.Empty) {}
-
- // Gets the state of the virtual machine.
- rpc getVmState(google.protobuf.Empty) returns (VmRunState) {}
-
- // Atomically changes the current multi-display configuration.
- // After this call the given display configurations will be activated. You
- // can only update secondary displays. Displays with id 0 will be ignored.
- //
- // This call can result in the removal or addition of secondary displays, the
- // final display state can be observed by the returned configuration.
- //
- // The following gRPC error codes can be returned:
- // - FAILED_PRECONDITION (code 9) if the AVD does not support a configurable
- // secondary display.
- // - INVALID_ARGUMENT (code 3) if:
- // - The same display id is defined multiple times.
- // - The display configurations are outside valid ranges
- // (see DisplayConfiguration)
- // - INTERNAL (code 13) if there was an internal emulator failure.
- rpc setDisplayConfigurations(DisplayConfigurations)
- returns (DisplayConfigurations) {}
-
- // Returns all currently valid logical displays.
- // The gRPC error code FAILED_PRECONDITION (code 9) is returned if the AVD
- // does not support a configurable secondary display.
- rpc getDisplayConfigurations(google.protobuf.Empty)
- returns (DisplayConfigurations) {}
-
- // Notifies client of the following changes:
- //
- // - Virtual scene camera status change.
- // - Display configuration changes from extended ui. This will only be fired
- // if the user makes modifications the extended displays through the
- // extended control tab.
- //
- // Note that this method will send the initial virtual scene state
- // immediately.
- rpc streamNotification(google.protobuf.Empty) returns (stream Notification) {}
-
- // RotationRadian is relative to the camera's current orientation.
- rpc rotateVirtualSceneCamera(RotationRadian) returns (google.protobuf.Empty) {
- }
- // Velocity is absolute
- rpc setVirtualSceneCameraVelocity(Velocity) returns (google.protobuf.Empty) {}
- // Set foldable posture
- rpc setPosture(Posture) returns (google.protobuf.Empty) {}
-}
-
-// A Run State that describes the state of the Virtual Machine.
-message VmRunState {
- enum RunState {
- // The emulator is in an unknown state. You cannot transition to this state.
- UNKNOWN = 0;
- // Guest is actively running. You can transition to this state from the
- // paused state.
- RUNNING = 1;
- // Guest is paused to load a snapshot. You cannot transition to this state.
- RESTORE_VM = 2;
- // Guest has been paused. Transitioning to this state will pause the
- // emulator the guest will not be consuming any cpu cycles.
- PAUSED = 3;
- // Guest is paused to take or export a snapshot. You cannot
- // transition to this state.
- SAVE_VM = 4;
- // System shutdown, note that it is similar to power off. It tries to set
- // the system status and notify guest. The system is likely going to
- // disappear soon and do proper cleanup of resources, possibly taking
- // a snapshot. This is the same behavior as closing the emulator by clicking
- // the X (close) in the user interface.
- SHUTDOWN = 5;
- // Immediately terminate the emulator. No resource cleanup will take place.
- // There is a good change to corrupt the system.
- TERMINATE = 7;
- // Will cause the emulator to reset. This is not a state you can observe.
- RESET = 9;
- // Guest experienced some error state, you cannot transition to this state.
- INTERNAL_ERROR = 10;
- }
-
- RunState state = 1;
-}
-
-message ParameterValue {
- repeated float data = 1 [packed = true];
-}
-
-message PhysicalModelValue {
- enum State {
- OK = 0;
- NO_SERVICE = -3; // qemud service is not available/initiated.
- DISABLED = -2; // Sensor is disabled.
- UNKNOWN = -1; // Unknown sensor (should not happen)
- }
-
- // Details on the sensors documentation can be found here:
- // https://developer.android.com/reference/android/hardware/Sensor.html#TYPE_
- // The types must follow the order defined in
- // "external/qemu/android/hw-sensors.h"
- enum PhysicalType {
- POSITION = 0;
-
- // All values are angles in degrees.
- // values = [x,y,z]
- ROTATION = 1;
-
- MAGNETIC_FIELD = 2;
-
- // Temperature in °C
- TEMPERATURE = 3;
-
- // Proximity sensor distance measured in centimeters
- PROXIMITY = 4;
-
- // Ambient light level in SI lux units
- LIGHT = 5;
-
- // Atmospheric pressure in hPa (millibar)
- PRESSURE = 6;
-
- // Relative ambient air humidity in percent
- HUMIDITY = 7;
-
- VELOCITY = 8;
- AMBIENT_MOTION = 9;
-
- // Describing a hinge angle sensor in degrees.
- HINGE_ANGLE0 = 10;
- HINGE_ANGLE1 = 11;
- HINGE_ANGLE2 = 12;
-
- ROLLABLE0 = 13;
- ROLLABLE1 = 14;
- ROLLABLE2 = 15;
- }
- PhysicalType target = 1;
-
- // [Output Only]
- State status = 2;
-
- // Value interpretation depends on sensor, will contain at most 3 values.
- ParameterValue value = 3;
-}
-
-// A single sensor value.
-message SensorValue {
- enum State {
- OK = 0;
- NO_SERVICE = -3; // qemud service is not available/initiated.
- DISABLED = -2; // Sensor is disabled.
- UNKNOWN = -1; // Unknown sensor (should not happen)
- }
-
- // These are the various sensors that can be available in an emulated
- // devices.
- enum SensorType {
- // Measures the acceleration force in m/s2 that is applied to a device
- // on all three physical axes (x, y, and z), including the force of
- // gravity.
- ACCELERATION = 0;
- // Measures a device's rate of rotation in rad/s around each of the
- // three physical axes (x, y, and z).
- GYROSCOPE = 1;
- // Measures the ambient geomagnetic field for all three physical axes
- // (x, y, z) in μT.
- MAGNETIC_FIELD = 2;
- // Measures degrees of rotation that a device makes around all three
- // physical axes (x, y, z)
- ORIENTATION = 3;
- // Measures the temperature of the device in degrees Celsius (°C).
- TEMPERATURE = 4;
- // Measures the proximity of an object in cm relative to the view screen
- // of a device. This sensor is typically used to determine whether a
- // handset is being held up to a person's ear.
- PROXIMITY = 5;
- // Measures the ambient light level (illumination) in lx.
- LIGHT = 6;
- // Measures the ambient air pressure in hPa or mbar.
- PRESSURE = 7;
- // Measures the relative ambient humidity in percent (%).
- HUMIDITY = 8;
- MAGNETIC_FIELD_UNCALIBRATED = 9;
- GYROSCOPE_UNCALIBRATED = 10;
- }
-
- // Type of sensor
- SensorType target = 1;
-
- // [Output Only]
- State status = 2;
-
- // Value interpretation depends on sensor enum, will contain at most 3
- // values.
- ParameterValue value = 3;
-}
-
-message LogMessage {
- // [Output Only] The contents of the log output.
- string contents = 1;
- // The starting byte position of the output that was returned. This
- // should match the start parameter sent with the request. If the serial
- // console output exceeds the size of the buffer, older output will be
- // overwritten by newer content and the start values will be mismatched.
- int64 start = 2;
- //[Output Only] The position of the next byte of content from the serial
- // console output. Use this value in the next request as the start
- // parameter.
- int64 next = 3;
-
- // Set the sort of response you are interested it in.
- // It the type is "Parsed" the entries field will contain the parsed
- // results. otherwise the contents field will be set.
- LogType sort = 4;
-
- // [Output Only] The parsed logcat entries so far. Only set if sort is
- // set to Parsed
- repeated LogcatEntry entries = 5;
-
- enum LogType {
- Text = 0;
- Parsed = 1;
- }
-}
-
-// A parsed logcat entry.
-message LogcatEntry {
- // The possible log levels.
- enum LogLevel {
- UNKNOWN = 0;
- DEFAULT = 1;
- VERBOSE = 2;
- DEBUG = 3;
- INFO = 4;
- WARN = 5;
- ERR = 6;
- FATAL = 7;
- SILENT = 8;
- }
-
- // A Unix timestamps in milliseconds (The number of milliseconds that
- // have elapsed since January 1, 1970 (midnight UTC/GMT), not counting
- // leap seconds)
- uint64 timestamp = 1;
-
- // Process id.
- uint32 pid = 2;
-
- // Thread id.
- uint32 tid = 3;
- LogLevel level = 4;
- string tag = 5;
- string msg = 6;
-}
-
-// Information about the hypervisor that is currently in use.
-message VmConfiguration {
- enum VmHypervisorType {
- // An unknown hypervisor
- UNKNOWN = 0;
-
- // No hypervisor is in use. This usually means that the guest is
- // running on a different CPU than the host, or you are using a
- // platform where no hypervisor is available.
- NONE = 1;
-
- // The Kernel based Virtual Machine
- // (https://www.linux-kvm.org/page/Main_Page)
- KVM = 2;
-
- // Intel® Hardware Accelerated Execution Manager (Intel® HAXM)
- // https://github.com/intel/haxm
- HAXM = 3;
-
- // Hypervisor Framework.
- // https://developer.apple.com/documentation/hypervisor
- HVF = 4;
-
- // Window Hypervisor Platform
- // https://docs.microsoft.com/en-us/virtualization/api/
- WHPX = 5;
-
- GVM = 6;
- }
-
- VmHypervisorType hypervisorType = 1;
- int32 numberOfCpuCores = 2;
- int64 ramSizeBytes = 3;
-}
-
-// Representation of a clipped data object on the clipboard.
-message ClipData {
- // UTF-8 Encoded text.
- string text = 1;
-}
-
-// The Touch interface represents a single contact point on a
-// touch-sensitive device. The contact point is commonly a finger or stylus
-// and the device may be a touchscreen or trackpad.
-message Touch {
- // The horizontal coordinate. This is the physical location on the
- // screen For example 0 indicates the leftmost coordinate.
- int32 x = 1;
-
- // The vertical coordinate. This is the physical location on the screen
- // For example 0 indicates the top left coordinate.
- int32 y = 2;
-
- // The identifier is an arbitrary non-negative integer that is used to
- // identify and track each tool independently when multiple tools are
- // active. For example, when multiple fingers are touching the device,
- // each finger should be assigned a distinct tracking id that is used as
- // long as the finger remains in contact. Tracking ids may be reused
- // when their associated tools move out of range.
- //
- // The emulator currently supports up to 10 concurrent touch events. The
- // identifier can be any uninque value and will be mapped to the next
- // available internal identifier.
- int32 identifier = 3;
-
- // Reports the physical pressure applied to the tip of the tool or the
- // signal strength of the touch contact.
- //
- // The values reported must be non-zero when the tool is touching the
- // device and zero otherwise to indicate that the touch event is
- // completed.
- //
- // Make sure to deliver a pressure of 0 for the given identifier when
- // the touch event is completed, otherwise the touch identifier will not
- // be unregistered!
- int32 pressure = 4;
-
- // Optionally reports the cross-sectional area of the touch contact, or
- // the length of the longer dimension of the touch contact.
- int32 touch_major = 5;
-
- // Optionally reports the length of the shorter dimension of the touch
- // contact. This axis will be ignored if touch_major is reporting an
- // area measurement greater than 0.
- int32 touch_minor = 6;
-
- enum EventExpiration {
- // The system will use the default time of 120s to track
- // the touch event with the given identifier. If no update happens
- // within this timeframe the identifier is considered expired
- // and can be made available for re-use. This means that a touch event
- // with pressure 0 for this identifier will be send to the emulator.
- EVENT_EXPIRATION_UNSPECIFIED = 0;
-
- // Never expire the given slot. You must *ALWAYS* close the identifier
- // by sending a touch event with 0 pressure.
- NEVER_EXPIRE = 1;
- }
-
- EventExpiration expiration = 7;
-}
-
-// A TouchEvent contains a list of Touch objects that are in contact with
-// the touch surface.
-//
-// Touch events are delivered in sequence as specified in the touchList.
-//
-// TouchEvents are delivered to the emulated devices using ["Protocol
-// B"](https://www.kernel.org/doc/Documentation/input/multi-touch-protocol.txt)
-message TouchEvent {
- // The list of Touch objects, note that these do not need to be unique
- repeated Touch touches = 1;
-
- // The display device where the touch event occurred.
- // Omitting or using the value 0 indicates the main display.
- //
- // Touch events cannot be send to displays other than 0, due to
- // https://issuetracker.google.com/issues/150699691
- int32 display = 2;
-}
-
-// The MouseEvent interface represents events that occur due to the user
-// interacting with a pointing device (such as a mouse).
-message MouseEvent {
- // The horizontal coordinate. This is the physical location on the
- // screen For example 0 indicates the leftmost coordinate.
- int32 x = 1;
-
- // The vertical coordinate. This is the physical location on the screen
- // For example 0 indicates the top left coordinate.
- int32 y = 2;
-
- // Indicates which buttons are pressed.
- // 0: No button was pressed
- // 1: Primary button (left)
- // 2: Secondary button (right)
- int32 buttons = 3;
-
- // The display device where the mouse event occurred.
- // Omitting or using the value 0 indicates the main display.
- int32 display = 4;
-}
-
-// KeyboardEvent objects describe a user interaction with the keyboard; each
-// event describes a single interaction between the user and a key (or
-// combination of a key with modifier keys) on the keyboard.
-// This follows the pattern as set by
-// (javascript)[https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent]
-//
-// Note: that only keyCode, key, or text can be set and that the semantics
-// will slightly vary.
-message KeyboardEvent {
- // Code types that the emulator can receive. Note that the emulator
- // will do its best to translate the code to an evdev value that
- // will be send to the emulator. This translation is based on
- // the chromium translation tables. See
- // (this)[https://android.googlesource.com/platform/external/qemu/+/refs/heads/emu-master-dev/android/android-grpc/android/emulation/control/keyboard/keycode_converter_data.inc]
- // for details on the translation.
- enum KeyCodeType {
- Usb = 0;
- Evdev = 1;
- XKB = 2;
- Win = 3;
- Mac = 4;
- }
-
- enum KeyEventType {
- // Indicates that this keyevent should be send to the emulator
- // as a key down event. Meaning that the key event will be
- // translated to an EvDev event type and bit 11 (0x400) will be
- // set before it is sent to the emulator.
- keydown = 0;
-
- // Indicates that the keyevent should be send to the emulator
- // as a key up event. Meaning that the key event will be
- // translated to an EvDev event type and
- // sent to the emulator.
- keyup = 1;
-
- // Indicates that the keyevent will be send to the emulator
- // as e key down event and immediately followed by a keyup event.
- keypress = 2;
- }
-
- // Type of keycode contained in the keyCode field.
- KeyCodeType codeType = 1;
-
- // The type of keyboard event that should be sent to the emulator
- KeyEventType eventType = 2;
-
- // This property represents a physical key on the keyboard (as opposed
- // to the character generated by pressing the key). In other words, this
- // property is a value which isn't altered by keyboard layout or the
- // state of the modifier keys. This value will be interpreted by the
- // emulator depending on the KeyCodeType. The incoming key code will be
- // translated to an evdev code type and send to the emulator.
- // The values in key and text will be ignored.
- int32 keyCode = 3;
-
- // The value of the key pressed by the user, taking into consideration
- // the state of modifier keys such as Shift as well as the keyboard
- // locale and layout. This follows the w3c standard used in browsers.
- // You can find an accurate description of valid values
- // [here](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values)
- //
- // Note that some keys can result in multiple evdev events that are
- // delivered to the emulator. for example the Key "A" will result in a
- // sequence:
- // ["Shift", "a"] -> [0x2a, 0x1e] whereas "a" results in ["a"] -> [0x1e].
- //
- // Not all documented keys are understood by android, and only printable
- // ASCII [32-127) characters are properly translated.
- //
- // Keep in mind that there are a set of key values that result in android
- // specific behavior
- // [see](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values#Phone_keys):
- //
- // - "AppSwitch": Behaves as the "Overview" button in android.
- // - "GoBack": The Back button.
- // - "GoHome": The Home button, which takes the user to the phone's main
- // screen (usually an application launcher).
- // - "Power": The Power button.
- string key = 4;
-
- // Series of utf8 encoded characters to send to the emulator. An attempt
- // will be made to translate every character will an EvDev event type and
- // send to the emulator as a keypress event. The values in keyCode,
- // eventType, codeType and key will be ignored.
- //
- // Note that most printable ASCII characters (range [32-127) can be send
- // individually with the "key" param. Do not expect arbitrary UTF symbols to
- // arrive in the emulator (most will be ignored).
- //
- // Note that it is possible to overrun the keyboard buffer by slamming this
- // endpoint with large quantities of text (>1kb). The clipboard api is better
- // suited for transferring large quantities of text.
- string text = 5;
-}
-
-message Fingerprint {
- // True when the fingprint is touched.
- bool isTouching = 1;
-
- // The identifier of the registered fingerprint.
- int32 touchId = 2;
-}
-
-message GpsState {
- // Setting this to false will disable auto updating from the LocationUI,
- // otherwise the location UI will override the location at a frequency of 1hz.
- //
- // - This is unused if the emulator is launched with -no-window, or when he
- // location ui is disabled.
- // - This will BREAK the location ui experience if it is set to false. For
- // example routing will no longer function.
- bool passiveUpdate = 1;
-
- // The latitude, in degrees.
- double latitude = 2;
-
- // The longitude, in degrees.
- double longitude = 3;
-
- // The speed if it is available, in meters/second over ground
- double speed = 4;
-
- // gets the horizontal direction of travel of this device, and is not
- // related to the device orientation. It is guaranteed to be in the
- // range [0.0, 360.0] if the device has a bearing. 0=North, 90=East,
- // 180=South, etc..
- double bearing = 5;
-
- // The altitude if available, in meters above the WGS 84 reference
- // ellipsoid.
- double altitude = 6;
-
- // The number of satellites used to derive the fix
- int32 satellites = 7;
-}
-
-message BatteryState {
- enum BatteryStatus {
- UNKNOWN = 0;
- CHARGING = 1;
- DISCHARGING = 2;
- NOT_CHARGING = 3;
- FULL = 4;
- }
-
- enum BatteryCharger {
- NONE = 0;
- AC = 1;
- USB = 2;
- WIRELESS = 3;
- }
-
- enum BatteryHealth {
- GOOD = 0;
- FAILED = 1;
- DEAD = 2;
- OVERVOLTAGE = 3;
- OVERHEATED = 4;
- }
-
- bool hasBattery = 1;
- bool isPresent = 2;
- BatteryCharger charger = 3;
- int32 chargeLevel = 4;
- BatteryHealth health = 5;
- BatteryStatus status = 6;
-}
-
-// An ImageTransport allows for specifying a side channel for
-// delivering image frames versus using the standard bytes array that is
-// returned with the gRPC request.
-message ImageTransport {
- enum TransportChannel {
- // Return full frames over the gRPC transport
- TRANSPORT_CHANNEL_UNSPECIFIED = 0;
-
- // Write images to the a file/shared memory handle.
- MMAP = 1;
- }
-
- // The desired transport channel used for delivering image frames. Only
- // relevant when streaming screenshots.
- TransportChannel channel = 1;
-
- // Handle used for writing image frames if transport is mmap. The client sets
- // and owns this handle. It can be either a shm region, or a mmap. A mmap
- // should be a url that starts with `file:///`
- // Note: the mmap can result in tearing.
- string handle = 2;
-}
-
-// The aspect ratio (width/height) will be different from the one
-// where the device is unfolded.
-message FoldedDisplay {
- uint32 width = 1;
- uint32 height = 2;
- // It is possible for the screen to be folded in different ways depending
- // on which surface is shown to the user. So xOffset and yOffset indicate
- // the top left corner of the folded screen within the original unfolded
- // screen.
- uint32 xOffset = 3;
- uint32 yOffset = 4;
-}
-
-message ImageFormat {
- enum ImgFormat {
- // Portable Network Graphics format
- // (https://en.wikipedia.org/wiki/Portable_Network_Graphics)
- PNG = 0;
-
- // Three-channel RGB color model supplemented with a fourth alpha
- // channel. https://en.wikipedia.org/wiki/RGBA_color_model
- // Each pixel consists of 4 bytes.
- RGBA8888 = 1;
-
- // Three-channel RGB color model, each pixel consists of 3 bytes
- RGB888 = 2;
- }
-
- // The (desired) format of the resulting bytes.
- ImgFormat format = 1;
-
- // [Output Only] The rotation of the image. The image will be rotated
- // based upon the coarse grained orientation of the device.
- Rotation rotation = 2;
-
- // The (desired) width of the image. When passed as input
- // the image will be scaled to match the given
- // width, while maintaining the aspect ratio of the device.
- // The returned image will never exceed the given width, but can be less.
- // Omitting this value (or passing in 0) will result in no scaling,
- // and the width of the actual device will be used.
- uint32 width = 3;
-
- // The (desired) height of the image. When passed as input
- // the image will be scaled to match the given
- // height, while maintaining the aspect ratio of the device.
- // The returned image will never exceed the given height, but can be less.
- // Omitting this value (or passing in 0) will result in no scaling,
- // and the height of the actual device will be used.
- uint32 height = 4;
-
- // The (desired) display id of the device. Setting this to 0 (or omitting)
- // indicates the main display.
- uint32 display = 5;
-
- // Set this if you wish to use a different transport channel to deliver image
- // frames.
- ImageTransport transport = 6;
-
- // [Output Only] Display configuration when screen is folded. The value is the
- // original configuration before scaling.
- FoldedDisplay foldedDisplay = 7;
-}
-
-message Image {
- ImageFormat format = 1;
-
- uint32 width = 2 [deprecated = true]; // width is contained in format.
- uint32 height = 3 [deprecated = true]; // height is contained in format.
-
- // The organization of the pixels in the image buffer is from left to
- // right and bottom up. This will be empty if an alternative image transport
- // is requested in the image format. In that case the side channel should
- // be used to obtain the image data.
- bytes image = 4;
-
- // [Output Only] Monotonically increasing sequence number in a stream of
- // screenshots. The first screenshot will have a sequence of 0. A single
- // screenshot will always have a sequence number of 0. The sequence is not
- // necessarily contiguous, and can be used to detect how many frames were
- // dropped. An example sequence could be: [0, 3, 5, 7, 9, 11].
- uint32 seq = 5;
-
- // [Output Only] Unix timestamp in microseconds when the emulator estimates
- // the frame was generated. The timestamp is before the actual frame is
- // copied and transformed. This can be used to calculate variance between
- // frame production time, and frame depiction time.
- uint64 timestampUs = 6;
-}
-
-message Rotation {
- enum SkinRotation {
- PORTRAIT = 0; // 0 degrees
- LANDSCAPE = 1; // 90 degrees
- REVERSE_PORTRAIT = 2; // -180 degrees
- REVERSE_LANDSCAPE = 3; // -90 degrees
- }
-
- // The rotation of the device, derived from the sensor state
- // of the emulator. The derivation reflects how android observes
- // the rotation state.
- SkinRotation rotation = 1;
-
- // Specifies the angle of rotation, in degrees [-180, 180]
- double xAxis = 2;
- double yAxis = 3;
- double zAxis = 4;
-}
-
-message PhoneCall {
- enum Operation {
- InitCall = 0;
- AcceptCall = 1;
- RejectCallExplicit = 2;
- RejectCallBusy = 3;
- DisconnectCall = 4;
- PlaceCallOnHold = 5;
- TakeCallOffHold = 6;
- }
- Operation operation = 1;
- string number = 2;
-}
-
-message PhoneResponse {
- enum Response {
- OK = 0;
- BadOperation = 1; // Enum out of range
- BadNumber = 2; // Mal-formed telephone number
- InvalidAction = 3; // E.g., disconnect when no call is in progress
- ActionFailed = 4; // Internal error
- RadioOff = 5; // Radio power off
- }
- Response response = 1;
-}
-
-message Entry {
- string key = 1;
- string value = 2;
-}
-
-message EntryList {
- repeated Entry entry = 1;
-}
-
-message EmulatorStatus {
- // The emulator version string.
- string version = 1;
-
- // The time the emulator has been active in .ms
- uint64 uptime = 2;
-
- // True if the device has completed booting.
- // For P and later this information will accurate,
- // for older images we rely on adb.
- bool booted = 3;
-
- // The current vm configuration
- VmConfiguration vmConfig = 4;
-
- // The hardware configuration of the running emulator as
- // key valure pairs.
- EntryList hardwareConfig = 5;
-}
-
-message AudioFormat {
- enum SampleFormat {
- AUD_FMT_U8 = 0; // Unsigned 8 bit
- AUD_FMT_S16 = 1; // Signed 16 bit (little endian)
- }
-
- enum Channels {
- Mono = 0;
- Stereo = 1;
- }
-
- // Sampling rate to use, defaulting to 44100 if this is not set.
- // Note, that android devices typically will not use a sampling
- // rate higher than 48kHz. See https://developer.android.com/ndk/guides/audio.
- uint64 samplingRate = 1;
- Channels channels = 2;
- SampleFormat format = 3;
-}
-
-message AudioPacket {
- AudioFormat format = 1;
-
- // Unix epoch in us when this frame was captured.
- uint64 timestamp = 2;
-
- // Contains a sample in the given audio format.
- bytes audio = 3;
-}
-
-message SmsMessage {
- // The source address where this message came from.
- //
- // The address should be a valid GSM-formatted address as specified by
- // 3GPP 23.040 Sec 9.1.2.5.
- //
- // For example: +3106225412 or (650) 555-1221
- string srcAddress = 1;
-
- // A utf8 encoded text message that should be delivered.
- string text = 2;
-}
-
-// A DisplayConfiguration describes a primary or secondary
-// display available to the emulator. The screen aspect ratio
-// cannot be longer (or wider) than 21:9 (or 9:21). Screen sizes
-// larger than 4k will be rejected.
-//
-// Common configurations (w x h) are:
-// - 480p (480x720) 142 dpi
-// - 720p (720x1280) 213 dpi
-// - 1080p (1080x1920) 320 dpi
-// - 4K (2160x3840) 320 dpi
-// - 4K (2160x3840) 640 dpi (upscaled)
-//
-// The behavior of the virtual display depends on the flags that are provided to
-// this method. By default, virtual displays are created to be private,
-// non-presentation and unsecure.
-message DisplayConfiguration {
- // These are the set of known android flags and their respective values.
- // you can combine the int values to (de)construct the flags field below.
- enum DisplayFlags {
- DISPLAYFLAGS_UNSPECIFIED = 0;
-
- // When this flag is set, the virtual display is public.
- // A public virtual display behaves just like most any other display
- // that is connected to the system such as an external or wireless
- // display. Applications can open windows on the display and the system
- // may mirror the contents of other displays onto it. see:
- // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PUBLIC
- VIRTUAL_DISPLAY_FLAG_PUBLIC = 1;
-
- // When this flag is set, the virtual display is registered as a
- // presentation display in the presentation display category.
- // Applications may automatically project their content to presentation
- // displays to provide richer second screen experiences.
- // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PRESENTATION
- VIRTUAL_DISPLAY_FLAG_PRESENTATION = 2;
-
- // When this flag is set, the virtual display is considered secure as
- // defined by the Display#FLAG_SECURE display flag. The caller promises
- // to take reasonable measures, such as over-the-air encryption, to
- // prevent the contents of the display from being intercepted or
- // recorded on a persistent medium.
- // see:
- // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_SECURE
- VIRTUAL_DISPLAY_FLAG_SECURE = 4;
-
- // This flag is used in conjunction with VIRTUAL_DISPLAY_FLAG_PUBLIC.
- // Ordinarily public virtual displays will automatically mirror the
- // content of the default display if they have no windows of their own.
- // When this flag is specified, the virtual display will only ever show
- // its own content and will be blanked instead if it has no windows. See
- // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY
- VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY = 8;
-
- // Allows content to be mirrored on private displays when no content is
- // being shown.
- // This flag is mutually exclusive with
- // VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY. If both flags are specified
- // then the own-content only behavior will be applied.
- // see:
- // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR)
- VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR = 16;
- }
-
- // The width of the display, restricted to:
- // 320 * (dpi / 160) <= width
- uint32 width = 1;
-
- // The heigh of the display, restricted to:
- // * 320 * (dpi / 160) <= height
- uint32 height = 2;
-
- // The pixel density (dpi).
- // See https://developer.android.com/training/multiscreen/screendensities
- // for details. This value should be in the range [120, ..., 640]
- uint32 dpi = 3;
-
- // A combination of virtual display flags. These flags can be constructed
- // by combining the DisplayFlags enum described above.
- //
- // The behavior of the virtual display depends on the flags. By default
- // virtual displays are created to be private, non-presentation and
- // unsecure.
- uint32 flags = 4;
-
- // The id of the display.
- // The primary (default) display has the display ID of 0.
- // A secondary display has a display ID not 0.
- //
- // The id can be used to get or stream a screenshot.
- uint32 display = 5;
-}
-
-message DisplayConfigurations {
- repeated DisplayConfiguration displays = 1;
-}
-
-message Notification {
- enum EventType {
- VIRTUAL_SCENE_CAMERA_INACTIVE = 0;
- VIRTUAL_SCENE_CAMERA_ACTIVE = 1;
-
- // Fired when an update to a display event has been fired through
- // the extended ui. This does not fire events when the display
- // is changed through the console or gRPC endpoint.
- DISPLAY_CONFIGURATIONS_CHANGED_UI = 2;
- // Keep adding more for other event types
- }
-
- EventType event = 1;
-}
-
-message RotationRadian {
- float x = 1; // x axis is horizontal and orthogonal to the view direction.
- float y = 2; // y axis points up and is perpendicular to the floor.
- float z = 3; // z axis is the view direction and is set to 0.0 in
- // rotateVirtualSceneCamera call.
-}
-
-message Velocity {
- float x = 1; // x axis is horizontal and orthogonal to the view direction.
- float y = 2; // y axis points up and is perpendicular to the floor.
- float z = 3; // z axis is the view direction
-}
-
-// must follow the definition in "external/qemu/android/hw-sensors.h"
-message Posture {
- enum PostureValue {
- POSTURE_UNKNOWN = 0;
- POSTURE_CLOSED = 1;
- POSTURE_HALF_OPENED = 2;
- POSTURE_OPENED = 3;
- POSTURE_FLIPPED = 4;
- POSTURE_TENT = 5;
- POSTURE_MAX = 6;
- }
- PostureValue value = 3;
-}
diff --git a/android_env/proto/snapshot.proto b/android_env/proto/snapshot.proto
deleted file mode 100644
index 0b0abc12..00000000
--- a/android_env/proto/snapshot.proto
+++ /dev/null
@@ -1,169 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto2";
-
-// This file must be synchronized between
-// Emulator (branch aosp/emu-master-dev):
-// external/qemu/android/android-emu/android/snapshot/proto/snapshot.proto
-//
-// Android Studio (branch goog/studio-master-dev):
-// tools/adt/idea/android/src/com/android/emulator/snapshot.proto
-//
-// If you modify one, please modify the other.
-
-package emulator_snapshot;
-
-option java_package = "com.android.emulator.snapshot";
-
-message Image {
- enum Type {
- IMAGE_TYPE_UNKNOWN = 0;
- IMAGE_TYPE_KERNEL = 1;
- IMAGE_TYPE_KERNEL_RANCHU = 2;
- IMAGE_TYPE_SYSTEM = 3;
- IMAGE_TYPE_SYSTEM_COPY = 4;
- IMAGE_TYPE_DATA = 5;
- IMAGE_TYPE_DATA_COPY = 6;
- IMAGE_TYPE_RAMDISK = 7;
- IMAGE_TYPE_SDCARD = 8;
- IMAGE_TYPE_CACHE = 9;
- IMAGE_TYPE_VENDOR = 10;
- IMAGE_TYPE_ENCRYPTION_KEY = 11;
- }
-
- optional Type type = 1;
- optional string path = 2;
- optional bool present = 3;
- optional int64 size = 4;
- optional int64 modification_time = 5;
-}
-
-message Host {
- optional string gpu_driver = 4;
- optional int32 hypervisor = 5;
-}
-
-message Config {
- // Features are int32, not enums here to make sure we don't have to update
- // one more protobuf definition with every single new feature flag, even
- // when the code doesn't really care about the actual meaning for them,
- // only for the values.
- repeated int32 enabled_features = 1;
-
- // This holds the renderer; int32 for the same reason as |enabled_features|.
- optional int32 selected_renderer = 2;
-
- optional int32 cpu_core_count = 3;
- optional int64 ram_size_bytes = 4;
-}
-
-message SaveStats {
- // Type of save
- // 0: non-incremental
- // 1: incremental
- optional uint32 incremental = 1;
- // Time taken to save.
- optional uint64 duration = 2;
- // How many changed bytes in RAM.
- optional uint64 ram_changed_bytes = 3;
-}
-
-message Snapshot {
- // Update every time when introducing some breaking changes that make the
- // previous loading code break when trying to load the new snapshot.
- // NOTE: if the old code is fine with just skipping the new fields or not
- // getting the meaning of new values, |version| should remain
- // unchanged.
- optional int32 version = 1;
-
- // Purely informative: when this snapshot was created, Unix timestamp.
- optional int64 creation_time = 2;
-
- // list of mounted disk images used during the snapshot creation.
- repeated Image images = 3;
-
- // Description of the host machine properties needed to load this snapshot.
- optional Host host = 4;
-
- // Description of the emulator configuration needed for this snapshot.
- // NOTE: try not to duplicate the configuration that's already in
- // hardware-qemu.ini; only add what's either not there or what
- // could've been overridden during process initialization.
- optional Config config = 5;
-
- // Set if the snapshot failed to load during the last attempt.
- // Code is up to the application to define, with 0 meaning 'not failed' just
- // in case.
- optional int64 failed_to_load_reason_code = 7;
-
- // Set if data image is mounted.
- // User build and userdebug build mount data partition at different time.
- // But it should be done before boot finished, so this field is very likely
- // to be true.
- // We snapshot it here just in case someday we support snapshot during
- // booting.
- optional bool guest_data_partition_mounted = 8;
-
- // Emulator rotation angle, in right angles (e.g. 1 is 90 degrees, 2 is 180
- // etc).
- optional int32 rotation = 9;
-
- // Number of invalid loads / crashes that happened under this snapshot.
- optional int32 invalid_loads = 10;
-
- // Number of successful loads.
- optional int32 successful_loads = 11;
-
- // The name given to the snapshot by the user. Independent of the
- // file name.
- optional string logical_name = 12;
-
- // The file name of this snapshot's parent. The parent is the
- // snapshot that was loaded into the AVD prior to this snapshot
- // being taken
- optional string parent = 13;
-
- // Arbitrary description added by the user
- optional string description = 14;
-
- // Record of save stats.
- repeated SaveStats save_stats = 15;
-
- // Folded state.
- optional bool folded = 16;
-
- // Emulator boot parameters
- repeated string launch_parameters = 17;
-
- // Emulator build ID
- optional string emulator_build_id = 18;
-
- // System image build ID
- optional string system_image_build_id = 19;
-}
diff --git a/android_env/proto/snapshot_service.proto b/android_env/proto/snapshot_service.proto
deleted file mode 100644
index 1cbf72ec..00000000
--- a/android_env/proto/snapshot_service.proto
+++ /dev/null
@@ -1,289 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Copyright (C) 2018 The Android Open Source Project
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Note that if you add/remove methods in this file you must update
-// the metrics sql as well by running ./android/scripts/gen-grpc-sql.py
-//
-// Please group deleted methods in a block including the date (MM/DD/YY)
-// it was removed. This enables us to easily keep metrics around after removal
-//
-// list of deleted methods
-// rpc iWasDeleted (03/12/12)
-// ...
-syntax = "proto3";
-
-package android.emulation.control;
-
-import "android_env/proto/snapshot.proto";
-
-option java_multiple_files = true;
-option java_package = "com.android.emulator.control";
-option objc_class_prefix = "AEC";
-
-// The SnapshotService enables you to list, insert, store, and retrieve
-// snapshots.
-//
-// Currently there are two types of snapshots:
-//
-// - Local (default): These are snapshots that are created locally. They are
-// stored internally inside qcow2 files and are very efficient. These are
-// the snapshots usually created by interacting with the UI.
-//
-// - Remote: These are snapshots that have been exported at a certain point.
-// an exported snapshot is normalized (completely self contained) and
-// can be imported into an emulator with a similar hardware configuration.
-//
-// Currently the emulator has limited support for importing snapshots:
-// - Once an imported snapshot has been loaded into an emulator it is no longer
-// possible to create new snapshots.
-// - The hardware configuration of the emulator your are pushing a snapshot to
-// must match (or be very similar) to the one you pulled the snapshot from.
-//
-// For example do not expect to be able to restore a snapshot on created on an
-// Intel cpu on an AMD cpu.
-service SnapshotService {
- // Lists all the snapshots, filtered by the given query, that are stored
- // locally for the currently running avd. This includes all the snapshots that
- // were imported (pushed) into this emulator.
- //
- // Returns a list of snapshot_id's and associated details that describes
- // the hardware configuration, logical name, etc of the snapshot.
- rpc ListSnapshots(SnapshotFilter) returns (SnapshotList) {}
-
- // Pulls down the snapshot stored inside the AVD as a tar.gz/tar stream
- // This will normalize the snapshot, all relevant data to push a snapshot
- // into a similar emulator will be placed inside the tar file.
- //
- // Pulling down a snapshot will pause the emulator until the snapshots
- // are rebased and ready for exporting. Once the snapshot is rebased
- // the emulator will continue and downloading should commence.
- //
- // Note that pulling .gz stream is slow.
- //
- // You must provide the snapshot_id and (desired) format.
- //
- // If SnapshotPackage.path is set, the gRPC service will directly write the
- // exported snapshot to SnapshotPackage.path without streaming, which is
- // usually significantly faster. It would require emulator to have direct
- // access to SnapshotPackage.path, which usually means it can only be used
- // when pulling from a local emulator.
- rpc PullSnapshot(SnapshotPackage) returns (stream SnapshotPackage) {}
-
- // Push a tar.gz stream contain the snapshot. The tar file should
- // be a snapshot that was exported through the PullSnapshot in the past.
- // The emulator will try to import the snapshot. The hardware configuration
- // of the current emulator should match the one used for pulling.
- //
- // A detailed description of the snapshot (emulator_snapshot.Snapshot)
- // is stored in the snapshot.pb file inside the tar.
- //
- // You must provide the snapshot_id and format in the first message.
- // Will return success and a possible error message when a failure occurs.
- //
- // If SnapshotPackage.path is set, the gRPC service will directly unzip the
- // exported snapshot from SnapshotPackage.path without streaming, which is
- // usually significantly faster. It would require emulator to have direct
- // access to SnapshotPackage.path, which usually means it can only be used
- // when pushing to a local emulator.
- rpc PushSnapshot(stream SnapshotPackage) returns (SnapshotPackage) {}
-
- // Loads the given snapshot inside the emulator and activates it.
- // The device will be in the state as it was when the snapshot was created.
- //
- // You will no longer be able to call Save if this was an imported
- // snapshot that was pushed into this emulator.
- //
- // You must provide the snapshot_id to indicate which snapshot to load
- // Will return success and a possible error message when a failure occurs.
- rpc LoadSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
-
- // Creates as a snapshot of the current state of the emulator.
- // You can only save a snapshot if you never activated (Load) an imported
- // snapshot (Push).
- //
- // For example:
- // - PushSnapshot("some_snap.tar.gz");
- // - LoadSnapshot("some_snap");
- // - SaveSnapshot("same_newer_snap"); // <--- Will currently fail.
- //
- // You can provide the snapshot_id to indicate the name used for storing.
- // Will return success and a possible error message when a failure occurs.
- rpc SaveSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
-
- // Deletes the snapshot with the given snapshot_id from the avd.
- //
- // You must provide the snapshot_id to indicate which snapshot to delete.
- // Will return success and a possible error message when a failure occurs.
- rpc DeleteSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
-
- // Tracks the given process for automated snapshot creation in case of
- // assert failures.
- //
- // Will return success and a possible error message when a failure occurs.
- // The snapshot_id field will contain the name of the snapshot that
- // will be created. The pid field will contain the process id that is
- // being tracked.
- rpc TrackProcess(IceboxTarget) returns (IceboxTarget) {}
-}
-
-// Sets options for SnapshotService. Used for both request and response
-// messages.
-message SnapshotPackage {
- enum Format {
- TARGZ = 0;
- TAR = 1;
- DIRECTORY = 2;
- }
- // The identifier to the snapshot, only required for request messages. For
- // streaming service, only used in the first stream message of a gRPC call
- // (would be ignored in consequent stream messages of the same call).
- string snapshot_id = 1;
-
- // A stream of bytes. Encoded as a tar (possibly gzipped) file pendinf on the
- // value of format.
- bytes payload = 2;
-
- // [response only] status fields, usually indicates end of transmission.
- bool success = 3;
- bytes err = 4;
-
- // [request only] Format of the payload. Only used in request messages. For
- // streaming service, only used in the first stream message of a gRPC call
- // (would be ignored in consequent stream messages of the same call).
- Format format = 5;
-
- // [request only] Path to the snapshot package file. Only used in request
- // messages.
- //
- // When set in a request, the PullSnapshot/PushSnapshot operation will
- // directly write/read the exported snapshot in path without streaming, which
- // is usually significantly faster. It would require emulator to have direct
- // access to path, which usually means it can only be used with a local
- // emulator.
- string path = 6;
-}
-
-// A snapshot filter can be used to filter the results produced by ListSnapshots
-message SnapshotFilter {
- enum LoadStatus {
- // Only return compatible snapshots
- CompatibleOnly = 0;
-
- // Return all snapshots.
- All = 1;
- }
-
- // Filter snapshots by load status.
- LoadStatus statusFilter = 1;
-}
-
-// Provides detailed information regarding the snapshot.
-message SnapshotDetails {
- enum LoadStatus {
- // The emulator believes that the snapshot is compatible with the emulator
- // that provided this information. The emulator will attempt to load this
- // snapshot when requested.
- //
- // A snapshot is usually compatible when the following statements are true:
- // - The snapshot was taken by the current emulator version. i.e.
- // emulator_build_id in the details field matches the build_id of the
- // emulator that provided this information.
- //
- // - The snapshot was taken on the current running machine, and no hardware
- // changes have taken place between taking and loading the snapshot.
- //
- // - The avd configuration has not changed between when this snapshot was
- // taken and when the snapshot was loaded.
- //
- // - The system images on which the avd is based have not changed.
- Compatible = 0;
-
- // The emulator will not allow loading of the snapshot, as it deems the
- // snapshot to be incompatible. Loading of snapshots can be forced by
- // launching the emulator with the feature "AllowSnapshotMigration" enabled.
- Incompatible = 1;
-
- // This snapshot was successfully loaded in the emulator, and was used at
- // the starting point of the current running emulator. The following holds:
- //
- // A loaded snapshot is a compatible snapshot
- // There is at most one snapshot_id that is in the "Loaded" state
- Loaded = 2;
- }
-
- // The id of this snapshot. Use this id to load/delete/pull the
- // snapshot.
- string snapshot_id = 1;
-
- // Detailed information about this snapshot. This contains a detailed
- // hardware description of the snapshot. These details are the same
- // as the "snapshot.pb" file found in an exported snapshot.
- // Look at the import file for a detailed description of the available
- // fields.
- emulator_snapshot.Snapshot details = 2;
-
- // Provides information about the ability to restore this snapshot.
- LoadStatus status = 3;
-
- // The size of the folder that stores required information to load a snapshot.
- uint64 size = 4;
-}
-
-// A list of on snapshot details.
-message SnapshotList {
- repeated SnapshotDetails snapshots = 1;
-}
-
-message IceboxTarget {
- // This is the process id to attach to, if this value is not set (0)
- // The process name will be used instead.
- int64 pid = 1;
-
- // The process name to attach to if any, if this is not set the pid will
- // be used. This is usually the application name of your application under
- // test, that is passed in to the am instrument command. It is likely
- // what you will find in your AndroidManifest.xml
- string package_name = 2;
-
- // The name of the snapshot that icebox will create if a snapshot is
- // generated.
- string snapshot_id = 3;
-
- // [Output Only] True if icebox failed to track the given target.
- bool failed = 4;
-
- // [Output Only] Detailed error message that might provide more information.
- string err = 5;
-
- // Maximum number of snapshots the emulator can take during one Icebox run.
- // Set to -1 for unlimited number of snapshots.
- int32 max_snapshot_number = 6;
-}
-
-// list of deleted methods:
-//
diff --git a/android_env/proto/state.proto b/android_env/proto/state.proto
deleted file mode 100644
index c562fe52..00000000
--- a/android_env/proto/state.proto
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-option java_multiple_files = true;
-
-message SaveStateRequest {
- map args = 1;
-}
-
-message LoadStateRequest {
- map args = 1;
-}
-
-message SaveStateResponse {
- enum Status {
- // Reserved value for unset statuses.
- UNDEFINED = 0;
- // Returned when everything goes well.
- OK = 1;
- // Returned when something internal did not work as expected.
- ERROR = 2;
- }
- Status status = 1;
- // `error_message` is only populated in case of errors.
- string error_message = 2;
-
- // Any additional info returned during the request; e.g., file paths or sizes.
- map additional_info = 3;
-}
-
-message LoadStateResponse {
- enum Status {
- // Reserved value for unset statuses.
- UNDEFINED = 0;
- // Returned when everything goes well.
- OK = 1;
- // Returned when there is no state to load.
- NOT_FOUND = 2;
- // Returned when something internal did not work as expected.
- ERROR = 3;
- }
- Status status = 1;
- // `error_message` is only populated in case of errors.
- string error_message = 2;
-
- // Any additional info returned during the request; e.g., file paths or sizes.
- map additional_info = 3;
-}
\ No newline at end of file
diff --git a/android_env/proto/task.proto b/android_env/proto/task.proto
deleted file mode 100644
index b39fd3e7..00000000
--- a/android_env/proto/task.proto
+++ /dev/null
@@ -1,211 +0,0 @@
-// Copyright 2024 DeepMind Technologies Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package android_env;
-
-import "android_env/proto/adb.proto";
-
-// An AppScreen identifies a unique configuration that we can observe on the
-// screen of a device.
-message AppScreen {
- // Fully-qualified name of the activity.
- string activity = 1;
-
- // A list of regexes to match at each level of the current view hierarchy.
- // The environment uses this list to determine whether the agent has "exited"
- // this current task.
- // Example: [
- // "^DecorView@.*\[MainActivity\]$",
- // "^android.widget.LinearLayout\{.*\}$",
- // "^android.widget.FrameLayout\{.*android\:id\/content\}",
- // "^android.widget.RelativeLayout\{.*\}",
- // "^android.widget.FrameLayout\{.*app\:id\/fragment_holder\}",
- // "^android.widget.RelativeLayout\{.*\}",
- // "^com.google.example.games.nostalgicracer.views.RaceView3D\{.*app\:id\/gameplay_screen_3d\}",
- // ],
- repeated string view_hierarchy_path = 2;
-}
-
-// Waits for `app_screen` to be the current app screen shown to the user.
-message WaitForAppScreen {
- AppScreen app_screen = 1;
- // Maximum time in seconds to wait for the activity to become the current one.
- float timeout_sec = 2;
-}
-
-message CheckInstall {
- string package_name = 1;
- // Maximum time in seconds to wait.
- float timeout_sec = 2;
-}
-
-message Sleep {
- float time_sec = 1;
-}
-
-message SuccessCondition {
- int32 num_retries = 1;
-
- oneof check {
- WaitForAppScreen wait_for_app_screen = 2;
- CheckInstall check_install = 3;
- }
-}
-
-message SetupStep {
- SuccessCondition success_condition = 1;
-
- oneof step {
- AdbRequest adb_request = 2;
- Sleep sleep = 3;
- }
-}
-
-// A specification of structured observations
-// Analogous to dm_env.specs.Array()
-
-message ArraySpec {
- // An identifier for this ArraySpec.
- string name = 1;
-
- // The shape of the multi-dimensional values associated with this ArraySpec,
- repeated int32 shape = 2;
-
- enum DataType {
- INVALID_DATA_TYPE = 0;
- FLOAT = 1;
- DOUBLE = 2;
- INT8 = 3;
- INT16 = 4;
- INT32 = 5;
- INT64 = 6;
- UINT8 = 7;
- UINT16 = 8;
- UINT32 = 9;
- UINT64 = 10;
- BOOL = 11;
- STRING_U1 = 12;
- STRING_U16 = 13;
- STRING_U25 = 14;
- STRING_U250 = 15;
- STRING = 16; // String without max length
- OBJECT = 17;
- }
-
- // Data type of elements we expect to see in an array of this spec.
- DataType dtype = 3;
-}
-
-message LogParsingConfig {
- // `filters` are tags used by the app's logging system so that we can
- // identify them in logcat's output. It's the first argument to logging calls
- // such as Log.e("ActivityManager", "My message").
- // Example: "ActivityManager"
- repeated string filters = 1;
-
- // Regular expressions that define how we can extract RL information such as
- // score, extras and episode end from raw logcat messages.
- message LogRegexps {
- // Regexp expected to match:
- // ...a floating point value which gets accumulated over time.
- // A delta in 'score' corresponds to the reward.
- string score = 1;
-
- // Regexp expected to match:
- // ...a floating point value directly forwarded by the environment.
- repeated string reward = 2;
-
- // Regexp expected to match:
- // ...a signal marking the end of an episode.
- repeated string episode_end = 3;
-
- // Regexp expected to match:
- // ...a string representing pairs of extra names and values.
- repeated string extra = 4;
-
- // Regexp expected to match:
- // ...a dict of extra names and values in json format.
- repeated string json_extra = 5;
-
- // Attaches rewards to arbitrary log messages, for example:
- // {event: "coin_collected" reward: 2.3}
- // {event: "car_crashed" reward: -1.4}
- message RewardEvent {
- // If `event` is matched, the environment will give `reward`.
- string event = 1;
-
- // Numerical value to give as reward if `event` is matched.
- float reward = 2;
- }
-
- repeated RewardEvent reward_event = 6;
- }
-
- LogRegexps log_regexps = 2;
-}
-
-// Description of a reinforcement learning task to be solved by an agent.
-message Task {
- // A globally unique identifier for this task.
- string id = 1;
-
- // A human readable name for this task.
- string name = 2;
-
- // A description of the task.
- string description = 3;
-
- repeated SetupStep setup_steps = 4;
- repeated SetupStep reset_steps = 5;
-
- AppScreen expected_app_screen = 6;
-
- // AndroidEnv resets the episode after `max_episode_sec` is passed since the
- // last reset(). Recommended for time sensitive tasks (e.g. reactive games).
- // Note that this is real time as measured by AndroidEnv and is independent of
- // the speed of simulation of Android.
- // If <= 0.0, this logic is disabled.
- float max_episode_sec = 7;
-
- // The maximum number of interactions in a single episode between the
- // environment and an agent.
- // This setting is appropriate for tasks that are not time-dependent or when
- // the performance of the simulation varies dramatically between runs.
- // If <= 0, this logic is disabled.
- int32 max_episode_steps = 8;
-
- // Defines parameters for parsing messages from logcat.
- LogParsingConfig log_parsing_config = 9;
-
- // NOTE: This field is deprecated and will be removed from this Task
- // definition soon.
- //
- // (Optional): The task may also define extras to help the RL agent.
- // An Extra in AndroidEnv is any information that apps may send to aid the
- // understanding of the task. The type of information sent through this
- // channel is usually something difficult to obtain from raw pixels and may
- // include things such as:
- //
- // - The current board configuration (e.g. of a chess game or a tetris game)
- // - The position of the avatar in a map
- // - Events (e.g. whether a button was pressed or a checkpoint was achieved)
- //
- // Notice that these are entirely optional and may not be available at all.
- // This specification ensures that only extras specified in the Task
- // definition will be passed to the agent, everything else is excluded.
- // The name of an extra must be unique across all extras.
- repeated ArraySpec extras_spec = 10;
-}
diff --git a/android_env/wrappers/__init__.py b/android_env/wrappers/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/wrappers/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/wrappers/a11y/__init__.py b/android_env/wrappers/a11y/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/android_env/wrappers/a11y/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/android_env/wrappers/a11y/a11y_events.py b/android_env/wrappers/a11y/a11y_events.py
deleted file mode 100644
index 2df3e855..00000000
--- a/android_env/wrappers/a11y/a11y_events.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tools for accessing accessibility events."""
-
-from collections.abc import Mapping
-from typing import Any
-
-from absl import logging
-from android_env.proto.a11y import a11y_pb2
-import numpy as np
-
-from google.protobuf import any_pb2
-
-
-_A11Y_EVENT_KEY = 'full_event'
-
-
-def package_events_to_task_extras(
- events: list[a11y_pb2.EventRequest],
-) -> Mapping[str, np.ndarray]:
- if not events:
- return {}
- events = np.stack(events, axis=0)
- return {_A11Y_EVENT_KEY: events}
-
-
-def extract_events_from_task_extras(
- task_extras: Mapping[str, Any] | None = None,
-) -> list[Mapping[str, str]]:
- """Inspects task_extras and extracts all accessibility events detected.
-
- Args:
- task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a
- key in task_extras, then this function returns an empty string. Otherwise,
- full_event is expected to be list to be a numpy array with one dimension,
- and contains a list of dictionary describing accessibility events that are
- present in the given task extras. e.g. 'event_type:
- TYPE_WINDOW_CONTENT_CHANGED // event_package_name:
- com.google.android.deskclock // source_class_name:
- android.widget.ImageView'.
-
- Returns:
- List of all events detected
- """
- if task_extras is None or _A11Y_EVENT_KEY not in task_extras:
- return []
-
- if (
- not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray)
- or task_extras[_A11Y_EVENT_KEY].ndim != 1
- ):
- raise ValueError(
- f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one'
- ' dimension.'
- )
-
- if task_extras[_A11Y_EVENT_KEY].size == 0:
- return []
-
- events = []
- for e in task_extras[_A11Y_EVENT_KEY]:
- if isinstance(e, a11y_pb2.EventRequest):
- events.append(dict(e.event))
- elif isinstance(e, dict):
- events.append(e)
- logging.warning(
- 'The event should come only from the a11y_grpc_wrapper. '
- 'Please verify that the upacking operation has not been '
- 'called twice. See here for full task_extras: %s',
- task_extras,
- )
- elif isinstance(e, any_pb2.Any):
- ev = a11y_pb2.EventRequest()
- new_any = any_pb2.Any()
- new_any.CopyFrom(e)
- new_any.Unpack(ev)
- events.append(dict(ev.event))
-
- else:
- raise TypeError(
- f'Unexpected event type: {type(e)}. See here for full '
- f'task_extras: {task_extras}.'
- )
-
- return events
-
-
-def keep_latest_event_only(task_extras: dict[str, Any]):
- """Removes all a11y events except the last one observed."""
- if task_extras is None or 'full_event' not in task_extras:
- return
-
- if (
- not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray)
- or task_extras[_A11Y_EVENT_KEY].ndim != 1
- ):
- raise ValueError(
- f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one'
- ' dimension.'
- )
-
- if task_extras[_A11Y_EVENT_KEY].size == 0:
- return []
-
- task_extras[_A11Y_EVENT_KEY] = task_extras[_A11Y_EVENT_KEY][-1:]
diff --git a/android_env/wrappers/a11y/a11y_events_test.py b/android_env/wrappers/a11y/a11y_events_test.py
deleted file mode 100644
index 400fd801..00000000
--- a/android_env/wrappers/a11y/a11y_events_test.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for a11y_events."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.proto.a11y import a11y_pb2
-from android_env.wrappers.a11y import a11y_events
-import numpy as np
-
-from google.protobuf import any_pb2
-
-
-def _event_request(d: dict[str, str]) -> a11y_pb2.EventRequest:
- event_request = a11y_pb2.EventRequest()
- for k, v in d.items():
- event_request.event[k] = v
- return event_request
-
-
-def _event_request_as_any(d: dict[str, str]) -> any_pb2.Any:
- event_request = _event_request(d)
- response = any_pb2.Any()
- response.Pack(event_request)
- return response
-
-
-class A11yEventsTest(parameterized.TestCase):
-
- @parameterized.parameters(
- dict(task_extras={}),
- dict(
- task_extras={'no_full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}]},
- ),
- dict(
- task_extras={'full_event': np.array([])},
- ),
- dict(
- task_extras={},
- ),
- )
- def test_no_events_in_task_extras(self, task_extras):
- events = a11y_events.extract_events_from_task_extras(task_extras)
- self.assertEmpty(events)
-
- @parameterized.parameters(
- dict(
- task_extras={'full_event': [{'1': '1'}, {'2': '2'}]},
- expected_events=[{'1': '1'}, {'2': '2'}],
- ),
- dict(
- task_extras={'full_event': [{}]},
- expected_events=[{}],
- ),
- dict(
- task_extras={
- 'full_event_wrong_key': [1, 2, 3],
- 'full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}],
- },
- expected_events=[{'1': '1'}, {'2': '2'}, {'3': '3'}],
- ),
- )
- def test_task_extras(self, task_extras, expected_events):
- event_requests = [_event_request(e) for e in task_extras['full_event']]
- task_extras['full_event'] = np.stack(event_requests, axis=0)
- events = a11y_events.extract_events_from_task_extras(task_extras)
- self.assertEqual(len(events), len(expected_events))
- for i, event in enumerate(expected_events):
- self.assertEqual(len(event), len(expected_events[i]))
- for k, v in event.items():
- self.assertIn(k, expected_events[i])
- self.assertEqual(v, expected_events[i][k])
-
- def test_events_key_has_dict_event_requrests(self):
- event_requests = [
- _event_request({'1': '1'}),
- {'2': '2'},
- _event_request({'3': '3'}),
- ]
- expected_events = [
- {'1': '1'},
- {'2': '2'},
- {'3': '3'},
- ]
- task_extras = {'full_event': np.stack(event_requests, axis=0)}
- events = a11y_events.extract_events_from_task_extras(task_extras)
- self.assertEqual(len(events), len(expected_events))
- for i, event in enumerate(expected_events):
- self.assertEqual(len(event), len(expected_events[i]))
- for k, v in event.items():
- self.assertIn(k, expected_events[i])
- self.assertEqual(v, expected_events[i][k])
-
- def test_events_key_has__event_requrests_packed_as_any(self):
- event_requests = [
- _event_request_as_any({'1': '1'}),
- {'2': '2'},
- _event_request_as_any({'3': '3'}),
- ]
- expected_events = [
- {'1': '1'},
- {'2': '2'},
- {'3': '3'},
- ]
- task_extras = {'full_event': np.stack(event_requests, axis=0)}
- events = a11y_events.extract_events_from_task_extras(task_extras)
- self.assertEqual(len(events), len(expected_events))
- for i, event in enumerate(expected_events):
- self.assertEqual(len(event), len(expected_events[i]))
- for k, v in event.items():
- self.assertIn(k, expected_events[i])
- self.assertEqual(v, expected_events[i][k])
-
- def test_events_key_has_non_event_requrests(self):
- event_requests = [
- _event_request({'1': '1'}),
- 3, # Not an even and not a dict.
- _event_request({'3': '3'}),
- ]
- task_extras = {'full_event': np.stack(event_requests, axis=0)}
- with self.assertRaises(TypeError):
- _ = a11y_events.extract_events_from_task_extras(task_extras)
-
- @parameterized.parameters(
- dict(task_extras={}, expected_extras={}),
- dict(
- task_extras={
- 'no_full_event': 42,
- },
- expected_extras={
- 'no_full_event': 42,
- },
- ),
- dict(
- task_extras={'full_event': np.array([1, 2]), 'no_full_event': 43},
- expected_extras={'full_event': np.array([2]), 'no_full_event': 43},
- ),
- dict(
- task_extras={'full_event': np.array([1, 2, 3])},
- expected_extras={'full_event': np.array([3])},
- ),
- dict(
- task_extras={'full_event': np.array([]), 'no_full_event': 44},
- expected_extras={'full_event': np.array([]), 'no_full_event': 44},
- ),
- )
- def test_keep_latest_only(self, task_extras, expected_extras):
- a11y_events.keep_latest_event_only(task_extras)
- self.assertEqual(len(task_extras), len(expected_extras))
- for k, v in task_extras.items():
- self.assertIn(k, expected_extras)
- if k == 'full_event':
- np.testing.assert_array_equal(v, expected_extras['full_event'])
- else:
- self.assertEqual(v, expected_extras[k])
- pass
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/a11y/a11y_forests.py b/android_env/wrappers/a11y/a11y_forests.py
deleted file mode 100644
index 1cd8ef2d..00000000
--- a/android_env/wrappers/a11y/a11y_forests.py
+++ /dev/null
@@ -1,128 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tools for accessing accessibility events."""
-
-from collections.abc import Mapping
-from typing import Any
-
-from android_env.proto.a11y import android_accessibility_forest_pb2
-import numpy as np
-
-from google.protobuf import any_pb2
-
-
-_A11Y_FORESTS_KEY = 'accessibility_tree'
-
-
-def package_forests_to_task_extras(
- forests: list[android_accessibility_forest_pb2.AndroidAccessibilityForest],
-) -> Mapping[str, np.ndarray]:
- if not forests:
- return {}
- forests = np.stack(forests, axis=0)
- return {_A11Y_FORESTS_KEY: forests}
-
-
-def task_extras_has_forests(task_extras: Mapping[str, Any]) -> bool:
- """Checks that the task_extras has any a11y forest information."""
- if _A11Y_FORESTS_KEY not in task_extras:
- return False
-
- payload = task_extras[_A11Y_FORESTS_KEY]
- if not isinstance(payload, np.ndarray) or payload.ndim != 1:
- raise ValueError(
- f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one'
- f' dimension. payload: {payload}'
- )
-
- if payload.size == 0:
- return False
-
- if any(isinstance(f, any_pb2.Any) for f in payload):
- # Forests were packed as Any.
- return True
-
- return any(
- isinstance(f, android_accessibility_forest_pb2.AndroidAccessibilityForest)
- for f in payload
- )
-
-
-def convert_to_forest(
- forest: android_accessibility_forest_pb2.AndroidAccessibilityForest
- | any_pb2.Any
- | None,
-) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None:
- """Takes an object and attempts to convert it to a forest."""
- if forest is None:
- return None
-
- if isinstance(forest, any_pb2.Any):
- output = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- new_any = any_pb2.Any()
- new_any.CopyFrom(forest)
- new_any.Unpack(output)
- return output
- elif isinstance(
- forest, android_accessibility_forest_pb2.AndroidAccessibilityForest
- ):
- return forest
- else:
- return None
-
-
-def extract_forests_from_task_extras(
- task_extras: Mapping[str, Any] | None = None,
-) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]:
- """Inspects task_extras and extracts all accessibility forests detected.
-
- Args:
- task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a
- key in task_extras, then this function returns an empty string. Otherwise,
- full_event is expected to be list to be a numpy array with one dimension,
- and contains a list of dictionary describing accessibility forests that
- are present in the given task extras.
-
- Returns:
- List of all forests detected
- """
- if task_extras is None or not task_extras_has_forests(task_extras):
- return []
-
- forests = []
- for f in task_extras[_A11Y_FORESTS_KEY]:
- f = convert_to_forest(f)
- if f is not None:
- forests.append(f)
- return forests
-
-
-def keep_latest_forest_only(task_extras: dict[str, Any]):
- """Removes all a11y forests except the last one observed."""
- if _A11Y_FORESTS_KEY not in task_extras.keys():
- return
-
- payload = task_extras[_A11Y_FORESTS_KEY]
- if not isinstance(payload, np.ndarray) or payload.ndim != 1:
- raise ValueError(
- f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one'
- f' dimension. payload: {payload}'
- )
-
- if payload.size == 0:
- return
-
- task_extras[_A11Y_FORESTS_KEY] = payload[-1:]
diff --git a/android_env/wrappers/a11y/a11y_forests_test.py b/android_env/wrappers/a11y/a11y_forests_test.py
deleted file mode 100644
index b57f30aa..00000000
--- a/android_env/wrappers/a11y/a11y_forests_test.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for a11y_forests."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.proto.a11y import android_accessibility_forest_pb2
-from android_env.wrappers.a11y import a11y_forests
-import numpy as np
-
-from google.protobuf import any_pb2
-
-
-def _pack_any(proto_message) -> any_pb2.Any:
- response = any_pb2.Any()
- response.Pack(proto_message)
- return response
-
-
-def _empty_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- return android_accessibility_forest_pb2.AndroidAccessibilityForest()
-
-
-def _one_empty_window_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- forest.windows.add()
- return forest
-
-
-def _two_window_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- window = forest.windows.add()
- window.tree.nodes.add(
- class_name='foo', is_clickable=True, hint_text='Foo hint'
- )
- forest.windows.add()
- return forest
-
-
-class A11YForestsTest(parameterized.TestCase):
-
- @parameterized.parameters(
- dict(task_extras={}, expected_forests=[], convert_to_np=[]),
- dict(
- task_extras={'accessibility_tree': []},
- convert_to_np=['accessibility_tree'],
- expected_forests=[],
- ),
- dict(
- task_extras={
- 'not_accessibility_tree': [
- _empty_forest(),
- _one_empty_window_forest(),
- _two_window_forest(),
- ],
- },
- convert_to_np=['not_accessibility_tree'],
- expected_forests=[],
- ),
- dict(
- task_extras={
- 'accessibility_tree': [
- _empty_forest(),
- {'not_a_forest_key': 'nor_a_forest_value'},
- _two_window_forest(),
- ]
- },
- convert_to_np=['accessibility_tree'],
- expected_forests=[_empty_forest(), _two_window_forest()],
- ),
- dict(
- task_extras={
- 'accessibility_tree': [
- {'not_a_forest_key': 'nor_a_forest_value'},
- 3,
- 4,
- {'not_a_forest_key': _empty_forest()},
- ],
- },
- convert_to_np=['accessibility_tree'],
- expected_forests=[],
- ),
- dict(
- task_extras={'accessibility_tree': []},
- convert_to_np=['accessibility_tree'],
- expected_forests=[],
- ),
- dict(
- task_extras={
- 'accessibility_tree_wrong_key': [1, 2, 3],
- 'accessibility_tree': [
- _empty_forest(),
- None,
- None,
- _one_empty_window_forest(),
- _two_window_forest(),
- ],
- },
- convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'],
- expected_forests=[
- _empty_forest(),
- _one_empty_window_forest(),
- _two_window_forest(),
- ],
- ),
- dict(
- task_extras={
- 'accessibility_tree_wrong_key': [1, 2, 3],
- 'accessibility_tree': [
- None,
- _pack_any(_empty_forest()),
- _pack_any(_one_empty_window_forest()),
- _pack_any(_two_window_forest()),
- ],
- },
- convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'],
- expected_forests=[
- _empty_forest(),
- _one_empty_window_forest(),
- _two_window_forest(),
- ],
- ),
- dict(
- task_extras={
- 'accessibility_tree': [
- _pack_any(_empty_forest()),
- {'not_a_forest_key': 'nor_a_forest_value'},
- None,
- _two_window_forest(),
- None,
- ]
- },
- convert_to_np=['accessibility_tree'],
- expected_forests=[_empty_forest(), _two_window_forest()],
- ),
- )
- def test_task_extras(self, task_extras, expected_forests, convert_to_np):
- for k in convert_to_np:
- if task_extras[k]:
- task_extras[k] = np.stack(task_extras[k], axis=0)
- else:
- task_extras[k] = np.array([])
- forests = a11y_forests.extract_forests_from_task_extras(task_extras)
- self.assertEqual(len(forests), len(expected_forests))
- for idx, f in enumerate(forests):
- self.assertEqual(f, expected_forests[idx])
-
- @parameterized.parameters(
- dict(task_extras={}, expected_extras={}),
- dict(
- task_extras={
- 'no_accessibility_tree': 42,
- },
- expected_extras={
- 'no_accessibility_tree': 42,
- },
- ),
- dict(
- task_extras={'accessibility_tree': []},
- expected_extras={'accessibility_tree': []},
- ),
- dict(
- task_extras={
- 'accessibility_tree': [
- _empty_forest(),
- _one_empty_window_forest(),
- ],
- 'no_accessibility_tree': 43,
- },
- expected_extras={
- 'accessibility_tree': [_one_empty_window_forest()],
- 'no_accessibility_tree': 43,
- },
- ),
- dict(
- task_extras={
- 'accessibility_tree': [
- _empty_forest(),
- _one_empty_window_forest(),
- _two_window_forest(),
- ]
- },
- expected_extras={'accessibility_tree': [_two_window_forest()]},
- ),
- dict(
- task_extras={
- 'accessibility_tree': [],
- 'no_accessibility_tree': 44,
- },
- expected_extras={
- 'accessibility_tree': [],
- 'no_accessibility_tree': 44,
- },
- ),
- )
- def test_keep_latest_only(self, task_extras, expected_extras):
- if 'accessibility_tree' in task_extras:
- if task_extras['accessibility_tree']:
- task_extras['accessibility_tree'] = np.stack(
- task_extras['accessibility_tree'], axis=0
- )
- else:
- task_extras['accessibility_tree'] = np.array([])
-
- a11y_forests.keep_latest_forest_only(task_extras)
- self.assertSameElements(task_extras.keys(), expected_extras.keys())
- for k in task_extras.keys():
- if k == 'accessibility_tree':
- self.assertEqual(len(task_extras[k]), len(expected_extras[k]))
- for idx, f in enumerate(task_extras[k]):
- self.assertEqual(f, expected_extras[k][idx])
- else:
- self.assertEqual(task_extras[k], expected_extras[k])
- pass
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/a11y/a11y_servicer.py b/android_env/wrappers/a11y/a11y_servicer.py
deleted file mode 100644
index 82e9e253..00000000
--- a/android_env/wrappers/a11y/a11y_servicer.py
+++ /dev/null
@@ -1,199 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Accessibility Servicer implementation."""
-
-import asyncio
-from collections.abc import AsyncIterator, Generator, Iterable
-import threading
-
-from absl import logging
-from android_env.proto.a11y import a11y_pb2
-from android_env.proto.a11y import a11y_pb2_grpc
-from android_env.proto.a11y import android_accessibility_forest_pb2
-import grpc
-
-
-class A11yServicer(a11y_pb2_grpc.A11yServiceServicer):
- """Services the A11yService requests."""
-
- def __init__(self, latest_forest_only: bool = False):
- self._received_forests: list[
- android_accessibility_forest_pb2.AndroidAccessibilityForest
- ] = []
- self._received_events: list[a11y_pb2.EventRequest] = []
- self._lock_forests = threading.Lock()
- self._lock_events = threading.Lock()
- self._latest_forest_only = latest_forest_only
- self._paused = True
-
- # A11y Forest bookkeeping.
- self._get_forest = asyncio.Event() # Whether to request a forest.
- self._forest_ready = asyncio.Event() # Whether the forest is ready.
- self._latest_forest: (
- android_accessibility_forest_pb2.AndroidAccessibilityForest | None
- ) = None
-
- def SendForest(
- self,
- request: android_accessibility_forest_pb2.AndroidAccessibilityForest,
- context: grpc.ServicerContext,
- ) -> a11y_pb2.ForestResponse:
- self._process_forest(request)
- return a11y_pb2.ForestResponse()
-
- def SendEvent(
- self,
- request: a11y_pb2.EventRequest,
- context: grpc.ServicerContext,
- ) -> a11y_pb2.EventResponse:
- self._process_event(request)
- return a11y_pb2.EventResponse()
-
- async def Bidi(
- self,
- request_iterator: AsyncIterator[a11y_pb2.ClientToServer],
- context: grpc.aio.ServicerContext,
- ) -> AsyncIterator[a11y_pb2.ServerToClient]:
- """Processes incoming ClientToServer requests."""
-
- logging.info('Starting A11yServicer.Bidi()')
-
- # Send a dummy message to unblock clients in their loop.
- yield a11y_pb2.ServerToClient()
-
- # This block defines two coroutines:
- #
- # * `read_client_requests()`
- # * `check_forest()`
- #
- # They cooperate with each other and both populate a queue `q` which is
- # consumed in a loop below, which actually yields requests which are sent to
- # the client. The processing finishes when the clients "closes" the
- # connection, which causes `read_client_requests()` to put a special value,
- # `STOP_ITERATION`, in the queue.
-
- # Queue for communicating from coroutines to `Bidi()`.
- q = asyncio.Queue()
-
- should_run = True
-
- async def read_client_requests():
- """Coroutine for reading client requests."""
-
- nonlocal should_run
- async for request in request_iterator:
- field_name = request.WhichOneof('payload')
- match field_name:
- case 'event':
- self._process_event(request.event)
- case 'forest':
- self._latest_forest = request.forest
- self._forest_ready.set()
- self._get_forest.clear() # Reset the `Event`.
- case _:
- logging.error('Unknown field %r', field_name)
- await q.put(a11y_pb2.ServerToClient())
-
- # Send a special value to stop processing this `Bidi` connection.
- await q.put('STOP_ITERATION')
- should_run = False
-
- async def check_forest():
- """Coroutine for sending "get forest" requests."""
-
- nonlocal should_run
- while should_run:
- await self._get_forest.wait()
- await q.put(a11y_pb2.ServerToClient(get_forest={}))
-
- tasks = asyncio.gather(read_client_requests(), check_forest())
-
- while should_run:
- v = await q.get()
- if v == 'STOP_ITERATION':
- break
- else:
- yield v
-
- await tasks
-
- logging.info('Finishing A11yServicer.Bidi()')
-
- async def get_forest(
- self,
- ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None:
- """Issues a request to get the a11y forest from the client."""
-
- self._get_forest.set() # Unblocks coroutine to send a request.
- await self._forest_ready.wait() # Wait for forest to be ready.
- self._forest_ready.clear() # Reset the `Event`.
- return self._latest_forest
-
- def gather_forests(
- self,
- ) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]:
- forests = []
- with self._lock_forests:
- forests = self._received_forests
- self._received_forests = []
- return forests
-
- def gather_events(self) -> list[a11y_pb2.EventRequest]:
- events = []
- with self._lock_events:
- events = self._received_events
- self._received_events = []
- return events
-
- def pause_and_clear(self) -> None:
- """Temporarily stop receiving events/forests and clear the queue.
-
- Used when resetting the environment; in this case:
- - all events/forests that have been received since last timestep are things
- that happened in the last episode after its `LAST` timestep (so we should
- ignore them, done by clearing the lists).
- - we're about to receive a bunch of events/forests just as a result of
- resetting the environment. We don't want to count these either; thus we
- temporarily stop receiving new ones.
- """
- self._paused = True
- with self._lock_forests:
- self._received_forests = []
- with self._lock_events:
- self._received_events = []
-
- def resume(self) -> None:
- """Start receiving events/forests (e.g., after a reset)."""
- self._paused = False
-
- def _process_event(self, event: a11y_pb2.EventRequest) -> None:
- """Adds the given event to the internal buffer of events."""
-
- if not self._paused:
- with self._lock_events:
- self._received_events.append(event)
-
- def _process_forest(
- self, forest: android_accessibility_forest_pb2.AndroidAccessibilityForest
- ) -> None:
- """Adds the given forest to the internal buffer of forests."""
-
- if not self._paused:
- with self._lock_forests:
- if self._latest_forest_only:
- self._received_forests = [forest]
- else:
- self._received_forests.append(forest)
diff --git a/android_env/wrappers/a11y/a11y_servicer_test.py b/android_env/wrappers/a11y/a11y_servicer_test.py
deleted file mode 100644
index 91be6024..00000000
--- a/android_env/wrappers/a11y/a11y_servicer_test.py
+++ /dev/null
@@ -1,224 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for a11y_servicer."""
-
-import asyncio
-from collections.abc import AsyncIterator, Iterable
-from typing import TypeVar
-from unittest import IsolatedAsyncioTestCase, mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env.proto.a11y import a11y_pb2
-from android_env.proto.a11y import android_accessibility_forest_pb2
-from android_env.wrappers.a11y import a11y_servicer
-import grpc
-
-
-_T = TypeVar('_T')
-
-
-async def _aiter(xs: Iterable[_T]) -> AsyncIterator[_T]:
- """Utility to make an AsyncIterator from Iterable."""
-
- for x in xs:
- yield x
-
-
-def one_window_one_node_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'foo'
- node.is_clickable = True
- node.hint_text = 'Foo hint'
- return forest
-
-
-def one_window_two_nodes_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'bar'
- node.is_clickable = True
- node.hint_text = 'Bar hint'
- node = window.tree.nodes.add()
- node.class_name = 'bar'
- node.is_clickable = False
- node.hint_text = 'Bar hint 2'
- return forest
-
-
-def empty_dict() -> dict[str, str]:
- return {}
-
-
-def single_item_dict_with_special_chars() -> dict[str, str]:
- return {'foo': 'bar\r\t\nbaz'}
-
-
-class A11yServicerTest(parameterized.TestCase, IsolatedAsyncioTestCase):
-
- def test_servicer_sendforest(self):
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
- servicer.resume()
- response = servicer.SendForest(one_window_one_node_forest(), mock_context)
- self.assertEqual(response.error, '')
- response = servicer.SendForest(one_window_two_nodes_forest(), mock_context)
- self.assertEqual(response.error, '')
- forests = servicer.gather_forests()
- self.assertLen(forests, 2)
- self.assertEqual(forests[0], one_window_one_node_forest())
- self.assertEqual(forests[1], one_window_two_nodes_forest())
-
- async def test_servicer_bidi_forests(self):
- """Checks that the bidirectional interface accepts forests."""
-
- # Arrange.
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
-
- # Act.
- servicer.resume()
- responses = [
- x
- async for x in servicer.Bidi(
- _aiter([
- a11y_pb2.ClientToServer(
- event=a11y_pb2.EventRequest(
- event=single_item_dict_with_special_chars()
- )
- ),
- a11y_pb2.ClientToServer(forest=one_window_two_nodes_forest()),
- ]),
- mock_context,
- )
- ]
- forest = await servicer.get_forest()
-
- # Assert.
- self.assertEqual(responses[0], a11y_pb2.ServerToClient())
- self.assertEqual(responses[1], a11y_pb2.ServerToClient())
- self.assertIsNotNone(forest)
- self.assertEqual(forest, one_window_two_nodes_forest())
-
- def test_servicer_sendforest_latest_only(self):
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer(latest_forest_only=True)
- servicer.resume()
- response = servicer.SendForest(one_window_one_node_forest(), mock_context)
- self.assertEqual(response.error, '')
- response = servicer.SendForest(one_window_two_nodes_forest(), mock_context)
- self.assertEqual(response.error, '')
- forests = servicer.gather_forests()
- self.assertLen(forests, 1)
- self.assertEqual(forests[0], one_window_two_nodes_forest())
-
- def test_servicer_sendevent(self):
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
- servicer.resume()
- response = servicer.SendEvent(
- a11y_pb2.EventRequest(event=empty_dict()), mock_context
- )
- self.assertEqual(response.error, '')
- response = servicer.SendEvent(
- a11y_pb2.EventRequest(event=single_item_dict_with_special_chars()),
- mock_context,
- )
- self.assertEqual(response.error, '')
- events = servicer.gather_events()
- self.assertLen(events, 2)
- self.assertEqual(events[0].event, empty_dict())
- self.assertEqual(events[1].event, single_item_dict_with_special_chars())
-
- async def test_servicer_bidi_events(self):
- """Checks that the bidirectional interface accepts events."""
-
- # Arrange.
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
-
- # Act.
- servicer.resume()
- responses = [
- x
- async for x in servicer.Bidi(
- _aiter([
- a11y_pb2.ClientToServer(
- event=a11y_pb2.EventRequest(event=empty_dict())
- ),
- a11y_pb2.ClientToServer(
- event=a11y_pb2.EventRequest(
- event=single_item_dict_with_special_chars()
- )
- ),
- ]),
- mock_context,
- )
- ]
- events = servicer.gather_events()
-
- # Assert.
- self.assertEqual(responses[0], a11y_pb2.ServerToClient())
- self.assertEqual(responses[1], a11y_pb2.ServerToClient())
- self.assertLen(events, 2)
- self.assertEqual(events[0].event, empty_dict())
- self.assertEqual(events[1].event, single_item_dict_with_special_chars())
-
- def test_servicer_pause_and_clear_pauses(self):
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
- servicer.resume()
- servicer.pause_and_clear()
- response = servicer.SendEvent(
- a11y_pb2.EventRequest(event=empty_dict()), mock_context
- )
- self.assertEqual(response.error, '')
- response = servicer.SendForest(one_window_one_node_forest(), mock_context)
- self.assertEqual(response.error, '')
- events = servicer.gather_events()
- self.assertEmpty(events)
- forests = servicer.gather_forests()
- self.assertEmpty(forests)
-
- def test_servicer_pause_and_clear_clears(self):
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- servicer = a11y_servicer.A11yServicer()
- servicer.resume()
- response = servicer.SendEvent(
- a11y_pb2.EventRequest(event=empty_dict()), mock_context
- )
- self.assertEqual(response.error, '')
- response = servicer.SendForest(one_window_one_node_forest(), mock_context)
- self.assertEqual(
- response.error,
- '',
- )
- servicer.pause_and_clear()
- events = servicer.gather_events()
- self.assertEmpty(events)
- forests = servicer.gather_forests()
- self.assertEmpty(forests)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/a11y_grpc_wrapper.py b/android_env/wrappers/a11y_grpc_wrapper.py
deleted file mode 100644
index 5b010ad3..00000000
--- a/android_env/wrappers/a11y_grpc_wrapper.py
+++ /dev/null
@@ -1,500 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps AndroidEnv to retrieve accessibility messages from gRPC."""
-
-from concurrent import futures
-import time
-from typing import Any
-
-import urllib
-
-from absl import logging
-from android_env import env_interface
-from android_env.components import action_type as android_action_type_lib
-from android_env.proto import adb_pb2
-from android_env.proto.a11y import a11y_pb2_grpc
-from android_env.wrappers import base_wrapper
-from android_env.wrappers.a11y import a11y_events
-from android_env.wrappers.a11y import a11y_forests
-from android_env.wrappers.a11y import a11y_servicer
-import dm_env
-import grpc
-import numpy as np
-import portpicker
-
-
-def _get_accessibility_forwarder_apk() -> bytes:
- logging.info('Downloading accessibility forwarder apk....')
- with urllib.request.urlopen(
- 'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk'
- ) as response:
- return response.read()
-
-
-class EnableNetworkingError(ValueError):
- pass
-
-
-class A11yGrpcWrapper(base_wrapper.BaseWrapper):
- """Wrapper which receives A11y events and forests over gRPC.
-
- A11y forest protobufs and event dicts are sent from the Android emulator via
- gRPC from the `AccessibilityForwarder` (for use in developing reward
- functions, etc). This wrapper constructs a server which receives these
- messages and channels them into `task_extras`.
-
- The downside of forwarding this information through gRPC is that no messages
- will be sent if networking is turned off (e.g., if the AVD is in airplane
- mode). To mitigate this problem, the `AccessibilityForwarder` logs an error
- message if it fails to contact the server. This wrapper monitors the logs for
- such error messages, and attempts (in another thread, to not block environment
- transitions) to reconnect the AVD to the network. If this fails to fix the
- problem, this wrapper ends the episode.
-
- This wrapper is implemented to be robust to multiple upstream callers of
- `task_extras`, and to ensure they each receive the same extras at every
- timestep. Thus, the logic is the following:
- * New a11y events/forests are fetched during `reset` and `step`, *not* during
- `task_extras()` calls.
- * If no one has called `task_extras()` since the last `step` or `reset`, the
- extras are accumulated (so that no extras are missed because someone called
- `step()` twice without calling `task_extras()`).
- * If someone *has* called `task_extras()` since last step, the newly fetched
- extras replace the old extras.
- """
-
- def __init__(
- self,
- env: env_interface.AndroidEnvInterface,
- disable_other_network_traffic: bool = False,
- install_a11y_forwarding: bool = False,
- start_a11y_service: bool = True,
- enable_a11y_tree_info: bool = False,
- add_latest_a11y_info_to_obs: bool = False,
- a11y_info_timeout: float | None = None,
- max_enable_networking_attempts: int = 10,
- latest_a11y_info_only: bool = False,
- ):
- """Initializes wrapper.
-
- Args:
- env: Environment to wrap.
- disable_other_network_traffic: When True, all network traffic, other than
- the connection to the servicer, is disabled. NOTE: This requires root
- access on the device (i.e. it uses the `su` command). An
- `AdbControllerError` exception will be raised if the underlying command
- fails.
- install_a11y_forwarding: If True, the wrapper handles the installation of
- all packages required for the servicer to collect a11y information.
- start_a11y_service: If True, starts the a11y forwarding services. NOTE:
- The packages must be installed beforehand, e.g., using the
- install_a11y_forwarding flag.
- enable_a11y_tree_info: When False, this wrapper collects only a11y events
- and not a11y tree.
- add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is
- added to the observation.
- a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs
- is set to True, the wrapper will wait the corresponding amount of time,
- measured in seconds, to collect the latest a11y forest.
- max_enable_networking_attempts: When the a11y gRPC service fails to
- provide a11y information, we attempt this many times to re-enable the
- networking. If all these attempts fail, fetching task_extras will raise
- an EnableNetworkingError.
- latest_a11y_info_only: When True, the a11y servicer is setup to save only
- the latest tree it has received from the Android app.
- """
- self._env = env
- if install_a11y_forwarding:
- self._install_a11y_forwarding_apk()
- time.sleep(10.0)
- if start_a11y_service:
- self._start_a11y_services()
- time.sleep(3.0)
- if enable_a11y_tree_info:
- self._enable_a11y_tree_logs()
- self._relaunch_count = 0
- self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- self._servicer = a11y_servicer.A11yServicer(
- latest_forest_only=latest_a11y_info_only
- )
- a11y_pb2_grpc.add_A11yServiceServicer_to_server(
- self._servicer, self._server
- )
- server_credentials = grpc.local_server_credentials()
- self._port = portpicker.pick_unused_port()
- logging.info('Using port %s', self._port)
- uri_address = f'[::]:{self._port}'
- self._server.add_secure_port(uri_address, server_credentials)
- logging.info('Starting server')
- self._server.start()
- logging.info('Server now running.')
-
- self._max_enable_networking_attempts = max_enable_networking_attempts
- self._reset_enable_networking_attempts()
-
- self._disable_other_network_traffic = disable_other_network_traffic
- self._should_accumulate = False
- self._accumulated_extras = None
- self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs
- self._a11y_info_timeout = a11y_info_timeout
- self._parent_action_spec = self._env.action_spec()
- if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0:
- if 'action_type' not in self._parent_action_spec.keys():
- raise ValueError(
- 'action_type not in the parent action spec: '
- f'{self._parent_action_spec}. This is a strong requirement when '
- f'a11y_info_timeout = {a11y_info_timeout} > 0'
- )
-
- def _start_a11y_services(self) -> None:
- """Starts the accessibility forwarder services.
-
- Raises:
- RuntimeError: If accessibility service is not started.
- """
- start_service_request = adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='enabled_accessibility_services',
- value=(
- 'com.google.androidenv.accessibilityforwarder/com.google.'
- 'androidenv.accessibilityforwarder.AccessibilityForwarder'
- ),
- ),
- )
- )
- start_service_response = self._env.execute_adb_call(start_service_request)
- if start_service_response.status != adb_pb2.AdbResponse.Status.OK:
- raise RuntimeError(
- 'Could not start accessibility forwarder '
- 'service: '
- f'{start_service_response}.'
- )
-
- def _install_a11y_forwarding_apk(self) -> None:
- """Enables accessibility information forwarding."""
- a11y_fwd_apk = _get_accessibility_forwarder_apk()
- # Install and setup the Accesssibility Forwarder.
- install_request = adb_pb2.AdbRequest(
- install_apk=adb_pb2.AdbRequest.InstallApk(
- blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk),
- )
- )
- install_response = self._env.execute_adb_call(install_request)
- if install_response.status != adb_pb2.AdbResponse.Status.OK:
- raise ValueError(
- f'Could not install accessibility_forwarder.apk: {install_response}.'
- )
-
- def _enable_a11y_tree_logs(self) -> None:
- enable_tree_logs_request = adb_pb2.AdbRequest(
- send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
- action=(
- 'accessibility_forwarder.intent.action.'
- 'ENABLE_ACCESSIBILITY_TREE_LOGS'
- ),
- component=(
- 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
- ),
- )
- )
- enable_tree_logs_response = self._env.execute_adb_call(
- enable_tree_logs_request
- )
- if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK:
- raise ValueError(
- 'Could not enable accessibility tree logging: '
- f'{enable_tree_logs_response}.'
- )
-
- def _reset_enable_networking_attempts(self) -> None:
- self._enable_networking_attempts_left = self._max_enable_networking_attempts
- self._enabling_networking_future = None
- self._a11y_exception = None
-
- def get_port(self):
- return self._port
-
- def close(self):
- self._server.stop(None)
- logging.info('gRPC server stopped')
- self._env.close()
-
- def attempt_enable_networking(self) -> None:
- """Attempts to turn on networking within the Android device.
-
- Attempt to turn on the networking in the Android device, by:
- - turning off airplane mode;
- - turning on the wifi connection.
- """
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='airplane_mode_on', value='0'
- ),
- )
- )
- )
- time.sleep(1.0)
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- generic=adb_pb2.AdbRequest.GenericRequest(
- args=[
- 'shell',
- 'svc',
- 'wifi',
- 'enable',
- ]
- )
- )
- )
- time.sleep(1.0)
-
- def _configure_grpc(self) -> None:
- """Configure networking and set the gRPC port in the AVD."""
-
- if self._disable_other_network_traffic:
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- generic=adb_pb2.AdbRequest.GenericRequest(
- args=[
- 'shell',
- 'su',
- '0',
- 'iptables',
- '-A',
- 'OUTPUT',
- '-p',
- 'tcp',
- '-d',
- '10.0.2.2',
- '--dport',
- str(self._port),
- '-j',
- 'ACCEPT',
- ]
- )
- )
- )
- time.sleep(3.0)
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- generic=adb_pb2.AdbRequest.GenericRequest(
- args=[
- 'shell',
- 'su',
- '0',
- 'iptables',
- '-A',
- 'OUTPUT',
- '-j',
- 'DROP',
- ]
- )
- )
- )
- time.sleep(3.0)
-
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- settings=adb_pb2.AdbRequest.SettingsRequest(
- name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
- put=adb_pb2.AdbRequest.SettingsRequest.Put(
- key='no_proxy', value=f'10.0.2.2:{self._port}'
- ),
- )
- )
- )
- self.attempt_enable_networking()
- self.execute_adb_call(
- adb_pb2.AdbRequest(
- send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
- action=(
- 'accessibility_forwarder.intent.action.SET_GRPC --ei'
- f' "port" {self._port}'
- ),
- component=(
- 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
- ),
- )
- )
- )
-
- def _accumulate_and_return_a11y_info(
- self, timer: float | None = None, get_env_observation: bool = True
- ) -> dict[str, Any]:
- """Accumulates and returns the latest a11y tree info and observation.
-
- Args:
- timer: If larger than 0, the system will wait this long for a11y info to
- accumulate before it returns a value.
- get_env_observation: If False, the corresponding observation is not
- introduced here.
-
- Returns:
- a dict with a11y forest under key 'a11y_forest'. All other fields will
- provide the observation, if requested.
- """
- timer = timer or 0.0
- if timer > 0.0:
- time.sleep(timer)
-
- if get_env_observation:
- # Fetch observation.
- new_ts = self._env.step({
- 'action_type': np.array(
- android_action_type_lib.ActionType.REPEAT,
- dtype=self._parent_action_spec['action_type'].dtype,
- ),
- })
- observation = new_ts.observation
- else:
- observation = {}
-
- extras = self.accumulate_new_extras()
- forests = a11y_forests.extract_forests_from_task_extras(extras)
- if forests:
- observation['a11y_forest'] = forests[-1]
- else:
- observation['a11y_forest'] = None
- return observation
-
- def _fetch_task_extras_and_update_observation(
- self, observation: dict[str, Any], timeout: float = 0.0
- ) -> dict[str, Any]:
- if timeout > 0.0:
- observation = self._accumulate_and_return_a11y_info(
- timeout, get_env_observation=True
- )
- if not self._add_latest_a11y_info_to_obs:
- observation.pop('a11y_forest')
- else:
- new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False)
- if self._add_latest_a11y_info_to_obs:
- observation.update(new_obs)
- return observation
-
- def reset(self) -> dm_env.TimeStep:
- self._reset_enable_networking_attempts()
- self._servicer.pause_and_clear()
- timestep = self._env.reset()
- self._servicer.resume()
- if self._env.stats()['relaunch_count'] > self._relaunch_count:
- self._configure_grpc()
- self._relaunch_count = self._env.stats()['relaunch_count']
- self._accumulated_extras = {}
- timeout = self._a11y_info_timeout or 0.0
- new_observation = self._fetch_task_extras_and_update_observation(
- timestep.observation, timeout
- )
- timestep = timestep._replace(observation=new_observation)
- return timestep
-
- def step(self, action: Any) -> dm_env.TimeStep:
- timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0))
- timestep = self._env.step(action)
- new_observation = self._fetch_task_extras_and_update_observation(
- timestep.observation, timeout=timeout
- )
- timestep = timestep._replace(observation=new_observation)
- return timestep
-
- def accumulate_new_extras(self) -> dict[str, Any]:
- new_extras = self._fetch_task_extras()
- if self._should_accumulate:
- for key in new_extras:
- if key in self._accumulated_extras:
- self._accumulated_extras[key] = np.concatenate(
- (self._accumulated_extras[key], new_extras[key]), axis=0
- )
- else:
- self._accumulated_extras[key] = new_extras[key]
- else:
- self._accumulated_extras = new_extras
- self._should_accumulate = True
- return self._accumulated_extras
-
- def _fetch_task_extras(self) -> dict[str, Any]:
- """Fetches task_extras from the services.
-
- NOTE: If you want to access the latest a11y information, please use
- accumulate_and_return_a11y_info instead. This function has the side effect
- of clearing the content from the servicer, hence all the a11y info returned
- here won't be accumulated.
-
- Returns:
- A dict with the corresponding task_extras.
-
- Raises:
- EnableNetworkingError: after a fixed number of attempts to revive the a11y
- services by re-enabling the network connection.
- """
- base_extras = self._env.task_extras(latest_only=False).copy()
- # If the previous future is done, reset it to the initial state.
- if (
- self._enabling_networking_future is not None
- and self._enabling_networking_future.done()
- ):
- self._enabling_networking_future = None
- self._enable_networking_attempts_left -= 1
- logging.info('Finished enabling networking.')
-
- if (
- self._enabling_networking_future is None
- and 'exception' in base_extras
- and base_extras['exception'].shape[0]
- ):
- self._a11y_exception = base_extras['exception']
- logging.warning(
- 'AccessibilityForwarder logged exceptions: %s', self._a11y_exception
- )
- if self._enable_networking_attempts_left > 0:
- logging.warning(
- 'Attempting to enable networking. %s attempts left.',
- self._enable_networking_attempts_left - 1,
- )
- executor = futures.ThreadPoolExecutor(max_workers=1)
- self._enabling_networking_future = executor.submit(
- self.attempt_enable_networking
- )
- else:
- raise EnableNetworkingError(
- 'A11y service failed multiple times with'
- f' exception.{self._a11y_exception}.'
- )
-
- forests = self._servicer.gather_forests()
- if forests:
- base_extras.update(a11y_forests.package_forests_to_task_extras(forests))
- self._reset_enable_networking_attempts()
- events = self._servicer.gather_events()
- if events:
- base_extras.update(a11y_events.package_events_to_task_extras(events))
- self._reset_enable_networking_attempts()
- return base_extras
-
- def task_extras(self, latest_only: bool = False) -> dict[str, Any]:
- if self._accumulated_extras is None:
- raise RuntimeError('You must call .reset() before calling .task_extras()')
- self._should_accumulate = False
- extras = self._accumulated_extras.copy()
- if latest_only:
- a11y_events.keep_latest_event_only(extras)
- a11y_forests.keep_latest_forest_only(extras)
- return extras
diff --git a/android_env/wrappers/a11y_grpc_wrapper_test.py b/android_env/wrappers/a11y_grpc_wrapper_test.py
deleted file mode 100644
index 1b81fd6e..00000000
--- a/android_env/wrappers/a11y_grpc_wrapper_test.py
+++ /dev/null
@@ -1,849 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for a11y_grpc_wrapper."""
-
-import time
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env import env_interface
-from android_env.proto import adb_pb2
-from android_env.proto.a11y import a11y_pb2
-from android_env.proto.a11y import a11y_pb2_grpc
-from android_env.proto.a11y import android_accessibility_forest_pb2
-from android_env.wrappers import a11y_grpc_wrapper
-import dm_env
-import grpc
-import numpy as np
-
-
-def empty_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- return android_accessibility_forest_pb2.AndroidAccessibilityForest()
-
-
-def one_empty_window_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- _ = forest.windows.add()
- return forest
-
-
-def one_window_one_node_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'foo'
- node.is_clickable = True
- node.hint_text = 'Foo hint'
- return forest
-
-
-def one_window_two_nodes_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'bar'
- node.is_clickable = True
- node.hint_text = 'Bar hint'
- node = window.tree.nodes.add()
- node.class_name = 'bar'
- node.is_clickable = False
- node.hint_text = 'Bar hint 2'
- return forest
-
-
-def three_windows_forest() -> (
- android_accessibility_forest_pb2.AndroidAccessibilityForest
-):
- forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
- _ = forest.windows.add()
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'foo'
- node.is_clickable = True
- node.hint_text = 'hint'
- window = forest.windows.add()
- node = window.tree.nodes.add()
- node.class_name = 'baz'
- node.is_clickable = True
- node.hint_text = 'hint'
- node = window.tree.nodes.add()
- node.class_name = 'foobar'
- node.is_clickable = False
- node.hint_text = 'hint'
- return forest
-
-
-def empty_dict() -> dict[str, str]:
- return {}
-
-
-def single_item_dict() -> dict[str, str]:
- return {'foo': 'bar'}
-
-
-def several_long_items_dict() -> dict[str, str]:
- return {
- 'first_key': 'Lorem ipsum ' * 100,
- 'second_key': 'the beginning is the end is' * 100,
- }
-
-
-def single_item_dict_with_special_chars() -> dict[str, str]:
- return {'foo': 'bar\r\t\nbaz'}
-
-
-def _ok_response():
- return adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
-
-
-class A11yGrpcWrapperTest(parameterized.TestCase):
-
- def test_server(self):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {}
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- wrapped_env.reset()
- channel_creds = grpc.local_channel_credentials()
- with grpc.secure_channel(
- f'[::]:{wrapped_env.get_port()}', channel_creds
- ) as channel:
- grpc.channel_ready_future(channel).result()
- stub = a11y_pb2_grpc.A11yServiceStub(channel)
- stub.SendForest(one_window_one_node_forest())
- stub.SendForest(one_window_two_nodes_forest())
- wrapped_env.step({})
- extras = wrapped_env.task_extras(latest_only=False)
- self.assertIn('accessibility_tree', extras)
- self.assertEqual(extras['accessibility_tree'].shape[0], 2)
-
- # tests of fetch_task_extras:
- # exception occurs (ensure attempt to enable networking) and recovers
- # exception occurs and enable networking doesn't help
- # exception occurs twice but with a forest sent between
-
- @parameterized.named_parameters(
- ('no_events_or_forests', [], []),
- (
- 'no_events',
- [],
- [one_window_one_node_forest(), one_window_two_nodes_forest()],
- ),
- ('no_forests', [empty_dict(), single_item_dict()], []),
- (
- 'events_and_forests',
- [empty_dict(), single_item_dict()],
- [one_window_one_node_forest(), one_window_two_nodes_forest()],
- ),
- )
- @mock.patch.object(time, 'sleep', autospec=True)
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_fetch_task_extras(
- self,
- received_events,
- received_forests,
- mock_server,
- mock_add_servicer,
- mock_sleep,
- ):
- del mock_server, mock_add_servicer, mock_sleep
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar', 'baz'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- wrapped_env.reset()
- for forest in received_forests:
- wrapped_env._servicer.SendForest(forest, mock_context)
- for event in received_events:
- wrapped_env._servicer.SendEvent(
- a11y_pb2.EventRequest(event=event), mock_context
- )
- with mock.patch.object(
- wrapped_env, 'attempt_enable_networking'
- ) as mock_attempt_enable_networking:
- extras = wrapped_env._fetch_task_extras()
- mock_attempt_enable_networking.assert_not_called()
- self.assertIn('foo', extras)
- np.testing.assert_array_equal(extras['foo'], ['bar', 'baz'])
- self.assertIn('some_key', extras)
- np.testing.assert_array_equal(extras['some_key'], ['some_value'])
- if received_events:
- self.assertIn('full_event', extras)
- self.assertLen(extras['full_event'], len(received_events))
- for i, event in enumerate(received_events):
- event = a11y_pb2.EventRequest(event=event)
- np.testing.assert_array_equal(extras['full_event'][i], event)
- else:
- self.assertNotIn('full_event', extras)
- if received_forests:
- self.assertIn('accessibility_tree', extras)
- self.assertLen(extras['accessibility_tree'], len(received_forests))
- for i, forest in enumerate(received_forests):
- np.testing.assert_array_equal(extras['accessibility_tree'][i], forest)
- else:
- self.assertNotIn('accessibility_tree', extras)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_fetch_task_extras_enable_networking(
- self,
- mock_server,
- mock_add_servicer,
- mock_sleep,
- ):
- del mock_server, mock_add_servicer, mock_sleep
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- 'exception': np.array(['fake exception'], dtype='U'),
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- with mock.patch.object(
- wrapped_env, 'attempt_enable_networking'
- ) as mock_attempt_enable_networking:
- extras = wrapped_env._fetch_task_extras()
- self.assertNotIn('accessibility_tree', extras)
- self.assertNotIn('full_event', extras)
- future = wrapped_env._enabling_networking_future
- if future is not None:
- future.result()
- mock_attempt_enable_networking.assert_called_once()
-
- @mock.patch.object(time, 'sleep', autospec=True)
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_fetch_task_extras_enable_networking_twice(
- self,
- mock_server,
- mock_add_servicer,
- mock_sleep,
- ):
- del mock_server, mock_add_servicer, mock_sleep
- mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- }
-
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- wrapped_env.reset()
-
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- 'exception': np.array(['fake exception'], dtype='U'),
- }
- with mock.patch.object(
- wrapped_env, 'attempt_enable_networking'
- ) as mock_attempt_enable_networking:
- extras = wrapped_env._fetch_task_extras()
- self.assertNotIn('accessibility_tree', extras)
- self.assertNotIn('full_event', extras)
- future = wrapped_env._enabling_networking_future
- if future is not None:
- future.result()
- mock_attempt_enable_networking.assert_called_once()
- # Fixed networking; send a forest so the wrapper knows it worked.
- wrapped_env._servicer.SendForest(one_window_one_node_forest(), mock_context)
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- }
- extras = wrapped_env._fetch_task_extras()
- self.assertIn('accessibility_tree', extras)
- self.assertEqual(extras['accessibility_tree'].shape[0], 1)
- self.assertEqual(
- extras['accessibility_tree'][0], one_window_one_node_forest()
- )
-
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- 'exception': np.array(['fake exception'], dtype='U'),
- }
- with mock.patch.object(
- wrapped_env, 'attempt_enable_networking'
- ) as mock_attempt_enable_networking:
- extras = wrapped_env._fetch_task_extras()
- self.assertNotIn('accessibility_tree', extras)
- self.assertNotIn('full_event', extras)
- future = wrapped_env._enabling_networking_future
- if future is not None:
- future.result()
- mock_attempt_enable_networking.assert_called_once()
-
- @mock.patch.object(time, 'sleep', autospec=True)
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_task_extras_raises_a11y_info_exception(
- self, mock_sleep, mock_add_servicer, mock_server
- ):
- del mock_server, mock_add_servicer, mock_sleep
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- }
-
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env,
- add_latest_a11y_info_to_obs=True,
- max_enable_networking_attempts=1,
- )
- wrapped_env.reset()
-
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- 'exception': np.array(['fake exception'], dtype='U'),
- }
- with mock.patch.object(
- wrapped_env, 'attempt_enable_networking'
- ) as mock_attempt_enable_networking:
- extras = wrapped_env._fetch_task_extras()
- self.assertNotIn('accessibility_tree', extras)
- self.assertNotIn('full_event', extras)
- # Wait for the the attempt to finish.
- future = wrapped_env._enabling_networking_future
- if future is not None:
- future.result()
- mock_attempt_enable_networking.assert_called_once()
- # The _fetch_task_extras() call inside the next step will force a restart
- self.assertRaises(
- a11y_grpc_wrapper.EnableNetworkingError, wrapped_env.step, {}
- )
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_configure_grpc(
- self,
- mock_server,
- mock_add_servicer,
- ):
- del mock_server, mock_add_servicer
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.task_extras.return_value = {
- 'foo': np.array(['bar'], dtype='U'),
- 'some_key': np.array(['some_value'], dtype='U'),
- }
-
- base_env.stats.return_value = {'relaunch_count': 1}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- with mock.patch.object(
- wrapped_env, '_configure_grpc'
- ) as mock_configure_grpc:
- wrapped_env.reset()
- mock_configure_grpc.assert_called_once()
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_task_extras_raises_before_reset(
- self, unused_mock_server, unused_mock_add_servicer
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- with self.assertRaisesRegex(
- RuntimeError,
- r'You must call \.reset\(\) before calling \.task_extras\(\)',
- ):
- wrapped_env.task_extras(latest_only=False)
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_extras_accumulate_between_steps(
- self, mock_server, mock_add_servicer
- ):
- del mock_server, mock_add_servicer
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env, add_latest_a11y_info_to_obs=True
- )
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.return_value = {
- 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }
- timestep = wrapped_env.reset()
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
- wrapped_env._fetch_task_extras.return_value = {
- 'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(
- one_window_two_nodes_forest(), ndmin=1, dtype=object
- ),
- }
- timestep = wrapped_env.step({})
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(
- timestep.observation['a11y_forest'], one_window_two_nodes_forest()
- )
- timestep = wrapped_env.step({})
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(
- timestep.observation['a11y_forest'], one_window_two_nodes_forest()
- )
- wrapped_env._fetch_task_extras.return_value = {
- 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
- }
- timestep = wrapped_env.step({})
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(
- timestep.observation['a11y_forest'], one_window_two_nodes_forest()
- )
- expected_task_extras = {
- 'full_event': np.array(
- [
- single_item_dict(),
- empty_dict(),
- empty_dict(),
- single_item_dict(),
- ],
- dtype=object,
- ),
- 'accessibility_tree': np.array(
- [
- empty_forest(),
- one_window_two_nodes_forest(),
- one_window_two_nodes_forest(),
- ],
- dtype=object,
- ),
- }
- expected_task_extras_latest = {
- 'full_event': np.array([single_item_dict()], dtype=object),
- 'accessibility_tree': np.array(
- [one_window_two_nodes_forest()], dtype=object
- ),
- }
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- task_extras = wrapped_env.task_extras(latest_only=True)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras_latest['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras_latest['accessibility_tree'],
- )
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_a11y_info_disabled(
- self,
- unused_mock_server,
- unused_mock_add_servicer,
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.action_spec.return_value = {
- 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env, add_latest_a11y_info_to_obs=False, a11y_info_timeout=1.0
- )
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.return_value = {
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }
- timestep = wrapped_env.reset()
- self.assertNotIn('a11y_forest', timestep.observation)
- timestep = wrapped_env.step({})
- self.assertNotIn('a11y_forest', timestep.observation)
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_a11y_info_with_timer_info_present(
- self,
- unused_mock_server,
- unused_mock_add_servicer,
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.action_spec.return_value = {
- 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0
- )
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.side_effect = [{
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }]
- timestep = wrapped_env.reset()
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_a11y_info_with_timer_task_extra_returned(
- self, unused_mock_server, unused_mock_add_servicer, unused_mock_sleep
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.action_spec.return_value = {
- 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0
- )
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.side_effect = [
- {
- 'accessibility_tree': np.array(
- empty_forest(), ndmin=1, dtype=object
- ),
- },
- ]
- timestep = wrapped_env.reset()
- self.assertIn('a11y_forest', timestep.observation)
- self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_a11y_info_with_timer_from_action(
- self, unused_mock_server, unused_mock_add_servicer, mock_sleep
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.action_spec.return_value = {
- 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
- }
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
- base_env.step.return_value = dm_env.transition(
- observation={'dummy': 42}, reward=0.0
- )
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=0.0
- )
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.side_effect = [
- {
- 'accessibility_tree': np.array(
- empty_forest(), ndmin=1, dtype=object
- ),
- },
- ]
- timestep = wrapped_env.step(action={'wait_time': 1.0})
- self.assertIn('a11y_forest', timestep.observation)
- mock_sleep.assert_called_once()
- self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_task_extras_same_between_calls(self, mock_server, mock_add_servicer):
- del mock_server, mock_add_servicer
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- expected_task_extras = {
- 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.return_value = expected_task_extras
- wrapped_env.reset()
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- expected_task_extras = {
- 'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(
- one_window_two_nodes_forest(), ndmin=1, dtype=object
- ),
- }
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- wrapped_env._fetch_task_extras.return_value = expected_task_extras
- wrapped_env.step({})
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- @mock.patch.object(
- a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
- )
- @mock.patch.object(grpc, 'server', autospec=True)
- def test_task_extras_clear_if_called_between_step(
- self, mock_server, mock_add_servicer
- ):
- del mock_server, mock_add_servicer
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
- base_env.stats.return_value = {'relaunch_count': 0}
- base_env.execute_adb_call.return_value = _ok_response()
- wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
- with mock.patch.object(wrapped_env, '_fetch_task_extras'):
- expected_task_extras = {
- 'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }
- wrapped_env._fetch_task_extras.return_value = expected_task_extras
- wrapped_env.reset()
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- expected_task_extras = {
- 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
- }
- wrapped_env._fetch_task_extras.return_value = expected_task_extras
- wrapped_env.step({})
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
- expected_task_extras = {
- 'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
- 'accessibility_tree': np.array(
- one_window_two_nodes_forest(), ndmin=1, dtype=object
- ),
- }
- wrapped_env._fetch_task_extras.return_value = expected_task_extras
- wrapped_env.step({})
- task_extras = wrapped_env.task_extras(latest_only=False)
- np.testing.assert_equal(
- task_extras['full_event'], expected_task_extras['full_event']
- )
- np.testing.assert_equal(
- task_extras['accessibility_tree'],
- expected_task_extras['accessibility_tree'],
- )
-
- @parameterized.named_parameters(
- ('none_true', False, False, False, 0),
- ('only_install', True, False, False, 1),
- ('only_start', False, True, False, 1),
- ('only_enable_a11y_tree', False, False, True, 1),
- ('install_and_start_no_a11y_tree', True, True, False, 2),
- ('install_and_a11y_tree', True, False, True, 2),
- ('start_and_a11y_tree', False, True, True, 2),
- ('all_true', True, True, True, 3),
- )
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_apk_install_and_start(
- self,
- install_a11y_forwarding: bool,
- start_a11y_service: bool,
- enable_a11y_tree_logs: bool,
- expected_adb_calls: int,
- unused_mock_sleep,
- ):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
-
- side_effects = []
- if install_a11y_forwarding:
- side_effects.append(_ok_response()) # install response
- if start_a11y_service:
- side_effects.append(_ok_response()) # start service response
- if enable_a11y_tree_logs:
- side_effects.append(_ok_response()) # enable_tree_request
-
- base_env.execute_adb_call.side_effect = side_effects
-
- _ = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env,
- install_a11y_forwarding=install_a11y_forwarding,
- start_a11y_service=start_a11y_service,
- enable_a11y_tree_info=enable_a11y_tree_logs,
- )
- self.assertEqual(base_env.execute_adb_call.call_count, expected_adb_calls)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_component_and_start(self, unused_mock_sleep):
- base_env = mock.create_autospec(
- env_interface.AndroidEnvInterface, instance=True
- )
-
- side_effects = []
- side_effects.append(_ok_response()) # install response
- side_effects.append(_ok_response()) # start service response
- side_effects.append(_ok_response()) # enable_tree_request
-
- base_env.execute_adb_call.side_effect = side_effects
-
- _ = a11y_grpc_wrapper.A11yGrpcWrapper(
- base_env,
- install_a11y_forwarding=True,
- start_a11y_service=True,
- enable_a11y_tree_info=True,
- )
-
- # call_args returns a tuple of which the first member is a tuple containing
- # the most recent args the mock was called with, and execute_adb_call only
- # has one arg (so [0][0] to access the AdbRequest).
- self.assertEqual(
- base_env.execute_adb_call.call_args[0][0].send_broadcast.component,
- 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver',
- )
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/base_wrapper.py b/android_env/wrappers/base_wrapper.py
deleted file mode 100644
index 68aea7e4..00000000
--- a/android_env/wrappers/base_wrapper.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base class for AndroidEnv wrappers."""
-
-from typing import Any
-
-from absl import logging
-from android_env import env_interface
-from android_env.proto import adb_pb2
-from android_env.proto import state_pb2
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-class BaseWrapper(env_interface.AndroidEnvInterface):
- """AndroidEnv wrapper."""
-
- def __init__(self, env):
- self._env = env
- logging.info('Wrapping with %s', self.__class__.__name__)
-
- def reset(self) -> dm_env.TimeStep:
- self._reset_state()
- timestep = self._process_timestep(self._env.reset())
- return timestep
-
- def step(self, action: Any) -> dm_env.TimeStep:
- action = self._process_action(action)
- return self._process_timestep(self._env.step(action))
-
- def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
- return self._env.task_extras(latest_only=latest_only)
-
- def _reset_state(self):
- pass
-
- def _process_action(self, action: Any) -> Any:
- return action
-
- def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
- return timestep
-
- def observation_spec(self) -> dict[str, specs.Array]:
- return self._env.observation_spec()
-
- def action_spec(self) -> dict[str, specs.Array]:
- return self._env.action_spec()
-
- def reward_spec(self) -> specs.Array:
- return self._env.reward_spec()
-
- def discount_spec(self) -> specs.Array:
- return self._env.discount_spec()
-
- def _wrapper_stats(self) -> dict[str, Any]:
- """Add wrapper specific logging here."""
- return {}
-
- def stats(self) -> dict[str, Any]:
- info = self._env.stats()
- info.update(self._wrapper_stats())
- return info
-
- def load_state(
- self, request: state_pb2.LoadStateRequest
- ) -> state_pb2.LoadStateResponse:
- """Loads a state."""
- return self._env.load_state(request)
-
- def save_state(
- self, request: state_pb2.SaveStateRequest
- ) -> state_pb2.SaveStateResponse:
- """Saves a state.
-
- Args:
- request: A `SaveStateRequest` containing any parameters necessary to
- specify how/what state to save.
-
- Returns:
- A `SaveStateResponse` containing the status, error message (if
- applicable), and any other relevant information.
- """
- return self._env.save_state(request)
-
- def execute_adb_call(self,
- adb_call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
- return self._env.execute_adb_call(adb_call)
-
- @property
- def raw_action(self):
- return self._env.raw_action
-
- @property
- def raw_observation(self):
- return self._env.raw_observation
-
- @property
- def raw_env(self):
- """Recursively unwrap until we reach the true 'raw' env."""
- wrapped = self._env
- if hasattr(wrapped, 'raw_env'):
- return wrapped.raw_env
- return wrapped
-
- def __getattr__(self, attr):
- """Delegate attribute access to underlying environment."""
- return getattr(self._env, attr)
-
- def close(self):
- self._env.close()
diff --git a/android_env/wrappers/base_wrapper_test.py b/android_env/wrappers/base_wrapper_test.py
deleted file mode 100644
index a690139d..00000000
--- a/android_env/wrappers/base_wrapper_test.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.base_wrapper."""
-
-from unittest import mock
-
-from absl import logging
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.proto import state_pb2
-from android_env.wrappers import base_wrapper
-
-
-class BaseWrapperTest(absltest.TestCase):
-
- @mock.patch.object(logging, 'info')
- def test_base_function_forwarding(self, mock_info):
- base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- wrapped_env = base_wrapper.BaseWrapper(base_env)
- mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper')
-
- fake_ts = 'fake_ts'
- base_env.reset.return_value = fake_ts
- self.assertEqual(fake_ts, wrapped_env.reset())
- base_env.reset.assert_called_once()
-
- fake_ts = 'fake_ts'
- fake_action = 'fake_action'
- base_env.step.return_value = fake_ts
- self.assertEqual(fake_ts, wrapped_env.step(fake_action))
- base_env.step.assert_called_once_with(fake_action)
-
- fake_extras = 'fake_task_extras'
- base_env.task_extras.return_value = fake_extras
- self.assertEqual(fake_extras, wrapped_env.task_extras(latest_only=True))
- base_env.task_extras.assert_called_once_with(latest_only=True)
-
- fake_obs_spec = 'fake_obs_spec'
- base_env.observation_spec.return_value = fake_obs_spec
- self.assertEqual(fake_obs_spec, wrapped_env.observation_spec())
- base_env.observation_spec.assert_called_once()
-
- fake_action_spec = 'fake_action_spec'
- base_env.action_spec.return_value = fake_action_spec
- self.assertEqual(fake_action_spec, wrapped_env.action_spec())
- base_env.action_spec.assert_called_once()
-
- fake_raw_action = 'fake_raw_action'
- type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action)
- self.assertEqual(fake_raw_action, wrapped_env.raw_action)
-
- fake_raw_observation = 'fake_raw_observation'
- type(base_env).raw_observation = mock.PropertyMock(
- return_value=fake_raw_observation)
- self.assertEqual(fake_raw_observation, wrapped_env.raw_observation)
-
- load_request = state_pb2.LoadStateRequest(args={})
- expected_response = state_pb2.LoadStateResponse(
- status=state_pb2.LoadStateResponse.Status.OK
- )
- base_env.load_state.return_value = expected_response
- self.assertEqual(wrapped_env.load_state(load_request), expected_response)
- base_env.load_state.assert_called_once_with(load_request)
-
- save_request = state_pb2.SaveStateRequest(args={})
- expected_response = state_pb2.SaveStateResponse(
- status=state_pb2.SaveStateResponse.Status.OK
- )
- base_env.save_state.return_value = expected_response
- self.assertEqual(wrapped_env.save_state(save_request), expected_response)
- base_env.save_state.assert_called_once_with(save_request)
-
- wrapped_env.close()
- base_env.close.assert_called_once()
-
- fake_return_value = 'fake'
- # AndroidEnv::some_random_function() does not exist and calling it should
- # raise an AttributeError.
- with self.assertRaises(AttributeError):
- base_env.some_random_function.return_value = fake_return_value
-
- def test_multiple_wrappers(self):
- base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
- wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
-
- wrapped_env_2.close()
- base_env.close.assert_called_once()
-
- def test_raw_env(self):
- base_env = 'fake_env'
- wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
- wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
- self.assertEqual(base_env, wrapped_env_2.raw_env)
-
- def test_stats(self):
- base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- wrapped_env = base_wrapper.BaseWrapper(base_env)
- base_stats = {'base': 'stats'}
- base_env.stats.return_value = base_stats
- self.assertEqual(base_stats, wrapped_env.stats())
-
- @mock.patch.object(logging, 'info')
- def test_wrapped_stats(self, mock_info):
- base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
-
- class LoggingWrapper1(base_wrapper.BaseWrapper):
-
- def _wrapper_stats(self):
- return {
- 'wrapper1': 'stats',
- 'shared': 1,
- }
-
- class LoggingWrapper2(base_wrapper.BaseWrapper):
-
- def _wrapper_stats(self):
- return {
- 'wrapper2': 'stats',
- 'shared': 2,
- }
-
- wrapped_env = LoggingWrapper2(LoggingWrapper1(base_env))
- mock_info.assert_has_calls([
- mock.call('Wrapping with %s', 'LoggingWrapper1'),
- mock.call('Wrapping with %s', 'LoggingWrapper2'),
- ])
- base_stats = {'base': 'stats'}
- base_env.stats.return_value = base_stats
- expected_stats = {
- 'base': 'stats',
- 'wrapper1': 'stats',
- 'wrapper2': 'stats',
- 'shared': 2,
- }
-
- self.assertEqual(expected_stats, wrapped_env.stats())
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/discrete_action_wrapper.py b/android_env/wrappers/discrete_action_wrapper.py
deleted file mode 100644
index 7bd483b7..00000000
--- a/android_env/wrappers/discrete_action_wrapper.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps the AndroidEnv environment to provide discrete actions."""
-
-from collections.abc import Sequence
-
-from android_env.components import action_type
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-_NOISE_CLIP_VALUE = 0.4999
-
-
-class DiscreteActionWrapper(base_wrapper.BaseWrapper):
- """AndroidEnv with discrete actions."""
-
- def __init__(
- self,
- env: dm_env.Environment,
- action_grid: Sequence[int] = (10, 10),
- redundant_actions: bool = True,
- noise: float = 0.1,
- ):
- super().__init__(env)
- self._parent_action_spec = self._env.action_spec()
- self._assert_base_env()
- self._action_grid = action_grid # [height, width]
- self._grid_size = np.prod(self._action_grid)
- self._num_action_types = self._parent_action_spec['action_type'].num_values
- self._redundant_actions = redundant_actions
- self._noise = noise
-
- def _assert_base_env(self):
- """Checks that the wrapped env has the right action spec format."""
-
- assert len(self._parent_action_spec) == 2
- assert not self._parent_action_spec['action_type'].shape
- assert self._parent_action_spec['touch_position'].shape == (2,)
-
- @property
- def num_actions(self) -> int:
- """Number of discrete actions."""
-
- if self._redundant_actions:
- return self._grid_size * self._num_action_types
- else:
- return self._grid_size + self._num_action_types - 1
-
- def step(self, action: dict[str, int]) -> dm_env.TimeStep:
- """Take a step in the base environment."""
-
- return self._env.step(self._process_action(action))
-
- def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]:
- """Transforms action so that it agrees with AndroidEnv's action spec."""
-
- return {
- 'action_type':
- np.array(self._get_action_type(action['action_id']),
- dtype=self._parent_action_spec['action_type'].dtype),
- 'touch_position':
- np.array(self._get_touch_position(action['action_id']),
- dtype=self._parent_action_spec['touch_position'].dtype)
- }
-
- def _get_action_type(self, action_id: int) -> action_type.ActionType:
- """Compute action type corresponding to the given action_id.
-
- When `self._redundant_actions` == True the `grid_size` is "broadcast" over
- all the possible actions so you end up with `grid_size` discrete actions
- of type 0, `grid_size` discrete actions of type 1, etc. for all action
- types.
-
- When `self._redundant_actions` == False the first `grid_size` actions are
- reserved for "touch" and the rest are just added (NOT multiplied) to the
- total number of discrete actions (exactly one of LIFT and REPEAT).
-
- Args:
- action_id: A discrete action.
- Returns:
- action_type: The action_type of the action.
- """
-
- if self._redundant_actions:
- assert action_id < self._num_action_types * self._grid_size
- return action_id // self._grid_size
-
- else:
- assert action_id <= self._grid_size + 1
- if action_id < self._grid_size:
- return action_type.ActionType.TOUCH
- elif action_id == self._grid_size:
- return action_type.ActionType.LIFT
- else:
- return action_type.ActionType.REPEAT
-
- def _get_touch_position(self, action_id: int) -> Sequence[float]:
- """Compute the position corresponding to the given action_id.
-
- Note: in the touch_position (x, y) of an action, x corresponds to the
- horizontal axis (width), and y corresponds to the vertical axis (height)
- of the screen. BUT, the screen has dimensions (height, width), i.e. the
- first coordinate corresponds to y, and the second coordinate corresponds
- to x. Pay attention to this mismatch in the calculations below.
-
- Args:
- action_id: A discrete action.
- Returns:
- touch_position: The [0,1]x[0,1] coordinate of the action.
- """
-
- position_idx = action_id % self._grid_size
-
- x_pos_grid = position_idx % self._action_grid[1] # WIDTH
- y_pos_grid = position_idx // self._action_grid[1] # HEIGHT
-
- noise_x = np.random.normal(loc=0.0, scale=self._noise)
- noise_y = np.random.normal(loc=0.0, scale=self._noise)
-
- # Noise is clipped so that the action will strictly stay in the cell.
- noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
- noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
-
- x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH
- y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT
-
- # Project action space to action_spec ranges. For the default case of
- # minimum = [0, 0] and maximum = [1, 1], this will not do anything.
- x_min, y_min = self._parent_action_spec['touch_position'].minimum
- x_max, y_max = self._parent_action_spec['touch_position'].maximum
-
- x_pos = x_min + x_pos * (x_max - x_min)
- y_pos = y_min + y_pos * (y_max - y_min)
-
- return [x_pos, y_pos]
-
- def action_spec(self) -> dict[str, specs.Array]:
- """Action spec of the wrapped environment."""
-
- return {
- 'action_id':
- specs.DiscreteArray(
- num_values=self.num_actions,
- name='action_id')
- }
diff --git a/android_env/wrappers/discrete_action_wrapper_test.py b/android_env/wrappers/discrete_action_wrapper_test.py
deleted file mode 100644
index bc9d4329..00000000
--- a/android_env/wrappers/discrete_action_wrapper_test.py
+++ /dev/null
@@ -1,386 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.discrete_action_wrapper."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.components import action_type as action_type_lib
-from android_env.wrappers import discrete_action_wrapper
-from dm_env import specs
-import numpy as np
-
-ActionType = action_type_lib.ActionType
-
-
-def _make_array_spec(shape, dtype, name):
- assert len(shape) == 1
- return specs.BoundedArray(
- name=name,
- shape=shape,
- dtype=dtype,
- minimum=np.zeros(shape),
- maximum=(shape[0] - 1) * np.ones(shape), # maximum is inclusive.
- )
-
-
-def _valid_shape(action):
- assert len(action) == 2, action
- assert not action['action_type'].shape, (
- 'action: %r, shape: %r' %
- (action['action_type'], action['action_type'].shape))
- assert action['touch_position'].shape == (
- 2,), ('action: %r, shape: %r' %
- (action['touch_position'], action['touch_position'].shape))
-
-
-def _valid_types(action, types):
- for a, t in zip(action.values(), types):
- assert a.dtype == t, '%r is not of dtype %r' % (a, t)
-
-
-class DiscreteActionWrapperTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self._num_action_types = 3 # Only TOUCH, LIFT, REPEAT.
- self._base_action_spec = {
- 'action_type': specs.DiscreteArray(
- num_values=self._num_action_types, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(2,), dtype=np.float32, name='touch_position'),
- }
- self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- self.base_env.action_spec.return_value = self._base_action_spec
-
- def test_num_actions(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, action_grid=(3, 3), redundant_actions=True)
- # 27 = 3 * 3 * 2 (H * W * self._num_action_types).
- self.assertEqual(27, wrapped_env.num_actions)
-
- def test_num_actions_non_redundant(self):
- # Check that with `redundant_actions`==False we get an additive term instead
- # of a multiplier in the number of actions.
- non_redudant_wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, action_grid=(3, 3), redundant_actions=False)
- # 11 = 3 * 3 + 2 (H * W + (self._num_action_types - 1)).
- self.assertEqual(11, non_redudant_wrapped_env.num_actions)
-
- def test_reset(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, redundant_actions=True)
- fake_timestep = 'ts'
- self.base_env.reset.return_value = fake_timestep
- ts = wrapped_env.reset()
- self.base_env.reset.assert_called_once()
- self.assertEqual(fake_timestep, ts)
-
- def test_step_no_noise(self):
- height = 4
- width = 3
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env,
- action_grid=(height, width),
- noise=0.0,
- redundant_actions=True)
- self.assertEqual(height * width * self._num_action_types,
- wrapped_env.num_actions)
-
- vertical_half_step = 1. / float(height) / 2.
- horizontal_half_step = 1. / float(width) / 2.
-
- delta = 0.0001
-
- # Testing the four corners with each finger position
- def get_verifier(expected_action_type, lower_x, lower_y):
-
- def verifier(x):
- _valid_shape(x)
- _valid_types(x, [np.int32, np.float32])
- self.assertEqual(
- expected_action_type, x['action_type'])
- if lower_y:
- self.assertAlmostEqual(
- vertical_half_step, x['touch_position'][1], delta=delta)
- else:
- self.assertAlmostEqual(
- 1 - vertical_half_step, x['touch_position'][1], delta=delta)
- if lower_x:
- self.assertAlmostEqual(
- horizontal_half_step, x['touch_position'][0], delta=delta)
- else:
- self.assertAlmostEqual(
- 1 - horizontal_half_step, x['touch_position'][0], delta=delta)
- return True
-
- return verifier
-
- action_tests = {
- 0: get_verifier(0, lower_x=True, lower_y=True),
- 2: get_verifier(0, lower_x=False, lower_y=True),
- 9: get_verifier(0, lower_x=True, lower_y=False),
- 11: get_verifier(0, lower_x=False, lower_y=False),
-
- 12: get_verifier(1, lower_x=True, lower_y=True),
- 14: get_verifier(1, lower_x=False, lower_y=True),
- 21: get_verifier(1, lower_x=True, lower_y=False),
- 23: get_verifier(1, lower_x=False, lower_y=False),
-
- 24: get_verifier(2, lower_x=True, lower_y=True),
- 26: get_verifier(2, lower_x=False, lower_y=True),
- 33: get_verifier(2, lower_x=True, lower_y=False),
- 35: get_verifier(2, lower_x=False, lower_y=False),
- }
-
- fake_timestep = 'ts'
- self.base_env.step.return_value = fake_timestep
-
- for action_id, verifier in action_tests.items():
- ts = wrapped_env.step({'action_id': action_id})
- verifier(self.base_env.step.call_args[0][0])
- self.assertEqual(fake_timestep, ts)
-
- def test_step_redundant_actions_invalid_action_id(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env,
- action_grid=(4, 3),
- noise=0.0,
- redundant_actions=True)
- with self.assertRaises(AssertionError):
- _ = wrapped_env.step({'action_id': 36})
-
- def test_step_no_noise_no_redudant_actions(self):
- height = 4
- width = 3
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env,
- action_grid=(height, width),
- noise=0.0,
- redundant_actions=False)
- self.assertEqual(height * width + (self._num_action_types - 1),
- wrapped_env.num_actions)
-
- vertical_half_step = 1. / float(height) / 2.
- horizontal_half_step = 1. / float(width) / 2.
-
- delta = 0.0001
-
- # Testing the four corners with each finger position
- def get_verifier(expected_action_type, lower_x, lower_y):
-
- def verifier(x):
- _valid_shape(x)
- _valid_types(x, [np.int32, np.float32])
- self.assertEqual(expected_action_type, x['action_type'])
- # If the action type == TOUCH, then check the coordinate values.
- if x['action_type'] == ActionType.TOUCH:
- if lower_y:
- self.assertAlmostEqual(
- vertical_half_step, x['touch_position'][1], delta=delta)
- else:
- self.assertAlmostEqual(
- 1 - vertical_half_step, x['touch_position'][1], delta=delta)
- if lower_x:
- self.assertAlmostEqual(
- horizontal_half_step, x['touch_position'][0], delta=delta)
- else:
- self.assertAlmostEqual(
- 1 - horizontal_half_step, x['touch_position'][0], delta=delta)
- return True
-
- return verifier
-
- action_tests = {
- # Touch type actions
- 0: get_verifier(0, lower_x=True, lower_y=True),
- 2: get_verifier(0, lower_x=False, lower_y=True),
- 9: get_verifier(0, lower_x=True, lower_y=False),
- 11: get_verifier(0, lower_x=False, lower_y=False),
- # Actions > grid_size return non-touch actions with (0,0) coordinates.
- 12: get_verifier(1, lower_x=False, lower_y=False),
- 13: get_verifier(2, lower_x=False, lower_y=False),
- }
-
- fake_timestep = 'ts'
- self.base_env.step.return_value = fake_timestep
-
- for action_id, verifier in action_tests.items():
- ts = wrapped_env.step({'action_id': action_id})
- verifier(self.base_env.step.call_args[0][0])
- self.assertEqual(fake_timestep, ts)
-
- def test_step_no_redundant_actions_invalid_action_id(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env,
- action_grid=(4, 3),
- noise=0.0,
- redundant_actions=False)
- with self.assertRaises(AssertionError):
- _ = wrapped_env.step({'action_id': 14})
-
- def test_step_with_noise(self):
- height = 4
- width = 3
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, action_grid=(height, width), noise=1.0)
- self.assertEqual(height * width * self._num_action_types,
- wrapped_env.num_actions)
-
- vertical_grid_step = 1. / float(height)
- horizontal_grid_step = 1. / float(width)
-
- # Testing the four corners with each finger position
- def get_verifier(expected_up_down, lower_x, lower_y):
-
- def verifier(x):
- _valid_shape(x)
- _valid_types(x, [np.int32, np.float32])
- self.assertEqual(expected_up_down, x['action_type'])
- if lower_y:
- self.assertGreater(vertical_grid_step, x['touch_position'][1])
- else:
- self.assertLess(1 - vertical_grid_step, x['touch_position'][1])
- if lower_x:
- self.assertGreater(horizontal_grid_step, x['touch_position'][0])
- else:
- self.assertLess(1 - horizontal_grid_step, x['touch_position'][0])
- return True
-
- return verifier
-
- action_tests = {
- 0: get_verifier(0, lower_x=True, lower_y=True),
- 2: get_verifier(0, lower_x=False, lower_y=True),
- 9: get_verifier(0, lower_x=True, lower_y=False),
- 11: get_verifier(0, lower_x=False, lower_y=False),
-
- 12: get_verifier(1, lower_x=True, lower_y=True),
- 14: get_verifier(1, lower_x=False, lower_y=True),
- 21: get_verifier(1, lower_x=True, lower_y=False),
- 23: get_verifier(1, lower_x=False, lower_y=False),
-
- 24: get_verifier(2, lower_x=True, lower_y=True),
- 26: get_verifier(2, lower_x=False, lower_y=True),
- 33: get_verifier(2, lower_x=True, lower_y=False),
- 35: get_verifier(2, lower_x=False, lower_y=False),
- }
-
- fake_timestep = 'ts'
- self.base_env.step.return_value = fake_timestep
-
- for action_id, verifier in action_tests.items():
- ts = wrapped_env.step({'action_id': action_id})
- verifier(self.base_env.step.call_args[0][0])
- self.assertEqual(fake_timestep, ts)
-
- def test_parent_spec_type(self):
- base_action_spec = {
- 'action_type': specs.DiscreteArray(
- num_values=self._num_action_types, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(2,), dtype=np.float64, name='touch_position'),
- }
- base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- base_env.action_spec.return_value = base_action_spec
-
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- base_env, noise=0.0)
-
- fake_timestep = 'ts'
- base_env.step.return_value = fake_timestep
-
- def verifier(x):
- _valid_types(x, [np.int32, np.float64])
- return True
-
- ts = wrapped_env.step({'action_id': 1})
- verifier(base_env.step.call_args[0][0])
- self.assertEqual(fake_timestep, ts)
-
- def test_observation_spec(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env)
- fake_obs_spec = 'fake_obs_spec'
- self.base_env.observation_spec.return_value = fake_obs_spec
- observation_spec = wrapped_env.observation_spec()
- self.base_env.observation_spec.assert_called_once()
- self.assertEqual(fake_obs_spec, observation_spec)
-
- def test_action_spec(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, action_grid=(4, 5), redundant_actions=True)
- expected_action_spec = {
- 'action_id':
- specs.DiscreteArray(
- num_values=4 * 5 * self._num_action_types, name='action_type')
- }
- self.assertEqual(expected_action_spec, wrapped_env.action_spec())
-
- def test_action_spec_non_redundant(self):
- wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
- self.base_env, action_grid=(4, 5), redundant_actions=False)
- num_non_touch_actions = self._num_action_types - 1
- expected_action_spec = {
- 'action_id':
- specs.DiscreteArray(
- num_values=4 * 5 + num_non_touch_actions, name='action_type')
- }
- self.assertEqual(expected_action_spec, wrapped_env.action_spec())
-
- def test_assert_base_env_action_spec_too_short(self):
- self.base_env.action_spec.return_value = {
- 'action_type': specs.DiscreteArray(
- num_values=self._num_action_types, name='action_type'),
- }
- with self.assertRaises(AssertionError):
- _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
-
- def test_assert_base_env_action_spec_too_long(self):
- self.base_env.action_spec.return_value = {
- 'action_type': specs.DiscreteArray(
- num_values=self._num_action_types, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(2,), dtype=np.float32, name='touch_position'),
- 'too_long': _make_array_spec(
- shape=(1,), dtype=np.float32, name='too_long'),
- }
- with self.assertRaises(AssertionError):
- _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
-
- def test_assert_base_env_action_spec_wrong_shapes(self):
- self.base_env.action_spec.return_value = {
- 'action_type': _make_array_spec(
- shape=(2,), dtype=np.float32, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(1,), dtype=np.float32, name='touch_position')
- }
- with self.assertRaises(AssertionError):
- _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
-
- def test_assert_base_env_ok(self):
- self.base_env.action_spec.return_value = {
- 'action_type': specs.DiscreteArray(
- num_values=self._num_action_types, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(2,), dtype=np.float32, name='touch_position'),
- }
- _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/flat_interface_wrapper.py b/android_env/wrappers/flat_interface_wrapper.py
deleted file mode 100644
index 868ca65f..00000000
--- a/android_env/wrappers/flat_interface_wrapper.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps the AndroidEnv environment to make its interface flat."""
-
-from typing import Any
-
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-RGB_CHANNELS = (0, 1, 2)
-
-
-def _extract_screen_pixels(obs: np.ndarray):
- """Get only screen pixels by removing previous action layer."""
- is_grayscale_image = obs.shape[-1] == 2
- if is_grayscale_image:
- return np.expand_dims(obs[..., 0], -1)
- return obs[..., RGB_CHANNELS]
-
-
-def _get_no_action_observation_spec(obs_spec: specs.BoundedArray):
- """Create an observation spec without the action layer."""
- shape = np.array(obs_spec.shape)
- shape[2] -= 1
- minimum = obs_spec.minimum
- maximum = obs_spec.maximum
- is_scalar = lambda x: np.isscalar(x) or np.ndim(x) == 0
- if not is_scalar(minimum):
- minimum = _extract_screen_pixels(minimum)
- if not is_scalar(maximum):
- maximum = _extract_screen_pixels(maximum)
- return obs_spec.replace(shape=shape, minimum=minimum, maximum=maximum)
-
-
-class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
- """Simple interface for AndroidEnv.
-
- Removes the structure from observations and actions, keeping only the pixel
- observations. Also exposes action as an int32 scalar, making it easier to use
- with conventional discrete agents. This wrapper expects a discretized action
- space.
- """
-
- def __init__(self,
- env: dm_env.Environment,
- flat_actions: bool = True,
- flat_observations: bool = True,
- keep_action_layer: bool = True):
- super().__init__(env)
- self._flat_actions = flat_actions
- self._flat_observations = flat_observations
- self._keep_action_layer = keep_action_layer
- self._action_name = list(self._env.action_spec())[0]
- self._assert_base_env()
-
- def _assert_base_env(self):
- base_action_spec = self._env.action_spec()
- assert len(base_action_spec) == 1, self._env.action_spec()
- assert isinstance(base_action_spec, dict)
- assert isinstance(base_action_spec[self._action_name], specs.BoundedArray)
-
- def _process_action(self, action: int | np.ndarray | dict[str, Any]):
- if self._flat_actions:
- return {self._action_name: action}
- else:
- return action
-
- def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
- if self._flat_observations:
- step_type, reward, discount, observation = timestep
- # Keep only the pixels.
- pixels = observation['pixels']
- pixels = pixels if self._keep_action_layer else _extract_screen_pixels(
- pixels)
- return dm_env.TimeStep(
- step_type=step_type,
- reward=reward,
- discount=discount,
- observation=pixels)
- else:
- return timestep
-
- def reset(self) -> dm_env.TimeStep:
- timestep = self._env.reset()
- return self._process_timestep(timestep)
-
- def step(self, action: int) -> dm_env.TimeStep:
- timestep = self._env.step(self._process_action(action))
- return self._process_timestep(timestep)
-
- def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
- if self._flat_observations:
- pixels_spec = self._env.observation_spec()['pixels']
- if not self._keep_action_layer:
- return _get_no_action_observation_spec(pixels_spec)
- return pixels_spec
- else:
- return self._env.observation_spec()
-
- def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
- if self._flat_actions:
- return self._env.action_spec()[self._action_name]
- else:
- return self._env.action_spec()
diff --git a/android_env/wrappers/flat_interface_wrapper_test.py b/android_env/wrappers/flat_interface_wrapper_test.py
deleted file mode 100644
index a650054b..00000000
--- a/android_env/wrappers/flat_interface_wrapper_test.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.flat_interface_wrapper."""
-
-from typing import cast
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.wrappers import flat_interface_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0):
- return specs.BoundedArray(
- shape=shape,
- dtype=dtype,
- name=name,
- maximum=np.ones(shape) * maximum,
- minimum=np.ones(shape) * minimum)
-
-
-def _make_timestep(observation):
- return dm_env.TimeStep(
- step_type='fake_step_type',
- reward='fake_reward',
- discount='fake_discount',
- observation=observation,
- )
-
-
-class FlatInterfaceWrapperTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.action_shape = (1,)
- self.base_action_spec: dict[str, specs.DiscreteArray] = {
- 'action_id': specs.DiscreteArray(name='action_id', num_values=4)
- }
- self.int_obs_shape = (3, 4, 2)
- self.float_obs_shape = (2,)
- self.base_observation_spec = {
- 'pixels': _make_array_spec(
- shape=self.int_obs_shape, dtype=np.uint8, name='pixels'),
- 'obs1': _make_array_spec(
- shape=self.float_obs_shape, dtype=np.float32, name='obs1'),
- }
- # Expected.
- self.expected_observation_spec = _make_array_spec(
- shape=self.int_obs_shape, dtype=np.uint8, name='pixels')
- self.image_obs = np.ones(self.int_obs_shape, dtype=np.uint8)
- self.expected_timestep = _make_timestep(self.image_obs)
-
- # Expected for no new action layer shape.
- expected_new_shape_no_action_layer = (3, 4, 1)
- self.expected_observation_spec_no_action_layer = _make_array_spec(
- shape=expected_new_shape_no_action_layer, dtype=np.uint8, name='pixels')
- self.expected_timestep_no_action_layer = _make_timestep(
- np.ones(expected_new_shape_no_action_layer, dtype=np.uint8))
-
- # Base environment.
- self.other_obs = np.ones(self.float_obs_shape, dtype=np.float32)
- self.base_timestep = _make_timestep({
- 'pixels': self.image_obs,
- 'obs1': self.other_obs})
- self.base_env = mock.create_autospec(dm_env.Environment)
- self.base_env.action_spec.return_value = self.base_action_spec
- self.base_env.observation_spec.return_value = self.base_observation_spec
- self.base_env.reset.return_value = self.base_timestep
- self.base_env.step.return_value = self.base_timestep
-
- def test_reset(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
- ts = wrapped_env.reset()
- self.base_env.reset.assert_called_once()
- self.assertEqual(self.expected_timestep, ts)
-
- def test_reset_no_action_layer(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
- self.base_env, keep_action_layer=False)
- ts = wrapped_env.reset()
- self.base_env.reset.assert_called_once()
- self.assertEqual(
- self.expected_timestep_no_action_layer.observation.tolist(),
- ts.observation.tolist())
-
- def test_step(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
- action = 2
- ts = wrapped_env.step(action)
-
- def verifier(x):
- self.assertIsInstance(x, dict)
- self.assertIsInstance(x['action_id'], int)
- self.assertEqual(x['action_id'], action)
- return True
- verifier(self.base_env.step.call_args[0][0])
- self.assertEqual(self.expected_timestep, ts)
-
- def test_step_no_action_layer(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
- self.base_env, keep_action_layer=False)
- action = 2
- ts = wrapped_env.step(action)
-
- def verifier(x):
- self.assertIsInstance(x, dict)
- self.assertIsInstance(x['action_id'], int)
- self.assertEqual(x['action_id'], action)
- return True
-
- verifier(self.base_env.step.call_args[0][0])
- self.assertEqual(
- self.expected_timestep_no_action_layer.observation.tolist(),
- ts.observation.tolist())
-
- def test_observation_spec(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
- observation_spec = wrapped_env.observation_spec()
- self.base_env.observation_spec.assert_called_once()
- self.assertEqual(self.expected_observation_spec, observation_spec)
-
- def test_observation_spec_no_action_layer(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
- self.base_env, keep_action_layer=False)
- observation_spec = wrapped_env.observation_spec()
- self.base_env.observation_spec.assert_called_once()
- self.assertEqual(self.expected_observation_spec_no_action_layer,
- observation_spec)
-
- def test_action_spec(self):
- wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
- action_spec = cast(specs.BoundedArray, wrapped_env.action_spec())
- parent_action_spec = self.base_action_spec['action_id']
-
- self.assertEqual(parent_action_spec.name, action_spec.name)
- self.assertEqual((), action_spec.shape)
- self.assertEqual(np.int32, action_spec.dtype)
- self.assertEqual(0, action_spec.minimum)
-
- def test_bad_action_spec_structured_action(self):
- bad_base_env = mock.create_autospec(dm_env.Environment)
- bad_base_env.action_spec.return_value = {
- 'action_id': _make_array_spec((1,)),
- 'too_many': _make_array_spec((1,))
- }
- with self.assertRaises(AssertionError):
- _ = flat_interface_wrapper.FlatInterfaceWrapper(bad_base_env)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/float_pixels_wrapper.py b/android_env/wrappers/float_pixels_wrapper.py
deleted file mode 100644
index 43def1f6..00000000
--- a/android_env/wrappers/float_pixels_wrapper.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Converts pixel observation to from int to float32 between 0.0 and 1.0."""
-
-from android_env.components import pixel_fns
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-class FloatPixelsWrapper(base_wrapper.BaseWrapper):
- """Wraps AndroidEnv for Panultimate agent."""
-
- def __init__(self, env: dm_env.Environment):
- super().__init__(env)
- self._input_spec = self._env.observation_spec()['pixels']
- self._should_convert_int_to_float = np.issubdtype(self._input_spec.dtype,
- np.integer)
-
- def _process_observation(
- self, observation: dict[str, np.ndarray]
- ) -> dict[str, np.ndarray]:
- if self._should_convert_int_to_float:
- float_pixels = pixel_fns.convert_int_to_float(
- observation['pixels'], self._input_spec
- )
- observation['pixels'] = float_pixels
- return observation
-
- def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
- step_type, reward, discount, observation = timestep
- return dm_env.TimeStep(
- step_type=step_type,
- reward=reward,
- discount=discount,
- observation=self._process_observation(observation))
-
- def reset(self) -> dm_env.TimeStep:
- return self._process_timestep(self._env.reset())
-
- def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- return self._process_timestep(self._env.step(action))
-
- def observation_spec(self) -> dict[str, specs.Array]:
- if self._should_convert_int_to_float:
- observation_spec = self._env.observation_spec()
- observation_spec['pixels'] = specs.BoundedArray(
- shape=self._env.observation_spec()['pixels'].shape,
- dtype=np.float32,
- minimum=0.0,
- maximum=1.0,
- name=self._env.observation_spec()['pixels'].name)
- return observation_spec
- return self._env.observation_spec()
diff --git a/android_env/wrappers/float_pixels_wrapper_test.py b/android_env/wrappers/float_pixels_wrapper_test.py
deleted file mode 100644
index 5868e75b..00000000
--- a/android_env/wrappers/float_pixels_wrapper_test.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.float_pixels_wrapper."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env.wrappers import float_pixels_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _make_array_spec(shape, dtype=np.float32, name=None):
- return specs.Array(
- shape=shape,
- dtype=dtype,
- name=name,
- )
-
-
-def _make_bounded_array_spec(
- shape, dtype=np.float32, name=None, maximum=1.0, minimum=0.0):
- return specs.BoundedArray(
- shape=shape,
- dtype=dtype,
- name=name,
- maximum=maximum,
- minimum=minimum,
- )
-
-
-def _simple_timestep(obs_shape, obs_type):
- return dm_env.TimeStep(
- step_type=dm_env.StepType.MID,
- reward=3.14,
- discount=0.9,
- observation=(np.ones(shape=obs_shape, dtype=obs_type),),
- )
-
-
-class FloatPixelsWrapperTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.pixels_shape = (3, 4)
- base_pixel_spec = _make_array_spec(
- shape=self.pixels_shape, dtype=np.uint8, name='pixels')
- self.other_obs_spec = _make_array_spec(
- shape=(1,), dtype=np.float32, name='other_obs')
- base_observation_spec = {
- 'pixels': base_pixel_spec,
- 'other_obs': self.other_obs_spec
- }
- self.base_env = mock.create_autospec(dm_env.Environment)
- self.base_env.observation_spec.return_value = base_observation_spec
-
- self.base_timestep = dm_env.TimeStep(
- step_type=dm_env.StepType.MID,
- reward=3.14,
- discount=0.9,
- observation={
- 'pixels': np.ones(shape=self.pixels_shape, dtype=np.uint8),
- 'other_obs': [42.2]})
- self.base_env.step.return_value = self.base_timestep
- self.base_env.reset.return_value = self.base_timestep
-
- def test_float_pixels_wrapper_spec(self):
- expected_pixel_spec = _make_bounded_array_spec(
- shape=self.pixels_shape,
- dtype=np.float32,
- name='pixels',
- minimum=0.0,
- maximum=1.0)
-
- wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
-
- self.assertLen(wrapped_env.observation_spec(), 2)
- self.assertEqual(expected_pixel_spec,
- wrapped_env.observation_spec()['pixels'])
- self.assertEqual(self.other_obs_spec,
- wrapped_env.observation_spec()['other_obs'])
-
- def test_float_pixels_wrapper_step(self):
- wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
- ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
-
- self.assertEqual(self.base_timestep.step_type, ts.step_type)
- self.assertEqual(self.base_timestep.reward, ts.reward)
- self.assertEqual(self.base_timestep.discount, ts.discount)
- self.assertEqual(self.base_timestep.observation['other_obs'],
- ts.observation['other_obs'])
- expected_pixel_value = 1. / 255. # original values are unit8
- expected_pixels = np.ones(
- self.pixels_shape, dtype=np.float32) * expected_pixel_value
- np.testing.assert_equal(expected_pixels, ts.observation['pixels'])
-
- def test_float_pixels_wrapper_reset(self):
- wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
- ts = wrapped_env.reset()
-
- self.assertEqual(self.base_timestep.step_type, ts.step_type)
- self.assertEqual(self.base_timestep.reward, ts.reward)
- self.assertEqual(self.base_timestep.discount, ts.discount)
- self.assertEqual(self.base_timestep.observation['other_obs'],
- ts.observation['other_obs'])
- expected_pixel_value = 1. / 255. # original values are unit8
- expected_pixels = np.ones(
- self.pixels_shape, dtype=np.float32) * expected_pixel_value
- np.testing.assert_equal(expected_pixels, ts.observation['pixels'])
-
- def test_float_pixels_wrapper_already_float(self):
- base_pixel_spec = _make_array_spec(
- shape=self.pixels_shape, dtype=np.float64, name='pixels')
- base_observation_spec = {
- 'pixels': base_pixel_spec,
- 'other_obs': self.other_obs_spec
- }
- base_env = mock.create_autospec(dm_env.Environment)
- base_env.observation_spec.return_value = base_observation_spec
-
- wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(base_env)
-
- # If the pixels are already float values, then obs_spec does not change.
- self.assertEqual(base_env.observation_spec(),
- wrapped_env.observation_spec())
-
- # The wrapper should not touch the timestep in this case.
- fake_timestep = ('step_type', 'reward', 'discount', 'obs')
- base_env.step.return_value = fake_timestep
- ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
- self.assertEqual(fake_timestep, ts)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/gym_wrapper.py b/android_env/wrappers/gym_wrapper.py
deleted file mode 100644
index 41ff44ed..00000000
--- a/android_env/wrappers/gym_wrapper.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps the AndroidEnv to expose an OpenAI Gym interface."""
-
-from typing import Any
-
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import gym
-from gym import spaces
-import numpy as np
-
-
-class GymInterfaceWrapper(gym.Env):
- """AndroidEnv with OpenAI Gym interface."""
-
- def __init__(self, env: dm_env.Environment):
- self._env = env
- self.spec = None
- self.action_space = self._spec_to_space(self._env.action_spec())
- self.observation_space = self._spec_to_space(self._env.observation_spec())
- self.metadata = {'render.modes': ['rgb_array']}
- self._latest_observation = None
-
- def _spec_to_space(self, spec: specs.Array) -> spaces.Space:
- """Converts dm_env specs to OpenAI Gym spaces."""
-
- if isinstance(spec, list):
- return spaces.Tuple([self._spec_to_space(s) for s in spec])
-
- if isinstance(spec, dict):
- return spaces.Dict(
- {name: self._spec_to_space(s) for name, s in spec.items()}
- )
-
- if isinstance(spec, specs.DiscreteArray):
- return spaces.Box(
- shape=(),
- dtype=spec.dtype,
- low=0,
- high=spec.num_values-1)
-
- if isinstance(spec, specs.BoundedArray):
- return spaces.Box(
- shape=spec.shape,
- dtype=spec.dtype,
- low=spec.minimum,
- high=spec.maximum)
-
- if isinstance(spec, specs.Array):
- if spec.dtype == np.uint8:
- low = 0
- high = 255
- else:
- low = -np.inf
- high = np.inf
- return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high)
-
- raise ValueError('Unknown type for specs: {}'.format(spec))
-
- def render(self, mode='rgb_array'):
- """Renders the environment."""
- if mode == 'rgb_array':
- if self._latest_observation is None:
- return
-
- return self._latest_observation['pixels']
- else:
- raise ValueError('Only supported render mode is rgb_array.')
-
- def reset(self) -> np.ndarray:
- self._latest_observation = None
- timestep = self._env.reset()
- return timestep.observation
-
- def step(self, action: dict[str, int]) -> tuple[Any, ...]:
- """Take a step in the base environment."""
- timestep = self._env.step(action)
- observation = timestep.observation
- self._latest_observation = observation
- reward = timestep.reward
- done = timestep.step_type == dm_env.StepType.LAST
- info = {'discount': timestep.discount}
- return observation, reward, done, info
diff --git a/android_env/wrappers/gym_wrapper_test.py b/android_env/wrappers/gym_wrapper_test.py
deleted file mode 100644
index 9c266699..00000000
--- a/android_env/wrappers/gym_wrapper_test.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.gym_wrapper."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.wrappers import gym_wrapper
-import dm_env
-from dm_env import specs
-from gym import spaces
-import numpy as np
-
-
-class GymInterfaceWrapperTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- self._base_env.action_spec.return_value = {
- 'action_type':
- specs.DiscreteArray(
- num_values=3,
- name='action_type'),
- 'touch_position':
- specs.BoundedArray(
- shape=(2,),
- dtype=np.float32,
- minimum=[0.0, 0.0],
- maximum=[1.0, 1.0],
- name='touch_position'),
- }
- self._base_env.observation_spec.return_value = {
- 'pixels':
- specs.Array(
- shape=(480, 320, 3),
- dtype=np.uint8,
- name='pixels'),
- 'timedelta':
- specs.Array(shape=(), dtype=np.int64, name='timedelta'),
- 'orientation':
- specs.Array(
- shape=np.array([4]),
- dtype=np.uint8,
- name='orientation'),
- }
- self._wrapped_env = gym_wrapper.GymInterfaceWrapper(self._base_env)
- self._fake_ts = dm_env.TimeStep(
- step_type=dm_env.StepType.MID,
- observation={'pixels': np.ones(shape=(2, 3))},
- reward=10.0,
- discount=1.0)
-
- def test_render(self):
- self._base_env.step.return_value = self._fake_ts
- _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
- image = self._wrapped_env.render(mode='rgb_array')
- self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3))))
-
- def test_render_error(self):
- with self.assertRaises(ValueError):
- _ = self._wrapped_env.render(mode='human')
-
- def test_reset(self):
- self._base_env.reset.return_value = dm_env.TimeStep(
- step_type=dm_env.StepType.FIRST,
- observation={'pixels': np.ones(shape=(2, 3))},
- reward=10.0,
- discount=1.0)
- obs = self._wrapped_env.reset()
- self._base_env.reset.assert_called_once()
- self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
-
- def test_step(self):
- self._base_env.step.return_value = self._fake_ts
- obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
- self._base_env.step.assert_called_once()
- self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
-
- def test_spec_to_space(self):
-
- spec = specs.Array(
- shape=(2, 3),
- dtype=np.float32)
- space = self._wrapped_env._spec_to_space(spec)
- self.assertEqual(space, spaces.Box(
- low=-np.inf, high=np.inf, shape=spec.shape, dtype=spec.dtype))
-
- spec = specs.BoundedArray(
- shape=(),
- dtype=np.float32,
- minimum=4,
- maximum=5)
- space = self._wrapped_env._spec_to_space(spec)
- self.assertEqual(space, spaces.Box(
- low=4, high=5, shape=spec.shape, dtype=spec.dtype))
-
- spec = specs.DiscreteArray(num_values=4)
- space = self._wrapped_env._spec_to_space(spec)
- self.assertEqual(space, spaces.Box(
- low=0, high=3, shape=(), dtype=np.int32))
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/image_rescale_wrapper.py b/android_env/wrappers/image_rescale_wrapper.py
deleted file mode 100644
index 6faeebca..00000000
--- a/android_env/wrappers/image_rescale_wrapper.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps the AndroidEnv environment to rescale the observations."""
-
-from collections.abc import Sequence
-
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-from PIL import Image
-
-
-# Taken from https://pillow.readthedocs.io/en/3.2.x/reference/Image.html#PIL.Image.Image.convert
-#
-# This array maps an RGB image to a grayscale image using the ITU-R 709
-# specification which is good for computer displays and HDTV.
-RGB_TO_GRAYSCALE_COEFFICIENTS = [0.2126, 0.7152, 0.0722]
-
-
-class ImageRescaleWrapper(base_wrapper.BaseWrapper):
- """AndroidEnv with rescaled observations."""
-
- def __init__(
- self,
- env: dm_env.Environment,
- zoom_factors: Sequence[float] | None = (0.5, 0.5),
- grayscale: bool = False,
- ):
- super().__init__(env)
- assert 'pixels' in self._env.observation_spec()
- assert self._env.observation_spec()['pixels'].shape[-1] in [1, 3], (
- 'Number of pixel channels should be 1 or 3.')
- self._grayscale = grayscale
- if zoom_factors is None:
- zoom_factors = (1.0, 1.0)
- # We only zoom the width and height of each layer, and we explicitly do not
- # want to zoom the number of channels so we just multiply it by 1.0.
- self._zoom_factors = tuple(zoom_factors) + (1.0,)
-
- def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
- observation = timestep.observation
- processed_observation = observation.copy()
- processed_observation['pixels'] = self._process_pixels(
- observation['pixels'])
- return timestep._replace(observation=processed_observation)
-
- def _process_pixels(self, raw_observation: np.ndarray) -> np.ndarray:
- # We expect `raw_observation` to have shape (W, H, 3) - 3 for RGB
- new_shape = np.array(
- self._zoom_factors[0:2] * np.array(raw_observation.shape[0:2]),
- dtype=np.int32)[::-1]
- if self._grayscale:
- # When self._grayscale == True, we squash the RGB into a single layer
- image = np.dot(raw_observation, RGB_TO_GRAYSCALE_COEFFICIENTS)
- else:
- image = raw_observation
- return self._resize_image_array(image, new_shape)
-
- def _resize_image_array(
- self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray
- ) -> np.ndarray:
- """Resize color or grayscale/action_layer array to new_shape."""
- assert new_shape.ndim == 1
- assert len(new_shape) == 2
- resized_array = np.array(
- Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize(
- tuple(new_shape)
- )
- )
- if resized_array.ndim == 2:
- return np.expand_dims(resized_array, axis=-1)
- return resized_array
-
- def reset(self) -> dm_env.TimeStep:
- timestep = self._env.reset()
- return self._process_timestep(timestep)
-
- def step(self, action) -> dm_env.TimeStep:
- timestep = self._env.step(action)
- return self._process_timestep(timestep)
-
- def observation_spec(self) -> dict[str, specs.Array]:
- parent_spec = self._env.observation_spec().copy()
- out_shape = np.multiply(parent_spec['pixels'].shape,
- self._zoom_factors).astype(np.int32)
- if self._grayscale:
- # In grayscale mode we want the output shape to be [W, H, 1]
- out_shape[-1] = 1
- parent_spec['pixels'] = specs.BoundedArray(
- shape=out_shape,
- dtype=parent_spec['pixels'].dtype,
- name=parent_spec['pixels'].name,
- minimum=parent_spec['pixels'].minimum,
- maximum=parent_spec['pixels'].maximum)
- return parent_spec
diff --git a/android_env/wrappers/image_rescale_wrapper_test.py b/android_env/wrappers/image_rescale_wrapper_test.py
deleted file mode 100644
index c55101f1..00000000
--- a/android_env/wrappers/image_rescale_wrapper_test.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.image_rescale_wrapper."""
-
-from typing import Any
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.wrappers import image_rescale_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _simple_spec():
- return specs.BoundedArray(
- shape=np.array([300, 300, 3]),
- dtype=np.uint8,
- name='pixels',
- minimum=0,
- maximum=255)
-
-
-def _simple_timestep():
- observation = np.ones(shape=[300, 300, 3])
- return dm_env.TimeStep(
- step_type=dm_env.StepType.MID,
- reward=3.14,
- discount=0.9,
- observation={'pixels': observation})
-
-
-class ImageRescaleWrapperTest(absltest.TestCase):
-
- def test_100x50_grayscale(self):
- fake_timestep = _simple_timestep()
- fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
- fake_env.reset.return_value = fake_timestep
- fake_env.step.return_value = fake_timestep
-
- wrapper = image_rescale_wrapper.ImageRescaleWrapper(
- fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True)
- self.assertIsNotNone(wrapper)
- self.assertEqual(wrapper.observation_spec()['pixels'].shape, (100, 50, 1))
- reset_timestep = wrapper.reset()
- reset_image = reset_timestep.observation['pixels']
- self.assertEqual(reset_image.shape, (100, 50, 1))
- step_timestep = wrapper.step(action='fake_action')
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (100, 50, 1))
-
- def test_150x60_full_channels(self):
- fake_timestep = _simple_timestep()
- fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
- fake_env.reset.return_value = fake_timestep
- fake_env.step.return_value = fake_timestep
-
- wrapper = image_rescale_wrapper.ImageRescaleWrapper(
- fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0))
- self.assertIsNotNone(wrapper)
- self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3))
- reset_timestep = wrapper.reset()
- reset_image = reset_timestep.observation['pixels']
- self.assertEqual(reset_image.shape, (150, 60, 3))
- step_timestep = wrapper.step(action='fake_action')
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (150, 60, 3))
-
- def test_list_zoom_factor(self):
- fake_timestep = _simple_timestep()
- fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
- fake_env.reset.return_value = fake_timestep
- fake_env.step.return_value = fake_timestep
-
- wrapper = image_rescale_wrapper.ImageRescaleWrapper(
- fake_env, zoom_factors=[0.5, 0.2])
- self.assertIsNotNone(wrapper)
- self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3))
- reset_timestep = wrapper.reset()
- reset_image = reset_timestep.observation['pixels']
- self.assertEqual(reset_image.shape, (150, 60, 3))
- step_timestep = wrapper.step(action='fake_action')
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (150, 60, 3))
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/last_action_wrapper.py b/android_env/wrappers/last_action_wrapper.py
deleted file mode 100644
index a09633c6..00000000
--- a/android_env/wrappers/last_action_wrapper.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Extends Android observation with the latest action taken."""
-
-from android_env.components import action_type
-from android_env.components import pixel_fns
-from android_env.wrappers import base_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-class LastActionWrapper(base_wrapper.BaseWrapper):
- """Extends Android observations with information about the last action taken.
-
- The position of the last action is denoted by a single white pixel (with a
- value of 255) in a channel of all black pixels (with a value of 0).
- As this wrapper makes use of temporarily stored information about the
- last action taken, it is important to apply on the environment side rather
- than the agent side. Recommended not to apply before an ImageRescaleWrapper,
- to avoid distortion of the single pixel denoting the action position.
- """
-
- def __init__(self,
- env: dm_env.Environment,
- concat_to_pixels: bool = True):
- """Initializes the internal state of this wrapper.
-
- Args:
- env: the environment to wrap.
- concat_to_pixels: If True, will add a channel to the pixel observation.
- If False, will pass the action as an extra observation.
- """
- super().__init__(env)
- self._concat_to_pixels = concat_to_pixels
- self._screen_dimensions = self._env.observation_spec()['pixels'].shape[:2]
-
- def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
- observation = timestep.observation.copy()
- processed_observation = self._process_observation(observation)
- return timestep._replace(observation=processed_observation)
-
- def _process_observation(
- self, observation: dict[str, np.ndarray]
- ) -> dict[str, np.ndarray]:
- """Extends observation with last_action data."""
- processed_observation = observation.copy()
- last_action_layer = self._get_last_action_layer(observation['pixels'])
- if self._concat_to_pixels:
- pixels = observation['pixels'].copy()
- processed_pixels = np.dstack((pixels, last_action_layer))
- processed_observation['pixels'] = processed_pixels
- else:
- processed_observation['last_action'] = last_action_layer
- return processed_observation
-
- def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray:
- """Makes sure the rescaling doesn't distort the last_action layer."""
-
- last_action = self._env.raw_action
- last_action_layer = np.zeros(self._screen_dimensions, dtype=pixels.dtype)
-
- if ('action_type' in last_action and
- last_action['action_type'] == action_type.ActionType.TOUCH):
- touch_position = last_action['touch_position']
- x, y = pixel_fns.touch_position_to_pixel_position(
- touch_position, width_height=self._screen_dimensions[::-1]
- )
- last_action_layer[y, x] = 255
-
- return last_action_layer
-
- def reset(self) -> dm_env.TimeStep:
- timestep = self._env.reset()
- return self._process_timestep(timestep)
-
- def step(self, action) -> dm_env.TimeStep:
- timestep = self._env.step(action)
- return self._process_timestep(timestep)
-
- def observation_spec(self) -> dict[str, specs.Array]:
- parent_spec = self._env.observation_spec().copy()
- shape = parent_spec['pixels'].shape
- if self._concat_to_pixels:
- parent_spec['pixels'] = specs.BoundedArray(
- shape=(shape[0], shape[1], shape[2] + 1),
- dtype=parent_spec['pixels'].dtype,
- name=parent_spec['pixels'].name,
- minimum=parent_spec['pixels'].minimum,
- maximum=parent_spec['pixels'].maximum)
- else:
- parent_spec.update({
- 'last_action':
- specs.BoundedArray(
- shape=(shape[0], shape[1]),
- dtype=parent_spec['pixels'].dtype,
- name='last_action',
- minimum=parent_spec['pixels'].minimum,
- maximum=parent_spec['pixels'].maximum)
- })
- return parent_spec
diff --git a/android_env/wrappers/last_action_wrapper_test.py b/android_env/wrappers/last_action_wrapper_test.py
deleted file mode 100644
index 77307c00..00000000
--- a/android_env/wrappers/last_action_wrapper_test.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for android_env.wrappers.last_action_wrapper."""
-
-from typing import Any
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.components import action_type
-from android_env.wrappers import last_action_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _simple_spec():
- return specs.BoundedArray(
- shape=np.array([120, 80, 3]),
- dtype=np.uint8,
- name='pixels',
- minimum=0,
- maximum=255)
-
-
-def _simple_timestep():
- observation = np.ones(shape=[120, 80, 3])
- return dm_env.TimeStep(
- step_type=dm_env.StepType.MID,
- reward=3.14,
- discount=0.9,
- observation={'pixels': observation})
-
-
-class LastActionWrapperTest(absltest.TestCase):
-
- def test_concat_to_pixels(self):
- fake_timestep = _simple_timestep()
- fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
- fake_env.reset.return_value = fake_timestep
- fake_env.step.return_value = fake_timestep
-
- wrapper = last_action_wrapper.LastActionWrapper(
- fake_env, concat_to_pixels=True)
- self.assertIsNotNone(wrapper)
- self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 4))
-
- reset_timestep = wrapper.reset()
- reset_image = reset_timestep.observation['pixels']
- self.assertEqual(reset_image.shape, (120, 80, 4))
- last_action_layer = reset_image[:, :, -1]
- self.assertEqual(np.sum(last_action_layer), 0)
-
- action1 = {
- 'action_type': action_type.ActionType.TOUCH,
- 'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H)
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
- step_timestep = wrapper.step(action=action1)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W)
- last_action_layer = step_image[:, :, -1]
- self.assertEqual(np.sum(last_action_layer), 255)
- y, x = np.where(last_action_layer == 255)
- self.assertEqual((y.item(), x.item()), (90, 20))
-
- action2 = {
- 'action_type': action_type.ActionType.LIFT,
- 'touch_position': np.array([0.25, 0.75], dtype=np.float32),
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
- step_timestep = wrapper.step(action=action2)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 4))
- last_action_layer = step_image[:, :, -1]
- self.assertEqual(np.sum(last_action_layer), 0)
-
- action3 = {
- 'action_type': action_type.ActionType.TOUCH,
- 'touch_position': np.array([0.25, 1.0], dtype=np.float32),
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
- step_timestep = wrapper.step(action=action3)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 4))
- last_action_layer = step_image[:, :, -1]
- self.assertEqual(np.sum(last_action_layer), 255)
- y, x = np.where(last_action_layer == 255)
- self.assertEqual((y.item(), x.item()), (119, 20))
-
- def test_no_concat_to_pixels(self):
- fake_timestep = _simple_timestep()
- fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
- fake_env.reset.return_value = fake_timestep
- fake_env.step.return_value = fake_timestep
-
- wrapper = last_action_wrapper.LastActionWrapper(
- fake_env, concat_to_pixels=False)
- self.assertIsNotNone(wrapper)
- self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 3))
- self.assertEqual(wrapper.observation_spec()['last_action'].shape, (120, 80))
-
- reset_timestep = wrapper.reset()
- reset_image = reset_timestep.observation['pixels']
- self.assertEqual(reset_image.shape, (120, 80, 3))
- last_action_layer = reset_timestep.observation['last_action']
- self.assertEqual(np.sum(last_action_layer), 0)
-
- action1 = {
- 'action_type': action_type.ActionType.TOUCH,
- 'touch_position': np.array([0.25, 0.75], dtype=np.float32),
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
- step_timestep = wrapper.step(action=action1)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 3))
- last_action_layer = step_timestep.observation['last_action']
- self.assertEqual(np.sum(last_action_layer), 255)
- y, x = np.where(last_action_layer == 255)
- self.assertEqual((y.item(), x.item()), (90, 20))
-
- action2 = {
- 'action_type': action_type.ActionType.LIFT,
- 'touch_position': np.array([0.25, 0.75], dtype=np.float32),
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
- step_timestep = wrapper.step(action=action2)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 3))
- last_action_layer = step_timestep.observation['last_action']
- self.assertEqual(np.sum(last_action_layer), 0)
-
- action3 = {
- 'action_type': action_type.ActionType.TOUCH,
- 'touch_position': np.array([1.0, 0.75], dtype=np.float32),
- }
- type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
- step_timestep = wrapper.step(action=action3)
- step_image = step_timestep.observation['pixels']
- self.assertEqual(step_image.shape, (120, 80, 3))
- last_action_layer = step_timestep.observation['last_action']
- self.assertEqual(np.sum(last_action_layer), 255)
- y, x = np.where(last_action_layer == 255)
- self.assertEqual((y.item(), x.item()), (90, 79))
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/rate_limit_wrapper.py b/android_env/wrappers/rate_limit_wrapper.py
deleted file mode 100644
index 2439be5b..00000000
--- a/android_env/wrappers/rate_limit_wrapper.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Limits interactions with the environment to a given rate."""
-
-import enum
-import time
-
-from android_env import env_interface
-from android_env.components import action_type
-from android_env.wrappers import base_wrapper
-import dm_env
-import numpy as np
-
-
-class RateLimitWrapper(base_wrapper.BaseWrapper):
- """Limits interactions with the environment to a given rate."""
-
- class SleepType(enum.IntEnum):
- """Determines how the wrapper interacts with the underlying environment."""
-
- # The wrapper sleeps before calling `step()` on the underlying environment.
- BEFORE = 0
-
- # The wrapper sleeps after calling `step()` on the underlying environment.
- AFTER = 1
-
- # The wrapper first calls `step()`, obtaining a TimeStep which is ignored,
- # then it sleeps, and then it calls `step(REPEAT)` to obtain a TimeStep
- # that's as fresh as possible.
- #
- # Note that for both BEFORE and AFTER_WITH_REPEAT, the _total_ amount of
- # time inside this wrapper may go beyond the rate specified in `rate`
- # because the sleep does not account for the time taken by step().
- AFTER_WITH_REPEAT = 2
-
- def __init__(self,
- env: env_interface.AndroidEnvInterface,
- rate: float,
- sleep_type: SleepType = SleepType.AFTER_WITH_REPEAT):
- """Initializes this wrapper.
-
- Args:
- env: The underlying environment to which this wrapper is applied.
- rate: The desired rate in Hz to interact with the environment. If <=0.0,
- this wrapper will be disabled.
- sleep_type: This determines how the wrapper will interact with the
- underlying AndroidEnv environment.
- """
- super().__init__(env)
- self._assert_base_env()
- self._last_step_time = None
- self._max_wait = 1.0 / rate if rate > 0.0 else 0.0
- self._sleep_type = sleep_type
-
- def _assert_base_env(self):
- """Checks that the wrapped env has the right action spec format."""
- parent_action_spec = self._env.action_spec()
- assert len(parent_action_spec) == 2
- assert not parent_action_spec['action_type'].shape
- assert parent_action_spec['touch_position'].shape == (2,)
-
- def reset(self):
- timestep = self._env.reset()
- self._last_step_time = time.time()
- return timestep
-
- def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- """Takes a step while maintaining a steady interaction rate."""
-
- # If max_wait is non-positive, the wrapper has no effect.
- if self._max_wait <= 0.0:
- return self._env.step(action)
-
- if self._sleep_type == RateLimitWrapper.SleepType.BEFORE:
- self._wait()
-
- timestep = self._env.step(action)
- if timestep.last():
- return timestep
-
- if self._sleep_type == RateLimitWrapper.SleepType.AFTER_WITH_REPEAT:
- for k in action.keys():
- if k.startswith('action_type'):
- action[k] = np.array(action_type.ActionType.REPEAT, dtype=np.uint8)
- self._wait()
- first_reward = timestep.reward or 0.0
- timestep = self._env.step(action)
- second_reward = timestep.reward or 0.0
- # Accumulate rewards over the two steps taken.
- timestep = timestep._replace(reward=first_reward + second_reward)
-
- elif self._sleep_type == RateLimitWrapper.SleepType.AFTER:
- self._wait()
-
- self._last_step_time = time.time()
-
- return timestep
-
- def _wait(self) -> None:
- if self._max_wait > 0.0 and self._last_step_time is not None:
- time_since_step = time.time() - self._last_step_time
- sec_to_wait = self._max_wait - time_since_step
- if sec_to_wait > 0.0:
- time.sleep(sec_to_wait)
diff --git a/android_env/wrappers/rate_limit_wrapper_test.py b/android_env/wrappers/rate_limit_wrapper_test.py
deleted file mode 100644
index d108de92..00000000
--- a/android_env/wrappers/rate_limit_wrapper_test.py
+++ /dev/null
@@ -1,270 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for rate_limit_wrapper."""
-
-import time
-from typing import Any, Protocol
-from unittest import mock
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from android_env import env_interface
-from android_env.components import action_type
-from android_env.wrappers import rate_limit_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _get_base_env():
- env = mock.create_autospec(env_interface.AndroidEnvInterface)
- env.action_spec.return_value = {
- 'action_type':
- specs.DiscreteArray(
- num_values=len(action_type.ActionType),
- name='action_type'),
- 'touch_position':
- specs.BoundedArray(
- shape=(2,),
- dtype=np.float32,
- minimum=[0.0, 0.0],
- maximum=[1.0, 1.0],
- name='touch_position'),
- }
- return env
-
-
-class _FnWithTimestamps(Protocol):
- """A function with `timestamp` and `timestamps` attributes."""
-
- timestamp: float
- timestamps: list[float]
-
-
-def _with_timestamp(fn: Any) -> _FnWithTimestamps:
- return fn
-
-
-class RateLimitWrapperTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- ('zero_rate', 0),
- ('negative_rate', -50),
- )
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_disabled(self, rate, mock_sleep):
- """With a non-positive rate, this wrapper should do nothing."""
- env = _get_base_env()
- wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=rate)
- _ = wrapper.reset()
- mock_sleep.assert_not_called()
- _ = wrapper.step({
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position': np.array([0.123, 0.456])
- })
- mock_sleep.assert_not_called()
- # When the wrapper is disabled, base step should only be called once.
- env.step.assert_called_once()
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_enabled(self, mock_sleep):
- """When enabled, the wrapper should sleep for a period in [0, 1/rate]."""
-
- env = _get_base_env()
- env.step.return_value = dm_env.transition(reward=None, observation=None)
- wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=1/33.33)
-
- _ = wrapper.reset()
- mock_sleep.assert_not_called() # It should never sleep during reset().
-
- # Step for 100 steps.
- for _ in range(100):
- _ = wrapper.step({
- 'action_type':
- np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position':
- np.array([0.123, 0.456])
- })
-
- # Check that there are 100 calls and that they're all within [0, 1/rate].
- self.assertLen(mock_sleep.call_args_list, 100)
- for call in mock_sleep.call_args_list:
- args, unused_kwargs = call
- sleep_time = args[0]
- self.assertBetween(sleep_time, 0.0, 33.33)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_enabled_sleep_type_before(self, mock_sleep):
- """When sleep_type==BEFORE, sleep should come before step()."""
-
- env = _get_base_env()
- wrapper = rate_limit_wrapper.RateLimitWrapper(
- env,
- rate=1/33.33,
- sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.BEFORE)
-
- _ = wrapper.reset()
- mock_sleep.assert_not_called() # It should never sleep during reset().
-
- @_with_timestamp
- def _sleep_fn(sleep_time):
- _sleep_fn.timestamp = time.time()
- self.assertBetween(sleep_time, 0.0, 33.33)
-
- mock_sleep.side_effect = _sleep_fn
-
- def _step_fn(action):
- self.assertEqual(
- action['action_type'],
- np.array(action_type.ActionType.LIFT, dtype=np.uint8))
- _step_fn.timestamps.append(time.time())
- return dm_env.transition(reward=None, observation=None)
-
- _step_fn.timestamps = []
-
- env.step.side_effect = _step_fn
-
- _ = wrapper.step({
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position': np.array([0.123, 0.456])
- })
-
- self.assertLen(_step_fn.timestamps, 1)
- # We expect sleep to have been executed BEFORE a single `step()`.
- self.assertGreaterEqual(_step_fn.timestamps[0], _sleep_fn.timestamp)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_enabled_sleep_type_after(self, mock_sleep):
- """When sleep_type==AFTER, sleep should come after step()."""
-
- env = _get_base_env()
- wrapper = rate_limit_wrapper.RateLimitWrapper(
- env,
- rate=1/33.33,
- sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.AFTER)
- _ = wrapper.reset()
- mock_sleep.assert_not_called() # It should never sleep during reset().
-
- @_with_timestamp
- def _sleep_fn(sleep_time):
- _sleep_fn.timestamp = time.time()
- self.assertBetween(sleep_time, 0.0, 33.33)
-
- mock_sleep.side_effect = _sleep_fn
-
- def _step_fn(action):
- self.assertEqual(
- action['action_type'],
- np.array(action_type.ActionType.LIFT, dtype=np.uint8))
- _step_fn.timestamps.append(time.time())
- return dm_env.transition(reward=None, observation=None)
-
- _step_fn.timestamps = []
-
- env.step.side_effect = _step_fn
-
- _ = wrapper.step({
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position': np.array([0.123, 0.456])
- })
-
- # We expect sleep to have been executed AFTER a single `step()`.
- self.assertLen(_step_fn.timestamps, 1)
- self.assertLessEqual(_step_fn.timestamps[0], _sleep_fn.timestamp)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_enabled_sleep_type_after_with_repeat(self, mock_sleep):
- """When sleep_type==AFTER_WITH_REPEAT, sleep should be between 2 steps()."""
-
- env = _get_base_env()
- wrapper = rate_limit_wrapper.RateLimitWrapper(
- env,
- rate=1/33.33,
- sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType
- .AFTER_WITH_REPEAT)
-
- _ = wrapper.reset()
- mock_sleep.assert_not_called() # It should never sleep during reset().
-
- @_with_timestamp
- def _sleep_fn(sleep_time):
- _sleep_fn.timestamp = time.time()
- self.assertBetween(sleep_time, 0.0, 33.33)
-
- mock_sleep.side_effect = _sleep_fn
-
- @_with_timestamp
- def _step_fn(action):
- # On even calls the action should be the actual agent action, but on odd
- # calls they should be REPEATs.
- if len(_step_fn.timestamps) % 2 == 0:
- self.assertEqual(
- action['action_type'],
- np.array(action_type.ActionType.LIFT, dtype=np.uint8))
- else:
- self.assertEqual(
- action['action_type'],
- np.array(action_type.ActionType.REPEAT, dtype=np.uint8))
- _step_fn.timestamps.append(time.time())
- return dm_env.transition(reward=1.0, observation=None)
-
- _step_fn.timestamps = []
-
- env.step.side_effect = _step_fn
-
- timestep = wrapper.step({
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position': np.array([0.123, 0.456])
- })
-
- # When the wrapper is enabled, base step should be called twice.
- self.assertEqual(env.step.call_count, 2)
-
- # `step()` should be called twice: before `sleep()` and after it.
- self.assertLen(_step_fn.timestamps, 2)
- self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0])
- self.assertLessEqual(_sleep_fn.timestamp, _step_fn.timestamps[1])
- # Rewards should accumulate over the two step() calls
- self.assertEqual(timestep.reward, 2.0)
-
- @mock.patch.object(time, 'sleep', autospec=True)
- def test_enabled_sleep_type_after_with_repeat_last(self, mock_sleep):
- """If the first step is a LAST, second step should not be taken."""
-
- env = _get_base_env()
- wrapper = rate_limit_wrapper.RateLimitWrapper(
- env,
- rate=1/33.33,
- sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType
- .AFTER_WITH_REPEAT)
-
- _ = wrapper.reset()
- mock_sleep.assert_not_called() # It should never sleep during reset().
-
- env.step.return_value = dm_env.termination(reward=None, observation=None)
-
- _ = wrapper.step({
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
- 'touch_position': np.array([0.123, 0.456])
- })
-
- # Second step call should be skipped.
- env.step.assert_called_once()
- mock_sleep.assert_not_called()
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/android_env/wrappers/tap_action_wrapper.py b/android_env/wrappers/tap_action_wrapper.py
deleted file mode 100644
index 5f0ee666..00000000
--- a/android_env/wrappers/tap_action_wrapper.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Wraps the AndroidEnv environment to provide tap actions of a given duration."""
-
-from collections.abc import Sequence
-
-from android_env.components import action_type
-from android_env.wrappers import base_wrapper
-import dm_env
-import numpy as np
-
-
-class TapActionWrapper(base_wrapper.BaseWrapper):
- """AndroidEnv with tap actions."""
-
- def __init__(self,
- env: dm_env.Environment,
- num_frames: int = 5,
- touch_only: bool = False):
- super().__init__(env)
- assert 'action_type' in env.action_spec()
- self._touch_only = touch_only
- self._num_frames = num_frames
- self._env_steps = 0
-
- def stats(self):
- """Returns a dictionary of metrics logged by the environment."""
- logs = self._env.stats()
- logs.update({'env_steps': self._env_steps})
- return logs
-
- def _process_action(
- self, action: dict[str, np.ndarray]
- ) -> Sequence[dict[str, np.ndarray]]:
- if self._touch_only:
- assert action['action_type'] == 0
- touch_action = action.copy()
- touch_action['action_type'] = np.array(
- action_type.ActionType.TOUCH
- ).astype(self.action_spec()['action_type'].dtype)
- actions = [touch_action] * self._num_frames
- lift_action = action.copy()
- lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype(
- self.action_spec()['action_type'].dtype
- )
- actions.append(lift_action)
-
- else:
- if action['action_type'] == action_type.ActionType.TOUCH:
- actions = [action] * self._num_frames
- lift_action = action.copy()
- lift_action['action_type'] = np.array(
- action_type.ActionType.LIFT
- ).astype(self.action_spec()['action_type'].dtype)
- actions.append(lift_action)
- else:
- actions = [action] * (self._num_frames + 1)
-
- return actions
-
- def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
- """Takes a step in the environment."""
- self._env_steps += self._num_frames + 1
- actions = self._process_action(action)
- total_reward = 0.0
- for idx in range(len(actions)):
- step_type, reward, discount, observation = self._env.step(actions[idx])
- if reward:
- total_reward += reward
- if step_type == dm_env.StepType.LAST:
- return dm_env.TimeStep(
- step_type=step_type,
- reward=total_reward,
- discount=discount,
- observation=observation)
- return dm_env.TimeStep(
- step_type=step_type,
- reward=total_reward,
- discount=discount,
- observation=observation)
-
- def action_spec(self) -> dict[str, dm_env.specs.Array]:
- if self._touch_only:
- return {
- 'action_type':
- dm_env.specs.DiscreteArray(num_values=1, name='action_type'),
- 'touch_position':
- self._env.action_spec()['touch_position'],
- }
- else:
- return self._env.action_spec()
diff --git a/android_env/wrappers/tap_action_wrapper_test.py b/android_env/wrappers/tap_action_wrapper_test.py
deleted file mode 100644
index b6f9ceb3..00000000
--- a/android_env/wrappers/tap_action_wrapper_test.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for tap_action_wrapper."""
-
-from unittest import mock
-
-from absl.testing import absltest
-from android_env import env_interface
-from android_env.components import action_type
-from android_env.wrappers import tap_action_wrapper
-import dm_env
-from dm_env import specs
-import numpy as np
-
-
-def _make_array_spec(shape, dtype, name):
- return specs.BoundedArray(
- name=name,
- shape=shape,
- dtype=dtype,
- minimum=np.zeros(shape),
- maximum=np.ones(shape), # maximum is inclusive.
- )
-
-
-class TapActionWrapperTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self._base_action_spec = {
- 'action_type': specs.DiscreteArray(
- num_values=3, name='action_type'),
- 'touch_position': _make_array_spec(
- shape=(2,), dtype=np.float32, name='touch_position'),
- }
- self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
- self.base_env.action_spec.return_value = self._base_action_spec
-
- def test_process_action_repeat(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=3)
- action = {
- 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32),
- 'touch_position': np.array([0.5, 0.5], dtype=np.float32),
- }
- actions = wrapped_env._process_action(action)
- self.assertLen(actions, wrapped_env._num_frames + 1)
- self.assertEqual(action, actions[-1])
-
- def test_process_action_lift(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=3)
- action = {
- 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.int32),
- 'touch_position': np.array([0.5, 0.5], dtype=np.float32),
- }
- actions = wrapped_env._process_action(action)
- self.assertLen(actions, wrapped_env._num_frames + 1)
- self.assertEqual(action, actions[-1])
-
- def test_process_action_touch(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=3)
- action = {
- 'action_type': np.array(action_type.ActionType.TOUCH, dtype=np.int32),
- 'touch_position': np.array([0.5, 0.5], dtype=np.float32),
- }
- actions = wrapped_env._process_action(action)
- self.assertLen(actions, wrapped_env._num_frames + 1)
- self.assertEqual(
- actions[-1]['action_type'], np.array(action_type.ActionType.LIFT)
- )
-
- def test_reset(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=5)
- fake_timestep = 'ts'
- self.base_env.reset.return_value = fake_timestep
- ts = wrapped_env.reset()
- self.base_env.reset.assert_called_once()
- self.assertEqual(fake_timestep, ts)
-
- def test_step(self):
- # Arrange.
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=5)
- fake_timestep = dm_env.TimeStep(
- step_type='fake_type',
- reward=0.0,
- discount=1.0,
- observation='fake_obs')
- self.base_env.step.return_value = fake_timestep
- self.base_env.stats.return_value = {}
-
- # Act.
- ts = wrapped_env.step({
- 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32),
- 'touch_position': np.array([0.5, 0.5], dtype=np.float32),
- })
- stats = wrapped_env.stats()
-
- # Assert.
- self.assertEqual(wrapped_env._num_frames+1, self.base_env.step.call_count)
- self.assertIsInstance(ts, dm_env.TimeStep)
- self.assertIsInstance(stats, dict)
- self.assertIn('env_steps', stats)
- self.assertEqual(stats['env_steps'], 6)
-
- def test_observation_spec(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=5)
- fake_obs_spec = 'fake_obs_spec'
- self.base_env.observation_spec.return_value = fake_obs_spec
- observation_spec = wrapped_env.observation_spec()
- self.base_env.observation_spec.assert_called_once()
- self.assertEqual(fake_obs_spec, observation_spec)
-
- def test_action_spec(self):
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=5)
- self.base_env.action_spec.return_value = self._base_action_spec
- action_spec = wrapped_env.action_spec()
- self.base_env.action_spec.assert_called()
- self.assertEqual(self.base_env.action_spec(),
- action_spec)
-
- def test_stats(self):
- """Checks that returned stats have expected properties."""
-
- # Arrange.
- self.base_env.stats.return_value = {
- 'some_key': 12345,
- 'another_key': 5.4321,
- }
- wrapped_env = tap_action_wrapper.TapActionWrapper(
- self.base_env, num_frames=5
- )
-
- # Act.
- stats = wrapped_env.stats()
-
- # Assert.
- self.assertIsInstance(stats, dict)
- # Original entries should still be present.
- self.assertIn('some_key', stats)
- self.assertEqual(stats['some_key'], 12345)
- self.assertIn('another_key', stats)
- self.assertEqual(stats['another_key'], 5.4321)
- # TapActionWrapper inserts its own `env_steps`.
- self.assertIn('env_steps', stats)
- self.assertEqual(stats['env_steps'], 0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/examples/__init__.py b/examples/__init__.py
deleted file mode 100644
index 2f66bf75..00000000
--- a/examples/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/examples/run_acme_agent.py b/examples/run_acme_agent.py
deleted file mode 100644
index 754a0436..00000000
--- a/examples/run_acme_agent.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Acme DQN agent interacting with AndroidEnv."""
-
-from absl import app
-from absl import flags
-from absl import logging
-import acme
-from acme import specs
-from acme import wrappers as acme_wrappers
-from acme.agents.tf import dqn
-from acme.tf import networks
-from android_env import loader
-from android_env.components import config_classes
-from android_env.wrappers import discrete_action_wrapper
-from android_env.wrappers import float_pixels_wrapper
-from android_env.wrappers import image_rescale_wrapper
-
-# Simulator args
-flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
-flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
-flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
-flags.DEFINE_string('emulator_path',
- '~/Android/Sdk/emulator/emulator', 'Path to emulator.')
-flags.DEFINE_string('adb_path',
- '~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
-
-# Environment args
-flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
-
-# Experiment args
-flags.DEFINE_integer('num_episodes', 100, 'Number of episodes.')
-
-FLAGS = flags.FLAGS
-
-
-def apply_wrappers(env):
- """Applies a series of wrappers to the environment."""
- env = discrete_action_wrapper.DiscreteActionWrapper(env, action_grid=(10, 10))
- env = image_rescale_wrapper.ImageRescaleWrapper(
- env, zoom_factors=(0.25, 0.25))
- env = float_pixels_wrapper.FloatPixelsWrapper(env)
- env = acme_wrappers.SinglePrecisionWrapper(env)
- return env
-
-
-def main(_):
-
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
- simulator=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- emulator_path=FLAGS.emulator_path,
- android_sdk_root=FLAGS.android_sdk_root,
- android_avd_home=FLAGS.android_avd_home,
- avd_name=FLAGS.avd_name,
- run_headless=FLAGS.run_headless,
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path=FLAGS.adb_path
- ),
- ),
- )
- with loader.load(config) as env:
-
- env = apply_wrappers(env)
- env_spec = specs.make_environment_spec(env)
-
- agent = dqn.DQN(
- environment_spec=env_spec,
- network=networks.DQNAtariNetwork(
- num_actions=env_spec.actions.num_values),
- batch_size=10,
- samples_per_insert=2,
- min_replay_size=10)
-
- loop = acme.EnvironmentLoop(env, agent)
- loop.run(num_episodes=FLAGS.num_episodes)
-
-
-if __name__ == '__main__':
- logging.set_verbosity('info')
- logging.set_stderrthreshold('info')
- flags.mark_flags_as_required(['task_path', 'avd_name'])
- app.run(main)
diff --git a/examples/run_human_agent.py b/examples/run_human_agent.py
deleted file mode 100644
index 505a729b..00000000
--- a/examples/run_human_agent.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Loads an interactive session where a human acts on behalf of an agent."""
-
-import time
-from typing import Any
-
-from absl import app
-from absl import flags
-from absl import logging
-from android_env import loader
-from android_env.components import action_type
-from android_env.components import config_classes
-from android_env.components import pixel_fns
-import dm_env
-import numpy as np
-import pygame
-
-# Simulator args.
-flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
-flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
-flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
-flags.DEFINE_string('emulator_path',
- '~/Android/Sdk/emulator/emulator', 'Path to emulator.')
-flags.DEFINE_string('adb_path',
- '~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
-flags.DEFINE_boolean('run_headless', True, 'Optionally turn off display.')
-
-# Environment args.
-flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
-
-# Pygame args.
-flags.DEFINE_list('screen_size', '480,720', 'Screen width, height in pixels.')
-flags.DEFINE_float('frame_rate', 1.0/30.0, 'Frame rate in seconds.')
-
-FLAGS = flags.FLAGS
-
-
-def _get_action_from_event(
- event: pygame.event.Event, screen: pygame.Surface, orientation: int
-) -> dict[str, Any]:
- """Returns the current action by reading data from a pygame Event object."""
-
- act_type = action_type.ActionType.LIFT
- if event.type == pygame.MOUSEBUTTONDOWN:
- act_type = action_type.ActionType.TOUCH
-
- return {
- 'action_type':
- np.array(act_type, dtype=np.int32),
- 'touch_position':
- _scale_position(event.pos, screen, orientation),
- }
-
-
-def _get_action_from_mouse(
- screen: pygame.Surface, orientation: int
-) -> dict[str, Any]:
- """Returns the current action by reading data from the mouse."""
-
- act_type = action_type.ActionType.LIFT
- if pygame.mouse.get_pressed()[0]:
- act_type = action_type.ActionType.TOUCH
-
- return {
- 'action_type':
- np.array(act_type, dtype=np.int32),
- 'touch_position':
- _scale_position(pygame.mouse.get_pos(), screen, orientation),
- }
-
-
-def _scale_position(position: np.ndarray, screen: pygame.Surface,
- orientation: int) -> np.ndarray:
- """AndroidEnv accepts mouse inputs as floats so we need to scale it."""
-
- scaled_pos = np.divide(position, screen.get_size(), dtype=np.float32)
- if orientation == 1: # LANDSCAPE_90
- scaled_pos = scaled_pos[::-1]
- scaled_pos[0] = 1 - scaled_pos[0]
- return scaled_pos
-
-
-def _accumulate_reward(
- timestep: dm_env.TimeStep,
- episode_return: float) -> float:
- """Accumulates rewards collected over the course of an episode."""
-
- if timestep.reward and timestep.reward != 0:
- logging.info('Reward: %s', timestep.reward)
- episode_return += timestep.reward
-
- if timestep.first():
- episode_return = 0
- elif timestep.last():
- logging.info('Episode return: %s', episode_return)
-
- return episode_return
-
-
-def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface,
- orientation: int, timestep: dm_env.TimeStep) -> None:
- """Displays latest observation on pygame surface."""
-
- frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB)
- frame = pixel_fns.transpose_pixels(frame) # (W x H x C)
- frame = pixel_fns.orient_pixels(frame, orientation)
-
- pygame.surfarray.blit_array(surface, frame)
- pygame.transform.smoothscale(surface, screen.get_size(), screen)
-
- pygame.display.flip()
-
-
-def main(_):
-
- pygame.init()
- pygame.display.set_caption('android_human_agent')
-
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
- simulator=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- emulator_path=FLAGS.emulator_path,
- android_sdk_root=FLAGS.android_sdk_root,
- android_avd_home=FLAGS.android_avd_home,
- avd_name=FLAGS.avd_name,
- run_headless=FLAGS.run_headless,
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path=FLAGS.adb_path
- ),
- ),
- )
- with loader.load(config) as env:
-
- # Reset environment.
- first_timestep = env.reset()
- orientation = np.argmax(first_timestep.observation['orientation'])
-
- # Create pygame canvas.
- screen_size = list(map(int, FLAGS.screen_size)) # (W x H)
- obs_shape = env.observation_spec()['pixels'].shape[:2] # (H x W)
-
- if (orientation == 1 or orientation == 3): # LANDSCAPE_90 | LANDSCAPE_270
- screen_size = screen_size[::-1]
- obs_shape = obs_shape[::-1]
-
- screen = pygame.display.set_mode(screen_size) # takes (W x H)
- surface = pygame.Surface(obs_shape[::-1]) # takes (W x H)
-
- # Start game loop.
- prev_frame = time.time()
- episode_return = 0
-
- while True:
- if pygame.key.get_pressed()[pygame.K_ESCAPE]:
- return
-
- all_events = pygame.event.get()
- for event in all_events:
- if event.type == pygame.QUIT:
- return
-
- # Filter event queue for mouse click events.
- mouse_click_events = [
- event for event in all_events
- if event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP]
- ]
-
- # Process all mouse click events.
- for event in mouse_click_events:
- action = _get_action_from_event(event, screen, orientation)
- timestep = env.step(action)
- episode_return = _accumulate_reward(timestep, episode_return)
- _render_pygame_frame(surface, screen, orientation, timestep)
-
- # Sample the current position of the mouse either way.
- action = _get_action_from_mouse(screen, orientation)
- timestep = env.step(action)
- episode_return = _accumulate_reward(timestep, episode_return)
- _render_pygame_frame(surface, screen, orientation, timestep)
-
- # Limit framerate.
- now = time.time()
- frame_time = now - prev_frame
- if frame_time < FLAGS.frame_rate:
- time.sleep(FLAGS.frame_rate - frame_time)
- prev_frame = now
-
-
-if __name__ == '__main__':
- logging.set_verbosity('info')
- logging.set_stderrthreshold('info')
- flags.mark_flags_as_required(['avd_name', 'task_path'])
- app.run(main)
diff --git a/examples/run_random_agent.py b/examples/run_random_agent.py
deleted file mode 100644
index fc8017b4..00000000
--- a/examples/run_random_agent.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# coding=utf-8
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Example script demonstrating usage of AndroidEnv."""
-
-from absl import app
-from absl import flags
-from absl import logging
-from android_env import loader
-from android_env.components import config_classes
-from dm_env import specs
-import numpy as np
-
-FLAGS = flags.FLAGS
-
-# Simulator args.
-flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
-flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
-flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
-flags.DEFINE_string('emulator_path',
- '~/Android/Sdk/emulator/emulator', 'Path to emulator.')
-flags.DEFINE_string('adb_path',
- '~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
-flags.DEFINE_bool('run_headless', False,
- 'Whether to display the emulator window.')
-
-# Environment args.
-flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
-
-# Experiment args.
-flags.DEFINE_integer('num_steps', 1000, 'Number of steps to take.')
-
-
-def main(_):
-
- config = config_classes.AndroidEnvConfig(
- task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
- simulator=config_classes.EmulatorConfig(
- emulator_launcher=config_classes.EmulatorLauncherConfig(
- emulator_path=FLAGS.emulator_path,
- android_sdk_root=FLAGS.android_sdk_root,
- android_avd_home=FLAGS.android_avd_home,
- avd_name=FLAGS.avd_name,
- run_headless=FLAGS.run_headless,
- ),
- adb_controller=config_classes.AdbControllerConfig(
- adb_path=FLAGS.adb_path
- ),
- ),
- )
- with loader.load(config) as env:
-
- action_spec = env.action_spec()
-
- def get_random_action() -> dict[str, np.ndarray]:
- """Returns a random AndroidEnv action."""
- action = {}
- for k, v in action_spec.items():
- if isinstance(v, specs.DiscreteArray):
- action[k] = np.random.randint(low=0, high=v.num_values, dtype=v.dtype)
- else:
- action[k] = np.random.random(size=v.shape).astype(v.dtype)
- return action
-
- _ = env.reset()
-
- for step in range(FLAGS.num_steps):
- action = get_random_action()
- timestep = env.step(action=action)
- reward = timestep.reward
- logging.info('Step %r, action: %r, reward: %r', step, action, reward)
-
-
-if __name__ == '__main__':
- logging.set_verbosity('info')
- logging.set_stderrthreshold('info')
- flags.mark_flags_as_required(['avd_name', 'task_path'])
- app.run(main)
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644
index 2504653f..00000000
--- a/pyproject.toml
+++ /dev/null
@@ -1,39 +0,0 @@
-[build-system]
-requires = [
- "setuptools",
- "wheel"
-]
-build-backend = "setuptools.build_meta"
-
-[project]
-name = "android-env"
-version = "1.2.2"
-description = "AndroidEnv environment and library for training agents."
-authors = [{name = "DeepMind"}]
-license = {file = "LICENSE"}
-readme = {text = "Read the README at https://github.com/deepmind/android_env for more information.", content-type = "text/plain"}
-keywords = ["Android", "OS", "reinforcement-learning"]
-requires-python = ">=3.10"
-dependencies = [
- "absl-py>=0.1.0",
- "dm_env",
- "grpcio",
- "numpy>=1.21",
- "portpicker>=1.2.0",
- "protobuf>=2.6",
- "pygame",
-]
-
-[project.optional-dependencies]
-testing = [
- "gym",
- "pillow",
- "pytype",
-]
-acme = ["dm-acme"]
-gym = ["gym"]
-
-[project.urls]
-repository = "https://github.com/deepmind/android_env"
-deepmind = "https://www.deepmind.com/publications/androidenv-the-android-learning-environment"
-arxiv = "https://arxiv.org/abs/2105.13231"
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 2c201d0b..00000000
--- a/setup.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright 2024 DeepMind Technologies Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Simple package definition for using with `pip`."""
-
-import importlib
-import os
-
-import setuptools
-from setuptools import find_packages
-from setuptools import setup
-from setuptools.command.build_ext import build_ext
-from setuptools.command.build_py import build_py
-
-_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-
-# Tuple of proto message definitions to build Python bindings for. Paths must
-# be relative to root directory.
-_ANDROID_ENV_PROTOS = (
- 'android_env/proto/adb.proto',
- 'android_env/proto/emulator_controller.proto',
- 'android_env/proto/snapshot.proto',
- 'android_env/proto/snapshot_service.proto',
- 'android_env/proto/state.proto',
- 'android_env/proto/task.proto',
- 'android_env/proto/a11y/a11y.proto',
- 'android_env/proto/a11y/android_accessibility_action.proto',
- 'android_env/proto/a11y/android_accessibility_forest.proto',
- 'android_env/proto/a11y/android_accessibility_node_info.proto',
- 'android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto',
- 'android_env/proto/a11y/android_accessibility_tree.proto',
- 'android_env/proto/a11y/android_accessibility_window_info.proto',
- 'android_env/proto/a11y/rect.proto',
-)
-
-
-class _GenerateProtoFiles(setuptools.Command):
- """Command to generate protobuf bindings for AndroidEnv protos."""
-
- descriptions = 'Generates Python protobuf bindings for AndroidEnv protos.'
- user_options = []
-
- def initialize_options(self):
- pass
-
- def finalize_options(self):
- pass
-
- def run(self):
- # Import grpc_tools here, after setuptools has installed setup_requires
- # dependencies.
- from grpc_tools import protoc # pylint: disable=g-import-not-at-top
-
- with importlib.resources.as_file(
- importlib.resources.files('grpc_tools').joinpath('_proto')
- ) as path:
- grpc_protos_include = str(path)
-
- for proto_path in _ANDROID_ENV_PROTOS:
- proto_args = [
- 'grpc_tools.protoc',
- '--proto_path={}'.format(grpc_protos_include),
- '--proto_path={}'.format(_ROOT_DIR),
- '--python_out={}'.format(_ROOT_DIR),
- '--pyi_out={}'.format(_ROOT_DIR),
- '--grpc_python_out={}'.format(_ROOT_DIR),
- os.path.join(_ROOT_DIR, proto_path),
- ]
- if protoc.main(proto_args) != 0:
- raise RuntimeError('ERROR: {}'.format(proto_args))
-
-
-class _BuildExt(build_ext):
- """Generate protobuf bindings in build_ext stage."""
-
- def run(self):
- self.run_command('generate_protos')
- build_ext.run(self)
-
-
-class _BuildPy(build_py):
- """Generate protobuf bindings in build_py stage."""
-
- def run(self):
- self.run_command('generate_protos')
- build_py.run(self)
-
-setup(
- packages=find_packages(exclude=['examples']),
- package_data={'': ['proto/*.proto']}, # Copy protobuf files.
- include_package_data=True,
- setup_requires=['grpcio-tools'],
- cmdclass={
- 'build_ext': _BuildExt,
- 'build_py': _BuildPy,
- 'generate_protos': _GenerateProtoFiles,
- },
-)