diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 6c333e7..e12cf50 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -56,5 +56,6 @@ dependencies { implementation("androidx.appcompat:appcompat:1.7.0") implementation("com.google.android.material:material:1.12.0") implementation("androidx.activity:activity-ktx:1.9.0") + implementation("androidx.work:work-runtime-ktx:2.9.1") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0") } diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index f518a1b..5126f4c 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -2,6 +2,8 @@ + + - runOnUiThread { - statusView.text = "${progress.file} ${progress.completed}/${progress.totalFiles}" + // Models download in a foreground worker so the transfer survives + // backgrounding the app. Activity just observes progress. + val workId = ModelDownloadWorker.enqueue(applicationContext, ModelPrecision.INT8) + WorkManager.getInstance(applicationContext) + .getWorkInfoByIdLiveData(workId) + .observe(this) { info -> + if (info == null) return@observe + when (info.state) { + WorkInfo.State.ENQUEUED, + WorkInfo.State.BLOCKED, + WorkInfo.State.RUNNING -> { + val total = info.progress.getInt(ModelDownloadWorker.KEY_TOTAL, 0) + if (total > 0) { + val file = info.progress.getString(ModelDownloadWorker.KEY_FILE) ?: "" + val done = info.progress.getInt(ModelDownloadWorker.KEY_COMPLETED, 0) + statusView.text = "$file $done/$total" + } + } + WorkInfo.State.SUCCEEDED -> { + val modelDir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR) + if (modelDir == null) { + statusView.text = "worker succeeded but no model dir" + return@observe + } + initPipeline(modelDir) } + WorkInfo.State.FAILED -> { + val err = info.outputData.getString(ModelDownloadWorker.KEY_ERROR) + ?: "unknown" + statusView.text = "download failed: $err" + } + WorkInfo.State.CANCELLED -> { statusView.text = "cancelled" } } + } + } + private fun initPipeline(modelDir: String) { + lifecycleScope.launch { + try { val config = SpeechConfig( modelDir = modelDir, useNnapi = false, diff --git a/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt b/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt index 41d654e..6dab0ce 100644 --- a/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt +++ b/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt @@ -25,7 +25,9 @@ import androidx.core.app.ActivityCompat import androidx.core.view.ViewCompat import androidx.core.view.WindowInsetsCompat import androidx.lifecycle.lifecycleScope -import audio.soniqo.speech.ModelManager +import androidx.work.WorkInfo +import androidx.work.WorkManager +import audio.soniqo.speech.ModelDownloadWorker import audio.soniqo.speech.ModelPrecision import audio.soniqo.speech.SpeechConfig import audio.soniqo.speech.SpeechEvent @@ -218,24 +220,52 @@ class MainActivity : ComponentActivity() { private fun loadPipeline() { setStatus("initializing...") - lifecycleScope.launch { - try { - val modelDir = ModelManager.ensureModels( - this@MainActivity, - precision = ModelPrecision.INT8, - ) { progress -> - val mb = progress.bytesDownloaded / 1_000_000 - setStatus("${progress.file} ${progress.completed}/${progress.totalFiles} (${mb} MB)") - runOnUiThread { - downloadProgress.progress = - (progress.completed * 100 / progress.totalFiles).coerceIn(0, 100) + // Models download in a foreground worker so the transfer survives + // backgrounding the app. Activity just observes progress. + val workId = ModelDownloadWorker.enqueue(applicationContext, ModelPrecision.INT8) + WorkManager.getInstance(applicationContext) + .getWorkInfoByIdLiveData(workId) + .observe(this) { info -> + if (info == null) return@observe + when (info.state) { + WorkInfo.State.ENQUEUED, + WorkInfo.State.BLOCKED, + WorkInfo.State.RUNNING -> { + val total = info.progress.getInt(ModelDownloadWorker.KEY_TOTAL, 0) + if (total > 0) { + val file = info.progress.getString(ModelDownloadWorker.KEY_FILE) ?: "" + val done = info.progress.getInt(ModelDownloadWorker.KEY_COMPLETED, 0) + val pct = info.progress.getInt(ModelDownloadWorker.KEY_PERCENT, 0) + setStatus("$file $done/$total") + downloadProgress.progress = pct + } } + WorkInfo.State.SUCCEEDED -> { + val modelDir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR) + if (modelDir == null) { + addSystemLine("worker succeeded but no model dir") + setStatus("error") + return@observe + } + downloadProgress.progress = 100 + downloadProgress.visibility = View.GONE + initPipeline(modelDir) + } + WorkInfo.State.FAILED -> { + val err = info.outputData.getString(ModelDownloadWorker.KEY_ERROR) + ?: "unknown" + addSystemLine("download failed: $err") + setStatus("error — tap to retry") + statusView.setOnClickListener { retryInit() } + } + WorkInfo.State.CANCELLED -> setStatus("cancelled") } - runOnUiThread { - downloadProgress.progress = 100 - downloadProgress.visibility = View.GONE - } + } + } + private fun initPipeline(modelDir: String) { + lifecycleScope.launch { + try { val config = SpeechConfig( modelDir = modelDir, useNnapi = !isEmulator, diff --git a/sdk/build.gradle.kts b/sdk/build.gradle.kts index a7d81a6..362f4eb 100644 --- a/sdk/build.gradle.kts +++ b/sdk/build.gradle.kts @@ -105,6 +105,8 @@ dependencies { implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0") implementation("com.squareup.okhttp3:okhttp:4.12.0") implementation("androidx.annotation:annotation:1.8.2") + implementation("androidx.work:work-runtime-ktx:2.9.1") + implementation("androidx.core:core-ktx:1.13.1") testImplementation("junit:junit:4.13.2") testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0") @@ -113,6 +115,7 @@ dependencies { testImplementation("androidx.test:core:1.6.1") testImplementation("androidx.test.ext:junit:1.2.1") testImplementation("io.mockk:mockk:1.13.13") + testImplementation("androidx.work:work-testing:2.9.1") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test:runner:1.6.2") diff --git a/sdk/src/main/AndroidManifest.xml b/sdk/src/main/AndroidManifest.xml index fbb2ef8..cbeb09d 100644 --- a/sdk/src/main/AndroidManifest.xml +++ b/sdk/src/main/AndroidManifest.xml @@ -1,3 +1,22 @@ - + + + + + + + + + + diff --git a/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt b/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt new file mode 100644 index 0000000..ffc903b --- /dev/null +++ b/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt @@ -0,0 +1,177 @@ +package audio.soniqo.speech + +import android.app.NotificationChannel +import android.app.NotificationManager +import android.content.Context +import android.content.pm.ServiceInfo +import android.os.Build +import androidx.core.app.NotificationCompat +import androidx.work.CoroutineWorker +import androidx.work.ExistingWorkPolicy +import androidx.work.ForegroundInfo +import androidx.work.OneTimeWorkRequestBuilder +import androidx.work.WorkManager +import androidx.work.WorkerParameters +import androidx.work.workDataOf +import java.io.IOException + +/** + * Downloads the speech models in a foreground worker so the transfer survives + * app backgrounding and process death. Wraps [ModelManager.ensureModels] — + * resumes partial downloads via the same on-disk `.tmp` files, retries on + * `IOException`, and reports progress via [setProgress]. + * + * ### Usage + * + * ``` + * WorkManager.getInstance(context).enqueueUniqueWork( + * ModelDownloadWorker.UNIQUE_NAME, + * ExistingWorkPolicy.KEEP, + * ModelDownloadWorker.request(ModelPrecision.INT8), + * ) + * + * WorkManager.getInstance(context) + * .getWorkInfosForUniqueWorkLiveData(ModelDownloadWorker.UNIQUE_NAME) + * .observe(this) { infos -> + * val info = infos.firstOrNull() ?: return@observe + * when (info.state) { + * WorkInfo.State.RUNNING -> { + * val pct = info.progress.getInt(ModelDownloadWorker.KEY_PERCENT, 0) + * ... + * } + * WorkInfo.State.SUCCEEDED -> { + * val dir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR) + * ... + * } + * else -> Unit + * } + * } + * ``` + * + * Requires the host app to declare `POST_NOTIFICATIONS` (API 33+) for the + * progress notification to appear; the worker still runs without it. + */ +class ModelDownloadWorker( + context: Context, + params: WorkerParameters, +) : CoroutineWorker(context, params) { + + override suspend fun doWork(): Result { + val precision = inputData.getString(KEY_PRECISION) + ?.let { runCatching { ModelPrecision.valueOf(it) }.getOrNull() } + ?: ModelPrecision.INT8 + + runCatching { setForeground(buildForegroundInfo(0, 0, "Preparing speech models…")) } + + return try { + val modelDir = ModelManager.ensureModels(applicationContext, precision) { p -> + val pct = if (p.totalFiles > 0) { + (p.completed * 100 / p.totalFiles).coerceIn(0, 100) + } else 0 + setProgressAsync(workDataOf( + KEY_FILE to p.file, + KEY_COMPLETED to p.completed, + KEY_TOTAL to p.totalFiles, + KEY_BYTES_DOWNLOADED to p.bytesDownloaded, + KEY_PERCENT to pct, + )) + runCatching { + setForegroundAsync(buildForegroundInfo( + completed = p.completed, + total = p.totalFiles, + text = "${p.file} ${p.completed}/${p.totalFiles}", + )) + } + } + Result.success(workDataOf(KEY_MODEL_DIR to modelDir)) + } catch (e: IOException) { + // Network / disk hiccup — let WorkManager retry with backoff. + Result.retry() + } catch (t: Throwable) { + Result.failure(workDataOf(KEY_ERROR to (t.message ?: t::class.java.simpleName))) + } + } + + private fun buildForegroundInfo(completed: Int, total: Int, text: String): ForegroundInfo { + ensureChannel() + val indeterminate = total <= 0 + val notif = NotificationCompat.Builder(applicationContext, CHANNEL_ID) + .setContentTitle("Speech models") + .setContentText(text) + .setSmallIcon(android.R.drawable.stat_sys_download) + .setProgress(if (indeterminate) 100 else total, completed, indeterminate) + .setOngoing(true) + .setOnlyAlertOnce(true) + .setPriority(NotificationCompat.PRIORITY_LOW) + .build() + return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + ForegroundInfo(NOTIFICATION_ID, notif, ServiceInfo.FOREGROUND_SERVICE_TYPE_DATA_SYNC) + } else { + ForegroundInfo(NOTIFICATION_ID, notif) + } + } + + private fun ensureChannel() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.O) return + val nm = applicationContext.getSystemService(NotificationManager::class.java) ?: return + if (nm.getNotificationChannel(CHANNEL_ID) != null) return + nm.createNotificationChannel(NotificationChannel( + CHANNEL_ID, + "Speech model downloads", + NotificationManager.IMPORTANCE_LOW, + ).apply { description = "Progress for downloading on-device speech models" }) + } + + companion object { + /** Pass to [WorkManager.enqueueUniqueWork] to dedupe concurrent downloads. */ + const val UNIQUE_NAME = "audio.soniqo.speech.modelDownload" + + // Input keys + const val KEY_PRECISION = "precision" + + // Output keys + const val KEY_MODEL_DIR = "modelDir" + const val KEY_ERROR = "error" + + // Progress keys + const val KEY_FILE = "file" + const val KEY_COMPLETED = "completed" + const val KEY_TOTAL = "totalFiles" + const val KEY_BYTES_DOWNLOADED = "bytesDownloaded" + const val KEY_PERCENT = "percent" + + private const val CHANNEL_ID = "audio.soniqo.speech.models" + // Stable, unlikely-to-collide id (decimal of 0xC0FFEE). + private const val NOTIFICATION_ID = 12648430 + + /** + * Build a one-shot download request. No JobScheduler network + * constraint — the underlying OkHttp client surfaces network failures + * as `IOException`, which the worker translates into `Result.retry()`. + * Avoids JobScheduler's `CONSTRAINT_CONNECTIVITY` waiting on a + * `VALIDATED` capability, which can sit unsatisfied for a long time + * on flaky or captive networks even when the device has working + * internet. + */ + fun request(precision: ModelPrecision = ModelPrecision.INT8) = + OneTimeWorkRequestBuilder() + .setInputData(workDataOf(KEY_PRECISION to precision.name)) + .build() + + /** + * Convenience: enqueue under the standard unique name with + * [ExistingWorkPolicy.KEEP] (a running download is reused; otherwise a + * new one starts). Returns the request id so callers can observe it. + */ + fun enqueue( + context: Context, + precision: ModelPrecision = ModelPrecision.INT8, + ): java.util.UUID { + val req = request(precision) + WorkManager.getInstance(context).enqueueUniqueWork( + UNIQUE_NAME, ExistingWorkPolicy.KEEP, req, + ) + return req.id + } + } +} diff --git a/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt b/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt new file mode 100644 index 0000000..24dd080 --- /dev/null +++ b/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt @@ -0,0 +1,143 @@ +package audio.soniqo.speech + +import android.content.Context +import androidx.test.core.app.ApplicationProvider +import androidx.work.ListenableWorker +import androidx.work.testing.TestListenableWorkerBuilder +import androidx.work.workDataOf +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockkObject +import io.mockk.unmockkAll +import kotlinx.coroutines.runBlocking +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import java.io.IOException + +/** + * Robolectric tests for [ModelDownloadWorker]. + * + * Mocks the [ModelManager] singleton via mockk so the worker can be exercised + * without touching the network or the file system. Uses + * [TestListenableWorkerBuilder] which gives the worker a real `Context` and + * stubs out the `setForeground` / `setProgress` plumbing — sufficient to + * assert the doWork() result contract. + */ +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [33]) +class ModelDownloadWorkerTest { + + private lateinit var context: Context + + @Before + fun setUp() { + context = ApplicationProvider.getApplicationContext() + mockkObject(ModelManager) + } + + @After + fun tearDown() { + unmockkAll() + } + + @Test + fun doWork_success_returnsModelDirInOutputData() = runBlocking { + coEvery { + ModelManager.ensureModels(any(), any(), any()) + } returns "/fake/model/dir" + + val worker = TestListenableWorkerBuilder(context) + .setInputData(workDataOf(ModelDownloadWorker.KEY_PRECISION to "INT8")) + .build() + + val result = worker.doWork() + + assertTrue("expected Success, got $result", result is ListenableWorker.Result.Success) + val output = (result as ListenableWorker.Result.Success).outputData + assertEquals("/fake/model/dir", output.getString(ModelDownloadWorker.KEY_MODEL_DIR)) + } + + @Test + fun doWork_ioException_returnsRetry() = runBlocking { + // The worker should bubble transient network/disk failures up to + // WorkManager so it reschedules with exponential backoff. + coEvery { + ModelManager.ensureModels(any(), any(), any()) + } throws IOException("network down") + + val worker = TestListenableWorkerBuilder(context).build() + val result = worker.doWork() + + assertTrue("expected Retry, got $result", result is ListenableWorker.Result.Retry) + } + + @Test + fun doWork_genericThrowable_returnsFailureWithMessage() = runBlocking { + // Non-IO exceptions are not transient (e.g. corrupt manifest, OOM) + // — emit Failure with the message in outputData so the host activity + // can surface a useful error. + coEvery { + ModelManager.ensureModels(any(), any(), any()) + } throws IllegalStateException("models corrupt") + + val worker = TestListenableWorkerBuilder(context).build() + val result = worker.doWork() + + assertTrue("expected Failure, got $result", result is ListenableWorker.Result.Failure) + val output = (result as ListenableWorker.Result.Failure).outputData + assertEquals("models corrupt", output.getString(ModelDownloadWorker.KEY_ERROR)) + } + + @Test + fun doWork_invalidPrecisionInput_defaultsToInt8() = runBlocking { + coEvery { + ModelManager.ensureModels(any(), any(), any()) + } returns "/fake" + + val worker = TestListenableWorkerBuilder(context) + .setInputData(workDataOf(ModelDownloadWorker.KEY_PRECISION to "NOT_A_PRECISION")) + .build() + + worker.doWork() + + coVerify(exactly = 1) { + ModelManager.ensureModels(any(), ModelPrecision.INT8, any()) + } + } + + @Test + fun doWork_missingPrecisionInput_defaultsToInt8() = runBlocking { + coEvery { + ModelManager.ensureModels(any(), any(), any()) + } returns "/fake" + + val worker = TestListenableWorkerBuilder(context).build() + worker.doWork() + + coVerify(exactly = 1) { + ModelManager.ensureModels(any(), ModelPrecision.INT8, any()) + } + } + + @Test + fun request_buildsRequestWithPrecisionInputDataAndNoNetworkConstraint() { + val req = ModelDownloadWorker.request(ModelPrecision.INT8) + + assertEquals( + "INT8", + req.workSpec.input.getString(ModelDownloadWorker.KEY_PRECISION), + ) + // No JobScheduler network constraint — the worker handles network + // failures itself via IOException → retry. See KDoc on `request()`. + assertEquals( + androidx.work.NetworkType.NOT_REQUIRED, + req.workSpec.constraints.requiredNetworkType, + ) + } +}