Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import com.getcode.opencode.controllers.TokenController
import com.getcode.opencode.model.core.ID
import com.getcode.utils.TraceManager
import com.getcode.utils.TraceType
import com.getcode.utils.network.retryable
import com.getcode.utils.trace
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
Expand Down Expand Up @@ -117,17 +118,23 @@ class AuthManager @Inject constructor(
suspend fun presentCredentialStorage(): Result<Unit> {
return credentialManager.presentSaveOption()
.onSuccess {
accountController.getUserFlags().onSuccess { userManager.set(it) }
accountController.getUserFlags().onSuccess { flags ->
userManager.set(flags)
if (flags.isRegistered) {
userManager.set(AuthState.LoggedInWithUser)
}
}
}.map { Unit }
}

suspend fun onAccountPurchased(): Result<Unit> {
return credentialManager.onAccountPurchased()
.fold(
onSuccess = {
userManager.set(AuthState.LoggedInWithUser)
accountController.getUserFlags()
val flagsResult = accountController.getUserFlags()
.onSuccess { userManager.set(it) }
userManager.set(AuthState.LoggedInWithUser)
flagsResult
},
onFailure = { Result.failure(it) }
).onSuccess { savePrefs() }.map { Unit }
Expand Down Expand Up @@ -159,14 +166,17 @@ class AuthManager @Inject constructor(

coroutineScope {
launch {
accountController.getUserFlags()
.onSuccess { flags ->
userManager.set(flags)
userManager.set(if (flags.isRegistered) AuthState.LoggedInWithUser else AuthState.Registered())
}.onFailure {
taggedTrace("Failed to get user flags", type = TraceType.Error, cause = it)
userManager.set(authState = AuthState.Registered())
}
val flags = retryable(maxRetries = 3) {
accountController.getUserFlags().getOrNull()
}

if (flags != null) {
userManager.set(flags)
userManager.set(if (flags.isRegistered) AuthState.LoggedInWithUser else AuthState.Registered())
} else {
taggedTrace("Failed to get user flags after retries", type = TraceType.Error)
userManager.set(authState = AuthState.Registered())
}
}
launch { savePrefs() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.flipcash.app.userflags.UserFlagsCoordinator
import com.flipcash.services.controllers.AccountController
import com.flipcash.services.controllers.PushController
import com.flipcash.services.models.UserFlags
import com.flipcash.services.user.AuthState
import com.flipcash.services.user.UserManager
import io.mockk.coEvery
import io.mockk.coVerify
Expand Down Expand Up @@ -215,4 +216,42 @@ class AuthManagerTest {
val secondRead = authManager.consumePendingSwitchEntropy()
assertNull(secondRead)
}

@Test
fun `login retries getUserFlags on failure then succeeds`() = runTest {
val entropy = "dGVzdGVudHJvcHkxMjM0NQ=="
val accountMetadata: AccountMetadata = mockk(relaxed = true)
val testId = listOf<Byte>(1, 2, 3)
every { accountMetadata.id } returns testId

coEvery { credentialManager.login(entropy, any()) } returns Result.success(accountMetadata)

val flags = UserFlags.Default.copy(isRegistered = true)
coEvery { accountController.getUserFlags() } returnsMany listOf(
Result.failure(RuntimeException("transient failure")),
Result.success(flags)
)

val result = authManager.login(entropyB64 = entropy)

assertTrue(result.isSuccess)
verify { userManager.set(flags) }
verify { userManager.set(authState = AuthState.LoggedInWithUser) }
}

@Test
fun `login falls back to Registered after all retries exhausted`() = runTest {
val entropy = "dGVzdGVudHJvcHkxMjM0NQ=="
val accountMetadata: AccountMetadata = mockk(relaxed = true)
val testId = listOf<Byte>(1, 2, 3)
every { accountMetadata.id } returns testId

coEvery { credentialManager.login(entropy, any()) } returns Result.success(accountMetadata)
coEvery { accountController.getUserFlags() } returns Result.failure(RuntimeException("persistent failure"))

val result = authManager.login(entropyB64 = entropy)

assertTrue(result.isSuccess)
verify { userManager.set(authState = AuthState.Registered()) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class RealSessionController @Inject constructor(
stopPolling()
_state.update { SessionState() }
}
authState.canAccessAuthenticatedApis -> {
authState.isAtLeastRegistered -> {
onAppInForeground()
}
}
Expand Down Expand Up @@ -187,6 +187,17 @@ class RealSessionController @Inject constructor(
.onEach { tokens ->
_state.update { it.copy(tokens = tokens) }
}.launchIn(scope)

// Retry updateUserFlags when network is restored
networkObserver.state
.map { it.connected }
.distinctUntilChanged()
.filter { connected -> connected }
.onEach {
if (userManager.authState.isAtLeastRegistered) {
updateUserFlags()
}
}.launchIn(scope)
}

/**
Expand Down Expand Up @@ -256,10 +267,15 @@ class RealSessionController @Inject constructor(
}

private fun updateUserFlags() {
if (userManager.authState.canAccessAuthenticatedApis) {
if (userManager.authState.isAtLeastRegistered) {
scope.launch {
accountController.getUserFlags()
.onSuccess { userManager.set(it) }
.onSuccess { flags ->
userManager.set(flags)
if (flags.isRegistered && !userManager.authState.canAccessAuthenticatedApis) {
userManager.set(authState = AuthState.LoggedInWithUser)
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,17 @@ class UserManager @Inject constructor(
}

fun set(authState: AuthState) {
val previous = _state.value.authState
_state.update { it.copy(authState = authState) }

when (authState) {
is AuthState.LoggedIn -> {
accountCluster?.let { owner ->
eventBus.send(Events.UpdateLimits(owner = owner, force = true))
// Fire OnLoggedIn only on transition INTO LoggedInWithUser
if (authState is AuthState.LoggedInWithUser && previous !is AuthState.LoggedInWithUser) {
eventBus.send(Events.OnLoggedIn(owner))
}
}
}

Expand All @@ -149,10 +154,6 @@ class UserManager @Inject constructor(
flags = userFlags,
)
}

if (userFlags?.isRegistered == true) {
accountCluster?.let { eventBus.send(Events.OnLoggedIn(accountCluster!!)) }
}
}

fun set(pushToken: String?) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,29 @@ import com.getcode.opencode.internal.manager.VerifiedState
import com.getcode.opencode.model.financial.LocalFiat
import com.getcode.opencode.model.transactions.ExchangeData
import com.getcode.solana.keys.Mint
import kotlin.time.Clock
import kotlin.time.Duration
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Instant

private val DefaultBillExchangeDataTimeout = 15.minutes

fun VerifiedState.exchangeDataFor(
amount: LocalFiat,
mint: Mint,
billExchangeDataTimeout: Duration?
): ExchangeData.Verified? {
if (billExchangeDataTimeout == null) {
val timeout = billExchangeDataTimeout ?: DefaultBillExchangeDataTimeout
if (timeout <= Duration.ZERO) return null

val ts = Instant.fromEpochSeconds(
rateProto.exchangeRate.timestamp.seconds,
rateProto.exchangeRate.timestamp.nanos
)
if (Clock.System.now() - ts > timeout) {
return null
}

return ExchangeData.Verified(
mint = mint,
nativeAmount = amount.nativeAmount.decimalValue,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,48 @@
package com.getcode.opencode.internal.extensions

import com.codeinc.opencode.gen.currency.v1.coreMintFiatExchangeRate
import com.codeinc.opencode.gen.currency.v1.verifiedCoreMintFiatExchangeRate
import com.getcode.opencode.internal.manager.VerifiedState
import com.getcode.opencode.model.financial.CurrencyCode
import com.getcode.opencode.model.financial.Fiat
import com.getcode.opencode.model.financial.LocalFiat
import com.getcode.opencode.model.financial.Rate
import com.getcode.opencode.model.transactions.ExchangeData
import com.getcode.solana.keys.Mint
import io.mockk.mockk
import com.google.protobuf.Timestamp
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.time.Clock
import kotlin.time.Duration
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

class VerifiedStateExtTest {

private val verifiedState = VerifiedState(
rateProto = mockk(relaxed = true),
reserveProto = null,
)
private fun verifiedStateAt(epochSeconds: Long): VerifiedState {
return VerifiedState(
rateProto = verifiedCoreMintFiatExchangeRate {
exchangeRate = coreMintFiatExchangeRate {
currencyCode = "usd"
exchangeRate = 1.0
timestamp = Timestamp.newBuilder().setSeconds(epochSeconds).build()
}
},
reserveProto = null,
)
}

private fun freshVerifiedState(): VerifiedState {
return verifiedStateAt(Clock.System.now().epochSeconds)
}

private fun staleVerifiedState(age: Duration): VerifiedState {
val staleEpoch = Clock.System.now().epochSeconds - age.inWholeSeconds
return verifiedStateAt(staleEpoch)
}

private val amount = LocalFiat(
underlyingTokenAmount = Fiat(quarks = 500_000L, currencyCode = CurrencyCode.USD),
Expand All @@ -29,40 +52,80 @@ class VerifiedStateExtTest {
)

@Test
fun `returns null when timeout is null`() {
val result = verifiedState.exchangeDataFor(
fun `returns Verified with correct fields when timeout is provided and rate is fresh`() {
val state = freshVerifiedState()
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = null,
billExchangeDataTimeout = 30.seconds,
)

assertNull(result)
assertIs<ExchangeData.Verified>(result)
assertEquals(Mint.usdf, result.mint)
assertEquals(amount.nativeAmount.decimalValue, result.nativeAmount)
assertEquals(amount.underlyingTokenAmount.quarks, result.quarks)
assertEquals(state, result.verifiedState)
}

@Test
fun `returns Verified with correct fields when timeout is provided`() {
val result = verifiedState.exchangeDataFor(
fun `passes through the verifiedState reference`() {
val state = freshVerifiedState()
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = 30.seconds,
)

assertIs<ExchangeData.Verified>(result)
assertEquals(Mint.usdf, result.mint)
assertEquals(amount.nativeAmount.decimalValue, result.nativeAmount)
assertEquals(amount.underlyingTokenAmount.quarks, result.quarks)
assertEquals(verifiedState, result.verifiedState)
assert(result.verifiedState === state)
}

@Test
fun `passes through the verifiedState reference`() {
val result = verifiedState.exchangeDataFor(
fun `returns null when timeout is zero`() {
val state = freshVerifiedState()
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = Duration.ZERO,
)

assertNull(result)
}

@Test
fun `returns null when rate exceeds timeout`() {
val state = staleVerifiedState(age = 60.seconds)
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = 1.seconds,
billExchangeDataTimeout = 30.seconds,
)

assertNull(result)
}

@Test
fun `uses default timeout when null and rate is fresh`() {
val state = freshVerifiedState()
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = null,
)

assertNotNull(result)
assertIs<ExchangeData.Verified>(result)
assert(result.verifiedState === verifiedState)
}

@Test
fun `returns null when null timeout and rate exceeds default 15 minutes`() {
val state = staleVerifiedState(age = 16.minutes)
val result = state.exchangeDataFor(
amount = amount,
mint = Mint.usdf,
billExchangeDataTimeout = null,
)

assertNull(result)
}
}
Loading