diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 61fb1eec..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: tests - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - workflow_dispatch: - inputs: - git-ref: - description: Git Ref (Optional) - required: false - -jobs: - build: - runs-on: ubuntu-latest - env: - TEST_TMPDIR: '/tmp' - strategy: - matrix: - python-version: ["3.11", "3.12", "3.13"] - steps: - - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - run: | - pip install --upgrade pip setuptools - python setup.py install - pip install .[testing] - - - name: Run tests - run: | - # Find all test files, print their names and execute them in parallel - # with a maximum of 20 proccesses. - find . -type f -name "*_test.py" -print0 | xargs -t -0 -n1 -P 20 python3 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 03c65248..00000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,25 +0,0 @@ - -# How to Contribute - -# Pull Requests - -Please send in fixes or feature additions through Pull Requests. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution, -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -## Code reviews - -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more -information on using pull requests. diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 7a4a3ea2..00000000 --- a/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md deleted file mode 100644 index b48c6bce..00000000 --- a/README.md +++ /dev/null @@ -1,163 +0,0 @@ - -# AndroidEnv - The Android Learning Environment - - - -[AndroidEnv](https://github.com/deepmind/android_env) is a Python library that -exposes an [Android](https://www.android.com/) device as a Reinforcement -Learning (RL) environment. The library provides a flexible platform for defining -custom tasks on top of the Android Operating System, including any Android -application. Agents interact with the device through a universal action -interface - the touchscreen - by sending localized touch and lift events to the -system. The library processes these events and returns pixel observations and -rewards as provided by specific [task definitions](docs/tasks_guide.md). For -example, rewards might be given for events such as successfully scrolling down a -page, sending an email, or achieving some score in a game, depending on the -research purpose and how the user configures the task. - -[![tests](https://github.com/deepmind/android_env/actions/workflows/tests.yml/badge.svg?branch=main)](https://github.com/deepmind/android_env/actions/workflows/tests.yml) -[![PyPI version](https://badge.fury.io/py/android-env.svg)](https://badge.fury.io/py/android-env) -[![Downloads](https://pepy.tech/badge/android-env)](https://pepy.tech/project/android-env) - -## Index - -* [Environment details](docs/environment.md) -* [Running AndroidEnv](docs/instructions.md) -* [Setting up a virtual Android device](docs/emulator_guide.md) -* [Defining a task in AndroidEnv](docs/tasks_guide.md) -* [Example tasks available for download](docs/example_tasks.md) - -## Environment features - -There are a number of aspects that make AndroidEnv a challenging yet suitable -environment for Reinforcement Learning research: - -* Allowing agents to interact with a system used daily by billions of users - around the world, AndroidEnv offers a platform for RL agents to navigate, - learn tasks and have direct impact in **real-world contexts**. The - environment wraps a simulated Android device, which runs independently from - the environment, completely unaltered, and works in exactly the same way as - the devices that humans use, exposing exactly the same features and - services. - -* The platform offers a virtually infinite **range of possible tasks**, all - sharing a common action interface. The library facilitates the design of - Reinforcement Learning tasks for any existing or custom built Android - application. For example, it exposes the broad world of Android games, - ranging from card games, puzzle games, time reactive games, all requiring a - diverse set of action combinations and interaction types. - -* The environment runs on top of a **real-time simulation** of an Android - device. In other words, the environment dynamics does not wait for the agent - to deliberate, and the speed of the simulation cannot be increased. - -* The observation is a collection of **RGB values** corresponding to the - displayed pixels on the screen. The exact screen resolution depends on the - simulated device, but in general it will be considered relatively large in - an RL context. However, users have the option of downsampling each - observation. - -* The learning environment has an interesting, **complex action space** unique - to the touchscreen interface of Android. - - * The raw, **hybrid action space** consists of a continuous tuple - signifying the action location, and a discrete signal determining - whether the agent wants to touch the screen or lift its virtual finger. - * Raw actions are highly **composable**: the Android UI and most - applications were designed so that they could be intuitively navigated - via common - [touchscreen gestures](https://developer.android.com/training/gestures/detector) - such as tapping, scrolling, swiping, pinching, drag & drop etc. This is - still the case in AndroidEnv: to trigger meaningful changes in the - environment, the agent often has to perform carefully timed and - positioned sequences of raw actions. For example, in order to navigate - to the next image in a photo gallery, the agent would have to perform a - *swipe*, touching the screen multiple times, gradually shifting the - actions' positions to the right. Thus, in most contexts raw actions do - not trigger changes in the state of the environment unless correctly - chained together to make up a human gesture. - * The action interface is **closely related to the observation space**, as - meaningful touch and lift events are often either co-localized or - strongly correlated to the location or movement of salient objects in - the observation. For example, the position of a button on the screen - aligns with the location of the actions that trigger the button press. - * The library provides tools for flexibly **altering the action - interface** if needed for particular studies, such as discretization or - hard-coding gesture skills. Still, we believe that the real challenge - remains in devising agents that are capable of dealing with a large - suite of diverse tasks, through acting and learning in the complex - unifying action interface. - -# Getting started - -### Installation - -The easiest way to get AndroidEnv is with pip: - -```shell -$ python3 -m pip install android-env -``` - -Please note that `/examples` are not included in this package. - -Alternatively, you can clone the repository from git's `main` branch: - -```shell -$ git clone https://github.com/deepmind/android_env/ -$ cd android_env -$ python3 setup.py install -``` - -Update: the environment now runs on Windows, but please keep in mind that this -option is not well-maintained or widely supported, as Unix-based systems are the -primary target platforms of this project. - -### Create a simulator - -Before running the environment, you will need access to an emulated Android -device. For instructions on creating a virtual Android device, see the -[Emulator guide](docs/emulator_guide.md). - -### Define a task - -Then, you will want to define what the agent's *task* is. At this point, the -agent will be able to communicate with the emulated device, but it will not yet -have an objective, or access to signals such as rewards or RL episode ends. -Learn [how to define an RL task](docs/tasks_guide.md) of your own, or use one of -the [existing task definitions](docs/example_tasks.md) for training. - -### Load and run - -To find out how to run and train agents on AndroidEnv, see these -[detailed instructions](docs/instructions.md). Here you can also find example -scripts demonstrating how to run a random agent, an -[acme](https://github.com/deepmind/acme) agent, or a human agent on AndroidEnv. - -## About - -This library is developed and maintained by [DeepMind](http://deepmind.com). \ -You can find the [technical report](https://arxiv.org/abs/2105.13231) on Arxiv, -as well as an introductory -[blog -post](https://www.deepmind.com/publications/androidenv-the-android-learning-environment) -on DeepMind's website. - -If you use AndroidEnv in your research, you can cite the paper using the -following BibTeX: - -``` -@article{ToyamaEtAl2021AndroidEnv, - title = {{AndroidEnv}: A Reinforcement Learning Platform for Android}, - author = {Daniel Toyama and Philippe Hamel and Anita Gergely and - Gheorghe Comanici and Amelia Glaese and Zafarali Ahmed and Tyler - Jackson and Shibl Mourad and Doina Precup}, - year = {2021}, - eprint = {2105.13231}, - archivePrefix = {arXiv}, - primaryClass = {cs.LG}, - volume = {abs/2105.13231}, - url = {http://arxiv.org/abs/2105.13231}, -} -``` - -Disclaimer: This is not an official Google product. diff --git a/android_env/__init__.py b/android_env/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt deleted file mode 100644 index 6e9bf82d..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.accessibilityservice.AccessibilityService -import android.util.Log -import android.view.accessibility.AccessibilityEvent -import android.view.accessibility.AccessibilityNodeInfo -import android.view.accessibility.AccessibilityWindowInfo -import com.google.androidenv.accessibilityforwarder.A11yServiceGrpcKt.A11yServiceCoroutineStub -import io.grpc.ManagedChannel -import io.grpc.ManagedChannelBuilder -import io.grpc.ProxyDetector -import io.grpc.StatusException -import kotlinx.coroutines.TimeoutCancellationException -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withTimeout - -/** - * An Android service that listens to accessibility events and forwards them via gRPC. - * - * This service also logs the accessibility tree if [LogFlags.logAccessibilityTree] is set and if - * [LogFlags.grpcPort] is positive. - * - * Please see - * https://developer.android.com/reference/android/view/accessibility/AccessibilityEvent#getEventType() - * for a comprehensive list of events emitted by Android. - */ -class AccessibilityForwarder( - private val channelFactory: (host: String, port: Int) -> ManagedChannel = { host, port -> - ManagedChannelBuilder.forAddress(host, port) - .proxyDetector(ProxyDetector { _ -> null }) - .usePlaintext() - .build() - } -) : AccessibilityService() { - - init { - // Spawn long-running thread for periodically logging the tree. - Thread( - Runnable { - while (LogFlags.a11yTreePeriodMs > 0) { - try { - val windows = this.windows - logAccessibilityTree(windows) - } catch (e: ConcurrentModificationException) { - continue - } - - Thread.sleep(/* millis= */ LogFlags.a11yTreePeriodMs) - } - } - ) - .start() - } - - // grpcStub has a backing property that can be reset to null. - private var _grpcStub: A11yServiceCoroutineStub? = null - val grpcStub: A11yServiceCoroutineStub - get() { - if (_grpcStub == null) { - Log.i(TAG, "Building channel on ${LogFlags.grpcHost}:${LogFlags.grpcPort}.") - _grpcStub = A11yServiceCoroutineStub(channelFactory(LogFlags.grpcHost, LogFlags.grpcPort)) - } - return _grpcStub!! - } - - private fun resetGrpcStub() { - _grpcStub = null - } - - override fun onInterrupt() { - LogFlags.a11yTreePeriodMs = 0 // Turn off periodic tree forwarding. - } - - override fun onAccessibilityEvent(event: AccessibilityEvent?) { - if (event == null) { - Log.i(TAG, "`event` is null.") - return - } - - logExtrasForEvent(event) - val eventType = event.eventType - val eventTypeStr: String = AccessibilityEvent.eventTypeToString(eventType) - if (eventTypeStr.isNotEmpty()) { - Log.i(TAG, eventTypeStr) - } - } - - private fun logAccessibilityTree(windows: List) { - if (!LogFlags.logAccessibilityTree) { - Log.i(TAG, "Not logging accessibility tree") - return - } - - // Check gRPC port before actually building the forest. - if (LogFlags.grpcPort <= 0) { - Log.w(TAG, "Can't log accessibility tree because gRPC port has not been set.") - return - } - - val forest = creator.buildForest(windows) - try { - val grpcTimeoutMillis = 1000L - val response: ForestResponse = - with(grpcStub) { - Log.i(TAG, "sending (blocking) gRPC request for tree.") - runBlocking { withTimeout(grpcTimeoutMillis) { sendForest(forest) } } - } - if (response.error.isNotEmpty()) { - Log.w(TAG, "gRPC response.error: ${response.error}") - } else { - Log.i(TAG, "gRPC request for tree succeeded.") - } - } catch (e: StatusException) { - Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?") - Log.i(TAG, "extra: exception ['$e']") - resetGrpcStub() - } catch (e: TimeoutCancellationException) { - Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?") - Log.i(TAG, "extra: exception ['$e']") - resetGrpcStub() - } - } - - /** Logs extras for all event types. */ - private fun logExtrasForEvent(event: AccessibilityEvent) { - - val events: MutableMap = mutableMapOf() - - val sourceDescription = event.source?.contentDescription() - if (!sourceDescription.isNullOrEmpty()) { - events.put("source_content_description", sourceDescription) - } - - // Output the event text. - val eventText = event.text.joinToString(", ") - if (eventText.isNotEmpty()) { - events.put("event_text", eventText) - } - - // Output the source text. - val sourceText = event.source?.text?.toString() - if (!sourceText.isNullOrEmpty()) { - events.put("source_text", sourceText) - } - - val eventTypeStr: String = AccessibilityEvent.eventTypeToString(event.eventType) - if (eventTypeStr.isNotEmpty()) { - events.put("event_type", eventTypeStr) - } - - val className = event.source?.className?.toString() - if (!className.isNullOrEmpty()) { - events.put("source_class_name", className) - } - - val packageName = event.packageName?.toString() - if (!packageName.isNullOrEmpty()) { - events.put("event_package_name", packageName) - } - - // Text editing properties. - val beforeText = event.beforeText?.toString() - if (!beforeText.isNullOrEmpty()) { - events.put("before_text", beforeText) - } - - val fromIndex = event.fromIndex - if (fromIndex != -1) { - events.put("from_index", fromIndex.toString()) - } - - val toIndex = event.toIndex - if (toIndex != -1) { - events.put("to_index", toIndex.toString()) - } - - val addedCount = event.addedCount - if (addedCount != -1) { - events.put("added_count", addedCount.toString()) - } - - val removedCount = event.removedCount - if (removedCount != -1) { - events.put("removed_count", removedCount.toString()) - } - - // Text traversal properties - val movementGranularity = event.movementGranularity - if (movementGranularity != 0) { - events.put("movement_granularity", movementGranularity.toString()) - } - - val action = event.action - if (action != 0) { - events.put("action", action.toString()) - } - - // Scrolling properties. - if (eventTypeStr == "TYPE_VIEW_SCROLLED") { - events.put("scroll_delta_x", event.scrollDeltaX.toString()) - events.put("scroll_delta_y", event.scrollDeltaY.toString()) - } - - // Report viewID so we know exactly where the event came from. - val viewId = event.source?.viewIdResourceName?.toString() - if (!viewId.isNullOrEmpty()) { - events.put("view_id", viewId) - } - - // Format [events] as a Python dict. - if (events.isNotEmpty()) { - events.put("event_timestamp_ms", event.eventTime.toString(10)) - // Check if we want to use gRPC. - if (LogFlags.grpcPort > 0) { - try { - val grpcTimeoutMillis = 1000L - val request = eventRequest { this.event.putAll(events) } - val response: EventResponse = - with(grpcStub) { - Log.i(TAG, "sending (blocking) gRPC request for event.") - runBlocking { withTimeout(grpcTimeoutMillis) { sendEvent(request) } } - } - if (response.error.isNotEmpty()) { - Log.w(TAG, "gRPC response.error: ${response.error}") - } else { - Log.i(TAG, "gRPC request for event succeeded.") - } - } catch (e: StatusException) { - Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?") - Log.i(TAG, "extra: exception ['$e']") - resetGrpcStub() - } catch (e: TimeoutCancellationException) { - Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?") - Log.i(TAG, "extra: exception ['$e']") - resetGrpcStub() - } - } else { - Log.w(TAG, "Can't log accessibility event because gRPC port has not been set.") - } - } - } - - /** Recursively climbs the accessibility tree until the root, collecting descriptions. */ - private fun AccessibilityNodeInfo?.contentDescription(): String { - if (this == null) { - return "" - } - - val descriptions = mutableListOf() - var current: AccessibilityNodeInfo? = this - while (current != null) { - val description = current.contentDescription - if (description != null) { - descriptions.add(description.toString()) - } - - current = current.parent - } - return descriptions.joinToString(", ") - } - - companion object { - private const val TAG = "AndroidRLTask" - private val creator = AccessibilityTreeCreator() - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt deleted file mode 100644 index b446db7a..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt +++ /dev/null @@ -1,516 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.view.accessibility.AccessibilityEvent -import android.view.accessibility.AccessibilityNodeInfo -import android.view.accessibility.AccessibilityWindowInfo -import com.google.common.truth.Truth.assertThat -import io.grpc.Status -import io.grpc.StatusException -import io.grpc.inprocess.InProcessChannelBuilder -import io.grpc.inprocess.InProcessServerBuilder -import io.grpc.testing.GrpcCleanupRule -import org.junit.Assert.assertFalse -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.robolectric.RobolectricTestParameterInjector -import org.robolectric.Shadows.shadowOf - -@RunWith(RobolectricTestParameterInjector::class) -class AccessibilityForwarderTest { - - @get:Rule(order = 1) val cleanupRule = GrpcCleanupRule() - - class FakeAccessibilityService : A11yServiceGrpcKt.A11yServiceCoroutineImplBase() { - var sendForestChecker: (AndroidAccessibilityForest) -> String = { _ -> "" } - var sendEventChecker: (EventRequest) -> String = { _ -> "" } - - override suspend fun sendForest(request: AndroidAccessibilityForest) = forestResponse { - error = sendForestChecker(request) - } - - override suspend fun sendEvent(request: EventRequest) = eventResponse { - error = sendEventChecker(request) - } - } - - protected lateinit var forwarder: AccessibilityForwarder - protected val fakeA11yService = FakeAccessibilityService() - protected val channel by lazy { - val serverName: String = InProcessServerBuilder.generateName() - cleanupRule.register( - InProcessServerBuilder.forName(serverName) - .directExecutor() - .addService(fakeA11yService) - .build() - .start() - ) - cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()) - } - - /** Initializes [forwarder] and [LogFlags] from the given args. */ - fun createForwarder( - logAccessibilityTree: Boolean = false, - a11yTreePeriodMs: Long = 0, - grpcHost: String = "10.0.2.2", - grpcPort: Int = 0, - a11yWindows: MutableList? = null, - ) { - LogFlags.logAccessibilityTree = logAccessibilityTree - LogFlags.a11yTreePeriodMs = a11yTreePeriodMs - LogFlags.grpcHost = grpcHost - LogFlags.grpcPort = grpcPort - forwarder = AccessibilityForwarder({ _, _ -> channel }) - if (a11yWindows == null) { - shadowOf(forwarder).setWindows(mutableListOf(AccessibilityWindowInfo.obtain())) - } else { - shadowOf(forwarder).setWindows(a11yWindows) - } - } - - @Test - fun onInterrupt_doesNotCrash() { - // Arrange. - createForwarder(logAccessibilityTree = false) - fakeA11yService.sendEventChecker = { _: EventRequest -> - assertFalse(true) // This should not be called. - "" // This should be unreachable - } - - // Act. - forwarder.onInterrupt() - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_nullEventShouldBeIgnored() { - // Arrange. - createForwarder(logAccessibilityTree = false) - fakeA11yService.sendEventChecker = { _: EventRequest -> - assertFalse(true) // This should not be called. - "" // This should be unreachable - } - - // Act. - forwarder.onAccessibilityEvent(null) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_knownEventWithNoInformationShouldNotBeEmitted() { - // Arrange. - createForwarder(logAccessibilityTree = false) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("") - var event = AccessibilityEvent() - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { _: EventRequest -> - assertFalse(true) // This should not be called. - "" // This should be unreachable - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_typeViewClicked_sendEventViaGrpc() { - // Arrange. - createForwarder(logAccessibilityTree = false, grpcPort = 1234) - forwarder = AccessibilityForwarder({ _, _ -> channel }) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) - event.getText().add("Some text!") - event.setPackageName("some.loooong.package.name") - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { request: EventRequest -> - // Check that all fields are consistent with how they were set above. - assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_CLICKED") - assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") - assertThat(request.eventMap.get("source_content_description")) - .isEqualTo("My Content Description") - assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") - assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") - assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") - assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") - // No error message - "" - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_typeViewTextChanged_ensureAllFieldsForwarded() { - // Arrange. - createForwarder(logAccessibilityTree = false, grpcPort = 1234) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED) - event.getText().add("Some text!") - event.fromIndex = 7 - event.beforeText = "Old words" - event.addedCount = 12 - event.removedCount = 9 - event.setPackageName("some.loooong.package.name") - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { request: EventRequest -> - // Check that all fields are consistent with how they were set above. - assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_TEXT_CHANGED") - assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") - assertThat(request.eventMap.get("source_content_description")) - .isEqualTo("My Content Description") - assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") - assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") - assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") - assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") - assertThat(request.eventMap.get("from_index")).isEqualTo("7") - assertThat(request.eventMap.get("before_text")).isEqualTo("Old words") - assertThat(request.eventMap.get("added_count")).isEqualTo("12") - assertThat(request.eventMap.get("removed_count")).isEqualTo("9") - assertFalse(request.eventMap.containsKey("to_index")) - assertFalse(request.eventMap.containsKey("view_id")) - assertFalse(request.eventMap.containsKey("action")) - assertFalse(request.eventMap.containsKey("movement_granularity")) - assertFalse(request.eventMap.containsKey("scroll_delta_x")) - assertFalse(request.eventMap.containsKey("scroll_delta_y")) - // No error message - "" - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_typeViewScrolled_ensureAllFieldsForwarded() { - // Arrange. - createForwarder(logAccessibilityTree = false, grpcPort = 1234) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_SCROLLED) - event.getText().add("Some text!") - event.scrollDeltaX = 13 - event.scrollDeltaY = 27 - event.setPackageName("some.loooong.package.name") - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { request: EventRequest -> - // Check that all fields are consistent with how they were set above. - assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_SCROLLED") - assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") - assertThat(request.eventMap.get("source_content_description")) - .isEqualTo("My Content Description") - assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") - assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") - assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") - assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") - assertThat(request.eventMap.get("scroll_delta_x")).isEqualTo("13") - assertThat(request.eventMap.get("scroll_delta_y")).isEqualTo("27") - assertFalse(request.eventMap.containsKey("from_index")) - assertFalse(request.eventMap.containsKey("to_index")) - assertFalse(request.eventMap.containsKey("before_text")) - assertFalse(request.eventMap.containsKey("added_count")) - assertFalse(request.eventMap.containsKey("removed_count")) - // No error message - "" - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_typeViewTextTraversedAtMovementGranularity_ensureAllFieldsForwarded() { - // Arrange. - createForwarder(logAccessibilityTree = false, grpcPort = 1234) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - nodeInfo.viewIdResourceName = "this.big.old.view.id" - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY) - event.getText().add("Some text!") - event.setPackageName("some.loooong.package.name") - event.movementGranularity = 5 - event.fromIndex = 6 - event.toIndex = 8 - event.action = 23 - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { request: EventRequest -> - // Check that all fields are consistent with how they were set above. - assertThat(request.eventMap.get("event_type")) - .isEqualTo("TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY") - assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") - assertThat(request.eventMap.get("source_content_description")) - .isEqualTo("My Content Description") - assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") - assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") - assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") - assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") - assertThat(request.eventMap.get("movement_granularity")).isEqualTo("5") - assertThat(request.eventMap.get("from_index")).isEqualTo("6") - assertThat(request.eventMap.get("to_index")).isEqualTo("8") - assertThat(request.eventMap.get("view_id")).isEqualTo("this.big.old.view.id") - assertThat(request.eventMap.get("action")).isEqualTo("23") - // No error message - "" - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_sendingevent_grpcTimeout() { - // Arrange. - createForwarder( - logAccessibilityTree = false, - a11yTreePeriodMs = 0, - grpcHost = "amazing.host", - grpcPort = 4321, - ) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) - event.getText().add("Some text!") - event.setPackageName("some.loooong.package.name") - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { _ -> - // Delay the request to prompt a timeout - Thread.sleep(1500L) - "" // Return no error. - } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Run a second request to ensure that the channel gets rebuilt. - fakeA11yService.sendEventChecker = { _ -> "" } - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun onAccessibilityEvent_sendingevent_grpcStatusException() { - // Arrange. - createForwarder(logAccessibilityTree = false, grpcHost = "amazing.host", grpcPort = 4321) - var nodeInfo = AccessibilityNodeInfo() - nodeInfo.setContentDescription("My Content Description") - nodeInfo.setText("My Source Text") - nodeInfo.setClassName("AwesomeClass") - var event = AccessibilityEvent() - event.setEventTime(1357924680) - event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) - event.getText().add("Some text!") - event.setPackageName("some.loooong.package.name") - shadowOf(event).setSourceNode(nodeInfo) - fakeA11yService.sendEventChecker = { _ -> throw StatusException(Status.UNAVAILABLE) } - - // Act. - forwarder.onAccessibilityEvent(event) - - // Run a second request to ensure that the channel gets rebuilt. - fakeA11yService.sendEventChecker = { _ -> "" } - forwarder.onAccessibilityEvent(event) - - // Assert. - // See `sendEventChecker` above. - } - - @Test - fun logAccessibilityTreeFalse_doesNotLogAccessibilityTree() { - // Arrange. - createForwarder(logAccessibilityTree = false, a11yTreePeriodMs = 10, grpcPort = 13579) - fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest -> - assertFalse(true) // This should not be called. - "" // This should be unreachable - } - - // Act. - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } - - @Test - fun grpcPortZero_doesNotSendTree() { - // Arrange. - createForwarder(logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcPort = 0) - fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest -> - assertFalse(true) // This should not be called. - "" // This should be unreachable - } - - // Act. - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } - - @Test - fun grpcPortPositive_shouldSendTreeViaGrpc() { - // Arrange. - val window = AccessibilityWindowInfo() - shadowOf(window).setType(AccessibilityWindowInfo.TYPE_SYSTEM) - createForwarder( - logAccessibilityTree = true, - a11yTreePeriodMs = 10, - grpcPort = 1234, - a11yWindows = mutableListOf(window), - ) - fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest -> - // Check that we get only a single window. - assertThat(request.windowsList.size).isEqualTo(1) - // And that its type is what we set above. - assertThat(request.windowsList[0].windowType) - .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_SYSTEM) - // The error message - "Something went wrong!" - } - - // Act. - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } - - @Test - fun grpcPortPositiveAndHost_shouldSendTreeViaGrpc() { - // Arrange. - fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest -> - // Check that we get only a single window. - assertThat(request.windowsList.size).isEqualTo(1) - // And that its type is what we set above. - assertThat(request.windowsList[0].windowType) - .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_ACCESSIBILITY_OVERLAY) - "" // Return no error. - } - val window = AccessibilityWindowInfo() - shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) - createForwarder( - logAccessibilityTree = true, - a11yTreePeriodMs = 500, - grpcHost = "amazing.host", - grpcPort = 4321, - a11yWindows = mutableListOf(window), - ) - - // Act. - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } - - @Test - fun sendingForest_grpcTimeout() { - // Arrange. - fakeA11yService.sendForestChecker = { _ -> - // Delay the request to prompt a timeout - Thread.sleep(1500L) - "" // Return no error. - } - val window = AccessibilityWindowInfo() - shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) - createForwarder( - logAccessibilityTree = true, - a11yTreePeriodMs = 10, - grpcHost = "amazing.host", - grpcPort = 4321, - a11yWindows = mutableListOf(window), - ) - - // Act. - Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function. - - // Run a second request to ensure that the channel gets rebuilt. - fakeA11yService.sendForestChecker = { _ -> "" } - Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } - - @Test - fun sendingForest_grpcStatusException() { - // Arrange. - val window = AccessibilityWindowInfo() - shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) - createForwarder( - logAccessibilityTree = true, - a11yTreePeriodMs = 10, - grpcHost = "amazing.host", - grpcPort = 4321, - a11yWindows = mutableListOf(window), - ) - fakeA11yService.sendForestChecker = { _ -> throw StatusException(Status.UNAVAILABLE) } - - // Act. - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Run a second request to ensure that the channel gets rebuilt. - fakeA11yService.sendForestChecker = { _ -> "" } - Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. - - // Assert. - // See `sendForestChecker` above. - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt deleted file mode 100644 index aeaee706..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.graphics.Rect -import android.util.Log -import android.view.accessibility.AccessibilityNodeInfo -import android.view.accessibility.AccessibilityWindowInfo -import com.google.androidenv.accessibilityforwarder.AndroidAccessibilityWindowInfo.WindowType -import java.util.concurrent.ConcurrentHashMap -import java.util.stream.Collectors -import kotlin.collections.mutableListOf -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.runBlocking - -/** Helper methods for creating the android accessibility info extra. */ -class AccessibilityTreeCreator() { - - /** Creates an accessibility forest proto. */ - fun buildForest(windowInfos: List): AndroidAccessibilityForest { - val sourcesMap: ConcurrentHashMap = - ConcurrentHashMap() - val windows: List = - processWindowsAndBlock(windowInfos, sourcesMap) - return androidAccessibilityForest { this.windows += windows } - } - - private fun processWindowsAndBlock( - windowInfos: List, - sourcesMap: ConcurrentHashMap, - ): List { - val windows: List - runBlocking { windows = processWindows(windowInfos, sourcesMap) } - return windows - } - - private suspend fun processWindows( - windowInfos: List, - sourcesMap: ConcurrentHashMap, - ): List { - var windowInfoProtos = mutableListOf() - for (i in windowInfos.size - 1 downTo 0) { - val windowInfoProto = processWindow(windowInfos.get(i), sourcesMap) - windowInfoProto?.let { windowInfoProtos.add(windowInfoProto) } - } - return windowInfoProtos.toList() - } - - private suspend fun processWindow( - windowInfo: AccessibilityWindowInfo, - sources: ConcurrentHashMap, - ): AndroidAccessibilityWindowInfo? { - val bounds = Rect() - windowInfo.getBoundsInScreen(bounds) - val root: AccessibilityNodeInfo? = windowInfo.root - if (root == null) { - Log.i(TAG, "window root is null") - return androidAccessibilityWindowInfo { - this.tree = androidAccessibilityTree {} - this.isActive = windowInfo.isActive - this.id = windowInfo.id - this.layer = windowInfo.layer - this.isAccessibilityFocused = windowInfo.isAccessibilityFocused - this.isFocused = windowInfo.isFocused - this.boundsInScreen = convertToRectProto(bounds) - this.windowType = toWindowType(windowInfo.type) - } - } - val treeDeferred: Deferred - runBlocking { treeDeferred = async { processNodesInWindow(root, sources) } } - return androidAccessibilityWindowInfo { - this.tree = treeDeferred.await() - this.isActive = windowInfo.isActive - this.id = windowInfo.id - this.layer = windowInfo.layer - this.isAccessibilityFocused = windowInfo.isAccessibilityFocused - this.isFocused = windowInfo.isFocused - this.boundsInScreen = convertToRectProto(bounds) - this.windowType = toWindowType(windowInfo.type) - } - } - - private suspend fun processNodesInWindow( - root: AccessibilityNodeInfo, - sources: ConcurrentHashMap, - ): AndroidAccessibilityTree { - Log.d(TAG, "processNodesInWindow()") - val traversalQueue = ArrayDeque() - traversalQueue.add(ParentChildNodePair.builder().child(root).build()) - val uniqueIdsCache: UniqueIdsGenerator = UniqueIdsGenerator() - var currentDepth = 0 - val nodesDeferred = mutableListOf>() - val seenNodes: HashSet = HashSet() - seenNodes.add(root) - runBlocking { - while (!traversalQueue.isEmpty()) { - // Traverse the tree layer-by-layer. - // The first layer has only the root and depth 0. - // The second layer has all the root's children and depth 1. - for (nodesAtCurrentDepth in traversalQueue.size downTo 1) { - val nodePair: ParentChildNodePair = traversalQueue.removeFirst() - for (i in 0 until nodePair.child().childCount) { - val childNode: AccessibilityNodeInfo? = nodePair.child().getChild(i) - if (childNode != null && !seenNodes.contains(childNode)) { - traversalQueue.add( - ParentChildNodePair.builder().child(childNode).parent(nodePair.child()).build() - ) - seenNodes.add(childNode) - } - } - val thisDepth = currentDepth - var deferred = async { processNode(nodePair, sources, uniqueIdsCache, thisDepth) } - nodesDeferred.add(deferred) - } - currentDepth++ - } - } - return androidAccessibilityTree { this.nodes += nodesDeferred.awaitAll() } - } - - companion object { - private const val TAG = "AndroidRLTask" - } -} - -private fun processNode( - nodePair: ParentChildNodePair, - sourceBuilder: ConcurrentHashMap, - uniqueIdsCache: UniqueIdsGenerator, - nodeDepth: Int, -): AndroidAccessibilityNodeInfo { - val node: AccessibilityNodeInfo = nodePair.child() - val immutableNode: AndroidAccessibilityNodeInfo = - createAndroidAccessibilityNode( - node, - uniqueIdsCache.getUniqueId(node), - nodeDepth, - getChildUniqueIds(node, uniqueIdsCache), - ) - sourceBuilder.put(immutableNode, node) - return immutableNode -} - -private fun createAndroidAccessibilityNode( - node: AccessibilityNodeInfo, - nodeId: Int, - depth: Int, - childIds: List, -): AndroidAccessibilityNodeInfo { - val bounds = Rect() - node.getBoundsInScreen(bounds) - val actions = node.getActionList().stream().map(::createAction).collect(Collectors.toList()) - return androidAccessibilityNodeInfo { - this.actions += actions - this.boundsInScreen = convertToRectProto(bounds) - this.isCheckable = node.isCheckable - this.isChecked = node.isChecked - this.className = stringFromNullableCharSequence(node.getClassName()) - this.isClickable = node.isClickable - this.contentDescription = stringFromNullableCharSequence(node.getContentDescription()) - this.isEditable = node.isEditable - this.isEnabled = node.isEnabled - this.isFocusable = node.isFocusable - this.hintText = stringFromNullableCharSequence(node.getHintText()) - this.isLongClickable = node.isLongClickable - this.packageName = stringFromNullableCharSequence(node.getPackageName()) - this.isPassword = node.isPassword - this.isScrollable = node.isScrollable - this.isSelected = node.isSelected - this.text = stringFromNullableCharSequence(node.getText()) - this.textSelectionEnd = node.getTextSelectionEnd().toLong() - this.textSelectionStart = node.getTextSelectionStart().toLong() - this.viewIdResourceName = node.getViewIdResourceName() ?: "" - this.isVisibleToUser = node.isVisibleToUser - this.windowId = node.windowId - this.uniqueId = nodeId - this.childIds += childIds - this.drawingOrder = node.drawingOrder - this.tooltipText = stringFromNullableCharSequence(node.getTooltipText()) - this.depth = depth - } -} - -private fun createAction( - action: AccessibilityNodeInfo.AccessibilityAction -): AndroidAccessibilityAction = - AndroidAccessibilityAction.newBuilder() - .setId(action.id) - .setLabel(stringFromNullableCharSequence(action.label)) - .build() - -private fun getChildUniqueIds( - node: AccessibilityNodeInfo, - uniqueIdsCache: UniqueIdsGenerator, -): List { - val ids = mutableListOf() - for (childId in 0 until node.getChildCount()) { - val child: AccessibilityNodeInfo = node.getChild(childId) ?: continue - ids.add(uniqueIdsCache.getUniqueId(child)) - } - return ids.toList() -} - -fun stringFromNullableCharSequence(cs: CharSequence?): String = cs?.toString() ?: "" - -fun convertToRectProto(rect: Rect) = protoRect { - left = rect.left - top = rect.top - right = rect.right - bottom = rect.bottom -} - -private fun toWindowType(type: Int): WindowType = - when (type) { - AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY -> WindowType.TYPE_ACCESSIBILITY_OVERLAY - AccessibilityWindowInfo.TYPE_APPLICATION -> WindowType.TYPE_APPLICATION - AccessibilityWindowInfo.TYPE_INPUT_METHOD -> WindowType.TYPE_INPUT_METHOD - AccessibilityWindowInfo.TYPE_SYSTEM -> WindowType.TYPE_SYSTEM - AccessibilityWindowInfo.TYPE_SPLIT_SCREEN_DIVIDER -> WindowType.TYPE_SPLIT_SCREEN_DIVIDER - else -> WindowType.UNKNOWN_TYPE - } diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt deleted file mode 100644 index b05d23a3..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.view.accessibility.AccessibilityNodeInfo -import android.view.accessibility.AccessibilityWindowInfo -import kotlin.test.assertEquals -import org.junit.Test -import org.junit.runner.RunWith -import org.robolectric.RobolectricTestRunner -import org.robolectric.Shadows.shadowOf - -@RunWith(RobolectricTestRunner::class) -class AccessibilityTreeCreatorTest { - - @Test - fun buildForest_buildsAccessibilityForestCorrectly() { - val creator = AccessibilityTreeCreator() - - val forest = creator.buildForest(mutableListOf(createWindowInfo())) - - assertEquals(forest.windowsCount, 1) - assertEquals(forest.getWindows(0).tree.nodesCount, 3) - var rootNode: AndroidAccessibilityNodeInfo? = null - var checkableNode: AndroidAccessibilityNodeInfo? = null - val nodes = forest.getWindows(0).tree.nodesList - for (i in nodes.size - 1 downTo 0) { - if (nodes[i].text == "root node") { - rootNode = nodes[i] - } - if (nodes[i].isCheckable == true) { - checkableNode = nodes[i] - } - } - assertEquals(rootNode?.childIdsCount, 2) - assertEquals(checkableNode?.text, "Check box") - } - - @Test - fun buildForest_noRootInWindow_returnsEmptyTree() { - val creator = AccessibilityTreeCreator() - val windowInfo = AccessibilityWindowInfo.obtain() - shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) - - val forest = creator.buildForest(mutableListOf(windowInfo)) - - assertEquals(0, forest.getWindows(0).tree.nodesList.size) - } - - private fun createAccessibilityNodeInfo(): AccessibilityNodeInfo { - val root = AccessibilityNodeInfo.obtain() - root.text = "root node" - root.isClickable = true - val accessibilityNodeInfo = AccessibilityNodeInfo.obtain() - accessibilityNodeInfo.viewIdResourceName = "test" - accessibilityNodeInfo.isClickable = true - accessibilityNodeInfo.isEditable = true - accessibilityNodeInfo.hintText = "Please enter your address" - shadowOf(root).addChild(accessibilityNodeInfo) - val anotherChildNode = AccessibilityNodeInfo.obtain() - anotherChildNode.isCheckable = true - anotherChildNode.text = "Check box" - shadowOf(root).addChild(anotherChildNode) - return root - } - - private fun createWindowInfo(): AccessibilityWindowInfo { - val windowInfo = AccessibilityWindowInfo.obtain() - shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) - shadowOf(windowInfo).setRoot(createAccessibilityNodeInfo()) - return windowInfo - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml deleted file mode 100644 index debf611f..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml deleted file mode 100644 index 7bf2e851..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - - - - - - diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt deleted file mode 100644 index 1ce5e763..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.content.BroadcastReceiver -import android.content.Context -import android.content.Intent -import android.util.Log - -/** Broadcast receiver responsible for enabling or disabling flags. */ -class FlagsBroadcastReceiver() : BroadcastReceiver() { - - override fun onReceive(context: Context?, intent: Intent?) { - val action = intent?.action - Log.i(TAG, "Received broadcast intent with action: " + action) - when (action) { - ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS -> { - Log.i(TAG, "Enabling Accessibility Tree logging.") - LogFlags.logAccessibilityTree = true - } - ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS -> { - Log.i(TAG, "Disabling Accessibility Tree logging.") - LogFlags.logAccessibilityTree = false - } - ACTION_SET_GRPC -> { - // The Android Emulator uses 10.0.2.2 as a redirect to the workstation's IP. Most often the - // gRPC server will be running locally so it makes sense to use this as the default value. - // See https://developer.android.com/studio/run/emulator-networking#networkaddresses. - val host = intent.getStringExtra("host") ?: "10.0.2.2" - // The TCP port to connect. If <=0 gRPC is disabled. - val port = intent.getIntExtra("port", 0) - Log.i(TAG, "Setting gRPC endpoint to ${host}:${port}.") - LogFlags.grpcHost = host - LogFlags.grpcPort = port - } - else -> Log.w(TAG, "Unknown action: ${action}") - } - } - - companion object { - private const val TAG = "FlagsBroadcastReceiver" - private const val ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS = - "accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS" - private const val ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS = - "accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS" - private const val ACTION_SET_GRPC = "accessibility_forwarder.intent.action.SET_GRPC" - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt deleted file mode 100644 index f643cfab..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.content.Intent -import com.google.common.truth.Truth.assertThat -import org.junit.Test -import org.junit.runner.RunWith -import org.robolectric.RobolectricTestRunner - -@RunWith(RobolectricTestRunner::class) -class FlagsBroadcastReceiverTest { - - @Test - fun onReceive_nullIntent_shouldNotLogAnything() { - // Arrange. - LogFlags.logAccessibilityTree = false - val receiver = FlagsBroadcastReceiver() - - // Act. - receiver.onReceive(context = null, intent = null) - - // Assert. - assertThat(LogFlags.logAccessibilityTree).isFalse() - } - - @Test - fun onReceive_nullIntent_actionShouldNotLogAnything() { - // Arrange. - LogFlags.logAccessibilityTree = false - val receiver = FlagsBroadcastReceiver() - val intent = Intent() - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.logAccessibilityTree).isFalse() - } - - @Test - fun onReceive_unknownIntent_actionShouldIssueWarning() { - // Arrange. - LogFlags.logAccessibilityTree = false - val receiver = FlagsBroadcastReceiver() - val intent = Intent("SOME_WEIRD_ACTION") - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.logAccessibilityTree).isFalse() - } - - @Test - fun onReceive_intentWithDisableAction_shouldDisableTreeLogging() { - // Arrange. - LogFlags.logAccessibilityTree = true - val receiver = FlagsBroadcastReceiver() - val intent = Intent("accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS") - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.logAccessibilityTree).isFalse() - } - - @Test - fun onReceive_intentWithEnableAction_shouldEnableTreeLogging() { - // Arrange. - LogFlags.logAccessibilityTree = false - val receiver = FlagsBroadcastReceiver() - val intent = Intent("accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS") - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.logAccessibilityTree).isTrue() - } - - @Test - fun onReceive_intentWithSetGrpcActionNoArgs_shouldDefaultToEmuIpAndPortZero() { - // Arrange. - LogFlags.grpcHost = "some_host" - LogFlags.grpcPort = 9999 - val receiver = FlagsBroadcastReceiver() - val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC") - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2") - assertThat(LogFlags.grpcPort).isEqualTo(0) - } - - @Test - fun onReceive_intentWithSetGrpcActionWithHostNoPort_shouldDefaultPortToZero() { - // Arrange. - LogFlags.grpcHost = "some_host" - LogFlags.grpcPort = 9999 - val receiver = FlagsBroadcastReceiver() - val intent = - Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { - putExtra("host", "awesome.server.ca") - } - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.grpcHost).isEqualTo("awesome.server.ca") - assertThat(LogFlags.grpcPort).isEqualTo(0) - } - - @Test - fun onReceive_intentWithSetGrpcActionWithPortNoHost_shouldDefaultHostToEmuIp() { - // Arrange. - LogFlags.grpcHost = "some_host" - LogFlags.grpcPort = 9999 - val receiver = FlagsBroadcastReceiver() - val intent = - Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("port", 54321) } - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2") - assertThat(LogFlags.grpcPort).isEqualTo(54321) - } - - @Test - fun onReceive_intentWithSetGrpcActionWithHostAndPort_shouldSetBoth() { - // Arrange. - LogFlags.grpcHost = "some_host" - LogFlags.grpcPort = 9999 - val receiver = FlagsBroadcastReceiver() - val intent = - Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { - putExtra("host", "grpc.ca") - putExtra("port", 54321) - } - - // Act. - receiver.onReceive(context = null, intent = intent) - - // Assert. - assertThat(LogFlags.grpcHost).isEqualTo("grpc.ca") - assertThat(LogFlags.grpcPort).isEqualTo(54321) - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt deleted file mode 100644 index 7af6fbde..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -/** - * Controls global settings in AccessibilityForwarder. - * - * Please note that this class is not thread safe. - */ -object LogFlags { - // Whether to log the accessibility tree. - var logAccessibilityTree: Boolean = false - // How frequent to emit a11y trees (in milliseconds). - var a11yTreePeriodMs: Long = 100 - - // The gRPC server to connect to. (Only available if grpcPort>0). - var grpcHost: String = "" - // If >0 this represents the gRPC port number to connect to. - var grpcPort: Int = 0 -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt deleted file mode 100644 index 4773a5ca..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import android.view.accessibility.AccessibilityNodeInfo -import com.google.auto.value.AutoValue - -/** Parent and child [AccessibilityNodeInfo] relationship. */ -@AutoValue -internal abstract class ParentChildNodePair { - abstract fun parent(): AccessibilityNodeInfo? - - abstract fun child(): AccessibilityNodeInfo - - /** [ParentChildNodePair] builder. */ - @AutoValue.Builder - abstract class Builder { - abstract fun parent(parent: AccessibilityNodeInfo?): Builder - - abstract fun child(child: AccessibilityNodeInfo): Builder - - abstract fun build(): ParentChildNodePair - } - - companion object { - @JvmStatic fun builder(): Builder = AutoValue_ParentChildNodePair.Builder() - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt deleted file mode 100644 index eeedacd5..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.androidenv.accessibilityforwarder - -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger -import java.util.function.Function - -/** Thread-safe helper class for assigning a unique ID to an object. */ -internal class UniqueIdsGenerator { - private val nextId = AtomicInteger(0) - private val uniqueIdsByNode = ConcurrentHashMap() - - fun getUniqueId(a: A): Int { - return uniqueIdsByNode.computeIfAbsent(a, Function { _: A -> nextId.getAndIncrement() }) - } -} diff --git a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml b/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml deleted file mode 100644 index e73943d0..00000000 --- a/android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - - - diff --git a/android_env/components/__init__.py b/android_env/components/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/components/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/components/action_fns.py b/android_env/components/action_fns.py deleted file mode 100644 index e290e9d0..00000000 --- a/android_env/components/action_fns.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions to convert actions between different components' formats.""" - -from absl import logging -from android_env.components import action_type as action_type_lib -from android_env.components import errors -from android_env.components import pixel_fns -from android_env.components.simulators import base_simulator -import numpy as np - - -def send_action_to_simulator( - action: dict[str, np.ndarray], - simulator: base_simulator.BaseSimulator, - screen_width: int, - screen_height: int, - num_fingers: int, -) -> bool: - """Sends the selected action to the given simulator. - - The simulator will interpret the action according to `action["action_type"]`. - The effect this action triggers in the Android OS will be determined by the - currently running application. - - Args: - action: action which will get interpreted as a touchscreen event. - simulator: The simulator that will receive the action. - screen_width: The width of the touchscreen in pixels. - screen_height: The height of the touchscreen in pixels. - num_fingers: The number of fingers used in this simulator. - """ - - try: - match action['action_type']: - # If the action is a TOUCH or LIFT, send a touch event to the simulator. - case action_type_lib.ActionType.TOUCH | action_type_lib.ActionType.LIFT: - prepared_action = _prepare_touch_action( - action, screen_width, screen_height, num_fingers - ) - simulator.send_touch(prepared_action) - # If the action is a key event, send a key event to the simulator. - case action_type_lib.ActionType.KEYDOWN: - simulator.send_key(action['keycode'].item(0), event_type='keydown') - case action_type_lib.ActionType.KEYUP: - simulator.send_key(action['keycode'].item(0), event_type='keyup') - case action_type_lib.ActionType.KEYPRESS: - simulator.send_key(action['keycode'].item(0), event_type='keypress') - except errors.SendActionError: - logging.exception('Unable to execute action: %r', action) - return False - - return True - - -def _prepare_touch_action( - action: dict[str, np.ndarray], - screen_width: int, - screen_height: int, - num_fingers: int, -) -> list[tuple[int, int, bool, int]]: - """Turns an AndroidEnv action into values that the simulator can interpret. - - Converts float-valued 'touch_position' to integer coordinates corresponding - to specific pixels, and 'action_type' to booleans indicating whether the - screen is touched at said location or not. The result of this function can - be sent directly to the underlying simulator (e.g. the Android Emulator, - virtual machine, or a phone). - - Args: - action: An action containing 'action_type' and 'touch_position'. - - Returns: - A tuple with the format (x: int, y: int, down/up: bool, finger_index: int). - """ - - touch_events = [] - for i, finger_action in enumerate(_split_touch_action(action, num_fingers)): - is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH - touch_position = finger_action['touch_position'] - touch_pixels = pixel_fns.touch_position_to_pixel_position( - touch_position, width_height=(screen_width, screen_height) - ) - touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i)) - return touch_events - - -def _split_touch_action( - action: dict[str, np.ndarray], num_fingers: int -) -> list[dict[str, np.ndarray]]: - """Splits a multitouch action into a list of single-touch actions.""" - - single_touch_actions = [{ - 'action_type': action['action_type'], - 'touch_position': action['touch_position'], - }] - for i in range(2, num_fingers + 1): - single_touch_actions.append({ - 'action_type': action[f'action_type_{i}'], - 'touch_position': action[f'touch_position_{i}'], - }) - return single_touch_actions - - -def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]: - """A lift action with each finger.""" - - # There's always at least one finger. - lift_action = { - 'action_type': np.array(action_type_lib.ActionType.LIFT), - 'touch_position': np.array([0, 0]), - } - # Subsequent fingers have separate dict entries. - for i in range(2, num_fingers + 1): - lift_action |= { - f'action_type_{i}': np.array(action_type_lib.ActionType.LIFT), - f'touch_position_{i}': np.array([0, 0]), - } - return lift_action diff --git a/android_env/components/action_fns_test.py b/android_env/components/action_fns_test.py deleted file mode 100644 index 5612b6da..00000000 --- a/android_env/components/action_fns_test.py +++ /dev/null @@ -1,227 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import action_fns -from android_env.components import action_type as action_type_lib -from android_env.components import errors -from android_env.components.simulators import base_simulator -import numpy as np - - -class ActionFnsTest(parameterized.TestCase): - - def test_send_action_to_simulator_missing_action_type(self): - """A `KeyError` should be raised if the action is missing "action_type".""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = {'some_key': np.array(123, np.int32)} - - # Act & Assert. - self.assertRaises( - KeyError, - action_fns.send_action_to_simulator, - action, - simulator, - 800, - 600, - 1, - ) - - def test_send_action_to_simulator_sendactionerror(self): - """Returns `False` if the simulator raises a SendActionError.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - simulator.send_touch.side_effect = errors.SendActionError('oops!') - action = { - 'action_type': action_type_lib.ActionType.TOUCH, - 'touch_position': np.array([0.3, 0.5], np.float32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 1, - ) - - # Assert. - self.assertFalse(output) - simulator.send_touch.assert_called_once() - - def test_send_action_to_simulator_touch_success_one_finger(self): - """Returns `True` with a proper 1-finger touch action.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = { - 'action_type': action_type_lib.ActionType.TOUCH, - 'touch_position': np.array([0.2, 0.5], np.float32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 1, - ) - - # Assert. - self.assertTrue(output) - simulator.send_touch.assert_called_once_with( - [(np.int32(160), np.int32(300), True, 0)] - ) - - def test_send_action_to_simulator_touch_success_multiple_finger(self): - """Returns `True` with a proper 3-finger touch action.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = { - 'action_type': action_type_lib.ActionType.TOUCH, - 'touch_position': np.array([0.2, 0.5], np.float32), - 'action_type_2': action_type_lib.ActionType.LIFT, - 'touch_position_2': np.array([0.1, 0.2], np.float32), - 'action_type_3': action_type_lib.ActionType.TOUCH, - 'touch_position_3': np.array([0.5, 0.2], np.float32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 3, - ) - - # Assert. - self.assertTrue(output) - simulator.send_touch.assert_called_once_with([ - (np.int32(160), np.int32(300), True, 0), - (np.int32(80), np.int32(120), False, 1), - (np.int32(400), np.int32(120), True, 2), - ]) - - def test_send_action_to_simulator_keydown_success(self): - """Returns `True` with a proper keydown action.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = { - 'action_type': action_type_lib.ActionType.KEYDOWN, - 'keycode': np.array([21], np.int32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 1, - ) - - # Assert. - self.assertTrue(output) - simulator.send_key.assert_called_once_with(21, event_type='keydown') - - def test_send_action_to_simulator_keyup_success(self): - """Returns `True` with a proper keyup action.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = { - 'action_type': action_type_lib.ActionType.KEYUP, - 'keycode': np.array([42], np.int32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 1, - ) - - # Assert. - self.assertTrue(output) - simulator.send_key.assert_called_once_with(42, event_type='keyup') - - def test_send_action_to_simulator_keypress_success(self): - """Returns `True` with a proper keypress action.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - action = { - 'action_type': action_type_lib.ActionType.KEYPRESS, - 'keycode': np.array([96], np.int32), - } - - # Act. - output = action_fns.send_action_to_simulator( - action, - simulator, - 800, - 600, - 1, - ) - - # Assert. - self.assertTrue(output) - simulator.send_key.assert_called_once_with(96, event_type='keypress') - - @parameterized.named_parameters( - ( - 'one_finger', - 1, - { - 'action_type': np.array(action_type_lib.ActionType.LIFT), - 'touch_position': np.array([0, 0]), - }, - ), - ( - 'two_fingers', - 2, - { - 'action_type': np.array(action_type_lib.ActionType.LIFT), - 'touch_position': np.array([0, 0]), - 'action_type_2': np.array(action_type_lib.ActionType.LIFT), - 'touch_position_2': np.array([0, 0]), - }, - ), - ) - def test_lift_all_fingers_action( - self, num_fingers: int, expected_action: dict[str, np.ndarray] - ): - """Returns the expected action.""" - - output = action_fns.lift_all_fingers_action(num_fingers) - for k, v in expected_action.items(): - np.testing.assert_array_equal(v, output[k]) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/action_type.py b/android_env/components/action_type.py deleted file mode 100644 index da57aa4f..00000000 --- a/android_env/components/action_type.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The different kinds of actions that AndroidEnv supports. - -The native action space of AndroidEnv consists of a tuple consisting of -- A position (x, y) ∈ [0, 1] x [0, 1], determining the location of the action on - the screen, and -- A discrete value, indicating the action type, which is in this file. - -See https://arxiv.org/abs/2105.13231, section 2.2 for details. -""" - -import enum - - -@enum.unique -class ActionType(enum.IntEnum): - """Integer values to describe each supported action in AndroidEnv. - - Note for KEY* types: - - Only meaningful if connected to a _physical_ keyboard, _not_ virtual - keyboard. - - Added afterwards so they did not appear in the paper. - - Attributes: - TOUCH: Touching the screen at a location. - LIFE: Lifting the (imaginary) pointer from the screen at a location. - REPEAT: Repeating the last chosen action. - KEYDOWN: Sending a key down event. - KEYUP: Sending a key up event. - KEYPRESS: Sending a key down event, immediately followed by a key up event. - """ - - TOUCH = 0 - LIFT = 1 - REPEAT = 2 - KEYDOWN = 3 - KEYUP = 4 - KEYPRESS = 5 diff --git a/android_env/components/adb_call_parser.py b/android_env/components/adb_call_parser.py deleted file mode 100644 index 53efe478..00000000 --- a/android_env/components/adb_call_parser.py +++ /dev/null @@ -1,915 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Processes adb_pb2.AdbRequest commands.""" - -import os -import re -import subprocess -import sys -import tempfile - -from absl import logging -from android_env.components import adb_controller as adb_control -from android_env.proto import adb_pb2 - -# A mapping from a Button enum to keycode strings. -# -# Please see https://developer.android.com/reference/android/view/KeyEvent -# -# We currently only accept the following entries: -_BUTTON_TO_KEYCODE = { - adb_pb2.AdbRequest.PressButton.Button.HOME: 'KEYCODE_HOME', - adb_pb2.AdbRequest.PressButton.Button.BACK: 'KEYCODE_BACK', - adb_pb2.AdbRequest.PressButton.Button.ENTER: 'KEYCODE_ENTER', -} - - -class AdbCallParser: - """Parses AdbRequest messages and executes corresponding adb commands.""" - - def __init__(self, adb_controller: adb_control.AdbController): - self._adb_controller = adb_controller - self._handlers = { - 'install_apk': self._install_apk, - 'start_activity': self._start_activity, - 'force_stop': self._force_stop, - 'tap': self._tap, - 'press_button': self._press_button, - 'start_screen_pinning': self._start_screen_pinning, - 'send_broadcast': self._send_broadcast, - 'uninstall_package': self._handle_uninstall_package, - 'get_current_activity': self._get_current_activity, - 'get_orientation': self._get_orientation, - 'push': self._push, - 'pull': self._pull, - 'input_text': self._input_text, - 'settings': self._handle_settings, - 'generic': self._handle_generic, - 'package_manager': self._handle_package_manager, - 'dumpsys': self._handle_dumpsys, - } - - def _execute_command( - self, command_args: list[str], timeout: float | None - ) -> tuple[adb_pb2.AdbResponse, bytes]: - """Executes the command, catches errors and populates the response status. - - Args: - command_args: a list of arguments for the ADB request. - timeout: Timeout in seconds. - - Returns: - A tuple of the AdbResponse with the status populated, and the output - bytes from the command. - """ - response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - command_output = b'' - try: - command_output = self._adb_controller.execute_command( - command_args, timeout=timeout) - except subprocess.CalledProcessError as adb_error: - if adb_error.stdout is not None: - response.status = adb_pb2.AdbResponse.Status.ADB_ERROR - response.error_message = adb_error.stdout - except subprocess.TimeoutExpired: - response.status = adb_pb2.AdbResponse.Status.TIMEOUT - response.error_message = 'Timeout' - - return response, command_output - - def parse(self, request: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: - """Executes `request` and returns an appropriate response.""" - - response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - command_type = request.WhichOneof('command') - logging.info('AdbRequest command type: %s', command_type) - if command_type is None: - response.status = adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND - response.error_message = 'AdbRequest.command is None.' - return response - - if request.timeout_sec < 0: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ('AdbRequest.timeout_sec cannot be negative. ' - f'Got: {request.timeout_sec}') - return response - - timeout: float | None = request.timeout_sec or None - return self._handlers[command_type](request, timeout) - - def _force_stop( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Stops an application. - - Args: - request: The external request containing the package to force stop. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - force_stop = request.force_stop - response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - if not force_stop.package_name: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = '`force_stop.package_name` cannot be empty.' - return response - - response, _ = self._execute_command( - ['shell', 'am', 'force-stop', force_stop.package_name], timeout) - - return response - - def _fetch_current_task_id( - self, full_activity_name: str, timeout: float | None = None - ) -> int: - """Returns the task ID of the given `full_activity_name`. - - Args: - full_activity_name: The full name of the activity whose corresponding - task id we are looking for. - timeout: Optional time limit in seconds. - Returns: - task_id: An integer corresponding to the specified activity. - """ - - stack = self._adb_controller.execute_command( - ['shell', 'am', 'stack', 'list'], timeout=timeout) - lines = stack.decode('utf-8').splitlines() - - regex = re.compile( - r'^\ *taskId=(?P[0-9]*): (?P[^\s]*) .*visible=true' - r'.*topActivity=ComponentInfo{(?P[^\s]*)}$') - - for line in lines: - match = regex.search(line) - if match is None: - continue - - current_task_id_str = match.group('id') - base_activity = match.group('base_activity') - top_activity = match.group('top_activity') - - # If neither of the matched activities equals the activity we are - # looking for, we discard their task id and continue the search. - if full_activity_name not in {base_activity, top_activity}: - logging.info('Full activity %s was not found in current line %s', - full_activity_name, line) - continue - - # Otherwise return the integer task id. - try: - return int(current_task_id_str) - except ValueError: - logging.info('Failed to parse task ID [%r].', current_task_id_str) - - # At this point if we could not find a task ID, there's nothing we can do. - logging.error('Could not find current activity in stack list: %r', lines) - return -1 - - def _start_screen_pinning( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Pins an application. - - Args: - request: The request containing the activity to pin. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - full_activity = request.start_screen_pinning.full_activity - response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - if not full_activity: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - '`start_screen_pinning.full_activity` cannot be empty.') - return response - - current_task_id = self._fetch_current_task_id(full_activity, timeout) - if current_task_id == -1: - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = ('Could not find task ID for activity ' - f'[{full_activity}]') - return response - - response, _ = self._execute_command( - ['shell', 'am', 'task', 'lock', - str(current_task_id)], timeout=timeout) - - return response - - def _send_broadcast( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Sends a broadcast. - - Args: - request: The request with the information for the broadcast event. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - send_broadcast = request.send_broadcast - response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - if not send_broadcast.action: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ('`send_broadcast.{action}` cannot be empty.') - return response - - if send_broadcast.component: - component_args = ['-n', send_broadcast.component] - else: - component_args = [] - - response, _ = self._execute_command( - ['shell', 'am', 'broadcast', '-a', send_broadcast.action] - + component_args, - timeout=timeout, - ) - - return response - - def _install_apk( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Installs an app given its local path in the filesystem. - - Args: - request: The external request with an install_apk field. - Contains information for the .apk installation. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - install_apk = request.install_apk - response = adb_pb2.AdbResponse() - location_type = install_apk.WhichOneof('location') - logging.info('location_type: %s', location_type) - - match location_type: - case 'filesystem': - fpath = install_apk.filesystem.path - if not os.path.exists(fpath): - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = f'Could not find local_apk_path: {fpath}' - return response - - response, _ = self._execute_command( - ['install', '-r', '-t', '-g', fpath], timeout=timeout - ) - case 'blob': - - # `delete_on_close` was only added in Python 3.12 so we add a switch - # here to still support previous Python versions. - if sys.version_info >= (3, 12): - kwargs = {'suffix': '.apk', 'delete_on_close': False} - else: - kwargs = {'suffix': '.apk'} - - with tempfile.NamedTemporaryFile(**kwargs) as f: - fpath = f.name - f.write(install_apk.blob.contents) - - response, _ = self._execute_command( - ['install', '-r', '-t', '-g', fpath], timeout=timeout - ) - case _: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Unsupported `install_apk.location` type: {location_type}' - ) - return response - - return response - - def _start_activity( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Starts a given activity. - - Options for `start_activity`: - `am start` command options: - -D: enable debugging - -W: wait for launch to complete - --start-profiler : start profiler and send results to - -P : like above, but profiling stops when app goes idle - -R: repeat the activity launch times. Prior to each repeat, - the top activity will be finished. - -S: force stop the target app before starting the activity - --opengl-trace: enable tracing of OpenGL functions - - Args: - request: The request with information on what activity to start. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. If successful, StartActivityResponse will contain the - activity name and adb command output. - """ - - activity = request.start_activity.full_activity - if not activity: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message='`start_activity.full_activity` cannot be empty.') - - force_stop = '-S' if request.start_activity.force_stop else '' - response, command_output = self._execute_command( - ['shell', 'am', 'start', force_stop, '-W', '-n', activity] + - list(request.start_activity.extra_args or []), - timeout=timeout) - - # Check command output for potential errors. - expected_error = re.compile(r""".*Error.*""", re.VERBOSE) - if expected_error.match(str(command_output)): - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.INTERNAL_ERROR, - error_message=f'start_activity failed with error: {command_output}') - - response.start_activity.full_activity = activity - response.start_activity.output = command_output - return response - - def _press_button( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Presses a keyboard key. - - Args: - request: The request with information on what button to press. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - button = request.press_button.button - if button not in _BUTTON_TO_KEYCODE: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message=('PressButton.button must be one of ' - f'[{_BUTTON_TO_KEYCODE.keys()}]. ' - f'Got: {button}. Please see `adb.proto`.')) - - keycode = _BUTTON_TO_KEYCODE[button] - response, command_output = self._execute_command( - ['shell', 'input', 'keyevent', keycode], timeout=timeout) - response.press_button.output = command_output - return response - - def _handle_uninstall_package( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Handles UninstallPackage messages. - - Args: - request: The specification of what to uninstall. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse - """ - - package_name = request.uninstall_package.package_name - response = adb_pb2.AdbResponse() - # Every UninstallPackage should have a package_name. - if not package_name: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - '`uninstall_package.package_name` cannot be empty.') - return response - - # Get list of installed packages and issue an uninstall only if it's - # already installed. - package_response = self._handle_package_manager( - adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - packages=adb_pb2.AdbRequest.PackageManagerRequest.List - .Packages())))) - if package_name in package_response.package_manager.list.items: - response, _ = self._execute_command(['uninstall', package_name], timeout) - else: - msg = (f'Cannot uninstall {package_name} since it is not installed.') - logging.warning(msg) - response.error_message = msg - - return response - - def _get_current_activity( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Fetches current activity. - - Args: - request: The request with the `.get_current_activity` field set. This is - unused, but it's in the signature so that all calls are uniform. - timeout: Optional time limit in seconds. - - Returns: - AdbResponse containing the current activity. - """ - - del request # Unused. - - response, visible_task = self._execute_command( - ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], - timeout=timeout) - - if response.status != adb_pb2.AdbResponse.Status.OK: - return response - - if not visible_task: - _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'], - timeout=timeout) - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = ('Empty visible_task. `am stack list`: ' - f'{am_stack_list}') - return response - - visible_task = visible_task.decode('utf-8') - if sys.platform == 'win32': - visible_task_list = re.findall( - r'visible=true topActivity=ComponentInfo{(.+?)}', visible_task) - if not visible_task_list: - visible_task = '' - else: - visible_task = 'ComponentInfo{' + visible_task_list[0] + '}' - - p = re.compile(r'.*\{(.*)\}') - matches = p.search(visible_task) - if matches is None: - _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'], - timeout=timeout) - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = ( - 'Could not extract current activity. Will return nothing. ' - f'`am stack list`: {am_stack_list}') - return response - - response.get_current_activity.full_activity = matches.group(1) - return response - - def _get_orientation( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Fetches current device orientation. - - Args: - request: The request with the `.get_orientation` field set. - timeout: Optional time limit in seconds. - - Returns: - AdbResponse containing the current device orientation. This is - unused, but it's in the signature so that all calls are uniform. - """ - - del request # Unused. - - logging.info('Getting orientation...') - response = self._handle_dumpsys( - adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='input')), - timeout=timeout) - output = response.dumpsys.output - if not output: - logging.error('Empty dumpsys output.') - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = 'Failed to execute `dumpsys input`' - return response - - output = output.decode('utf-8') - lines = output.split('\n') # Split by lines. - skip_next = False - for line in lines: - # There may be multiple devices in output. An invalid device can be - # identified by negative PhysicalWidth. - physical_width = re.match(r'\s+PhysicalWidth:\s+(-?\d+)px', line) - if physical_width: - skip_next = int(physical_width.group(1)) < 0 - - surface_orientation = re.match( - r'\s+(SurfaceOrientation|InputDeviceOrientation):\s+(\d)', line - ) - - if surface_orientation is not None: - if skip_next: - continue - if surface_orientation.re.groups < 2: - continue - orientation = surface_orientation.group(2) - logging.info('Done getting orientation: %r', orientation) - response.get_orientation.orientation = int(orientation) - return response - - response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR - response.error_message = ( - 'Could not find SurfaceOrientation/InputDeviceOrientation in dumpsys ' - 'output' - ) - return response - - def _push( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Uploads contents to the device. - - Args: - request: The request with the contents to push to the device. - timeout: Optional time limit in seconds. - - Returns: - An empty AdbResponse. - """ - - path = request.push.path - if not path: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message='Push.path is empty.') - - # Create temporary file with `push` contents. - with tempfile.NamedTemporaryFile(delete=False) as f: - fname = f.name - f.write(request.push.content) - # Issue `adb push` command to upload file. - logging.info('Uploading %r to %r.', fname, path) - response, _ = self._execute_command(['push', fname, path], timeout=timeout) - # Delete it. - os.remove(fname) - - return response - - def _pull( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Downloads file content from the device. - - Args: - request: The request with the information on what to get from the device. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse with the contents of the specified file. - """ - - path = request.pull.path - if not path: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message='Pull.path is empty.') - - # Issue `adb pull` command to copy it to a temporary file. - with tempfile.NamedTemporaryFile(delete=False) as f: - fname = f.name - logging.info('Downloading %r to %r.', path, fname) - response, _ = self._execute_command(['pull', path, fname], - timeout=timeout) - # Read the content of the file. - with open(fname, 'rb') as f: - response.pull.content = f.read() - # Delete it. - os.remove(fname) - - return response - - def _input_text( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Inserts text as keyboard events. - - Args: - request: The external request. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse - """ - - text = request.input_text.text - if not text: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message='InputText.text is empty.') - - response, _ = self._execute_command(['shell', 'input', 'text', text], - timeout=timeout) - return response - - def _tap( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Taps the device screen. - - Args: - request: The request with information on where to tap the screen. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse - """ - - x = request.tap.x - y = request.tap.y - # Check for negative coordinates. - # Notice that zero coordinates are valid coordinates (i.e. the first - # column/row of the screen). - if x < 0 or y < 0: - return adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, - error_message=( - f'Tap coordinates must be non-negative. Got: {request.tap}.')) - - response, _ = self._execute_command( - ['shell', 'input', 'tap', str(x), - str(y)], timeout=timeout) - - return response - - def _handle_settings( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Handles SettingsRequest messages. - - Args: - request: The specification of what to do with settings. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse - """ - - request = request.settings - response = adb_pb2.AdbResponse() - # Every SettingsRequest should have a namespace. - if request.name_space == adb_pb2.AdbRequest.SettingsRequest.Namespace.UNKNOWN: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Unknown SettingsRequest.name_space. Got: {request}.') - return response - - namespace = adb_pb2.AdbRequest.SettingsRequest.Namespace.Name( - request.name_space).lower() - - match request.WhichOneof('verb'): - case 'get': - get = request.get - if not get.key: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Empty SettingsRequest.get.key. Got: {request}.' - ) - return response - response, command_output = self._execute_command( - ['shell', 'settings', 'get', namespace, get.key], timeout=timeout - ) - response.settings.output = command_output - case 'put': - put = request.put - if not put.key or not put.value: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Empty SettingsRequest.put key or value. Got: {request}.' - ) - return response - response, command_output = self._execute_command( - ['shell', 'settings', 'put', namespace, put.key, put.value], - timeout=timeout, - ) - response.settings.output = command_output - case 'delete_key': - delete = request.delete_key - if not delete.key: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Empty SettingsRequest.delete_key.key. Got: {request}.' - ) - return response - response, command_output = self._execute_command( - ['shell', 'settings', 'delete', namespace, delete.key], - timeout=timeout, - ) - response.settings.output = command_output - case 'reset': - reset = request.reset - # At least one of `package_name` or `mode` should be given. - if ( - not reset.package_name - and reset.mode - == adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN - ): - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - 'At least one of SettingsRequest.reset package_name or mode' - f' should be given. Got: {request}.' - ) - return response - - mode = adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.Name( - reset.mode - ).lower() - arg = reset.package_name or mode - response, command_output = self._execute_command( - ['shell', 'settings', 'reset', namespace, arg], timeout=timeout - ) - response.settings.output = command_output - case 'list': - response, command_output = self._execute_command( - ['shell', 'settings', 'list', namespace], timeout=timeout - ) - response.settings.output = command_output - case _: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Unknown SettingsRequest.verb. Got: {request}.' - ) - - return response - - def _handle_generic( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Handles GenericRequest messages. - - Args: - request: The request with the `.generic` field set indicating what `adb` - shell command to issue - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse - """ - - response, command_output = self._execute_command( - list(request.generic.args), timeout) - response.generic.output = command_output - return response - - def _handle_package_manager( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Handles PackageManagerRequest messages. - - Args: - request: The request with the `.package_manager` field set containing the - sub-commands to issue to `adb pm`. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - request = request.package_manager - response = adb_pb2.AdbResponse() - - match request.WhichOneof('verb'): - case 'list': - what = request.list.WhichOneof('what') - response, output = self._execute_command( - ['shell', 'pm', 'list', what], timeout=timeout - ) - - if output: - items = output.decode('utf-8').split() - # Remove prefix for each item. - prefix = { - 'features': 'feature:', - 'libraries': 'library:', - 'packages': 'package:', - }[what] - items = [x[len(prefix) :] for x in items if x.startswith(prefix)] - response.package_manager.list.items.extend(items) - response.package_manager.output = output - case 'clear': - package_name = request.clear.package_name - if not package_name: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - f'Empty PackageManagerRequest.clear.package_name. Got: {request}.' - ) - return response - - args = ['shell', 'pm', 'clear', package_name] - if request.clear.user_id: - args.insert(3, '-f') - args.insert(4, request.clear.user_id) - response, response.package_manager.output = self._execute_command( - args, timeout=timeout - ) - case 'grant': - grant = request.grant - if not grant.package_name: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = '`grant.package_name` cannot be empty.' - return response - - if not grant.permissions: - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = '`grant.permissions` cannot be empty.' - return response - - for permission in grant.permissions: - logging.info('Granting permission: %r', permission) - response, response.package_manager.output = self._execute_command( - ['shell', 'pm', 'grant', grant.package_name, permission], - timeout=timeout, - ) - - return response - - def _handle_dumpsys( - self, request: adb_pb2.AdbRequest, timeout: float | None = None - ) -> adb_pb2.AdbResponse: - """Handles DumpsysRequest messages. - - Args: - request: The request with the `.dumpsys` field set containing - sub-commands to `adb dumpsys` shell command.. - timeout: Optional time limit in seconds. - - Returns: - An AdbResponse. - """ - - request = request.dumpsys - cmd = ['shell', 'dumpsys'] - - if request.timeout_sec < 0 or request.timeout_ms < 0: - response = adb_pb2.AdbResponse() - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - 'DumpsysRequest.timeout_{sec, ms} should be non-negative. ' - f'Got: {request}.') - return response - - if request.list_only: - # `-l` cannot be combined with the following options. - if request.service or request.args or request.skip_services: - response = adb_pb2.AdbResponse() - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - 'DumpsysRequest.list_only cannot be combined with other options. ' - f'Got: {request}.') - return response - - cmd.append('-l') - - if request.timeout_sec > 0: - cmd.append('-t') - cmd.append(str(request.timeout_sec)) - elif request.timeout_ms > 0: - cmd.append('-T') - cmd.append(str(request.timeout_ms)) - - if request.priority != adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET: - cmd.append('--priority') - cmd.append(adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.Name( - request.priority)) - - if request.skip_services: - if request.service: - response = adb_pb2.AdbResponse() - response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - response.error_message = ( - 'DumpsysRequest.skip_services cannot be combined with `service`. ' - f'Got: {request}.') - return response - - cmd.append('--skip') - cmd.append(','.join(request.skip_services)) - - if request.service: - cmd.append(request.service) - - if request.args: - cmd += list(request.args) - - if request.proto: - cmd.append('--proto') - - response, response.dumpsys.output = self._execute_command( - cmd, timeout=timeout) - - return response diff --git a/android_env/components/adb_call_parser_test.py b/android_env/components/adb_call_parser_test.py deleted file mode 100644 index 7b178d71..00000000 --- a/android_env/components/adb_call_parser_test.py +++ /dev/null @@ -1,1215 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import builtins -import os -import subprocess -import sys -import tempfile -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import adb_call_parser -from android_env.components import adb_controller -from android_env.proto import adb_pb2 - - -class AdbCallParserTest(parameterized.TestCase): - - def test_unknown_command(self): - """Gets UNKNOWN_COMMAND for an empty request.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - response = parser.parse(request) - self.assertEqual( - response.status, adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND - ) - - def test_invalid_timeout(self): - """AdbRequest.timeout_sec must be positive.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.tap.x = 123 - request.timeout_sec = -5 - response = parser.parse(request) - self.assertEqual( - response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - ) - - @mock.patch.object(os.path, 'exists', autospec=True) - def test_install_apk_file_not_found(self, mock_exists): - """Should fail installing APK when it is not found.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.install_apk.filesystem.path = '/my/home/game.apk' - mock_exists.return_value = False - - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - @mock.patch.object(os.path, 'exists', autospec=True) - def test_install_apk_successful(self, mock_exists): - """Should succeed installing an arbitrary APK.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.install_apk.filesystem.path = '/my/home/game.apk' - mock_exists.return_value = True - - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['install', '-r', '-t', '-g', '/my/home/game.apk'], None) - - @mock.patch.object(tempfile, 'NamedTemporaryFile', autospec=True) - def test_install_apk_from_blob(self, mock_tempfile): - """Should succeed installing APK from blob.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - blob_content = b'A fake blob content' - request.install_apk.blob.contents = blob_content - mock_tempfile.return_value.__enter__.return_value.name = '/my/home/test.apk' - mock_tempfile.return_value.__enter__.return_value.write.return_value = None - - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['install', '-r', '-t', '-g', '/my/home/test.apk'], None - ) - # pytype: disable=attribute-error - expected_tempfile_kwargs = ( - {'suffix': '.apk', 'delete_on_close': False} - if sys.version_info > (3, 12) - else {'suffix': '.apk'} - ) - mock_tempfile.assert_has_calls([ - mock.call(**expected_tempfile_kwargs), # Constructor - mock.call().__enter__(), # Enter context - mock.call().__enter__().write(blob_content), # Call write function - mock.call().__exit__(None, None, None), # Exit context - ]) - # pytype: enable=attribute-error - - def test_start_activity_empty_full_activity(self): - """A start_activity command should always have a nonempty activity.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_activity.extra_args.extend(['blah']) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_start_activity_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - command_output = (b'Stopping: my.project.SplashActivity\n' - b'Starting: Intent { cmp=my.project.SplashActivity }\n') - adb.execute_command.return_value = command_output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_activity.full_activity = 'my.project.SplashActivity' - request.start_activity.extra_args.extend(['blah']) - request.start_activity.force_stop = True - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call([ - 'shell', 'am', 'start', '-S', '-W', '-n', - 'my.project.SplashActivity', 'blah' - ], - timeout=None), - ]) - - def test_start_activity_successful_no_force_stop(self): - adb = mock.create_autospec(adb_controller.AdbController) - command_output = (b'Stopping: my.project.SplashActivity\n' - b'Starting: Intent { cmp=my.project.SplashActivity }\n') - adb.execute_command.return_value = command_output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_activity.full_activity = 'my.project.SplashActivity' - request.start_activity.extra_args.extend(['blah']) - request.start_activity.force_stop = False - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call([ - 'shell', 'am', 'start', '', '-W', '-n', 'my.project.SplashActivity', - 'blah' - ], - timeout=None), - ]) - - def test_start_activity_error(self): - adb = mock.create_autospec(adb_controller.AdbController) - command_output = (b'Stopping: my.project.SplashActivity\n' - b'Starting: Intent { cmp=my.project.SplashActivity }\n' - b'Error: Activity not started, unknown error code 101\n') - adb.execute_command.return_value = command_output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_activity.full_activity = 'my.project.SplashActivity' - request.start_activity.extra_args.extend(['blah']) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertEqual( - response.error_message, - f'start_activity failed with error: {str(command_output)}') - - def test_force_stop(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.force_stop.package_name = 'my.project' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'am', 'force-stop', 'my.project'], None) - - def test_grant_permissions_empty_package_name(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.package_manager.grant.permissions.extend(['perm1', 'perm2']) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_grant_permissions_empty_permissions(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.package_manager.grant.package_name = 'my.project' - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_grant_permissions_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.package_manager.grant.package_name = 'my.project' - request.package_manager.grant.permissions.extend(['perm1', 'perm2']) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call(['shell', 'pm', 'grant', 'my.project', 'perm1'], None), - mock.call(['shell', 'pm', 'grant', 'my.project', 'perm2'], None), - ]) - - def test_press_button_invalid_button(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.press_button.button = 99999 - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_press_button_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'' - parser = adb_call_parser.AdbCallParser(adb) - # HOME. - request = adb_pb2.AdbRequest() - request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.HOME - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_with( - ['shell', 'input', 'keyevent', 'KEYCODE_HOME'], None) - # BACK. - request = adb_pb2.AdbRequest() - request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.BACK - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_with( - ['shell', 'input', 'keyevent', 'KEYCODE_BACK'], None) - # ENTER. - request = adb_pb2.AdbRequest() - request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.ENTER - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_with( - ['shell', 'input', 'keyevent', 'KEYCODE_ENTER'], None) - - def test_start_screen_pinning_package_not_found(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = ( - b' taskId=12345: my.project.AnotherActivity visible=true' - b' topActivity=ComponentInfo{my.project.AnotherActivity}') - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_screen_pinning.full_activity = 'my.project.AmazingActivity' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'am', 'stack', 'list'], None) - - def test_start_screen_pinning_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = ( - b' taskId=12345: my.project.AmazingActivity visible=true' - b' topActivity=ComponentInfo{my.project.AmazingActivity}') - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_screen_pinning.full_activity = 'my.project.AmazingActivity' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call(['shell', 'am', 'stack', 'list'], None), - mock.call(['shell', 'am', 'task', 'lock', '12345'], None), - ]) - - def test_start_screen_pinning_base_activity(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = ( - b' taskId=12345: my.project.MainActivity visible=true' - b' topActivity=ComponentInfo{my.project.TopActivity}') - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_screen_pinning.full_activity = 'my.project.MainActivity' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call(['shell', 'am', 'stack', 'list'], None), - mock.call(['shell', 'am', 'task', 'lock', '12345'], None), - ]) - - def test_start_screen_pinning_top_activity(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = ( - b' taskId=12345: my.project.MainActivity visible=true' - b' topActivity=ComponentInfo{my.project.TopActivity}') - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.start_screen_pinning.full_activity = 'my.project.TopActivity' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call(['shell', 'am', 'stack', 'list'], None), - mock.call(['shell', 'am', 'task', 'lock', '12345'], None), - ]) - - def test_send_broadcast_empty_action(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - send_broadcast=adb_pb2.AdbRequest.SendBroadcast()) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_send_broadcast_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.send_broadcast.action = 'SOME-ACTION' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - - def test_send_broadcast_with_component_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.send_broadcast.action = 'SOME-ACTION' - request.send_broadcast.component = 'SOME-COMPONENT' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - - def test_uninstall_package_empty_package_name(self): - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.uninstall_package.package_name = '' - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - - def test_uninstall_package_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'package:my.package' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest() - request.uninstall_package.package_name = 'my.package' - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - - def test_get_current_activity_no_visible_task(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = None - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call( - ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], - None), - mock.call(['shell', 'am', 'stack', 'list'], None), - ]) - - def test_get_orientation_empty_dumpsys(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], - None) - - def test_get_orientation_invalid_device_no_surface_orientation(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b' PhysicalWidth: -123px' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], - None) - - @parameterized.named_parameters( - ('rotation_0', b""" SurfaceOrientation: 0""", 0), - ('rotation_90', b""" SurfaceOrientation: 1""", 1), - ('rotation_180', b""" SurfaceOrientation: 2""", 2), - ('rotation_270', b""" SurfaceOrientation: 3""", 3), - ('rotation_0_new', b""" InputDeviceOrientation: 0""", 0), - ('rotation_90_new', b""" InputDeviceOrientation: 1""", 1), - ('rotation_180_new', b""" InputDeviceOrientation: 2""", 2), - ('rotation_270_new', b""" InputDeviceOrientation: 3""", 3), - ) - def test_get_orientation_success( - self, orientation: bytes, expected_orientation: int - ): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = ( - b"""SomeRandomKey: 12345\n""" + orientation + b""" - MoreRandomStuff: awesome_value -""" - ) - - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.get_orientation.orientation, expected_orientation) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], - None) - - def test_get_current_activity_no_matches(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) - for platform in ['win32', 'linux']: - with mock.patch.object( - sys, 'platform', autospec=True, return_value=platform): - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.INTERNAL_ERROR) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_has_calls([ - mock.call([ - 'shell', 'am', 'stack', 'list', '|', 'grep', '-E', - 'visible=true' - ], None), - mock.call(['shell', 'am', 'stack', 'list'], None), - ]) - - def test_get_current_activity_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'{MyAwesomeActivity}' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) - for platform in ['win32', 'linux']: - with mock.patch.object( - sys, 'platform', autospec=True, return_value=platform): - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - # `execute_command` will be called once for each platform. - adb.execute_command.assert_called_with( - ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], - None) - self.assertEqual(response.get_current_activity.full_activity, - 'MyAwesomeActivity') - - def test_push_no_path(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - push=adb_pb2.AdbRequest.Push(content=b'Has content but no path')) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_push_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - push=adb_pb2.AdbRequest.Push( - content=b'My text.', path='/sdcard/my_file.txt')) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once() - args, kwargs = adb.execute_command.call_args - self.assertLen(args, 1) - cmd_args = args[0] - self.assertLen(cmd_args, 3) - self.assertEqual(cmd_args[0], 'push') - self.assertEqual(cmd_args[2], '/sdcard/my_file.txt') - self.assertIn('timeout', kwargs) - self.assertIsNone(kwargs['timeout']) - - def test_pull_no_path(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(pull=adb_pb2.AdbRequest.Pull()) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - @mock.patch.object(builtins, 'open', autospec=True) - def test_pull_successful(self, mock_open): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - mock_open.return_value.__enter__ = mock_open - mock_open.return_value.read.return_value = b'S3cR3t. dO nOt TeLl ANYONE' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - pull=adb_pb2.AdbRequest.Pull(path='/sdcard/my_file.txt')) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.pull.content, b'S3cR3t. dO nOt TeLl ANYONE') - adb.execute_command.assert_called_once() - args, kwargs = adb.execute_command.call_args - self.assertLen(args, 1) - cmd_args = args[0] - self.assertLen(cmd_args, 3) - self.assertEqual(cmd_args[0], 'pull') - self.assertEqual(cmd_args[1], '/sdcard/my_file.txt') - self.assertIn('timeout', kwargs) - self.assertIsNone(kwargs['timeout']) - - def test_input_text_no_text(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(input_text=adb_pb2.AdbRequest.InputText()) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_input_text_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - input_text=adb_pb2.AdbRequest.InputText( - text='The Greatest Text of All Time')) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'input', 'text', 'The Greatest Text of All Time'], None) - - @parameterized.named_parameters( - ('negative_x_and_negative_y', - adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=-1))), - ('negative_x', - adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=123))), - ('negative_y', - adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=456, y=-1))), - ) - def test_tap_failed(self, request: adb_pb2.AdbRequest): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_tap_successful(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=135, y=246)) - response = parser.parse(request) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'input', 'tap', '135', '246'], None) - - @parameterized.named_parameters( - ('empty_request', adb_pb2.AdbRequest.SettingsRequest()), - ('no_namespace', - adb_pb2.AdbRequest.SettingsRequest( - get=adb_pb2.AdbRequest.SettingsRequest.Get(key='my_key'))), - ('get_no_key', - adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - get=adb_pb2.AdbRequest.SettingsRequest.Get())), - ('put_no_key', - adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - put=adb_pb2.AdbRequest.SettingsRequest.Put())), - ('put_no_value', - adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - put=adb_pb2.AdbRequest.SettingsRequest.Put(key='another_key'))), - ('delete_no_key', - adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete())), - ('reset_no_package_name_and_no_mode', - adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - reset=adb_pb2.AdbRequest.SettingsRequest.Reset())), - ) - def test_settings_failures(self, request): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_settings_success_get(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'here it is!' - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - get=adb_pb2.AdbRequest.SettingsRequest.Get(key='some_key')) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.settings.output, b'here it is!') - adb.execute_command.assert_called_once_with( - ['shell', 'settings', 'get', 'system', 'some_key'], None) - - def test_settings_success_put(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'Done for ya!' - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE, - put=adb_pb2.AdbRequest.SettingsRequest.Put(key='key1', value='val2')) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.settings.output, b'Done for ya!') - adb.execute_command.assert_called_once_with( - ['shell', 'settings', 'put', 'secure', 'key1', 'val2'], None) - - def test_settings_success_delete(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'Key deleted.' - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, - delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete(key='useless_key')) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.settings.output, b'Key deleted.') - adb.execute_command.assert_called_once_with( - ['shell', 'settings', 'delete', 'global', 'useless_key'], None) - - @parameterized.named_parameters( - ('mode_untrusted_defaults', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, '', - 'untrusted_defaults'), - ('mode_untrusted_clear', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, '', - 'untrusted_clear'), - ('mode_trusted_defaults', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, '', - 'trusted_defaults'), - # If `package_name` is given, it takes precedence over `mode`. - ('mode_unknown_package_given', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN, 'great.package', - 'great.package'), - ('mode_untrusted_defaults_package_given', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, - 'great.package', 'great.package'), - ('mode_untrusted_clear_package_given', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, - 'great.package', 'great.package'), - ('mode_trusted_defaults_package_given', - adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, - 'great.package', 'great.package'), - ) - def test_settings_success_reset(self, mode, package_name, expected_arg): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'Pkg reset.' - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, - reset=adb_pb2.AdbRequest.SettingsRequest.Reset( - package_name=package_name, mode=mode)) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.settings.output, b'Pkg reset.') - adb.execute_command.assert_called_once_with( - ['shell', 'settings', 'reset', 'global', expected_arg], None) - - def test_settings_success_list(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'volume_ring=5\nvolume_system=7' - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - list=adb_pb2.AdbRequest.SettingsRequest.List()) - request = adb_pb2.AdbRequest(settings=request) - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - self.assertEqual(response.settings.output, - b'volume_ring=5\nvolume_system=7') - adb.execute_command.assert_called_once_with( - ['shell', 'settings', 'list', 'system'], None) - - def test_generic_command(self): - adb = mock.create_autospec(adb_controller.AdbController) - expected_output = b'generic_output' - args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] - adb.execute_command.return_value = expected_output - parser = adb_call_parser.AdbCallParser(adb) - - generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) - request = adb_pb2.AdbRequest(generic=generic_request) - response = parser.parse(request) - - self.assertEqual(adb_pb2.AdbResponse.Status.OK, response.status) - self.assertEmpty(response.error_message) - self.assertEqual(response.generic.output, expected_output) - adb.execute_command.assert_called_once_with(args, None) - - def test_generic_command_adb_error(self): - adb = mock.create_autospec(adb_controller.AdbController) - args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] - adb.execute_command.side_effect = subprocess.CalledProcessError( - cmd='cmd', output='adb_error', returncode=-1) - parser = adb_call_parser.AdbCallParser(adb) - - generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) - request = adb_pb2.AdbRequest(generic=generic_request) - response = parser.parse(request) - - self.assertEqual(adb_pb2.AdbResponse.Status.ADB_ERROR, response.status) - self.assertEqual('adb_error', response.error_message) - self.assertEmpty(response.generic.output) - adb.execute_command.assert_called_once_with(args, None) - - def test_generic_command_timeout(self): - adb = mock.create_autospec(adb_controller.AdbController) - args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] - adb.execute_command.side_effect = subprocess.TimeoutExpired( - cmd='cmd', timeout=10) - parser = adb_call_parser.AdbCallParser(adb) - - generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) - request = adb_pb2.AdbRequest(generic=generic_request) - response = parser.parse(request) - - self.assertEqual(adb_pb2.AdbResponse.Status.TIMEOUT, response.status) - self.assertEqual('Timeout', response.error_message) - self.assertEmpty(response.generic.output) - adb.execute_command.assert_called_once_with(args, None) - - @parameterized.named_parameters( - ('features', - adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - features=adb_pb2.AdbRequest.PackageManagerRequest.List - .Features())))), - ('libraries', - adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - libraries=adb_pb2.AdbRequest.PackageManagerRequest.List - .Libraries())))), - ('packages', - adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - packages=adb_pb2.AdbRequest.PackageManagerRequest.List - .Packages())))), - ) - def test_package_manager_list_bad_output(self, request): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b"""Something irrelevant.""" - parser = adb_call_parser.AdbCallParser(adb) - response = parser.parse(request) - response.package_manager.output = b"""Something irrelevant.""" - self.assertEmpty(response.package_manager.list.items) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once() - - def test_package_manager_list_features(self): - adb = mock.create_autospec(adb_controller.AdbController) - output = b""" -feature:android.hardware.audio.output -feature:android.hardware.bluetooth -feature:android.hardware.camera -feature:android.hardware.fingerprint -feature:android.software.autofill -feature:android.software.backup -feature:android.software.webview -""" - adb.execute_command.return_value = output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - features=adb_pb2.AdbRequest.PackageManagerRequest.List.Features( - )))) - response = parser.parse(request) - self.assertEqual(response.package_manager.output, output) - self.assertEqual(response.package_manager.list.items, [ - 'android.hardware.audio.output', - 'android.hardware.bluetooth', - 'android.hardware.camera', - 'android.hardware.fingerprint', - 'android.software.autofill', - 'android.software.backup', - 'android.software.webview', - ]) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'pm', 'list', 'features'], None) - - def test_package_manager_list_libraries(self): - adb = mock.create_autospec(adb_controller.AdbController) - output = b""" -library:android.ext.shared -library:android.hidl.base-V1.0-java -library:android.hidl.manager-V1.0-java -library:android.net.ipsec.ike -library:android.test.base -library:android.test.mock -library:android.test.runner -library:androidx.window.sidecar -library:com.android.future.usb.accessory -library:com.android.location.provider -library:com.android.media.remotedisplay -library:com.android.mediadrm.signer -library:com.android.nfc_extras -library:com.google.android.gms -library:com.google.android.trichromelibrary -library:javax.obex -library:org.apache.http.legacy -""" - adb.execute_command.return_value = output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - libraries=adb_pb2.AdbRequest.PackageManagerRequest.List - .Libraries()))) - response = parser.parse(request) - self.assertEqual(response.package_manager.output, output) - self.assertEqual(response.package_manager.list.items, [ - 'android.ext.shared', - 'android.hidl.base-V1.0-java', - 'android.hidl.manager-V1.0-java', - 'android.net.ipsec.ike', - 'android.test.base', - 'android.test.mock', - 'android.test.runner', - 'androidx.window.sidecar', - 'com.android.future.usb.accessory', - 'com.android.location.provider', - 'com.android.media.remotedisplay', - 'com.android.mediadrm.signer', - 'com.android.nfc_extras', - 'com.google.android.gms', - 'com.google.android.trichromelibrary', - 'javax.obex', - 'org.apache.http.legacy', - ]) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'pm', 'list', 'libraries'], None) - - def test_package_manager_list_packages(self): - adb = mock.create_autospec(adb_controller.AdbController) - output = b""" -package:com.android.phone -package:com.awesome.company -package:com.another.great.thingie -""" - adb.execute_command.return_value = output - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages( - )))) - response = parser.parse(request) - self.assertEqual(response.package_manager.output, output) - self.assertEqual(response.package_manager.list.items, [ - 'com.android.phone', - 'com.awesome.company', - 'com.another.great.thingie', - ]) - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'pm', 'list', 'packages'], None) - - def test_package_manager_clear_no_package_name(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b"""Something irrelevant.""" - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( - package_name=''))) - response = parser.parse(request) - - self.assertEmpty(response.package_manager.output) - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_package_manager_clear_successful_no_user_id(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b"""Some successful message.""" - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( - package_name='my.package'))) - response = parser.parse(request) - - self.assertEqual(response.package_manager.output, - b"""Some successful message.""") - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'pm', 'clear', 'my.package'], None) - - def test_package_manager_clear_successful_with_user_id(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b"""Some successful message.""" - parser = adb_call_parser.AdbCallParser(adb) - - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( - package_name='my.package', user_id='mrawesome'))) - response = parser.parse(request) - - self.assertEqual(response.package_manager.output, - b"""Some successful message.""") - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'pm', 'clear', '-f', 'mrawesome', 'my.package'], None) - - def test_dumpsys_empty_request(self): - """An empty `DumpsysRequest` is a valid request.""" - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(dumpsys=adb_pb2.AdbRequest.DumpsysRequest()) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys'], - timeout=None) - - @parameterized.named_parameters( - ('negative_timeout_sec', - adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_sec=-1))), - ('negative_timeout_ms', - adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_ms=-2))), - ) - def test_dumpsys_negative_timeouts(self, request): - """`DumpsysRequest.timeout_{sec, ms}` if passed, should be positive.""" - adb = mock.create_autospec(adb_controller.AdbController) - parser = adb_call_parser.AdbCallParser(adb) - - response = parser.parse(request) - - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - @parameterized.named_parameters( - ('both_timeouts_zero', 0, 0, ['shell', 'dumpsys']), - ('sec_takes_precedence_zero', 123, 0, ['shell', 'dumpsys', '-t', '123']), - ('sec_takes_precedence', 123, 456, ['shell', 'dumpsys', '-t', '123']), - ('ms_if_no_sec', 0, 456, ['shell', 'dumpsys', '-T', '456']), - ) - def test_dumpsys_timeout_successful(self, timeout_sec, timeout_ms, expected): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest( - timeout_sec=timeout_sec, timeout_ms=timeout_ms)) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with(expected, timeout=None) - - @parameterized.named_parameters( - ('priority_undefined', - adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET, - ['shell', 'dumpsys']), - ('priority_normal', - adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.NORMAL, - ['shell', 'dumpsys', '--priority', 'NORMAL']), - ('priority_high', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.HIGH, - ['shell', 'dumpsys', '--priority', 'HIGH']), - ('priority_critical', - adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.CRITICAL, - ['shell', 'dumpsys', '--priority', 'CRITICAL']), - ) - def test_dumpsys_priority_timeout_successful(self, priority, expected): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(priority=priority)) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with(expected, timeout=None) - - @parameterized.named_parameters( - ( - 'window_service', - adb_pb2.AdbRequest.DumpsysRequest(list_only=True, service='window'), - ), - ( - 'arbitrary_args', - adb_pb2.AdbRequest.DumpsysRequest( - list_only=True, args=['myoption', 'anotheroption'] - ), - ), - ( - 'skip_usb', - adb_pb2.AdbRequest.DumpsysRequest( - list_only=True, skip_services=['usb'] - ), - ), - ) - def test_dumpsys_list_only_cannot_be_combined( - self, dumpsys_request: adb_pb2.AdbRequest.DumpsysRequest - ): - """When `list_only==True`, the request cannot contain a few fields.""" - - # Arrange. - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest(dumpsys=dumpsys_request) - - # Act. - response = parser.parse(request) - - # Assert. - self.assertEqual( - response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION - ) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_dumpsys_list_only_success(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(list_only=True)) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys', '-l'], - timeout=None) - - def test_dumpsys_skip_services_cannot_combine_with_service(self): - """When using `DumpsysRequest.skip_service`, it cannot contain `.service`.""" - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest( - service='wifi', skip_services=['window', 'usb'])) - - response = parser.parse(request) - - self.assertEqual(response.status, - adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) - self.assertNotEmpty(response.error_message) - adb.execute_command.assert_not_called() - - def test_dumpsys_skip_services(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest( - skip_services=['window', 'usb'])) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'dumpsys', '--skip', 'window,usb'], timeout=None) - - def test_dumpsys_single_service(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window')) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'window'], - timeout=None) - - def test_dumpsys_single_service_with_args(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'whatever' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest( - service='window', args=['arg1', 'arg2'])) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'dumpsys', 'window', 'arg1', 'arg2'], timeout=None) - - def test_dumpsys_single_service_with_proto(self): - adb = mock.create_autospec(adb_controller.AdbController) - adb.execute_command.return_value = b'some binary output' - parser = adb_call_parser.AdbCallParser(adb) - request = adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window', proto=True)) - - response = parser.parse(request) - - self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) - self.assertEmpty(response.error_message) - adb.execute_command.assert_called_once_with( - ['shell', 'dumpsys', 'window', '--proto'], timeout=None) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/adb_controller.py b/android_env/components/adb_controller.py deleted file mode 100644 index a8a1b9d3..00000000 --- a/android_env/components/adb_controller.py +++ /dev/null @@ -1,151 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A class to manage and control an external ADB process.""" - -import os -import subprocess -import time - -from absl import logging -from android_env.components import config_classes -from android_env.components import errors - - -class AdbController: - """Manages communication with adb.""" - - def __init__(self, config: config_classes.AdbControllerConfig): - """Instantiates an AdbController object.""" - - self._config = config - logging.info('config: %r', self._config) - - # Unset problematic environment variables. ADB commands will fail if these - # are set. They are normally exported by AndroidStudio. - if 'ANDROID_HOME' in os.environ: - del os.environ['ANDROID_HOME'] - if 'ANDROID_ADB_SERVER_PORT' in os.environ: - del os.environ['ANDROID_ADB_SERVER_PORT'] - - # Explicitly expand the $HOME environment variable. - self._os_env_vars = dict(os.environ).copy() - self._os_env_vars.update( - {'HOME': os.path.expandvars(self._os_env_vars.get('HOME', ''))} - ) - logging.info('self._os_env_vars: %r', self._os_env_vars) - - def command_prefix(self, include_device_name: bool = True) -> list[str]: - """The command for instantiating an adb client to this server.""" - command_prefix = [ - self._config.adb_path, - '-P', - str(self._config.adb_server_port), - ] - if include_device_name: - command_prefix.extend(['-s', self._config.device_name]) - return command_prefix - - def init_server(self, timeout: float | None = None): - """Initialize the ADB server deamon on the given port. - - This function should be called immediately after initializing the first - adb_controller, and before launching the simulator. - - Args: - timeout: A timeout to use for this operation. If not set the default - timeout set on the constructor will be used. - """ - # Make an initial device-independent call to ADB to start the deamon. - self.execute_command(['devices'], timeout, device_specific=False) - time.sleep(0.2) - - def _restart_server(self, timeout: float | None = None): - """Kills and restarts the adb server. - - Args: - timeout: A timeout to use for this operation. If not set the default - timeout set on the constructor will be used. - """ - logging.info('Restarting adb server.') - self.execute_command( - ['kill-server'], timeout=timeout, device_specific=False) - time.sleep(0.2) - cmd_output = self.execute_command( - ['start-server'], timeout=timeout, device_specific=False) - logging.info('start-server output: %r', cmd_output.decode('utf-8')) - time.sleep(2.0) - self.execute_command( - ['devices'], timeout=timeout, device_specific=False) - time.sleep(0.2) - - def execute_command( - self, - args: list[str], - timeout: float | None = None, - device_specific: bool = True, - ) -> bytes: - """Executes an adb command. - - Args: - args: A list of strings representing each adb argument. - For example: ['install', '/my/app.apk'] - timeout: A timeout to use for this operation. If not set the default - timeout set on the constructor will be used. - device_specific: Whether the call is device-specific or independent. - - Returns: - The output of running such command as a binary string. - """ - timeout = self._config.default_timeout if timeout is None else timeout - command = self.command_prefix(include_device_name=device_specific) + args - command_str = 'adb ' + ' '.join(command[1:]) - - n_retries = 2 - n_tries = 1 - latest_error = None - while n_tries <= n_retries: - try: - logging.info('Executing ADB command: [%s]', command_str) - cmd_output = subprocess.check_output( - command, - stderr=subprocess.STDOUT, - timeout=timeout, - env=self._os_env_vars, - ) - logging.debug('ADB command output: %s', cmd_output) - return cmd_output - except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: - logging.exception( - 'Failed to execute ADB command (try %r of 3): [%s]', - n_tries, command_str) - if e.stdout is not None: - logging.error('**stdout**:') - for line in e.stdout.splitlines(): - logging.error(' %s', line) - if e.stderr is not None: - logging.error('**stderr**:') - for line in e.stderr.splitlines(): - logging.error(' %s', line) - n_tries += 1 - latest_error = e - if device_specific and n_tries <= n_retries: - self._restart_server(timeout=timeout) - - raise errors.AdbControllerError( - f'Error executing adb command: [{command_str}]\n' - f'Caused by: {latest_error}\n' - f'adb stdout: [{latest_error.stdout}]\n' - f'adb stderr: [{latest_error.stderr}]') from latest_error diff --git a/android_env/components/adb_controller_test.py b/android_env/components/adb_controller_test.py deleted file mode 100644 index dddd04d3..00000000 --- a/android_env/components/adb_controller_test.py +++ /dev/null @@ -1,234 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.adb_controller.""" - -import os -import subprocess -import time -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_controller as adb_controller_lib -from android_env.components import config_classes -from android_env.components import errors - -# Timeout to be used by default in tests below. Set to a small value to avoid -# hanging on a failed test. -_TIMEOUT = 2 - - -class AdbControllerTest(absltest.TestCase): - - def setUp(self): - super().setUp() - # Set two env vars. - os.environ['MY_ENV_VAR'] = '/some/path/' - os.environ['HOME'] = '$MY_ENV_VAR' - self._env_before = os.environ - self._adb_controller = adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - ) - ) - - @mock.patch.object(subprocess, 'check_output', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_init_server(self, mock_sleep, mock_check_output): - # Arrange. - adb_controller = adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - ) - ) - - # Act. - adb_controller.init_server(timeout=_TIMEOUT) - - # Assert. - expected_env = self._env_before - expected_env['HOME'] = '/some/path/' - mock_check_output.assert_called_once_with( - ['my_adb', '-P', '9999', 'devices'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ) - mock_sleep.assert_called_once() - - @mock.patch.object(subprocess, 'check_output', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_restart_server(self, mock_sleep, mock_check_output): - # Arrange. - mock_check_output.side_effect = [ - subprocess.CalledProcessError(returncode=1, cmd='blah'), - ] + ['fake_output'.encode('utf-8')] * 4 - adb_controller = adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - ) - ) - - # Act. - adb_controller.execute_command(['my_command'], timeout=_TIMEOUT) - - # Assert. - expected_env = self._env_before - expected_env['HOME'] = '/some/path/' - mock_check_output.assert_has_calls([ - mock.call( - ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'kill-server'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'start-server'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'devices'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - ]) - mock_sleep.assert_has_calls( - [mock.call(0.2), mock.call(2.0), mock.call(0.2)]) - - @mock.patch.object(subprocess, 'check_output', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_invalid_command(self, mock_sleep, mock_check_output): - # Arrange. - restart_sequence = ['fake_output'.encode('utf-8')] * 3 - mock_check_output.side_effect = ( - [ - subprocess.CalledProcessError(returncode=1, cmd='blah'), - ] - + restart_sequence - + [subprocess.CalledProcessError(returncode=1, cmd='blah')] - # Don't restart if last call fails. - ) - adb_controller = adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - ) - ) - - # Act. - with self.assertRaises(errors.AdbControllerError): - adb_controller.execute_command(['my_command'], timeout=_TIMEOUT) - - # Assert. - expected_env = self._env_before - expected_env['HOME'] = '/some/path/' - mock_check_output.assert_has_calls( - [ - mock.call( - ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'kill-server'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'start-server'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', 'devices'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - mock.call( - ['my_adb', '-P', '9999', '-s', 'awesome_device', 'my_command'], - stderr=subprocess.STDOUT, - timeout=_TIMEOUT, - env=expected_env, - ), - ], - any_order=False, - ) - mock_sleep.assert_has_calls( - [mock.call(0.2), mock.call(2.0), mock.call(0.2)] - ) - - @mock.patch.object(subprocess, 'check_output', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_avoid_infinite_recursion(self, mock_sleep, mock_check_output): - del mock_sleep - mock_check_output.side_effect = subprocess.CalledProcessError( - returncode=1, cmd='blah') - adb_controller = adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - ) - ) - self.assertRaises( - errors.AdbControllerError, - adb_controller.execute_command, ['my_command'], timeout=_TIMEOUT) - - -class AdbControllerInitTest(absltest.TestCase): - - def test_deletes_problem_env_vars(self): - os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk' - os.environ['ANDROID_ADB_SERVER_PORT'] = '1337' - adb_controller_lib.AdbController( - config_classes.AdbControllerConfig( - adb_path='my_adb', - device_name='awesome_device', - adb_server_port=9999, - default_timeout=_TIMEOUT, - ) - ) - self.assertNotIn('ANDROID_HOME', os.environ) - self.assertNotIn('ANDROID_ADB_SERVER_PORT', os.environ) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/adb_log_stream.py b/android_env/components/adb_log_stream.py deleted file mode 100644 index 17ad4eb8..00000000 --- a/android_env/components/adb_log_stream.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Class for a stream of logs output by a locally running emulator.""" - -import subprocess - -from absl import logging -from android_env.components import log_stream - - -_LOGCAT_COMMAND = ['logcat', '-v', 'epoch'] - - -class AdbLogStream(log_stream.LogStream): - """Manages adb logcat process for a locally running emulator.""" - - def __init__(self, adb_command_prefix: list[str], verbose: bool = False): - super().__init__(verbose=verbose) - self._adb_command_prefix = adb_command_prefix - - def _get_stream_output(self): - - # Before spawning a long-lived process, we issue `logcat -b all -c` to clear - # all buffers to avoid interference from previous runs. - clear_buffer_output = subprocess.check_output( - self._adb_command_prefix + ['logcat', '-b', 'all', '-c'], - stderr=subprocess.STDOUT, - timeout=100) - logging.info('clear_buffer_output: %r', clear_buffer_output) - cmd = self._adb_command_prefix + _LOGCAT_COMMAND + self._filters - self._adb_subprocess = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, - universal_newlines=True) - return self._adb_subprocess.stdout - - def stop_stream(self): - if not hasattr(self, '_adb_subprocess') or self._adb_subprocess is None: - logging.error('`stop_stream()` called before `get_stream_output()`. ' - 'This violates the `LogStream` API.') - else: - self._adb_subprocess.kill() diff --git a/android_env/components/adb_log_stream_test.py b/android_env/components/adb_log_stream_test.py deleted file mode 100644 index 98ea2120..00000000 --- a/android_env/components/adb_log_stream_test.py +++ /dev/null @@ -1,69 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for adb_log_stream.""" - -import subprocess -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_log_stream - - -class FakeAdbSubprocess: - - @property - def stdout(self): - return [f'line_{i}' for i in range(100)] - - def kill(self): - pass - - -class AdbLogStreamTest(absltest.TestCase): - - @mock.patch.object(subprocess, 'check_output', return_value=b'') - @mock.patch.object(subprocess, 'Popen', return_value=FakeAdbSubprocess()) - def test_get_stream_output(self, mock_popen, unused_mock_check_output): - stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo']) - stream.set_log_filters(['bar']) - stream_output = stream.get_stream_output() - - for i, line in enumerate(stream_output): - self.assertEqual(line, f'line_{i}') - - mock_popen.assert_called_with( - ['foo', 'logcat', '-v', 'epoch', 'bar', '*:S'], - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE, - bufsize=1, - universal_newlines=True) - - def test_stop_stream_before_get_stream_output(self): - """Calling `stop_stream()` before `get_stream_output()` should not crash.""" - - # Arrange. - stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo']) - - # Act. - stream.stop_stream() - - # Assert. - # Nothing to assert. The test should just finish without raising an - # exception. - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/app_screen_checker.py b/android_env/components/app_screen_checker.py deleted file mode 100644 index 742e481c..00000000 --- a/android_env/components/app_screen_checker.py +++ /dev/null @@ -1,269 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Determines if the current app screen matches an expected app screen.""" - -from collections.abc import Callable, Sequence -import enum -import re -import time -from typing import Self - -from absl import logging -from android_env.components import adb_call_parser as adb_call_parser_lib -from android_env.components import errors -from android_env.proto import adb_pb2 -from android_env.proto import task_pb2 - - -class _DumpsysNode: - """A node in a dumpsys tree.""" - - def __init__(self, data: str): - self._children = [] - self._data = data - - @property - def data(self) -> str: - return self._data - - @property - def children(self) -> list[Self]: - return self._children - - def find_child( - self, predicate: Callable[[Self], bool], max_levels: int = 0 - ) -> Self | None: - """Returns the first direct child that matches `predicate`, None otherwise. - - Args: - predicate: Function-like that accepts a _DumpsysNode and returns boolean. - max_levels: Maximum number of levels down the tree to search for a child. - If non-positive, only direct children will be searched for. - - Returns: - A _DumpsysNode or None. - """ - if not self.children: - return None - - try: - return next(x for x in self.children if predicate(x)) - except StopIteration: - logging.info('Failed to find child. max_levels: %i.', max_levels) - # Search children. - if max_levels: - for child in self.children: - child_result = child.find_child(predicate, max_levels - 1) - if child_result is not None: - return child_result - - return None - - def __repr__(self): - return self._data - - def print_tree(self, indent: int = 2): - """Prints this tree in logging.info().""" - logging.info(' ' * indent + self.data) - for c in self.children: - c.print_tree(indent + 2) - - -def build_tree_from_dumpsys_output(dumpsys_output: str) -> _DumpsysNode: - """Constructs a tree from a dumpsys string output. - - Args: - dumpsys_output: string Verbatim output from adb dumpsys. The expected format - is a list where each line is a node and the indentation marks the - relationship with its parent or sibling. - - Returns: - _DumpsysNode The root of the tree. - """ - lines = dumpsys_output.split('\n') # Split by lines. - lines = [x.rstrip(' \r') for x in lines] - lines = [x for x in lines if len(x)] # Remove empty lines. - - root = _DumpsysNode('___root___') # The root of all nodes. - parents_stack = [root] - for line in lines: - stripped_line = line.lstrip(' ') - indent = len(line) - len(stripped_line) # Number of indent spaces. - new_node = _DumpsysNode(stripped_line) # Create a node without indentation. - - parent = parents_stack.pop() - if parent.data == '___root___': # The root is an exception for indentation. - parent_indent = -2 - else: - parent_indent = (len(parents_stack) - 1) * 2 - - if indent == parent_indent: # `new_node` is a sibiling. - parent = parents_stack.pop() - elif indent < parent_indent: # Indentation reduced (i.e. a block finished) - num_levels = (indent // 2) + 1 - parents_stack = parents_stack[:num_levels] - parent = parents_stack.pop() - elif indent > parent_indent: # `new_node` is a child. - pass # No need to change the current parent. - - parent.children.append(new_node) - parents_stack.append(parent) - parents_stack.append(new_node) - - return root - - -def matches_path( - dumpsys_activity_output: str, - expected_view_hierarchy_path: Sequence[re.Pattern[str]], - max_levels: int = 0, -) -> bool: - """Returns True if the current dumpsys output matches the expected path. - - Args: - dumpsys_activity_output: The output of running `dumpsys activity ...`. - expected_view_hierarchy_path: [regex] A list of regular expressions to be - tested at each level of the tree. - max_levels: How many levels to search from root for View Hierarchy. - - Returns: - True if the dumpsys tree contains one path that matches all regexes. - """ - root = build_tree_from_dumpsys_output(dumpsys_activity_output) - - # Find the View Hierarchy. - view_hierarchy = root.find_child( - lambda x: x.data.startswith('View Hierarchy'), max_levels) - if view_hierarchy is None: - logging.error( - 'view_hierarchy is None. Dumpsys activity output: %s. tree: %r', - str(dumpsys_activity_output), root.print_tree()) - logging.error('Tree root: %s', str(root)) - return False - - current_node = view_hierarchy - for i, regex in enumerate(expected_view_hierarchy_path): - - def regex_predicate(node, expr=regex): - matches = expr.match(node.data) - return matches is not None - - child = current_node.find_child(regex_predicate) - if child is None: - logging.error('Mismatched regex (%i, %s). current_node: %s', i, - regex.pattern, current_node) - logging.error('Dumpsys activity output: %s', str(dumpsys_activity_output)) - logging.error('Tree root: %s', str(root)) - return False - else: - current_node = child - return True - - -class AppScreenChecker: - """Checks that the current app screen matches an expected screen.""" - - class Outcome(enum.IntEnum): - """Possible return vales from checking the current app screen.""" - # The current app screen matches the expected app screen. - SUCCESS = 0 - # There's no activity to check. - EMPTY_EXPECTED_ACTIVITY = 1 - # We were unable to determine the current activity. - FAILED_ACTIVITY_EXTRACTION = 2 - # The current activity does not match the expected activity. - UNEXPECTED_ACTIVITY = 3 - # The current view hierarchy does not match the expected view hierarchy. - UNEXPECTED_VIEW_HIERARCHY = 4 - - def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser, - expected_app_screen: task_pb2.AppScreen): - self._adb_call_parser = adb_call_parser - self._expected_app_screen = expected_app_screen - self._expected_activity = expected_app_screen.activity - self._expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_app_screen.view_hierarchy_path - ] - - # Return type is AppScreenChecker.Outcome, but pytype doesn't understand that. - def matches_current_app_screen(self) -> enum.IntEnum: - """Determines whether the current app screen matches `expected_app_screen`.""" - if not self._expected_activity: - return AppScreenChecker.Outcome.EMPTY_EXPECTED_ACTIVITY - - # Check if we are still on the expected Activity. - response = self._adb_call_parser.parse( - adb_pb2.AdbRequest( - get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())) - if response.status != adb_pb2.AdbResponse.OK: - return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION - - current_activity = response.get_current_activity.full_activity - if current_activity != self._expected_activity: - logging.error('current_activity: %s, expected_activity: %s', - current_activity, self._expected_activity) - return AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY - - # Extract just the package name from the full activity name. - package_name = self._expected_activity.split('/')[0] - - # Check if we are in the expected view hierarchy path. - if self._expected_view_hierarchy_path: - dumpsys_response = self._adb_call_parser.parse( - adb_pb2.AdbRequest( - dumpsys=adb_pb2.AdbRequest.DumpsysRequest( - service='activity', args=[package_name, package_name]))) - if dumpsys_response.status != adb_pb2.AdbResponse.OK: - return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION - - if dumpsys_response.dumpsys.output: - if not matches_path( - dumpsys_response.dumpsys.output.decode('utf-8'), - self._expected_view_hierarchy_path, - max_levels=3): - return AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY - - return AppScreenChecker.Outcome.SUCCESS - - def wait_for_app_screen(self, timeout_sec: float) -> float: - """Waits for `self._expected_app_screen` to be the current screen. - - Args: - timeout_sec: Maximum total time to wait for the screen to pop up. - - Returns: - The total amount of time in seconds spent waiting for the screen to pop - up. - Raises: - errors.WaitForAppScreenError if the screen does not pop up within - `timeout_sec`. - """ - - logging.info('Waiting for app screen...') - start_time = time.time() - while time.time() - start_time < timeout_sec: - if self.matches_current_app_screen() == AppScreenChecker.Outcome.SUCCESS: - wait_time = time.time() - start_time - logging.info('Successfully waited for app screen in %r seconds: [%r]', - wait_time, self._expected_app_screen) - return wait_time - time.sleep(0.1) - - wait_time = time.time() - start_time - logging.error('Failed to wait for app screen in %r seconds: [%r].', - wait_time, self._expected_app_screen) - - raise errors.WaitForAppScreenError() diff --git a/android_env/components/app_screen_checker_test.py b/android_env/components/app_screen_checker_test.py deleted file mode 100644 index d9ec6aeb..00000000 --- a/android_env/components/app_screen_checker_test.py +++ /dev/null @@ -1,266 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.app_screen_checker.""" - -import re -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_call_parser -from android_env.components import app_screen_checker -from android_env.components import errors -from android_env.proto import adb_pb2 -from android_env.proto import task_pb2 - - -def _flatten_tree( - tree: app_screen_checker._DumpsysNode, flat_tree: list[str], indent: int = 2 -): - """Appends a list of strings to `flat_tree` from `tree`.""" - flat_tree.append(' ' * indent + tree.data) - for c in tree.children: - _flatten_tree(c, flat_tree, indent + 2) - - -class AppScreenCheckerTest(absltest.TestCase): - - # Ensures that build_tree_from_dumpsys_output produces a node whose flat - # representation matches our expectation from an arbitrary hierarchy. - def test_build_tree_from_dumpsys_output(self): - dumpsys_output = """ -Queen Elizabeth II - Charles - William - George - Charlotte - Louis - Harry - Archie - Anne - Peter - Savannah - Isla - Zara - Mia - Lena - Andrew - Beatrice - Eugenie - Edward - Louise - James -""" - tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output) - flat_tree = [] - _flatten_tree(tree, flat_tree, indent=2) - self.assertEqual(flat_tree, [ - ' ___root___', - ' Queen Elizabeth II', - ' Charles', - ' William', - ' George', - ' Charlotte', - ' Louis', - ' Harry', - ' Archie', - ' Anne', - ' Peter', - ' Savannah', - ' Isla', - ' Zara', - ' Mia', - ' Lena', - ' Andrew', - ' Beatrice', - ' Eugenie', - ' Edward', - ' Louise', - ' James', - ]) - - # Ensures that build_tree_from_dumpsys_output produces a node whose flat - # representation matches our expectation from an arbitrary hierarchy. - def test_build_forest_from_dumpsys_output(self): - dumpsys_output = """ -Tree1 - Branch1 - Leaf1 - Leaf2 - Branch2 - Leaf3 - Leaf4 - Leaf5 -Tree2 - Branch3 - Leaf6 - Leaf7 - Branch4 - Leaf8 - Leaf9 - Leaf10 - Leaf11 -""" - tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output) - flat_tree = [] - _flatten_tree(tree, flat_tree, indent=2) - self.assertEqual(flat_tree, [ - ' ___root___', - ' Tree1', - ' Branch1', - ' Leaf1', - ' Leaf2', - ' Branch2', - ' Leaf3', - ' Leaf4', - ' Leaf5', - ' Tree2', - ' Branch3', - ' Leaf6', - ' Leaf7', - ' Branch4', - ' Leaf8', - ' Leaf9', - ' Leaf10', - ' Leaf11', - ]) - - def test_no_view_hierarchy_matches_path(self): - dumpsys_output = """ -TASK - ACTIVITY - Missing View Hierarchy - A - B - C - D - E - F -""" - expected_path = ['^A$', 'B$'] - expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_path - ] - self.assertFalse( - app_screen_checker.matches_path(dumpsys_output, - expected_view_hierarchy_path)) - - def test_matches_path(self): - dumpsys_output = """ -TASK - ACTIVITY - Some node we don't care - Blah - - View Hierarchy - Hirohito - Akihito - Naruhito - Aiko - Fumihito - Mako - Kako - Hisahito - Masahito -""" - expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] - expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_path - ] - self.assertTrue( - app_screen_checker.matches_path( - dumpsys_output, expected_view_hierarchy_path, max_levels=2)) - - # Also check that the following path does not match anything in the tree. - expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kenji$'] - expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_path - ] - self.assertFalse( - app_screen_checker.matches_path(dumpsys_output, - expected_view_hierarchy_path)) - - def test_matches_path_one_level_deep(self): - dumpsys_output = """ -TASK - ACTIVITY - Some node we don't care - Blah - - Some intermediate node - View Hierarchy - Hirohito - Akihito - Naruhito - Aiko - Fumihito - Mako - Kako - Hisahito - Masahito -""" - expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] - expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_path - ] - self.assertTrue( - app_screen_checker.matches_path( - dumpsys_output, expected_view_hierarchy_path, max_levels=3)) - - # Also check that the view hierarchy is not found when searching only grand - # children of TASK. - expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] - expected_view_hierarchy_path = [ - re.compile(regex) for regex in expected_path - ] - self.assertFalse( - app_screen_checker.matches_path( - dumpsys_output, expected_view_hierarchy_path, max_levels=2)) - - def test_wait_for_app_screen_zero_timeout(self): - """Ensures that an exception is raised if the timeout is passed.""" - app_screen = task_pb2.AppScreen(activity='whatever.MyActivity') - call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) - screen_checker = app_screen_checker.AppScreenChecker( - adb_call_parser=call_parser, - expected_app_screen=app_screen) - # With a zero timeout, the method should never be able to wait for the - # screen to pop up and an exception should be raised. - self.assertRaises( - errors.WaitForAppScreenError, - screen_checker.wait_for_app_screen, - timeout_sec=0.0) - - def test_wait_for_app_screen_successful(self): - """Ensures that with the right conditions, the app screen should pop up.""" - app_screen = task_pb2.AppScreen(activity='my.favorite.AwesomeActivity') - call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) - call_parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK, - get_current_activity=adb_pb2.AdbResponse.GetCurrentActivityResponse( - full_activity='my.favorite.AwesomeActivity')) - - screen_checker = app_screen_checker.AppScreenChecker( - call_parser, app_screen) - timeout = 1.0 - wait_time = screen_checker.wait_for_app_screen(timeout_sec=timeout) - - # The call should not generate an exception and the return value should be - # less than the timeout given. - self.assertLess(wait_time, timeout) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/config_classes.py b/android_env/components/config_classes.py deleted file mode 100644 index b4ed956f..00000000 --- a/android_env/components/config_classes.py +++ /dev/null @@ -1,203 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dataclass definitions used for instantiating AndroidEnv components.""" - -import dataclasses - - -@dataclasses.dataclass -class AdbControllerConfig: - """Settings for instatiating an `AdbController` instance.""" - - # Filesystem path to the `adb` binary. - # NOTE: This must be a full path and must not contain environment variables - # or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be - # expanded internally by AndroidEnv. - adb_path: str = '~/Android/Sdk/platform-tools/adb' - # Port for adb server. - adb_server_port: int = 5037 - # Default timeout in seconds for internal commands. - default_timeout: float = 120.0 - # Name of the device to communicate with. - device_name: str = '' - - -@dataclasses.dataclass -class DeviceSettingsConfig: - """Config class for DeviceSettings.""" - - # Whether to show circles on the screen indicating touch position. - show_touches: bool = True - # Whether to show blue lines on the screen indicating touch position. - show_pointer_location: bool = True - # Whether or not to show the status (top) bar. - show_status_bar: bool = False - # Whether or not to show the navigation (bottom) bar. - show_navigation_bar: bool = False - - -@dataclasses.dataclass -class CoordinatorConfig: - """Config class for Coordinator.""" - - # Number of virtual "fingers" of the agent. - num_fingers: int = 1 - # Whether to enable keyboard key events. - enable_key_events: bool = False - # Time between periodic restarts in minutes. If > 0, will trigger - # a simulator restart at the beginning of the next episode once the time has - # been reached. - periodic_restart_time_min: float = 0.0 - # General Android settings. - device_settings: DeviceSettingsConfig = dataclasses.field( - default_factory=DeviceSettingsConfig - ) - - -@dataclasses.dataclass -class SimulatorConfig: - """Base class for all simulator configs.""" - - # If true, the log stream of the simulator will be verbose. - verbose_logs: bool = False - # How often to (asynchronously) grab the screenshot from the simulator. - # If <= 0, stepping the environment blocks on fetching the screenshot (the - # environment is synchronous). - interaction_rate_sec: float = 0.0 - - -@dataclasses.dataclass -class EmulatorLauncherConfig: - """Config class for EmulatorLauncher.""" - - # NOTE: If `adb_port`, `emulator_console_port` and `grpc_port` are defined - # (i.e. not all equal to 0), it is assumed that the emulator they point to - # exists already and EmulatorLauncher will be skipped. - - # Filesystem path to the `emulator` binary. - emulator_path: str = '~/Android/Sdk/emulator/emulator' - # Filesystem path to the Android SDK root. - android_sdk_root: str = '~/Android/Sdk' - # Name of the AVD. - avd_name: str = '' - # Local directory for AVDs. - android_avd_home: str = '~/.android/avd' - # Name of the snapshot to load. - snapshot_name: str = '' - # Path to the KVM device. - kvm_device: str = '/dev/kvm' - # Path to directory which will hold temporary files. - tmp_dir: str = '/tmp/android_env/simulator/' - # GPU mode override. - # Please see - # https://developer.android.com/studio/run/emulator-acceleration#accel-graphics. - gpu_mode: str = 'swangle_indirect' # Alternative: swiftshader_indirect, host - # Whether to run in headless mode (i.e. without a graphical window). - run_headless: bool = True - # Whether to restrict network access. - # If True, will disable networking on the device. This option is only - # available for emulator version > 31.3.9 (June 2022). - restrict_network: bool = False - # Whether to set `SHOW_PERF_STATS=1` when launching the emulator to display - # performance and memory statistics. - show_perf_stats: bool = False - - # ADB port for the Android device. - adb_port: int = 0 - # Port for telnet communication with the emulator. - emulator_console_port: int = 0 - # Port for gRPC communication with the emulator. - grpc_port: int = 0 - - -@dataclasses.dataclass -class EmulatorConfig(SimulatorConfig): - """Config class for EmulatorSimulator.""" - - # Configuration for launching the Android Emulator. - emulator_launcher: EmulatorLauncherConfig = dataclasses.field( - default_factory=EmulatorLauncherConfig - ) - # Configuration for talking to adb. - adb_controller: AdbControllerConfig = dataclasses.field( - default_factory=AdbControllerConfig - ) - # Path to file which holds emulator logs. If not provided, it will be - # determined by the EmulatorLauncher. - logfile_path: str = '' - # The number of times to try launching the emulator before rebooting (reboot - # on the n+1-st try). - launch_n_times_without_reboot: int = 1 - # The number of times to try launching the emulator before reinstalling - # (reinstall on the n+1-st try). - launch_n_times_without_reinstall: int = 2 - - -@dataclasses.dataclass -class FakeSimulatorConfig(SimulatorConfig): - """Config class for FakeSimulator.""" - - # The dimensions in pixels of the device screen (HxW). - screen_dimensions: tuple[int, int] = (0, 0) - - -@dataclasses.dataclass -class TaskManagerConfig: - """Config class for TaskManager.""" - - # If max_bad_states episodes finish in a bad state in a row, restart - # the simulation. - max_bad_states: int = 3 - # The frequency to check for the current activity and view hierarchy. - # The unit is raw observation (i.e. each call to AndroidEnv.step()). - dumpsys_check_frequency: int = 150 - # The maximum number of tries for extracting the current activity before - # forcing the episode to restart. - max_failed_current_activity: int = 10 - # The maximum number of extras elements to store. If this number is exceeded, - # elements are dropped in the order they were received. - extras_max_buffer_size: int = 100 - - -@dataclasses.dataclass -class TaskConfig: - """Base config class for loading tasks.""" - - # The directory for temporary task-related resources. - tmp_dir: str = '' - - -@dataclasses.dataclass -class FilesystemTaskConfig(TaskConfig): - """Config for protobuf files stored in the local filesystem.""" - - # Filesystem path to `.binarypb` or `.textproto` protobuf Task. - path: str = '' - - -@dataclasses.dataclass -class AndroidEnvConfig: - """Config class for AndroidEnv.""" - - # Configs for main components. - task: TaskConfig = dataclasses.field(default_factory=TaskConfig) - task_manager: TaskManagerConfig = dataclasses.field( - default_factory=TaskManagerConfig - ) - coordinator: CoordinatorConfig = dataclasses.field( - default_factory=CoordinatorConfig - ) - simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig) diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py deleted file mode 100644 index bca636ec..00000000 --- a/android_env/components/coordinator.py +++ /dev/null @@ -1,282 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Coordinator handles interaction between internal components of AndroidEnv.""" - -import copy -import time -from typing import Any - -from absl import logging -from android_env.components import action_fns -from android_env.components import action_type as action_type_lib -from android_env.components import adb_call_parser -from android_env.components import config_classes -from android_env.components import device_settings as device_settings_lib -from android_env.components import errors -from android_env.components import pixel_fns -from android_env.components import specs -from android_env.components import task_manager as task_manager_lib -from android_env.components.simulators import base_simulator -from android_env.proto import adb_pb2 -import dm_env -import numpy as np - - -class Coordinator: - """Handles interaction between internal components of AndroidEnv.""" - - def __init__( - self, - simulator: base_simulator.BaseSimulator, - task_manager: task_manager_lib.TaskManager, - device_settings: device_settings_lib.DeviceSettings, - config: config_classes.CoordinatorConfig | None = None, - ): - """Handles communication between AndroidEnv and its components. - - Args: - simulator: A BaseSimulator instance. - task_manager: The TaskManager, responsible for coordinating RL tasks. - config: Settings to customize this Coordinator. - """ - self._simulator = simulator - self._task_manager = task_manager - self._config = config or config_classes.CoordinatorConfig() - self._device_settings = device_settings - self._adb_call_parser: adb_call_parser.AdbCallParser = None - - # Initialize stats. - self._stats = { - 'relaunch_count': 0, - 'relaunch_count_periodic': 0, - 'relaunch_count_setup_steps': 0, - 'relaunch_count_reset_steps': 0, - 'relaunch_count_simulator_launch': 0, - 'relaunch_count_simulator_reset': 0, - 'relaunch_count_execute_action': 0, - 'relaunch_count_fetch_observation': 0, - 'relaunch_count_update_settings': 0, - 'failed_task_updates': 0, - } - - # Initialize counters. - self._simulator_healthy = False - self._latest_observation_time = 0 - self._simulator_start_time = None - - logging.info('Starting the simulator...') - self._launch_simulator() - - def action_spec(self) -> dict[str, dm_env.specs.Array]: - return specs.base_action_spec( - num_fingers=self._config.num_fingers, - enable_key_events=self._config.enable_key_events, - ) - - def observation_spec(self) -> dict[str, dm_env.specs.Array]: - return specs.base_observation_spec( - height=self._device_settings.screen_height(), - width=self._device_settings.screen_width(), - ) - - def _should_periodic_relaunch(self) -> bool: - """Checks if it is time to restart the simulator. - - If a periodic restart time was specified, the Coordinator will re-launch - the simulator at regular time intervals. This helps to make sure that the - simulator is not in a stale state even if the environment has been running - for a significant amount of time. - - Returns: - Boolean indicating if it is time to restart the simulator. - """ - - if self._config.periodic_restart_time_min and self._simulator_start_time: - sim_alive_time = (time.time() - self._simulator_start_time) / 60.0 - logging.info('Simulator has been running for %f mins', sim_alive_time) - if sim_alive_time > self._config.periodic_restart_time_min: - logging.info('Maximum alive time reached. Restarting simulator.') - self._stats['relaunch_count_periodic'] += 1 - return True - return False - - def _launch_simulator(self, max_retries: int = 3): - """Launches the simulator. - - Sets up the simulator and other task-related settings. - - Args: - max_retries: Number of times to attempt a restart before raising an error. - """ - - self._simulator_healthy = False - - # Attempt to restart the system a given number of times. - num_tries = 1 - latest_error = None - while True: - if num_tries > max_retries: - raise errors.TooManyRestartsError( - 'Maximum number of restart attempts reached.' - ) from latest_error - logging.info('Simulator launch attempt %d of %d', num_tries, max_retries) - - self._task_manager.stop() - - # Launch the simulator. - self._simulator.launch() - self._simulator_start_time = time.time() - - # From here on, the simulator is assumed to be up and running. - self._adb_call_parser = self._create_adb_call_parser() - try: - self._device_settings.update(self._config.device_settings) - except errors.AdbControllerError as e: - logging.exception('device_settings.update() failed.') - self._stats['relaunch_count_update_settings'] += 1 - self._latest_error = e - num_tries += 1 - continue - - # Start the task. - self._task_manager.start( - adb_call_parser_factory=self._create_adb_call_parser, - log_stream=self._simulator.create_log_stream(), - ) - try: - self._task_manager.setup_task() - except errors.StepCommandError as error: - logging.exception('Failed to set up the task. Restarting simulator.') - self._stats['relaunch_count_setup_steps'] += 1 - latest_error = error - num_tries += 1 - continue - - # Restart was successful. - self._simulator_healthy = True - self._stats['relaunch_count'] += 1 - break - - def _create_adb_call_parser(self): - """Creates a new AdbCallParser instance.""" - return adb_call_parser.AdbCallParser( - adb_controller=self._simulator.create_adb_controller() - ) - - def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: - return self._adb_call_parser.parse(call) - - def rl_reset(self) -> dm_env.TimeStep: - """Resets the RL episode.""" - - # Relaunch the simulator if necessary. - if not self._simulator_healthy or self._should_periodic_relaunch(): - self._launch_simulator() - - # Reset counters. - self._latest_observation_time = 0 - for key in self._stats: - if key.startswith('episode'): - self._stats[key] = 0.0 - - # Execute a lift action before resetting the task. - if not action_fns.send_action_to_simulator( - action_fns.lift_all_fingers_action(self._config.num_fingers), - self._simulator, - self._device_settings.screen_width(), - self._device_settings.screen_height(), - self._config.num_fingers, - ): - self._stats['relaunch_count_execute_action'] += 1 - self._simulator_healthy = False - - # Reset the task. - self._task_manager.reset_task() - self._device_settings.get_orientation() - - # Get data from the simulator. - simulator_signals = self._gather_simulator_signals() - - return self._task_manager.rl_reset(simulator_signals) - - def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep: - """Executes the selected action and returns a timestep. - - Args: - agent_action: Selected action to perform on the simulated Android device. - If `agent_action` is `None` it means that this is an RL reset (to start - a new episode). - - Returns: - An RL timestep. - """ - - if not action_fns.send_action_to_simulator( - agent_action, - self._simulator, - self._device_settings.screen_width(), - self._device_settings.screen_height(), - self._config.num_fingers, - ): - self._stats['relaunch_count_execute_action'] += 1 - self._simulator_healthy = False - - # Get data from the simulator. - try: - simulator_signals = self._gather_simulator_signals() - except errors.ReadObservationError: - logging.exception('Unable to fetch observation. Restarting simulator.') - self._stats['relaunch_count_fetch_observation'] += 1 - self._simulator_healthy = False - - if not self._simulator_healthy: - return dm_env.truncation(reward=0.0, observation=None) - - return self._task_manager.rl_step(simulator_signals) - - def _gather_simulator_signals(self) -> dict[str, np.ndarray]: - """Gathers data from various sources to assemble the RL observation.""" - - # Get current timestamp and update the delta. - now = time.time() - timestamp_delta = ( - 0 - if self._latest_observation_time == 0 - else (now - self._latest_observation_time) * 1e6 - ) - self._latest_observation_time = now - - return { - 'pixels': self._simulator.get_screenshot(), - 'orientation': self._device_settings.get_orientation(), - 'timedelta': np.array(timestamp_delta, dtype=np.int64), - } - - def __del__(self): - self.close() - - def stats(self) -> dict[str, Any]: - """Returns various statistics.""" - - return copy.deepcopy(self._stats) - - def close(self): - """Cleans up the state of this Coordinator.""" - - if hasattr(self, '_task_manager'): - self._task_manager.stop() - if hasattr(self, '_simulator'): - self._simulator.close() diff --git a/android_env/components/coordinator_test.py b/android_env/components/coordinator_test.py deleted file mode 100644 index 78d3a4fe..00000000 --- a/android_env/components/coordinator_test.py +++ /dev/null @@ -1,283 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.coordinator.""" - -import tempfile -import time -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import action_type -from android_env.components import adb_call_parser -from android_env.components import config_classes -from android_env.components import coordinator as coordinator_lib -from android_env.components import device_settings as device_settings_lib -from android_env.components import errors -from android_env.components import task_manager -from android_env.components.simulators import base_simulator -from android_env.proto import adb_pb2 -from android_env.proto import state_pb2 -from android_env.proto import task_pb2 -import dm_env -import numpy as np - - -class CoordinatorTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.addCleanup(mock.patch.stopall) # Disable previous patches. - - self._simulator = mock.create_autospec(base_simulator.BaseSimulator) - self._random_screenshot = np.random.randint( - low=0, high=255, size=(800, 600, 3), dtype=np.uint8) - self._simulator.get_screenshot.return_value = self._random_screenshot - self._task_manager = mock.create_autospec(task_manager.TaskManager) - self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) - self.enter_context( - mock.patch.object( - adb_call_parser, - 'AdbCallParser', - autospec=True, - return_value=self._adb_call_parser)) - self._coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - device_settings=device_settings_lib.DeviceSettings(self._simulator), - ) - - def tearDown(self): - super().tearDown() - self._coordinator.close() - - @mock.patch.object(time, 'sleep', autospec=True) - def test_relaunch_simulator(self, unused_mock_sleep): - relaunch_count = self._coordinator.stats()['relaunch_count'] - self._coordinator._launch_simulator() - self.assertEqual(self._coordinator.stats()['relaunch_count'], - relaunch_count + 1) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_reset(self, unused_mock_sleep): - """'relaunch_count_execute_action' should be zero if there's no error.""" - - self._coordinator.rl_reset() - stats = self._coordinator.stats() - self.assertIn('relaunch_count_execute_action', stats) - self.assertEqual(stats['relaunch_count_execute_action'], 0) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_reset_error_sending_action(self, unused_mock_sleep): - """'relaunch_count_execute_action' should be positive if there's an error.""" - - self._simulator.send_touch.side_effect = errors.SendActionError() - self._coordinator.rl_reset() - stats = self._coordinator.stats() - self.assertIn('relaunch_count_execute_action', stats) - self.assertEqual(stats['relaunch_count_execute_action'], 1) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_lift_all_fingers(self, unused_mock_sleep): - self._coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - device_settings=device_settings_lib.DeviceSettings(self._simulator), - config=config_classes.CoordinatorConfig(num_fingers=3), - ) - self._coordinator.rl_reset() - expected_actions = [ - # (x, y, is_down, identifier). - (0, 0, False, 0), - (0, 0, False, 1), - (0, 0, False, 2), - ] - actual_actions = self._simulator.send_touch.call_args[0][0] - for actual, expected in zip(actual_actions, expected_actions): - np.testing.assert_array_equal(actual, expected) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_process_action(self, unused_mock_sleep): - - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=10.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - - self._task_manager.rl_step.side_effect = fake_rl_step - timestep = self._coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - obs = timestep.observation - self.assertEqual(obs['pixels'].shape, (800, 600, 3)) - np.testing.assert_equal(obs['orientation'], - np.array([0, 0, 0, 0], dtype=np.uint8)) - self.assertEqual(timestep.reward, 10.0) - self.assertEqual(obs['extras'], {'extra': [0.0]}) - self.assertFalse(timestep.last()) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_process_action_error(self, unused_mock_sleep): - - def fake_rl_step(simulator_signals): - self.assertFalse(simulator_signals['simulator_healthy']) - return dm_env.truncation(reward=0.0, observation=None) - - self._task_manager.rl_step.side_effect = fake_rl_step - self._simulator.get_screenshot.side_effect = errors.ReadObservationError() - timestep = self._coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - self.assertIsNone(timestep.observation) - self.assertEqual(timestep.reward, 0.0) - self.assertTrue(timestep.last()) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_execute_action_touch(self, unused_mock_sleep): - - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=123.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - - self._task_manager.rl_step.side_effect = fake_rl_step - timestep = self._coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.TOUCH), - 'touch_position': np.array([0.5, 0.5]) - }) - self.assertEqual(timestep.reward, 123.0) - np.testing.assert_equal(timestep.observation['pixels'], - self._random_screenshot) - self._simulator.send_touch.assert_called_once_with([(300, 400, True, 0)]) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_execute_multitouch_action(self, unused_mock_sleep): - self._coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - device_settings=device_settings_lib.DeviceSettings(self._simulator), - config=config_classes.CoordinatorConfig(num_fingers=3), - ) - - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=456.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - - self._task_manager.rl_step.side_effect = fake_rl_step - action = { - 'action_type': np.array([action_type.ActionType.TOUCH]), - 'touch_position': np.array([0.25, 0.75]), - 'action_type_2': np.array([action_type.ActionType.TOUCH]), - 'touch_position_2': np.array([0.75, 0.25]), - 'action_type_3': np.array([action_type.ActionType.LIFT]), - 'touch_position_3': np.array([0.5, 0.5]), - } - timestep = self._coordinator.rl_step(action) - self._simulator.send_touch.assert_called_once_with([(150, 600, True, 0), - (450, 200, True, 1), - (300, 400, False, 2)]) - self.assertEqual(timestep.reward, 456.0) - np.testing.assert_equal(timestep.observation['pixels'], - self._random_screenshot) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_execute_action_repeat(self, unused_mock_sleep): - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=10.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - - self._task_manager.rl_step.side_effect = fake_rl_step - timestep = self._coordinator.rl_step( - {'action_type': np.array(action_type.ActionType.REPEAT)}) - self._simulator.send_touch.assert_not_called() - np.testing.assert_equal(timestep.observation['pixels'], - self._random_screenshot) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_execute_action_error(self, unused_mock_sleep): - def fake_rl_step(simulator_signals): - self.assertFalse(simulator_signals['simulator_healthy']) - return dm_env.truncation(reward=0.0, observation=None) - - self._task_manager.rl_step.side_effect = fake_rl_step - self._simulator.send_touch.side_effect = errors.SendActionError - timestep = self._coordinator.rl_step({ - 'action_type': np.array(action_type.ActionType.TOUCH), - 'touch_position': np.array([0.3, 0.8]) - }) - self.assertIsNone(timestep.observation) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_max_restarts_setup_steps(self, unused_mock_sleep): - init_fn_call = self._task_manager.setup_task.call_count - self._task_manager.setup_task.side_effect = errors.StepCommandError - self.assertRaises(errors.TooManyRestartsError, - self._coordinator._launch_simulator) - # The method was called three more times when attempting to relaunch. - self.assertEqual(init_fn_call + 3, - self._task_manager.setup_task.call_count) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_execute_adb_call(self, unused_mock_sleep): - call = adb_pb2.AdbRequest( - force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah')) - expected_response = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - self._adb_call_parser.parse.side_effect = [expected_response] - - response = self._coordinator.execute_adb_call(call) - - self.assertEqual(response, expected_response) - self._adb_call_parser.parse.assert_called_with(call) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/device_settings.py b/android_env/components/device_settings.py deleted file mode 100644 index 8b5b7a5e..00000000 --- a/android_env/components/device_settings.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sets and gets some global settings on an Android device.""" - -from typing import Final -from unittest import mock - -from absl import logging -from android_env.components import adb_call_parser -from android_env.components import config_classes -from android_env.components.simulators import base_simulator -from android_env.proto import adb_pb2 -import numpy as np - - -# The internal `AdbCallParser` instance is lazily instantiated within -# `DeviceSettings`. If we make it optional (i.e. `| None`), pytype will think -# that it could be `None`, requiring either explicit runtime checks or escape -# hatches in every actual call, even if it's never actually `None` if reached -# via the public API. -# The trick here is to create this dummy instance of the right type that's used -# as a sentinel to indicate that it hasn't been initialized yet. -_PLACEHOLDER_ADB_CALL_PARSER: Final[adb_call_parser.AdbCallParser] = ( - mock.create_autospec(adb_call_parser.AdbCallParser) -) - - -class DeviceSettings: - """An abstraction for general properties and settings of an Android device.""" - - def __init__(self, simulator: base_simulator.BaseSimulator): - self._simulator = simulator - self._adb_call_parser = _PLACEHOLDER_ADB_CALL_PARSER - - # The size of the device screen in pixels. - self._screen_width: int = 0 - self._screen_height: int = 0 - # The device orientation. - self._orientation = np.zeros(4, dtype=np.uint8) - - def update(self, config: config_classes.DeviceSettingsConfig) -> None: - """Sets the configuration of the device according to `config`.""" - - if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER: - self._adb_call_parser = adb_call_parser.AdbCallParser( - adb_controller=self._simulator.create_adb_controller() - ) - - self._update_screen_size() - self._set_show_touches(config.show_touches) - self._set_show_pointer_location(config.show_pointer_location) - self._set_status_navigation_bars( - config.show_navigation_bar, config.show_status_bar - ) - - def screen_width(self) -> int: - """The screen width in pixels. Only valid after `update()` is called.""" - - return self._screen_width - - def screen_height(self) -> int: - """The screen height in pixels. Only valid after `update()` is called.""" - - return self._screen_height - - def get_orientation(self) -> np.ndarray: - """Returns the device orientation. Please see specs.py for details.""" - - if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER: - self._adb_call_parser = adb_call_parser.AdbCallParser( - adb_controller=self._simulator.create_adb_controller() - ) - - self._update_orientation() - return self._orientation - - def _update_screen_size(self) -> None: - """Sets the screen size from a screenshot ignoring the color channel.""" - - screenshot = self._simulator.get_screenshot() - self._screen_height = screenshot.shape[0] - self._screen_width = screenshot.shape[1] - - def _set_show_touches(self, show: bool) -> None: - """Whether to display circles indicating the touch position.""" - - self._adb_call_parser.parse( - adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='show_touches', value='1' if show else '0' - ), - ) - ) - ) - - def _set_show_pointer_location(self, show: bool) -> None: - """Whether to display blue lines on the screen indicating touch position.""" - - self._adb_call_parser.parse( - adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='pointer_location', value='1' if show else '0' - ), - ) - ) - ) - - def _set_status_navigation_bars( - self, show_navigation: bool, show_status: bool - ) -> None: - """Whether to display the status (top) and navigation (bottom) bars.""" - - if show_navigation and show_status: - policy_control_value = 'null*' - elif show_navigation and not show_status: - policy_control_value = 'immersive.status=*' - elif not show_navigation and show_status: - policy_control_value = 'immersive.navigation=*' - else: - policy_control_value = 'immersive.full=*' - - self._adb_call_parser.parse( - adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='policy_control', value=policy_control_value - ), - ) - ) - ) - - def _update_orientation(self) -> None: - """Updates the current device orientation.""" - - # Skip fetching the orientation if we already have it. - if not np.all(self._orientation == np.zeros(4)): - return - - orientation_response = self._adb_call_parser.parse( - adb_pb2.AdbRequest( - get_orientation=adb_pb2.AdbRequest.GetOrientationRequest() - ) - ) - if orientation_response.status != adb_pb2.AdbResponse.Status.OK: - logging.error('Got bad orientation: %r', orientation_response) - return - - orientation = orientation_response.get_orientation.orientation - if orientation not in {0, 1, 2, 3}: - logging.error('Got bad orientation: %r', orientation) - return - - # Transform into one-hot format. - orientation_onehot = np.zeros([4], dtype=np.uint8) - orientation_onehot[orientation] = 1 - self._orientation = orientation_onehot diff --git a/android_env/components/device_settings_test.py b/android_env/components/device_settings_test.py deleted file mode 100644 index 83eb5f55..00000000 --- a/android_env/components/device_settings_test.py +++ /dev/null @@ -1,228 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import config_classes -from android_env.components import device_settings as device_settings_lib -from android_env.components.simulators import base_simulator -import numpy as np - - -class DeviceSettingsTest(parameterized.TestCase): - - def test_screen_size_before_update(self): - """The screen size should be 0x0 without calling `update()`.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - height = device_settings.screen_height() - width = device_settings.screen_width() - - # Assert. - self.assertEqual(height, 0) - self.assertEqual(width, 0) - - def test_screen_size_after_update(self): - """The screen size should be set after calling `update()`.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - simulator.get_screenshot.return_value = np.random.randint( - low=0, high=255, size=(123, 456, 3), dtype=np.uint8 - ) - adb_controller = simulator.create_adb_controller.return_value - adb_controller.execute_command.return_value = b'' - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - device_settings.update(config_classes.DeviceSettingsConfig()) - height = device_settings.screen_height() - width = device_settings.screen_width() - - # Assert. - self.assertEqual(height, 123) - self.assertEqual(width, 456) - - @parameterized.named_parameters( - ( - 'show_touches', - config_classes.DeviceSettingsConfig(show_touches=True), - mock.call( - ['shell', 'settings', 'put', 'system', 'show_touches', '1'], - timeout=None, - ), - ), - ( - 'show_touches_false', - config_classes.DeviceSettingsConfig(show_touches=False), - mock.call( - ['shell', 'settings', 'put', 'system', 'show_touches', '0'], - timeout=None, - ), - ), - ( - 'show_pointer_location', - config_classes.DeviceSettingsConfig(show_pointer_location=True), - mock.call( - ['shell', 'settings', 'put', 'system', 'pointer_location', '1'], - timeout=None, - ), - ), - ( - 'show_pointer_location_false', - config_classes.DeviceSettingsConfig(show_pointer_location=False), - mock.call( - ['shell', 'settings', 'put', 'system', 'pointer_location', '0'], - timeout=None, - ), - ), - ( - 'show_navigation_and_status', - config_classes.DeviceSettingsConfig( - show_navigation_bar=True, show_status_bar=True - ), - mock.call( - ['shell', 'settings', 'put', 'global', 'policy_control', 'null*'], - timeout=None, - ), - ), - ( - 'show_navigation_and_no_status', - config_classes.DeviceSettingsConfig( - show_navigation_bar=True, show_status_bar=False - ), - mock.call( - [ - 'shell', - 'settings', - 'put', - 'global', - 'policy_control', - 'immersive.status=*', - ], - timeout=None, - ), - ), - ( - 'show_no_navigation_and_status', - config_classes.DeviceSettingsConfig( - show_navigation_bar=False, show_status_bar=True - ), - mock.call( - [ - 'shell', - 'settings', - 'put', - 'global', - 'policy_control', - 'immersive.navigation=*', - ], - timeout=None, - ), - ), - ( - 'show_no_navigation_and_no_status', - config_classes.DeviceSettingsConfig( - show_navigation_bar=False, show_status_bar=False - ), - mock.call( - [ - 'shell', - 'settings', - 'put', - 'global', - 'policy_control', - 'immersive.full=*', - ], - timeout=None, - ), - ), - ) - def test_update( - self, settings: config_classes.DeviceSettingsConfig, expected_call - ): - """We expect the right call for each setting.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - adb_controller = simulator.create_adb_controller.return_value - adb_controller.execute_command.return_value = b'' - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - device_settings.update(settings) - - # Assert. - adb_controller.execute_command.assert_has_calls( - [expected_call], any_order=True - ) - - def test_get_orientation_bad_response(self): - """The orientation should be unset if the underlying response is bad.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - adb_controller = simulator.create_adb_controller.return_value - adb_controller.execute_command.return_value = b'' - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - orientation = device_settings.get_orientation() - - # Assert. - np.testing.assert_array_equal(orientation, np.zeros(4)) - - def test_get_orientation_bad_orientation(self): - """The orientation should be unset if the underlying orientation is bad.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - adb_controller = simulator.create_adb_controller.return_value - adb_controller.execute_command.return_value = b' InputDeviceOrientation: 9' - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - orientation = device_settings.get_orientation() - - # Assert. - np.testing.assert_array_equal(orientation, np.zeros(4)) - - def test_get_orientation_success(self): - """Checks that the orientation comes back as expected.""" - - # Arrange. - simulator = mock.create_autospec(base_simulator.BaseSimulator) - adb_controller = simulator.create_adb_controller.return_value - adb_controller.execute_command.return_value = b' InputDeviceOrientation: 3' - device_settings = device_settings_lib.DeviceSettings(simulator) - - # Act. - orientation = device_settings.get_orientation() - # The output should be idempotent if the underlying system did not change. - orientation_again = device_settings.get_orientation() - - # Assert. - np.testing.assert_array_equal(orientation, np.array([0, 0, 0, 1])) - np.testing.assert_array_equal(orientation, orientation_again) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/dumpsys_thread.py b/android_env/components/dumpsys_thread.py deleted file mode 100644 index 3466307a..00000000 --- a/android_env/components/dumpsys_thread.py +++ /dev/null @@ -1,125 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A ThreadFunction that runs and parses adb dumpsys.""" - -import concurrent.futures - -from absl import logging -from android_env.components import app_screen_checker as app_screen_checker_lib - -_Outcome = app_screen_checker_lib.AppScreenChecker.Outcome - - -class DumpsysThread: - """A thread that checks if the user is in the expected app screen.""" - - def __init__( - self, - app_screen_checker: app_screen_checker_lib.AppScreenChecker, - check_frequency: int = 10, - max_failed_current_activity: int = 10, - ): - """Initializes the dumpsys reader thread. - - This loops forever checking if the user is in the expected screen dictated - by `app_screen_checker`. These analyses are too expensive to be in the - critical path of AndroidEnv::step() so we consume them async from this - separate thread. - - Args: - app_screen_checker: The class that actually determines if the current - screen matches the expected screen. - check_frequency: Integer. We only call dumpsys 1/check_frequency times in - each iteration of the while loop below. - max_failed_current_activity: Integer. We try to fetch the current activity - but sometimes it fails. If it fails more than - `max_failed_current_activity` consecutive times, we declare that the - user has exited `expected_activity`. - """ - - self._app_screen_checker = app_screen_checker - self._main_loop_counter = 0 - self._check_frequency = check_frequency - self._max_failed_activity_extraction = max_failed_current_activity - self._num_failed_activity_extraction = 0 - self._latest_check: concurrent.futures.Future | None = None - - def check_user_exited(self, timeout: float | None = None) -> bool: - """Returns True if the user is not in the expected screen. - - Args: - timeout: An optional time in seconds to block waiting for the result of - the (expensive) checking operation. If None, the function will return - immediately with `False`. - - Returns: - Whether the user of the Android device has exited the expected screen - determined by `AppScreenChecker` given at __init__(). - """ - - # Update and check loop_counter against check_frequency. - self._main_loop_counter += 1 - if (self._check_frequency <= 0 or - self._main_loop_counter < self._check_frequency): - return False - self._main_loop_counter = 0 - - # If the latest check is None, perform a check and return. - if self._latest_check is None: - with concurrent.futures.ThreadPoolExecutor() as executor: - self._latest_check = executor.submit(self._check_impl) - return False - - # If there's a check in flight, continue only if it's finished. - if not timeout and not self._latest_check.done(): - return False - - v = self._latest_check.result(timeout=timeout) - self._latest_check = None # Reset the check. - return v - - def _check_impl(self) -> bool: - """The synchronous implementation of Dumpsys.""" - - outcome = self._app_screen_checker.matches_current_app_screen() - - # We were unable to determine the current activity. - if outcome == _Outcome.FAILED_ACTIVITY_EXTRACTION: - self._num_failed_activity_extraction += 1 - logging.info('self._num_failed_activity_extraction: %s', - self._num_failed_activity_extraction) - if (self._num_failed_activity_extraction >= - self._max_failed_activity_extraction): - logging.error('Maximum number of failed activity extraction reached.') - self._num_failed_activity_extraction = 0 - return True - else: - self._num_failed_activity_extraction = 0 - - # The current app screen matches all expectations. - if (outcome == _Outcome.SUCCESS or - outcome == _Outcome.EMPTY_EXPECTED_ACTIVITY): - return False - - # Player has exited the app. Terminate the episode. - elif outcome == _Outcome.UNEXPECTED_ACTIVITY: - return True - - # Player has exited the main game. Terminate the episode. - elif outcome == _Outcome.UNEXPECTED_VIEW_HIERARCHY: - return True - - return False diff --git a/android_env/components/dumpsys_thread_test.py b/android_env/components/dumpsys_thread_test.py deleted file mode 100644 index fad76656..00000000 --- a/android_env/components/dumpsys_thread_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.dumpsys_thread.""" - -from unittest import mock - -from absl.testing import absltest -from android_env.components import app_screen_checker as screen_checker -from android_env.components import dumpsys_thread - - -class DumpsysThreadTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._app_screen_checker = mock.create_autospec( - screen_checker.AppScreenChecker) - - def test_unexpected_activity(self): - dumpsys = dumpsys_thread.DumpsysThread( - app_screen_checker=self._app_screen_checker, check_frequency=1) - outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY - self._app_screen_checker.matches_current_app_screen.return_value = outcome - # The first time that `check_user_exited()` is called, it'll only trigger - # the processing, but it should return immediately. - self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) - # The second time it should then wait for the result. - self.assertTrue(dumpsys.check_user_exited(timeout=1.0)) - - def test_unexpected_view_hierarchy(self): - dumpsys = dumpsys_thread.DumpsysThread( - app_screen_checker=self._app_screen_checker, check_frequency=1) - outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY - self._app_screen_checker.matches_current_app_screen.return_value = outcome - self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) - self.assertTrue(dumpsys.check_user_exited(timeout=1.0)) - - def test_success(self): - dumpsys = dumpsys_thread.DumpsysThread( - app_screen_checker=self._app_screen_checker, check_frequency=1) - outcome = screen_checker.AppScreenChecker.Outcome.SUCCESS - self._app_screen_checker.matches_current_app_screen.return_value = outcome - self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) - self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) - - def test_skipped(self): - dumpsys = dumpsys_thread.DumpsysThread( - app_screen_checker=self._app_screen_checker, check_frequency=5) - self._app_screen_checker.matches_current_app_screen.side_effect = [ - screen_checker.AppScreenChecker.Outcome.SUCCESS, - screen_checker.AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION - ] - - for _ in range(17): - self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) - - # The first 4 calls will hit the early exit from `check_frequency`. - # The 5th call will trigger the processing (increasing the call count to - # matches_current_app_screen() by 1), but it should return early. - # The 10th call will find a result of the previous processing, and it should - # be SUCCESS. - # The next 4 calls (11, 12, 13, 14) will hit the early exit from - # `check_frequency`. - # The 15th call should trigger the processing again (increasing the call - # count to matches_current_app_screen() by 1), but it should return early. - # The next 2 calls (16, 17) will hit the early exit from `check_frequency`. - # In total there should be only two calls to `matches_current_app_screen()`. - expected_call_count = 2 - self.assertEqual( - self._app_screen_checker.matches_current_app_screen.call_count, - expected_call_count) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/errors.py b/android_env/components/errors.py deleted file mode 100644 index 50ef8b4f..00000000 --- a/android_env/components/errors.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Definitions of exceptions used by AndroidEnv.""" - - -class AndroidEnvError(Exception): - """Base class for all known errors generated by AndroidEnv.""" - - # An integer that identifies this class of error. - # Subclasses should use a different value. - ERROR_CODE: int = 0 - - -class ReadObservationError(AndroidEnvError): - """When the environment is unable to obtain an observation from a simulator.""" - - ERROR_CODE = 1 - - -class CoordinatorError(AndroidEnvError): - """Error raised by the Coordinator.""" - - ERROR_CODE = 2 - - -class TooManyRestartsError(CoordinatorError): - """The number of restarts has exceeded _MAX_RESTART_TRIES.""" - - ERROR_CODE = 3 - - -class AdbControllerError(AndroidEnvError): - """Errors that can be raised by ADBController.""" - - ERROR_CODE = 4 - - -class SimulatorError(AndroidEnvError): - """Errors that can be raised by a simulator.""" - - ERROR_CODE = 5 - - -class SendActionError(AndroidEnvError): - """Raised when action couldn't be sent successfully.""" - - ERROR_CODE = 6 - - -class StepCommandError(AndroidEnvError): - """Raised when setup step interpreter cannot process a command.""" - - ERROR_CODE = 7 - - -class WaitForAppScreenError(StepCommandError): - """Raised when the wait_for_app_screen success check is not met.""" - - ERROR_CODE = 8 - - -class CheckInstallError(StepCommandError): - """Raised when the check_install success check is not met.""" - - ERROR_CODE = 9 - - -def from_code(code: int, msg: str = '') -> AndroidEnvError | None: - """Returns an AndroidEnvError instance from the given arguments.""" - - code_to_error = { - 0: AndroidEnvError, - 1: ReadObservationError, - 2: CoordinatorError, - 3: TooManyRestartsError, - 4: AdbControllerError, - 5: SimulatorError, - 6: SendActionError, - 7: StepCommandError, - 8: WaitForAppScreenError, - 9: CheckInstallError, - } - - if code in code_to_error: - return code_to_error[code](msg) diff --git a/android_env/components/errors_test.py b/android_env/components/errors_test.py deleted file mode 100644 index df0deca6..00000000 --- a/android_env/components/errors_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for errors.py.""" - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import errors - - -class ErrorsTest(parameterized.TestCase): - - @parameterized.parameters( - (errors.ReadObservationError, 1), - (errors.CoordinatorError, 2), - (errors.TooManyRestartsError, 3), - (errors.AdbControllerError, 4), - (errors.SimulatorError, 5), - (errors.SendActionError, 6), - (errors.StepCommandError, 7), - (errors.WaitForAppScreenError, 8), - (errors.CheckInstallError, 9), - ) - def test_error_codes(self, error, expected_error_code): - with self.assertRaises(error) as context: - raise error() - self.assertEqual(context.exception.ERROR_CODE, expected_error_code) - - def test_error_codes_unique(self): - error_codes = set() - errors_list = ( - errors.ReadObservationError, - errors.CoordinatorError, - errors.TooManyRestartsError, - errors.AdbControllerError, - errors.SimulatorError, - errors.SendActionError, - errors.StepCommandError, - errors.WaitForAppScreenError, - errors.CheckInstallError, - ) - for error in errors_list: - self.assertNotIn(error.ERROR_CODE, error_codes) - error_codes.add(error.ERROR_CODE) - - @parameterized.parameters([ - errors.ReadObservationError(), - errors.CoordinatorError(), - errors.TooManyRestartsError(), - errors.AdbControllerError(), - errors.SimulatorError(), - errors.SendActionError(), - errors.StepCommandError(), - errors.WaitForAppScreenError(), - errors.CheckInstallError(), - ]) - def test_all_errors_are_androidenv_errors(self, error): - self.assertIsInstance(error, errors.AndroidEnvError) - - @parameterized.named_parameters([ - ('less_than_zero', -1), - # The largest `ERROR_CODE` is currently `CheckInstallError == 10`. - ('greater_than_all_errors', 10 + 1), - ('less_than_zero_float', -3.14159265), - ('greater_than_all_errors_float', 123.456), - ]) - def test_from_code_unsupported_code(self, code: int): - """Unsupported errors should raise `RuntimeError`.""" - - self.assertIsNone(errors.from_code(code)) - - @parameterized.parameters([ - (-1, None, 'No such error code.'), - (0, errors.AndroidEnvError, 'hello'), - (0, errors.AndroidEnvError, ''), - (1, errors.ReadObservationError, 'Could not read obs.'), - (2, errors.CoordinatorError, 'Some error'), - (3, errors.TooManyRestartsError, 'Too many already...'), - (4, errors.AdbControllerError, 'Some adb error...'), - (5, errors.SimulatorError, 'Simulator is not coping.'), - (6, errors.SendActionError, 'Could not send action.'), - (7, errors.StepCommandError, 'Some issue setting up the task.'), - (8, errors.WaitForAppScreenError, 'Waited for too long!'), - (9, errors.CheckInstallError, 'App did not install correctly.'), - ]) - def test_from_code(self, code: int, expected_class: errors.AndroidEnvError, - msg: str): - """`from_code` should produce consistent outputs for known errors.""" - - error = errors.from_code(code, msg) - if error is not None: - self.assertIsInstance(error, expected_class) - self.assertEqual(error.ERROR_CODE, code) - self.assertEqual(str(error), msg) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/log_stream.py b/android_env/components/log_stream.py deleted file mode 100644 index d9a047b5..00000000 --- a/android_env/components/log_stream.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Abstract class for handling a stream of logs from a simulator.""" - -import abc -from collections.abc import Generator, Sequence -import threading -from absl import logging - - -class LogStream(metaclass=abc.ABCMeta): - """Manages the stream of logs output by a simulator.""" - - def __init__(self, verbose: bool = False): - self._verbose = verbose - self._filters = [] - self._should_stream = threading.Event() - - def get_stream_output(self) -> Generator[str, None, None]: - """Starts log process and returns the stream of logs.""" - for line in self._get_stream_output(): - if self._verbose: - logging.info('line: %r', line) - if self._should_stream.is_set(): - yield line - - @abc.abstractmethod - def _get_stream_output(self): - """Starts log process and returns the stream of logs.""" - pass - - @abc.abstractmethod - def stop_stream(self) -> None: - """Terminates the log stream. - - NOTE: This should only be called _after_ `get_stream_output()`. - """ - - def pause_stream(self) -> None: - """No lines are yielded while the event is not set.""" - logging.info('Pausing LogStream.') - self._should_stream.clear() - - def resume_stream(self) -> None: - """The stream will continue yielding lines if the event is set.""" - logging.info('Resuming LogStream.') - self._should_stream.set() - - def set_log_filters(self, log_filters: Sequence[str]): - """Sets the filters for the log stream.""" - self._filters = list(log_filters) + ['*:S'] diff --git a/android_env/components/log_stream_test.py b/android_env/components/log_stream_test.py deleted file mode 100644 index 90fd7dcf..00000000 --- a/android_env/components/log_stream_test.py +++ /dev/null @@ -1,109 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for log_stream.""" - -from absl.testing import absltest -from android_env.components import log_stream - - -class FakeLogStream(log_stream.LogStream): - - def __init__(self, filter_name: str): - super().__init__() - self._filter_name = filter_name - - def _get_stream_output(self): - """Starts a log process and returns the stream of logs.""" - lines = [ - f'{self._filter_name} fake_line_1', - 'fake_line_2', - f'{self._filter_name} fake_line_3', - f'{self._filter_name} fake_line_4', - 'fake_line_5', - 'fake_line_6', - ] - for line in lines: - if f'{self._filter_name}:V' in self._filters: - if self._filter_name in line: - yield line - else: - yield line - - def stop_stream(self): - """Stops the log stream from the simulator.""" - pass - - -class LogStreamTest(absltest.TestCase): - - def test_get_stream_output(self): - filter_name = 'AndroidRLTask' - stream = FakeLogStream(filter_name=filter_name) - stream.resume_stream() - stream_output = stream.get_stream_output() - expected_lines = [ - f'{filter_name} fake_line_1', - 'fake_line_2', - f'{filter_name} fake_line_3', - f'{filter_name} fake_line_4', - 'fake_line_5', - 'fake_line_6', - ] - for line, expected_line in zip(stream_output, expected_lines): - self.assertEqual(line, expected_line) - - def test_set_log_filters(self): - filter_name = 'AndroidRLTask' - stream = FakeLogStream(filter_name=filter_name) - stream.set_log_filters([f'{filter_name}:V']) - stream.resume_stream() - stream_output = stream.get_stream_output() - expected_lines = [ - f'{filter_name} fake_line_1', - f'{filter_name} fake_line_3', - f'{filter_name} fake_line_4', - ] - for line, expected_line in zip(stream_output, expected_lines): - self.assertEqual(line, expected_line) - - def test_pause_resume_stream(self): - filter_name = 'AndroidRLTask' - stream = FakeLogStream(filter_name=filter_name) - stream.resume_stream() - stream_output = stream.get_stream_output() - expected_lines = [ - f'{filter_name} fake_line_1', - 'fake_line_2', - f'{filter_name} fake_line_3', - f'{filter_name} fake_line_4', - 'fake_line_5', - 'fake_line_6', - ] - for line, expected_line in zip(stream_output, expected_lines): - self.assertEqual(line, expected_line) - # If the stream is paused, we expect no lines to be yielded. - stream.pause_stream() - stream_output = list(stream.get_stream_output()) - self.assertEmpty(stream_output) - # If the stream is resumed, we expect to see all lines yielded. - stream.resume_stream() - stream_output = stream.get_stream_output() - for line, expected_line in zip(stream_output, expected_lines): - self.assertEqual(line, expected_line) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/logcat_thread.py b/android_env/components/logcat_thread.py deleted file mode 100644 index fffadbcb..00000000 --- a/android_env/components/logcat_thread.py +++ /dev/null @@ -1,131 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A class that launches a thread to read Android log outputs.""" - -from collections.abc import Callable -import dataclasses -import re -import threading - -from absl import logging -from android_env.components import log_stream as log_stream_lib - - -@dataclasses.dataclass -class EventListener: - """A function that's called when an event is triggered.""" - - regexp: re.Pattern[str] - handler_fn: Callable[[re.Pattern[str], re.Match[str]], None] - - -class LogcatThread: - """Reads log entries in a separate thread.""" - - def __init__(self, log_stream: log_stream_lib.LogStream): - """Initializes this LogcatThread with optional filters. - - Please see https://developer.android.com/studio/command-line/logcat for more - info on `logcat`. - - Args: - log_stream: Stream of logs from simulator. - """ - - self._log_stream = log_stream - self._listeners = {} - self._line_ready = threading.Event() - self._line_ready.set() - self._should_stop = threading.Event() - self._thread = threading.Thread(target=self._process_logs) - self._thread.daemon = True - self._thread.start() - - def add_event_listener(self, event_listener: EventListener) -> None: - """Adds `fn` to the list of handlers to call when `event` occurs.""" - event_regexp = event_listener.regexp - if event_regexp not in self._listeners: - self._listeners[event_regexp] = [] - self._listeners[event_regexp].append(event_listener.handler_fn) - - def remove_event_listener(self, event_listener: EventListener) -> None: - """Removes `fn` from the list of handlers to call when `event` occurs.""" - event_regexp = event_listener.regexp - if event_regexp not in self._listeners: - logging.error('Event: %r is not registered.', event_regexp) - return - self._listeners[event_regexp].remove(event_listener.handler_fn) - - def line_ready(self) -> threading.Event: - """Indicates whether all listeners have been notified for a given line.""" - return self._line_ready - - def pause(self): - self._log_stream.pause_stream() - - def resume(self): - self._log_stream.resume_stream() - - def kill(self): - self._should_stop.set() - self._log_stream.stop_stream() - self._thread.join(timeout=3.0) - - def _process_logs(self) -> None: - """A loop that runs until `self._should_stop` is set().""" - - # pylint: disable=g-line-too-long - # Format is: "TIME_SEC PID TID PRIORITY TAG: MESSAGE" - # - # Example: - # ' 1553110400.424 5583 5658 D NostalgicRacer: com.google.example.games.nostalgicracer.views.renderers.OpenGLRenderDriver@912fb8.onSurfaceChanged 480x320' # - # pylint: enable=g-line-too-long - - logline_regexp = r""" - ^ # Beginning of the line. - [ ]+(?P[0-9]+\.[0-9]+) # Spaces and a float. - [ ]+(?P[0-9]+) # Spaces and an int. - [ ]+(?P[0-9]+) # Spaces and an int. - [ ]+(?P.) # Spaces and any single character. - [ ]+(?P[^:]*): # Spaces and any char that's not ':'. - [ ](?P.*)$ # The actual log message. - """ - logline_re = re.compile(logline_regexp, re.VERBOSE) - - for line in self._log_stream.get_stream_output(): - if self._should_stop.is_set(): - break - - if not line: # Skip empty lines. - continue - - matches = logline_re.match(line) - if not matches or len(matches.groups()) != 6: - continue - - # Make sure that values are not read until all listeners are notified. - self._line_ready.clear() - - # We're currently only consuming `message`, but we may use the other - # fields in the future. - content = matches.group('message') - for ev, listeners in self._listeners.items(): - ev_matches = ev.match(content) - if ev_matches: - for listener in listeners: # Notify listeners. - listener(ev, ev_matches) - - self._line_ready.set() # Allow consumers to read values. diff --git a/android_env/components/logcat_thread_test.py b/android_env/components/logcat_thread_test.py deleted file mode 100644 index fadcbf93..00000000 --- a/android_env/components/logcat_thread_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.logcat_thread.""" - -import re -import threading - -from absl.testing import absltest -from android_env.components import log_stream -from android_env.components import logcat_thread -from android_env.proto import task_pb2 - - -class FakeStream: - """This class simulates the logs coming from ADB.""" - - def __init__(self): - self._values = [] - self._kill = False - self._lock = threading.Lock() - - def send_value(self, value): - with self._lock: - self._values.append(value) - - def has_next_value(self): - return bool(self._values) - - def kill(self): - self._kill = True - - def __iter__(self): - while True: - if self._kill: - return - if not self._values: - continue - else: - with self._lock: - next_value = self._values.pop(0) - yield next_value - - -def make_stdout(data): - """Returns a valid log output with given data as message.""" - return ' 1553110400.424 5583 5658 D Tag: %s' % data - - -class FakeLogStream(log_stream.LogStream): - """FakeLogStream class that wraps a FakeStream.""" - - def __init__(self): - super().__init__(verbose=False) - self.logs = FakeStream() - self.stream_is_alive = True - - def _get_stream_output(self): - return self.logs - - def stop_stream(self): - self.stream_is_alive = False - self.logs.kill() - - -class LogcatThreadTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.fake_log_stream = FakeLogStream() - - def tearDown(self): - self.fake_log_stream.stop_stream() - super().tearDown() - - def test_set_filters(self): - log_parsing_config = task_pb2.LogParsingConfig(filters=['AndroidRLTask:V']) - self.fake_log_stream.set_log_filters(log_parsing_config.filters) - _ = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) - expected_filters = ['AndroidRLTask:V', '*:S'] - self.assertEqual(expected_filters, self.fake_log_stream._filters) - - def test_kill(self): - logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) - self.assertTrue(self.fake_log_stream.stream_is_alive) - logcat.kill() - self.assertFalse(self.fake_log_stream.stream_is_alive) - - def test_listeners(self): - """Ensures that we can wait for a specific message without polling.""" - logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) - # Start yielding lines from LogStream. - logcat.resume() - - # Set up a listener that modifies an arbitrary state. - some_state = threading.Event() - - def my_handler(event: re.Pattern[str], match: re.Match[str]): - del event, match - nonlocal some_state - some_state.set() - - # Create a desired event and hook up the listener. - my_event = re.compile('Hello world') - listener = logcat_thread.EventListener(my_event, my_handler) - logcat.add_event_listener(listener) - self.fake_log_stream.logs.send_value('Hi there!') # This should not match. - self.assertFalse(some_state.is_set()) - self.fake_log_stream.logs.send_value(make_stdout('Hello world')) - some_state.wait(timeout=1.0) - self.assertTrue(some_state.is_set()) - - # Waiting for any events should also trigger the listener. - some_state.clear() - self.fake_log_stream.logs.send_value(make_stdout('Hello world')) - some_state.wait(timeout=1.0) - self.assertTrue(some_state.is_set()) - - # After removing the listener, it should not be called anymore. - some_state.clear() - logcat.remove_event_listener(listener) - self.fake_log_stream.logs.send_value(make_stdout('Hello world')) - some_state.wait(timeout=1.0) - self.assertFalse(some_state.is_set()) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/pixel_fns.py b/android_env/components/pixel_fns.py deleted file mode 100644 index 65bb2eea..00000000 --- a/android_env/components/pixel_fns.py +++ /dev/null @@ -1,70 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utils for AndroidEnv.""" - -from collections.abc import Sequence - -from dm_env import specs -import numpy as np - - -def touch_position_to_pixel_position( - touch_position: np.ndarray, - width_height: Sequence[int], -) -> tuple[int, int]: - """Maps touch position in [0,1] to the corresponding pixel on the screen.""" - touch_pixels = (touch_position * width_height).astype(np.int32) - cap_idx = lambda v, idx_len: min(v, idx_len - 1) - return tuple(map(cap_idx, touch_pixels, width_height)) - - -def transpose_pixels(frame: np.ndarray) -> np.ndarray: - """Converts image from shape (H, W, C) to (W, H, C) and vice-versa.""" - return np.transpose(frame, axes=(1, 0, 2)) - - -def orient_pixels(frame: np.ndarray, orientation: int) -> np.ndarray: - """Rotates screen pixels according to the given orientation.""" - - match orientation: - case 0: # PORTRAIT_90 - return frame - case 1: # LANDSCAPE_90 - return np.rot90(frame, k=3, axes=(0, 1)) - case 2: # PORTRAIT_180 - return np.rot90(frame, k=2, axes=(0, 1)) - case 3: # LANDSCAPE_270 - return np.rot90(frame, k=1, axes=(0, 1)) - case _: - raise ValueError( - 'Orientation must be an integer in [0, 3] but is %r' % orientation - ) - - -def convert_int_to_float(data: np.ndarray, data_spec: specs.Array): - """Converts an array of int values to floats between 0 and 1.""" - - if not np.issubdtype(data.dtype, np.integer): - raise TypeError(f'{data.dtype} is not an integer type') - if isinstance(data_spec, specs.BoundedArray): - value_min = data_spec.minimum - value_max = data_spec.maximum - else: - # We use the int type to figure out the boundaries. - iinfo = np.iinfo(data_spec.dtype) - value_min = iinfo.min - value_max = iinfo.max - return np.float32(1.0 * (data - value_min) / (value_max - value_min)) diff --git a/android_env/components/pixel_fns_test.py b/android_env/components/pixel_fns_test.py deleted file mode 100644 index 82024307..00000000 --- a/android_env/components/pixel_fns_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for pixel_fns.""" - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import pixel_fns -from dm_env import specs -import numpy as np - - -class UtilsTest(parameterized.TestCase): - - @parameterized.parameters( - ([0.5, 0.5], [320, 480], (160, 240)), - ([0.25, 0.75], [320, 480], (80, 360)), - ([0.0, 0.0], [320, 480], (0, 0)), - ([1.0, 1.0], [320, 480], (319, 479)), - ) - def test_touch_position_to_pixel_position( - self, touch_pos, width_height, pixel_pos): - self.assertEqual( - pixel_fns.touch_position_to_pixel_position( - np.array(touch_pos), width_height - ), - pixel_pos, - ) - - def test_transpose_pixels(self): - image = np.reshape(np.array(range(12)), (3, 2, 2)) - expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]] - self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2)) - self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all()) - - def test_orient_pixels(self): - image = np.reshape(np.array(range(12)), (3, 2, 2)) - - expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]] - rot_90 = 1 # LANDSCAPE_90 - rotated = pixel_fns.orient_pixels(image, rot_90) - self.assertEqual(rotated.shape, (2, 3, 2)) - self.assertTrue((rotated == expected_90).all()) - - expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]] - rot_180 = 2 # PORTRAIT_180 - rotated = pixel_fns.orient_pixels(image, rot_180) - self.assertEqual(rotated.shape, (3, 2, 2)) - self.assertTrue((rotated == expected_180).all()) - - expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]] - rot_270 = 3 # LANDSCAPE_270 - rotated = pixel_fns.orient_pixels(image, rot_270) - self.assertEqual(rotated.shape, (2, 3, 2)) - self.assertTrue((rotated == expected_270).all()) - - rot_0 = 0 # PORTRAIT_0 - rotated = pixel_fns.orient_pixels(image, rot_0) - self.assertEqual(rotated.shape, (3, 2, 2)) - self.assertTrue((rotated == image).all()) - - def test_convert_int_to_float_bounded_array(self): - spec = specs.BoundedArray( - shape=(4,), - dtype=np.int32, - minimum=[0, 1, 10, -2], - maximum=[5, 5, 20, 2], - name='bounded_array') - data = np.array([2, 2, 10, 0], dtype=np.int32) - float_data = pixel_fns.convert_int_to_float(data, spec) - np.testing.assert_equal( - np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data - ) - - def test_convert_int_to_float_bounded_array_broadcast(self): - spec = specs.BoundedArray( - shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array') - data = np.array([2, 3, 4], dtype=np.int16) - float_data = pixel_fns.convert_int_to_float(data, spec) - np.testing.assert_equal( - np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data) - - def test_convert_int_to_float_no_bounds(self): - spec = specs.Array( - shape=(3,), - dtype=np.int8, # int8 implies min=-128, max=127 - name='bounded_array') - data = np.array([-128, 0, 127], dtype=np.int16) - float_data = pixel_fns.convert_int_to_float(data, spec) - np.testing.assert_equal( - np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/setup_step_interpreter.py b/android_env/components/setup_step_interpreter.py deleted file mode 100644 index fb509db6..00000000 --- a/android_env/components/setup_step_interpreter.py +++ /dev/null @@ -1,182 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A component that parses and processes SetupSteps.""" - -from collections.abc import Sequence -import copy -import time -from typing import Any - -from absl import logging -from android_env.components import adb_call_parser as adb_call_parser_lib -from android_env.components import app_screen_checker -from android_env.components import errors -from android_env.proto import adb_pb2 -from android_env.proto import task_pb2 - - -class SetupStepInterpreter: - """An interpreter for SetupSteps.""" - - def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser): - """Initializes this interpreter. - - Args: - adb_call_parser: An object to communicate with Android via ADB. - """ - self._adb_call_parser = adb_call_parser - self._stats = { - 'error_count_adb_request': 0, - 'error_count_wait_for_app_screen': 0, - 'error_count_check_install': 0, - 'error_count_wait_for_message': 0, - 'total_time_waiting_for_app_screen': 0 - } - - def stats(self) -> dict[str, Any]: - return copy.deepcopy(self._stats) - - def interpret(self, setup_steps: Sequence[task_pb2.SetupStep]) -> None: - """Returns True if parsing and processing `setup_steps` is successful.""" - if setup_steps: - logging.info('Executing setup steps: %s', setup_steps) - for step in setup_steps: - self._process_step_command(step) - logging.info('Done executing setup steps.') - - def _process_step_command(self, step_cmd: task_pb2.SetupStep) -> None: - """Processes a single step command from a reset or extra setup.""" - - if not step_cmd: - logging.info('Empty step_cmd') - return - - logging.info('Executing step_cmd: %r', step_cmd) - step_type = step_cmd.WhichOneof('step') - success_condition = step_cmd.success_condition - success_check = success_condition.WhichOneof('check') - assert step_type or success_check, ( - 'At least one of step and success_condition must be defined.') - - num_tries = 0 - max_retries = max(success_condition.num_retries, 3) - latest_error = None - while num_tries < max_retries: - - num_tries += 1 - - try: - unused_adb_response = self._execute_step_cmd(step_cmd, step_type) - time.sleep(0.5) - self._check_success(success_check, success_condition) - return - - except NotImplementedError: - logging.exception('Not implemented error! Skipping this step command.') - return - - except errors.AdbControllerError as error: - latest_error = error - self._stats['error_count_adb_request'] += 1 - logging.exception('ADB call [%r] has failed. Try %d of %d.', - step_cmd.adb_request, num_tries, max_retries) - - except errors.WaitForAppScreenError as error: - latest_error = error - self._stats['error_count_wait_for_app_screen'] += 1 - logging.exception('Failed to wait for app screen. Try %d of %d.', - num_tries, max_retries) - - except errors.CheckInstallError as error: - latest_error = error - self._stats['error_count_check_install'] += 1 - logging.exception('Package [%r] not installed. Try %d of %d.', - success_condition.check_install.package_name, - num_tries, max_retries) - - raise errors.StepCommandError( - f'Step failed: [{step_cmd}]') from latest_error - - def _execute_step_cmd( - self, step_cmd: task_pb2.SetupStep, step_type: str | None - ) -> adb_pb2.AdbResponse | None: - """Executes a step command of given type.""" - - match step_type: - case None: - return None - case 'sleep': - time.sleep(step_cmd.sleep.time_sec) - return None - case 'adb_request': - response = self._adb_call_parser.parse(step_cmd.adb_request) - if response.status != adb_pb2.AdbResponse.Status.OK: - raise errors.AdbControllerError( - f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n' - f'Status: {response.status}\n' - f'Error: {response.error_message}' - ) - return response - case _: - raise NotImplementedError(f'No step command of type [{step_type}].') - - def _check_success( - self, - success_check: str | None, - success_condition: task_pb2.SuccessCondition, - ) -> None: - """Checks whether the given success condition was met.""" - - match success_check: - case None: - return None - case 'wait_for_app_screen': - wait_for_app_screen = success_condition.wait_for_app_screen - screen_checker = app_screen_checker.AppScreenChecker( - adb_call_parser=self._adb_call_parser, - expected_app_screen=wait_for_app_screen.app_screen, - ) - wait_time = screen_checker.wait_for_app_screen( - timeout_sec=wait_for_app_screen.timeout_sec - ) - self._stats['total_time_waiting_for_app_screen'] += wait_time - case 'check_install': - self._check_install(success_condition.check_install) - case _: - raise NotImplementedError(f'No success check called [{success_check}].') - - def _check_install(self, check_install: task_pb2.CheckInstall) -> None: - """Checks that the given package is installed.""" - - package = check_install.package_name - logging.info('Checking if package is installed: [%r]', package) - - request = adb_pb2.AdbRequest( - package_manager=adb_pb2.AdbRequest.PackageManagerRequest( - list=adb_pb2.AdbRequest.PackageManagerRequest.List( - packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages( - )))) - - start_time = time.time() - while time.time() - start_time < check_install.timeout_sec: - response = self._adb_call_parser.parse(request) - if package in response.package_manager.list.items: - logging.info('Done confirming that package is installed.') - return - time.sleep(0.1) - - logging.error('Package not found.') - raise errors.CheckInstallError() diff --git a/android_env/components/setup_step_interpreter_test.py b/android_env/components/setup_step_interpreter_test.py deleted file mode 100644 index 26468159..00000000 --- a/android_env/components/setup_step_interpreter_test.py +++ /dev/null @@ -1,342 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.setup_step_interpreter.""" - -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_call_parser -from android_env.components import errors -from android_env.components import setup_step_interpreter -from android_env.proto import adb_pb2 -from android_env.proto import task_pb2 - -from google.protobuf import text_format - - -def _to_proto(proto_class, text): - proto = proto_class() - text_format.Parse(text, proto) - return proto - - -class SetupStepInterpreterTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._parser = mock.create_autospec( - adb_call_parser.AdbCallParser, instance=True) - - def test_empty_setup_steps(self): - """Simple test where nothing should break, and nothing should be done. - - The test simply expects this test to not crash. - """ - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([]) - - def test_none_setup_steps(self): - """Simple test where nothing should break, and nothing should be done. - - The test simply expects this test to not crash. - """ - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - # Empty setup steps should be ignored. - interpreter.interpret([]) - - def test_invalid_setup_step(self): - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - # Empty setup steps should be ignored. - self.assertRaises(AssertionError, interpreter.interpret, - [task_pb2.SetupStep()]) - - def test_adb_install_apk_filesystem(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -adb_request: { - install_apk: { - filesystem: { - path: "/my/favorite/dir/my_apk.apk" - } - } -}""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - install_apk=adb_pb2.AdbRequest.InstallApk( - filesystem=adb_pb2.AdbRequest.InstallApk.Filesystem( - path='/my/favorite/dir/my_apk.apk')))) - - def test_adb_force_stop(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -adb_request: { force_stop: { package_name: "my.app.Activity" } }""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - force_stop=adb_pb2.AdbRequest.ForceStop( - package_name='my.app.Activity'))) - - def test_adb_start_activity(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -adb_request: { - start_activity: { - full_activity: "my.app.Activity" - extra_args: "arg1" - extra_args: "arg2" - } -}""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - start_activity=adb_pb2.AdbRequest.StartActivity( - full_activity='my.app.Activity', extra_args=['arg1', 'arg2']))) - - def test_adb_single_tap(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto(task_pb2.SetupStep, """ -adb_request: { - tap: { - x: 321 - y: 654 - } -}""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=321, y=654))) - - def test_adb_press_button(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto(task_pb2.SetupStep, - """ adb_request: { press_button: { button: HOME } }""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - press_button=adb_pb2.AdbRequest.PressButton( - button=adb_pb2.AdbRequest.PressButton.Button.HOME))) - - self._parser.reset_mock() - interpreter.interpret([ - _to_proto(task_pb2.SetupStep, - """ adb_request: { press_button: { button: BACK } }""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - press_button=adb_pb2.AdbRequest.PressButton( - button=adb_pb2.AdbRequest.PressButton.Button.BACK))) - - def test_adb_start_screen_pinning(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -adb_request: { - start_screen_pinning: { - full_activity: "my.app.HighlanderApp" # "There can be only one". - } -}""") - ]) - self._parser.parse.assert_called_once_with( - adb_pb2.AdbRequest( - start_screen_pinning=adb_pb2.AdbRequest.StartScreenPinning( - full_activity='my.app.HighlanderApp'))) - - @mock.patch('time.sleep') - def test_time_sleep(self, mock_sleep): - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret( - [_to_proto(task_pb2.SetupStep, """sleep: { time_sec: 0.875 }""")]) - assert mock_sleep.call_count == 2 - mock_sleep.assert_has_calls([mock.call(0.875), mock.call(0.5)]) - - @mock.patch('time.sleep') - def test_wait_for_app_screen_empty_activity(self, unused_mock_sleep): - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - with self.assertRaises(errors.StepCommandError): - interpreter.interpret([ - _to_proto(task_pb2.SetupStep, - """success_condition: {wait_for_app_screen: { }}""") - ]) - - @mock.patch('time.sleep') - def test_check_install_not_installed(self, unused_mock_sleep): - self._parser.parse.return_value = adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ - 'com.some.package', - 'not.what.you.are.looking.for', - ]))) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - with self.assertRaises(errors.StepCommandError): - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -success_condition: { - check_install: { - package_name: "faz" - timeout_sec: 0.0001 - } -} -""") - ]) - - def test_check_install_installed(self): - self._parser.parse.return_value = adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ - 'com.some.package', - 'baz', - ]))) - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - # The test checks that this command raises no AssertionError. - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -success_condition: { - check_install: { - package_name: "baz" - timeout_sec: 0.0001 - } -}""") - ]) - - def test_num_retries_failure(self): - self._parser.parse.side_effect = [ - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List( - items=[]))), - ] * 3 - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - with self.assertRaises(errors.StepCommandError): - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -success_condition: { - check_install: { - package_name: "faz" - timeout_sec: 0.0001 - } - num_retries: 3 -}""") - ]) - # We retried 3 times after the first call, so we expect 3+1 calls. - self.assertEqual(self._parser.parse.call_count, 3) - - @mock.patch('time.sleep') - def test_num_retries_success(self, unused_mock_sleep): - self._parser.parse.side_effect = [ - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List( - items=[]))), - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List( - items=[]))), - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ - 'com.some.package', - 'bar', - ]))), - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[]))) - ] - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -success_condition: { - check_install: { - package_name: "bar" - timeout_sec: 0.0001 - } - num_retries: 5 -}""") - ]) - # The check should succeed on the third try. - self.assertEqual(self._parser.parse.call_count, 3) - - def test_retry_step(self): - self._parser.parse.side_effect = [ - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List( - items=[]))), - adb_pb2.AdbResponse( - package_manager=adb_pb2.AdbResponse.PackageManagerResponse( - list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ - 'com.some.package', - 'bar', - ]))), - ] - interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=self._parser) - interpreter.interpret([ - _to_proto( - task_pb2.SetupStep, """ -success_condition: { - check_install: { - package_name: "bar" - timeout_sec: 0.0001 - } - num_retries: 2 -}""") - ]) - # We expect the check to fail once and succeed on the second pass. - self.assertEqual(self._parser.parse.call_count, 2) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/simulators/__init__.py b/android_env/components/simulators/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/components/simulators/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/components/simulators/base_simulator.py b/android_env/components/simulators/base_simulator.py deleted file mode 100644 index 49ab1040..00000000 --- a/android_env/components/simulators/base_simulator.py +++ /dev/null @@ -1,208 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A base class for talking to different types of Android simulators.""" - -import abc -from collections.abc import Callable -import threading -import time - -from absl import logging -from android_env.components import adb_controller -from android_env.components import config_classes -from android_env.components import errors -from android_env.components import log_stream -from android_env.proto import state_pb2 -import numpy as np - - -class BaseSimulator(metaclass=abc.ABCMeta): - """An interface for communicating with an Android simulator.""" - - def __init__(self, config: config_classes.SimulatorConfig): - """Instantiates a BaseSimulator object. - - The simulator may be an emulator, virtual machine or even a physical device. - Each simulator has its own AdbController that is used for internal - bookkeeping. - - Args: - config: Settings for this simulator. - """ - - self._config = config - self._interaction_thread: InteractionThread | None = None - - # An increasing number that tracks the attempt at launching the simulator. - self._num_launch_attempts: int = 0 - - def get_logs(self) -> str: - """Returns logs recorded by the simulator (if provided).""" - return 'No simulator logs provided.' - - @abc.abstractmethod - def adb_device_name(self) -> str: - """Returns the device name that the adb client will connect to.""" - - @abc.abstractmethod - def create_adb_controller(self) -> adb_controller.AdbController: - """Returns an ADB controller which can communicate with this simulator.""" - - @abc.abstractmethod - def create_log_stream(self) -> log_stream.LogStream: - """Creates a stream of logs from the simulator.""" - - def launch(self) -> None: - """Starts the simulator.""" - - # Stop screenshot thread if it's enabled. - if self._interaction_thread is not None: - self._interaction_thread.stop() - self._interaction_thread.join() - - self._num_launch_attempts += 1 - try: - self._launch_impl() - except Exception as error: - for line in self.get_logs().splitlines(): - logging.error(line) - raise errors.SimulatorError( - 'Exception caught in simulator. Please see the simulator logs ' - 'above for more details.' - ) from error - - # Start interaction thread. - if self._config.interaction_rate_sec > 0: - self._interaction_thread = InteractionThread( - self._get_screenshot_impl, self._config.interaction_rate_sec - ) - self._interaction_thread.start() - - @abc.abstractmethod - def _launch_impl(self) -> None: - """Platform specific launch implementation.""" - - @abc.abstractmethod - def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: - """Sends a touch event to be executed on the simulator. - - Args: - touches: A list of touch events. Each element in the list corresponds to a - single touch event. Each touch event tuple should have: - 0 x: The horizontal coordinate of this event. - 1 y: The vertical coordinate of this event. - 2 is_down: Whether the finger is touching or not the screen. - 3 identifier: Identifies a particular finger in a multitouch event. - """ - - @abc.abstractmethod - def send_key(self, keycode: np.int32, event_type: str) -> None: - """Sends a keyboard event. - - Args: - keycode: Represents a specific keyboard key. This is platform and - simulator-specific. - event_type: Type of key event to be sent. - """ - - def load_state( - self, request: state_pb2.LoadStateRequest - ) -> state_pb2.LoadStateResponse: - """Loads a state. - - Args: - request: A `LoadStateRequest` containing any parameters necessary to - specify how/what state to load. - - Returns: - A `LoadStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - raise NotImplementedError('This simulator does not support load_state()') - - def save_state( - self, request: state_pb2.SaveStateRequest - ) -> state_pb2.SaveStateResponse: - """Saves a state. - - Args: - request: A `SaveStateRequest` containing any parameters necessary to - specify how/what state to save. - - Returns: - A `SaveStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - raise NotImplementedError('This simulator does not support save_state()') - - def get_screenshot(self) -> np.ndarray: - """Returns pixels representing the current screenshot of the simulator.""" - - if self._config.interaction_rate_sec > 0: - assert self._interaction_thread is not None - return self._interaction_thread.screenshot() # Async mode. - else: - return self._get_screenshot_impl() # Sync mode. - - @abc.abstractmethod - def _get_screenshot_impl(self) -> np.ndarray: - """Actual implementation of `get_screenshot()`. - - The output numpy array should have shape [height, width, num_channels] and - can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved - as a PNG file using my_pil.save('/tmp/my_screenshot.png', 'PNG'). - """ - - def close(self): - """Frees up resources allocated by this object.""" - - if self._interaction_thread is not None: - self._interaction_thread.stop() - self._interaction_thread.join() - - -class InteractionThread(threading.Thread): - """A thread that gets screenshot in the background.""" - - def __init__( - self, - get_screenshot_fn: Callable[[], np.ndarray], - interaction_rate_sec: float, - ): - super().__init__() - self._get_screenshot_fn = get_screenshot_fn - self._interaction_rate_sec = interaction_rate_sec - self._should_stop = threading.Event() - self._screenshot = self._get_screenshot_fn() - - def run(self): - last_read = time.time() - while not self._should_stop.is_set(): - self._screenshot = self._get_screenshot_fn() - now = time.time() - elapsed = now - last_read - last_read = now - sleep_time = self._interaction_rate_sec - elapsed - if sleep_time > 0.0: - time.sleep(sleep_time) - logging.info('InteractionThread.run() finished.') - - def stop(self): - logging.info('Stopping InteractionThread.') - self._should_stop.set() - - def screenshot(self) -> np.ndarray: - return self._screenshot diff --git a/android_env/components/simulators/base_simulator_test.py b/android_env/components/simulators/base_simulator_test.py deleted file mode 100644 index 6b5457ad..00000000 --- a/android_env/components/simulators/base_simulator_test.py +++ /dev/null @@ -1,181 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import time -from unittest import mock - -from absl.testing import absltest -from android_env.components import config_classes -from android_env.components import errors -# fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need -# to import it here explicitly. -from android_env.components.simulators import base_simulator -from android_env.components.simulators.fake import fake_simulator -import numpy as np - - -class BaseSimulatorTest(absltest.TestCase): - - def test_launch(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480)) - ) - # The simulator should launch and not crash. - simulator.launch() - - def test_launch_close(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig() - ) - # The simulator should launch and not crash. - simulator.launch() - # Closing the simulator should also not crash. - simulator.close() - - def test_get_screenshot(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480)) - ) - # The simulator should launch and not crash. - simulator.launch() - - screenshot = simulator.get_screenshot() - np.testing.assert_equal(screenshot.shape, [640, 480, 3]) - - def test_print_logs_on_exception(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig() - ) - with mock.patch.object( - simulator, 'get_logs' - ) as mock_get_logs, mock.patch.object( - simulator, '_launch_impl', autospec=True - ) as mock_launch: - mock_launch.side_effect = ValueError('Oh no!') - self.assertRaises(errors.SimulatorError, simulator.launch) - mock_get_logs.assert_called_once() - - def test_get_screenshot_error_async(self): - """An exception in the underlying interaction thread should bubble up.""" - - # Arrange. - mock_interaction_thread = mock.create_autospec( - base_simulator.InteractionThread - ) - mock_interaction_thread.screenshot.side_effect = ( - errors.ReadObservationError() - ) - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5) - ) - with mock.patch.object( - base_simulator, - 'InteractionThread', - autospec=True, - return_value=mock_interaction_thread, - ): - simulator.launch() - - # Act & Assert. - self.assertRaises(errors.ReadObservationError, simulator.get_screenshot) - - # Cleanup. - simulator.close() - - def test_get_screenshot_faster_than_screenshot_impl(self): - """Return same screenshot when step is faster than the interaction rate.""" - - # Arrange. - slow_rate = 0.5 - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate) - ) - - # Act. - with mock.patch.object( - simulator, '_get_screenshot_impl', autospec=True - ) as mock_get_screenshot_impl: - mock_get_screenshot_impl.side_effect = ( - np.array(i, ndmin=3) for i in itertools.count(0, 1) - ) - simulator.launch() - # Get two screenshots one after the other without pausing. - screenshot1 = simulator.get_screenshot() - screenshot2 = simulator.get_screenshot() - - # Assert. - self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0]) - - # Cleanup. - simulator.close() - - def test_get_screenshot_slower_than_screenshot_impl(self): - """Return different screenshots when step slower than the interaction rate.""" - - # Arrange. - fast_rate = 0.01 - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate) - ) - - # Act. - with mock.patch.object( - simulator, '_get_screenshot_impl', autospec=True - ) as mock_get_screenshot_impl: - mock_get_screenshot_impl.side_effect = ( - np.array(i, ndmin=3) for i in itertools.count(0, 1) - ) - simulator.launch() - # Sleep for 500ms between two screenshots. - screenshot1 = simulator.get_screenshot() - time.sleep(0.5) - screenshot2 = simulator.get_screenshot() - - # Assert. - self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0]) - - # Cleanup. - simulator.close() - - def test_interaction_thread_closes_upon_relaunch(self): - """Async interaction should kill the InteractionThread when relaunching.""" - - # Arrange. - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(interaction_rate_sec=0.01) - ) - mock_interaction_thread = mock.create_autospec( - base_simulator.InteractionThread - ) - - # Act & Assert. - with mock.patch.object( - base_simulator, - 'InteractionThread', - autospec=True, - return_value=mock_interaction_thread, - ): - simulator.launch() - mock_interaction_thread.stop.assert_not_called() - mock_interaction_thread.join.assert_not_called() - simulator.launch() - mock_interaction_thread.stop.assert_called_once() - mock_interaction_thread.join.assert_called_once() - simulator.close() - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/simulators/emulator/__init__.py b/android_env/components/simulators/emulator/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/components/simulators/emulator/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/components/simulators/emulator/emulator_launcher.py b/android_env/components/simulators/emulator/emulator_launcher.py deleted file mode 100644 index a1abe3f3..00000000 --- a/android_env/components/simulators/emulator/emulator_launcher.py +++ /dev/null @@ -1,169 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Prepares and launches an emulator process.""" - -import glob -import os -import subprocess -import tempfile - -from absl import logging -from android_env.components import config_classes - - -class EmulatorLauncher: - """Handles launching an emulator.""" - - def __init__( - self, - config: config_classes.EmulatorLauncherConfig, - adb_controller_config: config_classes.AdbControllerConfig, - ): - """Launches an emulator.""" - - self._config = config - self._adb_controller_config = adb_controller_config - - self._emulator = None - self._emulator_output = None - self._is_closed = False - - # Create directory for tmp files. - # Note: this will be deleted once EmulatorLauncher instance is cleaned up. - os.makedirs(config.tmp_dir, exist_ok=True) - self._local_tmp_dir_handle = tempfile.TemporaryDirectory( - dir=config.tmp_dir, prefix='simulator_instance_' - ) - self._local_tmp_dir = self._local_tmp_dir_handle.name - self._logfile_path = os.path.join(self._local_tmp_dir, 'emulator_output') - logging.info('Simulator local_tmp_dir: %s', self._local_tmp_dir) - - def logfile_path(self) -> str: - return self._logfile_path - - def launch_emulator_process(self) -> None: - """Launches the emulator.""" - - logging.info('Booting new emulator: %s', self._config.emulator_path) - - # Set necessary environment variables. - base_lib_dir = self._config.emulator_path[:-8] + 'lib64/' - ld_library_path = ':'.join([ - base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/', - base_lib_dir + 'gles_swiftshader/', base_lib_dir - ]) - extra_env_vars = { - 'ANDROID_HOME': '', - 'ANDROID_SDK_ROOT': self._config.android_sdk_root, - 'ANDROID_AVD_HOME': self._config.android_avd_home, - 'ANDROID_EMULATOR_KVM_DEVICE': self._config.kvm_device, - 'ANDROID_ADB_SERVER_PORT': str( - self._adb_controller_config.adb_server_port - ), - 'LD_LIBRARY_PATH': ld_library_path, - 'QT_XKB_CONFIG_ROOT': str( - self._config.emulator_path[:-8] + 'qt_config/' - ), - 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1', - 'SHOW_PERF_STATS': str(1 if self._config.show_perf_stats else 0), - } - logging.info('extra_env_vars: %s', - ' '.join(f'{k}={v}' for k, v in extra_env_vars.items())) - env_vars = dict(os.environ).copy() - env_vars.update(extra_env_vars) - - # Compile command. - grpc_port = ( - ['-grpc', str(self._config.grpc_port)] - if self._config.grpc_port >= 0 - else [] - ) - run_headless = ( - ['-no-skin', '-no-window'] if self._config.run_headless else [] - ) - ports = [ - '-ports', - '%s,%s' % (self._config.emulator_console_port, self._config.adb_port), - ] - snapshot = [ - '-snapshot', - self._config.snapshot_name, - '-feature', - 'AllowSnapshotMigration,MigratableSnapshotSave', - ] - snapshot = snapshot if self._config.snapshot_name else ['-no-snapshot'] - restrict_network_args = [ - '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options', - 'restrict=y' - ] - network_args = ( - restrict_network_args if self._config.restrict_network else [] - ) - command = ( - [ - self._config.emulator_path, - '-adb-path', - self._adb_controller_config.adb_path, - '-gpu', - self._config.gpu_mode, - '-no-audio', - '-show-kernel', - '-verbose', - '-avd', - self._config.avd_name, - ] - + grpc_port - + run_headless - + ports - + snapshot - + network_args - ) - logging.info('Emulator launch command: %s', ' '.join(command)) - # Prepare logfile. - self._emulator_output = open(self._logfile_path, 'wb') - - # Spawn the emulator process. - self._emulator = subprocess.Popen( - command, - env=env_vars, - stdout=self._emulator_output, - stderr=self._emulator_output) - - def confirm_shutdown(self) -> None: - """Shuts down the emulator process.""" - if self._emulator is not None: - logging.info('Checking if emulator process has finished...') - try: - self._emulator.wait(timeout=30.0) - except subprocess.TimeoutExpired: - logging.exception( - 'The emulator process did not finish after 30s. ' - 'returncode: %s. Will now try to kill() it.', - self._emulator.returncode) - self._emulator.kill() - self._emulator = None - self._emulator_output.close() - logging.info('The emulator process has finished.') - - def close(self): - """Clean up launcher files and processes.""" - if not self._is_closed: - self._local_tmp_dir_handle.cleanup() - self.confirm_shutdown() - self._is_closed = True - - def __del__(self): - self.close() diff --git a/android_env/components/simulators/emulator/emulator_launcher_test.py b/android_env/components/simulators/emulator/emulator_launcher_test.py deleted file mode 100644 index 3bfd44f1..00000000 --- a/android_env/components/simulators/emulator/emulator_launcher_test.py +++ /dev/null @@ -1,294 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.emulator_launcher.""" - -import builtins -import os -import subprocess -import tempfile -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import config_classes -from android_env.components.simulators.emulator import emulator_launcher - - -class EmulatorLauncherTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - - self._emulator_path = 'fake/path/emulator' - self._adb_path = 'fake/path/adb' - self._adb_port = 5554 - self._adb_server_port = 1234 - self._emulator_console_port = 5555 - self._avd_name = 'my_avd_name' - - self._expected_command = [ - self._emulator_path, - '-adb-path', - 'fake/path/adb', - '-gpu', - 'swangle_indirect', - '-no-audio', - '-show-kernel', - '-verbose', - '-avd', - self._avd_name, - ] - self._headless = ['-no-skin', '-no-window'] - self._ports = ['-ports', f'{self._emulator_console_port},{self._adb_port}'] - self._snapshot = ['-no-snapshot'] - - base_lib_dir = self._emulator_path[:-8] + 'lib64/' - ld_library_path = ':'.join([ - base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/', - base_lib_dir + 'gles_swiftshader/', base_lib_dir - ]) - - # Instantiate the config to extract default values. - config = config_classes.EmulatorLauncherConfig() - self._expected_env_vars = { - 'ANDROID_HOME': '', - 'ANDROID_SDK_ROOT': config.android_sdk_root, - 'ANDROID_AVD_HOME': config.android_avd_home, - 'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm', - 'ANDROID_ADB_SERVER_PORT': '1234', - 'LD_LIBRARY_PATH': ld_library_path, - 'QT_XKB_CONFIG_ROOT': str(self._emulator_path[:-8] + 'qt_config/'), - 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1', - } - - @parameterized.named_parameters([ - ('hide_perf_stats', False), - ('show_perf_stats', True), - ]) - @mock.patch.object(os, 'makedirs') - @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) - @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) - def test_launch( - self, - show_perf_stats: bool, - mock_tmp_dir, - unused_os_environ, - unused_os_makedirs, - ): - mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' - - config = config_classes.EmulatorLauncherConfig( - adb_port=self._adb_port, - emulator_console_port=self._emulator_console_port, - emulator_path=self._emulator_path, - avd_name=self._avd_name, - grpc_port=-1, - show_perf_stats=show_perf_stats, - ) - adb_controller_config = config_classes.AdbControllerConfig( - adb_path=self._adb_path, - adb_server_port=self._adb_server_port, - ) - launcher = emulator_launcher.EmulatorLauncher( - config=config, adb_controller_config=adb_controller_config - ) - - expected_env_vars = self._expected_env_vars - expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' - - with mock.patch.object( - subprocess, 'Popen', autospec=True - ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f: - f.return_value.__enter__ = f() - launcher.launch_emulator_process() - emulator_init.assert_called_once_with( - args=self._expected_command - + self._headless - + self._ports - + self._snapshot, - env=expected_env_vars, - stdout=f(), - stderr=f(), - ) - - @parameterized.named_parameters([ - ('hide_perf_stats', False), - ('show_perf_stats', True), - ]) - @mock.patch.object(os, 'makedirs') - @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) - @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) - def test_grpc_port( - self, - show_perf_stats: bool, - mock_tmp_dir, - unused_os_environ, - unused_os_makedirs, - ): - mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' - - config = config_classes.EmulatorLauncherConfig( - adb_port=self._adb_port, - emulator_console_port=self._emulator_console_port, - emulator_path=self._emulator_path, - avd_name=self._avd_name, - grpc_port=8554, - show_perf_stats=show_perf_stats, - ) - adb_controller_config = config_classes.AdbControllerConfig( - adb_path=self._adb_path, - adb_server_port=self._adb_server_port, - ) - launcher = emulator_launcher.EmulatorLauncher( - config=config, adb_controller_config=adb_controller_config - ) - - expected_env_vars = self._expected_env_vars - expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' - - with mock.patch.object( - subprocess, 'Popen', autospec=True - ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f: - f.return_value.__enter__ = f() - launcher.launch_emulator_process() - emulator_init.assert_called_once_with( - args=self._expected_command - + ['-grpc', '8554'] - + self._headless - + self._ports - + self._snapshot, - env=expected_env_vars, - stdout=f(), - stderr=f(), - ) - - @parameterized.named_parameters([ - ('hide_perf_stats', False), - ('show_perf_stats', True), - ]) - @mock.patch.object(os, 'makedirs') - @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) - @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) - def test_snapshot( - self, - show_perf_stats: bool, - mock_tmp_dir, - unused_os_environ, - unused_os_makedirs, - ): - mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' - - config = config_classes.EmulatorLauncherConfig( - adb_port=self._adb_port, - emulator_console_port=self._emulator_console_port, - emulator_path=self._emulator_path, - avd_name=self._avd_name, - grpc_port=-1, - snapshot_name='my_snapshot', - show_perf_stats=show_perf_stats, - ) - adb_controller_config = config_classes.AdbControllerConfig( - adb_path=self._adb_path, - adb_server_port=self._adb_server_port, - ) - launcher = emulator_launcher.EmulatorLauncher( - config=config, adb_controller_config=adb_controller_config - ) - - expected_snapshot = [ - '-snapshot', 'my_snapshot', '-feature', - 'AllowSnapshotMigration,MigratableSnapshotSave' - ] - - expected_env_vars = self._expected_env_vars - expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' - - with mock.patch.object( - subprocess, 'Popen', autospec=True) as emulator_init, \ - mock.patch.object(builtins, 'open', autospec=True) as f: - f.return_value.__enter__ = f() - launcher.launch_emulator_process() - emulator_init.assert_called_once_with( - args=self._expected_command - + self._headless - + self._ports - + expected_snapshot, - env=expected_env_vars, - stdout=f(), - stderr=f(), - ) - - @parameterized.named_parameters([ - ('hide_perf_stats', False), - ('show_perf_stats', True), - ]) - @mock.patch.object(os, 'makedirs') - @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) - @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) - def test_network_restrict( - self, - show_perf_stats: bool, - mock_tmp_dir, - unused_os_environ, - unused_os_makedirs, - ): - mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' - - config = config_classes.EmulatorLauncherConfig( - adb_port=self._adb_port, - emulator_console_port=self._emulator_console_port, - emulator_path=self._emulator_path, - avd_name=self._avd_name, - grpc_port=-1, - restrict_network=True, - show_perf_stats=show_perf_stats, - ) - adb_controller_config = config_classes.AdbControllerConfig( - adb_path=self._adb_path, - adb_server_port=self._adb_server_port, - ) - launcher = emulator_launcher.EmulatorLauncher( - config=config, adb_controller_config=adb_controller_config - ) - - expected_snapshot = ['-no-snapshot'] - expected_network_restrict = [ - '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options', - 'restrict=y' - ] - - expected_env_vars = self._expected_env_vars - expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' - - with mock.patch.object( - subprocess, 'Popen', autospec=True) as emulator_init, \ - mock.patch.object(builtins, 'open', autospec=True) as f: - f.return_value.__enter__ = f() - launcher.launch_emulator_process() - emulator_init.assert_called_once_with( - self._expected_command - + self._headless - + self._ports - + expected_snapshot - + expected_network_restrict, - env=expected_env_vars, - stdout=f(), - stderr=f(), - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/simulators/emulator/emulator_simulator.py b/android_env/components/simulators/emulator/emulator_simulator.py deleted file mode 100644 index d7f61874..00000000 --- a/android_env/components/simulators/emulator/emulator_simulator.py +++ /dev/null @@ -1,486 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A class that manages an Android Emulator.""" - -import os -import time -from typing import Any - -from absl import logging -from android_env.components import adb_controller -from android_env.components import adb_log_stream -from android_env.components import config_classes -from android_env.components import errors -from android_env.components import log_stream -from android_env.components.simulators import base_simulator -from android_env.components.simulators.emulator import emulator_launcher -from android_env.proto import state_pb2 -import grpc -import numpy as np -import portpicker - -from android_env.proto import emulator_controller_pb2 -from android_env.proto import emulator_controller_pb2_grpc -from android_env.proto import snapshot_service_pb2 -from android_env.proto import snapshot_service_pb2_grpc -from google.protobuf import empty_pb2 - - -_DEFAULT_SNAPSHOT_NAME: str = 'default_snapshot' - - -def _is_existing_emulator_provided( - launcher_config: config_classes.EmulatorLauncherConfig, -) -> bool: - """Returns true if all necessary args were provided.""" - - return bool( - launcher_config.adb_port - and launcher_config.emulator_console_port - and launcher_config.grpc_port - ) - - -def _pick_adb_port() -> int: - """Tries to pick a port in the recommended range 5555-5585. - - If no such port can be found, will return a random unused port. More info: - https://developer.android.com/studio/command-line/adb#howadbworks. - - Returns: - port: an available port for adb. - """ - - for p in range(5555, 5587, 2): - if portpicker.is_port_free(p): - return p - return portpicker.pick_unused_port() - - -def _pick_emulator_grpc_port() -> int: - """Tries to pick the recommended port for grpc. - - If no such port can be found, will return a random unused port. More info: - https://android.googlesource.com/platform/external/qemu/+/emu-master-dev/android/android-grpc/docs/. - - Returns: - port: an available port for emulator grpc. - """ - - if portpicker.is_port_free(8554): - return 8554 - else: - return portpicker.pick_unused_port() - - -class EmulatorBootError(errors.SimulatorError): - """Raised when an emulator failed to boot.""" - - -class EmulatorCrashError(errors.SimulatorError): - """Raised when a simulator crashed.""" - - -class EmulatorSimulator(base_simulator.BaseSimulator): - """Controls an Android Emulator.""" - - def __init__(self, config: config_classes.EmulatorConfig): - """Instantiates an EmulatorSimulator.""" - - super().__init__(config) - self._config = config - - # If adb_port, console_port and grpc_port are all already provided, - # we assume the emulator already exists and there's no need to launch. - if _is_existing_emulator_provided(self._config.emulator_launcher): - self._existing_emulator_provided = True - logging.info('Connecting to existing emulator "%r"', - self.adb_device_name()) - else: - self._existing_emulator_provided = False - self._config.emulator_launcher.adb_port = _pick_adb_port() - self._config.emulator_launcher.emulator_console_port = ( - portpicker.pick_unused_port() - ) - self._config.emulator_launcher.grpc_port = _pick_emulator_grpc_port() - - self._channel = None - self._emulator_stub: emulator_controller_pb2_grpc.EmulatorControllerStub | None = ( - None - ) - self._snapshot_stub = None - # Set the image format to RGBA. The width and height of the returned - # screenshots will use the device's width and height. - self._image_format = emulator_controller_pb2.ImageFormat( - format=emulator_controller_pb2.ImageFormat.ImgFormat.RGBA8888) - - if ( - self._config.launch_n_times_without_reboot - > self._config.launch_n_times_without_reinstall - ): - raise ValueError( - 'Number of launch attempts before reboot' - f' ({self._config.launch_n_times_without_reboot}) should not be' - ' greater than number of launch attempts before reinstall' - f' ({self._config.launch_n_times_without_reinstall})' - ) - - # Initialize own ADB controller. - self._config.adb_controller.device_name = self.adb_device_name() - self._adb_controller = self.create_adb_controller() - self._adb_controller.init_server() - logging.info( - 'Initialized simulator with ADB server port %r.', - self._config.adb_controller.adb_server_port, - ) - - # If necessary, create EmulatorLauncher. - if self._existing_emulator_provided: - self._logfile_path = self._config.logfile_path or None - self._launcher = None - else: - logging.info( - 'emulator_launcher config: %r', self._config.emulator_launcher - ) - self._launcher = emulator_launcher.EmulatorLauncher( - config=self._config.emulator_launcher, - adb_controller_config=self._config.adb_controller, - ) - self._logfile_path = ( - self._config.logfile_path or self._launcher.logfile_path() - ) - - def _reconnect_on_grpc_error(func): - """Decorator function for reconnecting to emulator upon grpc errors.""" - - def wrapper(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except grpc.RpcError: - logging.exception('RpcError caught. Reconnecting to emulator...') - self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( - self._config.emulator_launcher.grpc_port - ) - return func(self, *args, **kwargs) - - return wrapper - - def get_logs(self) -> str: - """Returns logs recorded by the emulator.""" - if self._logfile_path and os.path.exists(self._logfile_path): - with open(self._logfile_path, 'rb') as f: - return f.read().decode('utf-8') - else: - return f'Logfile does not exist: {self._logfile_path}.' - - def adb_device_name(self) -> str: - return 'emulator-%s' % (self._config.emulator_launcher.adb_port - 1) - - def create_adb_controller(self): - """Returns an ADB controller which can communicate with this simulator.""" - return adb_controller.AdbController(self._config.adb_controller) - - def create_log_stream(self) -> log_stream.LogStream: - return adb_log_stream.AdbLogStream( - adb_command_prefix=self._adb_controller.command_prefix(), - verbose=self._config.verbose_logs, - ) - - def _launch_impl(self) -> None: - """Prepares an Android Emulator for RL interaction. - - The behavior depends on `self._num_launch_attempts`'s value: - * <= self._config.launch_n_times_without_reboot -> Normal boot behavior. - * > self._config.launch_n_times_without_reboot but <= - self._config.launch_n_times_without_reinstall -> reboot (i.e. process - is killed and started again). - * > self._config.launch_n_times_without_reinstall -> reinstall (i.e. - process is killed, emulator files are deleted and the process started - again). - """ - - logging.info('Attempt %r at launching the Android Emulator (%r)', - self._num_launch_attempts, self.adb_device_name()) - - if self._launcher is not None: - # If not the first time, then shutdown the emulator first. - if ( - self._emulator_stub is not None - and self._num_launch_attempts - > self._config.launch_n_times_without_reboot - ): - self._shutdown_emulator() - # Subsequent attempts cause the emulator files to be reinstalled. - if ( - self._num_launch_attempts - > self._config.launch_n_times_without_reinstall - ): - logging.info('Closing emulator (%r)', self.adb_device_name()) - self._launcher.close() - self._launcher = emulator_launcher.EmulatorLauncher( - config=self._config.emulator_launcher, - adb_controller_config=self._config.adb_controller, - ) - self._launcher.launch_emulator_process() - # Establish grpc connection to emulator process. - self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( - self._config.emulator_launcher.grpc_port - ) - - # Confirm booted status. - try: - self._confirm_booted() - except EmulatorCrashError: - logging.exception('Failed to confirm booted status of emulator.') - - logging.info('Done booting the Android Emulator.') - - def load_state( - self, request: state_pb2.LoadStateRequest - ) -> state_pb2.LoadStateResponse: - """Loads a state using the emulator's snapshotting mechanism. - - Args: - request: The `LoadStateRequest`. In this case, `args` should be a dict - containing the key 'snapshot_name', representing the name of the - snapshot to load. If `request.args.snapshot_name` is `None`, a default - snapshot name is used. - - Returns: - A response indicating whether the snapshot was successfully loaded. - * If the snapshot was loaded successfully, the status will be `OK`. - * If no snapshot of the given name was found, the status will be - `NOT_FOUND`. - * If an error occurred during the snapshot loading process, the status - will be `ERROR` and the `error_message` field will be filled. - """ - assert self._snapshot_stub is not None - snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME) - snapshot_list = self._snapshot_stub.ListSnapshots( - snapshot_service_pb2.SnapshotFilter( - statusFilter=snapshot_service_pb2.SnapshotFilter.LoadStatus.All - ) - ) - if any( - snapshot.snapshot_id == snapshot_name - for snapshot in snapshot_list.snapshots - ): - snapshot_result = self._snapshot_stub.LoadSnapshot( - snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name) - ) - if snapshot_result.success: - return state_pb2.LoadStateResponse( - status=state_pb2.LoadStateResponse.Status.OK - ) - else: - return state_pb2.LoadStateResponse( - status=state_pb2.LoadStateResponse.Status.ERROR, - error_message=snapshot_result.err.decode('utf-8'), - ) - - else: - return state_pb2.LoadStateResponse( - status=state_pb2.LoadStateResponse.Status.NOT_FOUND - ) - - def save_state( - self, request: state_pb2.SaveStateRequest - ) -> state_pb2.SaveStateResponse: - """Saves a state using the emulator's snapshotting mechanism. - - Args: - request: The `SaveStateRequest`. In this case, `args` should be a dict - containing the key 'snapshot_name', representing the name of the - snapshot to save. If `request.args.snapshot_name` is `None`, a default - snapshot name is used. - - Returns: - A response indicating whether the snapshot was successfully saved. - * If the snapshot was saved successfully, the status will be `OK`. - * If an error occurred during the snapshot saving process, the status - will be `ERROR` and the `error_message` field will be filled. - """ - assert self._snapshot_stub is not None - snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME) - snapshot_result = self._snapshot_stub.SaveSnapshot( - snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name) - ) - if snapshot_result.success: - return state_pb2.SaveStateResponse( - status=state_pb2.SaveStateResponse.Status.OK - ) - else: - return state_pb2.SaveStateResponse( - status=state_pb2.SaveStateResponse.Status.ERROR, - error_message=snapshot_result.err.decode('utf-8'), - ) - - def _connect_to_emulator( - self, - grpc_port: int, - timeout_sec: int = 100, - ) -> tuple[ - emulator_controller_pb2_grpc.EmulatorControllerStub, - snapshot_service_pb2_grpc.SnapshotServiceStub, - ]: - """Connects to an emulator and returns a corresponsing stub.""" - - logging.info('Creating gRPC channel to the emulator on port %r', grpc_port) - port = f'localhost:{grpc_port}' - options = [('grpc.max_send_message_length', -1), - ('grpc.max_receive_message_length', -1)] - creds = grpc.local_channel_credentials() - - try: - self._channel = grpc.secure_channel(port, creds, options=options) - grpc.channel_ready_future(self._channel).result(timeout=timeout_sec) - except (grpc.RpcError, grpc.FutureTimeoutError) as grpc_error: - logging.exception('Failed to connect to the emulator.') - raise EmulatorBootError( - 'Failed to connect to the emulator.') from grpc_error - - logging.info('Added gRPC channel for the Emulator on port %s', port) - emulator_controller_stub = ( - emulator_controller_pb2_grpc.EmulatorControllerStub(self._channel) - ) - snapshot_stub = snapshot_service_pb2_grpc.SnapshotServiceStub(self._channel) - return emulator_controller_stub, snapshot_stub - - @_reconnect_on_grpc_error - def _confirm_booted(self, startup_wait_time_sec: int = 300): - """Waits until the emulator is fully booted.""" - - assert ( - self._emulator_stub is not None - ), 'Emulator stub has not been initialized yet.' - start_time = time.time() - deadline = start_time + startup_wait_time_sec - success = False - while time.time() < deadline: - emu_status = self._emulator_stub.getStatus(empty_pb2.Empty()) - logging.info('Waiting for emulator (%r) to start... (%rms)', - self.adb_device_name(), emu_status.uptime) - if emu_status.booted: - success = True - break - time.sleep(5.0) - - elapsed_time = time.time() - start_time - if not success: - raise EmulatorCrashError( - f'The emulator failed to boot after {startup_wait_time_sec} seconds') - - logging.info('Done booting the emulator (in %f seconds).', elapsed_time) - logging.info('********** Emulator logs **********') - for line in self.get_logs().splitlines(): - logging.info(line) - logging.info('******* End of emulator logs *******') - logging.info('See the full emulator logs at %r', self._logfile_path) - - @_reconnect_on_grpc_error - def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: - """Sends a touch event to be executed on the simulator. - - Args: - touches: A list of touch events. Each element in the list corresponds to a - single touch event. Each touch event tuple should have: - 0 x: The horizontal coordinate of this event. - 1 y: The vertical coordinate of this event. - 2 is_down: Whether the finger is touching or not the screen. - 3 identifier: Identifies a particular finger in a multitouch event. - """ - - assert ( - self._emulator_stub is not None - ), 'Emulator stub has not been initialized yet.' - touch_events = [ - emulator_controller_pb2.Touch( - x=t[0], y=t[1], pressure=int(t[2]), identifier=t[3]) - for t in touches - ] - self._emulator_stub.sendTouch( - emulator_controller_pb2.TouchEvent(touches=touch_events)) - - @_reconnect_on_grpc_error - def send_key(self, keycode: np.int32, event_type: str) -> None: - """Sends a key event to the emulator. - - Args: - keycode: Code representing the desired key press in XKB format. - See the emulator_controller_pb2 for details. - event_type: Type of key event to be sent. - """ - - event_types = emulator_controller_pb2.KeyboardEvent.KeyEventType.keys() - if event_type not in event_types: - raise ValueError( - f'Event type must be one of {event_types} but is {event_type}.') - - assert ( - self._emulator_stub is not None - ), 'Emulator stub has not been initialized yet.' - self._emulator_stub.sendKey( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType.Value( - event_type - ), - keyCode=int(keycode), - ) - ) - - @_reconnect_on_grpc_error - def _get_screenshot_impl(self) -> np.ndarray: - """Fetches the latest screenshot from the emulator.""" - - assert ( - self._emulator_stub is not None - ), 'Emulator stub has not been initialized yet.' - assert self._image_format, 'ImageFormat has not been initialized yet.' - image_proto = self._emulator_stub.getScreenshot(self._image_format) - h, w = image_proto.format.height, image_proto.format.width - image = np.frombuffer(image_proto.image, dtype='uint8', count=h * w * 4) - image.shape = (h, w, 4) - return image[:, :, :3] - - @_reconnect_on_grpc_error - def _shutdown_emulator(self): - """Sends a signal to trigger emulator shutdown.""" - - if self._emulator_stub is None: - logging.info('Emulator (%r) is not up.', self.adb_device_name()) - return - - assert self._launcher is not None, 'Launcher is already down.' - - logging.info('Shutting down the emulator (%r)...', self.adb_device_name()) - self._emulator_stub.setVmState( - emulator_controller_pb2.VmRunState( - state=emulator_controller_pb2.VmRunState.RunState.SHUTDOWN)) - self._launcher.confirm_shutdown() - - def close(self): - super().close() - - if self._launcher is not None: - self._shutdown_emulator() - logging.info('Closing emulator (%r)', self.adb_device_name()) - self._launcher.close() - self._emulator_stub = None - self._snapshot_stub = None - if self._channel is not None: - self._channel.close() - super().close() diff --git a/android_env/components/simulators/emulator/emulator_simulator_test.py b/android_env/components/simulators/emulator/emulator_simulator_test.py deleted file mode 100644 index 97e24711..00000000 --- a/android_env/components/simulators/emulator/emulator_simulator_test.py +++ /dev/null @@ -1,543 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.emulator_simulator.""" - -import builtins -import os -import time -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_call_parser -from android_env.components import adb_controller -from android_env.components import config_classes -from android_env.components.simulators.emulator import emulator_launcher -from android_env.components.simulators.emulator import emulator_simulator -from android_env.proto import state_pb2 -import grpc -from PIL import Image -import portpicker - -from android_env.proto import emulator_controller_pb2 -from android_env.proto import emulator_controller_pb2_grpc -from android_env.proto import snapshot_service_pb2 - - -class EmulatorSimulatorTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.addCleanup(mock.patch.stopall) # Disable previous patches. - - self._adb_controller = mock.create_autospec(adb_controller.AdbController) - self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) - self._launcher = mock.create_autospec(emulator_launcher.EmulatorLauncher) - self._launcher.logfile_path.return_value = 'logfile_path' - self._emulator_stub = mock.create_autospec( - emulator_controller_pb2_grpc.EmulatorControllerStub) - - self._grpc_channel = mock.create_autospec(grpc.Channel) - mock.patch.object( - grpc.aio, 'secure_channel', return_value=self._grpc_channel).start() - mock.patch.object( - grpc, 'secure_channel', return_value=self._grpc_channel).start() - mock.patch.object( - grpc, 'local_channel_credentials', - return_value=self._grpc_channel).start() - self._mock_future = mock.create_autospec(grpc.Future) - mock.patch.object( - grpc, 'channel_ready_future', return_value=self._mock_future).start() - mock.patch.object(time, 'time', return_value=12345).start() - - mock.patch.object( - adb_controller, 'AdbController', - return_value=self._adb_controller).start() - mock.patch.object( - adb_call_parser, - 'AdbCallParser', - autospec=True, - return_value=self._adb_call_parser).start() - mock.patch.object( - emulator_launcher, 'EmulatorLauncher', - return_value=self._launcher).start() - - def test_adb_device_name_not_empty(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - self.assertNotEmpty(simulator.adb_device_name()) - - @mock.patch.object(os.path, 'exists', autospec=True, return_value=True) - @mock.patch.object(builtins, 'open', autospec=True) - def test_logfile_path(self, mock_open, unused_mock_exists): - config = config_classes.EmulatorConfig( - logfile_path='fake/logfile/path', - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - mock_open.return_value.__enter__.return_value.read.return_value = ( - 'fake_logs'.encode('utf-8')) - logs = simulator.get_logs() - mock_open.assert_called_once_with('fake/logfile/path', 'rb') - self.assertEqual(logs, 'fake_logs') - - @mock.patch.object(portpicker, 'is_port_free', return_value=True) - def test_grpc_port(self, unused_mock_portpicker): - - launcher_config = config_classes.EmulatorLauncherConfig( - tmp_dir=self.create_tempdir().full_path - ) - config = config_classes.EmulatorConfig( - emulator_launcher=launcher_config, - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - self.assertEqual(launcher_config.grpc_port, 8554) - - @mock.patch.object(portpicker, 'is_port_free', return_value=False) - def test_grpc_port_unavailable(self, unused_mock_portpicker): - - launcher_config = config_classes.EmulatorLauncherConfig( - tmp_dir=self.create_tempdir().full_path - ) - config = config_classes.EmulatorConfig( - emulator_launcher=launcher_config, - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - self.assertNotEqual(launcher_config.grpc_port, 8554) - - def test_launch_operation_order(self): - """Makes sure that adb_controller is started before Emulator is launched.""" - - # Arrange. - call_order = [] - self._adb_controller.init_server.side_effect = lambda: call_order.append( - 'init_server' - ) - self._launcher.launch_emulator_process.side_effect = ( - lambda: call_order.append('launch_emulator_process') - ) - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # Act. - simulator.launch() # The simulator should launch and not crash. - - # Assert. - # The adb server should be initialized before launching the emulator. - self.assertEqual(call_order, ['init_server', 'launch_emulator_process']) - - def test_close(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - # For whatever reason clients may want to close the EmulatorSimulator. - # We just want to check that the simulator does not crash and/or leak - # resources. - simulator.close() - - def test_value_error_if_launch_attempt_params_incorrect(self): - self.assertRaises( - ValueError, - emulator_simulator.EmulatorSimulator, - config=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - launch_n_times_without_reboot=2, - launch_n_times_without_reinstall=1, - ), - ) - - def test_launch_attempt_reboot(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - launch_n_times_without_reboot=1, - launch_n_times_without_reinstall=2, - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - self._launcher.launch_emulator_process.assert_called_once() - self._launcher.reset_mock() - - # Launch attempt 2. - simulator.launch() - self._launcher.confirm_shutdown.assert_called_once() - self._launcher.close.assert_not_called() - self._launcher.launch_emulator_process.assert_called_once() - - def test_launch_attempt_reinstall_after_zero_attempts(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - launch_n_times_without_reboot=0, - launch_n_times_without_reinstall=0, - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should not reboot or reinstall on its very first launch. - simulator.launch() - self._launcher.launch_emulator_process.assert_called_once() - self._launcher.confirm_shutdown.assert_not_called() - self._launcher.close.assert_not_called() - - # Every subsequent attempt should reboot and reinstall. - self._launcher.reset_mock() - simulator.launch() - self._launcher.confirm_shutdown.assert_called_once() - self._launcher.close.assert_called_once() # Now this should `close()`. - self._launcher.launch_emulator_process.assert_called_once() - - def test_launch_attempt_reinstall(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - launch_n_times_without_reboot=1, - launch_n_times_without_reinstall=2, - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - self._launcher.launch_emulator_process.assert_called_once() - - # Launch attempt 2. - self._launcher.reset_mock() - simulator.launch() - self._launcher.confirm_shutdown.assert_called_once() - self._launcher.close.assert_not_called() # Reboots don't `close()`. - self._launcher.launch_emulator_process.assert_called_once() - - # Launch attempt 3. - self._launcher.reset_mock() - simulator.launch() - self._launcher.confirm_shutdown.assert_called_once() - self._launcher.close.assert_called_once() # Now this should `close()`. - self._launcher.launch_emulator_process.assert_called_once() - - def test_get_screenshot(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - simulator._emulator_stub.getScreenshot = mock.MagicMock( - return_value=emulator_controller_pb2.Image( - format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), - image=Image.new('RGBA', (1234, 5678)).tobytes(), - timestampUs=123)) - - screenshot = simulator.get_screenshot() - # The screenshot should have the same screen dimensions as reported by ADB - # and it should have 3 channels (RGB). - self.assertEqual(screenshot.shape, (1234, 5678, 3)) - - def test_load_state(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - with mock.patch.object( - simulator, '_snapshot_stub', create_autospec=True - ) as mock_snapshot_stub: - snapshot_list = snapshot_service_pb2.SnapshotList() - snapshot_list.snapshots.add(snapshot_id='snapshot_name_foo') - snapshot_list.snapshots.add(snapshot_id='snapshot_name_bar') - mock_snapshot_stub.ListSnapshots.return_value = snapshot_list - mock_snapshot_stub.LoadSnapshot.return_value = ( - snapshot_service_pb2.SnapshotPackage(success=True) - ) - load_response = simulator.load_state( - request=state_pb2.LoadStateRequest( - args={'snapshot_name': 'snapshot_name_foo'} - ) - ) - self.assertEqual( - load_response.status, state_pb2.LoadStateResponse.Status.OK - ) - load_response = simulator.load_state( - request=state_pb2.LoadStateRequest( - args={'snapshot_name': 'snapshot_name_baz'} - ) - ) - self.assertEqual( - load_response.status, state_pb2.LoadStateResponse.Status.NOT_FOUND - ) - mock_snapshot_stub.LoadSnapshot.return_value = ( - snapshot_service_pb2.SnapshotPackage(success=False, err=b'error') - ) - load_response = simulator.load_state( - request=state_pb2.LoadStateRequest( - args={'snapshot_name': 'snapshot_name_bar'} - ) - ) - self.assertEqual( - load_response.status, state_pb2.LoadStateResponse.Status.ERROR - ) - self.assertEqual(load_response.error_message, 'error') - - def test_save_state(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - with mock.patch.object( - simulator, '_snapshot_stub', create_autospec=True - ) as mock_snapshot_stub: - mock_snapshot_stub.SaveSnapshot.return_value = ( - snapshot_service_pb2.SnapshotPackage(success=True) - ) - save_response = simulator.save_state( - request=state_pb2.SaveStateRequest( - args={'snapshot_name': 'snapshot_name_foo'} - ) - ) - self.assertEqual( - save_response.status, state_pb2.SaveStateResponse.Status.OK - ) - mock_snapshot_stub.SaveSnapshot.return_value = ( - snapshot_service_pb2.SnapshotPackage(success=False, err=b'error') - ) - save_response = simulator.save_state( - request=state_pb2.SaveStateRequest( - args={'snapshot_name': 'snapshot_name_bar'} - ) - ) - self.assertEqual( - save_response.status, state_pb2.SaveStateResponse.Status.ERROR - ) - self.assertEqual(save_response.error_message, 'error') - - def test_send_touch(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) - - simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)]) - simulator.send_touch([(1, 2, True, 0), (3, 4, True, 1)]) - simulator.send_touch([(321, 654, False, 0), (531, 642, False, 1)]) - - simulator._emulator_stub.sendTouch.assert_has_calls([ - mock.call( - emulator_controller_pb2.TouchEvent(touches=[{ - 'x': 123, - 'y': 456, - 'pressure': 1 - }, { - 'x': 135, - 'y': 246, - 'pressure': 1, - 'identifier': 1 - }])), - mock.call( - emulator_controller_pb2.TouchEvent(touches=[{ - 'x': 1, - 'y': 2, - 'pressure': 1 - }, { - 'x': 3, - 'y': 4, - 'pressure': 1, - 'identifier': 1 - }])), - mock.call( - emulator_controller_pb2.TouchEvent(touches=[{ - 'x': 321, - 'y': 654, - 'pressure': 0 - }, { - 'x': 531, - 'y': 642, - 'pressure': 0, - 'identifier': 1 - }])), - ]) - - def test_send_key(self): - config = config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - grpc_port=1234, tmp_dir=self.create_tempdir().full_path - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='/my/adb', - adb_server_port=5037, - ), - ) - simulator = emulator_simulator.EmulatorSimulator(config) - - # The simulator should launch and not crash. - simulator.launch() - - simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) - - simulator.send_key(123, 'keydown') - simulator.send_key(321, 'keydown') - simulator.send_key(321, 'keyup') - simulator.send_key(123, 'keyup') - simulator.send_key(321, 'keypress') - simulator.send_key(123, 'keypress') - - simulator._emulator_stub.sendKey.assert_has_calls([ - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keydown, - keyCode=123, - )), - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keydown, - keyCode=321, - )), - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keyup, - keyCode=321, - )), - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keyup, - keyCode=123, - )), - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keypress, - keyCode=321, - )), - mock.call( - emulator_controller_pb2.KeyboardEvent( - codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, - eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType - .keypress, - keyCode=123, - )) - ]) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/simulators/fake/__init__.py b/android_env/components/simulators/fake/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/components/simulators/fake/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/components/simulators/fake/fake_simulator.py b/android_env/components/simulators/fake/fake_simulator.py deleted file mode 100644 index 79996c8b..00000000 --- a/android_env/components/simulators/fake/fake_simulator.py +++ /dev/null @@ -1,138 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fake Simulator for testing AndroidEnv infrastructure.""" - -import random -import threading -import time - -from absl import logging -from android_env.components import adb_controller -from android_env.components import config_classes -from android_env.components import log_stream -from android_env.components.simulators import base_simulator -import numpy as np - - -class FakeStream: - """This class simulates the logs coming from ADB.""" - - def __init__(self): - self._values = [ - '', - self._make_stdout('reward: 0.5'), - self._make_stdout('reward: 1.0'), - self._make_stdout('extra: my_extra [1.0]'), - self._make_stdout('episode end'), - ] - self._kill = False - self._lock = threading.Lock() - - def _make_stdout(self, data): - """Returns a valid log output with given data as message.""" - return f' 1553110400.424 5583 5658 D Tag: {data}' - - def kill(self): - self._kill = True - - def __iter__(self): - while True: - if self._kill: - return - else: - with self._lock: - next_value = random.choices( - self._values, weights=[0.49, 0.15, 0.15, 0.15, 0.01], k=1)[0] - time.sleep(0.1) - yield next_value - - -class FakeLogStream(log_stream.LogStream): - """FakeLogStream class that wraps a FakeStream.""" - - def __init__(self): - super().__init__(verbose=False) - self.stream = FakeStream() - - def _get_stream_output(self): - return self.stream - - def stop_stream(self): - self.stream.kill() - - -class FakeAdbController(adb_controller.AdbController): - """Fake adb controller for FakeSimulator.""" - - def execute_command( - self, - args: list[str], - timeout: float | None = None, - device_specific: bool = True, - ) -> bytes: - """Returns fake output for adb commands.""" - - del timeout, device_specific - - # Fake "service is ready" output. - if args[:3] == ['shell', 'service', 'check']: - return f'Service {args[-1]}: found'.encode('utf-8') - - # Fake dumpsys output for getting orientation. - if args == ['shell', 'dumpsys', 'input']: - return b' SurfaceOrientation: 0' - - # app_screen_checker: fake_task expects 'fake_activity'. - if args[:4] == ['shell', 'am', 'stack', 'list']: - return (b'taskId=0 fake_activity visible=true ' - b'topActivity=ComponentInfo{fake_activity}') - - return b'fake output' - - -class FakeSimulator(base_simulator.BaseSimulator): - """FakeSimulator class.""" - - def __init__(self, config: config_classes.FakeSimulatorConfig): - """FakeSimulator class that can replace EmulatorSimulator in AndroidEnv.""" - super().__init__(config) - self._screen_dimensions = np.array(config.screen_dimensions) - logging.info('Created FakeSimulator.') - - def get_logs(self) -> str: - return 'FakeSimulator: fake logs' - - def adb_device_name(self) -> str: - return 'fake_simulator' - - def create_adb_controller(self): - return FakeAdbController(config_classes.AdbControllerConfig()) - - def create_log_stream(self) -> log_stream.LogStream: - return FakeLogStream() - - def _launch_impl(self) -> None: - pass - - def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: - del touches - - def send_key(self, keycode: np.int32, event_type: str) -> None: - del keycode, event_type - - def _get_screenshot_impl(self) -> np.ndarray: - return np.random.randint( - low=0, high=255, size=(*self._screen_dimensions, 3), dtype=np.uint8) diff --git a/android_env/components/simulators/fake/fake_simulator_test.py b/android_env/components/simulators/fake/fake_simulator_test.py deleted file mode 100644 index 5231433e..00000000 --- a/android_env/components/simulators/fake/fake_simulator_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for fake_simulator.""" - -import re -from absl.testing import absltest -from android_env.components import config_classes -from android_env.components.simulators.fake import fake_simulator -import numpy as np - - -class FakeSimulatorTest(absltest.TestCase): - - def test_device_name(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - self.assertEqual(simulator.adb_device_name(), 'fake_simulator') - - def test_launch_close(self): - # The simulator should launch and not crash. - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - # Closing the simulator should also not crash. - simulator.close() - - def test_get_screenshot(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - - screenshot = simulator.get_screenshot() - np.testing.assert_equal(screenshot.shape, [320, 480, 3]) - np.testing.assert_equal(screenshot.dtype, np.uint8) - - def test_log_stream(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - log_stream = simulator.create_log_stream() - # Start yielding lines from LogStream. - log_stream.resume_stream() - lines = [ - '', - ' 1553110400.424 5583 5658 D Tag: reward: 0.5', - ' 1553110400.424 5583 5658 D Tag: reward: 1.0', - ' 1553110400.424 5583 5658 D Tag: extra: my_extra [1.0]', - ' 1553110400.424 5583 5658 D Tag: episode end', - ] - for i, line in enumerate(log_stream.get_stream_output()): - self.assertIn(line, lines) - if i > 10: - break - - def test_adb_output(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - adb_controller = simulator.create_adb_controller() - line = adb_controller.execute_command(['shell', 'dumpsys', 'input']) - line = line.decode('utf-8') - matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line) - self.assertIsNotNone(matches) - orientation = matches.group(1) - self.assertEqual(orientation, '0') - line = adb_controller.execute_command(['shell', 'service', 'check', 'foo']) - line = line.decode('utf-8') - self.assertEqual(line, 'Service foo: found') - line = adb_controller.execute_command(['shell', 'am', 'stack', 'list']) - line = line.decode('utf-8') - self.assertEqual(line, 'taskId=0 fake_activity visible=true ' - 'topActivity=ComponentInfo{fake_activity}') - - def test_send_touch(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - simulator.send_touch([(0, 1, True, 0)]) - simulator.send_touch([(0, 1, False, 0)]) - # No assertions, we just want to ensure that `send_touch()` can be called - # without crashing anything. - - def test_send_key(self): - simulator = fake_simulator.FakeSimulator( - config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) - ) - simulator.launch() - simulator.send_key(np.int32(123), 'keydown') - simulator.send_key(np.int32(123), 'keyup') - simulator.send_key(np.int32(123), 'keypress') - # No assertions, we just want to ensure that `send_key()` can be called - # without crashing anything. - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/specs.py b/android_env/components/specs.py deleted file mode 100644 index d17aa7c0..00000000 --- a/android_env/components/specs.py +++ /dev/null @@ -1,138 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base specs for AndroidEnv.""" - -from android_env.components import action_type -from android_env.proto import task_pb2 -from dm_env import specs -import numpy as np - - -_PROTO_DTYPE_TO_NUMPY_DTYPE = { - task_pb2.ArraySpec.DataType.FLOAT: np.float32, - task_pb2.ArraySpec.DataType.DOUBLE: np.float64, - task_pb2.ArraySpec.DataType.INT8: np.int8, - task_pb2.ArraySpec.DataType.INT16: np.int16, - task_pb2.ArraySpec.DataType.INT32: np.int32, - task_pb2.ArraySpec.DataType.INT64: np.int64, - task_pb2.ArraySpec.DataType.UINT8: np.uint8, - task_pb2.ArraySpec.DataType.UINT16: np.uint16, - task_pb2.ArraySpec.DataType.UINT32: np.uint32, - task_pb2.ArraySpec.DataType.UINT64: np.uint64, - task_pb2.ArraySpec.DataType.BOOL: np.bool_, - task_pb2.ArraySpec.DataType.STRING_U1: np.dtype(('U1')), - task_pb2.ArraySpec.DataType.STRING_U16: np.dtype((' dict[str, specs.Array]: - """Default action spec for AndroidEnv. - - Args: - num_fingers: Number of virtual fingers of the agent. - enable_key_events: Whether keyboard key events are enabled. - - Returns: - A dict of action specs, each item corresponding to a virtual finger. - action_type: An integer of type ActionType: TOUCH=0, LIFT=1, REPEAT=2 - touch_position: Position [x, y] of the touch action, where x, y are float - values between 0.0 and 1.0 corresponding to the relative position on the - screen. IGNORED when (action_type != ActionType.TOUCH). - keycode: code representing the desired key press in XKB format. See the - emulator_controller_pb2 for details. - action_type_i: Action type for additional fingers (i>1). - touch_position_i: Touch position for additional fingers (i>1). - """ - - num_actions = len(action_type.ActionType) if enable_key_events else 3 - - action_spec = { - 'action_type': - specs.DiscreteArray(num_values=num_actions, name='action_type'), - 'touch_position': - specs.BoundedArray( - shape=(2,), - dtype=np.float32, - minimum=[0.0, 0.0], - maximum=[1.0, 1.0], - name='touch_position'), - } - - for i in range(2, num_fingers + 1): - action_spec.update({ - f'action_type_{i}': - specs.DiscreteArray( - num_values=len(action_type.ActionType), - name=f'action_type_{i}'), - f'touch_position_{i}': - specs.BoundedArray( - shape=(2,), - dtype=np.float32, - minimum=[0.0, 0.0], - maximum=[1.0, 1.0], - name=f'touch_position_{i}'), - }) - - if enable_key_events: - action_spec['keycode'] = specs.DiscreteArray( - num_values=(1 << 16) - 1, name='keycode') - - return action_spec - - -def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]: - """Default observation spec for AndroidEnv. - - Args: - height: Height of the device screen in pixels. - width: Width of the device screen in pixels. - - Returns: - pixels: Spec for the RGB screenshot of the device. Has shape (H, W, 3) - timedelta: Spec for time delta since the last observation (in microseconds). - The first timestep immediately after reset() will have this value set to - 0. - orientation: Spec for the latest orientation in a one-hot representation: - [1, 0, 0, 0]: PORTRAIT (0 degrees) - [0, 1, 0, 0]: LANDSCAPE (90 degrees clockwise) - [0, 0, 1, 0]: PORTRAIT (180 degrees) ("upside down") - [0, 0, 0, 1]: LANDSCAPE (270 degrees clockwise) - """ - - return { - 'pixels': - specs.BoundedArray( - shape=(height, width, 3), - dtype=np.uint8, - name='pixels', - minimum=0, - maximum=255), - 'timedelta': - specs.Array(shape=(), dtype=np.int64, name='timedelta'), - 'orientation': - specs.BoundedArray( - shape=np.array([4]), - dtype=np.uint8, - name='orientation', - minimum=0, - maximum=1), - } diff --git a/android_env/components/specs_test.py b/android_env/components/specs_test.py deleted file mode 100644 index 389eba5e..00000000 --- a/android_env/components/specs_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for specs.py.""" - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.components import specs -from android_env.proto import task_pb2 -from dm_env import specs as dm_env_specs -import numpy as np - - -class SpecsTest(parameterized.TestCase): - - def test_base_action_spec(self): - action_spec = specs.base_action_spec(num_fingers=1) - for spec in action_spec.values(): - self.assertIsInstance(spec, dm_env_specs.Array) - self.assertEqual(action_spec['action_type'].shape, ()) - self.assertEqual(action_spec['action_type'].dtype, np.int32) - self.assertEqual(action_spec['touch_position'].shape, (2,)) - self.assertEqual(action_spec['touch_position'].dtype, np.float32) - - def test_base_action_spec_with_key_events(self): - action_spec = specs.base_action_spec(num_fingers=1, enable_key_events=True) - for spec in action_spec.values(): - self.assertIsInstance(spec, dm_env_specs.Array) - self.assertEqual(action_spec['action_type'].shape, ()) - self.assertEqual(action_spec['action_type'].dtype, np.int32) - self.assertEqual(action_spec['touch_position'].shape, (2,)) - self.assertEqual(action_spec['touch_position'].dtype, np.float32) - self.assertEqual(action_spec['keycode'].shape, ()) - self.assertEqual(action_spec['keycode'].dtype, np.int32) - - def test_base_action_spec_multitouch(self): - action_spec = specs.base_action_spec(num_fingers=3) - self.assertLen(action_spec.keys(), 6) - for spec in action_spec.values(): - self.assertIsInstance(spec, dm_env_specs.Array) - self.assertEqual(action_spec['action_type'].shape, ()) - self.assertEqual(action_spec['action_type'].dtype, np.int32) - self.assertEqual(action_spec['touch_position'].shape, (2,)) - self.assertEqual(action_spec['touch_position'].dtype, np.float32) - self.assertEqual(action_spec['action_type_2'].shape, ()) - self.assertEqual(action_spec['action_type_2'].dtype, np.int32) - self.assertEqual(action_spec['touch_position_2'].shape, (2,)) - self.assertEqual(action_spec['touch_position_2'].dtype, np.float32) - self.assertEqual(action_spec['action_type_3'].shape, ()) - self.assertEqual(action_spec['action_type_3'].dtype, np.int32) - self.assertEqual(action_spec['touch_position_3'].shape, (2,)) - self.assertEqual(action_spec['touch_position_3'].dtype, np.float32) - - @parameterized.parameters( - (480, 320), - (100, 100), - (1440, 1960), - ) - def test_base_observation_spec(self, height, width): - observation_spec = specs.base_observation_spec(height, width) - for spec in observation_spec.values(): - self.assertIsInstance(spec, dm_env_specs.Array) - self.assertEqual(observation_spec['pixels'].shape, (height, width, 3)) - self.assertEqual(observation_spec['pixels'].dtype, np.uint8) - self.assertEqual(observation_spec['timedelta'].shape, ()) - self.assertEqual(observation_spec['timedelta'].dtype, np.int64) - self.assertEqual(observation_spec['orientation'].shape, (4,)) - self.assertEqual(observation_spec['orientation'].dtype, np.uint8) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/components/task_manager.py b/android_env/components/task_manager.py deleted file mode 100644 index 81718546..00000000 --- a/android_env/components/task_manager.py +++ /dev/null @@ -1,385 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""TaskManager handles all events and information related to the task.""" - -import ast -from collections.abc import Callable -import copy -import datetime -import json -import re -import threading -from typing import Any - -from absl import logging -from android_env.components import adb_call_parser as adb_call_parser_lib -from android_env.components import app_screen_checker -from android_env.components import config_classes -from android_env.components import dumpsys_thread -from android_env.components import log_stream as log_stream_lib -from android_env.components import logcat_thread -from android_env.components import setup_step_interpreter -from android_env.proto import task_pb2 -import dm_env -import numpy as np - - -class TaskManager: - """Handles all events and information related to the task.""" - - def __init__( - self, - task: task_pb2.Task, - config: config_classes.TaskManagerConfig | None = None, - ): - """Controls task-relevant events and information. - - Args: - task: A task proto defining the RL task. - config: Configuration for this instance. - """ - - self._task = task - self._config = config or config_classes.TaskManagerConfig() - self._lock = threading.Lock() - self._logcat_thread = None - self._dumpsys_thread = None - self._setup_step_interpreter = None - - # Initialize stats. - self._stats = { - 'episode_steps': 0, - 'reset_count_step_timeout': 0, - 'reset_count_user_exited': 0, - 'reset_count_episode_end': 0, - 'reset_count_max_duration_reached': 0, - 'restart_count_max_bad_states': 0, - 'task_updates': 0, - } - - # Initialize internal state - self._task_start_time = None - self._bad_state_counter = 0 - self._is_bad_episode = False - - self._latest_values = { - 'reward': 0.0, - 'score': 0.0, - 'extra': {}, - 'episode_end': False, - } - - logging.info('Task config: %s', self._task) - - def stats(self) -> dict[str, Any]: - """Returns a dictionary of stats. - - This method is expected to be called after setup_task() has been called. - """ - output = copy.deepcopy(self._stats) - if self._setup_step_interpreter is not None: - output.update(self._setup_step_interpreter.stats()) - return output - - def setup_task(self) -> None: - """Performs one-off task setup..""" - self._setup_step_interpreter.interpret(self._task.setup_steps) - - def stop(self) -> None: - """Suspends task processing.""" - self._stop_logcat_thread() - - def start( - self, - adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser], - log_stream: log_stream_lib.LogStream) -> None: - """Starts task processing.""" - - self._start_logcat_thread(log_stream=log_stream) - self._logcat_thread.resume() - self._start_dumpsys_thread(adb_call_parser_factory()) - self._start_setup_step_interpreter(adb_call_parser_factory()) - - def reset_task(self) -> None: - """Resets a task for a new run.""" - - self._logcat_thread.pause() - self._setup_step_interpreter.interpret(self._task.reset_steps) - self._logcat_thread.resume() - - # Reset some other variables. - if not self._is_bad_episode: - self._bad_state_counter = 0 - self._is_bad_episode = False - - self._task_start_time = datetime.datetime.now() - with self._lock: - self._latest_values = { - 'reward': 0.0, - 'score': 0.0, - 'extra': {}, - 'episode_end': False, - } - - def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep: - """Performs one RL step.""" - - self._stats['episode_steps'] = 0 - - self._logcat_thread.line_ready().wait() - with self._lock: - extras = self._get_current_extras() - - observation['extras'] = extras - - return dm_env.TimeStep( - step_type=dm_env.StepType.FIRST, - reward=0.0, - discount=0.0, - observation=observation) - - def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep: - """Performs one RL step.""" - - self._stats['episode_steps'] += 1 - - self._logcat_thread.line_ready().wait() - with self._lock: - reward = self._get_current_reward() - extras = self._get_current_extras() - transition_fn = self._determine_transition_fn() - - observation['extras'] = extras - - return transition_fn(reward=reward, observation=observation) - - def _get_current_reward(self) -> float: - """Returns total reward accumulated since the last step.""" - reward = self._latest_values['reward'] - self._latest_values['reward'] = 0.0 - return reward - - def _get_current_extras(self) -> dict[str, Any]: - """Returns task extras accumulated since the last step.""" - extras = {} - for name, values in self._latest_values['extra'].items(): - extras[name] = np.stack(values) - self._latest_values['extra'] = {} - return extras - - def _determine_transition_fn(self) -> Callable[..., dm_env.TimeStep]: - """Determines the type of RL transition will be used.""" - - # Check if user existed the task - if self._dumpsys_thread.check_user_exited(): - self._increment_bad_state() - self._stats['reset_count_user_exited'] += 1 - logging.warning('User exited the task. Truncating the episode.') - logging.info('************* END OF EPISODE *************') - return dm_env.truncation - - # Check if episode has ended - if self._latest_values['episode_end']: - self._stats['reset_count_episode_end'] += 1 - logging.info('End of episode from logcat! Ending episode.') - return dm_env.termination - - # Check if step limit or time limit has been reached - if self._task.max_episode_steps > 0: - if self._stats['episode_steps'] > self._task.max_episode_steps: - self._stats['reset_count_max_duration_reached'] += 1 - logging.info('Maximum task duration (%r steps) reached. ' - 'Truncating the episode.', self._task.max_episode_steps) - return dm_env.truncation - - if self._task.max_episode_sec > 0.0: - task_duration = datetime.datetime.now() - self._task_start_time - max_episode_sec = self._task.max_episode_sec - if task_duration > datetime.timedelta(seconds=int(max_episode_sec)): - self._stats['reset_count_max_duration_reached'] += 1 - logging.info('Maximum task duration (%r sec) reached. ' - 'Truncating the episode.', max_episode_sec) - return dm_env.truncation - - return dm_env.transition - - def _start_setup_step_interpreter( - self, adb_call_parser: adb_call_parser_lib.AdbCallParser): - self._setup_step_interpreter = setup_step_interpreter.SetupStepInterpreter( - adb_call_parser=adb_call_parser) - - def _start_logcat_thread(self, log_stream: log_stream_lib.LogStream): - log_stream.set_log_filters(list(self._task.log_parsing_config.filters)) - self._logcat_thread = logcat_thread.LogcatThread(log_stream=log_stream) - - for event_listener in self._logcat_listeners(): - self._logcat_thread.add_event_listener(event_listener) - - def _start_dumpsys_thread(self, - adb_call_parser: adb_call_parser_lib.AdbCallParser): - self._dumpsys_thread = dumpsys_thread.DumpsysThread( - app_screen_checker=app_screen_checker.AppScreenChecker( - adb_call_parser=adb_call_parser, - expected_app_screen=self._task.expected_app_screen, - ), - check_frequency=self._config.dumpsys_check_frequency, - max_failed_current_activity=self._config.max_failed_current_activity, - ) - - def _stop_logcat_thread(self): - if self._logcat_thread is not None: - self._logcat_thread.kill() - self._logcat_thread = None - - def _increment_bad_state(self) -> None: - """Increments the bad state counter. - - Bad states are errors that shouldn't happen and that trigger an - episode reset. If enough bad states have been seen consecutively, - we restart the simulation in the hope of returning the simulation - to a good state. - """ - logging.warning('Bad state detected.') - if self._config.max_bad_states: - self._is_bad_episode = True - self._bad_state_counter += 1 - logging.warning('Bad state counter: %d.', self._bad_state_counter) - if self._bad_state_counter >= self._config.max_bad_states: - logging.error('Too many consecutive bad states. Restarting simulator.') - self._stats['restart_count_max_bad_states'] += 1 - self._should_restart = True - else: - logging.warning('Max bad states not set, bad states will be ignored.') - - def _logcat_listeners(self): - """Creates list of EventListeners for logcat thread.""" - - # Defaults to 'a^' since that regex matches no string by definition. - regexps = self._task.log_parsing_config.log_regexps - listeners = [] - - # Reward listeners - def _reward_handler(event, match): - del event - reward = float(match.group(1)) - with self._lock: - self._latest_values['reward'] += reward - - for regexp in regexps.reward: - listeners.append(logcat_thread.EventListener( - regexp=re.compile(regexp or 'a^'), - handler_fn=_reward_handler)) - - # RewardEvent listeners - for reward_event in regexps.reward_event: - - def get_reward_event_handler(reward): - def _reward_event_handler(event, match): - del event, match - with self._lock: - self._latest_values['reward'] += reward - return _reward_event_handler - - listeners.append(logcat_thread.EventListener( - regexp=re.compile(reward_event.event or 'a^'), - handler_fn=get_reward_event_handler(reward_event.reward))) - - # Score listener - def _score_handler(event, match): - del event - current_score = float(match.group(1)) - with self._lock: - current_reward = current_score - self._latest_values['score'] - self._latest_values['score'] = current_score - self._latest_values['reward'] += current_reward - - listeners.append(logcat_thread.EventListener( - regexp=re.compile(regexps.score or 'a^'), - handler_fn=_score_handler)) - - # Episode end listeners - def _episode_end_handler(event, match): - del event, match - with self._lock: - self._latest_values['episode_end'] = True - - for regexp in regexps.episode_end: - listeners.append(logcat_thread.EventListener( - regexp=re.compile(regexp or 'a^'), - handler_fn=_episode_end_handler)) - - # Extra listeners - def _extras_handler(event, match): - del event - extra_name = match.group('name') - extra = match.group('extra') - if extra: - try: - extra = ast.literal_eval(extra) - except ( - ValueError, - TypeError, - SyntaxError, - MemoryError, - RecursionError, - ): - logging.exception('Could not parse extra: %s', extra) - # Don't try to process the extra as text; that would probably crash. - return - else: - # No extra value provided for boolean extra. Setting value to True. - extra = 1 - _process_extra(extra_name, extra) - - for regexp in regexps.extra: - listeners.append(logcat_thread.EventListener( - regexp=re.compile(regexp or 'a^'), - handler_fn=_extras_handler)) - - # JSON extra listeners - def _json_extras_handler(event, match): - del event - extra_data = match.group('json_extra') - try: - extra = dict(json.loads(extra_data)) - except ValueError: - logging.error('JSON string could not be parsed: %s', extra_data) - return - for extra_name, extra_value in extra.items(): - _process_extra(extra_name, extra_value) - - for regexp in regexps.json_extra: - listeners.append(logcat_thread.EventListener( - regexp=re.compile(regexp or 'a^'), - handler_fn=_json_extras_handler)) - - def _process_extra(extra_name, extra): - extra = np.array(extra) - with self._lock: - latest_extras = self._latest_values['extra'] - if extra_name in latest_extras: - # If latest extra is not flushed, append. - if ( - len(latest_extras[extra_name]) - >= self._config.extras_max_buffer_size - ): - latest_extras[extra_name].pop(0) - latest_extras[extra_name].append(extra) - else: - latest_extras[extra_name] = [extra] - self._latest_values['extra'] = latest_extras - - return listeners diff --git a/android_env/components/task_manager_test.py b/android_env/components/task_manager_test.py deleted file mode 100644 index 471ecfd9..00000000 --- a/android_env/components/task_manager_test.py +++ /dev/null @@ -1,415 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.components.task_manager.py.""" - -import json -from unittest import mock - -from absl.testing import absltest -from android_env.components import adb_call_parser as adb_call_parser_lib -from android_env.components import dumpsys_thread -from android_env.components import log_stream -from android_env.components import logcat_thread -from android_env.components import setup_step_interpreter -from android_env.components import task_manager -from android_env.proto import task_pb2 -import numpy as np - - -class TaskManagerTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.addCleanup(mock.patch.stopall) # Disable previous patches. - - self._setup_step_interpreter = mock.create_autospec( - setup_step_interpreter.SetupStepInterpreter) - self._dumpsys_thread = mock.create_autospec(dumpsys_thread.DumpsysThread) - self._logcat_thread = mock.create_autospec(logcat_thread.LogcatThread) - self._log_stream = mock.create_autospec(log_stream.LogStream) - - mock.patch.object( - setup_step_interpreter, - 'SetupStepInterpreter', - return_value=self._setup_step_interpreter).start() - mock.patch.object( - dumpsys_thread, 'DumpsysThread', - return_value=self._dumpsys_thread).start() - mock.patch.object( - logcat_thread, 'LogcatThread', - return_value=self._logcat_thread).start() - mock.patch.object( - log_stream, 'LogStream', - return_value=self._log_stream).start() - - def test_start(self): - task_mgr = task_manager.TaskManager(task=task_pb2.Task()) - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - self.assertIsNotNone(task_mgr._logcat_thread) - self.assertIsNotNone(task_mgr._dumpsys_thread) - self.assertIsNotNone(task_mgr._setup_step_interpreter) - - def test_setup_task(self): - task_mgr = task_manager.TaskManager(task=task_pb2.Task()) - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - self._setup_step_interpreter.interpret.assert_called_once() - - def test_step_count(self): - task_mgr = task_manager.TaskManager(task=task_pb2.Task()) - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - task_mgr.rl_reset(observation={}) - self.assertEqual(task_mgr.stats()['episode_steps'], 0) - task_mgr.rl_step(observation={}) - self.assertEqual(task_mgr.stats()['episode_steps'], 1) - task_mgr.rl_step(observation={}) - self.assertEqual(task_mgr.stats()['episode_steps'], 2) - task_mgr.rl_reset(observation={}) - self.assertEqual(task_mgr.stats()['episode_steps'], 0) - - def test_get_current_reward(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - match = event_listener.regexp.match('Reward: 123.0') - if match is None: # Ignore events that are not rewards. - return - - event_listener.handler_fn(event_listener.regexp, match) - - task = task_pb2.Task() - task.log_parsing_config.log_regexps.reward.extend([ - '^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$' - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertEqual(timestep.reward, 123.0) - np.testing.assert_equal(timestep.observation['pixels'], np.array([1, 2, 3])) - - def test_reward_event(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - match_1 = event_listener.regexp.match('foo_1') - match_2 = event_listener.regexp.match('foo_2') - match_3 = event_listener.regexp.match('Reward: 2.0') - if match_1: - event_listener.handler_fn(event_listener.regexp, match_1) - if match_2: - event_listener.handler_fn(event_listener.regexp, match_2) - if match_3: - event_listener.handler_fn(event_listener.regexp, match_3) - - task = task_pb2.Task() - reward_event_1 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent( - event='foo_1', reward=5.0) - reward_event_2 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent( - event='foo_2', reward=-1.0) - task.log_parsing_config.log_regexps.reward_event.extend( - [reward_event_1, reward_event_2]) - task.log_parsing_config.log_regexps.reward.extend( - ['^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$']) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertEqual(timestep.reward, 6.0) - - def test_get_current_reward_via_score(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - event = event_listener.regexp - match = event.match('score: 200.0') - if match is None: # Ignore events that are not scores. - return - - event_listener.handler_fn(event, match) - - # Scores are accumulated by their differences, so a subsequent lower score - # means that the final reward decreases. - match = event.match('score: 185') - event_listener.handler_fn(event, match) - - task = task_pb2.Task() - task.log_parsing_config.log_regexps.score = ( - '^score: ([-+]?[0-9]*\\.?[0-9]*)$') - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertEqual(timestep.reward, 185.0) - - def test_get_current_extras(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - event = event_listener.regexp - match = event.match('extra: some_extra [1, 2]') - if match is None: # Ignore events that are not extras. - return - - # Emit events. - fn = event_listener.handler_fn - fn(event, event.match('extra: an_extra [1, 2, 3]')) - fn(event, event.match('extra: an_extra [4, 5, 6]')) - fn(event, event.match('extra: another_extra 0.5')) - fn(event, event.match('extra: multi_dimension_extra [[9,8,7],[6,5,4]]')) - fn(event, event.match('extra: boolean_extra')) - - # Setup the task and trigger the listener. - task = task_pb2.Task() - task.log_parsing_config.log_regexps.extra.extend([ - '^extra: (?P[^ ]*)[ ]?(?P.*)$' - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - - # Check expectations. - self.assertIn('extras', timestep.observation) - extras = timestep.observation['extras'] - np.testing.assert_almost_equal([[1, 2, 3], [4, 5, 6]], - extras.get('an_extra')) - np.testing.assert_almost_equal([0.5], extras.get('another_extra')) - np.testing.assert_almost_equal([[[9, 8, 7], [6, 5, 4]]], - extras.get('multi_dimension_extra')) - np.testing.assert_equal([1], extras.get('boolean_extra')) - - def test_get_current_extras_json_format(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - event = event_listener.regexp - match = event.match('json_extra: {}') - if match is None: # Ignore events that are not extras. - return - - # Emit events. - extra = { - 'extra_scalar': 0, - 'extra_list': [1, 2, 3, 4], - 'extra_dict': { - 'foo': 'bar' - }, - 'extra_string': 'a_string' - } - extra_update = {'extra_string': 'a_new_string', 'extra_float': 0.6} - fn = event_listener.handler_fn - fn(event, event.match(f'json_extra: {json.dumps(extra)}')) - fn(event, event.match(f'json_extra: {json.dumps(extra_update)}')) - - # Setup the task and trigger the listener. - task = task_pb2.Task() - task.log_parsing_config.log_regexps.json_extra.extend([ - '^json_extra: (?P.*)$' - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - - # Check expectations. - self.assertIn('extras', timestep.observation) - extras = timestep.observation['extras'] - expected_extra = { - 'extra_scalar': [0], - 'extra_list': [[1, 2, 3, 4]], - 'extra_dict': [{ - 'foo': 'bar' - }], - 'extra_string': ['a_string', 'a_new_string'], - 'extra_float': [0.6] - } - np.testing.assert_almost_equal( - expected_extra.get('extra_scalar'), extras.get('extra_scalar')) - np.testing.assert_almost_equal( - expected_extra.get('extra_list'), extras.get('extra_list')) - np.testing.assert_equal( - expected_extra.get('extra_string'), extras.get('extra_string')) - np.testing.assert_almost_equal( - expected_extra.get('extra_float'), extras.get('extra_float')) - np.testing.assert_equal( - expected_extra.get('extra_dict'), extras.get('extra_dict')) - - def test_get_current_extras_failed_to_parse(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - event = event_listener.regexp - match = event.match('extra: some_extra [1, 2]') - if match is None: # Ignore events that are not extras. - return - - # Emit events. - fn = event_listener.handler_fn - fn(event, event.match('extra: extra_with_malformed_1 [1]')) - fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]')) - fn(event, event.match('extra: extra_with_malformed_1 [2]')) - fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]')) - fn(event, event.match('extra: extra_with_malformed_2 [1]')) - fn(event, event.match('extra: extra_malformed_only [_very_bad_news]')) - - # Setup the task and trigger the listener. - task = task_pb2.Task() - task.log_parsing_config.log_regexps.extra.extend([ - '^extra: (?P[^ ]*)[ ]?(?P.*)$' - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - - # Check expectations. - self.assertIn('extras', timestep.observation) - extras = timestep.observation['extras'] - np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'), - [[1], [2]]) - np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]]) - self.assertNotIn('extra_malformed_only', extras) - - def test_multi_log_regexp(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - match = event_listener.regexp.match('Reward_2: 123.0') - if match is None: # Ignore events that are not rewards. - return - - event_listener.handler_fn(event_listener.regexp, match) - - task = task_pb2.Task() - task.log_parsing_config.log_regexps.reward.extend([ - '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$', - '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$' - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertEqual(timestep.reward, 123.0) - - def test_multi_reward_regexp(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away.' - - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - match_1 = event_listener.regexp.match('Reward_1: 5.0') - match_2 = event_listener.regexp.match('Reward_2: 10.0') - - if match_1: - event_listener.handler_fn(event_listener.regexp, match_1) - - if match_2: - event_listener.handler_fn(event_listener.regexp, match_2) - - task = task_pb2.Task() - task.log_parsing_config.log_regexps.reward.extend([ - '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$', - '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$', - ]) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertEqual(timestep.reward, 15.0) - - def test_determine_transition_fn(self): - # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` - # right away. - def my_add_ev_listener(event_listener: logcat_thread.EventListener): - # Check that the event matches what's expected. - event = event_listener.regexp - match = event.match('I am done!') - if match is None: # Ignore events that are not episode end. - return - - event_listener.handler_fn(event, match) - - task = task_pb2.Task() - task.log_parsing_config.log_regexps.episode_end.extend(['I am done!']) - task_mgr = task_manager.TaskManager(task=task) - self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener - adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) - task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) - task_mgr.setup_task() - timestep = task_mgr.rl_step( - observation={ - 'pixels': np.array([1, 2, 3]), - }) - self.assertTrue(timestep.last()) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/env_interface.py b/android_env/env_interface.py deleted file mode 100644 index 45e8faf3..00000000 --- a/android_env/env_interface.py +++ /dev/null @@ -1,109 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Abstract AndroidEnv interface. - -AndroidEnv is a standard dm_env.Environment instance, but it also offers a few -extra methods that clients may use for extended functionality. -""" - -import abc -from typing import Any - -from android_env.proto import adb_pb2 -from android_env.proto import state_pb2 -import dm_env -import numpy as np - - -class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta): - """Pure virtual interface for AndroidEnv implementations.""" - - # Methods required by dm_env.Environment. - - @abc.abstractmethod - def action_spec(self) -> dict[str, dm_env.specs.Array]: - """Returns the action specification.""" - - @abc.abstractmethod - def observation_spec(self) -> dict[str, dm_env.specs.Array]: - """Returns the observation specification.""" - - @abc.abstractmethod - def reset(self) -> dm_env.TimeStep: - """Resets the current episode.""" - - @abc.abstractmethod - def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: - """Executes `action` and returns a `TimeStep`.""" - - @abc.abstractmethod - def close(self) -> None: - """Frees up resources.""" - - # Extensions provided by AndroidEnv. - - def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: - """Returns extra info provided by tasks.""" - - return {} - - @property - def raw_action(self): - """Returns the latest action.""" - - @property - def raw_observation(self): - """Returns the latest observation.""" - - def stats(self) -> dict[str, Any]: - """Returns information generated inside the implementation.""" - - return {} - - def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: - """Executes `call` and returns its response.""" - - return adb_pb2.AdbResponse() - - def load_state( - self, request: state_pb2.LoadStateRequest - ) -> state_pb2.LoadStateResponse: - """Loads a state. - - Args: - request: A `LoadStateRequest` containing any parameters necessary to - specify how/what state to load. - - Returns: - A `LoadStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - raise NotImplementedError('This environment does not support loading state') - - def save_state( - self, request: state_pb2.SaveStateRequest - ) -> state_pb2.SaveStateResponse: - """Saves a state. - - Args: - request: A `SaveStateRequest` containing any parameters necessary to - specify how/what state to save. - - Returns: - A `SaveStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - raise NotImplementedError('This environment does not support saving state') diff --git a/android_env/environment.py b/android_env/environment.py deleted file mode 100644 index 85638e68..00000000 --- a/android_env/environment.py +++ /dev/null @@ -1,190 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Android environment implementation.""" - -from typing import Any - -from absl import logging -from android_env import env_interface -from android_env.components import adb_call_parser -from android_env.components import coordinator as coordinator_lib -from android_env.components import task_manager as task_manager_lib -from android_env.components.simulators import base_simulator -from android_env.proto import adb_pb2 -from android_env.proto import state_pb2 -import dm_env -import numpy as np - - -class AndroidEnv(env_interface.AndroidEnvInterface): - """An RL environment that interacts with Android apps.""" - - def __init__( - self, - simulator: base_simulator.BaseSimulator, - coordinator: coordinator_lib.Coordinator, - task_manager: task_manager_lib.TaskManager, - ): - """Initializes the state of this AndroidEnv object.""" - - self._simulator = simulator - self._coordinator = coordinator - self._task_manager = task_manager - self._latest_action = {} - self._latest_observation = {} - self._latest_extras = {} - self._reset_next_step = True - self._is_closed = False - - logging.info('Action spec: %s', self.action_spec()) - logging.info('Observation spec: %s', self.observation_spec()) - - def __del__(self) -> None: - self.close() - - # Methods required by dm_env.Environment. - - def action_spec(self) -> dict[str, dm_env.specs.Array]: - return self._coordinator.action_spec() - - def observation_spec(self) -> dict[str, dm_env.specs.Array]: - return self._coordinator.observation_spec() - - def reset(self) -> dm_env.TimeStep: - """Resets the environment for a new RL episode.""" - - logging.info('Resetting AndroidEnv...') - - # Execute a reset. Timestep will be of type FIRST. - timestep = self._coordinator.rl_reset() - - # Process relevant information. - if timestep.observation is not None: - self._latest_extras = timestep.observation.pop('extras') - self._latest_observation = timestep.observation.copy() - else: - # If the observation is None, we return the latest observation again. - timestep = timestep._replace(observation=self._latest_observation.copy()) - - self._latest_action = {} - self._reset_next_step = False - - logging.info('Done resetting AndroidEnv.') - logging.info('************* NEW EPISODE *************') - - return timestep - - def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: - """Takes a step in the environment.""" - - # Check if it's time to reset the episode. - if self._reset_next_step: - return self.reset() - - # Execute selected action. - timestep = self._coordinator.rl_step(action) - - # Process relevant information. - if timestep.observation is not None: - self._latest_extras = timestep.observation.pop('extras') - self._latest_observation = timestep.observation.copy() - else: - # If the observation is None, we return the latest observation again. - timestep = timestep._replace(observation=self._latest_observation.copy()) - - self._latest_action = action.copy() - - if timestep.last(): - self._reset_next_step = True - logging.info('************* END OF EPISODE *************') - - return timestep - - def close(self) -> None: - """Cleans up running processes, threads and local files.""" - if not self._is_closed: - logging.info('Cleaning up AndroidEnv...') - if hasattr(self, '_coordinator'): - self._coordinator.close() - logging.info('Done cleaning up AndroidEnv.') - self._is_closed = True - - # Extensions provided by AndroidEnv. - - def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: - """Returns latest task extras.""" - - task_extras = {} # Build a copy to avoid reusing objects. - for k, spec in self._latest_extras.items(): - extra_values = spec.astype(spec.dtype) - task_extras[k] = extra_values[-1] if latest_only else extra_values - return task_extras - - @property - def raw_action(self): - return self._latest_action.copy() - - @property - def raw_observation(self): - return self._latest_observation.copy() - - def stats(self) -> dict[str, Any]: - coordinator_stats = self._coordinator.stats() - task_manager_stats = self._task_manager.stats() - return coordinator_stats | task_manager_stats - - def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: - return self._coordinator.execute_adb_call(call) - - def load_state( - self, request: state_pb2.LoadStateRequest - ) -> state_pb2.LoadStateResponse: - """Loads a state. - - Args: - request: A `LoadStateRequest` containing any parameters necessary to - specify how/what state to load. - - Returns: - A `LoadStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - - self._task_manager.stop() - response = self._simulator.load_state(request) - self._task_manager.start( - adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser( - self._simulator.create_adb_controller() - ), - log_stream=self._simulator.create_log_stream(), - ) - return response - - def save_state( - self, request: state_pb2.SaveStateRequest - ) -> state_pb2.SaveStateResponse: - """Saves a state. - - Args: - request: A `SaveStateRequest` containing any parameters necessary to - specify how/what state to save. - - Returns: - A `SaveStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - - return self._simulator.save_state(request) diff --git a/android_env/environment_test.py b/android_env/environment_test.py deleted file mode 100644 index b74c5989..00000000 --- a/android_env/environment_test.py +++ /dev/null @@ -1,273 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for AndroidEnv.""" - -from unittest import mock - -from absl.testing import absltest -from android_env import environment -from android_env.components import config_classes -from android_env.components import coordinator as coordinator_lib -from android_env.components import task_manager as task_manager_lib -from android_env.components.simulators import base_simulator -from android_env.components.simulators.fake import fake_simulator -from android_env.proto import adb_pb2 -from android_env.proto import state_pb2 -import dm_env -import numpy as np - - -def _create_mock_coordinator() -> coordinator_lib.Coordinator: - coordinator = mock.create_autospec(coordinator_lib.Coordinator) - coordinator.action_spec.return_value = { - 'action_type': - dm_env.specs.DiscreteArray(num_values=3), - 'touch_position': - dm_env.specs.BoundedArray( - shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), - } - coordinator.observation_spec.return_value = { - 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), - 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), - 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), - } - return coordinator - - -def _create_fake_simulator() -> fake_simulator.FakeSimulator: - return fake_simulator.FakeSimulator( - config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) - ) - - -class AndroidEnvTest(absltest.TestCase): - - def test_specs(self): - simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - - # Check action spec. - self.assertNotEmpty(env.action_spec()) - self.assertIn('action_type', env.action_spec()) - self.assertIsInstance(env.action_spec()['action_type'], - dm_env.specs.DiscreteArray) - self.assertIn('touch_position', env.action_spec()) - self.assertIsInstance(env.action_spec()['touch_position'], - dm_env.specs.BoundedArray) - - # Check observation spec. - self.assertNotEmpty(env.observation_spec()) - self.assertIn('pixels', env.observation_spec()) - self.assertIsInstance(env.observation_spec()['pixels'], dm_env.specs.Array) - # The `pixels` entry in the observation spec should match the screen size of - # the simulator with three color channels (RGB). - self.assertEqual(env.observation_spec()['pixels'].shape, (123, 456, 3)) - self.assertIn('timedelta', env.observation_spec()) - self.assertIsInstance(env.observation_spec()['timedelta'], - dm_env.specs.Array) - # The `timedelta` should be a scalar. - self.assertEqual(env.observation_spec()['timedelta'].shape, ()) - self.assertIn('orientation', env.observation_spec()) - # The `orientation` should be a one-hot vector with four dimensions. - self.assertIsInstance(env.observation_spec()['orientation'], - dm_env.specs.Array) - self.assertEqual(env.observation_spec()['orientation'].shape, (4,)) - - def test_reset_and_step(self): - simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - coordinator.action_spec.return_value = { - 'action_type': - dm_env.specs.DiscreteArray(num_values=3), - 'touch_position': - dm_env.specs.BoundedArray( - shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), - } - coordinator.observation_spec.return_value = { - 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), - 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), - 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), - } - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - coordinator.rl_reset.return_value = dm_env.TimeStep( - step_type=dm_env.StepType.FIRST, - reward=0.0, - discount=0.0, - observation={ - 'pixels': np.random.rand(987, 654, 3), - 'timedelta': 123456, - 'orientation': np.array((1, 0, 0, 0)), - 'extras': { - 'click': np.array([[246]], dtype=np.int64) - } - }, - ) - - ts = env.reset() - self.assertIsInstance(ts, dm_env.TimeStep) - # After a `reset()` the TimeStep should follow some expectations. - self.assertTrue(ts.first()) - self.assertEqual(ts.reward, 0.0) - self.assertEqual(ts.discount, 0.0) - obs = ts.observation - self.assertIn('pixels', obs) - self.assertEqual(obs['pixels'].shape, (987, 654, 3)) - self.assertIn('timedelta', obs) - self.assertEqual(obs['timedelta'], 123456) - self.assertIn('orientation', obs) - self.assertEqual(obs['orientation'].shape, (4,)) - np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) - - # Extras should also be provided. - extras = env.task_extras() - self.assertIn('click', extras) - self.assertEqual(extras['click'], np.array([246], dtype=np.int64)) - - coordinator.stats.return_value = {'my_measurement': 135} - task_manager.stats.return_value = {'another_measurement': 79} - - # Step again in the environment and check expectations again. - pixels = np.random.rand(987, 654, 3) - latest_observation = { - 'pixels': pixels, - 'timedelta': 123456, - 'orientation': np.array((1, 0, 0, 0)), - 'extras': { - 'click': np.array([[246]], dtype=np.int64) - } - } - coordinator.rl_step.return_value = dm_env.transition( - reward=0.0, - discount=0.0, - observation=latest_observation, - ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) - self.assertIsInstance(ts, dm_env.TimeStep) - # The StepType now should NOT be FIRST. - self.assertFalse(ts.first()) - self.assertEqual(ts.reward, 0.0) - self.assertEqual(ts.discount, 0.0) - obs = ts.observation - self.assertIn('pixels', obs) - self.assertEqual(obs['pixels'].shape, (987, 654, 3)) - self.assertIn('timedelta', obs) - self.assertEqual(obs['timedelta'], 123456) - self.assertIn('orientation', obs) - self.assertEqual(obs['orientation'].shape, (4,)) - np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) - - # Extras should still be provided. - extras = env.task_extras() - self.assertIn('click', extras) - self.assertEqual(extras['click'], np.array([246], dtype=np.int64)) - - # At this point these methods and properties should return something. - self.assertNotEmpty(env.stats()) - self.assertNotEmpty(env.raw_observation) - self.assertNotIn('extras', env.raw_observation) - self.assertNotEmpty(env.raw_action) - - # If the observation is None, we want to return the latest observation. - coordinator.rl_step.return_value = dm_env.truncation( - reward=0.0, - observation=None, - ) - ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) - self.assertIsInstance(ts, dm_env.TimeStep) - # Assert the observation matches the latest observation. - obs = ts.observation - self.assertIn('pixels', obs) - self.assertEqual(obs['pixels'].shape, (987, 654, 3)) - np.testing.assert_equal(obs['pixels'], pixels) - self.assertIn('timedelta', obs) - self.assertEqual(obs['timedelta'], 123456) - self.assertIn('orientation', obs) - self.assertEqual(obs['orientation'].shape, (4,)) - np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) - - def test_adb_call(self): - simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - call = adb_pb2.AdbRequest( - force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah')) - expected_response = adb_pb2.AdbResponse( - status=adb_pb2.AdbResponse.Status.OK) - coordinator.execute_adb_call.return_value = expected_response - - response = env.execute_adb_call(call) - - self.assertEqual(response, expected_response) - coordinator.execute_adb_call.assert_called_once_with(call) - - def test_load_state(self): - simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - expected_response = state_pb2.LoadStateResponse( - status=state_pb2.LoadStateResponse.Status.OK - ) - request = state_pb2.LoadStateRequest(args={'foo': 'bar'}) - simulator.load_state.return_value = expected_response - response = env.load_state(request) - self.assertEqual(response, expected_response) - simulator.load_state.assert_called_once_with(request) - task_manager.stop.assert_called_once() - task_manager.start.assert_called_once() - - def test_save_state(self): - simulator = mock.create_autospec(base_simulator.BaseSimulator) - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - expected_response = state_pb2.SaveStateResponse( - status=state_pb2.SaveStateResponse.Status.OK - ) - request = state_pb2.SaveStateRequest(args={'foo': 'bar'}) - simulator.save_state.return_value = expected_response - response = env.save_state(request) - self.assertEqual(response, expected_response) - simulator.save_state.assert_called_once_with(request) - - def test_double_close(self): - simulator = _create_fake_simulator() - coordinator = _create_mock_coordinator() - task_manager = mock.create_autospec(task_manager_lib.TaskManager) - env = environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - env.close() - env.close() - coordinator.close.assert_called_once() - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/loader.py b/android_env/loader.py deleted file mode 100644 index 0588ff33..00000000 --- a/android_env/loader.py +++ /dev/null @@ -1,89 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Function for loading AndroidEnv.""" - -import os - -from absl import logging -from android_env import environment -from android_env.components import config_classes -from android_env.components import coordinator as coordinator_lib -from android_env.components import device_settings as device_settings_lib -from android_env.components import task_manager as task_manager_lib -from android_env.components.simulators.emulator import emulator_simulator -from android_env.components.simulators.fake import fake_simulator -from android_env.proto import task_pb2 - -from google.protobuf import text_format - - -def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task: - """Returns the task according to `task_config`.""" - - task = task_pb2.Task() - match task_config: - case config_classes.FilesystemTaskConfig(): - with open(task_config.path, 'r') as proto_file: - text_format.Parse(proto_file.read(), task) - case _: - logging.error('Unsupported TaskConfig: %r', task_config) - - return task - - -def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv: - """Loads an AndroidEnv instance.""" - - task = _load_task(config.task) - task_manager = task_manager_lib.TaskManager(task) - - match config.simulator: - case config_classes.EmulatorConfig(): - _process_emulator_launcher_config(config.simulator) - simulator = emulator_simulator.EmulatorSimulator(config=config.simulator) - case config_classes.FakeSimulatorConfig(): - simulator = fake_simulator.FakeSimulator(config=config.simulator) - case _: - raise ValueError('Unsupported simulator config: {config.simulator}') - - device_settings = device_settings_lib.DeviceSettings(simulator) - coordinator = coordinator_lib.Coordinator( - simulator, task_manager, device_settings - ) - return environment.AndroidEnv( - simulator=simulator, coordinator=coordinator, task_manager=task_manager - ) - - -def _process_emulator_launcher_config( - emulator_config: config_classes.EmulatorConfig, -) -> None: - """Adjusts the configuration of the emulator depending on some conditions.""" - - # Expand the user directory if specified. - launcher_config = emulator_config.emulator_launcher - launcher_config.android_avd_home = os.path.expanduser( - launcher_config.android_avd_home - ) - launcher_config.android_sdk_root = os.path.expanduser( - launcher_config.android_sdk_root - ) - launcher_config.emulator_path = os.path.expanduser( - launcher_config.emulator_path - ) - emulator_config.adb_controller.adb_path = os.path.expanduser( - emulator_config.adb_controller.adb_path - ) diff --git a/android_env/loader_test.py b/android_env/loader_test.py deleted file mode 100644 index 280fa4ab..00000000 --- a/android_env/loader_test.py +++ /dev/null @@ -1,182 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for loader.""" - -import builtins -import os -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env import loader -from android_env.components import config_classes -from android_env.components import coordinator as coordinator_lib -from android_env.components import device_settings as device_settings_lib -from android_env.components import task_manager as task_manager_lib -from android_env.components.simulators.emulator import emulator_simulator -from android_env.components.simulators.fake import fake_simulator -from android_env.proto import task_pb2 - - -class LoaderTest(absltest.TestCase): - - @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) - @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) - @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True) - @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) - @mock.patch.object(builtins, 'open', autospec=True) - def test_load_emulator( - self, - mock_open, - mock_coordinator, - mock_device_settings, - mock_simulator_class, - mock_task_manager, - ): - - # Arrange. - mock_open.return_value.__enter__ = mock_open - mock_open.return_value.read.return_value = '' - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path='some/path/'), - simulator=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - avd_name='my_avd', - android_avd_home='~/.android/avd', - android_sdk_root='~/Android/Sdk', - emulator_path='~/Android/Sdk/emulator/emulator', - run_headless=False, - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='~/Android/Sdk/platform-tools/adb', - ), - ), - ) - - # Act. - env = loader.load(config) - - # Assert. - self.assertIsInstance(env, env_interface.AndroidEnvInterface) - mock_simulator_class.assert_called_with( - config=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - avd_name='my_avd', - android_avd_home=os.path.expanduser('~/.android/avd'), - android_sdk_root=os.path.expanduser('~/Android/Sdk'), - emulator_path=os.path.expanduser( - '~/Android/Sdk/emulator/emulator' - ), - run_headless=False, - gpu_mode='swangle_indirect', - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'), - adb_server_port=5037, - ), - ) - ) - mock_coordinator.assert_called_with( - mock_simulator_class.return_value, - mock_task_manager.return_value, - mock_device_settings.return_value, - ) - - @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) - @mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True) - @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True) - @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) - @mock.patch.object(builtins, 'open', autospec=True) - def test_load_fake_simulator( - self, - mock_open, - mock_coordinator, - mock_device_settings, - mock_simulator_class, - mock_task_manager, - ): - - # Arrange. - mock_open.return_value.__enter__ = mock_open - mock_open.return_value.read.return_value = '' - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path='some/path/'), - simulator=config_classes.FakeSimulatorConfig( - screen_dimensions=(1234, 5678) - ), - ) - - # Act. - env = loader.load(config) - - # Assert. - self.assertIsInstance(env, env_interface.AndroidEnvInterface) - mock_simulator_class.assert_called_with( - config=config_classes.FakeSimulatorConfig( - screen_dimensions=(1234, 5678) - ) - ) - mock_coordinator.assert_called_with( - mock_simulator_class.return_value, - mock_task_manager.return_value, - mock_device_settings.return_value, - ) - - @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) - @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) - @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) - @mock.patch.object(builtins, 'open', autospec=True) - def test_task( - self, mock_open, mock_coordinator, mock_simulator, mock_task_manager - ): - - # Arrange. - del mock_coordinator, mock_simulator - mock_open.return_value.__enter__ = mock_open - mock_open.return_value.read.return_value = r''' -id: "fake_task" -name: "Fake Task" -description: "Task for testing loader." -max_episode_sec: 0 -''' - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path='some/path/'), - simulator=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - avd_name='my_avd' - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path='~/Android/Sdk/platform-tools/adb', - ), - ), - ) - - # Act. - env = loader.load(config) - - # Assert. - expected_task = task_pb2.Task() - expected_task.id = 'fake_task' - expected_task.name = 'Fake Task' - expected_task.description = 'Task for testing loader.' - expected_task.max_episode_sec = 0 - - mock_task_manager.assert_called_with(expected_task) - self.assertIsInstance(env, env_interface.AndroidEnvInterface) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/proto/__init__.py b/android_env/proto/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/proto/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/proto/a11y/__init__.py b/android_env/proto/a11y/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/proto/a11y/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/proto/a11y/a11y.proto b/android_env/proto/a11y/a11y.proto deleted file mode 100644 index 90c266bb..00000000 --- a/android_env/proto/a11y/a11y.proto +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/a11y/android_accessibility_forest.proto"; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// A service to send Accessibility information to a remote server. -// -// The client is assumed to be running inside an Android device (e.g. emulator -// or real device) while the server is assumed to be running outside (e.g. in a -// Python process). -service A11yService { - // Sends a forest of Accessibility trees to a server. - rpc SendForest(AndroidAccessibilityForest) returns (ForestResponse) {} - // Sends an a11y event to a server. - rpc SendEvent(EventRequest) returns (EventResponse) {} - - // Long-lived bidirection communication between the client and the server. - rpc Bidi(stream ClientToServer) returns (stream ServerToClient) {} -} - -// TODO(b/334952387): Remove `ForestResponse`, `EventRequest` and -// `EventResponse` once bidi communication is in-place. -message ForestResponse { - // The error if anything. - string error = 1; -} - -// An Accessibility event. -message EventRequest { - // A single event as a dictionary. - map event = 1; -} - -message EventResponse { - // The error if anything. - string error = 1; -} - -// The message sent from the Android device to the server running outside of the -// device. -message ClientToServer { - oneof payload { - EventRequest event = 1; - AndroidAccessibilityForest forest = 2; - } -} - -// The message sent from the server running outside of the device to the Android -// device. -message ServerToClient { - // A request to obtain the Accessibility forest. - message GetA11yForest {} - - oneof payload { - GetA11yForest get_forest = 1; - } -} diff --git a/android_env/proto/a11y/android_accessibility_action.proto b/android_env/proto/a11y/android_accessibility_action.proto deleted file mode 100644 index 5273d7d9..00000000 --- a/android_env/proto/a11y/android_accessibility_action.proto +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// An Android Accessibility Action. -// Next index: 3 -message AndroidAccessibilityAction { - // Required ID that uniquely identifies the action for this node. - // Can be one of the standard action IDs listed in the documentation. - // https://developer.android.com/reference/android/view/accessibility/AccessibilityNodeInfo.AccessibilityAction - int32 id = 1; - - // Optional label describing what the action is. - string label = 2; -} diff --git a/android_env/proto/a11y/android_accessibility_forest.proto b/android_env/proto/a11y/android_accessibility_forest.proto deleted file mode 100644 index 63ddb8c5..00000000 --- a/android_env/proto/a11y/android_accessibility_forest.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/a11y/android_accessibility_window_info.proto"; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// A forest of Android accessibility trees. Each tree belongs to a single -// window. Next index: 2 -message AndroidAccessibilityForest { - // All of the windows present on screen. - repeated AndroidAccessibilityWindowInfo windows = 1; -} diff --git a/android_env/proto/a11y/android_accessibility_node_info.proto b/android_env/proto/a11y/android_accessibility_node_info.proto deleted file mode 100644 index 3c904c86..00000000 --- a/android_env/proto/a11y/android_accessibility_node_info.proto +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/a11y/android_accessibility_action.proto"; -import "android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto"; -import "android_env/proto/a11y/rect.proto"; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// An Android AccessibilityNodeInfo. -// Next index: 32 -message AndroidAccessibilityNodeInfo { - // Unique monotonically-increasing ID. - int32 unique_id = 1; - - // The bounds of this node within the device's screen. - ProtoRect bounds_in_screen = 2; - - // The name of the View class that created this node. - string class_name = 3; - - // The content description of the node. - string content_description = 4; - - // The hint text of the node. - string hint_text = 5; - - // The name of the package this node comes from. - string package_name = 6; - - // The text of this node. - string text = 7; - - // The start index of the text selection. - int64 text_selection_start = 8; - - // The end index of the text selection. - int64 text_selection_end = 9; - - // The view ID resource name of the node. - string view_id_resource_name = 10; - - // The ID of the window this node belongs to. - int32 window_id = 11; - - // If true, this node can be checked. - bool is_checkable = 12; - - // If true, this node is currently checked. - bool is_checked = 13; - - // If true, this node (probably) responds to being clicked. - bool is_clickable = 14; - - // If true, this node's text can be edited by the user. - bool is_editable = 15; - - // If true, this node is enabled (e.g., if it is a button). - bool is_enabled = 16; - - // If true, this node can be focused (e.g., a text input). - bool is_focusable = 17; - - // If true, this node is currently focused. - bool is_focused = 18; - - // If true, this node (probably) responds to being long pressed. - bool is_long_clickable = 19; - - // If true, this node is a password input. - bool is_password = 20; - - // If true, this node can be scrolled. - bool is_scrollable = 21; - - // If true, this node is currently selected. - bool is_selected = 22; - - // If true, this node is (probably) visible to the user. - bool is_visible_to_user = 23; - - // List of actions that can be performed on this node. - repeated AndroidAccessibilityAction actions = 24; - - // Ordered list of child IDs (i.e., unique_id). - repeated int32 child_ids = 25 [packed = true]; - - // List of clickable spans present in the node's text or content description. - repeated AndroidAccessibilityNodeInfoClickableSpan clickable_spans = 26; - - // The depth of this node in the accessibility tree. - int32 depth = 27; - - // Unique ID of the node that this node is declaring itself to be labeled by. - int32 labeled_by_id = 28; - - // Unique ID of the node that this is node is declaring itself to be a label - // for. - int32 label_for_id = 29; - - // The drawing order for the node. - int32 drawing_order = 30; - - // The tooltip text of the node. - string tooltip_text = 31; -} diff --git a/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto b/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto deleted file mode 100644 index f20d0bfd..00000000 --- a/android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// A single clickable span found in the accessibility node's text. -// Next index: 6 -message AndroidAccessibilityNodeInfoClickableSpan { - // The source of the span (so the client can find the correct spannable string - // in the node). - // Next index: 3 - enum SpanSource { - UNKNOWN_TYPE = 0; // Catch all type for forward compatibility. - TEXT = 1; // The span is from node#getText - CONTENT_DESCRIPTION = 2; // The span is from node#getContentDescription. - } - - // The text of the span (a substring of the spannable string). - string text = 1; - - // The URL attached to the span if specified. - string url = 2; - - // The source of the span. - SpanSource source = 3; - - // The index of the first character of the span in the spannable string. - // The end of the span would be a sum of span_start and text.length(). - int32 start = 4; - - // The unique_id from the corresponding AndroidAccessibilityNodeInfo. - int32 node_id = 5; -} diff --git a/android_env/proto/a11y/android_accessibility_tree.proto b/android_env/proto/a11y/android_accessibility_tree.proto deleted file mode 100644 index 4bc48ef9..00000000 --- a/android_env/proto/a11y/android_accessibility_tree.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/a11y/android_accessibility_node_info.proto"; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// A tree (actually a graph) of Android accessibility nodes. -// Next index: 3 -message AndroidAccessibilityTree { - // All of the nodes in the graph. The root node is the node whose ID is 0. - repeated AndroidAccessibilityNodeInfo nodes = 1; -} diff --git a/android_env/proto/a11y/android_accessibility_window_info.proto b/android_env/proto/a11y/android_accessibility_window_info.proto deleted file mode 100644 index 2e0baeec..00000000 --- a/android_env/proto/a11y/android_accessibility_window_info.proto +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/a11y/android_accessibility_tree.proto"; -import "android_env/proto/a11y/rect.proto"; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// An Android AccessibilityWindowInfo. -// Next index: 12 -message AndroidAccessibilityWindowInfo { - // Type of the window. - // Next index: 6 - enum WindowType { - // The window type is an unknown value. - UNKNOWN_TYPE = 0; - - // A standard application window. - TYPE_APPLICATION = 1; - - // An IME window (e.g. GBoard). - TYPE_INPUT_METHOD = 2; - - // A system window (e.g., a notification). - TYPE_SYSTEM = 3; - - // An accessibility overlay. - TYPE_ACCESSIBILITY_OVERLAY = 4; - - // A system window used to divide the screen in split-screen mode. This type - // of window is present only in split-screen mode. - TYPE_SPLIT_SCREEN_DIVIDER = 5; - } - - // Bounds of this window in the device's screen. - ProtoRect bounds_in_screen = 1; - - // A unique ID identifying the display in which this window is shown. - int32 display_id = 2; - - // Unique ID as defined by the Android platform. - int32 id = 3; - - // Z-index of the window. Windows with a greater z-index appear in front of - // those with a lesser z-index. - int32 layer = 4; - - // The title of the window, if set. - string title = 5; - - // The type of the window. - WindowType window_type = 6; - - // If true, the window is currently accessibility-focused. - bool is_accessibility_focused = 7; - - // If true, the window is currently active. - bool is_active = 8; - - // If true, the window is currently focused. - bool is_focused = 9; - - // If true, the window is in Picture in Picture mode. - bool is_in_picture_in_picture_mode = 10; - - // The associated accessibility tree for this window. - AndroidAccessibilityTree tree = 11; -} diff --git a/android_env/proto/a11y/rect.proto b/android_env/proto/a11y/rect.proto deleted file mode 100644 index 58167b44..00000000 --- a/android_env/proto/a11y/rect.proto +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -option java_multiple_files = true; -option java_package = "com.google.androidenv.accessibilityforwarder"; - -// Proto representation of Android Rect. -// https://developer.android.com/reference/android/graphics/Rect -// Next index: 5 -message ProtoRect { - int32 left = 1; - int32 top = 2; - int32 right = 3; - int32 bottom = 4; -} diff --git a/android_env/proto/adb.proto b/android_env/proto/adb.proto deleted file mode 100644 index 86e65709..00000000 --- a/android_env/proto/adb.proto +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -message AdbRequest { - // Installs an APK into the simulator. - message InstallApk { - - // A location in the filesystem. - message Filesystem { - string path = 1; - } - - // A byte sequence of a single APK file. - message Blob { - // The serialized file as bytes. - bytes contents = 1; - } - - oneof location { - Filesystem filesystem = 2; - Blob blob = 6; - } - } - - message StartActivity { - string full_activity = 1; - repeated string extra_args = 2; - // Whether to stop the current app before starting the activity. - // Notice that if this option is `true`, the activity probably needs the - // `android:launchMode="singleTop"` attribute in its `AndroidManifest.xml`, - // otherwise intents may not be received by `onNewIntent()`. Please see more - // info on `android:launchMode` at - // https://developer.android.com/guide/topics/manifest/activity-element. - bool force_stop = 3; - } - - message SendBroadcast { - // Action to send during the broadcast event. - string action = 1; - - // Specify the component name with package name prefix to create an explicit - // intent, such as com.example.app/.ExampleActivity (see -n specification at - // https://developer.android.com/tools/adb#IntentSpec). - string component = 2; - } - - message UninstallPackage { - string package_name = 1; - } - - message ForceStop { - string package_name = 1; - } - - message Tap { - // NOTE: These are absolute coordinates in the range of the screen - // resolution. They are NOT floats in [0,1]. - // Precondition: `x` and `y` must be non-negative. - int32 x = 1; - int32 y = 2; - } - - message PressButton { - enum Button { - HOME = 0; - BACK = 1; - ENTER = 2; - } - Button button = 1; - } - - // Pins the given activity to the screen. - // This essentially locks the user into a single app mode (aka "Kiosk mode"). - message StartScreenPinning { - string full_activity = 1; - } - - // Returns the full activity name that is currently opened to the user. - // If successful, a GetCurrentActivityResponse is returned. - message GetCurrentActivity {} - - // Returns the orientation of the device. - message GetOrientationRequest {} - - // Performs `adb push`. - // Please see https://developer.android.com/studio/command-line/adb#copyfiles. - // - // Notice that a source destination path for the file is not sent, but raw - // bytes in `content` instead. Obviously, the `content` can be set from a real - // file, but this is done to ensure Task definitions are as hermetic as - // possible, without depending on the environment from where they're run. - message Push { - // The contents of the file. - bytes content = 1; - - // Destination path _inside_ Android. E.g. /sdcard/my_file.txt. - string path = 2; - } - - // Performs `adb pull`. - // Please see https://developer.android.com/studio/command-line/adb#copyfiles. - // - // Notice that a local destination for the copied file is not sent, as raw - // bytes are returned instead (please see PullResponse). Obviously, these - // bytes can be written to disk by the caller of this command. - message Pull { - // Path _inside_ Android. E.g. /sdcard/my_file.txt. - string path = 1; - } - - // Inserts text into the current text field (if any). - // Essentially `adb shell input text `. - message InputText { - string text = 1; - } - - // Issues an `adb shell settings` command. - message SettingsRequest { - // Each request has an associated namespace. - enum Namespace { - UNKNOWN = 0; - SYSTEM = 1; - SECURE = 2; - GLOBAL = 3; - } - - // Retrieves the current value for `key`. - message Get { - string key = 1; - } - - // Changes the contents `key` to `value`. - message Put { - string key = 1; - string value = 2; - } - - // Deletes the entry for `key`. - message Delete { - string key = 1; - } - - // Resets the global/secure table for a package with the given mode. - message Reset { - enum Mode { - UNKNOWN = 0; - UNTRUSTED_DEFAULTS = 1; - UNTRUSTED_CLEAR = 2; - TRUSTED_DEFAULTS = 3; - } - - string package_name = 1; - Mode mode = 2; - } - - // Prints all defined keys in the given namespace. - message List {} - - // The part of the system where this command will take place. - // NOTE: We avoid the identifier `namespace` because it's a keyword in C++. - Namespace name_space = 1; - - // The subcommand to issue to `adb settings`. - // NOTE: We avoid the identifiers `delete` and `del` because they're - // keywords in C++ and Python respectively. - oneof verb { - Get get = 2; - Put put = 3; - Delete delete_key = 4; - Reset reset = 5; - List list = 6; - } - } - - // Generic ADB command. Use this for commands that are not - // explicitly implemented. - // Calls `adb [args...]`. - message GenericRequest { - repeated string args = 1; - } - - message PackageManagerRequest { - message List { - // Lists all features of the system. - message Features {} - - // Lists all system libraries. - message Libraries {} - - // Lists all packages; optionally only those whose name contains the text - // in `filter`. - message Packages { - string filter = 1; - - // Extra options that control the output. Please see `pm help` for - // details. - repeated string options = 2; - } - - oneof what { - Features features = 1; - Libraries libraries = 2; - Packages packages = 3; - } - } - - // Deletes all data associated with a package. - message Clear { - // The package name to clear its cache. - string package_name = 1; - - // Optional USER_ID. - string user_id = 2; - } - - message Grant { - string package_name = 1; - - // Possible values listed at - // https://developer.android.com/reference/android/Manifest.permission - // To query an app's required permissions, use the following adb command: - // > adb shell dumpsys package - // The output will contain things like - // android.permission.WRITE_SECURE_SETTINGS - repeated string permissions = 2; - } - - // The subcommand to issue to `pm`. - oneof verb { - List list = 1; - Clear clear = 2; - Grant grant = 3; - } - } - - // For executing `dumpsys` commands. - message DumpsysRequest { - enum PriorityLevel { - UNSET = 0; - NORMAL = 1; - HIGH = 2; - CRITICAL = 3; - } - - // The service to dump. If empty, all services will be dumped. - string service = 1; - - // Optional arguments to pass to the specific service dump. - repeated string args = 2; - - // Lists services, does not dump them. - // This effectively disables dumping information about any particular - // service. - bool list_only = 3; - - // Timeouts natively supported by `dumpsys`. - int32 timeout_sec = 4; - int32 timeout_ms = 5; - - // Whether to dump the process ID instead of the usual dump. - bool pid = 6; - - // Whether dumps will be in proto format. Only works for services that - // support dumping data in proto format. - bool proto = 7; - - // Filters services based on specified priority. - PriorityLevel priority = 8; - - // Excludes services from the dump. - repeated string skip_services = 9; - } - - oneof command { - InstallApk install_apk = 1; - StartActivity start_activity = 2; - ForceStop force_stop = 3; - Tap tap = 6; - PressButton press_button = 7; - StartScreenPinning start_screen_pinning = 10; - UninstallPackage uninstall_package = 16; - GetCurrentActivity get_current_activity = 17; - GetOrientationRequest get_orientation = 24; - Push push = 18; - Pull pull = 19; - InputText input_text = 20; - SettingsRequest settings = 21; - GenericRequest generic = 22; - PackageManagerRequest package_manager = 23; - DumpsysRequest dumpsys = 26; - SendBroadcast send_broadcast = 25; - } - - // Optional (soft) deadline in seconds for completing this command. - // Expected to be >0. If ==0 (the default), it's ignored. - // Notice that not all commands accept timeouts, but because it's such a - // common parameter, we include it here instead of in each separate command. - float timeout_sec = 100; -} - -message AdbResponse { - enum Status { - // Reserved value for unset statuses. - UNDEFINED = 0; - // Returned when everything goes well. - OK = 1; - // Returned when handling unknown AdbRequest commands. - UNKNOWN_COMMAND = 2; - // Returned when an argument does not respect a precondition. - FAILED_PRECONDITION = 3; - // Returned when something internal did not work as expected. - INTERNAL_ERROR = 4; - // Returned when the adb command failed. - ADB_ERROR = 5; - // Returned when the adb command timed out. - TIMEOUT = 6; - } - Status status = 1; - - // `error_message` is only populated in case of errors. - string error_message = 2; - - // General stats that components may optionally report. - map stats = 3; - - // Response for GetCurrentActivity requests. - message GetCurrentActivityResponse { - // The format of the output is `package/package.ActivityName', for example: - // "com.example.vokram/com.example.vokram.MainActivity" - string full_activity = 1; - } - - // Response for GetOrientationRequests. - message GetOrientationResponse { - // Possible values are {0, 1, 2, 3} corresponding to {0, 90, 180, 270} - // degrees respectively. - // Please see https://developer.android.com/reference/android/view/Surface. - int32 orientation = 1; - } - - // Response for StartActivity requests. - message StartActivityResponse { - // The activity that was actually started. On a failed request, this will be - // empty. - string full_activity = 1; - bytes output = 2; - } - - // Response for PressButton requests. - message PressButtonResponse { - // The output, if any, by `adb` after sending a key press. - // This is intentionally left as `bytes` instead of `string` so that content - // other than `UTF-8` can be transmitted. - bytes output = 1; - } - - // Response for Push requests. - message PushResponse {} - - // Response for Pull requests. - message PullResponse { - // The contents of the file. - // This is intentionally left as `bytes` instead of `string` so that content - // other than `UTF-8` can be transmitted. - bytes content = 1; - } - - // Response for InputText requests. - message InputTextResponse {} - - // Response for SettingsRequests. - message SettingsResponse { - // The output, if any, of the `adb shell settings` command. - bytes output = 1; - } - - // Response for GenericRequests. - message GenericResponse { - // The output, if any, of the generic adb command. - bytes output = 1; - } - - // Response for PackageManagerRequests. - message PackageManagerResponse { - // The output, if any, of the `adb shell pm` command. - bytes output = 1; - - message List { - // A list of items. The actual content depends on the request, but it - // could be things like features, libraries or package names. - repeated string items = 1; - } - - oneof verb { - List list = 2; - } - } - - // Response for DumpsysRequests. - message DumpsysResponse { - // The output, if any, of the `dumpsys` command. - bytes output = 1; - } - - oneof payload { - GetCurrentActivityResponse get_current_activity = 10; - StartActivityResponse start_activity = 11; - PressButtonResponse press_button = 12; - PushResponse push = 13; - PullResponse pull = 14; - InputTextResponse input_text = 15; - SettingsResponse settings = 16; - GenericResponse generic = 17; - PackageManagerResponse package_manager = 18; - GetOrientationResponse get_orientation = 19; - DumpsysResponse dumpsys = 21; - } -} diff --git a/android_env/proto/emulator_controller.proto b/android_env/proto/emulator_controller.proto deleted file mode 100644 index 563b2f1f..00000000 --- a/android_env/proto/emulator_controller.proto +++ /dev/null @@ -1,1132 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright (C) 2018 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Note that if you add/remove methods in this file you must update -// the metrics sql as well ./android/scripts/gen-grpc-sql.py -// -// Please group deleted methods in a block including the date (MM/DD/YY) -// it was removed. This enables us to easily keep metrics around after removal -// -// list of deleted methods -// rpc iWasDeleted (03/12/12) -// ... - -// LINT: LEGACY_NAMES - -syntax = "proto3"; - -package android.emulation.control; - -import "google/protobuf/empty.proto"; - -option java_multiple_files = true; -option java_package = "com.android.emulator.control"; -option objc_class_prefix = "AEC"; - -// An EmulatorController service lets you control the emulator. -// Note that this is currently an experimental feature, and that the -// service definition might change without notice. Use at your own risk! -// -// We use the following rough conventions: -// -// streamXXX --> streams values XXX (usually for emulator lifetime). Values -// are updated as soon as they become available. -// getXXX --> gets a single value XXX -// setXXX --> sets a single value XXX, does not returning state, these -// usually have an observable lasting side effect. -// sendXXX --> send a single event XXX, possibly returning state information. -// android usually responds to these events. -service EmulatorController { - // Set the sensor data - rpc streamSensor(SensorValue) returns (stream SensorValue) {} - // Get the sensor data - rpc getSensor(SensorValue) returns (SensorValue) {} - // Stream the sensor data - rpc setSensor(SensorValue) returns (google.protobuf.Empty) {} - - // Set the physical model, this is likely the one you are - // looking for when you wish to modify the device state. - rpc setPhysicalModel(PhysicalModelValue) returns (google.protobuf.Empty) {} - // Get the physical model - rpc getPhysicalModel(PhysicalModelValue) returns (PhysicalModelValue) {} - // Stream the physical model - rpc streamPhysicalModel(PhysicalModelValue) - returns (stream PhysicalModelValue) {} - - // Atomically set the current primary clipboard data. - rpc setClipboard(ClipData) returns (google.protobuf.Empty) {} - // Atomically get the current primary clipboard data. - rpc getClipboard(google.protobuf.Empty) returns (ClipData) {} - - // Streams the current data on the clipboard. This will immediately produce - // a result with the current state of the clipboard after which the stream - // will block and wait until a new clip event is available from the guest. - // Calling the setClipboard method above will not result in generating a clip - // event. It is possible to lose clipboard events if the clipboard updates - // very rapidly. - rpc streamClipboard(google.protobuf.Empty) returns (stream ClipData) {} - - // Set the battery to the given state. - rpc setBattery(BatteryState) returns (google.protobuf.Empty) {} - // Get the battery to the given state. - rpc getBattery(google.protobuf.Empty) returns (BatteryState) {} - - // Set the state of the gps, gps support will only work - // properly if: - // - // - no location ui is active. That is the emulator - // is launched in headless mode (-no-window) or the location - // ui is disabled (-no-location-ui). - // - the passiveUpdate is set to false. Setting this to false - // will disable/break the LocationUI. - // - // Keep in mind that android usually only samples the gps at 1 hz. - rpc setGps(GpsState) returns (google.protobuf.Empty) {} - - // Gets the latest gps state as delivered by the setGps call, or location ui - // if active. - // - // Note: this is not necessarily the actual gps coordinate visible at the - // time, due to gps sample frequency (usually 1hz). - rpc getGps(google.protobuf.Empty) returns (GpsState) {} - - // Simulate a touch event on the finger print sensor. - rpc sendFingerprint(Fingerprint) returns (google.protobuf.Empty) {} - - // Send a keyboard event. Translating the event. - rpc sendKey(KeyboardEvent) returns (google.protobuf.Empty) {} - - // Send touch events. Note that mouse events can be simulated by touch events. - rpc sendTouch(TouchEvent) returns (google.protobuf.Empty) {} - // Send mouse events. - rpc sendMouse(MouseEvent) returns (google.protobuf.Empty) {} - - // Make a phone call. - rpc sendPhone(PhoneCall) returns (PhoneResponse) {} - - // Sends an sms message to the emulator. - rpc sendSms(SmsMessage) returns (PhoneResponse) {} - - // Retrieve the status of the emulator. This will contain general - // hardware information, and whether the device has booted or not. - rpc getStatus(google.protobuf.Empty) returns (EmulatorStatus) {} - - // Gets an individual screenshot in the desired format. - // - // The image will be scaled to the desired ImageFormat, while maintaining - // the aspect ratio. The returned image will never exceed the provided width - // and height. Not setting the width or height (i.e. they are 0) will result - // in using the device width and height. - // - // The resulting image will be properly oriented and can be displayed - // directly without post processing. For example, if the device has a - // 1080x1920 screen and is in landscape mode and called with no width or - // height parameter, it will return an 1920x1080 image. - // - // This method will return an empty image if the display is not visible. - rpc getScreenshot(ImageFormat) returns (Image) {} - - // Streams a series of screenshots in the desired format. - // A new frame will be delivered whenever the device produces a new frame. - // (Beware that this can produce a significant amount of data, and that - // certain translations are (png transform) can be costly). - // - // If the requested display is not visible it will send a single empty image - // and wait start producing images once the display becomes active, again - // producing a single empty image when the display becomes inactive. - rpc streamScreenshot(ImageFormat) returns (stream Image) {} - - // Streams a series of audio packets in the desired format. - // A new frame will be delivered whenever the emulated device - // produces a new audio frame. You can expect packets to be - // delivered in intervals of 20-30ms. - // - // Be aware that this can block when the emulator does not - // produce any audio whatsoever! - rpc streamAudio(AudioFormat) returns (stream AudioPacket) {} - - // Injects a series of audio packets to the android microphone. - // A new frame will be delivered whenever the emulated device - // requests a new audio frame. Audio is usually delivered at a rate - // that the emulator is requesting frames. Audio will be stored in a - // temporary buffer that can hold 500ms of audio. - // - // Note: Currently the emulator will downsample to 16khz. - // - // - INVALID_ARGUMENT (code 3) The sampling rate was too high - // - INVALID_ARGUMENT (code 3) The audio packet was too large to handle. - // - FAILED_PRECONDITION (code 9) If there was a microphone registered - // already. - rpc injectAudio(stream AudioPacket) returns (google.protobuf.Empty) {} - - // Returns the last 128Kb of logcat output from the emulator - // Note that parsed logcat messages are only available after L (Api >23). - // it is possible that the logcat buffer gets overwritten, or falls behind. - rpc getLogcat(LogMessage) returns (LogMessage) {} - - // Streams the logcat output from the emulator. The first call - // can retrieve up to 128Kb. This call will not return. - // Note that parsed logcat messages are only available after L (Api >23) - // it is possible that the logcat buffer gets overwritten, or falls behind. - rpc streamLogcat(LogMessage) returns (stream LogMessage) {} - - // Transition the virtual machine to the desired state. Note that - // some states are only observable. For example you cannot transition - // to the error state. - rpc setVmState(VmRunState) returns (google.protobuf.Empty) {} - - // Gets the state of the virtual machine. - rpc getVmState(google.protobuf.Empty) returns (VmRunState) {} - - // Atomically changes the current multi-display configuration. - // After this call the given display configurations will be activated. You - // can only update secondary displays. Displays with id 0 will be ignored. - // - // This call can result in the removal or addition of secondary displays, the - // final display state can be observed by the returned configuration. - // - // The following gRPC error codes can be returned: - // - FAILED_PRECONDITION (code 9) if the AVD does not support a configurable - // secondary display. - // - INVALID_ARGUMENT (code 3) if: - // - The same display id is defined multiple times. - // - The display configurations are outside valid ranges - // (see DisplayConfiguration) - // - INTERNAL (code 13) if there was an internal emulator failure. - rpc setDisplayConfigurations(DisplayConfigurations) - returns (DisplayConfigurations) {} - - // Returns all currently valid logical displays. - // The gRPC error code FAILED_PRECONDITION (code 9) is returned if the AVD - // does not support a configurable secondary display. - rpc getDisplayConfigurations(google.protobuf.Empty) - returns (DisplayConfigurations) {} - - // Notifies client of the following changes: - // - // - Virtual scene camera status change. - // - Display configuration changes from extended ui. This will only be fired - // if the user makes modifications the extended displays through the - // extended control tab. - // - // Note that this method will send the initial virtual scene state - // immediately. - rpc streamNotification(google.protobuf.Empty) returns (stream Notification) {} - - // RotationRadian is relative to the camera's current orientation. - rpc rotateVirtualSceneCamera(RotationRadian) returns (google.protobuf.Empty) { - } - // Velocity is absolute - rpc setVirtualSceneCameraVelocity(Velocity) returns (google.protobuf.Empty) {} - // Set foldable posture - rpc setPosture(Posture) returns (google.protobuf.Empty) {} -} - -// A Run State that describes the state of the Virtual Machine. -message VmRunState { - enum RunState { - // The emulator is in an unknown state. You cannot transition to this state. - UNKNOWN = 0; - // Guest is actively running. You can transition to this state from the - // paused state. - RUNNING = 1; - // Guest is paused to load a snapshot. You cannot transition to this state. - RESTORE_VM = 2; - // Guest has been paused. Transitioning to this state will pause the - // emulator the guest will not be consuming any cpu cycles. - PAUSED = 3; - // Guest is paused to take or export a snapshot. You cannot - // transition to this state. - SAVE_VM = 4; - // System shutdown, note that it is similar to power off. It tries to set - // the system status and notify guest. The system is likely going to - // disappear soon and do proper cleanup of resources, possibly taking - // a snapshot. This is the same behavior as closing the emulator by clicking - // the X (close) in the user interface. - SHUTDOWN = 5; - // Immediately terminate the emulator. No resource cleanup will take place. - // There is a good change to corrupt the system. - TERMINATE = 7; - // Will cause the emulator to reset. This is not a state you can observe. - RESET = 9; - // Guest experienced some error state, you cannot transition to this state. - INTERNAL_ERROR = 10; - } - - RunState state = 1; -} - -message ParameterValue { - repeated float data = 1 [packed = true]; -} - -message PhysicalModelValue { - enum State { - OK = 0; - NO_SERVICE = -3; // qemud service is not available/initiated. - DISABLED = -2; // Sensor is disabled. - UNKNOWN = -1; // Unknown sensor (should not happen) - } - - // Details on the sensors documentation can be found here: - // https://developer.android.com/reference/android/hardware/Sensor.html#TYPE_ - // The types must follow the order defined in - // "external/qemu/android/hw-sensors.h" - enum PhysicalType { - POSITION = 0; - - // All values are angles in degrees. - // values = [x,y,z] - ROTATION = 1; - - MAGNETIC_FIELD = 2; - - // Temperature in °C - TEMPERATURE = 3; - - // Proximity sensor distance measured in centimeters - PROXIMITY = 4; - - // Ambient light level in SI lux units - LIGHT = 5; - - // Atmospheric pressure in hPa (millibar) - PRESSURE = 6; - - // Relative ambient air humidity in percent - HUMIDITY = 7; - - VELOCITY = 8; - AMBIENT_MOTION = 9; - - // Describing a hinge angle sensor in degrees. - HINGE_ANGLE0 = 10; - HINGE_ANGLE1 = 11; - HINGE_ANGLE2 = 12; - - ROLLABLE0 = 13; - ROLLABLE1 = 14; - ROLLABLE2 = 15; - } - PhysicalType target = 1; - - // [Output Only] - State status = 2; - - // Value interpretation depends on sensor, will contain at most 3 values. - ParameterValue value = 3; -} - -// A single sensor value. -message SensorValue { - enum State { - OK = 0; - NO_SERVICE = -3; // qemud service is not available/initiated. - DISABLED = -2; // Sensor is disabled. - UNKNOWN = -1; // Unknown sensor (should not happen) - } - - // These are the various sensors that can be available in an emulated - // devices. - enum SensorType { - // Measures the acceleration force in m/s2 that is applied to a device - // on all three physical axes (x, y, and z), including the force of - // gravity. - ACCELERATION = 0; - // Measures a device's rate of rotation in rad/s around each of the - // three physical axes (x, y, and z). - GYROSCOPE = 1; - // Measures the ambient geomagnetic field for all three physical axes - // (x, y, z) in μT. - MAGNETIC_FIELD = 2; - // Measures degrees of rotation that a device makes around all three - // physical axes (x, y, z) - ORIENTATION = 3; - // Measures the temperature of the device in degrees Celsius (°C). - TEMPERATURE = 4; - // Measures the proximity of an object in cm relative to the view screen - // of a device. This sensor is typically used to determine whether a - // handset is being held up to a person's ear. - PROXIMITY = 5; - // Measures the ambient light level (illumination) in lx. - LIGHT = 6; - // Measures the ambient air pressure in hPa or mbar. - PRESSURE = 7; - // Measures the relative ambient humidity in percent (%). - HUMIDITY = 8; - MAGNETIC_FIELD_UNCALIBRATED = 9; - GYROSCOPE_UNCALIBRATED = 10; - } - - // Type of sensor - SensorType target = 1; - - // [Output Only] - State status = 2; - - // Value interpretation depends on sensor enum, will contain at most 3 - // values. - ParameterValue value = 3; -} - -message LogMessage { - // [Output Only] The contents of the log output. - string contents = 1; - // The starting byte position of the output that was returned. This - // should match the start parameter sent with the request. If the serial - // console output exceeds the size of the buffer, older output will be - // overwritten by newer content and the start values will be mismatched. - int64 start = 2; - //[Output Only] The position of the next byte of content from the serial - // console output. Use this value in the next request as the start - // parameter. - int64 next = 3; - - // Set the sort of response you are interested it in. - // It the type is "Parsed" the entries field will contain the parsed - // results. otherwise the contents field will be set. - LogType sort = 4; - - // [Output Only] The parsed logcat entries so far. Only set if sort is - // set to Parsed - repeated LogcatEntry entries = 5; - - enum LogType { - Text = 0; - Parsed = 1; - } -} - -// A parsed logcat entry. -message LogcatEntry { - // The possible log levels. - enum LogLevel { - UNKNOWN = 0; - DEFAULT = 1; - VERBOSE = 2; - DEBUG = 3; - INFO = 4; - WARN = 5; - ERR = 6; - FATAL = 7; - SILENT = 8; - } - - // A Unix timestamps in milliseconds (The number of milliseconds that - // have elapsed since January 1, 1970 (midnight UTC/GMT), not counting - // leap seconds) - uint64 timestamp = 1; - - // Process id. - uint32 pid = 2; - - // Thread id. - uint32 tid = 3; - LogLevel level = 4; - string tag = 5; - string msg = 6; -} - -// Information about the hypervisor that is currently in use. -message VmConfiguration { - enum VmHypervisorType { - // An unknown hypervisor - UNKNOWN = 0; - - // No hypervisor is in use. This usually means that the guest is - // running on a different CPU than the host, or you are using a - // platform where no hypervisor is available. - NONE = 1; - - // The Kernel based Virtual Machine - // (https://www.linux-kvm.org/page/Main_Page) - KVM = 2; - - // Intel® Hardware Accelerated Execution Manager (Intel® HAXM) - // https://github.com/intel/haxm - HAXM = 3; - - // Hypervisor Framework. - // https://developer.apple.com/documentation/hypervisor - HVF = 4; - - // Window Hypervisor Platform - // https://docs.microsoft.com/en-us/virtualization/api/ - WHPX = 5; - - GVM = 6; - } - - VmHypervisorType hypervisorType = 1; - int32 numberOfCpuCores = 2; - int64 ramSizeBytes = 3; -} - -// Representation of a clipped data object on the clipboard. -message ClipData { - // UTF-8 Encoded text. - string text = 1; -} - -// The Touch interface represents a single contact point on a -// touch-sensitive device. The contact point is commonly a finger or stylus -// and the device may be a touchscreen or trackpad. -message Touch { - // The horizontal coordinate. This is the physical location on the - // screen For example 0 indicates the leftmost coordinate. - int32 x = 1; - - // The vertical coordinate. This is the physical location on the screen - // For example 0 indicates the top left coordinate. - int32 y = 2; - - // The identifier is an arbitrary non-negative integer that is used to - // identify and track each tool independently when multiple tools are - // active. For example, when multiple fingers are touching the device, - // each finger should be assigned a distinct tracking id that is used as - // long as the finger remains in contact. Tracking ids may be reused - // when their associated tools move out of range. - // - // The emulator currently supports up to 10 concurrent touch events. The - // identifier can be any uninque value and will be mapped to the next - // available internal identifier. - int32 identifier = 3; - - // Reports the physical pressure applied to the tip of the tool or the - // signal strength of the touch contact. - // - // The values reported must be non-zero when the tool is touching the - // device and zero otherwise to indicate that the touch event is - // completed. - // - // Make sure to deliver a pressure of 0 for the given identifier when - // the touch event is completed, otherwise the touch identifier will not - // be unregistered! - int32 pressure = 4; - - // Optionally reports the cross-sectional area of the touch contact, or - // the length of the longer dimension of the touch contact. - int32 touch_major = 5; - - // Optionally reports the length of the shorter dimension of the touch - // contact. This axis will be ignored if touch_major is reporting an - // area measurement greater than 0. - int32 touch_minor = 6; - - enum EventExpiration { - // The system will use the default time of 120s to track - // the touch event with the given identifier. If no update happens - // within this timeframe the identifier is considered expired - // and can be made available for re-use. This means that a touch event - // with pressure 0 for this identifier will be send to the emulator. - EVENT_EXPIRATION_UNSPECIFIED = 0; - - // Never expire the given slot. You must *ALWAYS* close the identifier - // by sending a touch event with 0 pressure. - NEVER_EXPIRE = 1; - } - - EventExpiration expiration = 7; -} - -// A TouchEvent contains a list of Touch objects that are in contact with -// the touch surface. -// -// Touch events are delivered in sequence as specified in the touchList. -// -// TouchEvents are delivered to the emulated devices using ["Protocol -// B"](https://www.kernel.org/doc/Documentation/input/multi-touch-protocol.txt) -message TouchEvent { - // The list of Touch objects, note that these do not need to be unique - repeated Touch touches = 1; - - // The display device where the touch event occurred. - // Omitting or using the value 0 indicates the main display. - // - // Touch events cannot be send to displays other than 0, due to - // https://issuetracker.google.com/issues/150699691 - int32 display = 2; -} - -// The MouseEvent interface represents events that occur due to the user -// interacting with a pointing device (such as a mouse). -message MouseEvent { - // The horizontal coordinate. This is the physical location on the - // screen For example 0 indicates the leftmost coordinate. - int32 x = 1; - - // The vertical coordinate. This is the physical location on the screen - // For example 0 indicates the top left coordinate. - int32 y = 2; - - // Indicates which buttons are pressed. - // 0: No button was pressed - // 1: Primary button (left) - // 2: Secondary button (right) - int32 buttons = 3; - - // The display device where the mouse event occurred. - // Omitting or using the value 0 indicates the main display. - int32 display = 4; -} - -// KeyboardEvent objects describe a user interaction with the keyboard; each -// event describes a single interaction between the user and a key (or -// combination of a key with modifier keys) on the keyboard. -// This follows the pattern as set by -// (javascript)[https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent] -// -// Note: that only keyCode, key, or text can be set and that the semantics -// will slightly vary. -message KeyboardEvent { - // Code types that the emulator can receive. Note that the emulator - // will do its best to translate the code to an evdev value that - // will be send to the emulator. This translation is based on - // the chromium translation tables. See - // (this)[https://android.googlesource.com/platform/external/qemu/+/refs/heads/emu-master-dev/android/android-grpc/android/emulation/control/keyboard/keycode_converter_data.inc] - // for details on the translation. - enum KeyCodeType { - Usb = 0; - Evdev = 1; - XKB = 2; - Win = 3; - Mac = 4; - } - - enum KeyEventType { - // Indicates that this keyevent should be send to the emulator - // as a key down event. Meaning that the key event will be - // translated to an EvDev event type and bit 11 (0x400) will be - // set before it is sent to the emulator. - keydown = 0; - - // Indicates that the keyevent should be send to the emulator - // as a key up event. Meaning that the key event will be - // translated to an EvDev event type and - // sent to the emulator. - keyup = 1; - - // Indicates that the keyevent will be send to the emulator - // as e key down event and immediately followed by a keyup event. - keypress = 2; - } - - // Type of keycode contained in the keyCode field. - KeyCodeType codeType = 1; - - // The type of keyboard event that should be sent to the emulator - KeyEventType eventType = 2; - - // This property represents a physical key on the keyboard (as opposed - // to the character generated by pressing the key). In other words, this - // property is a value which isn't altered by keyboard layout or the - // state of the modifier keys. This value will be interpreted by the - // emulator depending on the KeyCodeType. The incoming key code will be - // translated to an evdev code type and send to the emulator. - // The values in key and text will be ignored. - int32 keyCode = 3; - - // The value of the key pressed by the user, taking into consideration - // the state of modifier keys such as Shift as well as the keyboard - // locale and layout. This follows the w3c standard used in browsers. - // You can find an accurate description of valid values - // [here](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values) - // - // Note that some keys can result in multiple evdev events that are - // delivered to the emulator. for example the Key "A" will result in a - // sequence: - // ["Shift", "a"] -> [0x2a, 0x1e] whereas "a" results in ["a"] -> [0x1e]. - // - // Not all documented keys are understood by android, and only printable - // ASCII [32-127) characters are properly translated. - // - // Keep in mind that there are a set of key values that result in android - // specific behavior - // [see](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values#Phone_keys): - // - // - "AppSwitch": Behaves as the "Overview" button in android. - // - "GoBack": The Back button. - // - "GoHome": The Home button, which takes the user to the phone's main - // screen (usually an application launcher). - // - "Power": The Power button. - string key = 4; - - // Series of utf8 encoded characters to send to the emulator. An attempt - // will be made to translate every character will an EvDev event type and - // send to the emulator as a keypress event. The values in keyCode, - // eventType, codeType and key will be ignored. - // - // Note that most printable ASCII characters (range [32-127) can be send - // individually with the "key" param. Do not expect arbitrary UTF symbols to - // arrive in the emulator (most will be ignored). - // - // Note that it is possible to overrun the keyboard buffer by slamming this - // endpoint with large quantities of text (>1kb). The clipboard api is better - // suited for transferring large quantities of text. - string text = 5; -} - -message Fingerprint { - // True when the fingprint is touched. - bool isTouching = 1; - - // The identifier of the registered fingerprint. - int32 touchId = 2; -} - -message GpsState { - // Setting this to false will disable auto updating from the LocationUI, - // otherwise the location UI will override the location at a frequency of 1hz. - // - // - This is unused if the emulator is launched with -no-window, or when he - // location ui is disabled. - // - This will BREAK the location ui experience if it is set to false. For - // example routing will no longer function. - bool passiveUpdate = 1; - - // The latitude, in degrees. - double latitude = 2; - - // The longitude, in degrees. - double longitude = 3; - - // The speed if it is available, in meters/second over ground - double speed = 4; - - // gets the horizontal direction of travel of this device, and is not - // related to the device orientation. It is guaranteed to be in the - // range [0.0, 360.0] if the device has a bearing. 0=North, 90=East, - // 180=South, etc.. - double bearing = 5; - - // The altitude if available, in meters above the WGS 84 reference - // ellipsoid. - double altitude = 6; - - // The number of satellites used to derive the fix - int32 satellites = 7; -} - -message BatteryState { - enum BatteryStatus { - UNKNOWN = 0; - CHARGING = 1; - DISCHARGING = 2; - NOT_CHARGING = 3; - FULL = 4; - } - - enum BatteryCharger { - NONE = 0; - AC = 1; - USB = 2; - WIRELESS = 3; - } - - enum BatteryHealth { - GOOD = 0; - FAILED = 1; - DEAD = 2; - OVERVOLTAGE = 3; - OVERHEATED = 4; - } - - bool hasBattery = 1; - bool isPresent = 2; - BatteryCharger charger = 3; - int32 chargeLevel = 4; - BatteryHealth health = 5; - BatteryStatus status = 6; -} - -// An ImageTransport allows for specifying a side channel for -// delivering image frames versus using the standard bytes array that is -// returned with the gRPC request. -message ImageTransport { - enum TransportChannel { - // Return full frames over the gRPC transport - TRANSPORT_CHANNEL_UNSPECIFIED = 0; - - // Write images to the a file/shared memory handle. - MMAP = 1; - } - - // The desired transport channel used for delivering image frames. Only - // relevant when streaming screenshots. - TransportChannel channel = 1; - - // Handle used for writing image frames if transport is mmap. The client sets - // and owns this handle. It can be either a shm region, or a mmap. A mmap - // should be a url that starts with `file:///` - // Note: the mmap can result in tearing. - string handle = 2; -} - -// The aspect ratio (width/height) will be different from the one -// where the device is unfolded. -message FoldedDisplay { - uint32 width = 1; - uint32 height = 2; - // It is possible for the screen to be folded in different ways depending - // on which surface is shown to the user. So xOffset and yOffset indicate - // the top left corner of the folded screen within the original unfolded - // screen. - uint32 xOffset = 3; - uint32 yOffset = 4; -} - -message ImageFormat { - enum ImgFormat { - // Portable Network Graphics format - // (https://en.wikipedia.org/wiki/Portable_Network_Graphics) - PNG = 0; - - // Three-channel RGB color model supplemented with a fourth alpha - // channel. https://en.wikipedia.org/wiki/RGBA_color_model - // Each pixel consists of 4 bytes. - RGBA8888 = 1; - - // Three-channel RGB color model, each pixel consists of 3 bytes - RGB888 = 2; - } - - // The (desired) format of the resulting bytes. - ImgFormat format = 1; - - // [Output Only] The rotation of the image. The image will be rotated - // based upon the coarse grained orientation of the device. - Rotation rotation = 2; - - // The (desired) width of the image. When passed as input - // the image will be scaled to match the given - // width, while maintaining the aspect ratio of the device. - // The returned image will never exceed the given width, but can be less. - // Omitting this value (or passing in 0) will result in no scaling, - // and the width of the actual device will be used. - uint32 width = 3; - - // The (desired) height of the image. When passed as input - // the image will be scaled to match the given - // height, while maintaining the aspect ratio of the device. - // The returned image will never exceed the given height, but can be less. - // Omitting this value (or passing in 0) will result in no scaling, - // and the height of the actual device will be used. - uint32 height = 4; - - // The (desired) display id of the device. Setting this to 0 (or omitting) - // indicates the main display. - uint32 display = 5; - - // Set this if you wish to use a different transport channel to deliver image - // frames. - ImageTransport transport = 6; - - // [Output Only] Display configuration when screen is folded. The value is the - // original configuration before scaling. - FoldedDisplay foldedDisplay = 7; -} - -message Image { - ImageFormat format = 1; - - uint32 width = 2 [deprecated = true]; // width is contained in format. - uint32 height = 3 [deprecated = true]; // height is contained in format. - - // The organization of the pixels in the image buffer is from left to - // right and bottom up. This will be empty if an alternative image transport - // is requested in the image format. In that case the side channel should - // be used to obtain the image data. - bytes image = 4; - - // [Output Only] Monotonically increasing sequence number in a stream of - // screenshots. The first screenshot will have a sequence of 0. A single - // screenshot will always have a sequence number of 0. The sequence is not - // necessarily contiguous, and can be used to detect how many frames were - // dropped. An example sequence could be: [0, 3, 5, 7, 9, 11]. - uint32 seq = 5; - - // [Output Only] Unix timestamp in microseconds when the emulator estimates - // the frame was generated. The timestamp is before the actual frame is - // copied and transformed. This can be used to calculate variance between - // frame production time, and frame depiction time. - uint64 timestampUs = 6; -} - -message Rotation { - enum SkinRotation { - PORTRAIT = 0; // 0 degrees - LANDSCAPE = 1; // 90 degrees - REVERSE_PORTRAIT = 2; // -180 degrees - REVERSE_LANDSCAPE = 3; // -90 degrees - } - - // The rotation of the device, derived from the sensor state - // of the emulator. The derivation reflects how android observes - // the rotation state. - SkinRotation rotation = 1; - - // Specifies the angle of rotation, in degrees [-180, 180] - double xAxis = 2; - double yAxis = 3; - double zAxis = 4; -} - -message PhoneCall { - enum Operation { - InitCall = 0; - AcceptCall = 1; - RejectCallExplicit = 2; - RejectCallBusy = 3; - DisconnectCall = 4; - PlaceCallOnHold = 5; - TakeCallOffHold = 6; - } - Operation operation = 1; - string number = 2; -} - -message PhoneResponse { - enum Response { - OK = 0; - BadOperation = 1; // Enum out of range - BadNumber = 2; // Mal-formed telephone number - InvalidAction = 3; // E.g., disconnect when no call is in progress - ActionFailed = 4; // Internal error - RadioOff = 5; // Radio power off - } - Response response = 1; -} - -message Entry { - string key = 1; - string value = 2; -} - -message EntryList { - repeated Entry entry = 1; -} - -message EmulatorStatus { - // The emulator version string. - string version = 1; - - // The time the emulator has been active in .ms - uint64 uptime = 2; - - // True if the device has completed booting. - // For P and later this information will accurate, - // for older images we rely on adb. - bool booted = 3; - - // The current vm configuration - VmConfiguration vmConfig = 4; - - // The hardware configuration of the running emulator as - // key valure pairs. - EntryList hardwareConfig = 5; -} - -message AudioFormat { - enum SampleFormat { - AUD_FMT_U8 = 0; // Unsigned 8 bit - AUD_FMT_S16 = 1; // Signed 16 bit (little endian) - } - - enum Channels { - Mono = 0; - Stereo = 1; - } - - // Sampling rate to use, defaulting to 44100 if this is not set. - // Note, that android devices typically will not use a sampling - // rate higher than 48kHz. See https://developer.android.com/ndk/guides/audio. - uint64 samplingRate = 1; - Channels channels = 2; - SampleFormat format = 3; -} - -message AudioPacket { - AudioFormat format = 1; - - // Unix epoch in us when this frame was captured. - uint64 timestamp = 2; - - // Contains a sample in the given audio format. - bytes audio = 3; -} - -message SmsMessage { - // The source address where this message came from. - // - // The address should be a valid GSM-formatted address as specified by - // 3GPP 23.040 Sec 9.1.2.5. - // - // For example: +3106225412 or (650) 555-1221 - string srcAddress = 1; - - // A utf8 encoded text message that should be delivered. - string text = 2; -} - -// A DisplayConfiguration describes a primary or secondary -// display available to the emulator. The screen aspect ratio -// cannot be longer (or wider) than 21:9 (or 9:21). Screen sizes -// larger than 4k will be rejected. -// -// Common configurations (w x h) are: -// - 480p (480x720) 142 dpi -// - 720p (720x1280) 213 dpi -// - 1080p (1080x1920) 320 dpi -// - 4K (2160x3840) 320 dpi -// - 4K (2160x3840) 640 dpi (upscaled) -// -// The behavior of the virtual display depends on the flags that are provided to -// this method. By default, virtual displays are created to be private, -// non-presentation and unsecure. -message DisplayConfiguration { - // These are the set of known android flags and their respective values. - // you can combine the int values to (de)construct the flags field below. - enum DisplayFlags { - DISPLAYFLAGS_UNSPECIFIED = 0; - - // When this flag is set, the virtual display is public. - // A public virtual display behaves just like most any other display - // that is connected to the system such as an external or wireless - // display. Applications can open windows on the display and the system - // may mirror the contents of other displays onto it. see: - // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PUBLIC - VIRTUAL_DISPLAY_FLAG_PUBLIC = 1; - - // When this flag is set, the virtual display is registered as a - // presentation display in the presentation display category. - // Applications may automatically project their content to presentation - // displays to provide richer second screen experiences. - // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PRESENTATION - VIRTUAL_DISPLAY_FLAG_PRESENTATION = 2; - - // When this flag is set, the virtual display is considered secure as - // defined by the Display#FLAG_SECURE display flag. The caller promises - // to take reasonable measures, such as over-the-air encryption, to - // prevent the contents of the display from being intercepted or - // recorded on a persistent medium. - // see: - // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_SECURE - VIRTUAL_DISPLAY_FLAG_SECURE = 4; - - // This flag is used in conjunction with VIRTUAL_DISPLAY_FLAG_PUBLIC. - // Ordinarily public virtual displays will automatically mirror the - // content of the default display if they have no windows of their own. - // When this flag is specified, the virtual display will only ever show - // its own content and will be blanked instead if it has no windows. See - // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY - VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY = 8; - - // Allows content to be mirrored on private displays when no content is - // being shown. - // This flag is mutually exclusive with - // VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY. If both flags are specified - // then the own-content only behavior will be applied. - // see: - // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR) - VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR = 16; - } - - // The width of the display, restricted to: - // 320 * (dpi / 160) <= width - uint32 width = 1; - - // The heigh of the display, restricted to: - // * 320 * (dpi / 160) <= height - uint32 height = 2; - - // The pixel density (dpi). - // See https://developer.android.com/training/multiscreen/screendensities - // for details. This value should be in the range [120, ..., 640] - uint32 dpi = 3; - - // A combination of virtual display flags. These flags can be constructed - // by combining the DisplayFlags enum described above. - // - // The behavior of the virtual display depends on the flags. By default - // virtual displays are created to be private, non-presentation and - // unsecure. - uint32 flags = 4; - - // The id of the display. - // The primary (default) display has the display ID of 0. - // A secondary display has a display ID not 0. - // - // The id can be used to get or stream a screenshot. - uint32 display = 5; -} - -message DisplayConfigurations { - repeated DisplayConfiguration displays = 1; -} - -message Notification { - enum EventType { - VIRTUAL_SCENE_CAMERA_INACTIVE = 0; - VIRTUAL_SCENE_CAMERA_ACTIVE = 1; - - // Fired when an update to a display event has been fired through - // the extended ui. This does not fire events when the display - // is changed through the console or gRPC endpoint. - DISPLAY_CONFIGURATIONS_CHANGED_UI = 2; - // Keep adding more for other event types - } - - EventType event = 1; -} - -message RotationRadian { - float x = 1; // x axis is horizontal and orthogonal to the view direction. - float y = 2; // y axis points up and is perpendicular to the floor. - float z = 3; // z axis is the view direction and is set to 0.0 in - // rotateVirtualSceneCamera call. -} - -message Velocity { - float x = 1; // x axis is horizontal and orthogonal to the view direction. - float y = 2; // y axis points up and is perpendicular to the floor. - float z = 3; // z axis is the view direction -} - -// must follow the definition in "external/qemu/android/hw-sensors.h" -message Posture { - enum PostureValue { - POSTURE_UNKNOWN = 0; - POSTURE_CLOSED = 1; - POSTURE_HALF_OPENED = 2; - POSTURE_OPENED = 3; - POSTURE_FLIPPED = 4; - POSTURE_TENT = 5; - POSTURE_MAX = 6; - } - PostureValue value = 3; -} diff --git a/android_env/proto/snapshot.proto b/android_env/proto/snapshot.proto deleted file mode 100644 index 0b0abc12..00000000 --- a/android_env/proto/snapshot.proto +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright (C) 2018 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto2"; - -// This file must be synchronized between -// Emulator (branch aosp/emu-master-dev): -// external/qemu/android/android-emu/android/snapshot/proto/snapshot.proto -// -// Android Studio (branch goog/studio-master-dev): -// tools/adt/idea/android/src/com/android/emulator/snapshot.proto -// -// If you modify one, please modify the other. - -package emulator_snapshot; - -option java_package = "com.android.emulator.snapshot"; - -message Image { - enum Type { - IMAGE_TYPE_UNKNOWN = 0; - IMAGE_TYPE_KERNEL = 1; - IMAGE_TYPE_KERNEL_RANCHU = 2; - IMAGE_TYPE_SYSTEM = 3; - IMAGE_TYPE_SYSTEM_COPY = 4; - IMAGE_TYPE_DATA = 5; - IMAGE_TYPE_DATA_COPY = 6; - IMAGE_TYPE_RAMDISK = 7; - IMAGE_TYPE_SDCARD = 8; - IMAGE_TYPE_CACHE = 9; - IMAGE_TYPE_VENDOR = 10; - IMAGE_TYPE_ENCRYPTION_KEY = 11; - } - - optional Type type = 1; - optional string path = 2; - optional bool present = 3; - optional int64 size = 4; - optional int64 modification_time = 5; -} - -message Host { - optional string gpu_driver = 4; - optional int32 hypervisor = 5; -} - -message Config { - // Features are int32, not enums here to make sure we don't have to update - // one more protobuf definition with every single new feature flag, even - // when the code doesn't really care about the actual meaning for them, - // only for the values. - repeated int32 enabled_features = 1; - - // This holds the renderer; int32 for the same reason as |enabled_features|. - optional int32 selected_renderer = 2; - - optional int32 cpu_core_count = 3; - optional int64 ram_size_bytes = 4; -} - -message SaveStats { - // Type of save - // 0: non-incremental - // 1: incremental - optional uint32 incremental = 1; - // Time taken to save. - optional uint64 duration = 2; - // How many changed bytes in RAM. - optional uint64 ram_changed_bytes = 3; -} - -message Snapshot { - // Update every time when introducing some breaking changes that make the - // previous loading code break when trying to load the new snapshot. - // NOTE: if the old code is fine with just skipping the new fields or not - // getting the meaning of new values, |version| should remain - // unchanged. - optional int32 version = 1; - - // Purely informative: when this snapshot was created, Unix timestamp. - optional int64 creation_time = 2; - - // list of mounted disk images used during the snapshot creation. - repeated Image images = 3; - - // Description of the host machine properties needed to load this snapshot. - optional Host host = 4; - - // Description of the emulator configuration needed for this snapshot. - // NOTE: try not to duplicate the configuration that's already in - // hardware-qemu.ini; only add what's either not there or what - // could've been overridden during process initialization. - optional Config config = 5; - - // Set if the snapshot failed to load during the last attempt. - // Code is up to the application to define, with 0 meaning 'not failed' just - // in case. - optional int64 failed_to_load_reason_code = 7; - - // Set if data image is mounted. - // User build and userdebug build mount data partition at different time. - // But it should be done before boot finished, so this field is very likely - // to be true. - // We snapshot it here just in case someday we support snapshot during - // booting. - optional bool guest_data_partition_mounted = 8; - - // Emulator rotation angle, in right angles (e.g. 1 is 90 degrees, 2 is 180 - // etc). - optional int32 rotation = 9; - - // Number of invalid loads / crashes that happened under this snapshot. - optional int32 invalid_loads = 10; - - // Number of successful loads. - optional int32 successful_loads = 11; - - // The name given to the snapshot by the user. Independent of the - // file name. - optional string logical_name = 12; - - // The file name of this snapshot's parent. The parent is the - // snapshot that was loaded into the AVD prior to this snapshot - // being taken - optional string parent = 13; - - // Arbitrary description added by the user - optional string description = 14; - - // Record of save stats. - repeated SaveStats save_stats = 15; - - // Folded state. - optional bool folded = 16; - - // Emulator boot parameters - repeated string launch_parameters = 17; - - // Emulator build ID - optional string emulator_build_id = 18; - - // System image build ID - optional string system_image_build_id = 19; -} diff --git a/android_env/proto/snapshot_service.proto b/android_env/proto/snapshot_service.proto deleted file mode 100644 index 1cbf72ec..00000000 --- a/android_env/proto/snapshot_service.proto +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright (C) 2018 The Android Open Source Project -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Note that if you add/remove methods in this file you must update -// the metrics sql as well by running ./android/scripts/gen-grpc-sql.py -// -// Please group deleted methods in a block including the date (MM/DD/YY) -// it was removed. This enables us to easily keep metrics around after removal -// -// list of deleted methods -// rpc iWasDeleted (03/12/12) -// ... -syntax = "proto3"; - -package android.emulation.control; - -import "android_env/proto/snapshot.proto"; - -option java_multiple_files = true; -option java_package = "com.android.emulator.control"; -option objc_class_prefix = "AEC"; - -// The SnapshotService enables you to list, insert, store, and retrieve -// snapshots. -// -// Currently there are two types of snapshots: -// -// - Local (default): These are snapshots that are created locally. They are -// stored internally inside qcow2 files and are very efficient. These are -// the snapshots usually created by interacting with the UI. -// -// - Remote: These are snapshots that have been exported at a certain point. -// an exported snapshot is normalized (completely self contained) and -// can be imported into an emulator with a similar hardware configuration. -// -// Currently the emulator has limited support for importing snapshots: -// - Once an imported snapshot has been loaded into an emulator it is no longer -// possible to create new snapshots. -// - The hardware configuration of the emulator your are pushing a snapshot to -// must match (or be very similar) to the one you pulled the snapshot from. -// -// For example do not expect to be able to restore a snapshot on created on an -// Intel cpu on an AMD cpu. -service SnapshotService { - // Lists all the snapshots, filtered by the given query, that are stored - // locally for the currently running avd. This includes all the snapshots that - // were imported (pushed) into this emulator. - // - // Returns a list of snapshot_id's and associated details that describes - // the hardware configuration, logical name, etc of the snapshot. - rpc ListSnapshots(SnapshotFilter) returns (SnapshotList) {} - - // Pulls down the snapshot stored inside the AVD as a tar.gz/tar stream - // This will normalize the snapshot, all relevant data to push a snapshot - // into a similar emulator will be placed inside the tar file. - // - // Pulling down a snapshot will pause the emulator until the snapshots - // are rebased and ready for exporting. Once the snapshot is rebased - // the emulator will continue and downloading should commence. - // - // Note that pulling .gz stream is slow. - // - // You must provide the snapshot_id and (desired) format. - // - // If SnapshotPackage.path is set, the gRPC service will directly write the - // exported snapshot to SnapshotPackage.path without streaming, which is - // usually significantly faster. It would require emulator to have direct - // access to SnapshotPackage.path, which usually means it can only be used - // when pulling from a local emulator. - rpc PullSnapshot(SnapshotPackage) returns (stream SnapshotPackage) {} - - // Push a tar.gz stream contain the snapshot. The tar file should - // be a snapshot that was exported through the PullSnapshot in the past. - // The emulator will try to import the snapshot. The hardware configuration - // of the current emulator should match the one used for pulling. - // - // A detailed description of the snapshot (emulator_snapshot.Snapshot) - // is stored in the snapshot.pb file inside the tar. - // - // You must provide the snapshot_id and format in the first message. - // Will return success and a possible error message when a failure occurs. - // - // If SnapshotPackage.path is set, the gRPC service will directly unzip the - // exported snapshot from SnapshotPackage.path without streaming, which is - // usually significantly faster. It would require emulator to have direct - // access to SnapshotPackage.path, which usually means it can only be used - // when pushing to a local emulator. - rpc PushSnapshot(stream SnapshotPackage) returns (SnapshotPackage) {} - - // Loads the given snapshot inside the emulator and activates it. - // The device will be in the state as it was when the snapshot was created. - // - // You will no longer be able to call Save if this was an imported - // snapshot that was pushed into this emulator. - // - // You must provide the snapshot_id to indicate which snapshot to load - // Will return success and a possible error message when a failure occurs. - rpc LoadSnapshot(SnapshotPackage) returns (SnapshotPackage) {} - - // Creates as a snapshot of the current state of the emulator. - // You can only save a snapshot if you never activated (Load) an imported - // snapshot (Push). - // - // For example: - // - PushSnapshot("some_snap.tar.gz"); - // - LoadSnapshot("some_snap"); - // - SaveSnapshot("same_newer_snap"); // <--- Will currently fail. - // - // You can provide the snapshot_id to indicate the name used for storing. - // Will return success and a possible error message when a failure occurs. - rpc SaveSnapshot(SnapshotPackage) returns (SnapshotPackage) {} - - // Deletes the snapshot with the given snapshot_id from the avd. - // - // You must provide the snapshot_id to indicate which snapshot to delete. - // Will return success and a possible error message when a failure occurs. - rpc DeleteSnapshot(SnapshotPackage) returns (SnapshotPackage) {} - - // Tracks the given process for automated snapshot creation in case of - // assert failures. - // - // Will return success and a possible error message when a failure occurs. - // The snapshot_id field will contain the name of the snapshot that - // will be created. The pid field will contain the process id that is - // being tracked. - rpc TrackProcess(IceboxTarget) returns (IceboxTarget) {} -} - -// Sets options for SnapshotService. Used for both request and response -// messages. -message SnapshotPackage { - enum Format { - TARGZ = 0; - TAR = 1; - DIRECTORY = 2; - } - // The identifier to the snapshot, only required for request messages. For - // streaming service, only used in the first stream message of a gRPC call - // (would be ignored in consequent stream messages of the same call). - string snapshot_id = 1; - - // A stream of bytes. Encoded as a tar (possibly gzipped) file pendinf on the - // value of format. - bytes payload = 2; - - // [response only] status fields, usually indicates end of transmission. - bool success = 3; - bytes err = 4; - - // [request only] Format of the payload. Only used in request messages. For - // streaming service, only used in the first stream message of a gRPC call - // (would be ignored in consequent stream messages of the same call). - Format format = 5; - - // [request only] Path to the snapshot package file. Only used in request - // messages. - // - // When set in a request, the PullSnapshot/PushSnapshot operation will - // directly write/read the exported snapshot in path without streaming, which - // is usually significantly faster. It would require emulator to have direct - // access to path, which usually means it can only be used with a local - // emulator. - string path = 6; -} - -// A snapshot filter can be used to filter the results produced by ListSnapshots -message SnapshotFilter { - enum LoadStatus { - // Only return compatible snapshots - CompatibleOnly = 0; - - // Return all snapshots. - All = 1; - } - - // Filter snapshots by load status. - LoadStatus statusFilter = 1; -} - -// Provides detailed information regarding the snapshot. -message SnapshotDetails { - enum LoadStatus { - // The emulator believes that the snapshot is compatible with the emulator - // that provided this information. The emulator will attempt to load this - // snapshot when requested. - // - // A snapshot is usually compatible when the following statements are true: - // - The snapshot was taken by the current emulator version. i.e. - // emulator_build_id in the details field matches the build_id of the - // emulator that provided this information. - // - // - The snapshot was taken on the current running machine, and no hardware - // changes have taken place between taking and loading the snapshot. - // - // - The avd configuration has not changed between when this snapshot was - // taken and when the snapshot was loaded. - // - // - The system images on which the avd is based have not changed. - Compatible = 0; - - // The emulator will not allow loading of the snapshot, as it deems the - // snapshot to be incompatible. Loading of snapshots can be forced by - // launching the emulator with the feature "AllowSnapshotMigration" enabled. - Incompatible = 1; - - // This snapshot was successfully loaded in the emulator, and was used at - // the starting point of the current running emulator. The following holds: - // - // A loaded snapshot is a compatible snapshot - // There is at most one snapshot_id that is in the "Loaded" state - Loaded = 2; - } - - // The id of this snapshot. Use this id to load/delete/pull the - // snapshot. - string snapshot_id = 1; - - // Detailed information about this snapshot. This contains a detailed - // hardware description of the snapshot. These details are the same - // as the "snapshot.pb" file found in an exported snapshot. - // Look at the import file for a detailed description of the available - // fields. - emulator_snapshot.Snapshot details = 2; - - // Provides information about the ability to restore this snapshot. - LoadStatus status = 3; - - // The size of the folder that stores required information to load a snapshot. - uint64 size = 4; -} - -// A list of on snapshot details. -message SnapshotList { - repeated SnapshotDetails snapshots = 1; -} - -message IceboxTarget { - // This is the process id to attach to, if this value is not set (0) - // The process name will be used instead. - int64 pid = 1; - - // The process name to attach to if any, if this is not set the pid will - // be used. This is usually the application name of your application under - // test, that is passed in to the am instrument command. It is likely - // what you will find in your AndroidManifest.xml - string package_name = 2; - - // The name of the snapshot that icebox will create if a snapshot is - // generated. - string snapshot_id = 3; - - // [Output Only] True if icebox failed to track the given target. - bool failed = 4; - - // [Output Only] Detailed error message that might provide more information. - string err = 5; - - // Maximum number of snapshots the emulator can take during one Icebox run. - // Set to -1 for unlimited number of snapshots. - int32 max_snapshot_number = 6; -} - -// list of deleted methods: -// diff --git a/android_env/proto/state.proto b/android_env/proto/state.proto deleted file mode 100644 index c562fe52..00000000 --- a/android_env/proto/state.proto +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -option java_multiple_files = true; - -message SaveStateRequest { - map args = 1; -} - -message LoadStateRequest { - map args = 1; -} - -message SaveStateResponse { - enum Status { - // Reserved value for unset statuses. - UNDEFINED = 0; - // Returned when everything goes well. - OK = 1; - // Returned when something internal did not work as expected. - ERROR = 2; - } - Status status = 1; - // `error_message` is only populated in case of errors. - string error_message = 2; - - // Any additional info returned during the request; e.g., file paths or sizes. - map additional_info = 3; -} - -message LoadStateResponse { - enum Status { - // Reserved value for unset statuses. - UNDEFINED = 0; - // Returned when everything goes well. - OK = 1; - // Returned when there is no state to load. - NOT_FOUND = 2; - // Returned when something internal did not work as expected. - ERROR = 3; - } - Status status = 1; - // `error_message` is only populated in case of errors. - string error_message = 2; - - // Any additional info returned during the request; e.g., file paths or sizes. - map additional_info = 3; -} \ No newline at end of file diff --git a/android_env/proto/task.proto b/android_env/proto/task.proto deleted file mode 100644 index b39fd3e7..00000000 --- a/android_env/proto/task.proto +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2024 DeepMind Technologies Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package android_env; - -import "android_env/proto/adb.proto"; - -// An AppScreen identifies a unique configuration that we can observe on the -// screen of a device. -message AppScreen { - // Fully-qualified name of the activity. - string activity = 1; - - // A list of regexes to match at each level of the current view hierarchy. - // The environment uses this list to determine whether the agent has "exited" - // this current task. - // Example: [ - // "^DecorView@.*\[MainActivity\]$", - // "^android.widget.LinearLayout\{.*\}$", - // "^android.widget.FrameLayout\{.*android\:id\/content\}", - // "^android.widget.RelativeLayout\{.*\}", - // "^android.widget.FrameLayout\{.*app\:id\/fragment_holder\}", - // "^android.widget.RelativeLayout\{.*\}", - // "^com.google.example.games.nostalgicracer.views.RaceView3D\{.*app\:id\/gameplay_screen_3d\}", - // ], - repeated string view_hierarchy_path = 2; -} - -// Waits for `app_screen` to be the current app screen shown to the user. -message WaitForAppScreen { - AppScreen app_screen = 1; - // Maximum time in seconds to wait for the activity to become the current one. - float timeout_sec = 2; -} - -message CheckInstall { - string package_name = 1; - // Maximum time in seconds to wait. - float timeout_sec = 2; -} - -message Sleep { - float time_sec = 1; -} - -message SuccessCondition { - int32 num_retries = 1; - - oneof check { - WaitForAppScreen wait_for_app_screen = 2; - CheckInstall check_install = 3; - } -} - -message SetupStep { - SuccessCondition success_condition = 1; - - oneof step { - AdbRequest adb_request = 2; - Sleep sleep = 3; - } -} - -// A specification of structured observations -// Analogous to dm_env.specs.Array() - -message ArraySpec { - // An identifier for this ArraySpec. - string name = 1; - - // The shape of the multi-dimensional values associated with this ArraySpec, - repeated int32 shape = 2; - - enum DataType { - INVALID_DATA_TYPE = 0; - FLOAT = 1; - DOUBLE = 2; - INT8 = 3; - INT16 = 4; - INT32 = 5; - INT64 = 6; - UINT8 = 7; - UINT16 = 8; - UINT32 = 9; - UINT64 = 10; - BOOL = 11; - STRING_U1 = 12; - STRING_U16 = 13; - STRING_U25 = 14; - STRING_U250 = 15; - STRING = 16; // String without max length - OBJECT = 17; - } - - // Data type of elements we expect to see in an array of this spec. - DataType dtype = 3; -} - -message LogParsingConfig { - // `filters` are tags used by the app's logging system so that we can - // identify them in logcat's output. It's the first argument to logging calls - // such as Log.e("ActivityManager", "My message"). - // Example: "ActivityManager" - repeated string filters = 1; - - // Regular expressions that define how we can extract RL information such as - // score, extras and episode end from raw logcat messages. - message LogRegexps { - // Regexp expected to match: - // ...a floating point value which gets accumulated over time. - // A delta in 'score' corresponds to the reward. - string score = 1; - - // Regexp expected to match: - // ...a floating point value directly forwarded by the environment. - repeated string reward = 2; - - // Regexp expected to match: - // ...a signal marking the end of an episode. - repeated string episode_end = 3; - - // Regexp expected to match: - // ...a string representing pairs of extra names and values. - repeated string extra = 4; - - // Regexp expected to match: - // ...a dict of extra names and values in json format. - repeated string json_extra = 5; - - // Attaches rewards to arbitrary log messages, for example: - // {event: "coin_collected" reward: 2.3} - // {event: "car_crashed" reward: -1.4} - message RewardEvent { - // If `event` is matched, the environment will give `reward`. - string event = 1; - - // Numerical value to give as reward if `event` is matched. - float reward = 2; - } - - repeated RewardEvent reward_event = 6; - } - - LogRegexps log_regexps = 2; -} - -// Description of a reinforcement learning task to be solved by an agent. -message Task { - // A globally unique identifier for this task. - string id = 1; - - // A human readable name for this task. - string name = 2; - - // A description of the task. - string description = 3; - - repeated SetupStep setup_steps = 4; - repeated SetupStep reset_steps = 5; - - AppScreen expected_app_screen = 6; - - // AndroidEnv resets the episode after `max_episode_sec` is passed since the - // last reset(). Recommended for time sensitive tasks (e.g. reactive games). - // Note that this is real time as measured by AndroidEnv and is independent of - // the speed of simulation of Android. - // If <= 0.0, this logic is disabled. - float max_episode_sec = 7; - - // The maximum number of interactions in a single episode between the - // environment and an agent. - // This setting is appropriate for tasks that are not time-dependent or when - // the performance of the simulation varies dramatically between runs. - // If <= 0, this logic is disabled. - int32 max_episode_steps = 8; - - // Defines parameters for parsing messages from logcat. - LogParsingConfig log_parsing_config = 9; - - // NOTE: This field is deprecated and will be removed from this Task - // definition soon. - // - // (Optional): The task may also define extras to help the RL agent. - // An Extra in AndroidEnv is any information that apps may send to aid the - // understanding of the task. The type of information sent through this - // channel is usually something difficult to obtain from raw pixels and may - // include things such as: - // - // - The current board configuration (e.g. of a chess game or a tetris game) - // - The position of the avatar in a map - // - Events (e.g. whether a button was pressed or a checkpoint was achieved) - // - // Notice that these are entirely optional and may not be available at all. - // This specification ensures that only extras specified in the Task - // definition will be passed to the agent, everything else is excluded. - // The name of an extra must be unique across all extras. - repeated ArraySpec extras_spec = 10; -} diff --git a/android_env/wrappers/__init__.py b/android_env/wrappers/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/wrappers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/wrappers/a11y/__init__.py b/android_env/wrappers/a11y/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/android_env/wrappers/a11y/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/android_env/wrappers/a11y/a11y_events.py b/android_env/wrappers/a11y/a11y_events.py deleted file mode 100644 index 2df3e855..00000000 --- a/android_env/wrappers/a11y/a11y_events.py +++ /dev/null @@ -1,118 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for accessing accessibility events.""" - -from collections.abc import Mapping -from typing import Any - -from absl import logging -from android_env.proto.a11y import a11y_pb2 -import numpy as np - -from google.protobuf import any_pb2 - - -_A11Y_EVENT_KEY = 'full_event' - - -def package_events_to_task_extras( - events: list[a11y_pb2.EventRequest], -) -> Mapping[str, np.ndarray]: - if not events: - return {} - events = np.stack(events, axis=0) - return {_A11Y_EVENT_KEY: events} - - -def extract_events_from_task_extras( - task_extras: Mapping[str, Any] | None = None, -) -> list[Mapping[str, str]]: - """Inspects task_extras and extracts all accessibility events detected. - - Args: - task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a - key in task_extras, then this function returns an empty string. Otherwise, - full_event is expected to be list to be a numpy array with one dimension, - and contains a list of dictionary describing accessibility events that are - present in the given task extras. e.g. 'event_type: - TYPE_WINDOW_CONTENT_CHANGED // event_package_name: - com.google.android.deskclock // source_class_name: - android.widget.ImageView'. - - Returns: - List of all events detected - """ - if task_extras is None or _A11Y_EVENT_KEY not in task_extras: - return [] - - if ( - not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) - or task_extras[_A11Y_EVENT_KEY].ndim != 1 - ): - raise ValueError( - f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' - ' dimension.' - ) - - if task_extras[_A11Y_EVENT_KEY].size == 0: - return [] - - events = [] - for e in task_extras[_A11Y_EVENT_KEY]: - if isinstance(e, a11y_pb2.EventRequest): - events.append(dict(e.event)) - elif isinstance(e, dict): - events.append(e) - logging.warning( - 'The event should come only from the a11y_grpc_wrapper. ' - 'Please verify that the upacking operation has not been ' - 'called twice. See here for full task_extras: %s', - task_extras, - ) - elif isinstance(e, any_pb2.Any): - ev = a11y_pb2.EventRequest() - new_any = any_pb2.Any() - new_any.CopyFrom(e) - new_any.Unpack(ev) - events.append(dict(ev.event)) - - else: - raise TypeError( - f'Unexpected event type: {type(e)}. See here for full ' - f'task_extras: {task_extras}.' - ) - - return events - - -def keep_latest_event_only(task_extras: dict[str, Any]): - """Removes all a11y events except the last one observed.""" - if task_extras is None or 'full_event' not in task_extras: - return - - if ( - not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) - or task_extras[_A11Y_EVENT_KEY].ndim != 1 - ): - raise ValueError( - f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' - ' dimension.' - ) - - if task_extras[_A11Y_EVENT_KEY].size == 0: - return [] - - task_extras[_A11Y_EVENT_KEY] = task_extras[_A11Y_EVENT_KEY][-1:] diff --git a/android_env/wrappers/a11y/a11y_events_test.py b/android_env/wrappers/a11y/a11y_events_test.py deleted file mode 100644 index 400fd801..00000000 --- a/android_env/wrappers/a11y/a11y_events_test.py +++ /dev/null @@ -1,173 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for a11y_events.""" - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.proto.a11y import a11y_pb2 -from android_env.wrappers.a11y import a11y_events -import numpy as np - -from google.protobuf import any_pb2 - - -def _event_request(d: dict[str, str]) -> a11y_pb2.EventRequest: - event_request = a11y_pb2.EventRequest() - for k, v in d.items(): - event_request.event[k] = v - return event_request - - -def _event_request_as_any(d: dict[str, str]) -> any_pb2.Any: - event_request = _event_request(d) - response = any_pb2.Any() - response.Pack(event_request) - return response - - -class A11yEventsTest(parameterized.TestCase): - - @parameterized.parameters( - dict(task_extras={}), - dict( - task_extras={'no_full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}]}, - ), - dict( - task_extras={'full_event': np.array([])}, - ), - dict( - task_extras={}, - ), - ) - def test_no_events_in_task_extras(self, task_extras): - events = a11y_events.extract_events_from_task_extras(task_extras) - self.assertEmpty(events) - - @parameterized.parameters( - dict( - task_extras={'full_event': [{'1': '1'}, {'2': '2'}]}, - expected_events=[{'1': '1'}, {'2': '2'}], - ), - dict( - task_extras={'full_event': [{}]}, - expected_events=[{}], - ), - dict( - task_extras={ - 'full_event_wrong_key': [1, 2, 3], - 'full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}], - }, - expected_events=[{'1': '1'}, {'2': '2'}, {'3': '3'}], - ), - ) - def test_task_extras(self, task_extras, expected_events): - event_requests = [_event_request(e) for e in task_extras['full_event']] - task_extras['full_event'] = np.stack(event_requests, axis=0) - events = a11y_events.extract_events_from_task_extras(task_extras) - self.assertEqual(len(events), len(expected_events)) - for i, event in enumerate(expected_events): - self.assertEqual(len(event), len(expected_events[i])) - for k, v in event.items(): - self.assertIn(k, expected_events[i]) - self.assertEqual(v, expected_events[i][k]) - - def test_events_key_has_dict_event_requrests(self): - event_requests = [ - _event_request({'1': '1'}), - {'2': '2'}, - _event_request({'3': '3'}), - ] - expected_events = [ - {'1': '1'}, - {'2': '2'}, - {'3': '3'}, - ] - task_extras = {'full_event': np.stack(event_requests, axis=0)} - events = a11y_events.extract_events_from_task_extras(task_extras) - self.assertEqual(len(events), len(expected_events)) - for i, event in enumerate(expected_events): - self.assertEqual(len(event), len(expected_events[i])) - for k, v in event.items(): - self.assertIn(k, expected_events[i]) - self.assertEqual(v, expected_events[i][k]) - - def test_events_key_has__event_requrests_packed_as_any(self): - event_requests = [ - _event_request_as_any({'1': '1'}), - {'2': '2'}, - _event_request_as_any({'3': '3'}), - ] - expected_events = [ - {'1': '1'}, - {'2': '2'}, - {'3': '3'}, - ] - task_extras = {'full_event': np.stack(event_requests, axis=0)} - events = a11y_events.extract_events_from_task_extras(task_extras) - self.assertEqual(len(events), len(expected_events)) - for i, event in enumerate(expected_events): - self.assertEqual(len(event), len(expected_events[i])) - for k, v in event.items(): - self.assertIn(k, expected_events[i]) - self.assertEqual(v, expected_events[i][k]) - - def test_events_key_has_non_event_requrests(self): - event_requests = [ - _event_request({'1': '1'}), - 3, # Not an even and not a dict. - _event_request({'3': '3'}), - ] - task_extras = {'full_event': np.stack(event_requests, axis=0)} - with self.assertRaises(TypeError): - _ = a11y_events.extract_events_from_task_extras(task_extras) - - @parameterized.parameters( - dict(task_extras={}, expected_extras={}), - dict( - task_extras={ - 'no_full_event': 42, - }, - expected_extras={ - 'no_full_event': 42, - }, - ), - dict( - task_extras={'full_event': np.array([1, 2]), 'no_full_event': 43}, - expected_extras={'full_event': np.array([2]), 'no_full_event': 43}, - ), - dict( - task_extras={'full_event': np.array([1, 2, 3])}, - expected_extras={'full_event': np.array([3])}, - ), - dict( - task_extras={'full_event': np.array([]), 'no_full_event': 44}, - expected_extras={'full_event': np.array([]), 'no_full_event': 44}, - ), - ) - def test_keep_latest_only(self, task_extras, expected_extras): - a11y_events.keep_latest_event_only(task_extras) - self.assertEqual(len(task_extras), len(expected_extras)) - for k, v in task_extras.items(): - self.assertIn(k, expected_extras) - if k == 'full_event': - np.testing.assert_array_equal(v, expected_extras['full_event']) - else: - self.assertEqual(v, expected_extras[k]) - pass - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/a11y/a11y_forests.py b/android_env/wrappers/a11y/a11y_forests.py deleted file mode 100644 index 1cd8ef2d..00000000 --- a/android_env/wrappers/a11y/a11y_forests.py +++ /dev/null @@ -1,128 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for accessing accessibility events.""" - -from collections.abc import Mapping -from typing import Any - -from android_env.proto.a11y import android_accessibility_forest_pb2 -import numpy as np - -from google.protobuf import any_pb2 - - -_A11Y_FORESTS_KEY = 'accessibility_tree' - - -def package_forests_to_task_extras( - forests: list[android_accessibility_forest_pb2.AndroidAccessibilityForest], -) -> Mapping[str, np.ndarray]: - if not forests: - return {} - forests = np.stack(forests, axis=0) - return {_A11Y_FORESTS_KEY: forests} - - -def task_extras_has_forests(task_extras: Mapping[str, Any]) -> bool: - """Checks that the task_extras has any a11y forest information.""" - if _A11Y_FORESTS_KEY not in task_extras: - return False - - payload = task_extras[_A11Y_FORESTS_KEY] - if not isinstance(payload, np.ndarray) or payload.ndim != 1: - raise ValueError( - f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' - f' dimension. payload: {payload}' - ) - - if payload.size == 0: - return False - - if any(isinstance(f, any_pb2.Any) for f in payload): - # Forests were packed as Any. - return True - - return any( - isinstance(f, android_accessibility_forest_pb2.AndroidAccessibilityForest) - for f in payload - ) - - -def convert_to_forest( - forest: android_accessibility_forest_pb2.AndroidAccessibilityForest - | any_pb2.Any - | None, -) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: - """Takes an object and attempts to convert it to a forest.""" - if forest is None: - return None - - if isinstance(forest, any_pb2.Any): - output = android_accessibility_forest_pb2.AndroidAccessibilityForest() - new_any = any_pb2.Any() - new_any.CopyFrom(forest) - new_any.Unpack(output) - return output - elif isinstance( - forest, android_accessibility_forest_pb2.AndroidAccessibilityForest - ): - return forest - else: - return None - - -def extract_forests_from_task_extras( - task_extras: Mapping[str, Any] | None = None, -) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: - """Inspects task_extras and extracts all accessibility forests detected. - - Args: - task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a - key in task_extras, then this function returns an empty string. Otherwise, - full_event is expected to be list to be a numpy array with one dimension, - and contains a list of dictionary describing accessibility forests that - are present in the given task extras. - - Returns: - List of all forests detected - """ - if task_extras is None or not task_extras_has_forests(task_extras): - return [] - - forests = [] - for f in task_extras[_A11Y_FORESTS_KEY]: - f = convert_to_forest(f) - if f is not None: - forests.append(f) - return forests - - -def keep_latest_forest_only(task_extras: dict[str, Any]): - """Removes all a11y forests except the last one observed.""" - if _A11Y_FORESTS_KEY not in task_extras.keys(): - return - - payload = task_extras[_A11Y_FORESTS_KEY] - if not isinstance(payload, np.ndarray) or payload.ndim != 1: - raise ValueError( - f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' - f' dimension. payload: {payload}' - ) - - if payload.size == 0: - return - - task_extras[_A11Y_FORESTS_KEY] = payload[-1:] diff --git a/android_env/wrappers/a11y/a11y_forests_test.py b/android_env/wrappers/a11y/a11y_forests_test.py deleted file mode 100644 index b57f30aa..00000000 --- a/android_env/wrappers/a11y/a11y_forests_test.py +++ /dev/null @@ -1,237 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for a11y_forests.""" - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.proto.a11y import android_accessibility_forest_pb2 -from android_env.wrappers.a11y import a11y_forests -import numpy as np - -from google.protobuf import any_pb2 - - -def _pack_any(proto_message) -> any_pb2.Any: - response = any_pb2.Any() - response.Pack(proto_message) - return response - - -def _empty_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - return android_accessibility_forest_pb2.AndroidAccessibilityForest() - - -def _one_empty_window_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - forest.windows.add() - return forest - - -def _two_window_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - window = forest.windows.add() - window.tree.nodes.add( - class_name='foo', is_clickable=True, hint_text='Foo hint' - ) - forest.windows.add() - return forest - - -class A11YForestsTest(parameterized.TestCase): - - @parameterized.parameters( - dict(task_extras={}, expected_forests=[], convert_to_np=[]), - dict( - task_extras={'accessibility_tree': []}, - convert_to_np=['accessibility_tree'], - expected_forests=[], - ), - dict( - task_extras={ - 'not_accessibility_tree': [ - _empty_forest(), - _one_empty_window_forest(), - _two_window_forest(), - ], - }, - convert_to_np=['not_accessibility_tree'], - expected_forests=[], - ), - dict( - task_extras={ - 'accessibility_tree': [ - _empty_forest(), - {'not_a_forest_key': 'nor_a_forest_value'}, - _two_window_forest(), - ] - }, - convert_to_np=['accessibility_tree'], - expected_forests=[_empty_forest(), _two_window_forest()], - ), - dict( - task_extras={ - 'accessibility_tree': [ - {'not_a_forest_key': 'nor_a_forest_value'}, - 3, - 4, - {'not_a_forest_key': _empty_forest()}, - ], - }, - convert_to_np=['accessibility_tree'], - expected_forests=[], - ), - dict( - task_extras={'accessibility_tree': []}, - convert_to_np=['accessibility_tree'], - expected_forests=[], - ), - dict( - task_extras={ - 'accessibility_tree_wrong_key': [1, 2, 3], - 'accessibility_tree': [ - _empty_forest(), - None, - None, - _one_empty_window_forest(), - _two_window_forest(), - ], - }, - convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], - expected_forests=[ - _empty_forest(), - _one_empty_window_forest(), - _two_window_forest(), - ], - ), - dict( - task_extras={ - 'accessibility_tree_wrong_key': [1, 2, 3], - 'accessibility_tree': [ - None, - _pack_any(_empty_forest()), - _pack_any(_one_empty_window_forest()), - _pack_any(_two_window_forest()), - ], - }, - convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], - expected_forests=[ - _empty_forest(), - _one_empty_window_forest(), - _two_window_forest(), - ], - ), - dict( - task_extras={ - 'accessibility_tree': [ - _pack_any(_empty_forest()), - {'not_a_forest_key': 'nor_a_forest_value'}, - None, - _two_window_forest(), - None, - ] - }, - convert_to_np=['accessibility_tree'], - expected_forests=[_empty_forest(), _two_window_forest()], - ), - ) - def test_task_extras(self, task_extras, expected_forests, convert_to_np): - for k in convert_to_np: - if task_extras[k]: - task_extras[k] = np.stack(task_extras[k], axis=0) - else: - task_extras[k] = np.array([]) - forests = a11y_forests.extract_forests_from_task_extras(task_extras) - self.assertEqual(len(forests), len(expected_forests)) - for idx, f in enumerate(forests): - self.assertEqual(f, expected_forests[idx]) - - @parameterized.parameters( - dict(task_extras={}, expected_extras={}), - dict( - task_extras={ - 'no_accessibility_tree': 42, - }, - expected_extras={ - 'no_accessibility_tree': 42, - }, - ), - dict( - task_extras={'accessibility_tree': []}, - expected_extras={'accessibility_tree': []}, - ), - dict( - task_extras={ - 'accessibility_tree': [ - _empty_forest(), - _one_empty_window_forest(), - ], - 'no_accessibility_tree': 43, - }, - expected_extras={ - 'accessibility_tree': [_one_empty_window_forest()], - 'no_accessibility_tree': 43, - }, - ), - dict( - task_extras={ - 'accessibility_tree': [ - _empty_forest(), - _one_empty_window_forest(), - _two_window_forest(), - ] - }, - expected_extras={'accessibility_tree': [_two_window_forest()]}, - ), - dict( - task_extras={ - 'accessibility_tree': [], - 'no_accessibility_tree': 44, - }, - expected_extras={ - 'accessibility_tree': [], - 'no_accessibility_tree': 44, - }, - ), - ) - def test_keep_latest_only(self, task_extras, expected_extras): - if 'accessibility_tree' in task_extras: - if task_extras['accessibility_tree']: - task_extras['accessibility_tree'] = np.stack( - task_extras['accessibility_tree'], axis=0 - ) - else: - task_extras['accessibility_tree'] = np.array([]) - - a11y_forests.keep_latest_forest_only(task_extras) - self.assertSameElements(task_extras.keys(), expected_extras.keys()) - for k in task_extras.keys(): - if k == 'accessibility_tree': - self.assertEqual(len(task_extras[k]), len(expected_extras[k])) - for idx, f in enumerate(task_extras[k]): - self.assertEqual(f, expected_extras[k][idx]) - else: - self.assertEqual(task_extras[k], expected_extras[k]) - pass - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/a11y/a11y_servicer.py b/android_env/wrappers/a11y/a11y_servicer.py deleted file mode 100644 index 82e9e253..00000000 --- a/android_env/wrappers/a11y/a11y_servicer.py +++ /dev/null @@ -1,199 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Accessibility Servicer implementation.""" - -import asyncio -from collections.abc import AsyncIterator, Generator, Iterable -import threading - -from absl import logging -from android_env.proto.a11y import a11y_pb2 -from android_env.proto.a11y import a11y_pb2_grpc -from android_env.proto.a11y import android_accessibility_forest_pb2 -import grpc - - -class A11yServicer(a11y_pb2_grpc.A11yServiceServicer): - """Services the A11yService requests.""" - - def __init__(self, latest_forest_only: bool = False): - self._received_forests: list[ - android_accessibility_forest_pb2.AndroidAccessibilityForest - ] = [] - self._received_events: list[a11y_pb2.EventRequest] = [] - self._lock_forests = threading.Lock() - self._lock_events = threading.Lock() - self._latest_forest_only = latest_forest_only - self._paused = True - - # A11y Forest bookkeeping. - self._get_forest = asyncio.Event() # Whether to request a forest. - self._forest_ready = asyncio.Event() # Whether the forest is ready. - self._latest_forest: ( - android_accessibility_forest_pb2.AndroidAccessibilityForest | None - ) = None - - def SendForest( - self, - request: android_accessibility_forest_pb2.AndroidAccessibilityForest, - context: grpc.ServicerContext, - ) -> a11y_pb2.ForestResponse: - self._process_forest(request) - return a11y_pb2.ForestResponse() - - def SendEvent( - self, - request: a11y_pb2.EventRequest, - context: grpc.ServicerContext, - ) -> a11y_pb2.EventResponse: - self._process_event(request) - return a11y_pb2.EventResponse() - - async def Bidi( - self, - request_iterator: AsyncIterator[a11y_pb2.ClientToServer], - context: grpc.aio.ServicerContext, - ) -> AsyncIterator[a11y_pb2.ServerToClient]: - """Processes incoming ClientToServer requests.""" - - logging.info('Starting A11yServicer.Bidi()') - - # Send a dummy message to unblock clients in their loop. - yield a11y_pb2.ServerToClient() - - # This block defines two coroutines: - # - # * `read_client_requests()` - # * `check_forest()` - # - # They cooperate with each other and both populate a queue `q` which is - # consumed in a loop below, which actually yields requests which are sent to - # the client. The processing finishes when the clients "closes" the - # connection, which causes `read_client_requests()` to put a special value, - # `STOP_ITERATION`, in the queue. - - # Queue for communicating from coroutines to `Bidi()`. - q = asyncio.Queue() - - should_run = True - - async def read_client_requests(): - """Coroutine for reading client requests.""" - - nonlocal should_run - async for request in request_iterator: - field_name = request.WhichOneof('payload') - match field_name: - case 'event': - self._process_event(request.event) - case 'forest': - self._latest_forest = request.forest - self._forest_ready.set() - self._get_forest.clear() # Reset the `Event`. - case _: - logging.error('Unknown field %r', field_name) - await q.put(a11y_pb2.ServerToClient()) - - # Send a special value to stop processing this `Bidi` connection. - await q.put('STOP_ITERATION') - should_run = False - - async def check_forest(): - """Coroutine for sending "get forest" requests.""" - - nonlocal should_run - while should_run: - await self._get_forest.wait() - await q.put(a11y_pb2.ServerToClient(get_forest={})) - - tasks = asyncio.gather(read_client_requests(), check_forest()) - - while should_run: - v = await q.get() - if v == 'STOP_ITERATION': - break - else: - yield v - - await tasks - - logging.info('Finishing A11yServicer.Bidi()') - - async def get_forest( - self, - ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: - """Issues a request to get the a11y forest from the client.""" - - self._get_forest.set() # Unblocks coroutine to send a request. - await self._forest_ready.wait() # Wait for forest to be ready. - self._forest_ready.clear() # Reset the `Event`. - return self._latest_forest - - def gather_forests( - self, - ) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: - forests = [] - with self._lock_forests: - forests = self._received_forests - self._received_forests = [] - return forests - - def gather_events(self) -> list[a11y_pb2.EventRequest]: - events = [] - with self._lock_events: - events = self._received_events - self._received_events = [] - return events - - def pause_and_clear(self) -> None: - """Temporarily stop receiving events/forests and clear the queue. - - Used when resetting the environment; in this case: - - all events/forests that have been received since last timestep are things - that happened in the last episode after its `LAST` timestep (so we should - ignore them, done by clearing the lists). - - we're about to receive a bunch of events/forests just as a result of - resetting the environment. We don't want to count these either; thus we - temporarily stop receiving new ones. - """ - self._paused = True - with self._lock_forests: - self._received_forests = [] - with self._lock_events: - self._received_events = [] - - def resume(self) -> None: - """Start receiving events/forests (e.g., after a reset).""" - self._paused = False - - def _process_event(self, event: a11y_pb2.EventRequest) -> None: - """Adds the given event to the internal buffer of events.""" - - if not self._paused: - with self._lock_events: - self._received_events.append(event) - - def _process_forest( - self, forest: android_accessibility_forest_pb2.AndroidAccessibilityForest - ) -> None: - """Adds the given forest to the internal buffer of forests.""" - - if not self._paused: - with self._lock_forests: - if self._latest_forest_only: - self._received_forests = [forest] - else: - self._received_forests.append(forest) diff --git a/android_env/wrappers/a11y/a11y_servicer_test.py b/android_env/wrappers/a11y/a11y_servicer_test.py deleted file mode 100644 index 91be6024..00000000 --- a/android_env/wrappers/a11y/a11y_servicer_test.py +++ /dev/null @@ -1,224 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for a11y_servicer.""" - -import asyncio -from collections.abc import AsyncIterator, Iterable -from typing import TypeVar -from unittest import IsolatedAsyncioTestCase, mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env.proto.a11y import a11y_pb2 -from android_env.proto.a11y import android_accessibility_forest_pb2 -from android_env.wrappers.a11y import a11y_servicer -import grpc - - -_T = TypeVar('_T') - - -async def _aiter(xs: Iterable[_T]) -> AsyncIterator[_T]: - """Utility to make an AsyncIterator from Iterable.""" - - for x in xs: - yield x - - -def one_window_one_node_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'foo' - node.is_clickable = True - node.hint_text = 'Foo hint' - return forest - - -def one_window_two_nodes_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'bar' - node.is_clickable = True - node.hint_text = 'Bar hint' - node = window.tree.nodes.add() - node.class_name = 'bar' - node.is_clickable = False - node.hint_text = 'Bar hint 2' - return forest - - -def empty_dict() -> dict[str, str]: - return {} - - -def single_item_dict_with_special_chars() -> dict[str, str]: - return {'foo': 'bar\r\t\nbaz'} - - -class A11yServicerTest(parameterized.TestCase, IsolatedAsyncioTestCase): - - def test_servicer_sendforest(self): - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - servicer.resume() - response = servicer.SendForest(one_window_one_node_forest(), mock_context) - self.assertEqual(response.error, '') - response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) - self.assertEqual(response.error, '') - forests = servicer.gather_forests() - self.assertLen(forests, 2) - self.assertEqual(forests[0], one_window_one_node_forest()) - self.assertEqual(forests[1], one_window_two_nodes_forest()) - - async def test_servicer_bidi_forests(self): - """Checks that the bidirectional interface accepts forests.""" - - # Arrange. - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - - # Act. - servicer.resume() - responses = [ - x - async for x in servicer.Bidi( - _aiter([ - a11y_pb2.ClientToServer( - event=a11y_pb2.EventRequest( - event=single_item_dict_with_special_chars() - ) - ), - a11y_pb2.ClientToServer(forest=one_window_two_nodes_forest()), - ]), - mock_context, - ) - ] - forest = await servicer.get_forest() - - # Assert. - self.assertEqual(responses[0], a11y_pb2.ServerToClient()) - self.assertEqual(responses[1], a11y_pb2.ServerToClient()) - self.assertIsNotNone(forest) - self.assertEqual(forest, one_window_two_nodes_forest()) - - def test_servicer_sendforest_latest_only(self): - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer(latest_forest_only=True) - servicer.resume() - response = servicer.SendForest(one_window_one_node_forest(), mock_context) - self.assertEqual(response.error, '') - response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) - self.assertEqual(response.error, '') - forests = servicer.gather_forests() - self.assertLen(forests, 1) - self.assertEqual(forests[0], one_window_two_nodes_forest()) - - def test_servicer_sendevent(self): - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - servicer.resume() - response = servicer.SendEvent( - a11y_pb2.EventRequest(event=empty_dict()), mock_context - ) - self.assertEqual(response.error, '') - response = servicer.SendEvent( - a11y_pb2.EventRequest(event=single_item_dict_with_special_chars()), - mock_context, - ) - self.assertEqual(response.error, '') - events = servicer.gather_events() - self.assertLen(events, 2) - self.assertEqual(events[0].event, empty_dict()) - self.assertEqual(events[1].event, single_item_dict_with_special_chars()) - - async def test_servicer_bidi_events(self): - """Checks that the bidirectional interface accepts events.""" - - # Arrange. - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - - # Act. - servicer.resume() - responses = [ - x - async for x in servicer.Bidi( - _aiter([ - a11y_pb2.ClientToServer( - event=a11y_pb2.EventRequest(event=empty_dict()) - ), - a11y_pb2.ClientToServer( - event=a11y_pb2.EventRequest( - event=single_item_dict_with_special_chars() - ) - ), - ]), - mock_context, - ) - ] - events = servicer.gather_events() - - # Assert. - self.assertEqual(responses[0], a11y_pb2.ServerToClient()) - self.assertEqual(responses[1], a11y_pb2.ServerToClient()) - self.assertLen(events, 2) - self.assertEqual(events[0].event, empty_dict()) - self.assertEqual(events[1].event, single_item_dict_with_special_chars()) - - def test_servicer_pause_and_clear_pauses(self): - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - servicer.resume() - servicer.pause_and_clear() - response = servicer.SendEvent( - a11y_pb2.EventRequest(event=empty_dict()), mock_context - ) - self.assertEqual(response.error, '') - response = servicer.SendForest(one_window_one_node_forest(), mock_context) - self.assertEqual(response.error, '') - events = servicer.gather_events() - self.assertEmpty(events) - forests = servicer.gather_forests() - self.assertEmpty(forests) - - def test_servicer_pause_and_clear_clears(self): - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - servicer = a11y_servicer.A11yServicer() - servicer.resume() - response = servicer.SendEvent( - a11y_pb2.EventRequest(event=empty_dict()), mock_context - ) - self.assertEqual(response.error, '') - response = servicer.SendForest(one_window_one_node_forest(), mock_context) - self.assertEqual( - response.error, - '', - ) - servicer.pause_and_clear() - events = servicer.gather_events() - self.assertEmpty(events) - forests = servicer.gather_forests() - self.assertEmpty(forests) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/a11y_grpc_wrapper.py b/android_env/wrappers/a11y_grpc_wrapper.py deleted file mode 100644 index 5b010ad3..00000000 --- a/android_env/wrappers/a11y_grpc_wrapper.py +++ /dev/null @@ -1,500 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps AndroidEnv to retrieve accessibility messages from gRPC.""" - -from concurrent import futures -import time -from typing import Any - -import urllib - -from absl import logging -from android_env import env_interface -from android_env.components import action_type as android_action_type_lib -from android_env.proto import adb_pb2 -from android_env.proto.a11y import a11y_pb2_grpc -from android_env.wrappers import base_wrapper -from android_env.wrappers.a11y import a11y_events -from android_env.wrappers.a11y import a11y_forests -from android_env.wrappers.a11y import a11y_servicer -import dm_env -import grpc -import numpy as np -import portpicker - - -def _get_accessibility_forwarder_apk() -> bytes: - logging.info('Downloading accessibility forwarder apk....') - with urllib.request.urlopen( - 'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk' - ) as response: - return response.read() - - -class EnableNetworkingError(ValueError): - pass - - -class A11yGrpcWrapper(base_wrapper.BaseWrapper): - """Wrapper which receives A11y events and forests over gRPC. - - A11y forest protobufs and event dicts are sent from the Android emulator via - gRPC from the `AccessibilityForwarder` (for use in developing reward - functions, etc). This wrapper constructs a server which receives these - messages and channels them into `task_extras`. - - The downside of forwarding this information through gRPC is that no messages - will be sent if networking is turned off (e.g., if the AVD is in airplane - mode). To mitigate this problem, the `AccessibilityForwarder` logs an error - message if it fails to contact the server. This wrapper monitors the logs for - such error messages, and attempts (in another thread, to not block environment - transitions) to reconnect the AVD to the network. If this fails to fix the - problem, this wrapper ends the episode. - - This wrapper is implemented to be robust to multiple upstream callers of - `task_extras`, and to ensure they each receive the same extras at every - timestep. Thus, the logic is the following: - * New a11y events/forests are fetched during `reset` and `step`, *not* during - `task_extras()` calls. - * If no one has called `task_extras()` since the last `step` or `reset`, the - extras are accumulated (so that no extras are missed because someone called - `step()` twice without calling `task_extras()`). - * If someone *has* called `task_extras()` since last step, the newly fetched - extras replace the old extras. - """ - - def __init__( - self, - env: env_interface.AndroidEnvInterface, - disable_other_network_traffic: bool = False, - install_a11y_forwarding: bool = False, - start_a11y_service: bool = True, - enable_a11y_tree_info: bool = False, - add_latest_a11y_info_to_obs: bool = False, - a11y_info_timeout: float | None = None, - max_enable_networking_attempts: int = 10, - latest_a11y_info_only: bool = False, - ): - """Initializes wrapper. - - Args: - env: Environment to wrap. - disable_other_network_traffic: When True, all network traffic, other than - the connection to the servicer, is disabled. NOTE: This requires root - access on the device (i.e. it uses the `su` command). An - `AdbControllerError` exception will be raised if the underlying command - fails. - install_a11y_forwarding: If True, the wrapper handles the installation of - all packages required for the servicer to collect a11y information. - start_a11y_service: If True, starts the a11y forwarding services. NOTE: - The packages must be installed beforehand, e.g., using the - install_a11y_forwarding flag. - enable_a11y_tree_info: When False, this wrapper collects only a11y events - and not a11y tree. - add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is - added to the observation. - a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs - is set to True, the wrapper will wait the corresponding amount of time, - measured in seconds, to collect the latest a11y forest. - max_enable_networking_attempts: When the a11y gRPC service fails to - provide a11y information, we attempt this many times to re-enable the - networking. If all these attempts fail, fetching task_extras will raise - an EnableNetworkingError. - latest_a11y_info_only: When True, the a11y servicer is setup to save only - the latest tree it has received from the Android app. - """ - self._env = env - if install_a11y_forwarding: - self._install_a11y_forwarding_apk() - time.sleep(10.0) - if start_a11y_service: - self._start_a11y_services() - time.sleep(3.0) - if enable_a11y_tree_info: - self._enable_a11y_tree_logs() - self._relaunch_count = 0 - self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - self._servicer = a11y_servicer.A11yServicer( - latest_forest_only=latest_a11y_info_only - ) - a11y_pb2_grpc.add_A11yServiceServicer_to_server( - self._servicer, self._server - ) - server_credentials = grpc.local_server_credentials() - self._port = portpicker.pick_unused_port() - logging.info('Using port %s', self._port) - uri_address = f'[::]:{self._port}' - self._server.add_secure_port(uri_address, server_credentials) - logging.info('Starting server') - self._server.start() - logging.info('Server now running.') - - self._max_enable_networking_attempts = max_enable_networking_attempts - self._reset_enable_networking_attempts() - - self._disable_other_network_traffic = disable_other_network_traffic - self._should_accumulate = False - self._accumulated_extras = None - self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs - self._a11y_info_timeout = a11y_info_timeout - self._parent_action_spec = self._env.action_spec() - if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0: - if 'action_type' not in self._parent_action_spec.keys(): - raise ValueError( - 'action_type not in the parent action spec: ' - f'{self._parent_action_spec}. This is a strong requirement when ' - f'a11y_info_timeout = {a11y_info_timeout} > 0' - ) - - def _start_a11y_services(self) -> None: - """Starts the accessibility forwarder services. - - Raises: - RuntimeError: If accessibility service is not started. - """ - start_service_request = adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='enabled_accessibility_services', - value=( - 'com.google.androidenv.accessibilityforwarder/com.google.' - 'androidenv.accessibilityforwarder.AccessibilityForwarder' - ), - ), - ) - ) - start_service_response = self._env.execute_adb_call(start_service_request) - if start_service_response.status != adb_pb2.AdbResponse.Status.OK: - raise RuntimeError( - 'Could not start accessibility forwarder ' - 'service: ' - f'{start_service_response}.' - ) - - def _install_a11y_forwarding_apk(self) -> None: - """Enables accessibility information forwarding.""" - a11y_fwd_apk = _get_accessibility_forwarder_apk() - # Install and setup the Accesssibility Forwarder. - install_request = adb_pb2.AdbRequest( - install_apk=adb_pb2.AdbRequest.InstallApk( - blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk), - ) - ) - install_response = self._env.execute_adb_call(install_request) - if install_response.status != adb_pb2.AdbResponse.Status.OK: - raise ValueError( - f'Could not install accessibility_forwarder.apk: {install_response}.' - ) - - def _enable_a11y_tree_logs(self) -> None: - enable_tree_logs_request = adb_pb2.AdbRequest( - send_broadcast=adb_pb2.AdbRequest.SendBroadcast( - action=( - 'accessibility_forwarder.intent.action.' - 'ENABLE_ACCESSIBILITY_TREE_LOGS' - ), - component=( - 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' - ), - ) - ) - enable_tree_logs_response = self._env.execute_adb_call( - enable_tree_logs_request - ) - if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK: - raise ValueError( - 'Could not enable accessibility tree logging: ' - f'{enable_tree_logs_response}.' - ) - - def _reset_enable_networking_attempts(self) -> None: - self._enable_networking_attempts_left = self._max_enable_networking_attempts - self._enabling_networking_future = None - self._a11y_exception = None - - def get_port(self): - return self._port - - def close(self): - self._server.stop(None) - logging.info('gRPC server stopped') - self._env.close() - - def attempt_enable_networking(self) -> None: - """Attempts to turn on networking within the Android device. - - Attempt to turn on the networking in the Android device, by: - - turning off airplane mode; - - turning on the wifi connection. - """ - self.execute_adb_call( - adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='airplane_mode_on', value='0' - ), - ) - ) - ) - time.sleep(1.0) - self.execute_adb_call( - adb_pb2.AdbRequest( - generic=adb_pb2.AdbRequest.GenericRequest( - args=[ - 'shell', - 'svc', - 'wifi', - 'enable', - ] - ) - ) - ) - time.sleep(1.0) - - def _configure_grpc(self) -> None: - """Configure networking and set the gRPC port in the AVD.""" - - if self._disable_other_network_traffic: - self.execute_adb_call( - adb_pb2.AdbRequest( - generic=adb_pb2.AdbRequest.GenericRequest( - args=[ - 'shell', - 'su', - '0', - 'iptables', - '-A', - 'OUTPUT', - '-p', - 'tcp', - '-d', - '10.0.2.2', - '--dport', - str(self._port), - '-j', - 'ACCEPT', - ] - ) - ) - ) - time.sleep(3.0) - self.execute_adb_call( - adb_pb2.AdbRequest( - generic=adb_pb2.AdbRequest.GenericRequest( - args=[ - 'shell', - 'su', - '0', - 'iptables', - '-A', - 'OUTPUT', - '-j', - 'DROP', - ] - ) - ) - ) - time.sleep(3.0) - - self.execute_adb_call( - adb_pb2.AdbRequest( - settings=adb_pb2.AdbRequest.SettingsRequest( - name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, - put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='no_proxy', value=f'10.0.2.2:{self._port}' - ), - ) - ) - ) - self.attempt_enable_networking() - self.execute_adb_call( - adb_pb2.AdbRequest( - send_broadcast=adb_pb2.AdbRequest.SendBroadcast( - action=( - 'accessibility_forwarder.intent.action.SET_GRPC --ei' - f' "port" {self._port}' - ), - component=( - 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' - ), - ) - ) - ) - - def _accumulate_and_return_a11y_info( - self, timer: float | None = None, get_env_observation: bool = True - ) -> dict[str, Any]: - """Accumulates and returns the latest a11y tree info and observation. - - Args: - timer: If larger than 0, the system will wait this long for a11y info to - accumulate before it returns a value. - get_env_observation: If False, the corresponding observation is not - introduced here. - - Returns: - a dict with a11y forest under key 'a11y_forest'. All other fields will - provide the observation, if requested. - """ - timer = timer or 0.0 - if timer > 0.0: - time.sleep(timer) - - if get_env_observation: - # Fetch observation. - new_ts = self._env.step({ - 'action_type': np.array( - android_action_type_lib.ActionType.REPEAT, - dtype=self._parent_action_spec['action_type'].dtype, - ), - }) - observation = new_ts.observation - else: - observation = {} - - extras = self.accumulate_new_extras() - forests = a11y_forests.extract_forests_from_task_extras(extras) - if forests: - observation['a11y_forest'] = forests[-1] - else: - observation['a11y_forest'] = None - return observation - - def _fetch_task_extras_and_update_observation( - self, observation: dict[str, Any], timeout: float = 0.0 - ) -> dict[str, Any]: - if timeout > 0.0: - observation = self._accumulate_and_return_a11y_info( - timeout, get_env_observation=True - ) - if not self._add_latest_a11y_info_to_obs: - observation.pop('a11y_forest') - else: - new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False) - if self._add_latest_a11y_info_to_obs: - observation.update(new_obs) - return observation - - def reset(self) -> dm_env.TimeStep: - self._reset_enable_networking_attempts() - self._servicer.pause_and_clear() - timestep = self._env.reset() - self._servicer.resume() - if self._env.stats()['relaunch_count'] > self._relaunch_count: - self._configure_grpc() - self._relaunch_count = self._env.stats()['relaunch_count'] - self._accumulated_extras = {} - timeout = self._a11y_info_timeout or 0.0 - new_observation = self._fetch_task_extras_and_update_observation( - timestep.observation, timeout - ) - timestep = timestep._replace(observation=new_observation) - return timestep - - def step(self, action: Any) -> dm_env.TimeStep: - timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0)) - timestep = self._env.step(action) - new_observation = self._fetch_task_extras_and_update_observation( - timestep.observation, timeout=timeout - ) - timestep = timestep._replace(observation=new_observation) - return timestep - - def accumulate_new_extras(self) -> dict[str, Any]: - new_extras = self._fetch_task_extras() - if self._should_accumulate: - for key in new_extras: - if key in self._accumulated_extras: - self._accumulated_extras[key] = np.concatenate( - (self._accumulated_extras[key], new_extras[key]), axis=0 - ) - else: - self._accumulated_extras[key] = new_extras[key] - else: - self._accumulated_extras = new_extras - self._should_accumulate = True - return self._accumulated_extras - - def _fetch_task_extras(self) -> dict[str, Any]: - """Fetches task_extras from the services. - - NOTE: If you want to access the latest a11y information, please use - accumulate_and_return_a11y_info instead. This function has the side effect - of clearing the content from the servicer, hence all the a11y info returned - here won't be accumulated. - - Returns: - A dict with the corresponding task_extras. - - Raises: - EnableNetworkingError: after a fixed number of attempts to revive the a11y - services by re-enabling the network connection. - """ - base_extras = self._env.task_extras(latest_only=False).copy() - # If the previous future is done, reset it to the initial state. - if ( - self._enabling_networking_future is not None - and self._enabling_networking_future.done() - ): - self._enabling_networking_future = None - self._enable_networking_attempts_left -= 1 - logging.info('Finished enabling networking.') - - if ( - self._enabling_networking_future is None - and 'exception' in base_extras - and base_extras['exception'].shape[0] - ): - self._a11y_exception = base_extras['exception'] - logging.warning( - 'AccessibilityForwarder logged exceptions: %s', self._a11y_exception - ) - if self._enable_networking_attempts_left > 0: - logging.warning( - 'Attempting to enable networking. %s attempts left.', - self._enable_networking_attempts_left - 1, - ) - executor = futures.ThreadPoolExecutor(max_workers=1) - self._enabling_networking_future = executor.submit( - self.attempt_enable_networking - ) - else: - raise EnableNetworkingError( - 'A11y service failed multiple times with' - f' exception.{self._a11y_exception}.' - ) - - forests = self._servicer.gather_forests() - if forests: - base_extras.update(a11y_forests.package_forests_to_task_extras(forests)) - self._reset_enable_networking_attempts() - events = self._servicer.gather_events() - if events: - base_extras.update(a11y_events.package_events_to_task_extras(events)) - self._reset_enable_networking_attempts() - return base_extras - - def task_extras(self, latest_only: bool = False) -> dict[str, Any]: - if self._accumulated_extras is None: - raise RuntimeError('You must call .reset() before calling .task_extras()') - self._should_accumulate = False - extras = self._accumulated_extras.copy() - if latest_only: - a11y_events.keep_latest_event_only(extras) - a11y_forests.keep_latest_forest_only(extras) - return extras diff --git a/android_env/wrappers/a11y_grpc_wrapper_test.py b/android_env/wrappers/a11y_grpc_wrapper_test.py deleted file mode 100644 index 1b81fd6e..00000000 --- a/android_env/wrappers/a11y_grpc_wrapper_test.py +++ /dev/null @@ -1,849 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for a11y_grpc_wrapper.""" - -import time -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env import env_interface -from android_env.proto import adb_pb2 -from android_env.proto.a11y import a11y_pb2 -from android_env.proto.a11y import a11y_pb2_grpc -from android_env.proto.a11y import android_accessibility_forest_pb2 -from android_env.wrappers import a11y_grpc_wrapper -import dm_env -import grpc -import numpy as np - - -def empty_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - return android_accessibility_forest_pb2.AndroidAccessibilityForest() - - -def one_empty_window_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - _ = forest.windows.add() - return forest - - -def one_window_one_node_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'foo' - node.is_clickable = True - node.hint_text = 'Foo hint' - return forest - - -def one_window_two_nodes_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'bar' - node.is_clickable = True - node.hint_text = 'Bar hint' - node = window.tree.nodes.add() - node.class_name = 'bar' - node.is_clickable = False - node.hint_text = 'Bar hint 2' - return forest - - -def three_windows_forest() -> ( - android_accessibility_forest_pb2.AndroidAccessibilityForest -): - forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() - _ = forest.windows.add() - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'foo' - node.is_clickable = True - node.hint_text = 'hint' - window = forest.windows.add() - node = window.tree.nodes.add() - node.class_name = 'baz' - node.is_clickable = True - node.hint_text = 'hint' - node = window.tree.nodes.add() - node.class_name = 'foobar' - node.is_clickable = False - node.hint_text = 'hint' - return forest - - -def empty_dict() -> dict[str, str]: - return {} - - -def single_item_dict() -> dict[str, str]: - return {'foo': 'bar'} - - -def several_long_items_dict() -> dict[str, str]: - return { - 'first_key': 'Lorem ipsum ' * 100, - 'second_key': 'the beginning is the end is' * 100, - } - - -def single_item_dict_with_special_chars() -> dict[str, str]: - return {'foo': 'bar\r\t\nbaz'} - - -def _ok_response(): - return adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) - - -class A11yGrpcWrapperTest(parameterized.TestCase): - - def test_server(self): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = {} - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - wrapped_env.reset() - channel_creds = grpc.local_channel_credentials() - with grpc.secure_channel( - f'[::]:{wrapped_env.get_port()}', channel_creds - ) as channel: - grpc.channel_ready_future(channel).result() - stub = a11y_pb2_grpc.A11yServiceStub(channel) - stub.SendForest(one_window_one_node_forest()) - stub.SendForest(one_window_two_nodes_forest()) - wrapped_env.step({}) - extras = wrapped_env.task_extras(latest_only=False) - self.assertIn('accessibility_tree', extras) - self.assertEqual(extras['accessibility_tree'].shape[0], 2) - - # tests of fetch_task_extras: - # exception occurs (ensure attempt to enable networking) and recovers - # exception occurs and enable networking doesn't help - # exception occurs twice but with a forest sent between - - @parameterized.named_parameters( - ('no_events_or_forests', [], []), - ( - 'no_events', - [], - [one_window_one_node_forest(), one_window_two_nodes_forest()], - ), - ('no_forests', [empty_dict(), single_item_dict()], []), - ( - 'events_and_forests', - [empty_dict(), single_item_dict()], - [one_window_one_node_forest(), one_window_two_nodes_forest()], - ), - ) - @mock.patch.object(time, 'sleep', autospec=True) - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_fetch_task_extras( - self, - received_events, - received_forests, - mock_server, - mock_add_servicer, - mock_sleep, - ): - del mock_server, mock_add_servicer, mock_sleep - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = { - 'foo': np.array(['bar', 'baz'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - wrapped_env.reset() - for forest in received_forests: - wrapped_env._servicer.SendForest(forest, mock_context) - for event in received_events: - wrapped_env._servicer.SendEvent( - a11y_pb2.EventRequest(event=event), mock_context - ) - with mock.patch.object( - wrapped_env, 'attempt_enable_networking' - ) as mock_attempt_enable_networking: - extras = wrapped_env._fetch_task_extras() - mock_attempt_enable_networking.assert_not_called() - self.assertIn('foo', extras) - np.testing.assert_array_equal(extras['foo'], ['bar', 'baz']) - self.assertIn('some_key', extras) - np.testing.assert_array_equal(extras['some_key'], ['some_value']) - if received_events: - self.assertIn('full_event', extras) - self.assertLen(extras['full_event'], len(received_events)) - for i, event in enumerate(received_events): - event = a11y_pb2.EventRequest(event=event) - np.testing.assert_array_equal(extras['full_event'][i], event) - else: - self.assertNotIn('full_event', extras) - if received_forests: - self.assertIn('accessibility_tree', extras) - self.assertLen(extras['accessibility_tree'], len(received_forests)) - for i, forest in enumerate(received_forests): - np.testing.assert_array_equal(extras['accessibility_tree'][i], forest) - else: - self.assertNotIn('accessibility_tree', extras) - - @mock.patch.object(time, 'sleep', autospec=True) - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_fetch_task_extras_enable_networking( - self, - mock_server, - mock_add_servicer, - mock_sleep, - ): - del mock_server, mock_add_servicer, mock_sleep - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - 'exception': np.array(['fake exception'], dtype='U'), - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - with mock.patch.object( - wrapped_env, 'attempt_enable_networking' - ) as mock_attempt_enable_networking: - extras = wrapped_env._fetch_task_extras() - self.assertNotIn('accessibility_tree', extras) - self.assertNotIn('full_event', extras) - future = wrapped_env._enabling_networking_future - if future is not None: - future.result() - mock_attempt_enable_networking.assert_called_once() - - @mock.patch.object(time, 'sleep', autospec=True) - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_fetch_task_extras_enable_networking_twice( - self, - mock_server, - mock_add_servicer, - mock_sleep, - ): - del mock_server, mock_add_servicer, mock_sleep - mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - } - - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - wrapped_env.reset() - - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - 'exception': np.array(['fake exception'], dtype='U'), - } - with mock.patch.object( - wrapped_env, 'attempt_enable_networking' - ) as mock_attempt_enable_networking: - extras = wrapped_env._fetch_task_extras() - self.assertNotIn('accessibility_tree', extras) - self.assertNotIn('full_event', extras) - future = wrapped_env._enabling_networking_future - if future is not None: - future.result() - mock_attempt_enable_networking.assert_called_once() - # Fixed networking; send a forest so the wrapper knows it worked. - wrapped_env._servicer.SendForest(one_window_one_node_forest(), mock_context) - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - } - extras = wrapped_env._fetch_task_extras() - self.assertIn('accessibility_tree', extras) - self.assertEqual(extras['accessibility_tree'].shape[0], 1) - self.assertEqual( - extras['accessibility_tree'][0], one_window_one_node_forest() - ) - - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - 'exception': np.array(['fake exception'], dtype='U'), - } - with mock.patch.object( - wrapped_env, 'attempt_enable_networking' - ) as mock_attempt_enable_networking: - extras = wrapped_env._fetch_task_extras() - self.assertNotIn('accessibility_tree', extras) - self.assertNotIn('full_event', extras) - future = wrapped_env._enabling_networking_future - if future is not None: - future.result() - mock_attempt_enable_networking.assert_called_once() - - @mock.patch.object(time, 'sleep', autospec=True) - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_task_extras_raises_a11y_info_exception( - self, mock_sleep, mock_add_servicer, mock_server - ): - del mock_server, mock_add_servicer, mock_sleep - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - } - - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, - add_latest_a11y_info_to_obs=True, - max_enable_networking_attempts=1, - ) - wrapped_env.reset() - - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - 'exception': np.array(['fake exception'], dtype='U'), - } - with mock.patch.object( - wrapped_env, 'attempt_enable_networking' - ) as mock_attempt_enable_networking: - extras = wrapped_env._fetch_task_extras() - self.assertNotIn('accessibility_tree', extras) - self.assertNotIn('full_event', extras) - # Wait for the the attempt to finish. - future = wrapped_env._enabling_networking_future - if future is not None: - future.result() - mock_attempt_enable_networking.assert_called_once() - # The _fetch_task_extras() call inside the next step will force a restart - self.assertRaises( - a11y_grpc_wrapper.EnableNetworkingError, wrapped_env.step, {} - ) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_configure_grpc( - self, - mock_server, - mock_add_servicer, - ): - del mock_server, mock_add_servicer - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.task_extras.return_value = { - 'foo': np.array(['bar'], dtype='U'), - 'some_key': np.array(['some_value'], dtype='U'), - } - - base_env.stats.return_value = {'relaunch_count': 1} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - with mock.patch.object( - wrapped_env, '_configure_grpc' - ) as mock_configure_grpc: - wrapped_env.reset() - mock_configure_grpc.assert_called_once() - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_task_extras_raises_before_reset( - self, unused_mock_server, unused_mock_add_servicer - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - with self.assertRaisesRegex( - RuntimeError, - r'You must call \.reset\(\) before calling \.task_extras\(\)', - ): - wrapped_env.task_extras(latest_only=False) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_extras_accumulate_between_steps( - self, mock_server, mock_add_servicer - ): - del mock_server, mock_add_servicer - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, add_latest_a11y_info_to_obs=True - ) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.return_value = { - 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - } - timestep = wrapped_env.reset() - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) - wrapped_env._fetch_task_extras.return_value = { - 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array( - one_window_two_nodes_forest(), ndmin=1, dtype=object - ), - } - timestep = wrapped_env.step({}) - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual( - timestep.observation['a11y_forest'], one_window_two_nodes_forest() - ) - timestep = wrapped_env.step({}) - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual( - timestep.observation['a11y_forest'], one_window_two_nodes_forest() - ) - wrapped_env._fetch_task_extras.return_value = { - 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), - } - timestep = wrapped_env.step({}) - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual( - timestep.observation['a11y_forest'], one_window_two_nodes_forest() - ) - expected_task_extras = { - 'full_event': np.array( - [ - single_item_dict(), - empty_dict(), - empty_dict(), - single_item_dict(), - ], - dtype=object, - ), - 'accessibility_tree': np.array( - [ - empty_forest(), - one_window_two_nodes_forest(), - one_window_two_nodes_forest(), - ], - dtype=object, - ), - } - expected_task_extras_latest = { - 'full_event': np.array([single_item_dict()], dtype=object), - 'accessibility_tree': np.array( - [one_window_two_nodes_forest()], dtype=object - ), - } - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - task_extras = wrapped_env.task_extras(latest_only=True) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras_latest['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras_latest['accessibility_tree'], - ) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_a11y_info_disabled( - self, - unused_mock_server, - unused_mock_add_servicer, - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.action_spec.return_value = { - 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, add_latest_a11y_info_to_obs=False, a11y_info_timeout=1.0 - ) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.return_value = { - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - } - timestep = wrapped_env.reset() - self.assertNotIn('a11y_forest', timestep.observation) - timestep = wrapped_env.step({}) - self.assertNotIn('a11y_forest', timestep.observation) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_a11y_info_with_timer_info_present( - self, - unused_mock_server, - unused_mock_add_servicer, - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.action_spec.return_value = { - 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 - ) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.side_effect = [{ - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - }] - timestep = wrapped_env.reset() - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_a11y_info_with_timer_task_extra_returned( - self, unused_mock_server, unused_mock_add_servicer, unused_mock_sleep - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.action_spec.return_value = { - 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 - ) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.side_effect = [ - { - 'accessibility_tree': np.array( - empty_forest(), ndmin=1, dtype=object - ), - }, - ] - timestep = wrapped_env.reset() - self.assertIn('a11y_forest', timestep.observation) - self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - @mock.patch.object(time, 'sleep', autospec=True) - def test_a11y_info_with_timer_from_action( - self, unused_mock_server, unused_mock_add_servicer, mock_sleep - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.action_spec.return_value = { - 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) - } - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) - base_env.step.return_value = dm_env.transition( - observation={'dummy': 42}, reward=0.0 - ) - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=0.0 - ) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.side_effect = [ - { - 'accessibility_tree': np.array( - empty_forest(), ndmin=1, dtype=object - ), - }, - ] - timestep = wrapped_env.step(action={'wait_time': 1.0}) - self.assertIn('a11y_forest', timestep.observation) - mock_sleep.assert_called_once() - self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_task_extras_same_between_calls(self, mock_server, mock_add_servicer): - del mock_server, mock_add_servicer - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - expected_task_extras = { - 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - } - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.return_value = expected_task_extras - wrapped_env.reset() - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - expected_task_extras = { - 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array( - one_window_two_nodes_forest(), ndmin=1, dtype=object - ), - } - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - wrapped_env._fetch_task_extras.return_value = expected_task_extras - wrapped_env.step({}) - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - @mock.patch.object( - a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True - ) - @mock.patch.object(grpc, 'server', autospec=True) - def test_task_extras_clear_if_called_between_step( - self, mock_server, mock_add_servicer - ): - del mock_server, mock_add_servicer - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - base_env.stats.return_value = {'relaunch_count': 0} - base_env.execute_adb_call.return_value = _ok_response() - wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) - with mock.patch.object(wrapped_env, '_fetch_task_extras'): - expected_task_extras = { - 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - } - wrapped_env._fetch_task_extras.return_value = expected_task_extras - wrapped_env.reset() - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - expected_task_extras = { - 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), - } - wrapped_env._fetch_task_extras.return_value = expected_task_extras - wrapped_env.step({}) - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - expected_task_extras = { - 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), - 'accessibility_tree': np.array( - one_window_two_nodes_forest(), ndmin=1, dtype=object - ), - } - wrapped_env._fetch_task_extras.return_value = expected_task_extras - wrapped_env.step({}) - task_extras = wrapped_env.task_extras(latest_only=False) - np.testing.assert_equal( - task_extras['full_event'], expected_task_extras['full_event'] - ) - np.testing.assert_equal( - task_extras['accessibility_tree'], - expected_task_extras['accessibility_tree'], - ) - - @parameterized.named_parameters( - ('none_true', False, False, False, 0), - ('only_install', True, False, False, 1), - ('only_start', False, True, False, 1), - ('only_enable_a11y_tree', False, False, True, 1), - ('install_and_start_no_a11y_tree', True, True, False, 2), - ('install_and_a11y_tree', True, False, True, 2), - ('start_and_a11y_tree', False, True, True, 2), - ('all_true', True, True, True, 3), - ) - @mock.patch.object(time, 'sleep', autospec=True) - def test_apk_install_and_start( - self, - install_a11y_forwarding: bool, - start_a11y_service: bool, - enable_a11y_tree_logs: bool, - expected_adb_calls: int, - unused_mock_sleep, - ): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - - side_effects = [] - if install_a11y_forwarding: - side_effects.append(_ok_response()) # install response - if start_a11y_service: - side_effects.append(_ok_response()) # start service response - if enable_a11y_tree_logs: - side_effects.append(_ok_response()) # enable_tree_request - - base_env.execute_adb_call.side_effect = side_effects - - _ = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, - install_a11y_forwarding=install_a11y_forwarding, - start_a11y_service=start_a11y_service, - enable_a11y_tree_info=enable_a11y_tree_logs, - ) - self.assertEqual(base_env.execute_adb_call.call_count, expected_adb_calls) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_component_and_start(self, unused_mock_sleep): - base_env = mock.create_autospec( - env_interface.AndroidEnvInterface, instance=True - ) - - side_effects = [] - side_effects.append(_ok_response()) # install response - side_effects.append(_ok_response()) # start service response - side_effects.append(_ok_response()) # enable_tree_request - - base_env.execute_adb_call.side_effect = side_effects - - _ = a11y_grpc_wrapper.A11yGrpcWrapper( - base_env, - install_a11y_forwarding=True, - start_a11y_service=True, - enable_a11y_tree_info=True, - ) - - # call_args returns a tuple of which the first member is a tuple containing - # the most recent args the mock was called with, and execute_adb_call only - # has one arg (so [0][0] to access the AdbRequest). - self.assertEqual( - base_env.execute_adb_call.call_args[0][0].send_broadcast.component, - 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver', - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/base_wrapper.py b/android_env/wrappers/base_wrapper.py deleted file mode 100644 index 68aea7e4..00000000 --- a/android_env/wrappers/base_wrapper.py +++ /dev/null @@ -1,124 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base class for AndroidEnv wrappers.""" - -from typing import Any - -from absl import logging -from android_env import env_interface -from android_env.proto import adb_pb2 -from android_env.proto import state_pb2 -import dm_env -from dm_env import specs -import numpy as np - - -class BaseWrapper(env_interface.AndroidEnvInterface): - """AndroidEnv wrapper.""" - - def __init__(self, env): - self._env = env - logging.info('Wrapping with %s', self.__class__.__name__) - - def reset(self) -> dm_env.TimeStep: - self._reset_state() - timestep = self._process_timestep(self._env.reset()) - return timestep - - def step(self, action: Any) -> dm_env.TimeStep: - action = self._process_action(action) - return self._process_timestep(self._env.step(action)) - - def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: - return self._env.task_extras(latest_only=latest_only) - - def _reset_state(self): - pass - - def _process_action(self, action: Any) -> Any: - return action - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - return timestep - - def observation_spec(self) -> dict[str, specs.Array]: - return self._env.observation_spec() - - def action_spec(self) -> dict[str, specs.Array]: - return self._env.action_spec() - - def reward_spec(self) -> specs.Array: - return self._env.reward_spec() - - def discount_spec(self) -> specs.Array: - return self._env.discount_spec() - - def _wrapper_stats(self) -> dict[str, Any]: - """Add wrapper specific logging here.""" - return {} - - def stats(self) -> dict[str, Any]: - info = self._env.stats() - info.update(self._wrapper_stats()) - return info - - def load_state( - self, request: state_pb2.LoadStateRequest - ) -> state_pb2.LoadStateResponse: - """Loads a state.""" - return self._env.load_state(request) - - def save_state( - self, request: state_pb2.SaveStateRequest - ) -> state_pb2.SaveStateResponse: - """Saves a state. - - Args: - request: A `SaveStateRequest` containing any parameters necessary to - specify how/what state to save. - - Returns: - A `SaveStateResponse` containing the status, error message (if - applicable), and any other relevant information. - """ - return self._env.save_state(request) - - def execute_adb_call(self, - adb_call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: - return self._env.execute_adb_call(adb_call) - - @property - def raw_action(self): - return self._env.raw_action - - @property - def raw_observation(self): - return self._env.raw_observation - - @property - def raw_env(self): - """Recursively unwrap until we reach the true 'raw' env.""" - wrapped = self._env - if hasattr(wrapped, 'raw_env'): - return wrapped.raw_env - return wrapped - - def __getattr__(self, attr): - """Delegate attribute access to underlying environment.""" - return getattr(self._env, attr) - - def close(self): - self._env.close() diff --git a/android_env/wrappers/base_wrapper_test.py b/android_env/wrappers/base_wrapper_test.py deleted file mode 100644 index a690139d..00000000 --- a/android_env/wrappers/base_wrapper_test.py +++ /dev/null @@ -1,154 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.base_wrapper.""" - -from unittest import mock - -from absl import logging -from absl.testing import absltest -from android_env import env_interface -from android_env.proto import state_pb2 -from android_env.wrappers import base_wrapper - - -class BaseWrapperTest(absltest.TestCase): - - @mock.patch.object(logging, 'info') - def test_base_function_forwarding(self, mock_info): - base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - wrapped_env = base_wrapper.BaseWrapper(base_env) - mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper') - - fake_ts = 'fake_ts' - base_env.reset.return_value = fake_ts - self.assertEqual(fake_ts, wrapped_env.reset()) - base_env.reset.assert_called_once() - - fake_ts = 'fake_ts' - fake_action = 'fake_action' - base_env.step.return_value = fake_ts - self.assertEqual(fake_ts, wrapped_env.step(fake_action)) - base_env.step.assert_called_once_with(fake_action) - - fake_extras = 'fake_task_extras' - base_env.task_extras.return_value = fake_extras - self.assertEqual(fake_extras, wrapped_env.task_extras(latest_only=True)) - base_env.task_extras.assert_called_once_with(latest_only=True) - - fake_obs_spec = 'fake_obs_spec' - base_env.observation_spec.return_value = fake_obs_spec - self.assertEqual(fake_obs_spec, wrapped_env.observation_spec()) - base_env.observation_spec.assert_called_once() - - fake_action_spec = 'fake_action_spec' - base_env.action_spec.return_value = fake_action_spec - self.assertEqual(fake_action_spec, wrapped_env.action_spec()) - base_env.action_spec.assert_called_once() - - fake_raw_action = 'fake_raw_action' - type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action) - self.assertEqual(fake_raw_action, wrapped_env.raw_action) - - fake_raw_observation = 'fake_raw_observation' - type(base_env).raw_observation = mock.PropertyMock( - return_value=fake_raw_observation) - self.assertEqual(fake_raw_observation, wrapped_env.raw_observation) - - load_request = state_pb2.LoadStateRequest(args={}) - expected_response = state_pb2.LoadStateResponse( - status=state_pb2.LoadStateResponse.Status.OK - ) - base_env.load_state.return_value = expected_response - self.assertEqual(wrapped_env.load_state(load_request), expected_response) - base_env.load_state.assert_called_once_with(load_request) - - save_request = state_pb2.SaveStateRequest(args={}) - expected_response = state_pb2.SaveStateResponse( - status=state_pb2.SaveStateResponse.Status.OK - ) - base_env.save_state.return_value = expected_response - self.assertEqual(wrapped_env.save_state(save_request), expected_response) - base_env.save_state.assert_called_once_with(save_request) - - wrapped_env.close() - base_env.close.assert_called_once() - - fake_return_value = 'fake' - # AndroidEnv::some_random_function() does not exist and calling it should - # raise an AttributeError. - with self.assertRaises(AttributeError): - base_env.some_random_function.return_value = fake_return_value - - def test_multiple_wrappers(self): - base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - wrapped_env_1 = base_wrapper.BaseWrapper(base_env) - wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1) - - wrapped_env_2.close() - base_env.close.assert_called_once() - - def test_raw_env(self): - base_env = 'fake_env' - wrapped_env_1 = base_wrapper.BaseWrapper(base_env) - wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1) - self.assertEqual(base_env, wrapped_env_2.raw_env) - - def test_stats(self): - base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - wrapped_env = base_wrapper.BaseWrapper(base_env) - base_stats = {'base': 'stats'} - base_env.stats.return_value = base_stats - self.assertEqual(base_stats, wrapped_env.stats()) - - @mock.patch.object(logging, 'info') - def test_wrapped_stats(self, mock_info): - base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - - class LoggingWrapper1(base_wrapper.BaseWrapper): - - def _wrapper_stats(self): - return { - 'wrapper1': 'stats', - 'shared': 1, - } - - class LoggingWrapper2(base_wrapper.BaseWrapper): - - def _wrapper_stats(self): - return { - 'wrapper2': 'stats', - 'shared': 2, - } - - wrapped_env = LoggingWrapper2(LoggingWrapper1(base_env)) - mock_info.assert_has_calls([ - mock.call('Wrapping with %s', 'LoggingWrapper1'), - mock.call('Wrapping with %s', 'LoggingWrapper2'), - ]) - base_stats = {'base': 'stats'} - base_env.stats.return_value = base_stats - expected_stats = { - 'base': 'stats', - 'wrapper1': 'stats', - 'wrapper2': 'stats', - 'shared': 2, - } - - self.assertEqual(expected_stats, wrapped_env.stats()) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/discrete_action_wrapper.py b/android_env/wrappers/discrete_action_wrapper.py deleted file mode 100644 index 7bd483b7..00000000 --- a/android_env/wrappers/discrete_action_wrapper.py +++ /dev/null @@ -1,161 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps the AndroidEnv environment to provide discrete actions.""" - -from collections.abc import Sequence - -from android_env.components import action_type -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -_NOISE_CLIP_VALUE = 0.4999 - - -class DiscreteActionWrapper(base_wrapper.BaseWrapper): - """AndroidEnv with discrete actions.""" - - def __init__( - self, - env: dm_env.Environment, - action_grid: Sequence[int] = (10, 10), - redundant_actions: bool = True, - noise: float = 0.1, - ): - super().__init__(env) - self._parent_action_spec = self._env.action_spec() - self._assert_base_env() - self._action_grid = action_grid # [height, width] - self._grid_size = np.prod(self._action_grid) - self._num_action_types = self._parent_action_spec['action_type'].num_values - self._redundant_actions = redundant_actions - self._noise = noise - - def _assert_base_env(self): - """Checks that the wrapped env has the right action spec format.""" - - assert len(self._parent_action_spec) == 2 - assert not self._parent_action_spec['action_type'].shape - assert self._parent_action_spec['touch_position'].shape == (2,) - - @property - def num_actions(self) -> int: - """Number of discrete actions.""" - - if self._redundant_actions: - return self._grid_size * self._num_action_types - else: - return self._grid_size + self._num_action_types - 1 - - def step(self, action: dict[str, int]) -> dm_env.TimeStep: - """Take a step in the base environment.""" - - return self._env.step(self._process_action(action)) - - def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]: - """Transforms action so that it agrees with AndroidEnv's action spec.""" - - return { - 'action_type': - np.array(self._get_action_type(action['action_id']), - dtype=self._parent_action_spec['action_type'].dtype), - 'touch_position': - np.array(self._get_touch_position(action['action_id']), - dtype=self._parent_action_spec['touch_position'].dtype) - } - - def _get_action_type(self, action_id: int) -> action_type.ActionType: - """Compute action type corresponding to the given action_id. - - When `self._redundant_actions` == True the `grid_size` is "broadcast" over - all the possible actions so you end up with `grid_size` discrete actions - of type 0, `grid_size` discrete actions of type 1, etc. for all action - types. - - When `self._redundant_actions` == False the first `grid_size` actions are - reserved for "touch" and the rest are just added (NOT multiplied) to the - total number of discrete actions (exactly one of LIFT and REPEAT). - - Args: - action_id: A discrete action. - Returns: - action_type: The action_type of the action. - """ - - if self._redundant_actions: - assert action_id < self._num_action_types * self._grid_size - return action_id // self._grid_size - - else: - assert action_id <= self._grid_size + 1 - if action_id < self._grid_size: - return action_type.ActionType.TOUCH - elif action_id == self._grid_size: - return action_type.ActionType.LIFT - else: - return action_type.ActionType.REPEAT - - def _get_touch_position(self, action_id: int) -> Sequence[float]: - """Compute the position corresponding to the given action_id. - - Note: in the touch_position (x, y) of an action, x corresponds to the - horizontal axis (width), and y corresponds to the vertical axis (height) - of the screen. BUT, the screen has dimensions (height, width), i.e. the - first coordinate corresponds to y, and the second coordinate corresponds - to x. Pay attention to this mismatch in the calculations below. - - Args: - action_id: A discrete action. - Returns: - touch_position: The [0,1]x[0,1] coordinate of the action. - """ - - position_idx = action_id % self._grid_size - - x_pos_grid = position_idx % self._action_grid[1] # WIDTH - y_pos_grid = position_idx // self._action_grid[1] # HEIGHT - - noise_x = np.random.normal(loc=0.0, scale=self._noise) - noise_y = np.random.normal(loc=0.0, scale=self._noise) - - # Noise is clipped so that the action will strictly stay in the cell. - noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE) - noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE) - - x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH - y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT - - # Project action space to action_spec ranges. For the default case of - # minimum = [0, 0] and maximum = [1, 1], this will not do anything. - x_min, y_min = self._parent_action_spec['touch_position'].minimum - x_max, y_max = self._parent_action_spec['touch_position'].maximum - - x_pos = x_min + x_pos * (x_max - x_min) - y_pos = y_min + y_pos * (y_max - y_min) - - return [x_pos, y_pos] - - def action_spec(self) -> dict[str, specs.Array]: - """Action spec of the wrapped environment.""" - - return { - 'action_id': - specs.DiscreteArray( - num_values=self.num_actions, - name='action_id') - } diff --git a/android_env/wrappers/discrete_action_wrapper_test.py b/android_env/wrappers/discrete_action_wrapper_test.py deleted file mode 100644 index bc9d4329..00000000 --- a/android_env/wrappers/discrete_action_wrapper_test.py +++ /dev/null @@ -1,386 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.discrete_action_wrapper.""" - -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env.components import action_type as action_type_lib -from android_env.wrappers import discrete_action_wrapper -from dm_env import specs -import numpy as np - -ActionType = action_type_lib.ActionType - - -def _make_array_spec(shape, dtype, name): - assert len(shape) == 1 - return specs.BoundedArray( - name=name, - shape=shape, - dtype=dtype, - minimum=np.zeros(shape), - maximum=(shape[0] - 1) * np.ones(shape), # maximum is inclusive. - ) - - -def _valid_shape(action): - assert len(action) == 2, action - assert not action['action_type'].shape, ( - 'action: %r, shape: %r' % - (action['action_type'], action['action_type'].shape)) - assert action['touch_position'].shape == ( - 2,), ('action: %r, shape: %r' % - (action['touch_position'], action['touch_position'].shape)) - - -def _valid_types(action, types): - for a, t in zip(action.values(), types): - assert a.dtype == t, '%r is not of dtype %r' % (a, t) - - -class DiscreteActionWrapperTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._num_action_types = 3 # Only TOUCH, LIFT, REPEAT. - self._base_action_spec = { - 'action_type': specs.DiscreteArray( - num_values=self._num_action_types, name='action_type'), - 'touch_position': _make_array_spec( - shape=(2,), dtype=np.float32, name='touch_position'), - } - self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - self.base_env.action_spec.return_value = self._base_action_spec - - def test_num_actions(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, action_grid=(3, 3), redundant_actions=True) - # 27 = 3 * 3 * 2 (H * W * self._num_action_types). - self.assertEqual(27, wrapped_env.num_actions) - - def test_num_actions_non_redundant(self): - # Check that with `redundant_actions`==False we get an additive term instead - # of a multiplier in the number of actions. - non_redudant_wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, action_grid=(3, 3), redundant_actions=False) - # 11 = 3 * 3 + 2 (H * W + (self._num_action_types - 1)). - self.assertEqual(11, non_redudant_wrapped_env.num_actions) - - def test_reset(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, redundant_actions=True) - fake_timestep = 'ts' - self.base_env.reset.return_value = fake_timestep - ts = wrapped_env.reset() - self.base_env.reset.assert_called_once() - self.assertEqual(fake_timestep, ts) - - def test_step_no_noise(self): - height = 4 - width = 3 - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, - action_grid=(height, width), - noise=0.0, - redundant_actions=True) - self.assertEqual(height * width * self._num_action_types, - wrapped_env.num_actions) - - vertical_half_step = 1. / float(height) / 2. - horizontal_half_step = 1. / float(width) / 2. - - delta = 0.0001 - - # Testing the four corners with each finger position - def get_verifier(expected_action_type, lower_x, lower_y): - - def verifier(x): - _valid_shape(x) - _valid_types(x, [np.int32, np.float32]) - self.assertEqual( - expected_action_type, x['action_type']) - if lower_y: - self.assertAlmostEqual( - vertical_half_step, x['touch_position'][1], delta=delta) - else: - self.assertAlmostEqual( - 1 - vertical_half_step, x['touch_position'][1], delta=delta) - if lower_x: - self.assertAlmostEqual( - horizontal_half_step, x['touch_position'][0], delta=delta) - else: - self.assertAlmostEqual( - 1 - horizontal_half_step, x['touch_position'][0], delta=delta) - return True - - return verifier - - action_tests = { - 0: get_verifier(0, lower_x=True, lower_y=True), - 2: get_verifier(0, lower_x=False, lower_y=True), - 9: get_verifier(0, lower_x=True, lower_y=False), - 11: get_verifier(0, lower_x=False, lower_y=False), - - 12: get_verifier(1, lower_x=True, lower_y=True), - 14: get_verifier(1, lower_x=False, lower_y=True), - 21: get_verifier(1, lower_x=True, lower_y=False), - 23: get_verifier(1, lower_x=False, lower_y=False), - - 24: get_verifier(2, lower_x=True, lower_y=True), - 26: get_verifier(2, lower_x=False, lower_y=True), - 33: get_verifier(2, lower_x=True, lower_y=False), - 35: get_verifier(2, lower_x=False, lower_y=False), - } - - fake_timestep = 'ts' - self.base_env.step.return_value = fake_timestep - - for action_id, verifier in action_tests.items(): - ts = wrapped_env.step({'action_id': action_id}) - verifier(self.base_env.step.call_args[0][0]) - self.assertEqual(fake_timestep, ts) - - def test_step_redundant_actions_invalid_action_id(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, - action_grid=(4, 3), - noise=0.0, - redundant_actions=True) - with self.assertRaises(AssertionError): - _ = wrapped_env.step({'action_id': 36}) - - def test_step_no_noise_no_redudant_actions(self): - height = 4 - width = 3 - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, - action_grid=(height, width), - noise=0.0, - redundant_actions=False) - self.assertEqual(height * width + (self._num_action_types - 1), - wrapped_env.num_actions) - - vertical_half_step = 1. / float(height) / 2. - horizontal_half_step = 1. / float(width) / 2. - - delta = 0.0001 - - # Testing the four corners with each finger position - def get_verifier(expected_action_type, lower_x, lower_y): - - def verifier(x): - _valid_shape(x) - _valid_types(x, [np.int32, np.float32]) - self.assertEqual(expected_action_type, x['action_type']) - # If the action type == TOUCH, then check the coordinate values. - if x['action_type'] == ActionType.TOUCH: - if lower_y: - self.assertAlmostEqual( - vertical_half_step, x['touch_position'][1], delta=delta) - else: - self.assertAlmostEqual( - 1 - vertical_half_step, x['touch_position'][1], delta=delta) - if lower_x: - self.assertAlmostEqual( - horizontal_half_step, x['touch_position'][0], delta=delta) - else: - self.assertAlmostEqual( - 1 - horizontal_half_step, x['touch_position'][0], delta=delta) - return True - - return verifier - - action_tests = { - # Touch type actions - 0: get_verifier(0, lower_x=True, lower_y=True), - 2: get_verifier(0, lower_x=False, lower_y=True), - 9: get_verifier(0, lower_x=True, lower_y=False), - 11: get_verifier(0, lower_x=False, lower_y=False), - # Actions > grid_size return non-touch actions with (0,0) coordinates. - 12: get_verifier(1, lower_x=False, lower_y=False), - 13: get_verifier(2, lower_x=False, lower_y=False), - } - - fake_timestep = 'ts' - self.base_env.step.return_value = fake_timestep - - for action_id, verifier in action_tests.items(): - ts = wrapped_env.step({'action_id': action_id}) - verifier(self.base_env.step.call_args[0][0]) - self.assertEqual(fake_timestep, ts) - - def test_step_no_redundant_actions_invalid_action_id(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, - action_grid=(4, 3), - noise=0.0, - redundant_actions=False) - with self.assertRaises(AssertionError): - _ = wrapped_env.step({'action_id': 14}) - - def test_step_with_noise(self): - height = 4 - width = 3 - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, action_grid=(height, width), noise=1.0) - self.assertEqual(height * width * self._num_action_types, - wrapped_env.num_actions) - - vertical_grid_step = 1. / float(height) - horizontal_grid_step = 1. / float(width) - - # Testing the four corners with each finger position - def get_verifier(expected_up_down, lower_x, lower_y): - - def verifier(x): - _valid_shape(x) - _valid_types(x, [np.int32, np.float32]) - self.assertEqual(expected_up_down, x['action_type']) - if lower_y: - self.assertGreater(vertical_grid_step, x['touch_position'][1]) - else: - self.assertLess(1 - vertical_grid_step, x['touch_position'][1]) - if lower_x: - self.assertGreater(horizontal_grid_step, x['touch_position'][0]) - else: - self.assertLess(1 - horizontal_grid_step, x['touch_position'][0]) - return True - - return verifier - - action_tests = { - 0: get_verifier(0, lower_x=True, lower_y=True), - 2: get_verifier(0, lower_x=False, lower_y=True), - 9: get_verifier(0, lower_x=True, lower_y=False), - 11: get_verifier(0, lower_x=False, lower_y=False), - - 12: get_verifier(1, lower_x=True, lower_y=True), - 14: get_verifier(1, lower_x=False, lower_y=True), - 21: get_verifier(1, lower_x=True, lower_y=False), - 23: get_verifier(1, lower_x=False, lower_y=False), - - 24: get_verifier(2, lower_x=True, lower_y=True), - 26: get_verifier(2, lower_x=False, lower_y=True), - 33: get_verifier(2, lower_x=True, lower_y=False), - 35: get_verifier(2, lower_x=False, lower_y=False), - } - - fake_timestep = 'ts' - self.base_env.step.return_value = fake_timestep - - for action_id, verifier in action_tests.items(): - ts = wrapped_env.step({'action_id': action_id}) - verifier(self.base_env.step.call_args[0][0]) - self.assertEqual(fake_timestep, ts) - - def test_parent_spec_type(self): - base_action_spec = { - 'action_type': specs.DiscreteArray( - num_values=self._num_action_types, name='action_type'), - 'touch_position': _make_array_spec( - shape=(2,), dtype=np.float64, name='touch_position'), - } - base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - base_env.action_spec.return_value = base_action_spec - - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - base_env, noise=0.0) - - fake_timestep = 'ts' - base_env.step.return_value = fake_timestep - - def verifier(x): - _valid_types(x, [np.int32, np.float64]) - return True - - ts = wrapped_env.step({'action_id': 1}) - verifier(base_env.step.call_args[0][0]) - self.assertEqual(fake_timestep, ts) - - def test_observation_spec(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env) - fake_obs_spec = 'fake_obs_spec' - self.base_env.observation_spec.return_value = fake_obs_spec - observation_spec = wrapped_env.observation_spec() - self.base_env.observation_spec.assert_called_once() - self.assertEqual(fake_obs_spec, observation_spec) - - def test_action_spec(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, action_grid=(4, 5), redundant_actions=True) - expected_action_spec = { - 'action_id': - specs.DiscreteArray( - num_values=4 * 5 * self._num_action_types, name='action_type') - } - self.assertEqual(expected_action_spec, wrapped_env.action_spec()) - - def test_action_spec_non_redundant(self): - wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( - self.base_env, action_grid=(4, 5), redundant_actions=False) - num_non_touch_actions = self._num_action_types - 1 - expected_action_spec = { - 'action_id': - specs.DiscreteArray( - num_values=4 * 5 + num_non_touch_actions, name='action_type') - } - self.assertEqual(expected_action_spec, wrapped_env.action_spec()) - - def test_assert_base_env_action_spec_too_short(self): - self.base_env.action_spec.return_value = { - 'action_type': specs.DiscreteArray( - num_values=self._num_action_types, name='action_type'), - } - with self.assertRaises(AssertionError): - _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) - - def test_assert_base_env_action_spec_too_long(self): - self.base_env.action_spec.return_value = { - 'action_type': specs.DiscreteArray( - num_values=self._num_action_types, name='action_type'), - 'touch_position': _make_array_spec( - shape=(2,), dtype=np.float32, name='touch_position'), - 'too_long': _make_array_spec( - shape=(1,), dtype=np.float32, name='too_long'), - } - with self.assertRaises(AssertionError): - _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) - - def test_assert_base_env_action_spec_wrong_shapes(self): - self.base_env.action_spec.return_value = { - 'action_type': _make_array_spec( - shape=(2,), dtype=np.float32, name='action_type'), - 'touch_position': _make_array_spec( - shape=(1,), dtype=np.float32, name='touch_position') - } - with self.assertRaises(AssertionError): - _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) - - def test_assert_base_env_ok(self): - self.base_env.action_spec.return_value = { - 'action_type': specs.DiscreteArray( - num_values=self._num_action_types, name='action_type'), - 'touch_position': _make_array_spec( - shape=(2,), dtype=np.float32, name='touch_position'), - } - _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/flat_interface_wrapper.py b/android_env/wrappers/flat_interface_wrapper.py deleted file mode 100644 index 868ca65f..00000000 --- a/android_env/wrappers/flat_interface_wrapper.py +++ /dev/null @@ -1,119 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps the AndroidEnv environment to make its interface flat.""" - -from typing import Any - -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import numpy as np - -RGB_CHANNELS = (0, 1, 2) - - -def _extract_screen_pixels(obs: np.ndarray): - """Get only screen pixels by removing previous action layer.""" - is_grayscale_image = obs.shape[-1] == 2 - if is_grayscale_image: - return np.expand_dims(obs[..., 0], -1) - return obs[..., RGB_CHANNELS] - - -def _get_no_action_observation_spec(obs_spec: specs.BoundedArray): - """Create an observation spec without the action layer.""" - shape = np.array(obs_spec.shape) - shape[2] -= 1 - minimum = obs_spec.minimum - maximum = obs_spec.maximum - is_scalar = lambda x: np.isscalar(x) or np.ndim(x) == 0 - if not is_scalar(minimum): - minimum = _extract_screen_pixels(minimum) - if not is_scalar(maximum): - maximum = _extract_screen_pixels(maximum) - return obs_spec.replace(shape=shape, minimum=minimum, maximum=maximum) - - -class FlatInterfaceWrapper(base_wrapper.BaseWrapper): - """Simple interface for AndroidEnv. - - Removes the structure from observations and actions, keeping only the pixel - observations. Also exposes action as an int32 scalar, making it easier to use - with conventional discrete agents. This wrapper expects a discretized action - space. - """ - - def __init__(self, - env: dm_env.Environment, - flat_actions: bool = True, - flat_observations: bool = True, - keep_action_layer: bool = True): - super().__init__(env) - self._flat_actions = flat_actions - self._flat_observations = flat_observations - self._keep_action_layer = keep_action_layer - self._action_name = list(self._env.action_spec())[0] - self._assert_base_env() - - def _assert_base_env(self): - base_action_spec = self._env.action_spec() - assert len(base_action_spec) == 1, self._env.action_spec() - assert isinstance(base_action_spec, dict) - assert isinstance(base_action_spec[self._action_name], specs.BoundedArray) - - def _process_action(self, action: int | np.ndarray | dict[str, Any]): - if self._flat_actions: - return {self._action_name: action} - else: - return action - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - if self._flat_observations: - step_type, reward, discount, observation = timestep - # Keep only the pixels. - pixels = observation['pixels'] - pixels = pixels if self._keep_action_layer else _extract_screen_pixels( - pixels) - return dm_env.TimeStep( - step_type=step_type, - reward=reward, - discount=discount, - observation=pixels) - else: - return timestep - - def reset(self) -> dm_env.TimeStep: - timestep = self._env.reset() - return self._process_timestep(timestep) - - def step(self, action: int) -> dm_env.TimeStep: - timestep = self._env.step(self._process_action(action)) - return self._process_timestep(timestep) - - def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks - if self._flat_observations: - pixels_spec = self._env.observation_spec()['pixels'] - if not self._keep_action_layer: - return _get_no_action_observation_spec(pixels_spec) - return pixels_spec - else: - return self._env.observation_spec() - - def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks - if self._flat_actions: - return self._env.action_spec()[self._action_name] - else: - return self._env.action_spec() diff --git a/android_env/wrappers/flat_interface_wrapper_test.py b/android_env/wrappers/flat_interface_wrapper_test.py deleted file mode 100644 index a650054b..00000000 --- a/android_env/wrappers/flat_interface_wrapper_test.py +++ /dev/null @@ -1,166 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.flat_interface_wrapper.""" - -from typing import cast -from unittest import mock - -from absl.testing import absltest -from android_env.wrappers import flat_interface_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0): - return specs.BoundedArray( - shape=shape, - dtype=dtype, - name=name, - maximum=np.ones(shape) * maximum, - minimum=np.ones(shape) * minimum) - - -def _make_timestep(observation): - return dm_env.TimeStep( - step_type='fake_step_type', - reward='fake_reward', - discount='fake_discount', - observation=observation, - ) - - -class FlatInterfaceWrapperTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.action_shape = (1,) - self.base_action_spec: dict[str, specs.DiscreteArray] = { - 'action_id': specs.DiscreteArray(name='action_id', num_values=4) - } - self.int_obs_shape = (3, 4, 2) - self.float_obs_shape = (2,) - self.base_observation_spec = { - 'pixels': _make_array_spec( - shape=self.int_obs_shape, dtype=np.uint8, name='pixels'), - 'obs1': _make_array_spec( - shape=self.float_obs_shape, dtype=np.float32, name='obs1'), - } - # Expected. - self.expected_observation_spec = _make_array_spec( - shape=self.int_obs_shape, dtype=np.uint8, name='pixels') - self.image_obs = np.ones(self.int_obs_shape, dtype=np.uint8) - self.expected_timestep = _make_timestep(self.image_obs) - - # Expected for no new action layer shape. - expected_new_shape_no_action_layer = (3, 4, 1) - self.expected_observation_spec_no_action_layer = _make_array_spec( - shape=expected_new_shape_no_action_layer, dtype=np.uint8, name='pixels') - self.expected_timestep_no_action_layer = _make_timestep( - np.ones(expected_new_shape_no_action_layer, dtype=np.uint8)) - - # Base environment. - self.other_obs = np.ones(self.float_obs_shape, dtype=np.float32) - self.base_timestep = _make_timestep({ - 'pixels': self.image_obs, - 'obs1': self.other_obs}) - self.base_env = mock.create_autospec(dm_env.Environment) - self.base_env.action_spec.return_value = self.base_action_spec - self.base_env.observation_spec.return_value = self.base_observation_spec - self.base_env.reset.return_value = self.base_timestep - self.base_env.step.return_value = self.base_timestep - - def test_reset(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) - ts = wrapped_env.reset() - self.base_env.reset.assert_called_once() - self.assertEqual(self.expected_timestep, ts) - - def test_reset_no_action_layer(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( - self.base_env, keep_action_layer=False) - ts = wrapped_env.reset() - self.base_env.reset.assert_called_once() - self.assertEqual( - self.expected_timestep_no_action_layer.observation.tolist(), - ts.observation.tolist()) - - def test_step(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) - action = 2 - ts = wrapped_env.step(action) - - def verifier(x): - self.assertIsInstance(x, dict) - self.assertIsInstance(x['action_id'], int) - self.assertEqual(x['action_id'], action) - return True - verifier(self.base_env.step.call_args[0][0]) - self.assertEqual(self.expected_timestep, ts) - - def test_step_no_action_layer(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( - self.base_env, keep_action_layer=False) - action = 2 - ts = wrapped_env.step(action) - - def verifier(x): - self.assertIsInstance(x, dict) - self.assertIsInstance(x['action_id'], int) - self.assertEqual(x['action_id'], action) - return True - - verifier(self.base_env.step.call_args[0][0]) - self.assertEqual( - self.expected_timestep_no_action_layer.observation.tolist(), - ts.observation.tolist()) - - def test_observation_spec(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) - observation_spec = wrapped_env.observation_spec() - self.base_env.observation_spec.assert_called_once() - self.assertEqual(self.expected_observation_spec, observation_spec) - - def test_observation_spec_no_action_layer(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( - self.base_env, keep_action_layer=False) - observation_spec = wrapped_env.observation_spec() - self.base_env.observation_spec.assert_called_once() - self.assertEqual(self.expected_observation_spec_no_action_layer, - observation_spec) - - def test_action_spec(self): - wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) - action_spec = cast(specs.BoundedArray, wrapped_env.action_spec()) - parent_action_spec = self.base_action_spec['action_id'] - - self.assertEqual(parent_action_spec.name, action_spec.name) - self.assertEqual((), action_spec.shape) - self.assertEqual(np.int32, action_spec.dtype) - self.assertEqual(0, action_spec.minimum) - - def test_bad_action_spec_structured_action(self): - bad_base_env = mock.create_autospec(dm_env.Environment) - bad_base_env.action_spec.return_value = { - 'action_id': _make_array_spec((1,)), - 'too_many': _make_array_spec((1,)) - } - with self.assertRaises(AssertionError): - _ = flat_interface_wrapper.FlatInterfaceWrapper(bad_base_env) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/float_pixels_wrapper.py b/android_env/wrappers/float_pixels_wrapper.py deleted file mode 100644 index 43def1f6..00000000 --- a/android_env/wrappers/float_pixels_wrapper.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Converts pixel observation to from int to float32 between 0.0 and 1.0.""" - -from android_env.components import pixel_fns -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -class FloatPixelsWrapper(base_wrapper.BaseWrapper): - """Wraps AndroidEnv for Panultimate agent.""" - - def __init__(self, env: dm_env.Environment): - super().__init__(env) - self._input_spec = self._env.observation_spec()['pixels'] - self._should_convert_int_to_float = np.issubdtype(self._input_spec.dtype, - np.integer) - - def _process_observation( - self, observation: dict[str, np.ndarray] - ) -> dict[str, np.ndarray]: - if self._should_convert_int_to_float: - float_pixels = pixel_fns.convert_int_to_float( - observation['pixels'], self._input_spec - ) - observation['pixels'] = float_pixels - return observation - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - step_type, reward, discount, observation = timestep - return dm_env.TimeStep( - step_type=step_type, - reward=reward, - discount=discount, - observation=self._process_observation(observation)) - - def reset(self) -> dm_env.TimeStep: - return self._process_timestep(self._env.reset()) - - def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: - return self._process_timestep(self._env.step(action)) - - def observation_spec(self) -> dict[str, specs.Array]: - if self._should_convert_int_to_float: - observation_spec = self._env.observation_spec() - observation_spec['pixels'] = specs.BoundedArray( - shape=self._env.observation_spec()['pixels'].shape, - dtype=np.float32, - minimum=0.0, - maximum=1.0, - name=self._env.observation_spec()['pixels'].name) - return observation_spec - return self._env.observation_spec() diff --git a/android_env/wrappers/float_pixels_wrapper_test.py b/android_env/wrappers/float_pixels_wrapper_test.py deleted file mode 100644 index 5868e75b..00000000 --- a/android_env/wrappers/float_pixels_wrapper_test.py +++ /dev/null @@ -1,149 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.float_pixels_wrapper.""" - -from unittest import mock - -from absl.testing import absltest -from android_env.wrappers import float_pixels_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _make_array_spec(shape, dtype=np.float32, name=None): - return specs.Array( - shape=shape, - dtype=dtype, - name=name, - ) - - -def _make_bounded_array_spec( - shape, dtype=np.float32, name=None, maximum=1.0, minimum=0.0): - return specs.BoundedArray( - shape=shape, - dtype=dtype, - name=name, - maximum=maximum, - minimum=minimum, - ) - - -def _simple_timestep(obs_shape, obs_type): - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=3.14, - discount=0.9, - observation=(np.ones(shape=obs_shape, dtype=obs_type),), - ) - - -class FloatPixelsWrapperTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.pixels_shape = (3, 4) - base_pixel_spec = _make_array_spec( - shape=self.pixels_shape, dtype=np.uint8, name='pixels') - self.other_obs_spec = _make_array_spec( - shape=(1,), dtype=np.float32, name='other_obs') - base_observation_spec = { - 'pixels': base_pixel_spec, - 'other_obs': self.other_obs_spec - } - self.base_env = mock.create_autospec(dm_env.Environment) - self.base_env.observation_spec.return_value = base_observation_spec - - self.base_timestep = dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=3.14, - discount=0.9, - observation={ - 'pixels': np.ones(shape=self.pixels_shape, dtype=np.uint8), - 'other_obs': [42.2]}) - self.base_env.step.return_value = self.base_timestep - self.base_env.reset.return_value = self.base_timestep - - def test_float_pixels_wrapper_spec(self): - expected_pixel_spec = _make_bounded_array_spec( - shape=self.pixels_shape, - dtype=np.float32, - name='pixels', - minimum=0.0, - maximum=1.0) - - wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) - - self.assertLen(wrapped_env.observation_spec(), 2) - self.assertEqual(expected_pixel_spec, - wrapped_env.observation_spec()['pixels']) - self.assertEqual(self.other_obs_spec, - wrapped_env.observation_spec()['other_obs']) - - def test_float_pixels_wrapper_step(self): - wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) - ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])}) - - self.assertEqual(self.base_timestep.step_type, ts.step_type) - self.assertEqual(self.base_timestep.reward, ts.reward) - self.assertEqual(self.base_timestep.discount, ts.discount) - self.assertEqual(self.base_timestep.observation['other_obs'], - ts.observation['other_obs']) - expected_pixel_value = 1. / 255. # original values are unit8 - expected_pixels = np.ones( - self.pixels_shape, dtype=np.float32) * expected_pixel_value - np.testing.assert_equal(expected_pixels, ts.observation['pixels']) - - def test_float_pixels_wrapper_reset(self): - wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) - ts = wrapped_env.reset() - - self.assertEqual(self.base_timestep.step_type, ts.step_type) - self.assertEqual(self.base_timestep.reward, ts.reward) - self.assertEqual(self.base_timestep.discount, ts.discount) - self.assertEqual(self.base_timestep.observation['other_obs'], - ts.observation['other_obs']) - expected_pixel_value = 1. / 255. # original values are unit8 - expected_pixels = np.ones( - self.pixels_shape, dtype=np.float32) * expected_pixel_value - np.testing.assert_equal(expected_pixels, ts.observation['pixels']) - - def test_float_pixels_wrapper_already_float(self): - base_pixel_spec = _make_array_spec( - shape=self.pixels_shape, dtype=np.float64, name='pixels') - base_observation_spec = { - 'pixels': base_pixel_spec, - 'other_obs': self.other_obs_spec - } - base_env = mock.create_autospec(dm_env.Environment) - base_env.observation_spec.return_value = base_observation_spec - - wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(base_env) - - # If the pixels are already float values, then obs_spec does not change. - self.assertEqual(base_env.observation_spec(), - wrapped_env.observation_spec()) - - # The wrapper should not touch the timestep in this case. - fake_timestep = ('step_type', 'reward', 'discount', 'obs') - base_env.step.return_value = fake_timestep - ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])}) - self.assertEqual(fake_timestep, ts) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/gym_wrapper.py b/android_env/wrappers/gym_wrapper.py deleted file mode 100644 index 41ff44ed..00000000 --- a/android_env/wrappers/gym_wrapper.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps the AndroidEnv to expose an OpenAI Gym interface.""" - -from typing import Any - -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import gym -from gym import spaces -import numpy as np - - -class GymInterfaceWrapper(gym.Env): - """AndroidEnv with OpenAI Gym interface.""" - - def __init__(self, env: dm_env.Environment): - self._env = env - self.spec = None - self.action_space = self._spec_to_space(self._env.action_spec()) - self.observation_space = self._spec_to_space(self._env.observation_spec()) - self.metadata = {'render.modes': ['rgb_array']} - self._latest_observation = None - - def _spec_to_space(self, spec: specs.Array) -> spaces.Space: - """Converts dm_env specs to OpenAI Gym spaces.""" - - if isinstance(spec, list): - return spaces.Tuple([self._spec_to_space(s) for s in spec]) - - if isinstance(spec, dict): - return spaces.Dict( - {name: self._spec_to_space(s) for name, s in spec.items()} - ) - - if isinstance(spec, specs.DiscreteArray): - return spaces.Box( - shape=(), - dtype=spec.dtype, - low=0, - high=spec.num_values-1) - - if isinstance(spec, specs.BoundedArray): - return spaces.Box( - shape=spec.shape, - dtype=spec.dtype, - low=spec.minimum, - high=spec.maximum) - - if isinstance(spec, specs.Array): - if spec.dtype == np.uint8: - low = 0 - high = 255 - else: - low = -np.inf - high = np.inf - return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high) - - raise ValueError('Unknown type for specs: {}'.format(spec)) - - def render(self, mode='rgb_array'): - """Renders the environment.""" - if mode == 'rgb_array': - if self._latest_observation is None: - return - - return self._latest_observation['pixels'] - else: - raise ValueError('Only supported render mode is rgb_array.') - - def reset(self) -> np.ndarray: - self._latest_observation = None - timestep = self._env.reset() - return timestep.observation - - def step(self, action: dict[str, int]) -> tuple[Any, ...]: - """Take a step in the base environment.""" - timestep = self._env.step(action) - observation = timestep.observation - self._latest_observation = observation - reward = timestep.reward - done = timestep.step_type == dm_env.StepType.LAST - info = {'discount': timestep.discount} - return observation, reward, done, info diff --git a/android_env/wrappers/gym_wrapper_test.py b/android_env/wrappers/gym_wrapper_test.py deleted file mode 100644 index 9c266699..00000000 --- a/android_env/wrappers/gym_wrapper_test.py +++ /dev/null @@ -1,119 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.gym_wrapper.""" - -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env.wrappers import gym_wrapper -import dm_env -from dm_env import specs -from gym import spaces -import numpy as np - - -class GymInterfaceWrapperTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - self._base_env.action_spec.return_value = { - 'action_type': - specs.DiscreteArray( - num_values=3, - name='action_type'), - 'touch_position': - specs.BoundedArray( - shape=(2,), - dtype=np.float32, - minimum=[0.0, 0.0], - maximum=[1.0, 1.0], - name='touch_position'), - } - self._base_env.observation_spec.return_value = { - 'pixels': - specs.Array( - shape=(480, 320, 3), - dtype=np.uint8, - name='pixels'), - 'timedelta': - specs.Array(shape=(), dtype=np.int64, name='timedelta'), - 'orientation': - specs.Array( - shape=np.array([4]), - dtype=np.uint8, - name='orientation'), - } - self._wrapped_env = gym_wrapper.GymInterfaceWrapper(self._base_env) - self._fake_ts = dm_env.TimeStep( - step_type=dm_env.StepType.MID, - observation={'pixels': np.ones(shape=(2, 3))}, - reward=10.0, - discount=1.0) - - def test_render(self): - self._base_env.step.return_value = self._fake_ts - _ = self._wrapped_env.step(action=np.zeros(shape=(1,))) - image = self._wrapped_env.render(mode='rgb_array') - self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3)))) - - def test_render_error(self): - with self.assertRaises(ValueError): - _ = self._wrapped_env.render(mode='human') - - def test_reset(self): - self._base_env.reset.return_value = dm_env.TimeStep( - step_type=dm_env.StepType.FIRST, - observation={'pixels': np.ones(shape=(2, 3))}, - reward=10.0, - discount=1.0) - obs = self._wrapped_env.reset() - self._base_env.reset.assert_called_once() - self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3)))) - - def test_step(self): - self._base_env.step.return_value = self._fake_ts - obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,))) - self._base_env.step.assert_called_once() - self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3)))) - - def test_spec_to_space(self): - - spec = specs.Array( - shape=(2, 3), - dtype=np.float32) - space = self._wrapped_env._spec_to_space(spec) - self.assertEqual(space, spaces.Box( - low=-np.inf, high=np.inf, shape=spec.shape, dtype=spec.dtype)) - - spec = specs.BoundedArray( - shape=(), - dtype=np.float32, - minimum=4, - maximum=5) - space = self._wrapped_env._spec_to_space(spec) - self.assertEqual(space, spaces.Box( - low=4, high=5, shape=spec.shape, dtype=spec.dtype)) - - spec = specs.DiscreteArray(num_values=4) - space = self._wrapped_env._spec_to_space(spec) - self.assertEqual(space, spaces.Box( - low=0, high=3, shape=(), dtype=np.int32)) - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/image_rescale_wrapper.py b/android_env/wrappers/image_rescale_wrapper.py deleted file mode 100644 index 6faeebca..00000000 --- a/android_env/wrappers/image_rescale_wrapper.py +++ /dev/null @@ -1,109 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps the AndroidEnv environment to rescale the observations.""" - -from collections.abc import Sequence - -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import numpy as np -from PIL import Image - - -# Taken from https://pillow.readthedocs.io/en/3.2.x/reference/Image.html#PIL.Image.Image.convert -# -# This array maps an RGB image to a grayscale image using the ITU-R 709 -# specification which is good for computer displays and HDTV. -RGB_TO_GRAYSCALE_COEFFICIENTS = [0.2126, 0.7152, 0.0722] - - -class ImageRescaleWrapper(base_wrapper.BaseWrapper): - """AndroidEnv with rescaled observations.""" - - def __init__( - self, - env: dm_env.Environment, - zoom_factors: Sequence[float] | None = (0.5, 0.5), - grayscale: bool = False, - ): - super().__init__(env) - assert 'pixels' in self._env.observation_spec() - assert self._env.observation_spec()['pixels'].shape[-1] in [1, 3], ( - 'Number of pixel channels should be 1 or 3.') - self._grayscale = grayscale - if zoom_factors is None: - zoom_factors = (1.0, 1.0) - # We only zoom the width and height of each layer, and we explicitly do not - # want to zoom the number of channels so we just multiply it by 1.0. - self._zoom_factors = tuple(zoom_factors) + (1.0,) - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - observation = timestep.observation - processed_observation = observation.copy() - processed_observation['pixels'] = self._process_pixels( - observation['pixels']) - return timestep._replace(observation=processed_observation) - - def _process_pixels(self, raw_observation: np.ndarray) -> np.ndarray: - # We expect `raw_observation` to have shape (W, H, 3) - 3 for RGB - new_shape = np.array( - self._zoom_factors[0:2] * np.array(raw_observation.shape[0:2]), - dtype=np.int32)[::-1] - if self._grayscale: - # When self._grayscale == True, we squash the RGB into a single layer - image = np.dot(raw_observation, RGB_TO_GRAYSCALE_COEFFICIENTS) - else: - image = raw_observation - return self._resize_image_array(image, new_shape) - - def _resize_image_array( - self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray - ) -> np.ndarray: - """Resize color or grayscale/action_layer array to new_shape.""" - assert new_shape.ndim == 1 - assert len(new_shape) == 2 - resized_array = np.array( - Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize( - tuple(new_shape) - ) - ) - if resized_array.ndim == 2: - return np.expand_dims(resized_array, axis=-1) - return resized_array - - def reset(self) -> dm_env.TimeStep: - timestep = self._env.reset() - return self._process_timestep(timestep) - - def step(self, action) -> dm_env.TimeStep: - timestep = self._env.step(action) - return self._process_timestep(timestep) - - def observation_spec(self) -> dict[str, specs.Array]: - parent_spec = self._env.observation_spec().copy() - out_shape = np.multiply(parent_spec['pixels'].shape, - self._zoom_factors).astype(np.int32) - if self._grayscale: - # In grayscale mode we want the output shape to be [W, H, 1] - out_shape[-1] = 1 - parent_spec['pixels'] = specs.BoundedArray( - shape=out_shape, - dtype=parent_spec['pixels'].dtype, - name=parent_spec['pixels'].name, - minimum=parent_spec['pixels'].minimum, - maximum=parent_spec['pixels'].maximum) - return parent_spec diff --git a/android_env/wrappers/image_rescale_wrapper_test.py b/android_env/wrappers/image_rescale_wrapper_test.py deleted file mode 100644 index c55101f1..00000000 --- a/android_env/wrappers/image_rescale_wrapper_test.py +++ /dev/null @@ -1,104 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.image_rescale_wrapper.""" - -from typing import Any -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env.wrappers import image_rescale_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _simple_spec(): - return specs.BoundedArray( - shape=np.array([300, 300, 3]), - dtype=np.uint8, - name='pixels', - minimum=0, - maximum=255) - - -def _simple_timestep(): - observation = np.ones(shape=[300, 300, 3]) - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=3.14, - discount=0.9, - observation={'pixels': observation}) - - -class ImageRescaleWrapperTest(absltest.TestCase): - - def test_100x50_grayscale(self): - fake_timestep = _simple_timestep() - fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) - fake_env.observation_spec.return_value = {'pixels': _simple_spec()} - fake_env.reset.return_value = fake_timestep - fake_env.step.return_value = fake_timestep - - wrapper = image_rescale_wrapper.ImageRescaleWrapper( - fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True) - self.assertIsNotNone(wrapper) - self.assertEqual(wrapper.observation_spec()['pixels'].shape, (100, 50, 1)) - reset_timestep = wrapper.reset() - reset_image = reset_timestep.observation['pixels'] - self.assertEqual(reset_image.shape, (100, 50, 1)) - step_timestep = wrapper.step(action='fake_action') - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (100, 50, 1)) - - def test_150x60_full_channels(self): - fake_timestep = _simple_timestep() - fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) - fake_env.observation_spec.return_value = {'pixels': _simple_spec()} - fake_env.reset.return_value = fake_timestep - fake_env.step.return_value = fake_timestep - - wrapper = image_rescale_wrapper.ImageRescaleWrapper( - fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0)) - self.assertIsNotNone(wrapper) - self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3)) - reset_timestep = wrapper.reset() - reset_image = reset_timestep.observation['pixels'] - self.assertEqual(reset_image.shape, (150, 60, 3)) - step_timestep = wrapper.step(action='fake_action') - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (150, 60, 3)) - - def test_list_zoom_factor(self): - fake_timestep = _simple_timestep() - fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) - fake_env.observation_spec.return_value = {'pixels': _simple_spec()} - fake_env.reset.return_value = fake_timestep - fake_env.step.return_value = fake_timestep - - wrapper = image_rescale_wrapper.ImageRescaleWrapper( - fake_env, zoom_factors=[0.5, 0.2]) - self.assertIsNotNone(wrapper) - self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3)) - reset_timestep = wrapper.reset() - reset_image = reset_timestep.observation['pixels'] - self.assertEqual(reset_image.shape, (150, 60, 3)) - step_timestep = wrapper.step(action='fake_action') - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (150, 60, 3)) - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/last_action_wrapper.py b/android_env/wrappers/last_action_wrapper.py deleted file mode 100644 index a09633c6..00000000 --- a/android_env/wrappers/last_action_wrapper.py +++ /dev/null @@ -1,114 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Extends Android observation with the latest action taken.""" - -from android_env.components import action_type -from android_env.components import pixel_fns -from android_env.wrappers import base_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -class LastActionWrapper(base_wrapper.BaseWrapper): - """Extends Android observations with information about the last action taken. - - The position of the last action is denoted by a single white pixel (with a - value of 255) in a channel of all black pixels (with a value of 0). - As this wrapper makes use of temporarily stored information about the - last action taken, it is important to apply on the environment side rather - than the agent side. Recommended not to apply before an ImageRescaleWrapper, - to avoid distortion of the single pixel denoting the action position. - """ - - def __init__(self, - env: dm_env.Environment, - concat_to_pixels: bool = True): - """Initializes the internal state of this wrapper. - - Args: - env: the environment to wrap. - concat_to_pixels: If True, will add a channel to the pixel observation. - If False, will pass the action as an extra observation. - """ - super().__init__(env) - self._concat_to_pixels = concat_to_pixels - self._screen_dimensions = self._env.observation_spec()['pixels'].shape[:2] - - def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: - observation = timestep.observation.copy() - processed_observation = self._process_observation(observation) - return timestep._replace(observation=processed_observation) - - def _process_observation( - self, observation: dict[str, np.ndarray] - ) -> dict[str, np.ndarray]: - """Extends observation with last_action data.""" - processed_observation = observation.copy() - last_action_layer = self._get_last_action_layer(observation['pixels']) - if self._concat_to_pixels: - pixels = observation['pixels'].copy() - processed_pixels = np.dstack((pixels, last_action_layer)) - processed_observation['pixels'] = processed_pixels - else: - processed_observation['last_action'] = last_action_layer - return processed_observation - - def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray: - """Makes sure the rescaling doesn't distort the last_action layer.""" - - last_action = self._env.raw_action - last_action_layer = np.zeros(self._screen_dimensions, dtype=pixels.dtype) - - if ('action_type' in last_action and - last_action['action_type'] == action_type.ActionType.TOUCH): - touch_position = last_action['touch_position'] - x, y = pixel_fns.touch_position_to_pixel_position( - touch_position, width_height=self._screen_dimensions[::-1] - ) - last_action_layer[y, x] = 255 - - return last_action_layer - - def reset(self) -> dm_env.TimeStep: - timestep = self._env.reset() - return self._process_timestep(timestep) - - def step(self, action) -> dm_env.TimeStep: - timestep = self._env.step(action) - return self._process_timestep(timestep) - - def observation_spec(self) -> dict[str, specs.Array]: - parent_spec = self._env.observation_spec().copy() - shape = parent_spec['pixels'].shape - if self._concat_to_pixels: - parent_spec['pixels'] = specs.BoundedArray( - shape=(shape[0], shape[1], shape[2] + 1), - dtype=parent_spec['pixels'].dtype, - name=parent_spec['pixels'].name, - minimum=parent_spec['pixels'].minimum, - maximum=parent_spec['pixels'].maximum) - else: - parent_spec.update({ - 'last_action': - specs.BoundedArray( - shape=(shape[0], shape[1]), - dtype=parent_spec['pixels'].dtype, - name='last_action', - minimum=parent_spec['pixels'].minimum, - maximum=parent_spec['pixels'].maximum) - }) - return parent_spec diff --git a/android_env/wrappers/last_action_wrapper_test.py b/android_env/wrappers/last_action_wrapper_test.py deleted file mode 100644 index 77307c00..00000000 --- a/android_env/wrappers/last_action_wrapper_test.py +++ /dev/null @@ -1,162 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for android_env.wrappers.last_action_wrapper.""" - -from typing import Any -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env.components import action_type -from android_env.wrappers import last_action_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _simple_spec(): - return specs.BoundedArray( - shape=np.array([120, 80, 3]), - dtype=np.uint8, - name='pixels', - minimum=0, - maximum=255) - - -def _simple_timestep(): - observation = np.ones(shape=[120, 80, 3]) - return dm_env.TimeStep( - step_type=dm_env.StepType.MID, - reward=3.14, - discount=0.9, - observation={'pixels': observation}) - - -class LastActionWrapperTest(absltest.TestCase): - - def test_concat_to_pixels(self): - fake_timestep = _simple_timestep() - fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) - fake_env.observation_spec.return_value = {'pixels': _simple_spec()} - fake_env.reset.return_value = fake_timestep - fake_env.step.return_value = fake_timestep - - wrapper = last_action_wrapper.LastActionWrapper( - fake_env, concat_to_pixels=True) - self.assertIsNotNone(wrapper) - self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 4)) - - reset_timestep = wrapper.reset() - reset_image = reset_timestep.observation['pixels'] - self.assertEqual(reset_image.shape, (120, 80, 4)) - last_action_layer = reset_image[:, :, -1] - self.assertEqual(np.sum(last_action_layer), 0) - - action1 = { - 'action_type': action_type.ActionType.TOUCH, - 'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H) - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action1) - step_timestep = wrapper.step(action=action1) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W) - last_action_layer = step_image[:, :, -1] - self.assertEqual(np.sum(last_action_layer), 255) - y, x = np.where(last_action_layer == 255) - self.assertEqual((y.item(), x.item()), (90, 20)) - - action2 = { - 'action_type': action_type.ActionType.LIFT, - 'touch_position': np.array([0.25, 0.75], dtype=np.float32), - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action2) - step_timestep = wrapper.step(action=action2) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 4)) - last_action_layer = step_image[:, :, -1] - self.assertEqual(np.sum(last_action_layer), 0) - - action3 = { - 'action_type': action_type.ActionType.TOUCH, - 'touch_position': np.array([0.25, 1.0], dtype=np.float32), - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action3) - step_timestep = wrapper.step(action=action3) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 4)) - last_action_layer = step_image[:, :, -1] - self.assertEqual(np.sum(last_action_layer), 255) - y, x = np.where(last_action_layer == 255) - self.assertEqual((y.item(), x.item()), (119, 20)) - - def test_no_concat_to_pixels(self): - fake_timestep = _simple_timestep() - fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) - fake_env.observation_spec.return_value = {'pixels': _simple_spec()} - fake_env.reset.return_value = fake_timestep - fake_env.step.return_value = fake_timestep - - wrapper = last_action_wrapper.LastActionWrapper( - fake_env, concat_to_pixels=False) - self.assertIsNotNone(wrapper) - self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 3)) - self.assertEqual(wrapper.observation_spec()['last_action'].shape, (120, 80)) - - reset_timestep = wrapper.reset() - reset_image = reset_timestep.observation['pixels'] - self.assertEqual(reset_image.shape, (120, 80, 3)) - last_action_layer = reset_timestep.observation['last_action'] - self.assertEqual(np.sum(last_action_layer), 0) - - action1 = { - 'action_type': action_type.ActionType.TOUCH, - 'touch_position': np.array([0.25, 0.75], dtype=np.float32), - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action1) - step_timestep = wrapper.step(action=action1) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 3)) - last_action_layer = step_timestep.observation['last_action'] - self.assertEqual(np.sum(last_action_layer), 255) - y, x = np.where(last_action_layer == 255) - self.assertEqual((y.item(), x.item()), (90, 20)) - - action2 = { - 'action_type': action_type.ActionType.LIFT, - 'touch_position': np.array([0.25, 0.75], dtype=np.float32), - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action2) - step_timestep = wrapper.step(action=action2) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 3)) - last_action_layer = step_timestep.observation['last_action'] - self.assertEqual(np.sum(last_action_layer), 0) - - action3 = { - 'action_type': action_type.ActionType.TOUCH, - 'touch_position': np.array([1.0, 0.75], dtype=np.float32), - } - type(fake_env).raw_action = mock.PropertyMock(return_value=action3) - step_timestep = wrapper.step(action=action3) - step_image = step_timestep.observation['pixels'] - self.assertEqual(step_image.shape, (120, 80, 3)) - last_action_layer = step_timestep.observation['last_action'] - self.assertEqual(np.sum(last_action_layer), 255) - y, x = np.where(last_action_layer == 255) - self.assertEqual((y.item(), x.item()), (90, 79)) - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/rate_limit_wrapper.py b/android_env/wrappers/rate_limit_wrapper.py deleted file mode 100644 index 2439be5b..00000000 --- a/android_env/wrappers/rate_limit_wrapper.py +++ /dev/null @@ -1,117 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Limits interactions with the environment to a given rate.""" - -import enum -import time - -from android_env import env_interface -from android_env.components import action_type -from android_env.wrappers import base_wrapper -import dm_env -import numpy as np - - -class RateLimitWrapper(base_wrapper.BaseWrapper): - """Limits interactions with the environment to a given rate.""" - - class SleepType(enum.IntEnum): - """Determines how the wrapper interacts with the underlying environment.""" - - # The wrapper sleeps before calling `step()` on the underlying environment. - BEFORE = 0 - - # The wrapper sleeps after calling `step()` on the underlying environment. - AFTER = 1 - - # The wrapper first calls `step()`, obtaining a TimeStep which is ignored, - # then it sleeps, and then it calls `step(REPEAT)` to obtain a TimeStep - # that's as fresh as possible. - # - # Note that for both BEFORE and AFTER_WITH_REPEAT, the _total_ amount of - # time inside this wrapper may go beyond the rate specified in `rate` - # because the sleep does not account for the time taken by step(). - AFTER_WITH_REPEAT = 2 - - def __init__(self, - env: env_interface.AndroidEnvInterface, - rate: float, - sleep_type: SleepType = SleepType.AFTER_WITH_REPEAT): - """Initializes this wrapper. - - Args: - env: The underlying environment to which this wrapper is applied. - rate: The desired rate in Hz to interact with the environment. If <=0.0, - this wrapper will be disabled. - sleep_type: This determines how the wrapper will interact with the - underlying AndroidEnv environment. - """ - super().__init__(env) - self._assert_base_env() - self._last_step_time = None - self._max_wait = 1.0 / rate if rate > 0.0 else 0.0 - self._sleep_type = sleep_type - - def _assert_base_env(self): - """Checks that the wrapped env has the right action spec format.""" - parent_action_spec = self._env.action_spec() - assert len(parent_action_spec) == 2 - assert not parent_action_spec['action_type'].shape - assert parent_action_spec['touch_position'].shape == (2,) - - def reset(self): - timestep = self._env.reset() - self._last_step_time = time.time() - return timestep - - def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: - """Takes a step while maintaining a steady interaction rate.""" - - # If max_wait is non-positive, the wrapper has no effect. - if self._max_wait <= 0.0: - return self._env.step(action) - - if self._sleep_type == RateLimitWrapper.SleepType.BEFORE: - self._wait() - - timestep = self._env.step(action) - if timestep.last(): - return timestep - - if self._sleep_type == RateLimitWrapper.SleepType.AFTER_WITH_REPEAT: - for k in action.keys(): - if k.startswith('action_type'): - action[k] = np.array(action_type.ActionType.REPEAT, dtype=np.uint8) - self._wait() - first_reward = timestep.reward or 0.0 - timestep = self._env.step(action) - second_reward = timestep.reward or 0.0 - # Accumulate rewards over the two steps taken. - timestep = timestep._replace(reward=first_reward + second_reward) - - elif self._sleep_type == RateLimitWrapper.SleepType.AFTER: - self._wait() - - self._last_step_time = time.time() - - return timestep - - def _wait(self) -> None: - if self._max_wait > 0.0 and self._last_step_time is not None: - time_since_step = time.time() - self._last_step_time - sec_to_wait = self._max_wait - time_since_step - if sec_to_wait > 0.0: - time.sleep(sec_to_wait) diff --git a/android_env/wrappers/rate_limit_wrapper_test.py b/android_env/wrappers/rate_limit_wrapper_test.py deleted file mode 100644 index d108de92..00000000 --- a/android_env/wrappers/rate_limit_wrapper_test.py +++ /dev/null @@ -1,270 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for rate_limit_wrapper.""" - -import time -from typing import Any, Protocol -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from android_env import env_interface -from android_env.components import action_type -from android_env.wrappers import rate_limit_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _get_base_env(): - env = mock.create_autospec(env_interface.AndroidEnvInterface) - env.action_spec.return_value = { - 'action_type': - specs.DiscreteArray( - num_values=len(action_type.ActionType), - name='action_type'), - 'touch_position': - specs.BoundedArray( - shape=(2,), - dtype=np.float32, - minimum=[0.0, 0.0], - maximum=[1.0, 1.0], - name='touch_position'), - } - return env - - -class _FnWithTimestamps(Protocol): - """A function with `timestamp` and `timestamps` attributes.""" - - timestamp: float - timestamps: list[float] - - -def _with_timestamp(fn: Any) -> _FnWithTimestamps: - return fn - - -class RateLimitWrapperTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('zero_rate', 0), - ('negative_rate', -50), - ) - @mock.patch.object(time, 'sleep', autospec=True) - def test_disabled(self, rate, mock_sleep): - """With a non-positive rate, this wrapper should do nothing.""" - env = _get_base_env() - wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=rate) - _ = wrapper.reset() - mock_sleep.assert_not_called() - _ = wrapper.step({ - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': np.array([0.123, 0.456]) - }) - mock_sleep.assert_not_called() - # When the wrapper is disabled, base step should only be called once. - env.step.assert_called_once() - - @mock.patch.object(time, 'sleep', autospec=True) - def test_enabled(self, mock_sleep): - """When enabled, the wrapper should sleep for a period in [0, 1/rate].""" - - env = _get_base_env() - env.step.return_value = dm_env.transition(reward=None, observation=None) - wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=1/33.33) - - _ = wrapper.reset() - mock_sleep.assert_not_called() # It should never sleep during reset(). - - # Step for 100 steps. - for _ in range(100): - _ = wrapper.step({ - 'action_type': - np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': - np.array([0.123, 0.456]) - }) - - # Check that there are 100 calls and that they're all within [0, 1/rate]. - self.assertLen(mock_sleep.call_args_list, 100) - for call in mock_sleep.call_args_list: - args, unused_kwargs = call - sleep_time = args[0] - self.assertBetween(sleep_time, 0.0, 33.33) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_enabled_sleep_type_before(self, mock_sleep): - """When sleep_type==BEFORE, sleep should come before step().""" - - env = _get_base_env() - wrapper = rate_limit_wrapper.RateLimitWrapper( - env, - rate=1/33.33, - sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.BEFORE) - - _ = wrapper.reset() - mock_sleep.assert_not_called() # It should never sleep during reset(). - - @_with_timestamp - def _sleep_fn(sleep_time): - _sleep_fn.timestamp = time.time() - self.assertBetween(sleep_time, 0.0, 33.33) - - mock_sleep.side_effect = _sleep_fn - - def _step_fn(action): - self.assertEqual( - action['action_type'], - np.array(action_type.ActionType.LIFT, dtype=np.uint8)) - _step_fn.timestamps.append(time.time()) - return dm_env.transition(reward=None, observation=None) - - _step_fn.timestamps = [] - - env.step.side_effect = _step_fn - - _ = wrapper.step({ - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': np.array([0.123, 0.456]) - }) - - self.assertLen(_step_fn.timestamps, 1) - # We expect sleep to have been executed BEFORE a single `step()`. - self.assertGreaterEqual(_step_fn.timestamps[0], _sleep_fn.timestamp) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_enabled_sleep_type_after(self, mock_sleep): - """When sleep_type==AFTER, sleep should come after step().""" - - env = _get_base_env() - wrapper = rate_limit_wrapper.RateLimitWrapper( - env, - rate=1/33.33, - sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.AFTER) - _ = wrapper.reset() - mock_sleep.assert_not_called() # It should never sleep during reset(). - - @_with_timestamp - def _sleep_fn(sleep_time): - _sleep_fn.timestamp = time.time() - self.assertBetween(sleep_time, 0.0, 33.33) - - mock_sleep.side_effect = _sleep_fn - - def _step_fn(action): - self.assertEqual( - action['action_type'], - np.array(action_type.ActionType.LIFT, dtype=np.uint8)) - _step_fn.timestamps.append(time.time()) - return dm_env.transition(reward=None, observation=None) - - _step_fn.timestamps = [] - - env.step.side_effect = _step_fn - - _ = wrapper.step({ - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': np.array([0.123, 0.456]) - }) - - # We expect sleep to have been executed AFTER a single `step()`. - self.assertLen(_step_fn.timestamps, 1) - self.assertLessEqual(_step_fn.timestamps[0], _sleep_fn.timestamp) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_enabled_sleep_type_after_with_repeat(self, mock_sleep): - """When sleep_type==AFTER_WITH_REPEAT, sleep should be between 2 steps().""" - - env = _get_base_env() - wrapper = rate_limit_wrapper.RateLimitWrapper( - env, - rate=1/33.33, - sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType - .AFTER_WITH_REPEAT) - - _ = wrapper.reset() - mock_sleep.assert_not_called() # It should never sleep during reset(). - - @_with_timestamp - def _sleep_fn(sleep_time): - _sleep_fn.timestamp = time.time() - self.assertBetween(sleep_time, 0.0, 33.33) - - mock_sleep.side_effect = _sleep_fn - - @_with_timestamp - def _step_fn(action): - # On even calls the action should be the actual agent action, but on odd - # calls they should be REPEATs. - if len(_step_fn.timestamps) % 2 == 0: - self.assertEqual( - action['action_type'], - np.array(action_type.ActionType.LIFT, dtype=np.uint8)) - else: - self.assertEqual( - action['action_type'], - np.array(action_type.ActionType.REPEAT, dtype=np.uint8)) - _step_fn.timestamps.append(time.time()) - return dm_env.transition(reward=1.0, observation=None) - - _step_fn.timestamps = [] - - env.step.side_effect = _step_fn - - timestep = wrapper.step({ - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': np.array([0.123, 0.456]) - }) - - # When the wrapper is enabled, base step should be called twice. - self.assertEqual(env.step.call_count, 2) - - # `step()` should be called twice: before `sleep()` and after it. - self.assertLen(_step_fn.timestamps, 2) - self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0]) - self.assertLessEqual(_sleep_fn.timestamp, _step_fn.timestamps[1]) - # Rewards should accumulate over the two step() calls - self.assertEqual(timestep.reward, 2.0) - - @mock.patch.object(time, 'sleep', autospec=True) - def test_enabled_sleep_type_after_with_repeat_last(self, mock_sleep): - """If the first step is a LAST, second step should not be taken.""" - - env = _get_base_env() - wrapper = rate_limit_wrapper.RateLimitWrapper( - env, - rate=1/33.33, - sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType - .AFTER_WITH_REPEAT) - - _ = wrapper.reset() - mock_sleep.assert_not_called() # It should never sleep during reset(). - - env.step.return_value = dm_env.termination(reward=None, observation=None) - - _ = wrapper.step({ - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), - 'touch_position': np.array([0.123, 0.456]) - }) - - # Second step call should be skipped. - env.step.assert_called_once() - mock_sleep.assert_not_called() - - -if __name__ == '__main__': - absltest.main() diff --git a/android_env/wrappers/tap_action_wrapper.py b/android_env/wrappers/tap_action_wrapper.py deleted file mode 100644 index 5f0ee666..00000000 --- a/android_env/wrappers/tap_action_wrapper.py +++ /dev/null @@ -1,104 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wraps the AndroidEnv environment to provide tap actions of a given duration.""" - -from collections.abc import Sequence - -from android_env.components import action_type -from android_env.wrappers import base_wrapper -import dm_env -import numpy as np - - -class TapActionWrapper(base_wrapper.BaseWrapper): - """AndroidEnv with tap actions.""" - - def __init__(self, - env: dm_env.Environment, - num_frames: int = 5, - touch_only: bool = False): - super().__init__(env) - assert 'action_type' in env.action_spec() - self._touch_only = touch_only - self._num_frames = num_frames - self._env_steps = 0 - - def stats(self): - """Returns a dictionary of metrics logged by the environment.""" - logs = self._env.stats() - logs.update({'env_steps': self._env_steps}) - return logs - - def _process_action( - self, action: dict[str, np.ndarray] - ) -> Sequence[dict[str, np.ndarray]]: - if self._touch_only: - assert action['action_type'] == 0 - touch_action = action.copy() - touch_action['action_type'] = np.array( - action_type.ActionType.TOUCH - ).astype(self.action_spec()['action_type'].dtype) - actions = [touch_action] * self._num_frames - lift_action = action.copy() - lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype( - self.action_spec()['action_type'].dtype - ) - actions.append(lift_action) - - else: - if action['action_type'] == action_type.ActionType.TOUCH: - actions = [action] * self._num_frames - lift_action = action.copy() - lift_action['action_type'] = np.array( - action_type.ActionType.LIFT - ).astype(self.action_spec()['action_type'].dtype) - actions.append(lift_action) - else: - actions = [action] * (self._num_frames + 1) - - return actions - - def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: - """Takes a step in the environment.""" - self._env_steps += self._num_frames + 1 - actions = self._process_action(action) - total_reward = 0.0 - for idx in range(len(actions)): - step_type, reward, discount, observation = self._env.step(actions[idx]) - if reward: - total_reward += reward - if step_type == dm_env.StepType.LAST: - return dm_env.TimeStep( - step_type=step_type, - reward=total_reward, - discount=discount, - observation=observation) - return dm_env.TimeStep( - step_type=step_type, - reward=total_reward, - discount=discount, - observation=observation) - - def action_spec(self) -> dict[str, dm_env.specs.Array]: - if self._touch_only: - return { - 'action_type': - dm_env.specs.DiscreteArray(num_values=1, name='action_type'), - 'touch_position': - self._env.action_spec()['touch_position'], - } - else: - return self._env.action_spec() diff --git a/android_env/wrappers/tap_action_wrapper_test.py b/android_env/wrappers/tap_action_wrapper_test.py deleted file mode 100644 index b6f9ceb3..00000000 --- a/android_env/wrappers/tap_action_wrapper_test.py +++ /dev/null @@ -1,168 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tap_action_wrapper.""" - -from unittest import mock - -from absl.testing import absltest -from android_env import env_interface -from android_env.components import action_type -from android_env.wrappers import tap_action_wrapper -import dm_env -from dm_env import specs -import numpy as np - - -def _make_array_spec(shape, dtype, name): - return specs.BoundedArray( - name=name, - shape=shape, - dtype=dtype, - minimum=np.zeros(shape), - maximum=np.ones(shape), # maximum is inclusive. - ) - - -class TapActionWrapperTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._base_action_spec = { - 'action_type': specs.DiscreteArray( - num_values=3, name='action_type'), - 'touch_position': _make_array_spec( - shape=(2,), dtype=np.float32, name='touch_position'), - } - self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface) - self.base_env.action_spec.return_value = self._base_action_spec - - def test_process_action_repeat(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=3) - action = { - 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32), - 'touch_position': np.array([0.5, 0.5], dtype=np.float32), - } - actions = wrapped_env._process_action(action) - self.assertLen(actions, wrapped_env._num_frames + 1) - self.assertEqual(action, actions[-1]) - - def test_process_action_lift(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=3) - action = { - 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.int32), - 'touch_position': np.array([0.5, 0.5], dtype=np.float32), - } - actions = wrapped_env._process_action(action) - self.assertLen(actions, wrapped_env._num_frames + 1) - self.assertEqual(action, actions[-1]) - - def test_process_action_touch(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=3) - action = { - 'action_type': np.array(action_type.ActionType.TOUCH, dtype=np.int32), - 'touch_position': np.array([0.5, 0.5], dtype=np.float32), - } - actions = wrapped_env._process_action(action) - self.assertLen(actions, wrapped_env._num_frames + 1) - self.assertEqual( - actions[-1]['action_type'], np.array(action_type.ActionType.LIFT) - ) - - def test_reset(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=5) - fake_timestep = 'ts' - self.base_env.reset.return_value = fake_timestep - ts = wrapped_env.reset() - self.base_env.reset.assert_called_once() - self.assertEqual(fake_timestep, ts) - - def test_step(self): - # Arrange. - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=5) - fake_timestep = dm_env.TimeStep( - step_type='fake_type', - reward=0.0, - discount=1.0, - observation='fake_obs') - self.base_env.step.return_value = fake_timestep - self.base_env.stats.return_value = {} - - # Act. - ts = wrapped_env.step({ - 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32), - 'touch_position': np.array([0.5, 0.5], dtype=np.float32), - }) - stats = wrapped_env.stats() - - # Assert. - self.assertEqual(wrapped_env._num_frames+1, self.base_env.step.call_count) - self.assertIsInstance(ts, dm_env.TimeStep) - self.assertIsInstance(stats, dict) - self.assertIn('env_steps', stats) - self.assertEqual(stats['env_steps'], 6) - - def test_observation_spec(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=5) - fake_obs_spec = 'fake_obs_spec' - self.base_env.observation_spec.return_value = fake_obs_spec - observation_spec = wrapped_env.observation_spec() - self.base_env.observation_spec.assert_called_once() - self.assertEqual(fake_obs_spec, observation_spec) - - def test_action_spec(self): - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=5) - self.base_env.action_spec.return_value = self._base_action_spec - action_spec = wrapped_env.action_spec() - self.base_env.action_spec.assert_called() - self.assertEqual(self.base_env.action_spec(), - action_spec) - - def test_stats(self): - """Checks that returned stats have expected properties.""" - - # Arrange. - self.base_env.stats.return_value = { - 'some_key': 12345, - 'another_key': 5.4321, - } - wrapped_env = tap_action_wrapper.TapActionWrapper( - self.base_env, num_frames=5 - ) - - # Act. - stats = wrapped_env.stats() - - # Assert. - self.assertIsInstance(stats, dict) - # Original entries should still be present. - self.assertIn('some_key', stats) - self.assertEqual(stats['some_key'], 12345) - self.assertIn('another_key', stats) - self.assertEqual(stats['another_key'], 5.4321) - # TapActionWrapper inserts its own `env_steps`. - self.assertIn('env_steps', stats) - self.assertEqual(stats['env_steps'], 0) - - -if __name__ == '__main__': - absltest.main() diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index 2f66bf75..00000000 --- a/examples/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/examples/run_acme_agent.py b/examples/run_acme_agent.py deleted file mode 100644 index 754a0436..00000000 --- a/examples/run_acme_agent.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Acme DQN agent interacting with AndroidEnv.""" - -from absl import app -from absl import flags -from absl import logging -import acme -from acme import specs -from acme import wrappers as acme_wrappers -from acme.agents.tf import dqn -from acme.tf import networks -from android_env import loader -from android_env.components import config_classes -from android_env.wrappers import discrete_action_wrapper -from android_env.wrappers import float_pixels_wrapper -from android_env.wrappers import image_rescale_wrapper - -# Simulator args -flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') -flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') -flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') -flags.DEFINE_string('emulator_path', - '~/Android/Sdk/emulator/emulator', 'Path to emulator.') -flags.DEFINE_string('adb_path', - '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') - -# Environment args -flags.DEFINE_string('task_path', None, 'Path to task textproto file.') - -# Experiment args -flags.DEFINE_integer('num_episodes', 100, 'Number of episodes.') - -FLAGS = flags.FLAGS - - -def apply_wrappers(env): - """Applies a series of wrappers to the environment.""" - env = discrete_action_wrapper.DiscreteActionWrapper(env, action_grid=(10, 10)) - env = image_rescale_wrapper.ImageRescaleWrapper( - env, zoom_factors=(0.25, 0.25)) - env = float_pixels_wrapper.FloatPixelsWrapper(env) - env = acme_wrappers.SinglePrecisionWrapper(env) - return env - - -def main(_): - - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), - simulator=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - run_headless=FLAGS.run_headless, - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path=FLAGS.adb_path - ), - ), - ) - with loader.load(config) as env: - - env = apply_wrappers(env) - env_spec = specs.make_environment_spec(env) - - agent = dqn.DQN( - environment_spec=env_spec, - network=networks.DQNAtariNetwork( - num_actions=env_spec.actions.num_values), - batch_size=10, - samples_per_insert=2, - min_replay_size=10) - - loop = acme.EnvironmentLoop(env, agent) - loop.run(num_episodes=FLAGS.num_episodes) - - -if __name__ == '__main__': - logging.set_verbosity('info') - logging.set_stderrthreshold('info') - flags.mark_flags_as_required(['task_path', 'avd_name']) - app.run(main) diff --git a/examples/run_human_agent.py b/examples/run_human_agent.py deleted file mode 100644 index 505a729b..00000000 --- a/examples/run_human_agent.py +++ /dev/null @@ -1,209 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Loads an interactive session where a human acts on behalf of an agent.""" - -import time -from typing import Any - -from absl import app -from absl import flags -from absl import logging -from android_env import loader -from android_env.components import action_type -from android_env.components import config_classes -from android_env.components import pixel_fns -import dm_env -import numpy as np -import pygame - -# Simulator args. -flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') -flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') -flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') -flags.DEFINE_string('emulator_path', - '~/Android/Sdk/emulator/emulator', 'Path to emulator.') -flags.DEFINE_string('adb_path', - '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') -flags.DEFINE_boolean('run_headless', True, 'Optionally turn off display.') - -# Environment args. -flags.DEFINE_string('task_path', None, 'Path to task textproto file.') - -# Pygame args. -flags.DEFINE_list('screen_size', '480,720', 'Screen width, height in pixels.') -flags.DEFINE_float('frame_rate', 1.0/30.0, 'Frame rate in seconds.') - -FLAGS = flags.FLAGS - - -def _get_action_from_event( - event: pygame.event.Event, screen: pygame.Surface, orientation: int -) -> dict[str, Any]: - """Returns the current action by reading data from a pygame Event object.""" - - act_type = action_type.ActionType.LIFT - if event.type == pygame.MOUSEBUTTONDOWN: - act_type = action_type.ActionType.TOUCH - - return { - 'action_type': - np.array(act_type, dtype=np.int32), - 'touch_position': - _scale_position(event.pos, screen, orientation), - } - - -def _get_action_from_mouse( - screen: pygame.Surface, orientation: int -) -> dict[str, Any]: - """Returns the current action by reading data from the mouse.""" - - act_type = action_type.ActionType.LIFT - if pygame.mouse.get_pressed()[0]: - act_type = action_type.ActionType.TOUCH - - return { - 'action_type': - np.array(act_type, dtype=np.int32), - 'touch_position': - _scale_position(pygame.mouse.get_pos(), screen, orientation), - } - - -def _scale_position(position: np.ndarray, screen: pygame.Surface, - orientation: int) -> np.ndarray: - """AndroidEnv accepts mouse inputs as floats so we need to scale it.""" - - scaled_pos = np.divide(position, screen.get_size(), dtype=np.float32) - if orientation == 1: # LANDSCAPE_90 - scaled_pos = scaled_pos[::-1] - scaled_pos[0] = 1 - scaled_pos[0] - return scaled_pos - - -def _accumulate_reward( - timestep: dm_env.TimeStep, - episode_return: float) -> float: - """Accumulates rewards collected over the course of an episode.""" - - if timestep.reward and timestep.reward != 0: - logging.info('Reward: %s', timestep.reward) - episode_return += timestep.reward - - if timestep.first(): - episode_return = 0 - elif timestep.last(): - logging.info('Episode return: %s', episode_return) - - return episode_return - - -def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface, - orientation: int, timestep: dm_env.TimeStep) -> None: - """Displays latest observation on pygame surface.""" - - frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB) - frame = pixel_fns.transpose_pixels(frame) # (W x H x C) - frame = pixel_fns.orient_pixels(frame, orientation) - - pygame.surfarray.blit_array(surface, frame) - pygame.transform.smoothscale(surface, screen.get_size(), screen) - - pygame.display.flip() - - -def main(_): - - pygame.init() - pygame.display.set_caption('android_human_agent') - - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), - simulator=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - run_headless=FLAGS.run_headless, - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path=FLAGS.adb_path - ), - ), - ) - with loader.load(config) as env: - - # Reset environment. - first_timestep = env.reset() - orientation = np.argmax(first_timestep.observation['orientation']) - - # Create pygame canvas. - screen_size = list(map(int, FLAGS.screen_size)) # (W x H) - obs_shape = env.observation_spec()['pixels'].shape[:2] # (H x W) - - if (orientation == 1 or orientation == 3): # LANDSCAPE_90 | LANDSCAPE_270 - screen_size = screen_size[::-1] - obs_shape = obs_shape[::-1] - - screen = pygame.display.set_mode(screen_size) # takes (W x H) - surface = pygame.Surface(obs_shape[::-1]) # takes (W x H) - - # Start game loop. - prev_frame = time.time() - episode_return = 0 - - while True: - if pygame.key.get_pressed()[pygame.K_ESCAPE]: - return - - all_events = pygame.event.get() - for event in all_events: - if event.type == pygame.QUIT: - return - - # Filter event queue for mouse click events. - mouse_click_events = [ - event for event in all_events - if event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP] - ] - - # Process all mouse click events. - for event in mouse_click_events: - action = _get_action_from_event(event, screen, orientation) - timestep = env.step(action) - episode_return = _accumulate_reward(timestep, episode_return) - _render_pygame_frame(surface, screen, orientation, timestep) - - # Sample the current position of the mouse either way. - action = _get_action_from_mouse(screen, orientation) - timestep = env.step(action) - episode_return = _accumulate_reward(timestep, episode_return) - _render_pygame_frame(surface, screen, orientation, timestep) - - # Limit framerate. - now = time.time() - frame_time = now - prev_frame - if frame_time < FLAGS.frame_rate: - time.sleep(FLAGS.frame_rate - frame_time) - prev_frame = now - - -if __name__ == '__main__': - logging.set_verbosity('info') - logging.set_stderrthreshold('info') - flags.mark_flags_as_required(['avd_name', 'task_path']) - app.run(main) diff --git a/examples/run_random_agent.py b/examples/run_random_agent.py deleted file mode 100644 index fc8017b4..00000000 --- a/examples/run_random_agent.py +++ /dev/null @@ -1,90 +0,0 @@ -# coding=utf-8 -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Example script demonstrating usage of AndroidEnv.""" - -from absl import app -from absl import flags -from absl import logging -from android_env import loader -from android_env.components import config_classes -from dm_env import specs -import numpy as np - -FLAGS = flags.FLAGS - -# Simulator args. -flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') -flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') -flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') -flags.DEFINE_string('emulator_path', - '~/Android/Sdk/emulator/emulator', 'Path to emulator.') -flags.DEFINE_string('adb_path', - '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') -flags.DEFINE_bool('run_headless', False, - 'Whether to display the emulator window.') - -# Environment args. -flags.DEFINE_string('task_path', None, 'Path to task textproto file.') - -# Experiment args. -flags.DEFINE_integer('num_steps', 1000, 'Number of steps to take.') - - -def main(_): - - config = config_classes.AndroidEnvConfig( - task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), - simulator=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - run_headless=FLAGS.run_headless, - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path=FLAGS.adb_path - ), - ), - ) - with loader.load(config) as env: - - action_spec = env.action_spec() - - def get_random_action() -> dict[str, np.ndarray]: - """Returns a random AndroidEnv action.""" - action = {} - for k, v in action_spec.items(): - if isinstance(v, specs.DiscreteArray): - action[k] = np.random.randint(low=0, high=v.num_values, dtype=v.dtype) - else: - action[k] = np.random.random(size=v.shape).astype(v.dtype) - return action - - _ = env.reset() - - for step in range(FLAGS.num_steps): - action = get_random_action() - timestep = env.step(action=action) - reward = timestep.reward - logging.info('Step %r, action: %r, reward: %r', step, action, reward) - - -if __name__ == '__main__': - logging.set_verbosity('info') - logging.set_stderrthreshold('info') - flags.mark_flags_as_required(['avd_name', 'task_path']) - app.run(main) diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 2504653f..00000000 --- a/pyproject.toml +++ /dev/null @@ -1,39 +0,0 @@ -[build-system] -requires = [ - "setuptools", - "wheel" -] -build-backend = "setuptools.build_meta" - -[project] -name = "android-env" -version = "1.2.2" -description = "AndroidEnv environment and library for training agents." -authors = [{name = "DeepMind"}] -license = {file = "LICENSE"} -readme = {text = "Read the README at https://github.com/deepmind/android_env for more information.", content-type = "text/plain"} -keywords = ["Android", "OS", "reinforcement-learning"] -requires-python = ">=3.10" -dependencies = [ - "absl-py>=0.1.0", - "dm_env", - "grpcio", - "numpy>=1.21", - "portpicker>=1.2.0", - "protobuf>=2.6", - "pygame", -] - -[project.optional-dependencies] -testing = [ - "gym", - "pillow", - "pytype", -] -acme = ["dm-acme"] -gym = ["gym"] - -[project.urls] -repository = "https://github.com/deepmind/android_env" -deepmind = "https://www.deepmind.com/publications/androidenv-the-android-learning-environment" -arxiv = "https://arxiv.org/abs/2105.13231" diff --git a/setup.py b/setup.py deleted file mode 100644 index 2c201d0b..00000000 --- a/setup.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simple package definition for using with `pip`.""" - -import importlib -import os - -import setuptools -from setuptools import find_packages -from setuptools import setup -from setuptools.command.build_ext import build_ext -from setuptools.command.build_py import build_py - -_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) - -# Tuple of proto message definitions to build Python bindings for. Paths must -# be relative to root directory. -_ANDROID_ENV_PROTOS = ( - 'android_env/proto/adb.proto', - 'android_env/proto/emulator_controller.proto', - 'android_env/proto/snapshot.proto', - 'android_env/proto/snapshot_service.proto', - 'android_env/proto/state.proto', - 'android_env/proto/task.proto', - 'android_env/proto/a11y/a11y.proto', - 'android_env/proto/a11y/android_accessibility_action.proto', - 'android_env/proto/a11y/android_accessibility_forest.proto', - 'android_env/proto/a11y/android_accessibility_node_info.proto', - 'android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto', - 'android_env/proto/a11y/android_accessibility_tree.proto', - 'android_env/proto/a11y/android_accessibility_window_info.proto', - 'android_env/proto/a11y/rect.proto', -) - - -class _GenerateProtoFiles(setuptools.Command): - """Command to generate protobuf bindings for AndroidEnv protos.""" - - descriptions = 'Generates Python protobuf bindings for AndroidEnv protos.' - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - # Import grpc_tools here, after setuptools has installed setup_requires - # dependencies. - from grpc_tools import protoc # pylint: disable=g-import-not-at-top - - with importlib.resources.as_file( - importlib.resources.files('grpc_tools').joinpath('_proto') - ) as path: - grpc_protos_include = str(path) - - for proto_path in _ANDROID_ENV_PROTOS: - proto_args = [ - 'grpc_tools.protoc', - '--proto_path={}'.format(grpc_protos_include), - '--proto_path={}'.format(_ROOT_DIR), - '--python_out={}'.format(_ROOT_DIR), - '--pyi_out={}'.format(_ROOT_DIR), - '--grpc_python_out={}'.format(_ROOT_DIR), - os.path.join(_ROOT_DIR, proto_path), - ] - if protoc.main(proto_args) != 0: - raise RuntimeError('ERROR: {}'.format(proto_args)) - - -class _BuildExt(build_ext): - """Generate protobuf bindings in build_ext stage.""" - - def run(self): - self.run_command('generate_protos') - build_ext.run(self) - - -class _BuildPy(build_py): - """Generate protobuf bindings in build_py stage.""" - - def run(self): - self.run_command('generate_protos') - build_py.run(self) - -setup( - packages=find_packages(exclude=['examples']), - package_data={'': ['proto/*.proto']}, # Copy protobuf files. - include_package_data=True, - setup_requires=['grpcio-tools'], - cmdclass={ - 'build_ext': _BuildExt, - 'build_py': _BuildPy, - 'generate_protos': _GenerateProtoFiles, - }, -)