Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ dependencies {
implementation("androidx.appcompat:appcompat:1.7.0")
implementation("com.google.android.material:material:1.12.0")
implementation("androidx.activity:activity-ktx:1.9.0")
implementation("androidx.work:work-runtime-ktx:2.9.1")
implementation("androidx.work:work-runtime-ktx:2.11.2")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0")
}
4 changes: 2 additions & 2 deletions sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ dependencies {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.9.0")
implementation("com.squareup.okhttp3:okhttp:4.12.0")
implementation("androidx.annotation:annotation:1.8.2")
implementation("androidx.work:work-runtime-ktx:2.9.1")
implementation("androidx.work:work-runtime-ktx:2.11.2")
implementation("androidx.core:core-ktx:1.13.1")

testImplementation("junit:junit:4.13.2")
Expand All @@ -115,7 +115,7 @@ dependencies {
testImplementation("androidx.test:core:1.6.1")
testImplementation("androidx.test.ext:junit:1.2.1")
testImplementation("io.mockk:mockk:1.13.13")
testImplementation("androidx.work:work-testing:2.9.1")
testImplementation("androidx.work:work-testing:2.11.2")

androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test:runner:1.6.2")
Expand Down
24 changes: 17 additions & 7 deletions sdk/src/main/kotlin/com/soniqo/speech/ModelManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ object ModelManager {
* and passes [isValidModel] (right ONNX magic, above the per-file size
* floor) and the cached version matches [MODEL_VERSION].
*
* Cheap and side-effect free — does not start a download. Used by paths
* that must answer "are we ready?" without blocking, e.g.
* `SpeechRecognitionService.onCheckRecognitionSupport()`.
* Cheap and side-effect free — does not start a download. Use this from
* `SpeechRecognitionService.onCheckRecognitionSupport()` (or any path
* that must not block) to decide whether to invoke [ensureModels] /
* `ModelDownloadWorker` first.
*/
fun areModelsReady(
context: Context,
Expand All @@ -102,6 +103,10 @@ object ModelManager {
}
}

/** Path to the model directory for [precision], without downloading. */
fun modelDir(context: Context): String =
File(context.filesDir, "models").absolutePath

/** Returns the model directory path, downloading models if needed. */
suspend fun ensureModels(
context: Context,
Expand All @@ -120,8 +125,10 @@ object ModelManager {
dir.resolve("voices").listFiles()?.forEach { it.delete() }
}

// Clean up leftover partial downloads from previous crashes
dir.walk().filter { it.extension == "tmp" }.forEach { it.delete() }
// Note: leftover .tmp files are intentionally preserved here. If a
// previous run was interrupted, downloadFile resumes via Range:
// bytes=N- on the next attempt. Stale .tmp from an old MODEL_VERSION
// is already wiped above.

val fileList = models(precision)
// FP32 encoder needs the external data file
Expand Down Expand Up @@ -235,8 +242,11 @@ object ModelManager {
}
}

// All retries exhausted — clean up partial file and throw
tmp.delete()
// All retries exhausted — preserve the partial .tmp so the next
// ensureModels() call can pick up where this one left off via the
// Range: header. Particularly important when called from
// ModelDownloadWorker, where Result.retry() spins up a fresh
// ensureModels() invocation after WorkManager's backoff window.
throw IOException("Download failed after $MAX_RETRIES attempts: ${lastException?.message}", lastException)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import android.speech.RecognizerIntent
import android.speech.SpeechRecognizer
import android.util.Log
import androidx.annotation.RequiresApi
import androidx.work.WorkInfo
import androidx.work.WorkManager
import audio.soniqo.speech.ModelDownloadWorker
import audio.soniqo.speech.ModelManager
import audio.soniqo.speech.ModelPrecision
import audio.soniqo.speech.SpeechConfig
Expand All @@ -29,8 +32,11 @@ import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import java.io.IOException
import java.util.concurrent.atomic.AtomicBoolean

/**
Expand Down Expand Up @@ -256,9 +262,39 @@ open class SpeechRecognitionService : RecognitionService() {
protected open fun createPipeline(config: SpeechConfig): SpeechPipeline =
SpeechPipeline(config)

/** Resolve the model directory. Overridden in tests to skip the download. */
protected open suspend fun resolveModelDir(): String =
ModelManager.ensureModels(this, ModelPrecision.INT8)
/**
* Resolve the model directory. If models aren't on disk yet we delegate
* to [ModelDownloadWorker] (which runs as a foreground service so the
* download survives the bind from Gboard timing out) and suspend until
* it reports a terminal state. Suspension is bound to this session's
* coroutine — if the framework cancels the request, the worker keeps
* running on its own and serves the *next* invocation immediately.
*
* Overridden in tests to skip the download.
*/
protected open suspend fun resolveModelDir(): String {
val ctx = applicationContext
if (ModelManager.areModelsReady(ctx, ModelPrecision.INT8)) {
return ModelManager.modelDir(ctx)
}
Log.i(TAG, "models not ready — delegating to ModelDownloadWorker")
val workId = ModelDownloadWorker.enqueue(ctx, ModelPrecision.INT8)
val info = WorkManager.getInstance(ctx)
.getWorkInfoByIdFlow(workId)
.filterNotNull()
.first { it.state.isFinished }
return when (info.state) {
WorkInfo.State.SUCCEEDED -> info.outputData
.getString(ModelDownloadWorker.KEY_MODEL_DIR)
?: throw IllegalStateException("worker succeeded but no model dir")
WorkInfo.State.FAILED -> throw IOException(
info.outputData.getString(ModelDownloadWorker.KEY_ERROR)
?: "model download failed",
)
WorkInfo.State.CANCELLED -> throw IllegalStateException("model download cancelled")
else -> throw IllegalStateException("unexpected worker state: ${info.state}")
}
}

/**
* Open the microphone. Returns null when the format is unsupported on this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ModelManagerDownloadTest {
}
}

tmp.delete()
// Preserve tmp for resume on next attempt — mirrors production.
throw IOException("Failed after $maxRetries attempts: ${lastException?.message}", lastException)
}

Expand Down Expand Up @@ -155,17 +155,28 @@ class ModelManagerDownloadTest {
}

@Test
fun `cleans up tmp file after all retries fail`() {
server.enqueue(MockResponse().setResponseCode(500))
server.enqueue(MockResponse().setResponseCode(500))
fun `preserves tmp file after all retries fail so next attempt can resume`() {
// Simulate every retry attempt failing mid-stream: server promises 16
// bytes via Content-Length but disconnects during the body. OkHttp
// throws, triggering retry. After all retries exhaust, we expect the
// .tmp file to persist (with whatever partial bytes made it to disk)
// so the worker can resume via Range: bytes=N- on a future invocation.
repeat(2) {
server.enqueue(
MockResponse()
.setBody("ABCDEFGHIJKLMNOP")
.setSocketPolicy(okhttp3.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY)
)
}

val dest = File(tmpDir.root, "model.onnx")
try {
downloadFile(server.url("/model.onnx").toString(), dest, maxRetries = 2)
} catch (_: IOException) {}

assertFalse(dest.exists())
assertFalse(File(tmpDir.root, "model.onnx.tmp").exists())
assertFalse("final file should not exist on failure", dest.exists())
val tmp = File(tmpDir.root, "model.onnx.tmp")
assertTrue("partial .tmp should be preserved for resume", tmp.exists())
}

@Test
Expand Down