diff --git a/app/build.gradle.kts b/app/build.gradle.kts
index 6c333e7..e12cf50 100644
--- a/app/build.gradle.kts
+++ b/app/build.gradle.kts
@@ -56,5 +56,6 @@ dependencies {
implementation("androidx.appcompat:appcompat:1.7.0")
implementation("com.google.android.material:material:1.12.0")
implementation("androidx.activity:activity-ktx:1.9.0")
+ implementation("androidx.work:work-runtime-ktx:2.9.1")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0")
}
diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml
index f518a1b..5126f4c 100644
--- a/app/src/main/AndroidManifest.xml
+++ b/app/src/main/AndroidManifest.xml
@@ -2,6 +2,8 @@
+
+
- runOnUiThread {
- statusView.text = "${progress.file} ${progress.completed}/${progress.totalFiles}"
+ // Models download in a foreground worker so the transfer survives
+ // backgrounding the app. Activity just observes progress.
+ val workId = ModelDownloadWorker.enqueue(applicationContext, ModelPrecision.INT8)
+ WorkManager.getInstance(applicationContext)
+ .getWorkInfoByIdLiveData(workId)
+ .observe(this) { info ->
+ if (info == null) return@observe
+ when (info.state) {
+ WorkInfo.State.ENQUEUED,
+ WorkInfo.State.BLOCKED,
+ WorkInfo.State.RUNNING -> {
+ val total = info.progress.getInt(ModelDownloadWorker.KEY_TOTAL, 0)
+ if (total > 0) {
+ val file = info.progress.getString(ModelDownloadWorker.KEY_FILE) ?: ""
+ val done = info.progress.getInt(ModelDownloadWorker.KEY_COMPLETED, 0)
+ statusView.text = "$file $done/$total"
+ }
+ }
+ WorkInfo.State.SUCCEEDED -> {
+ val modelDir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR)
+ if (modelDir == null) {
+ statusView.text = "worker succeeded but no model dir"
+ return@observe
+ }
+ initPipeline(modelDir)
}
+ WorkInfo.State.FAILED -> {
+ val err = info.outputData.getString(ModelDownloadWorker.KEY_ERROR)
+ ?: "unknown"
+ statusView.text = "download failed: $err"
+ }
+ WorkInfo.State.CANCELLED -> { statusView.text = "cancelled" }
}
+ }
+ }
+ private fun initPipeline(modelDir: String) {
+ lifecycleScope.launch {
+ try {
val config = SpeechConfig(
modelDir = modelDir,
useNnapi = false,
diff --git a/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt b/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt
index 41d654e..6dab0ce 100644
--- a/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt
+++ b/app/src/main/kotlin/com/soniqo/speech/demo/MainActivity.kt
@@ -25,7 +25,9 @@ import androidx.core.app.ActivityCompat
import androidx.core.view.ViewCompat
import androidx.core.view.WindowInsetsCompat
import androidx.lifecycle.lifecycleScope
-import audio.soniqo.speech.ModelManager
+import androidx.work.WorkInfo
+import androidx.work.WorkManager
+import audio.soniqo.speech.ModelDownloadWorker
import audio.soniqo.speech.ModelPrecision
import audio.soniqo.speech.SpeechConfig
import audio.soniqo.speech.SpeechEvent
@@ -218,24 +220,52 @@ class MainActivity : ComponentActivity() {
private fun loadPipeline() {
setStatus("initializing...")
- lifecycleScope.launch {
- try {
- val modelDir = ModelManager.ensureModels(
- this@MainActivity,
- precision = ModelPrecision.INT8,
- ) { progress ->
- val mb = progress.bytesDownloaded / 1_000_000
- setStatus("${progress.file} ${progress.completed}/${progress.totalFiles} (${mb} MB)")
- runOnUiThread {
- downloadProgress.progress =
- (progress.completed * 100 / progress.totalFiles).coerceIn(0, 100)
+ // Models download in a foreground worker so the transfer survives
+ // backgrounding the app. Activity just observes progress.
+ val workId = ModelDownloadWorker.enqueue(applicationContext, ModelPrecision.INT8)
+ WorkManager.getInstance(applicationContext)
+ .getWorkInfoByIdLiveData(workId)
+ .observe(this) { info ->
+ if (info == null) return@observe
+ when (info.state) {
+ WorkInfo.State.ENQUEUED,
+ WorkInfo.State.BLOCKED,
+ WorkInfo.State.RUNNING -> {
+ val total = info.progress.getInt(ModelDownloadWorker.KEY_TOTAL, 0)
+ if (total > 0) {
+ val file = info.progress.getString(ModelDownloadWorker.KEY_FILE) ?: ""
+ val done = info.progress.getInt(ModelDownloadWorker.KEY_COMPLETED, 0)
+ val pct = info.progress.getInt(ModelDownloadWorker.KEY_PERCENT, 0)
+ setStatus("$file $done/$total")
+ downloadProgress.progress = pct
+ }
}
+ WorkInfo.State.SUCCEEDED -> {
+ val modelDir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR)
+ if (modelDir == null) {
+ addSystemLine("worker succeeded but no model dir")
+ setStatus("error")
+ return@observe
+ }
+ downloadProgress.progress = 100
+ downloadProgress.visibility = View.GONE
+ initPipeline(modelDir)
+ }
+ WorkInfo.State.FAILED -> {
+ val err = info.outputData.getString(ModelDownloadWorker.KEY_ERROR)
+ ?: "unknown"
+ addSystemLine("download failed: $err")
+ setStatus("error — tap to retry")
+ statusView.setOnClickListener { retryInit() }
+ }
+ WorkInfo.State.CANCELLED -> setStatus("cancelled")
}
- runOnUiThread {
- downloadProgress.progress = 100
- downloadProgress.visibility = View.GONE
- }
+ }
+ }
+ private fun initPipeline(modelDir: String) {
+ lifecycleScope.launch {
+ try {
val config = SpeechConfig(
modelDir = modelDir,
useNnapi = !isEmulator,
diff --git a/sdk/build.gradle.kts b/sdk/build.gradle.kts
index a7d81a6..362f4eb 100644
--- a/sdk/build.gradle.kts
+++ b/sdk/build.gradle.kts
@@ -105,6 +105,8 @@ dependencies {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0")
implementation("com.squareup.okhttp3:okhttp:4.12.0")
implementation("androidx.annotation:annotation:1.8.2")
+ implementation("androidx.work:work-runtime-ktx:2.9.1")
+ implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0")
@@ -113,6 +115,7 @@ dependencies {
testImplementation("androidx.test:core:1.6.1")
testImplementation("androidx.test.ext:junit:1.2.1")
testImplementation("io.mockk:mockk:1.13.13")
+ testImplementation("androidx.work:work-testing:2.9.1")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test:runner:1.6.2")
diff --git a/sdk/src/main/AndroidManifest.xml b/sdk/src/main/AndroidManifest.xml
index fbb2ef8..cbeb09d 100644
--- a/sdk/src/main/AndroidManifest.xml
+++ b/sdk/src/main/AndroidManifest.xml
@@ -1,3 +1,22 @@
-
+
+
+
+
+
+
+
+
+
+
diff --git a/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt b/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt
new file mode 100644
index 0000000..ffc903b
--- /dev/null
+++ b/sdk/src/main/kotlin/com/soniqo/speech/ModelDownloadWorker.kt
@@ -0,0 +1,177 @@
+package audio.soniqo.speech
+
+import android.app.NotificationChannel
+import android.app.NotificationManager
+import android.content.Context
+import android.content.pm.ServiceInfo
+import android.os.Build
+import androidx.core.app.NotificationCompat
+import androidx.work.CoroutineWorker
+import androidx.work.ExistingWorkPolicy
+import androidx.work.ForegroundInfo
+import androidx.work.OneTimeWorkRequestBuilder
+import androidx.work.WorkManager
+import androidx.work.WorkerParameters
+import androidx.work.workDataOf
+import java.io.IOException
+
+/**
+ * Downloads the speech models in a foreground worker so the transfer survives
+ * app backgrounding and process death. Wraps [ModelManager.ensureModels] —
+ * resumes partial downloads via the same on-disk `.tmp` files, retries on
+ * `IOException`, and reports progress via [setProgress].
+ *
+ * ### Usage
+ *
+ * ```
+ * WorkManager.getInstance(context).enqueueUniqueWork(
+ * ModelDownloadWorker.UNIQUE_NAME,
+ * ExistingWorkPolicy.KEEP,
+ * ModelDownloadWorker.request(ModelPrecision.INT8),
+ * )
+ *
+ * WorkManager.getInstance(context)
+ * .getWorkInfosForUniqueWorkLiveData(ModelDownloadWorker.UNIQUE_NAME)
+ * .observe(this) { infos ->
+ * val info = infos.firstOrNull() ?: return@observe
+ * when (info.state) {
+ * WorkInfo.State.RUNNING -> {
+ * val pct = info.progress.getInt(ModelDownloadWorker.KEY_PERCENT, 0)
+ * ...
+ * }
+ * WorkInfo.State.SUCCEEDED -> {
+ * val dir = info.outputData.getString(ModelDownloadWorker.KEY_MODEL_DIR)
+ * ...
+ * }
+ * else -> Unit
+ * }
+ * }
+ * ```
+ *
+ * Requires the host app to declare `POST_NOTIFICATIONS` (API 33+) for the
+ * progress notification to appear; the worker still runs without it.
+ */
+class ModelDownloadWorker(
+ context: Context,
+ params: WorkerParameters,
+) : CoroutineWorker(context, params) {
+
+ override suspend fun doWork(): Result {
+ val precision = inputData.getString(KEY_PRECISION)
+ ?.let { runCatching { ModelPrecision.valueOf(it) }.getOrNull() }
+ ?: ModelPrecision.INT8
+
+ runCatching { setForeground(buildForegroundInfo(0, 0, "Preparing speech models…")) }
+
+ return try {
+ val modelDir = ModelManager.ensureModels(applicationContext, precision) { p ->
+ val pct = if (p.totalFiles > 0) {
+ (p.completed * 100 / p.totalFiles).coerceIn(0, 100)
+ } else 0
+ setProgressAsync(workDataOf(
+ KEY_FILE to p.file,
+ KEY_COMPLETED to p.completed,
+ KEY_TOTAL to p.totalFiles,
+ KEY_BYTES_DOWNLOADED to p.bytesDownloaded,
+ KEY_PERCENT to pct,
+ ))
+ runCatching {
+ setForegroundAsync(buildForegroundInfo(
+ completed = p.completed,
+ total = p.totalFiles,
+ text = "${p.file} ${p.completed}/${p.totalFiles}",
+ ))
+ }
+ }
+ Result.success(workDataOf(KEY_MODEL_DIR to modelDir))
+ } catch (e: IOException) {
+ // Network / disk hiccup — let WorkManager retry with backoff.
+ Result.retry()
+ } catch (t: Throwable) {
+ Result.failure(workDataOf(KEY_ERROR to (t.message ?: t::class.java.simpleName)))
+ }
+ }
+
+ private fun buildForegroundInfo(completed: Int, total: Int, text: String): ForegroundInfo {
+ ensureChannel()
+ val indeterminate = total <= 0
+ val notif = NotificationCompat.Builder(applicationContext, CHANNEL_ID)
+ .setContentTitle("Speech models")
+ .setContentText(text)
+ .setSmallIcon(android.R.drawable.stat_sys_download)
+ .setProgress(if (indeterminate) 100 else total, completed, indeterminate)
+ .setOngoing(true)
+ .setOnlyAlertOnce(true)
+ .setPriority(NotificationCompat.PRIORITY_LOW)
+ .build()
+ return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
+ ForegroundInfo(NOTIFICATION_ID, notif, ServiceInfo.FOREGROUND_SERVICE_TYPE_DATA_SYNC)
+ } else {
+ ForegroundInfo(NOTIFICATION_ID, notif)
+ }
+ }
+
+ private fun ensureChannel() {
+ if (Build.VERSION.SDK_INT < Build.VERSION_CODES.O) return
+ val nm = applicationContext.getSystemService(NotificationManager::class.java) ?: return
+ if (nm.getNotificationChannel(CHANNEL_ID) != null) return
+ nm.createNotificationChannel(NotificationChannel(
+ CHANNEL_ID,
+ "Speech model downloads",
+ NotificationManager.IMPORTANCE_LOW,
+ ).apply { description = "Progress for downloading on-device speech models" })
+ }
+
+ companion object {
+ /** Pass to [WorkManager.enqueueUniqueWork] to dedupe concurrent downloads. */
+ const val UNIQUE_NAME = "audio.soniqo.speech.modelDownload"
+
+ // Input keys
+ const val KEY_PRECISION = "precision"
+
+ // Output keys
+ const val KEY_MODEL_DIR = "modelDir"
+ const val KEY_ERROR = "error"
+
+ // Progress keys
+ const val KEY_FILE = "file"
+ const val KEY_COMPLETED = "completed"
+ const val KEY_TOTAL = "totalFiles"
+ const val KEY_BYTES_DOWNLOADED = "bytesDownloaded"
+ const val KEY_PERCENT = "percent"
+
+ private const val CHANNEL_ID = "audio.soniqo.speech.models"
+ // Stable, unlikely-to-collide id (decimal of 0xC0FFEE).
+ private const val NOTIFICATION_ID = 12648430
+
+ /**
+ * Build a one-shot download request. No JobScheduler network
+ * constraint — the underlying OkHttp client surfaces network failures
+ * as `IOException`, which the worker translates into `Result.retry()`.
+ * Avoids JobScheduler's `CONSTRAINT_CONNECTIVITY` waiting on a
+ * `VALIDATED` capability, which can sit unsatisfied for a long time
+ * on flaky or captive networks even when the device has working
+ * internet.
+ */
+ fun request(precision: ModelPrecision = ModelPrecision.INT8) =
+ OneTimeWorkRequestBuilder()
+ .setInputData(workDataOf(KEY_PRECISION to precision.name))
+ .build()
+
+ /**
+ * Convenience: enqueue under the standard unique name with
+ * [ExistingWorkPolicy.KEEP] (a running download is reused; otherwise a
+ * new one starts). Returns the request id so callers can observe it.
+ */
+ fun enqueue(
+ context: Context,
+ precision: ModelPrecision = ModelPrecision.INT8,
+ ): java.util.UUID {
+ val req = request(precision)
+ WorkManager.getInstance(context).enqueueUniqueWork(
+ UNIQUE_NAME, ExistingWorkPolicy.KEEP, req,
+ )
+ return req.id
+ }
+ }
+}
diff --git a/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt b/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt
new file mode 100644
index 0000000..24dd080
--- /dev/null
+++ b/sdk/src/test/kotlin/audio/soniqo/speech/ModelDownloadWorkerTest.kt
@@ -0,0 +1,143 @@
+package audio.soniqo.speech
+
+import android.content.Context
+import androidx.test.core.app.ApplicationProvider
+import androidx.work.ListenableWorker
+import androidx.work.testing.TestListenableWorkerBuilder
+import androidx.work.workDataOf
+import io.mockk.coEvery
+import io.mockk.coVerify
+import io.mockk.mockkObject
+import io.mockk.unmockkAll
+import kotlinx.coroutines.runBlocking
+import org.junit.After
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.robolectric.RobolectricTestRunner
+import org.robolectric.annotation.Config
+import java.io.IOException
+
+/**
+ * Robolectric tests for [ModelDownloadWorker].
+ *
+ * Mocks the [ModelManager] singleton via mockk so the worker can be exercised
+ * without touching the network or the file system. Uses
+ * [TestListenableWorkerBuilder] which gives the worker a real `Context` and
+ * stubs out the `setForeground` / `setProgress` plumbing — sufficient to
+ * assert the doWork() result contract.
+ */
+@RunWith(RobolectricTestRunner::class)
+@Config(sdk = [33])
+class ModelDownloadWorkerTest {
+
+ private lateinit var context: Context
+
+ @Before
+ fun setUp() {
+ context = ApplicationProvider.getApplicationContext()
+ mockkObject(ModelManager)
+ }
+
+ @After
+ fun tearDown() {
+ unmockkAll()
+ }
+
+ @Test
+ fun doWork_success_returnsModelDirInOutputData() = runBlocking {
+ coEvery {
+ ModelManager.ensureModels(any(), any(), any())
+ } returns "/fake/model/dir"
+
+ val worker = TestListenableWorkerBuilder(context)
+ .setInputData(workDataOf(ModelDownloadWorker.KEY_PRECISION to "INT8"))
+ .build()
+
+ val result = worker.doWork()
+
+ assertTrue("expected Success, got $result", result is ListenableWorker.Result.Success)
+ val output = (result as ListenableWorker.Result.Success).outputData
+ assertEquals("/fake/model/dir", output.getString(ModelDownloadWorker.KEY_MODEL_DIR))
+ }
+
+ @Test
+ fun doWork_ioException_returnsRetry() = runBlocking {
+ // The worker should bubble transient network/disk failures up to
+ // WorkManager so it reschedules with exponential backoff.
+ coEvery {
+ ModelManager.ensureModels(any(), any(), any())
+ } throws IOException("network down")
+
+ val worker = TestListenableWorkerBuilder(context).build()
+ val result = worker.doWork()
+
+ assertTrue("expected Retry, got $result", result is ListenableWorker.Result.Retry)
+ }
+
+ @Test
+ fun doWork_genericThrowable_returnsFailureWithMessage() = runBlocking {
+ // Non-IO exceptions are not transient (e.g. corrupt manifest, OOM)
+ // — emit Failure with the message in outputData so the host activity
+ // can surface a useful error.
+ coEvery {
+ ModelManager.ensureModels(any(), any(), any())
+ } throws IllegalStateException("models corrupt")
+
+ val worker = TestListenableWorkerBuilder(context).build()
+ val result = worker.doWork()
+
+ assertTrue("expected Failure, got $result", result is ListenableWorker.Result.Failure)
+ val output = (result as ListenableWorker.Result.Failure).outputData
+ assertEquals("models corrupt", output.getString(ModelDownloadWorker.KEY_ERROR))
+ }
+
+ @Test
+ fun doWork_invalidPrecisionInput_defaultsToInt8() = runBlocking {
+ coEvery {
+ ModelManager.ensureModels(any(), any(), any())
+ } returns "/fake"
+
+ val worker = TestListenableWorkerBuilder(context)
+ .setInputData(workDataOf(ModelDownloadWorker.KEY_PRECISION to "NOT_A_PRECISION"))
+ .build()
+
+ worker.doWork()
+
+ coVerify(exactly = 1) {
+ ModelManager.ensureModels(any(), ModelPrecision.INT8, any())
+ }
+ }
+
+ @Test
+ fun doWork_missingPrecisionInput_defaultsToInt8() = runBlocking {
+ coEvery {
+ ModelManager.ensureModels(any(), any(), any())
+ } returns "/fake"
+
+ val worker = TestListenableWorkerBuilder(context).build()
+ worker.doWork()
+
+ coVerify(exactly = 1) {
+ ModelManager.ensureModels(any(), ModelPrecision.INT8, any())
+ }
+ }
+
+ @Test
+ fun request_buildsRequestWithPrecisionInputDataAndNoNetworkConstraint() {
+ val req = ModelDownloadWorker.request(ModelPrecision.INT8)
+
+ assertEquals(
+ "INT8",
+ req.workSpec.input.getString(ModelDownloadWorker.KEY_PRECISION),
+ )
+ // No JobScheduler network constraint — the worker handles network
+ // failures itself via IOException → retry. See KDoc on `request()`.
+ assertEquals(
+ androidx.work.NetworkType.NOT_REQUIRED,
+ req.workSpec.constraints.requiredNetworkType,
+ )
+ }
+}