diff --git a/app/build.gradle.kts b/app/build.gradle.kts index e12cf50..7169337 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -56,6 +56,6 @@ dependencies { implementation("androidx.appcompat:appcompat:1.7.0") implementation("com.google.android.material:material:1.12.0") implementation("androidx.activity:activity-ktx:1.9.0") - implementation("androidx.work:work-runtime-ktx:2.9.1") + implementation("androidx.work:work-runtime-ktx:2.11.2") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0") } diff --git a/sdk/build.gradle.kts b/sdk/build.gradle.kts index 362f4eb..a86a9da 100644 --- a/sdk/build.gradle.kts +++ b/sdk/build.gradle.kts @@ -105,7 +105,7 @@ dependencies { implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0") implementation("com.squareup.okhttp3:okhttp:4.12.0") implementation("androidx.annotation:annotation:1.8.2") - implementation("androidx.work:work-runtime-ktx:2.9.1") + implementation("androidx.work:work-runtime-ktx:2.11.2") implementation("androidx.core:core-ktx:1.13.1") testImplementation("junit:junit:4.13.2") @@ -115,7 +115,7 @@ dependencies { testImplementation("androidx.test:core:1.6.1") testImplementation("androidx.test.ext:junit:1.2.1") testImplementation("io.mockk:mockk:1.13.13") - testImplementation("androidx.work:work-testing:2.9.1") + testImplementation("androidx.work:work-testing:2.11.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test:runner:1.6.2") diff --git a/sdk/src/main/kotlin/com/soniqo/speech/ModelManager.kt b/sdk/src/main/kotlin/com/soniqo/speech/ModelManager.kt index 53ab5f5..773173a 100644 --- a/sdk/src/main/kotlin/com/soniqo/speech/ModelManager.kt +++ b/sdk/src/main/kotlin/com/soniqo/speech/ModelManager.kt @@ -75,9 +75,10 @@ object ModelManager { * and passes [isValidModel] (right ONNX magic, above the per-file size * floor) and the cached version matches [MODEL_VERSION]. * - * Cheap and side-effect free — does not start a download. Used by paths - * that must answer "are we ready?" without blocking, e.g. - * `SpeechRecognitionService.onCheckRecognitionSupport()`. + * Cheap and side-effect free — does not start a download. Use this from + * `SpeechRecognitionService.onCheckRecognitionSupport()` (or any path + * that must not block) to decide whether to invoke [ensureModels] / + * `ModelDownloadWorker` first. */ fun areModelsReady( context: Context, @@ -102,6 +103,10 @@ object ModelManager { } } + /** Path to the model directory for [precision], without downloading. */ + fun modelDir(context: Context): String = + File(context.filesDir, "models").absolutePath + /** Returns the model directory path, downloading models if needed. */ suspend fun ensureModels( context: Context, @@ -120,8 +125,10 @@ object ModelManager { dir.resolve("voices").listFiles()?.forEach { it.delete() } } - // Clean up leftover partial downloads from previous crashes - dir.walk().filter { it.extension == "tmp" }.forEach { it.delete() } + // Note: leftover .tmp files are intentionally preserved here. If a + // previous run was interrupted, downloadFile resumes via Range: + // bytes=N- on the next attempt. Stale .tmp from an old MODEL_VERSION + // is already wiped above. val fileList = models(precision) // FP32 encoder needs the external data file @@ -235,8 +242,11 @@ object ModelManager { } } - // All retries exhausted — clean up partial file and throw - tmp.delete() + // All retries exhausted — preserve the partial .tmp so the next + // ensureModels() call can pick up where this one left off via the + // Range: header. Particularly important when called from + // ModelDownloadWorker, where Result.retry() spins up a fresh + // ensureModels() invocation after WorkManager's backoff window. throw IOException("Download failed after $MAX_RETRIES attempts: ${lastException?.message}", lastException) } diff --git a/sdk/src/main/kotlin/com/soniqo/speech/service/SpeechRecognitionService.kt b/sdk/src/main/kotlin/com/soniqo/speech/service/SpeechRecognitionService.kt index ec08969..116cdd6 100644 --- a/sdk/src/main/kotlin/com/soniqo/speech/service/SpeechRecognitionService.kt +++ b/sdk/src/main/kotlin/com/soniqo/speech/service/SpeechRecognitionService.kt @@ -17,6 +17,9 @@ import android.speech.RecognizerIntent import android.speech.SpeechRecognizer import android.util.Log import androidx.annotation.RequiresApi +import androidx.work.WorkInfo +import androidx.work.WorkManager +import audio.soniqo.speech.ModelDownloadWorker import audio.soniqo.speech.ModelManager import audio.soniqo.speech.ModelPrecision import audio.soniqo.speech.SpeechConfig @@ -29,8 +32,11 @@ import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.first import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import java.io.IOException import java.util.concurrent.atomic.AtomicBoolean /** @@ -256,9 +262,39 @@ open class SpeechRecognitionService : RecognitionService() { protected open fun createPipeline(config: SpeechConfig): SpeechPipeline = SpeechPipeline(config) - /** Resolve the model directory. Overridden in tests to skip the download. */ - protected open suspend fun resolveModelDir(): String = - ModelManager.ensureModels(this, ModelPrecision.INT8) + /** + * Resolve the model directory. If models aren't on disk yet we delegate + * to [ModelDownloadWorker] (which runs as a foreground service so the + * download survives the bind from Gboard timing out) and suspend until + * it reports a terminal state. Suspension is bound to this session's + * coroutine — if the framework cancels the request, the worker keeps + * running on its own and serves the *next* invocation immediately. + * + * Overridden in tests to skip the download. + */ + protected open suspend fun resolveModelDir(): String { + val ctx = applicationContext + if (ModelManager.areModelsReady(ctx, ModelPrecision.INT8)) { + return ModelManager.modelDir(ctx) + } + Log.i(TAG, "models not ready — delegating to ModelDownloadWorker") + val workId = ModelDownloadWorker.enqueue(ctx, ModelPrecision.INT8) + val info = WorkManager.getInstance(ctx) + .getWorkInfoByIdFlow(workId) + .filterNotNull() + .first { it.state.isFinished } + return when (info.state) { + WorkInfo.State.SUCCEEDED -> info.outputData + .getString(ModelDownloadWorker.KEY_MODEL_DIR) + ?: throw IllegalStateException("worker succeeded but no model dir") + WorkInfo.State.FAILED -> throw IOException( + info.outputData.getString(ModelDownloadWorker.KEY_ERROR) + ?: "model download failed", + ) + WorkInfo.State.CANCELLED -> throw IllegalStateException("model download cancelled") + else -> throw IllegalStateException("unexpected worker state: ${info.state}") + } + } /** * Open the microphone. Returns null when the format is unsupported on this diff --git a/sdk/src/test/kotlin/audio/soniqo/speech/ModelManagerDownloadTest.kt b/sdk/src/test/kotlin/audio/soniqo/speech/ModelManagerDownloadTest.kt index 0bd6b69..af18ec0 100644 --- a/sdk/src/test/kotlin/audio/soniqo/speech/ModelManagerDownloadTest.kt +++ b/sdk/src/test/kotlin/audio/soniqo/speech/ModelManagerDownloadTest.kt @@ -99,7 +99,7 @@ class ModelManagerDownloadTest { } } - tmp.delete() + // Preserve tmp for resume on next attempt — mirrors production. throw IOException("Failed after $maxRetries attempts: ${lastException?.message}", lastException) } @@ -155,17 +155,28 @@ class ModelManagerDownloadTest { } @Test - fun `cleans up tmp file after all retries fail`() { - server.enqueue(MockResponse().setResponseCode(500)) - server.enqueue(MockResponse().setResponseCode(500)) + fun `preserves tmp file after all retries fail so next attempt can resume`() { + // Simulate every retry attempt failing mid-stream: server promises 16 + // bytes via Content-Length but disconnects during the body. OkHttp + // throws, triggering retry. After all retries exhaust, we expect the + // .tmp file to persist (with whatever partial bytes made it to disk) + // so the worker can resume via Range: bytes=N- on a future invocation. + repeat(2) { + server.enqueue( + MockResponse() + .setBody("ABCDEFGHIJKLMNOP") + .setSocketPolicy(okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY) + ) + } val dest = File(tmpDir.root, "model.onnx") try { downloadFile(server.url("/model.onnx").toString(), dest, maxRetries = 2) } catch (_: IOException) {} - assertFalse(dest.exists()) - assertFalse(File(tmpDir.root, "model.onnx.tmp").exists()) + assertFalse("final file should not exist on failure", dest.exists()) + val tmp = File(tmpDir.root, "model.onnx.tmp") + assertTrue("partial .tmp should be preserved for resume", tmp.exists()) } @Test