diff --git a/.metals/metals.lock.db b/.metals/metals.lock.db new file mode 100644 index 000000000..5029a3740 --- /dev/null +++ b/.metals/metals.lock.db @@ -0,0 +1,6 @@ +#FileLock +#Mon Feb 23 23:16:17 CET 2026 +hostName=localhost +id=19c8c931c3ed5da4b5e931ae50555e001101b8d592c +method=file +server=localhost\:65457 diff --git a/.metals/metals.mv.db b/.metals/metals.mv.db new file mode 100644 index 000000000..c5c13e63d Binary files /dev/null and b/.metals/metals.mv.db differ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..32cfc61d2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.watcherExclude": { + "**/target": true + } +} \ No newline at end of file diff --git a/bytebuddy-proxy-support/bin/main/META-INF/services/dev.restate.common.reflections.ProxyFactory b/bytebuddy-proxy-support/bin/main/META-INF/services/dev.restate.common.reflections.ProxyFactory new file mode 100644 index 000000000..f205102de --- /dev/null +++ b/bytebuddy-proxy-support/bin/main/META-INF/services/dev.restate.common.reflections.ProxyFactory @@ -0,0 +1 @@ +dev.restate.bytebuddy.proxysupport.ByteBuddyProxyFactory \ No newline at end of file diff --git a/client-kotlin/bin/main/dev/restate/client/kotlin/ingress.kt b/client-kotlin/bin/main/dev/restate/client/kotlin/ingress.kt new file mode 100644 index 000000000..cfd2b7254 --- /dev/null +++ b/client-kotlin/bin/main/dev/restate/client/kotlin/ingress.kt @@ -0,0 +1,572 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client.kotlin + +import dev.restate.client.Client +import dev.restate.client.RequestOptions +import dev.restate.client.Response +import dev.restate.client.ResponseHead +import dev.restate.client.SendResponse +import dev.restate.common.InvocationOptions +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.common.Target +import dev.restate.common.WorkflowRequest +import dev.restate.common.reflection.kotlin.RequestCaptureProxy +import dev.restate.common.reflection.kotlin.captureInvocation +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.typeTag +import kotlin.coroutines.Continuation +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.startCoroutine +import kotlin.time.Duration +import kotlin.time.toJavaDuration +import kotlinx.coroutines.future.await + +// Extension methods for the Client + +/** Request options builder function */ +fun requestOptions(init: RequestOptions.Builder.() -> Unit): RequestOptions { + val builder = RequestOptions.builder() + builder.init() + return builder.build() +} + +/** + * Shorthand for [callSuspend] + * + * @param client the client to use for the call + * @return the response + */ +suspend fun Request.call(client: Client): Response { + return client.callSuspend(this) +} + +/** Call a service and wait for the response. */ +suspend fun Client.callSuspend(request: Request): Response { + return this.callAsync(request).await() +} + +/** + * Shorthand for [sendSuspend] + * + * @param client the client to use for sending + * @param delay optional execution delay + * @return the send response + */ +suspend fun Request.send( + client: Client, + delay: Duration? = null, +): SendResponse { + return client.sendSuspend(this, delay) +} + +/** + * Send a request to a service without waiting for the response, optionally providing an execution + * delay to wait for. + */ +suspend fun Client.sendSuspend( + request: Request, + delay: Duration? = null, +): SendResponse { + return this.sendAsync(request, delay?.toJavaDuration()).await() +} + +/** + * Shorthand for [submitSuspend] + * + * @param client the client to use for submission + * @param delay optional execution delay + * @return the send response + */ +suspend fun WorkflowRequest.submit( + client: Client, + delay: Duration? = null, +): SendResponse { + return client.submitSuspend(this, delay) +} + +/** Submit a workflow, optionally providing an execution delay to wait for. */ +suspend fun Client.submitSuspend( + request: WorkflowRequest, + delay: Duration? = null, +): SendResponse { + return this.submitAsync(request, delay?.toJavaDuration()).await() +} + +/** + * Complete with success the Awakeable. + * + * @param typeTag the type tag for serialization + * @param payload the payload + * @param options request options + */ +suspend fun Client.AwakeableHandle.resolveSuspend( + typeTag: TypeTag, + payload: T, + options: RequestOptions = RequestOptions.DEFAULT, +): Response { + return this.resolveAsync(typeTag, payload, options).await() +} + +/** + * Complete with success the Awakeable. + * + * @param payload the payload + * @param options request options + */ +suspend inline fun Client.AwakeableHandle.resolveSuspend( + payload: T, + options: RequestOptions = RequestOptions.DEFAULT, +): Response { + return this.resolveSuspend(typeTag(), payload, options) +} + +/** + * Complete with failure the Awakeable. + * + * @param reason the rejection reason + * @param options request options + */ +suspend fun Client.AwakeableHandle.rejectSuspend( + reason: String, + options: RequestOptions = RequestOptions.DEFAULT, +): Response { + return this.rejectAsync(reason, options).await() +} + +/** + * Create a new [Client.InvocationHandle] for the provided invocation identifier. + * + * @param invocationId the invocation identifier + * @return the invocation handle + */ +inline fun Client.invocationHandle( + invocationId: String +): Client.InvocationHandle { + return this.invocationHandle(invocationId, typeTag()) +} + +/** + * Suspend version of [Client.InvocationHandle.attach]. + * + * @param options request options + * @return the response + */ +suspend fun Client.InvocationHandle.attachSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response { + return this.attachAsync(options).await() +} + +/** + * Suspend version of [Client.InvocationHandle.getOutput]. + * + * @param options request options + * @return the output response + */ +suspend fun Client.InvocationHandle.getOutputSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response> { + return this.getOutputAsync(options).await() +} + +/** + * Create a new [Client.IdempotentInvocationHandle] for the provided target and idempotency key. + * + * @param target the target service/method + * @param idempotencyKey the idempotency key + * @return the idempotent invocation handle + */ +inline fun Client.idempotentInvocationHandle( + target: Target, + idempotencyKey: String, +): Client.IdempotentInvocationHandle { + return this.idempotentInvocationHandle(target, idempotencyKey, typeTag()) +} + +/** + * Suspend version of [Client.IdempotentInvocationHandle.attach]. + * + * @param options request options + * @return the response + */ +suspend fun Client.IdempotentInvocationHandle.attachSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response { + return this.attachAsync(options).await() +} + +/** + * Suspend version of [Client.IdempotentInvocationHandle.getOutput]. + * + * @param options request options + * @return the output response + */ +suspend fun Client.IdempotentInvocationHandle.getOutputSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response> { + return this.getOutputAsync(options).await() +} + +/** + * Create a new [Client.WorkflowHandle] for the provided workflow name and identifier. + * + * @param workflowName the workflow name + * @param workflowId the workflow identifier + * @return the workflow handle + */ +inline fun Client.workflowHandle( + workflowName: String, + workflowId: String, +): Client.WorkflowHandle { + return this.workflowHandle(workflowName, workflowId, typeTag()) +} + +/** + * Suspend version of [Client.WorkflowHandle.attach]. + * + * @param options request options + * @return the response + */ +suspend fun Client.WorkflowHandle.attachSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response { + return this.attachAsync(options).await() +} + +/** + * Suspend version of [Client.WorkflowHandle.getOutput]. + * + * @param options request options + * @return the output response + */ +suspend fun Client.WorkflowHandle.getOutputSuspend( + options: RequestOptions = RequestOptions.DEFAULT +): Response> { + return this.getOutputAsync(options).await() +} + +/** @see ResponseHead.statusCode */ +val ResponseHead.status: Int + get() = this.statusCode() + +/** @see ResponseHead.headers */ +val ResponseHead.headers: ResponseHead.Headers + get() = this.headers() + +/** @see Response.response */ +val Response.response: Res + get() = this.response() + +/** @see SendResponse.sendStatus */ +val SendResponse.sendStatus: SendResponse.SendStatus + get() = this.sendStatus() + +/** + * Create a proxy client for a Restate service. + * + * Example usage: + * ```kotlin + * val greeter = client.service() + * val response = greeter.greet("Alice") + * ``` + * + * @param SVC the service class annotated with @Service + * @return a proxy client to invoke the service + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.service(): SVC { + return service(this, SVC::class.java) +} + +/** + * Create a proxy client for a Restate virtual object. + * + * Example usage: + * ```kotlin + * val counter = client.virtualObject("my-key") + * val value = counter.increment() + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a proxy client to invoke the virtual object + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.virtualObject(key: String): SVC { + return virtualObject(this, SVC::class.java, key) +} + +/** + * Create a proxy client for a Restate workflow. + * + * Example usage: + * ```kotlin + * val wf = client.workflow("wf-123") + * val result = wf.run("input") + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a proxy client to invoke the workflow + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.workflow(key: String): SVC { + return workflow(this, SVC::class.java, key) +} + +/** + * Create a proxy for a service that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the service class + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun service(client: Client, clazz: Class): SVC { + ReflectionUtils.mustHaveServiceAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, null).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Create a proxy for a virtual object that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the virtual object class + * @param key the virtual object key + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun virtualObject(client: Client, clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Create a proxy for a workflow that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the workflow class + * @param key the workflow key + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun workflow(client: Client, clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveWorkflowAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Builder for creating type-safe requests. + * + * This builder allows the response type to be inferred from the lambda passed to [request]. + * + * @param SVC the service/virtual object/workflow class + */ +@org.jetbrains.annotations.ApiStatus.Experimental +class KClientRequestBuilder +@PublishedApi +internal constructor( + private val client: Client, + private val clazz: Class, + private val key: String?, +) { + /** + * Create a request by invoking a method on the target. + * + * The response type is inferred from the return type of the invoked method. + * + * @param Res the response type (inferred from the lambda) + * @param block a suspend lambda that invokes a method on the target + * @return a [KClientRequest] with the correct response type + */ + @Suppress("UNCHECKED_CAST") + suspend fun request(block: suspend SVC.() -> Res): KClientRequest { + return KClientRequestImpl( + client, + RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest(), + ) + as KClientRequest + } +} + +/** + * Kotlin-idiomatic request for invoking Restate services from an ingress client. + * + * Example usage: + * ```kotlin + * client.toService() + * .request { add(1) } + * .options { idempotencyKey = "123" } + * .call() + * ``` + * + * @param Req the request type + * @param Res the response type + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KClientRequest : Request { + + /** + * Configure invocation options using a DSL. + * + * @param block builder block for options + * @return a new request with the configured options + */ + fun options(block: InvocationOptions.Builder.() -> Unit): KClientRequest + + /** + * Call the target handler and wait for the response. + * + * @return the response + */ + suspend fun call(): Response + + /** + * Send the request without waiting for the response. + * + * @param delay optional delay before the invocation is executed + * @return the send response with invocation handle + */ + suspend fun send(delay: Duration? = null): SendResponse +} + +/** + * Create a builder for invoking a Restate service. + * + * Example usage: + * ```kotlin + * val response = client.toService() + * .request { greet("Alice") } + * .call() + * ``` + * + * @param SVC the service class annotated with @Service + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toService(): KClientRequestBuilder { + ReflectionUtils.mustHaveServiceAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, null) +} + +/** + * Create a builder for invoking a Restate virtual object. + * + * Example usage: + * ```kotlin + * val response = client.toVirtualObject("my-counter") + * .request { add(1) } + * .call() + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toVirtualObject(key: String): KClientRequestBuilder { + ReflectionUtils.mustHaveVirtualObjectAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, key) +} + +/** + * Create a builder for invoking a Restate workflow. + * + * Example usage: + * ```kotlin + * val response = client.toWorkflow("workflow-123") + * .request { run("input") } + * .call() + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toWorkflow(key: String): KClientRequestBuilder { + ReflectionUtils.mustHaveWorkflowAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, key) +} + +/** Implementation of [KClientRequest] for ingress client. */ +private class KClientRequestImpl( + private val client: Client, + private val request: Request, +) : KClientRequest, Request by request { + + override fun options(block: InvocationOptions.Builder.() -> Unit): KClientRequest { + val builder = InvocationOptions.builder() + builder.block() + return KClientRequestImpl( + client, + this.toBuilder().headers(builder.headers).idempotencyKey(builder.idempotencyKey).build(), + ) + } + + override suspend fun call(): Response { + return client.callSuspend(request) + } + + override suspend fun send(delay: Duration?): SendResponse { + return client.sendSuspend(request, delay) + } +} diff --git a/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt b/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt new file mode 100644 index 000000000..c50ad3d10 --- /dev/null +++ b/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt @@ -0,0 +1,52 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflection.kotlin + +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils + +/** + * Captures method invocations on a proxy to extract invocation information. + * + * This class is used to intercept calls on service proxies and extract the method metadata and + * arguments without actually executing the method. The captured information can then be used to + * build requests for remote invocation. + * + * @param SVC the service type + * @property clazz the service class + * @property serviceName the resolved service name + * @property key the virtual object/workflow key (null for stateless services) + */ +class RequestCaptureProxy(private val clazz: Class, private val key: String?) { + + private val serviceName: String = ReflectionUtils.extractServiceName(clazz) + + /** + * Capture a method invocation from the given block. + * + * @param block the suspend lambda that invokes a method on the service proxy + * @return the captured invocation information + */ + suspend fun capture(block: suspend SVC.() -> Any?): CapturedInvocation { + val proxy = + ProxySupport.createProxy(clazz) { invocation -> + throw invocation.captureInvocation(serviceName, key) + } + + try { + proxy.block() + } catch (e: CapturedInvocation) { + return e + } + + error( + "Method invocation was not captured. Make sure to call ONLY a method of the service proxy." + ) + } +} diff --git a/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/reflections.kt b/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/reflections.kt new file mode 100644 index 000000000..0b30577d6 --- /dev/null +++ b/common-kotlin/bin/main/dev/restate/common/reflection/kotlin/reflections.kt @@ -0,0 +1,111 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflection.kotlin + +import dev.restate.common.Request +import dev.restate.common.Target +import dev.restate.common.reflections.ProxyFactory +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.annotation.Raw +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.kotlinFunction +import kotlin.reflect.typeOf + +/** + * Captured information from a method invocation on a proxy. + * + * @property target the target service/handler + * @property inputTypeTag type tag for serializing the input + * @property outputTypeTag type tag for deserializing the output + * @property input the input value (may be null for no-arg methods) + */ +data class CapturedInvocation( + val target: Target, + val inputTypeTag: TypeTag<*>, + val outputTypeTag: TypeTag<*>, + val input: Any?, +) : RuntimeException("CapturedInvocation message should not be used", null, false, false) { + @Suppress("UNCHECKED_CAST") + fun toRequest(): Request<*, *> { + return Request.of(target, inputTypeTag as TypeTag, outputTypeTag as TypeTag, input) + } +} + +fun ProxyFactory.MethodInvocation.captureInvocation( + serviceName: String, + key: String?, +): CapturedInvocation { + val handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method) + val handlerName = handlerInfo.name + val kFunction = method.kotlinFunction + require(kFunction != null && kFunction.isSuspend) { + "Method '${method.name}' is not a suspend function, this is not supported." + } + + val parameters = kFunction.valueParameters + val inputTypeTag = + if (parameters.isEmpty()) { + resolveKotlinTypeTag(typeOf(), null) + } else { + parameters[0].let { inputParam -> + resolveKotlinTypeTag( + inputParam.type, + inputParam.findAnnotation(), + ) + } + } + + val outputTypeTag = + resolveKotlinTypeTag( + kFunction.returnType, + kFunction.findAnnotation(), + ) + + val target = + if (key != null) { + Target.virtualObject(serviceName, key, handlerName) + } else { + Target.service(serviceName, handlerName) + } + + // For suspend functions, arguments are: [input?, continuation] + // Extract the input (first argument, excluding continuation) + val input = + if (this.arguments.size > 1) { + this.arguments[0] + } else { + null + } + + return CapturedInvocation(target, inputTypeTag, outputTypeTag, input) +} + +private fun resolveKotlinTypeTag(kType: KType, rawAnnotation: Raw?): TypeTag<*> { + if (kType.classifier == Unit::class) { + return KotlinSerializationSerdeFactory.UNIT + } + + if (rawAnnotation != null && rawAnnotation.contentType != "application/octet-stream") { + return Serde.withContentType(rawAnnotation.contentType, Serde.RAW) + } else if (rawAnnotation != null) { + return Serde.RAW + } + + @Suppress("UNCHECKED_CAST") + return KotlinSerializationSerdeFactory.KtTypeTag( + kType.classifier as KClass<*>, + kType, + ) +} diff --git a/examples/bin/main/log4j2.properties b/examples/bin/main/log4j2.properties new file mode 100644 index 000000000..871f44bc5 --- /dev/null +++ b/examples/bin/main/log4j2.properties @@ -0,0 +1,26 @@ +# Set to debug or trace if log4j initialization is failing +status = warn + +# Console appender configuration +appender.console.type = Console +appender.console.name = consoleLogger +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p %notEmpty{[%X{restateInvocationTarget}]}%notEmpty{[%X{restateInvocationId}]} %t %c - %m%n + +# Filter out logging during replay +appender.console.filter.replay.type = ContextMapFilter +appender.console.filter.replay.onMatch = DENY +appender.console.filter.replay.onMismatch = NEUTRAL +appender.console.filter.replay.0.type = KeyValuePair +appender.console.filter.replay.0.key = restateInvocationStatus +appender.console.filter.replay.0.value = REPLAYING + +# Restate logs to info level +logger.app.name = dev.restate +logger.app.level = info +logger.app.additivity = false +logger.app.appenderRef.console.ref = consoleLogger + +# Root logger +rootLogger.level = warn +rootLogger.appenderRef.stdout.ref = consoleLogger \ No newline at end of file diff --git a/examples/bin/main/my/restate/sdk/examples/CounterKt.kt b/examples/bin/main/my/restate/sdk/examples/CounterKt.kt new file mode 100644 index 000000000..6a5eeb67c --- /dev/null +++ b/examples/bin/main/my/restate/sdk/examples/CounterKt.kt @@ -0,0 +1,62 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package my.restate.sdk.examples + +import dev.restate.sdk.annotation.Handler +import dev.restate.sdk.annotation.Shared +import dev.restate.sdk.annotation.VirtualObject +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.endpoint.* +import kotlinx.serialization.Serializable +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.Logger + +@VirtualObject +class CounterKt { + + companion object { + private val TOTAL = stateKey("total") + private val LOG: Logger = LogManager.getLogger(CounterKt::class.java) + } + + @Serializable data class CounterUpdate(var oldValue: Long, val newValue: Long) + + @Handler + suspend fun reset() { + state().clear(TOTAL) + } + + @Handler + suspend fun add(value: Long) { + val currentValue = state().get(TOTAL) ?: 0L + val newValue = currentValue + value + state().set(TOTAL, newValue) + } + + @Handler + @Shared + suspend fun get(): Long? { + return state().get(TOTAL) + } + + @Handler + suspend fun getAndAdd(value: Long): CounterUpdate { + LOG.info("Invoked get and add with $value") + val currentValue = state().get(TOTAL) ?: 0L + val newValue = currentValue + value + state().set(TOTAL, newValue) + return CounterUpdate(currentValue, newValue) + } +} + +fun main() { + val endpoint = endpoint { bind(CounterKt()) } + RestateHttpServer.listen(endpoint) +} diff --git a/sdk-api-gen/bin/main/META-INF/services/javax.annotation.processing.Processor b/sdk-api-gen/bin/main/META-INF/services/javax.annotation.processing.Processor new file mode 100644 index 000000000..477bc44fb --- /dev/null +++ b/sdk-api-gen/bin/main/META-INF/services/javax.annotation.processing.Processor @@ -0,0 +1 @@ +dev.restate.sdk.gen.ServiceProcessor \ No newline at end of file diff --git a/sdk-api-gen/bin/main/templates/Client.hbs b/sdk-api-gen/bin/main/templates/Client.hbs new file mode 100644 index 000000000..81325f4dc --- /dev/null +++ b/sdk-api-gen/bin/main/templates/Client.hbs @@ -0,0 +1,359 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +import dev.restate.sdk.CallDurableFuture; +import dev.restate.sdk.Context; +import dev.restate.sdk.common.StateKey; +import dev.restate.serde.Serde; +import dev.restate.common.Target; +import java.util.Optional; +import java.time.Duration; +import java.util.function.Consumer; + +/** + * Clients for {@link {{originalClassFqcn}} } + * + * @see {{originalClassFqcn}} + */ +public class {{generatedClassSimpleName}} { + + /** + * Create context client for {@link {{originalClassFqcn}} } + * + * @see {{originalClassFqcn}} + */ + public static ContextClient fromContext(Context ctx{{#isKeyed}}, String key{{/isKeyed}}) { + return new ContextClient(ctx{{#isKeyed}}, key{{/isKeyed}}); + } + + /** Create ingress client for {@link {{originalClassFqcn}} } **/ + public static IngressClient fromClient(dev.restate.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { + return new IngressClient(client{{#isKeyed}}, key{{/isKeyed}}); + } + + /** Create ingress client for {@link {{originalClassFqcn}} } **/ + public static IngressClient connect(String baseUri{{#isKeyed}}, String key{{/isKeyed}}) { + return new IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY){{#isKeyed}}, key{{/isKeyed}}); + } + + /** Create ingress client for {@link {{originalClassFqcn}} } **/ + public static IngressClient connect(String baseUri, dev.restate.client.RequestOptions requestOptions{{#isKeyed}}, String key{{/isKeyed}}) { + return new IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY, requestOptions){{#isKeyed}}, key{{/isKeyed}}); + } + + /** Context client for {@link {{originalClassFqcn}} } **/ + public static class ContextClient { + + private final Context ctx; + {{#isKeyed}}private final String key;{{/isKeyed}} + + public ContextClient(Context ctx{{#isKeyed}}, String key{{/isKeyed}}) { + this.ctx = ctx; + {{#isKeyed}}this.key = key;{{/isKeyed}} + } + + {{#handlers}} + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public CallDurableFuture<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return this.ctx.call( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public CallDurableFuture<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return this.ctx.call(reqBuilder); + } + {{/handlers}} + + public Send send() { + return new Send(); + } + + public class Send { + + {{#handlers}} + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return ContextClient.this.ctx.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return ContextClient.this.ctx.send(reqBuilder); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay) { + return ContextClient.this.ctx.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}), delay + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay, Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return ContextClient.this.ctx.send(reqBuilder, delay); + } + {{/handlers}} + } + } + + /** Ingress client for {@link {{originalClassFqcn}} } **/ + public static class IngressClient { + + private final dev.restate.client.Client client; + {{#isKeyed}}private final String key;{{/isKeyed}} + + public IngressClient(dev.restate.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { + this.client = client; + {{#isKeyed}}this.key = key;{{/isKeyed}} + } + + {{#handlers}}{{#if isWorkflow}} + public dev.restate.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> workflowHandle() { + return IngressClient.this.client.workflowHandle( + {{metadataClass}}.SERVICE_NAME, + this.key, + {{outputSerdeRef}}); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> submit({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return IngressClient.this.client.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> submit({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay) { + return IngressClient.this.client.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}), delay + ); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> submit({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.send(reqBuilder); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> submit({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay, Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.send(reqBuilder, delay); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> submitAsync({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return IngressClient.this.client.sendAsync( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> submitAsync({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}} Duration delay) { + return IngressClient.this.client.sendAsync( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}), delay + ); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> submitAsync({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.sendAsync(reqBuilder); + } + + /** + * Submit the workflow. + * + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> submitAsync({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay, Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.sendAsync(reqBuilder, delay); + } + {{else}} + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + {{^outputEmpty}}return {{/outputEmpty}}this.client.call( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ).response(); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + {{^outputEmpty}}return {{/outputEmpty}}this.client.call(reqBuilder.build()).response(); + } + + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return this.client.callAsync( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ).thenApply(dev.restate.client.Response::response); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return this.client.callAsync(reqBuilder.build()).thenApply(dev.restate.client.Response::response); + } + {{/if}}{{/handlers}} + + public Send send() { + return new Send(); + } + + public class Send { + + {{#handlers}}{{^isWorkflow}} + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return IngressClient.this.client.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.send(reqBuilder); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay) { + return IngressClient.this.client.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}), delay + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay, Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.send(reqBuilder, delay); + } + + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return IngressClient.this.client.sendAsync( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.sendAsync(reqBuilder); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay) { + return IngressClient.this.client.sendAsync( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}), delay + ); + } + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public java.util.concurrent.CompletableFuture> {{handlersClassMethodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay, Consumer> requestBuilderApplier) { + var reqBuilder = {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}); + if (requestBuilderApplier != null) { + requestBuilderApplier.accept(reqBuilder); + } + return IngressClient.this.client.sendAsync(reqBuilder, delay); + }{{/isWorkflow}}{{/handlers}} + } + } +} \ No newline at end of file diff --git a/sdk-api-gen/bin/main/templates/Handlers.hbs b/sdk-api-gen/bin/main/templates/Handlers.hbs new file mode 100644 index 000000000..379ee545d --- /dev/null +++ b/sdk-api-gen/bin/main/templates/Handlers.hbs @@ -0,0 +1,41 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +/** Handler request factories for {@link {{originalClassFqcn}} } **/ +@SuppressWarnings("unchecked") +public final class {{generatedClassSimpleName}} { + + private {{generatedClassSimpleName}}() {} + + {{#handlers}} + /** + * @see {{originalClassFqcn}}#{{name}} + **/ + public static {{#if isWorkflow}}dev.restate.common.WorkflowRequestBuilder{{else}}dev.restate.common.RequestBuilder{{/if}}<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> {{handlersClassMethodName}}({{#if ../isKeyed}}String key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return {{#if isWorkflow}}(dev.restate.common.WorkflowRequestBuilder){{/if}} dev.restate.common.Request.of( + {{{targetExpr this "key"}}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + {{#if inputEmpty}}null{{else}}req{{/if}}); + } + + {{/handlers}} + + /** Metadata for {@link {{originalClassFqcn}} } **/ + public final static class Metadata { + + public static final String SERVICE_NAME = "{{restateServiceName}}"; + public static final dev.restate.serde.SerdeFactory SERDE_FACTORY = {{serdeFactoryDecl}}; + + private Metadata() {} + + public final static class Serde { + {{#handlers}} + public static final dev.restate.serde.Serde<{{{boxedInputFqcn}}}> {{inputSerdeFieldName}} = {{{inputSerdeDecl}}}; + public static final dev.restate.serde.Serde<{{{boxedOutputFqcn}}}> {{outputSerdeFieldName}} = {{{outputSerdeDecl}}}; + {{/handlers}} + + private Serde() {} + } + + } +} \ No newline at end of file diff --git a/sdk-api-gen/bin/main/templates/ServiceDefinitionFactory.hbs b/sdk-api-gen/bin/main/templates/ServiceDefinitionFactory.hbs new file mode 100644 index 000000000..e449a52a3 --- /dev/null +++ b/sdk-api-gen/bin/main/templates/ServiceDefinitionFactory.hbs @@ -0,0 +1,37 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +/** Service definition factory to bind {@link {{originalClassFqcn}} } **/ +public class {{generatedClassSimpleName}} implements dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory<{{originalClassFqcn}}> { + + @java.lang.Override + public dev.restate.sdk.endpoint.definition.ServiceDefinition create({{originalClassFqcn}} bindableService, dev.restate.sdk.endpoint.definition.HandlerRunner.Options overrideHandlerOptions) { + dev.restate.sdk.HandlerRunner.Options handlerRunnerOptions = dev.restate.sdk.HandlerRunner.Options.DEFAULT; + if (overrideHandlerOptions != null) { + if (overrideHandlerOptions instanceof dev.restate.sdk.HandlerRunner.Options) { + handlerRunnerOptions = (dev.restate.sdk.HandlerRunner.Options)overrideHandlerOptions; + } else { + throw new IllegalArgumentException("The provided options class MUST be instance of dev.restate.sdk.HandlerRunner.Options, but was " + overrideHandlerOptions.getClass()); + } + } + return dev.restate.sdk.endpoint.definition.ServiceDefinition.of( + {{metadataClass}}.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.endpoint.definition.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.ServiceType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.ServiceType.SERVICE{{/if}}, + java.util.List.of( + {{#handlers}} + dev.restate.sdk.endpoint.definition.HandlerDefinition.of( + "{{restateName}}", + {{#if isExclusive}}dev.restate.sdk.endpoint.definition.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.HandlerType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.HandlerType.SHARED{{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + dev.restate.sdk.HandlerRunner.of(bindableService::{{name}}, {{serdeFactoryRef}}, handlerRunnerOptions) + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}{{#if documentation}}.withDocumentation("{{escapeJava documentation}}"){{/if}}{{#unless @last}},{{/unless}} + {{/handlers}} + ) + ){{#if documentation}}.withDocumentation("{{escapeJava documentation}}"){{/if}}; + } + + @java.lang.Override + public boolean supports(Object serviceObject) { + return serviceObject instanceof {{originalClassFqcn}}; + } +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/bin/main/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider b/sdk-api-kotlin-gen/bin/main/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider new file mode 100644 index 000000000..034b8dc5a --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider @@ -0,0 +1 @@ +dev.restate.sdk.kotlin.gen.ServiceProcessorProvider \ No newline at end of file diff --git a/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/KElementConverter.kt b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/KElementConverter.kt new file mode 100644 index 000000000..06a1bff0d --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/KElementConverter.kt @@ -0,0 +1,372 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.gen + +import com.google.devtools.ksp.* +import com.google.devtools.ksp.processing.KSBuiltIns +import com.google.devtools.ksp.processing.KSPLogger +import com.google.devtools.ksp.symbol.* +import com.google.devtools.ksp.visitor.KSDefaultVisitor +import dev.restate.sdk.annotation.* +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.gen.model.Handler +import dev.restate.sdk.gen.model.HandlerType +import dev.restate.sdk.gen.model.PayloadType +import dev.restate.sdk.gen.model.Service +import dev.restate.sdk.gen.utils.AnnotationUtils.getAnnotationDefaultValue +import dev.restate.sdk.kotlin.* +import java.util.regex.Pattern +import kotlin.reflect.KClass + +class KElementConverter( + private val logger: KSPLogger, + private val builtIns: KSBuiltIns, + private val byteArrayType: KSType, +) : KSDefaultVisitor() { + companion object { + private val SUPPORTED_CLASS_KIND: Set = setOf(ClassKind.CLASS, ClassKind.INTERFACE) + private val EMPTY_PAYLOAD: PayloadType = + PayloadType( + true, + "", + "Unit", + "dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory.UNIT", + ) + private const val RAW_SERDE: String = "dev.restate.serde.Serde.RAW" + } + + override fun defaultHandler(node: KSNode, data: Service.Builder) {} + + override fun visitAnnotated(annotated: KSAnnotated, data: Service.Builder) { + if (annotated !is KSClassDeclaration) { + logger.error( + "Only classes or interfaces can be annotated with @Service or @VirtualObject or @Workflow" + ) + } + visitClassDeclaration(annotated as KSClassDeclaration, data) + } + + @OptIn(KspExperimental::class) + override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Service.Builder) { + // Validate class declaration + if (classDeclaration.typeParameters.isNotEmpty()) { + logger.error("The ServiceProcessor doesn't support services with generics", classDeclaration) + } + if (!SUPPORTED_CLASS_KIND.contains(classDeclaration.classKind)) { + logger.error( + "The ServiceProcessor supports only class declarations of kind $SUPPORTED_CLASS_KIND", + classDeclaration, + ) + } + if (classDeclaration.getVisibility() == Visibility.PRIVATE) { + logger.error("The annotated class is private", classDeclaration) + } + + // Infer names + val targetPkg = classDeclaration.packageName.asString() + val targetFqcn = classDeclaration.qualifiedName!!.asString() + // Use simple class name, flattening subclasses names + val inCodeServiceName = + targetFqcn.substring(targetPkg.length).replace(Pattern.quote(".").toRegex(), "") + + classDeclaration.getAnnotationsByType(Name::class).firstOrNull().let { + if (it != null) { + data.withRestateName(it.value) + } + } + + data + .withTargetClassPkg(targetPkg) + .withTargetClassFqcn(targetFqcn) + .withGeneratedClassesNamePrefix(inCodeServiceName) + + // Compute handlersMetadata + classDeclaration + .getAllFunctions() + .filter { + it.isAnnotationPresent(dev.restate.sdk.annotation.Handler::class) || + it.isAnnotationPresent(dev.restate.sdk.annotation.Workflow::class) || + it.isAnnotationPresent(dev.restate.sdk.annotation.Exclusive::class) || + it.isAnnotationPresent(dev.restate.sdk.annotation.Shared::class) + } + .forEach { visitFunctionDeclaration(it, data) } + + if (data.handlers.isEmpty()) { + logger.warn( + "The class declaration $targetFqcn has no methods annotated as handlers", + classDeclaration, + ) + } + + var serdeFactoryDecl = "dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory()" + val customSerdeFactory: CustomSerdeFactory? = + classDeclaration.getAnnotationsByType(CustomSerdeFactory::class).firstOrNull() + if (customSerdeFactory != null) { + serdeFactoryDecl = parseAnnotationClassParameter { customSerdeFactory.value } + "()" + } + data.withSerdeFactoryDecl(serdeFactoryDecl) + } + + @OptIn(KspExperimental::class) + override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Service.Builder) { + // Validate function declaration + if (function.typeParameters.isNotEmpty()) { + logger.error("The ServiceProcessor doesn't support methods with generics", function) + } + if (function.functionKind != FunctionKind.MEMBER) { + logger.error("Only member function declarations are supported as Restate handlers") + } + if (function.getVisibility() == Visibility.PRIVATE) { + logger.error("The annotated function is private", function) + } + + val isAnnotatedWithShared = + function.isAnnotationPresent(dev.restate.sdk.annotation.Shared::class) + val isAnnotatedWithExclusive = + function.isAnnotationPresent(dev.restate.sdk.annotation.Exclusive::class) + val isAnnotatedWithWorkflow = + function.isAnnotationPresent(dev.restate.sdk.annotation.Workflow::class) + + // Check there's no more than one annotation + val hasAnyAnnotation = + isAnnotatedWithExclusive || isAnnotatedWithShared || isAnnotatedWithWorkflow + val hasExactlyOneAnnotation = + isAnnotatedWithShared xor isAnnotatedWithExclusive xor isAnnotatedWithWorkflow + if (!(!hasAnyAnnotation || hasExactlyOneAnnotation)) { + logger.error( + "You can have only one annotation between @Shared and @Exclusive and @Workflow to a method", + function, + ) + } + + val handlerBuilder = Handler.builder() + + // Set handler type + val handlerType = + if (isAnnotatedWithShared) HandlerType.SHARED + else if (isAnnotatedWithExclusive) HandlerType.EXCLUSIVE + else if (isAnnotatedWithWorkflow) HandlerType.WORKFLOW + else defaultHandlerType(data.serviceType) + handlerBuilder.withHandlerType(handlerType) + + validateMethodSignature(data.serviceType, handlerType, function) + + try { + data.withHandler( + handlerBuilder + .withName(function.simpleName.asString()) + .withRestateName(function.getAnnotationsByType(Name::class).firstOrNull()?.value) + .withHandlerType(handlerType) + .withInputAccept(inputAcceptFromParameterList(function.parameters)) + .withInputType(inputPayloadFromParameterList(function.parameters)) + .withOutputType(outputPayloadFromExecutableElement(function)) + .validateAndBuild() + ) + } catch (e: Exception) { + logger.error("Error when building handler: $e", function) + } + } + + @OptIn(KspExperimental::class) + private fun inputAcceptFromParameterList(paramList: List): String? { + if (paramList.size <= 1) { + return null + } + + return paramList[1].getAnnotationsByType(Accept::class).firstOrNull()?.value + } + + @OptIn(KspExperimental::class) + private fun inputPayloadFromParameterList(paramList: List): PayloadType { + if (paramList.size <= 1) { + return EMPTY_PAYLOAD + } + + val parameterElement: KSValueParameter = paramList[1] + return payloadFromTypeMirrorAndAnnotations( + parameterElement.type.resolve(), + parameterElement.getAnnotationsByType(Json::class).firstOrNull(), + parameterElement.getAnnotationsByType(Raw::class).firstOrNull(), + parameterElement, + ) + } + + @OptIn(KspExperimental::class) + private fun outputPayloadFromExecutableElement(fn: KSFunctionDeclaration): PayloadType { + return payloadFromTypeMirrorAndAnnotations( + fn.returnType?.resolve() ?: builtIns.unitType, + fn.getAnnotationsByType(Json::class).firstOrNull(), + fn.getAnnotationsByType(Raw::class).firstOrNull(), + fn, + ) + } + + private fun payloadFromTypeMirrorAndAnnotations( + ty: KSType, + jsonAnnotation: Json?, + rawAnnotation: Raw?, + relatedNode: KSNode, + ): PayloadType { + if (ty == builtIns.unitType) { + if (rawAnnotation != null || jsonAnnotation != null) { + logger.error("Unexpected annotation for void type.", relatedNode) + } + return EMPTY_PAYLOAD + } + // Some validation + if (rawAnnotation != null && jsonAnnotation != null) { + logger.error("A parameter cannot be annotated both with @Raw and @Json.", relatedNode) + } + if (rawAnnotation != null && ty != byteArrayType) { + logger.error("A parameter annotated with @Raw MUST be of type byte[], was $ty", relatedNode) + } + if (ty.isFunctionType || ty.isSuspendFunctionType) { + logger.error("Cannot use fun as parameter or return type", relatedNode) + } + + val qualifiedTypeName = qualifiedTypeName(ty) + var serdeDecl: String = + if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty, qualifiedTypeName) + if ( + rawAnnotation != null && + rawAnnotation.contentType != getAnnotationDefaultValue(Raw::class.java, "contentType") + ) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, rawAnnotation.contentType) + } + if ( + jsonAnnotation != null && + jsonAnnotation.contentType != getAnnotationDefaultValue(Json::class.java, "contentType") + ) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, jsonAnnotation.contentType) + } + + return PayloadType(false, qualifiedTypeName, boxedType(ty, qualifiedTypeName), serdeDecl) + } + + private fun contentTypeDecoratedSerdeDecl(serdeDecl: String, contentType: String): String { + return ("dev.restate.serde.Serde.withContentType(\"" + contentType + "\", " + serdeDecl + ")") + } + + private fun defaultHandlerType(serviceType: ServiceType): HandlerType { + when (serviceType) { + ServiceType.SERVICE -> return HandlerType.STATELESS + ServiceType.VIRTUAL_OBJECT -> return HandlerType.EXCLUSIVE + ServiceType.WORKFLOW -> return HandlerType.SHARED + } + } + + private fun validateMethodSignature( + serviceType: ServiceType, + handlerType: HandlerType, + function: KSFunctionDeclaration, + ) { + if (function.parameters.isEmpty()) { + logger.error( + "The annotated method has no parameters. There must be at least the context parameter as first parameter", + function, + ) + } + when (handlerType) { + HandlerType.SHARED -> + if (serviceType == ServiceType.VIRTUAL_OBJECT) { + validateFirstParameterType(SharedObjectContext::class, function) + } else if (serviceType == ServiceType.WORKFLOW) { + validateFirstParameterType(SharedWorkflowContext::class, function) + } else { + logger.error( + "The annotation @Shared is not supported by the service type $serviceType", + function, + ) + } + HandlerType.EXCLUSIVE -> + if (serviceType == ServiceType.VIRTUAL_OBJECT) { + validateFirstParameterType(ObjectContext::class, function) + } else { + logger.error( + "The annotation @Exclusive is not supported by the service type $serviceType", + function, + ) + } + HandlerType.STATELESS -> validateFirstParameterType(Context::class, function) + HandlerType.WORKFLOW -> + if (serviceType == ServiceType.WORKFLOW) { + validateFirstParameterType(WorkflowContext::class, function) + } else { + logger.error( + "The annotation @Workflow is not supported by the service type $serviceType", + function, + ) + } + } + } + + private fun validateFirstParameterType(clazz: KClass<*>, function: KSFunctionDeclaration) { + if ( + function.parameters[0].type.resolve().declaration.qualifiedName!!.asString() != + clazz.qualifiedName + ) { + logger.error( + "The method ${function.qualifiedName?.asString()} signature must have ${clazz.qualifiedName} as first parameter, was ${function.parameters[0].type.resolve().declaration.qualifiedName!!.asString()}", + function, + ) + } + } + + private fun jsonSerdeDecl(ty: KSType, qualifiedTypeName: String): String { + return when (ty) { + builtIns.unitType -> EMPTY_PAYLOAD.serdeDecl + else -> + "SERDE_FACTORY.create(dev.restate.serde.kotlinx.typeTag<${boxedType(ty, qualifiedTypeName)}>())" + } + } + + private fun boxedType(ty: KSType, qualifiedTypeName: String): String { + return when (ty) { + builtIns.unitType -> "Unit" + else -> qualifiedTypeName + } + } + + private fun qualifiedTypeName(ksType: KSType): String { + var typeName = ksType.declaration.qualifiedName?.asString() ?: ksType.toString() + + if (ksType.arguments.isNotEmpty()) { + typeName = + "$typeName<${ + ksType.arguments.joinToString(separator = ", ") { + if (it.variance == Variance.STAR) { + it.variance.label + } else { + "${it.variance.label} ${qualifiedTypeName(it.type!!.resolve())}" + } + } + }>" + } + + if (ksType.isMarkedNullable) { + typeName = "$typeName?" + } + + return typeName + } + + @OptIn(KspExperimental::class) + private fun parseAnnotationClassParameter(block: () -> KClass<*>): String? { + return try { // KSTypeNotPresentException will be thrown + block.invoke().qualifiedName + } catch (e: KSTypeNotPresentException) { + var res: String? = null + val declaration = e.ksType.declaration + if (declaration is KSClassDeclaration) { + declaration.qualifiedName?.asString()?.let { res = it } + } + res + } + } +} diff --git a/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt new file mode 100644 index 000000000..a7cffabef --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.gen + +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSName +import dev.restate.sdk.endpoint.definition.ServiceType + +internal data class MetaRestateAnnotation( + val annotationName: KSName, + val serviceType: ServiceType, +) { + fun resolveName(annotated: KSAnnotated): String? = + annotated.annotations + .find { it.annotationType.resolve().declaration.qualifiedName == annotationName } + ?.arguments + ?.firstOrNull { it -> it.name?.getShortName() == "name" } + ?.value as String? +} diff --git a/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt new file mode 100644 index 000000000..16947cf7d --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt @@ -0,0 +1,259 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.gen + +import com.github.jknack.handlebars.io.ClassPathTemplateLoader +import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.containingFile +import com.google.devtools.ksp.getClassDeclarationByName +import com.google.devtools.ksp.getKotlinClassByName +import com.google.devtools.ksp.processing.* +import com.google.devtools.ksp.symbol.ClassKind +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.Origin +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.gen.model.AnnotationProcessingOptions +import dev.restate.sdk.gen.model.Service +import dev.restate.sdk.gen.template.HandlebarsTemplateEngine +import java.io.BufferedWriter +import java.io.IOException +import java.io.Writer +import java.nio.charset.Charset + +class ServiceProcessor( + private val logger: KSPLogger, + private val codeGenerator: CodeGenerator, + private val options: AnnotationProcessingOptions, +) : SymbolProcessor { + + companion object { + private val RESERVED_METHOD_NAMES: Set = setOf("send", "submit", "workflowHandle") + } + + private val bindableServiceFactoryCodegen: HandlebarsTemplateEngine = + HandlebarsTemplateEngine( + "ServiceDefinitionFactory", + ClassPathTemplateLoader(), + mapOf( + ServiceType.SERVICE to "templates/ServiceDefinitionFactory", + ServiceType.WORKFLOW to "templates/ServiceDefinitionFactory", + ServiceType.VIRTUAL_OBJECT to "templates/ServiceDefinitionFactory", + ), + RESERVED_METHOD_NAMES, + ) + private val clientCodegen: HandlebarsTemplateEngine = + HandlebarsTemplateEngine( + "Client", + ClassPathTemplateLoader(), + mapOf( + ServiceType.SERVICE to "templates/Client", + ServiceType.WORKFLOW to "templates/Client", + ServiceType.VIRTUAL_OBJECT to "templates/Client", + ), + RESERVED_METHOD_NAMES, + ) + private val handlersCodegen: HandlebarsTemplateEngine = + HandlebarsTemplateEngine( + "Handlers", + ClassPathTemplateLoader(), + mapOf( + ServiceType.SERVICE to "templates/Handlers", + ServiceType.WORKFLOW to "templates/Handlers", + ServiceType.VIRTUAL_OBJECT to "templates/Handlers", + ), + RESERVED_METHOD_NAMES, + ) + + @OptIn(KspExperimental::class) + override fun process(resolver: Resolver): List { + val converter = + KElementConverter( + logger, + resolver.builtIns, + resolver.getKotlinClassByName(ByteArray::class.qualifiedName!!)!!.asType(listOf()), + ) + + val discovered = discoverRestateAnnotatedOrMetaAnnotatedServices(resolver) + + val services = + discovered + .map { + val serviceBuilder = Service.builder() + serviceBuilder.withServiceType(it.first.serviceType) + serviceBuilder.withRestateName(it.first.resolveName(it.second)) + + converter.visitAnnotated(it.second, serviceBuilder) + + var serviceModel: Service? = null + try { + serviceModel = serviceBuilder.validateAndBuild() + } catch (e: Exception) { + logger.error("Unable to build service: $e", it.second) + } + (it.second to serviceModel!!) + } + .toList() + + // Run code generation + for (service in services) { + try { + val fileCreator: (String) -> Writer = { name: String -> + codeGenerator + .createNewFile( + Dependencies(false, service.first.containingFile!!), + service.second.targetPkg.toString(), + name, + ) + .writer(Charset.defaultCharset()) + } + this.bindableServiceFactoryCodegen.generate(fileCreator, service.second) + this.handlersCodegen.generate(fileCreator, service.second) + if (!options.isClientGenDisabled(service.second.targetFqcn.toString())) { + this.clientCodegen.generate(fileCreator, service.second) + } + } catch (ex: Throwable) { + throw RuntimeException(ex) + } + } + + // META-INF + if (services.isNotEmpty()) { + generateMetaINF(services) + } + + return emptyList() + } + + private fun discoverRestateAnnotatedOrMetaAnnotatedServices( + resolver: Resolver + ): Set> { + val discoveredAnnotatedElements = mutableSetOf>() + + val metaAnnotationsToProcess = + mutableListOf( + MetaRestateAnnotation( + resolver + .getClassDeclarationByName()!! + .qualifiedName!!, + ServiceType.SERVICE, + ), + MetaRestateAnnotation( + resolver + .getClassDeclarationByName()!! + .qualifiedName!!, + ServiceType.VIRTUAL_OBJECT, + ), + MetaRestateAnnotation( + resolver + .getClassDeclarationByName()!! + .qualifiedName!!, + ServiceType.WORKFLOW, + ), + ) + + // Add spring annotations, if available + resolver.getClassDeclarationByName("dev.restate.sdk.springboot.RestateService")?.let { + metaAnnotationsToProcess.add(MetaRestateAnnotation(it.qualifiedName!!, ServiceType.SERVICE)) + } + resolver.getClassDeclarationByName("dev.restate.sdk.springboot.RestateVirtualObject")?.let { + metaAnnotationsToProcess.add( + MetaRestateAnnotation(it.qualifiedName!!, ServiceType.VIRTUAL_OBJECT) + ) + } + resolver.getClassDeclarationByName("dev.restate.sdk.springboot.RestateWorkflow")?.let { + metaAnnotationsToProcess.add(MetaRestateAnnotation(it.qualifiedName!!, ServiceType.WORKFLOW)) + } + + val discoveredAnnotations = mutableSetOf() + + var metaAnnotation = metaAnnotationsToProcess.removeFirstOrNull() + while (metaAnnotation != null) { + if (!discoveredAnnotations.add(metaAnnotation.annotationName.asString())) { + // We already discovered it, skip + continue + } + for (annotatedElement in + resolver.getSymbolsWithAnnotation(metaAnnotation.annotationName.asString())) { + if (annotatedElement !is KSClassDeclaration) { + continue + } + when (annotatedElement.classKind) { + ClassKind.INTERFACE, + ClassKind.CLASS -> { + if ( + annotatedElement.containingFile!!.origin != Origin.KOTLIN || + options.isClassDisabled(annotatedElement.qualifiedName!!.asString()) + ) { + // Skip if it's not kotlin + continue + } + discoveredAnnotatedElements.add(metaAnnotation to annotatedElement) + } + ClassKind.ANNOTATION_CLASS -> { + metaAnnotationsToProcess.add( + MetaRestateAnnotation(annotatedElement.qualifiedName!!, metaAnnotation.serviceType) + ) + } + else -> + logger.error( + "The ServiceProcessor supports only interfaces or classes declarations", + annotatedElement, + ) + } + } + metaAnnotation = metaAnnotationsToProcess.removeFirstOrNull() + } + + val knownAnnotations = discoveredAnnotations.toSet() + + // Check annotated elements are annotated with only one of the given annotations. + discoveredAnnotatedElements.forEach { it -> + val forbiddenAnnotations = knownAnnotations - setOf(it.first.annotationName.asString()) + val elementAnnotations = + it.second.annotations + .mapNotNull { it.annotationType.resolve().declaration.qualifiedName?.asString() } + .toSet() + if (forbiddenAnnotations.intersect(elementAnnotations).isNotEmpty()) { + logger.error("The type is annotated with more than one Restate annotation", it.second) + } + } + + return discoveredAnnotatedElements.toSet() + } + + private fun generateMetaINF(services: List>) { + val resourceFile = "META-INF/services/${ServiceDefinitionFactory::class.java.canonicalName}" + val dependencies = + Dependencies(true, *(services.map { it.first.containingFile!! }.toTypedArray())) + + val writer: BufferedWriter = + try { + codeGenerator.createNewFileByPath(dependencies, resourceFile, "").bufferedWriter() + } catch (e: FileSystemException) { + val existingFile = e.file + val currentValues = existingFile.readText() + val newWriter = e.file.bufferedWriter() + newWriter.write(currentValues) + newWriter + } + + try { + writer.use { + for (service in services) { + it.write("${service.second.fqcnGeneratedNamePrefix}ServiceDefinitionFactory") + it.newLine() + } + } + } catch (e: IOException) { + logger.error("Unable to create $resourceFile: $e") + } + } +} diff --git a/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessorProvider.kt b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessorProvider.kt new file mode 100644 index 000000000..efa7fe7e0 --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/dev/restate/sdk/kotlin/gen/ServiceProcessorProvider.kt @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.gen + +import com.google.devtools.ksp.processing.SymbolProcessor +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import dev.restate.sdk.gen.model.AnnotationProcessingOptions + +class ServiceProcessorProvider : SymbolProcessorProvider { + + override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor { + return ServiceProcessor( + logger = environment.logger, + codeGenerator = environment.codeGenerator, + options = AnnotationProcessingOptions(environment.options), + ) + } +} diff --git a/sdk-api-kotlin-gen/bin/main/templates/Client.hbs b/sdk-api-kotlin-gen/bin/main/templates/Client.hbs new file mode 100644 index 000000000..d5597f1c5 --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/templates/Client.hbs @@ -0,0 +1,83 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +import dev.restate.sdk.kotlin.CallDurableFuture +import dev.restate.sdk.kotlin.InvocationHandle +import dev.restate.sdk.kotlin.Context +import dev.restate.serde.Serde +import dev.restate.common.Target +import kotlin.time.Duration +import dev.restate.client.kotlin.* + +object {{generatedClassSimpleName}} { + + fun fromContext(ctx: Context{{#isKeyed}}, key: String{{/isKeyed}}): ContextClient { + return ContextClient(ctx{{#isKeyed}}, key{{/isKeyed}}) + } + + fun fromClient(client: dev.restate.client.Client{{#isKeyed}}, key: String{{/isKeyed}}): IngressClient { + return IngressClient(client{{#isKeyed}}, key{{/isKeyed}}); + } + + fun connect(baseUri: String, {{#isKeyed}}key: String, {{/isKeyed}}requestOptions: dev.restate.client.RequestOptions = dev.restate.client.RequestOptions.DEFAULT): IngressClient { + return IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY, requestOptions){{#isKeyed}}, key{{/isKeyed}}); + } + + class ContextClient(private val ctx: Context{{#isKeyed}}, private val key: String{{/isKeyed}}){ + {{#handlers}} + suspend fun {{handlersClassMethodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): CallDurableFuture<{{{boxedOutputFqcn}}}> { + return this.ctx.call( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ) + }{{/handlers}} + + fun send(): Send { + return Send() + } + + inner class Send internal constructor() { + {{#handlers}} + suspend fun {{handlersClassMethodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): InvocationHandle<{{{boxedOutputFqcn}}}> { + return this@ContextClient.ctx.send( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init), delay + ); + }{{/handlers}} + } + } + + class IngressClient(private val client: dev.restate.client.Client{{#isKeyed}}, private val key: String{{/isKeyed}}) { + + {{#handlers}}{{#if isWorkflow}} + fun workflowHandle(): dev.restate.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> { + return this@IngressClient.client.workflowHandle( + {{metadataClass}}.SERVICE_NAME, + this.key, + {{outputSerdeRef}}); + } + + suspend fun submit({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { + return this@IngressClient.client.sendSuspend( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init), delay + ) + } + {{else}} + suspend fun {{handlersClassMethodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): {{{boxedOutputFqcn}}} { + return this@IngressClient.client.callSuspend( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ).response(); + } + {{/if}}{{/handlers}} + + fun send(): Send { + return Send() + } + + inner class Send() { + {{#handlers}}{{^isWorkflow}} + suspend fun {{handlersClassMethodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { + return this@IngressClient.client.sendSuspend( + {{../handlersClass}}.{{handlersClassMethodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init), delay + ) + }{{/isWorkflow}}{{/handlers}} + } + } +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/bin/main/templates/Handlers.hbs b/sdk-api-kotlin-gen/bin/main/templates/Handlers.hbs new file mode 100644 index 000000000..aa889fb7d --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/templates/Handlers.hbs @@ -0,0 +1,29 @@ +{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} + +object {{generatedClassSimpleName}} { + + {{#handlers}} + fun {{handlersClassMethodName}}({{#if ../isKeyed}}key: String, {{/if}}{{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.RequestBuilder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): {{#if isWorkflow}}dev.restate.common.WorkflowRequest{{else}}dev.restate.common.Request{{/if}}<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> { + val builder = dev.restate.common.Request.of<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>( + {{{targetExpr this "key"}}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + {{#if inputEmpty}}null{{else}}req{{/if}}); + builder.init() + return builder.build() {{#if isWorkflow}} as dev.restate.common.WorkflowRequest{{/if}} + } + + {{/handlers}} + + object Metadata { + const val SERVICE_NAME: String = "{{restateServiceName}}" + val SERDE_FACTORY: dev.restate.serde.SerdeFactory = {{serdeFactoryDecl}} + + object Serde { + {{#handlers}} + val {{inputSerdeFieldName}}: dev.restate.serde.Serde<{{{boxedInputFqcn}}}> = {{{inputSerdeDecl}}} + val {{outputSerdeFieldName}}: dev.restate.serde.Serde<{{{boxedOutputFqcn}}}> = {{{outputSerdeDecl}}} + {{/handlers}} + } + } +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/bin/main/templates/ServiceDefinitionFactory.hbs b/sdk-api-kotlin-gen/bin/main/templates/ServiceDefinitionFactory.hbs new file mode 100644 index 000000000..43bb3e563 --- /dev/null +++ b/sdk-api-kotlin-gen/bin/main/templates/ServiceDefinitionFactory.hbs @@ -0,0 +1,33 @@ +{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} + +class {{generatedClassSimpleName}}: dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory<{{originalClassFqcn}}> { + + override fun create(bindableService: {{originalClassFqcn}}, overrideHandlerOptions: dev.restate.sdk.endpoint.definition.HandlerRunner.Options?): dev.restate.sdk.endpoint.definition.ServiceDefinition { + val handlerRunnerOptions = if (overrideHandlerOptions != null) { + check(overrideHandlerOptions is dev.restate.sdk.kotlin.HandlerRunner.Options) + overrideHandlerOptions as dev.restate.sdk.kotlin.HandlerRunner.Options + } else { + dev.restate.sdk.kotlin.HandlerRunner.Options.DEFAULT + } + + return dev.restate.sdk.endpoint.definition.ServiceDefinition.of( + {{metadataClass}}.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.endpoint.definition.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.ServiceType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.ServiceType.SERVICE{{/if}}, + listOf( + {{#handlers}} + dev.restate.sdk.endpoint.definition.HandlerDefinition.of( + "{{restateName}}", + {{#if isExclusive}}dev.restate.sdk.endpoint.definition.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.HandlerType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.HandlerType.SHARED{{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + dev.restate.sdk.kotlin.HandlerRunner.{{#if outputEmpty}}ofEmptyReturn{{else}}of{{/if}}({{serdeFactoryRef}}, handlerRunnerOptions, bindableService::{{name}}) + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}{{#unless @last}},{{/unless}} + {{/handlers}} + ) + ) + } + + override fun supports(serviceObject: Any?): Boolean { + return serviceObject is {{originalClassFqcn}}; + } +} \ No newline at end of file diff --git a/sdk-api-kotlin/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory b/sdk-api-kotlin/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory new file mode 100644 index 000000000..cba8a5774 --- /dev/null +++ b/sdk-api-kotlin/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory @@ -0,0 +1 @@ +dev.restate.sdk.kotlin.internal.ReflectionServiceDefinitionFactory diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/ContextImpl.kt new file mode 100644 index 000000000..f5b911821 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/ContextImpl.kt @@ -0,0 +1,240 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.common.Slice +import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.HandlerRequest +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeTag +import java.util.concurrent.CompletableFuture +import kotlin.jvm.optionals.getOrNull +import kotlin.time.Duration +import kotlin.time.toJavaDuration +import kotlinx.coroutines.* +import kotlinx.coroutines.future.await + +internal class ContextImpl +internal constructor( + internal val handlerContext: HandlerContext, + internal val contextSerdeFactory: SerdeFactory, +) : WorkflowContext { + override fun key(): String { + return this.handlerContext.objectKey() + } + + override fun request(): HandlerRequest { + return this.handlerContext.request() + } + + override suspend fun get(key: StateKey): T? = + resolveSerde(key.serdeInfo()) + .let { serde -> + SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap { + it.getOrNull()?.let { serde.deserialize(it) } + } + } + .await() + + override suspend fun stateKeys(): Collection = + SingleDurableFutureImpl(handlerContext.getKeys().await()).await() + + override suspend fun set(key: StateKey, value: T) { + handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await() + } + + override suspend fun clear(key: StateKey<*>) { + handlerContext.clear(key.name()).await() + } + + override suspend fun clearAll() { + handlerContext.clearAll().await() + } + + override suspend fun timer(duration: Duration, name: String?): DurableFuture = + SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {} + + override suspend fun call( + request: Request + ): CallDurableFuture = + resolveSerde(request.getResponseTypeTag()).let { responseSerde -> + val callHandle = + handlerContext + .call( + request.getTarget(), + resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), + request.getIdempotencyKey(), + request.getHeaders()?.entries, + ) + .await() + + val callAsyncResult = + callHandle.callAsyncResult.map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + } + + return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult) + } + + override suspend fun send( + request: Request, + delay: Duration?, + ): InvocationHandle = + resolveSerde(request.getResponseTypeTag()).let { responseSerde -> + val invocationIdAsyncResult = + handlerContext + .send( + request.getTarget(), + resolveAndSerialize(request.getRequestTypeTag(), request.getRequest()), + request.getIdempotencyKey(), + request.getHeaders()?.entries, + delay?.toJavaDuration(), + ) + .await() + + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() + } + } + + override fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag, + ): InvocationHandle = + resolveSerde(responseTypeTag).let { responseSerde -> + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationId + } + } + + override suspend fun runAsync( + typeTag: TypeTag, + name: String, + retryPolicy: RetryPolicy?, + block: suspend () -> T, + ): DurableFuture { + val serde: Serde = resolveSerde(typeTag) + val coroutineCtx = currentCoroutineContext() + val javaRetryPolicy = + retryPolicy?.let { + dev.restate.sdk.common.RetryPolicy.exponential( + it.initialDelay.toJavaDuration(), + it.exponentiationFactor, + ) + .setMaxAttempts(it.maxAttempts) + .setMaxDelay(it.maxDelay?.toJavaDuration()) + .setMaxDuration(it.maxDuration?.toJavaDuration()) + } + + val scope = CoroutineScope(coroutineCtx + CoroutineName("restate-run-$name")) + + val asyncResult = + handlerContext + .submitRun(name) { completer -> + scope.launch { + val result: Slice? + try { + result = serde.serialize(block()) + } catch (e: Throwable) { + completer.proposeFailure(e, javaRetryPolicy) + return@launch + } + completer.proposeSuccess(result) + } + } + .await() + return SingleDurableFutureImpl(asyncResult).map { serde.deserialize(it) } + } + + override suspend fun awakeable(typeTag: TypeTag): Awakeable { + val serde: Serde = resolveSerde(typeTag) + val awk = handlerContext.awakeable().await() + return AwakeableImpl(awk.asyncResult, serde, awk.id) + } + + override fun awakeableHandle(id: String): AwakeableHandle { + return AwakeableHandleImpl(this, id) + } + + override fun random(): RestateRandom { + return RestateRandom(handlerContext.request().invocationId().toRandomSeed()) + } + + override fun promise(key: DurablePromiseKey): DurablePromise { + return DurablePromiseImpl(key) + } + + override fun promiseHandle(key: DurablePromiseKey): DurablePromiseHandle { + return DurablePromiseHandleImpl(key) + } + + inner class DurablePromiseImpl(private val key: DurablePromiseKey) : + DurablePromise { + val serde: Serde = resolveSerde(key.serdeInfo()) + + override suspend fun future(): DurableFuture = + SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap { + serde.deserialize(it) + } + + override suspend fun peek(): Output = + SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await()) + .simpleMap { it.map { serde.deserialize(it) } } + .await() + } + + inner class DurablePromiseHandleImpl(private val key: DurablePromiseKey) : + DurablePromiseHandle { + val serde: Serde = resolveSerde(key.serdeInfo()) + + override suspend fun resolve(payload: T) { + SingleDurableFutureImpl( + handlerContext + .resolvePromise( + key.name(), + serde.serializeWrappingException(handlerContext, payload), + ) + .await() + ) + .await() + } + + override suspend fun reject(reason: String) { + SingleDurableFutureImpl( + handlerContext.rejectPromise(key.name(), TerminalException(reason)).await() + ) + .await() + } + } + + internal fun resolveAndSerialize(typeTag: TypeTag, value: T): Slice { + return try { + val serde = contextSerdeFactory.create(typeTag) + serde.serialize(value) + } catch (e: Exception) { + handlerContext.fail(e) + throw CancellationException("Failed serialization", e) + } + } + + private fun resolveSerde(typeTag: TypeTag): Serde { + return try { + contextSerdeFactory.create(typeTag)!! + } catch (e: Exception) { + handlerContext.fail(e) + throw CancellationException("Cannot resolve serde", e) + } + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/HandlerRunner.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/HandlerRunner.kt new file mode 100644 index 000000000..d2f997667 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/HandlerRunner.kt @@ -0,0 +1,173 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Slice +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.kotlin.internal.RestateContextElement +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import io.opentelemetry.extension.kotlin.asContextElement +import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicReference +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.asContextElement +import kotlinx.coroutines.launch +import org.apache.logging.log4j.LogManager + +/** Adapter class for [dev.restate.sdk.endpoint.definition.HandlerRunner] to use the Kotlin API. */ +class HandlerRunner +internal constructor( + private val runner: suspend (CTX, REQ) -> RES, + private val contextSerdeFactory: SerdeFactory, + private val options: Options, +) : dev.restate.sdk.endpoint.definition.HandlerRunner { + + companion object { + private val LOG = LogManager.getLogger(HandlerRunner::class.java) + + /** + * Factory method for [dev.restate.sdk.kotlin.HandlerRunner], used by codegen. Please note this + * may be subject to breaking changes. + */ + fun of( + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX, REQ) -> RES, + ): HandlerRunner { + return HandlerRunner(runner, contextSerdeFactory, options) + } + + /** + * Factory method for [dev.restate.sdk.kotlin.HandlerRunner], used by codegen. Please note this + * may be subject to breaking changes. + */ + fun of( + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX) -> RES, + ): HandlerRunner { + return HandlerRunner({ ctx: CTX, _: Unit -> runner(ctx) }, contextSerdeFactory, options) + } + + /** + * Factory method for [dev.restate.sdk.kotlin.HandlerRunner], used by codegen. Please note this + * may be subject to breaking changes. + */ + fun ofEmptyReturn( + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX, REQ) -> Unit, + ): HandlerRunner { + return HandlerRunner( + { ctx: CTX, req: REQ -> + runner(ctx, req) + Unit + }, + contextSerdeFactory, + options, + ) + } + + /** + * Factory method for [dev.restate.sdk.kotlin.HandlerRunner], used by codegen. Please note this + * may be subject to breaking changes. + */ + fun ofEmptyReturn( + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX) -> Unit, + ): HandlerRunner { + return HandlerRunner( + { ctx: CTX, _: Unit -> + runner(ctx) + Unit + }, + contextSerdeFactory, + options, + ) + } + } + + override fun run( + handlerContext: HandlerContext, + requestSerde: Serde, + responseSerde: Serde, + onClosedInvocationStreamHook: AtomicReference, + ): CompletableFuture { + val ctx: Context = ContextImpl(handlerContext, contextSerdeFactory) + + val scope = + CoroutineScope( + options.coroutineContext + + RestateContextElement(ctx) + + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL + .asContextElement(handlerContext) + + handlerContext.request().openTelemetryContext()!!.asContextElement() + ) + + val completableFuture = CompletableFuture() + val job = + scope.launch { + val serializedResult: Slice + + try { + // Parse input + val req: REQ + try { + req = requestSerde.deserialize(handlerContext.request().body()) + } catch (e: Throwable) { + LOG.warn("Error deserializing request", e) + completableFuture.completeExceptionally( + throw TerminalException( + TerminalException.BAD_REQUEST_CODE, + "Cannot deserialize request: " + e.message, + ) + ) + return@launch + } + + // Execute user code + @Suppress("UNCHECKED_CAST") val res: RES = runner(ctx as CTX, req) + + // Serialize output + try { + serializedResult = responseSerde.serialize(res) + } catch (e: Throwable) { + LOG.warn("Error when serializing response", e) + completableFuture.completeExceptionally(e) + return@launch + } + } catch (e: Throwable) { + completableFuture.completeExceptionally(e) + return@launch + } + + // Complete callback + completableFuture.complete(serializedResult) + } + onClosedInvocationStreamHook.set { job.cancel() } + + return completableFuture + } + + /** + * [dev.restate.sdk.kotlin.HandlerRunner] options. You can override the default options to + * configure the [CoroutineContext] to run the handler. + */ + data class Options(val coroutineContext: CoroutineContext) : + dev.restate.sdk.endpoint.definition.HandlerRunner.Options { + companion object { + val DEFAULT: Options = Options(Dispatchers.Default) + } + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/KtSerdes.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/KtSerdes.kt new file mode 100644 index 000000000..da50d1463 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/KtSerdes.kt @@ -0,0 +1,179 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Slice +import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.StateKey +import dev.restate.serde.Serde +import dev.restate.serde.Serde.Schema +import java.nio.charset.StandardCharsets +import kotlin.reflect.typeOf +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlinx.serialization.builtins.ListSerializer +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonTransformingSerializer +import kotlinx.serialization.serializer + +@Deprecated("Use stateKey() instead") +object KtStateKey { + + /** Creates a json [StateKey]. */ + @Deprecated("Use stateKey() instead", replaceWith = ReplaceWith(expression = "stateKey()")) + inline fun json(name: String): StateKey { + return StateKey.of(name, KtSerdes.json()) + } +} + +@Deprecated("Use durablePromiseKey() instead") +object KtDurablePromiseKey { + + /** Creates a json [StateKey]. */ + @Deprecated( + "Use durablePromiseKey() instead", + replaceWith = ReplaceWith(expression = "durablePromiseKey()"), + ) + inline fun json(name: String): DurablePromiseKey { + return DurablePromiseKey.of(name, KtSerdes.json()) + } +} + +@Deprecated("Moved to dev.restate.serde.kotlinx") +object KtSerdes { + + @Deprecated("Moved to dev.restate.serde.kotlinx") + inline fun json(): Serde { + @Suppress("UNCHECKED_CAST") + return when (typeOf()) { + typeOf() -> UNIT as Serde + else -> json(serializer()) + } + } + + @Deprecated("Moved to dev.restate.serde.kotlinx") + val UNIT: Serde = + object : Serde { + override fun serialize(value: Unit?): Slice { + return Slice.EMPTY + } + + override fun deserialize(value: Slice) { + return + } + + override fun contentType(): String? { + return null + } + } + + @Deprecated("Moved to dev.restate.serde.kotlinx") + inline fun json(serializer: KSerializer): Serde { + return object : Serde { + override fun serialize(value: T?): Slice { + if (value == null) { + return Slice.wrap( + Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray() + ) + } + + return Slice.wrap(Json.encodeToString(serializer, value).encodeToByteArray()) + } + + override fun deserialize(value: Slice): T { + return Json.decodeFromString( + serializer, + String(value.toByteArray(), StandardCharsets.UTF_8), + ) + } + + override fun contentType(): String { + return "application/json" + } + + override fun jsonSchema(): Schema { + val schema: JsonSchema = serializer.descriptor.jsonSchema() + return Serde.StringifiedJsonSchema(Json.encodeToString(schema)) + } + } + } + + @Deprecated("Moved to dev.restate.serde.kotlinx") + @Serializable + @PublishedApi + internal data class JsonSchema( + @Serializable(with = StringListSerializer::class) val type: List? = null, + val format: String? = null, + ) { + companion object { + val INT = JsonSchema(type = listOf("number"), format = "int32") + + val LONG = JsonSchema(type = listOf("number"), format = "int64") + + val DOUBLE = JsonSchema(type = listOf("number"), format = "double") + + val FLOAT = JsonSchema(type = listOf("number"), format = "float") + + val STRING = JsonSchema(type = listOf("string")) + + val BOOLEAN = JsonSchema(type = listOf("boolean")) + + val OBJECT = JsonSchema(type = listOf("object")) + + val LIST = JsonSchema(type = listOf("array")) + + val ANY = JsonSchema() + } + } + + object StringListSerializer : + JsonTransformingSerializer>(ListSerializer(String.Companion.serializer())) { + override fun transformSerialize(element: JsonElement): JsonElement { + require(element is JsonArray) + return element.singleOrNull() ?: element + } + } + + @Deprecated("Moved to dev.restate.serde.kotlinx") + @OptIn(ExperimentalSerializationApi::class) + @PublishedApi + internal fun SerialDescriptor.jsonSchema(): JsonSchema { + var schema = + when (this.kind) { + PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN + PrimitiveKind.BYTE -> JsonSchema.INT + PrimitiveKind.CHAR -> JsonSchema.STRING + PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE + PrimitiveKind.FLOAT -> JsonSchema.FLOAT + PrimitiveKind.INT -> JsonSchema.INT + PrimitiveKind.LONG -> JsonSchema.LONG + PrimitiveKind.SHORT -> JsonSchema.INT + PrimitiveKind.STRING -> JsonSchema.STRING + StructureKind.LIST -> JsonSchema.LIST + StructureKind.MAP -> JsonSchema.OBJECT + else -> JsonSchema.ANY + } + + // Add nullability constraint + if (this.isNullable && schema.type != null) { + schema = schema.copy(type = schema.type.plus("null")) + } + + return schema + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/RetryPolicy.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/RetryPolicy.kt new file mode 100644 index 000000000..b720ac8c4 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/RetryPolicy.kt @@ -0,0 +1,70 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +/** Retry policy configuration. */ +data class RetryPolicy( + /** Initial retry delay for the first retry attempt. */ + val initialDelay: Duration, + /** Exponentiation factor to use when computing the next retry delay. */ + val exponentiationFactor: Float, + /** Maximum delay between retries. */ + val maxDelay: Duration? = null, + /** + * Maximum number of attempts, including the initial, before giving up retrying. + * + * The policy gives up retrying when either at least the given number of attempts is reached, or + * the [maxDuration] (if set) is reached first. If both [maxAttempts] and [maxDuration] are + * `null`, the policy will retry indefinitely. + * + * **Note:** The number of actual retries may be higher than the provided value. This is due to + * the nature of the `run` operation, which executes the closure on the service and sends the + * result afterward to Restate. + */ + val maxAttempts: Int? = null, + /** + * Maximum duration of the retry loop. + * + * The policy gives up retrying when either the retry loop lasted at least for this given max + * duration, or the [maxAttempts] (if set) is reached first. If both [maxAttempts] and + * [maxDuration] are `null`, the policy will retry indefinitely. + * + * **Note:** The real retry loop duration may be higher than the given duration. TThis is due to + * the nature of the `run` operation, which executes the closure on the service and sends the + * result afterward to Restate. + */ + val maxDuration: Duration? = null, +) { + + data class Builder( + var initialDelay: Duration = 100.milliseconds, + var exponentiationFactor: Float = 2.0f, + var maxDelay: Duration? = null, + var maxAttempts: Int? = null, + var maxDuration: Duration? = null, + ) { + fun build() = + RetryPolicy( + initialDelay = initialDelay, + exponentiationFactor = exponentiationFactor, + maxDelay = maxDelay, + maxDuration = maxDuration, + maxAttempts = maxAttempts, + ) + } +} + +fun retryPolicy(init: RetryPolicy.Builder.() -> Unit): RetryPolicy { + val builder = RetryPolicy.Builder() + builder.init() + return builder.build() +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/UsePreviewContext.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/UsePreviewContext.kt new file mode 100644 index 000000000..36d12791e --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/UsePreviewContext.kt @@ -0,0 +1,20 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +/** + * Opt-in annotation to use the preview of new context features. + * + * In order to use these methods, you **MUST enable the preview context**, through the endpoint + * builders using `enablePreviewContext()`. + */ +@RequiresOptIn +@Retention(AnnotationRetention.BINARY) +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +annotation class UsePreviewContext diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/Util.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/Util.kt new file mode 100644 index 000000000..b8341d699 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/Util.kt @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Slice +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.serde.Serde +import kotlinx.coroutines.CancellationException + +internal fun Serde.serializeWrappingException( + handlerContext: HandlerContext, + value: T?, +): Slice { + return try { + this.serialize(value) + } catch (e: Exception) { + handlerContext.fail(e) + throw CancellationException("Failed serialization", e) + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/api.kt new file mode 100644 index 000000000..45b8830f3 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/api.kt @@ -0,0 +1,1525 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.InvocationOptions +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.common.Slice +import dev.restate.common.reflection.kotlin.RequestCaptureProxy +import dev.restate.common.reflection.kotlin.captureInvocation +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.HandlerRequest +import dev.restate.sdk.common.InvocationId +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.* +import java.nio.ByteBuffer +import java.util.* +import kotlin.coroutines.Continuation +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.startCoroutine +import kotlin.random.Random +import kotlin.time.Clock +import kotlin.time.Duration +import kotlin.time.ExperimentalTime +import kotlin.time.Instant +import kotlinx.coroutines.currentCoroutineContext + +/** + * This interface exposes the Restate functionalities to Restate services. It can be used to + * interact with other Restate services, record non-deterministic closures, execute timers and + * synchronize with external systems. + * + * All methods of this interface, and related interfaces, throws either [TerminalException] or + * cancels the coroutine. [TerminalException] can be caught and acted upon. + * + * NOTE: This interface MUST NOT be accessed concurrently since it can lead to different orderings + * of user actions, corrupting the execution of the invocation. + */ +sealed interface Context { + + fun request(): HandlerRequest + + /** + * Causes the current execution of the function invocation to sleep for the given duration. + * + * @param duration for which to sleep. + */ + suspend fun sleep(duration: Duration) { + timer(duration).await() + } + + /** + * Causes the start of a timer for the given duration. You can await on the timer end by invoking + * [DurableFuture.await]. + * + * @param duration for which to sleep. + * @param name name to be used for the timer + */ + suspend fun timer(duration: Duration, name: String? = null): DurableFuture + + /** + * Invoke another Restate handler. + * + * @param request Request object. For each service, a class called `Handlers` is + * generated containing the request builders. + * @return a [CallDurableFuture] that wraps the result. + */ + suspend fun call(request: Request): CallDurableFuture + + /** + * Invoke another Restate handler without waiting for the response. + * + * @param request Request object. For each service, a class called `Handlers` is + * generated containing the request builders. + * @param delay The delay to send the request, if any. + * @return an [InvocationHandle] to interact with the sent request. + */ + suspend fun send( + request: Request, + delay: Duration? = null, + ): InvocationHandle + + /** + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + */ + fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag, + ): InvocationHandle + + /** + * Execute a closure, recording the result value in the journal. The result value will be + * re-played in case of re-invocation (e.g. because of failure recovery or suspension point) + * without re-executing the closure. + * + * You can name this closure using the `name` parameter. This name will be available in the + * observability tools. + * + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. To control and limit the amount of retries, pass a + * [RetryPolicy] to this function. + * + *

Error handling

+ * + * Errors occurring within this closure won't be propagated to the caller, unless they are + * [TerminalException]. Consider the following code: + * ``` + * // Bad usage of try-catch outside the runBlock + * try { + * ctx.runBlock { + * throw IllegalStateException(); + * }; + * } catch (e: IllegalStateException) { + * // This will never be executed, + * // but the error will be retried by Restate, + * // following the invocation retry policy. + * } + * + * // Good usage of try-catch outside the runBlock + * try { + * ctx.runBlock { + * throw TerminalException("my error"); + * }; + * } catch (e: TerminalException) { + * // This is invoked + * } + * ``` + * + * To propagate failures to the run call-site, make sure to wrap them in [TerminalException]. + * + * @param typeTag the type tag of the return value, used to serialize/deserialize it. + * @param name the name of the side effect. + * @param block closure to execute. + * @param T type of the return value. + * @return value of the runBlock operation. + */ + suspend fun runBlock( + typeTag: TypeTag, + name: String = "", + retryPolicy: RetryPolicy? = null, + block: suspend () -> T, + ): T { + return runAsync(typeTag, name, retryPolicy, block).await() + } + + /** + * Execute a closure asynchronously. This is like [runBlock], but it returns a [DurableFuture] + * that you can combine and select. + * + * ``` + * // Fan out the subtasks - run them in parallel + * val futures = subTasks.map { subTask -> + * ctx.runAsync { subTask.execute() } + * } + * + * // Fan in - Await all results and aggregate + * val results = futures.awaitAll() + * ``` + * + * @see runBlock + */ + suspend fun runAsync( + typeTag: TypeTag, + name: String = "", + retryPolicy: RetryPolicy? = null, + block: suspend () -> T, + ): DurableFuture + + /** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * You can use this feature to implement external asynchronous systems interactions, for example + * you can send a Kafka record including the [Awakeable.id], and then let another service consume + * from Kafka the responses of given external system interaction by using [awakeableHandle]. + * + * @param serde the response type tag to use for deserializing the [Awakeable] result. + * @return the [Awakeable] to await on. + * @see Awakeable + */ + suspend fun awakeable(typeTag: TypeTag): Awakeable + + /** + * Create a new [AwakeableHandle] for the provided identifier. You can use it to + * [AwakeableHandle.resolve] or [AwakeableHandle.reject] the linked [Awakeable]. + * + * @see Awakeable + */ + fun awakeableHandle(id: String): AwakeableHandle + + /** + * Create a [RestateRandom] instance inherently predictable, seeded on the + * [dev.restate.sdk.common.InvocationId], which is not secret. + * + * This instance is useful to generate identifiers, idempotency keys, and for uniform sampling + * from a set of options. If a cryptographically secure value is needed, please generate that + * externally using [runBlock]. + * + * You MUST NOT use this [Random] instance inside a [runBlock]. + * + * @return the [Random] instance. + */ + fun random(): RestateRandom + + /** + * Returns the current time as a deterministic [Instant]. + * + *

This method returns the current timestamp in a way that is consistent across replays. The + * time is captured using [Context.runBlock], ensuring that the same value is returned during + * replay as was returned during the original execution. + * + * @return the recorded [Instant] + * @see Clock.System.now + */ + @ExperimentalTime + suspend fun instantNow(): Instant { + return runBlock(name = "Clock.System.now()", typeTag = typeTag()) { + Clock.System.now() + } + } +} + +/** + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + */ +inline fun Context.invocationHandle( + invocationId: String +): InvocationHandle { + return this.invocationHandle(invocationId, typeTag()) +} + +/** + * Execute a closure, recording the result value in the journal. The result value will be re-played + * in case of re-invocation (e.g. because of failure recovery or suspension point) without + * re-executing the closure. + * + * You can name this closure using the `name` parameter. This name will be available in the + * observability tools. + * + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. To control and limit the amount of retries, pass a [RetryPolicy] + * to this function. + * + *

Error handling

+ * + * Errors occurring within this closure won't be propagated to the caller, unless they are + * [TerminalException]. Consider the following code: + * ``` + * // Bad usage of try-catch outside the runBlock + * try { + * ctx.runBlock { + * throw IllegalStateException(); + * }; + * } catch (e: IllegalStateException) { + * // This will never be executed, + * // but the error will be retried by Restate, + * // following the invocation retry policy. + * } + * + * // Good usage of try-catch outside the runBlock + * try { + * ctx.runBlock { + * throw TerminalException("my error"); + * }; + * } catch (e: TerminalException) { + * // This is invoked + * } + * ``` + * + * To propagate failures to the run call-site, make sure to wrap them in [TerminalException]. + * + * @param name the name of the side effect. + * @param block closure to execute. + * @param T type of the return value. + * @return value of the runBlock operation. + */ +suspend inline fun Context.runBlock( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): T { + return this.runBlock(typeTag(), name, retryPolicy, block) +} + +/** + * Execute a closure asynchronously. This is like [runBlock], but it returns a [DurableFuture] that + * you can combine and select. + * + * ``` + * // Fan out the subtasks - run them in parallel + * val futures = subTasks.map { subTask -> + * ctx.runAsync { subTask.execute() } + * } + * + * // Fan in - Await all results and aggregate + * val results = futures.awaitAll() + * ``` + * + * @see runBlock + */ +suspend inline fun Context.runAsync( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): DurableFuture { + return this.runAsync(typeTag(), name, retryPolicy, block) +} + +/** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * You can use this feature to implement external asynchronous systems interactions, for example you + * can send a Kafka record including the [Awakeable.id], and then let another service consume from + * Kafka the responses of given external system interaction by using [awakeableHandle]. + * + * @return the [Awakeable] to await on. + * @see Awakeable + */ +suspend inline fun Context.awakeable(): Awakeable { + return this.awakeable(typeTag()) +} + +/** + * This interface can be used only within shared handlers of virtual objects. It extends [Context] + * adding access to the virtual object instance key-value state storage. + */ +sealed interface SharedObjectContext : Context { + + /** @return the key of this object */ + fun key(): String + + /** + * Gets the state stored under key, deserializing the raw value using the [StateKey.serdeInfo]. + * + * @param key identifying the state to get and its type. + * @return the value containing the stored state deserialized. + * @throws RuntimeException when the state cannot be deserialized. + */ + suspend fun get(key: StateKey): T? + + /** + * Gets all the known state keys for this virtual object instance. + * + * @return the immutable collection of known state keys. + */ + suspend fun stateKeys(): Collection +} + +inline fun stateKey(name: String): StateKey { + return StateKey.of(name, typeTag()) +} + +suspend inline fun SharedObjectContext.get(key: String): T? { + return this.get(StateKey.of(key, typeTag())) +} + +/** + * This interface can be used only within exclusive handlers of virtual objects. It extends + * [Context] adding access to the virtual object instance key-value state storage. + */ +sealed interface ObjectContext : SharedObjectContext { + + /** + * Sets the given value under the given key, serializing the value using the [StateKey.serdeInfo]. + * + * @param key identifying the value to store and its type. + * @param value to store under the given key. + */ + suspend fun set(key: StateKey, value: T) + + /** + * Clears the state stored under key. + * + * @param key identifying the state to clear. + */ + suspend fun clear(key: StateKey<*>) + + /** Clears all the state of this virtual object instance key-value state storage */ + suspend fun clearAll() +} + +suspend inline fun ObjectContext.set(key: String, value: T) { + this.set(StateKey.of(key, typeTag()), value) +} + +/** + * This interface can be used only within shared handlers of workflow. It extends [Context] adding + * access to the workflow instance key-value state storage and to the [DurablePromise] API. + * + * NOTE: This interface MUST NOT be accessed concurrently since it can lead to different orderings + * of user actions, corrupting the execution of the invocation. + * + * @see Context + * @see SharedObjectContext + */ +sealed interface SharedWorkflowContext : SharedObjectContext { + /** + * Create a [DurablePromise] for the given key. + * + * You can use this feature to implement interaction between different workflow handlers, e.g. to + * send a signal from a shared handler to the workflow handler. + * + * @see DurablePromise + */ + fun promise(key: DurablePromiseKey): DurablePromise + + /** + * Create a new [DurablePromiseHandle] for the provided key. You can use it to + * [DurablePromiseHandle.resolve] or [DurablePromiseHandle.reject] the given [DurablePromise]. + * + * @see DurablePromise + */ + fun promiseHandle(key: DurablePromiseKey): DurablePromiseHandle +} + +/** + * This interface can be used only within workflow handlers of workflow. It extends [Context] adding + * access to the workflow instance key-value state storage and to the [DurablePromise] API. + * + * NOTE: This interface MUST NOT be accessed concurrently since it can lead to different orderings + * of user actions, corrupting the execution of the invocation. + * + * @see Context + * @see ObjectContext + */ +sealed interface WorkflowContext : SharedWorkflowContext, ObjectContext + +class RestateRandom(seed: Long) : Random() { + private val r = Random(seed) + + override fun nextBits(bitCount: Int): Int { + return r.nextBits(bitCount) + } + + /** Generate a UUID that is stable across retries and replays. */ + fun nextUUID(): UUID { + return UUID(this.nextLong(), this.nextLong()) + } +} + +/** + * A [DurableFuture] allows to await an asynchronous result. Once [await] is called, the execution + * waits until the asynchronous result is available. + * + * The result can be either a success or a failure. In case of a failure, [await] will throw a + * [dev.restate.sdk.common.TerminalException]. + * + * @param T type of this future's result + */ +sealed interface DurableFuture { + + /** + * Wait for this [DurableFuture] to complete. + * + * @throws TerminalException if this future was completed with a failure + */ + suspend fun await(): T + + /** + * Same as [await] but throws a [dev.restate.sdk.common.TimeoutException] if this [DurableFuture] + * doesn't complete before the provided `timeout`. + */ + suspend fun await(duration: Duration): T + + /** + * Creates a [DurableFuture] that throws a [dev.restate.sdk.common.TimeoutException] if this + * future doesn't complete before the provided `timeout`. + */ + suspend fun withTimeout(duration: Duration): DurableFuture + + /** Clause for [select] operator. */ + val onAwait: SelectClause + + /** + * Map the success result of this [DurableFuture]. + * + * @param transform the mapper to execute if this [DurableFuture] completes with success. The + * mapper can throw a [dev.restate.sdk.common.TerminalException], thus failing the returned + * [DurableFuture]. + * @return a new [DurableFuture] with the mapped result, when completed + */ + suspend fun map(transform: suspend (value: T) -> R): DurableFuture + + /** + * Map both the success and the failure result of this [DurableFuture]. + * + * @param transformSuccess the mapper to execute if this [DurableFuture] completes with success. + * The mapper can throw a [dev.restate.sdk.common.TerminalException], thus failing the returned + * [DurableFuture]. + * @param transformFailure the mapper to execute if this [DurableFuture] completes with failure. + * The mapper can throw a [dev.restate.sdk.common.TerminalException], thus failing the returned + * [DurableFuture]. + * @return a new [DurableFuture] with the mapped result, when completed + */ + suspend fun map( + transformSuccess: suspend (value: T) -> R, + transformFailure: suspend (exception: TerminalException) -> R, + ): DurableFuture + + /** + * Map the failure result of this [DurableFuture]. + * + * @param transform the mapper to execute if this [DurableFuture] completes with failure. The + * mapper can throw a [dev.restate.sdk.common.TerminalException], thus failing the returned + * [DurableFuture]. + * @return a new [DurableFuture] with the mapped result, when completed + */ + suspend fun mapFailure(transform: suspend (exception: TerminalException) -> T): DurableFuture + + companion object { + /** @see awaitAll */ + fun all( + first: DurableFuture<*>, + second: DurableFuture<*>, + vararg others: DurableFuture<*>, + ): DurableFuture { + return wrapAllDurableFuture(listOf(first) + listOf(second) + others.asList()) + } + + /** @see awaitAll */ + fun all(durableFutures: List>): DurableFuture { + return wrapAllDurableFuture(durableFutures) + } + + /** @see select */ + fun any( + first: DurableFuture<*>, + second: DurableFuture<*>, + vararg others: DurableFuture<*>, + ): DurableFuture { + return wrapAnyDurableFuture(listOf(first) + listOf(second) + others.asList()) + } + + /** @see select */ + fun any(durableFutures: List>): DurableFuture { + return wrapAnyDurableFuture(durableFutures) + } + } +} + +/** + * Like [kotlinx.coroutines.awaitAll], but for [DurableFuture]. + * + * ``` + * val a1 = ctx.awakeable() + * val a2 = ctx.awakeable() + * + * val result = listOf(a1, a2) + * .awaitAll() + * .joinToString(separator = "-") + * ``` + */ +suspend fun Collection>.awaitAll(): List { + return awaitAll(*toTypedArray()) +} + +/** @see Collection.awaitAll */ +suspend fun awaitAll(vararg durableFutures: DurableFuture): List { + if (durableFutures.isEmpty()) { + return emptyList() + } + if (durableFutures.size == 1) { + return listOf(durableFutures[0].await()) + } + wrapAllDurableFuture(durableFutures.asList()).await() + return durableFutures.map { it.await() }.toList() +} + +/** + * Like [kotlinx.coroutines.selects.select], but for [DurableFuture] + * + * ``` + * val callFuture = ctx.awakeable() + * val timeout = ctx.timer(10.seconds) + * + * val result = select { + * callFuture.onAwait { it.message } + * timeout.onAwait { throw TimeoutException() } + * }.await() + * ``` + */ +suspend inline fun select(crossinline builder: SelectBuilder.() -> Unit): DurableFuture { + val selectImpl = SelectImplementation() + builder.invoke(selectImpl) + return selectImpl.build() +} + +sealed interface SelectBuilder { + /** Registers a clause in this [select] expression. */ + operator fun SelectClause.invoke(block: suspend (T) -> R) +} + +sealed interface SelectClause { + val durableFuture: DurableFuture +} + +/** The [DurableFuture] returned by a [Context.call]. */ +sealed interface CallDurableFuture : DurableFuture { + /** Get the invocation id of this call. */ + suspend fun invocationId(): String +} + +/** An invocation handle, that can be used to interact with a running invocation. */ +sealed interface InvocationHandle { + /** @return the invocation id of this invocation */ + suspend fun invocationId(): String + + /** Cancel this invocation. */ + suspend fun cancel() + + /** Attach to this invocation. This will wait for the invocation to complete */ + suspend fun attach(): DurableFuture + + /** @return the output of this invocation, if present. */ + suspend fun output(): Output +} + +/** + * An [Awakeable] is a special type of [DurableFuture] which can be arbitrarily completed by another + * service, by addressing it with its [id]. + * + * It can be used to let a service wait on a specific condition/result, which is fulfilled by + * another service or by an external system at a later point in time. + * + * For example, you can send a Kafka record including the [Awakeable.id], and then let another + * service consume from Kafka the responses of given external system interaction by using + * [RestateContext.awakeableHandle]. + */ +sealed interface Awakeable : DurableFuture { + /** The unique identifier of this [Awakeable] instance. */ + val id: String +} + +/** This class represents a handle to an [Awakeable] created in another service. */ +sealed interface AwakeableHandle { + /** + * Complete with success the [Awakeable]. + * + * @param typeTag used to serialize the [Awakeable] result payload. + * @param payload the result payload. + * @see Awakeable + */ + suspend fun resolve(typeTag: TypeTag, payload: T) + + /** + * Complete with failure the [Awakeable]. + * + * @param reason the rejection reason. + * @see Awakeable + */ + suspend fun reject(reason: String) +} + +/** + * Complete with success the [Awakeable]. + * + * @param payload the result payload. + * @see Awakeable + */ +suspend inline fun AwakeableHandle.resolve(payload: T) { + return this.resolve(typeTag(), payload) +} + +/** + * A [DurablePromise] is a durable, distributed version of a Kotlin's Deferred, or more commonly of + * a future/promise. Restate keeps track of the [DurablePromise] across restarts/failures. + * + * You can use this feature to implement interaction between different workflow handlers, e.g. to + * send a signal from a shared handler to the workflow handler. + * + * Use [SharedWorkflowContext.promiseHandle] to complete a durable promise, either by + * [DurablePromiseHandle.resolve] or [DurablePromiseHandle.reject]. + * + * A [DurablePromise] is tied to a single workflow execution and can only be resolved or rejected + * while the workflow run is still ongoing. Once the workflow is cleaned up, all its associated + * promises with their completions will be cleaned up as well. + * + * NOTE: This interface MUST NOT be accessed concurrently since it can lead to different orderings + * of user actions, corrupting the execution of the invocation. + */ +sealed interface DurablePromise { + /** @return the future to await the promise result on. */ + suspend fun future(): DurableFuture + + @Deprecated( + message = "Use future() instead", + level = DeprecationLevel.WARNING, + replaceWith = ReplaceWith(expression = "future()"), + ) + suspend fun awaitable(): DurableFuture { + return future() + } + + /** @return the value, if already present, otherwise returns an empty optional. */ + suspend fun peek(): Output +} + +/** This class represents a handle to a [DurablePromise] created in another service. */ +sealed interface DurablePromiseHandle { + /** + * Complete with success the [DurablePromise]. + * + * @param payload the result payload. + * @see DurablePromise + */ + suspend fun resolve(payload: T) + + /** + * Complete with failure the [DurablePromise]. + * + * @param reason the rejection reason. + * @see DurablePromise + */ + suspend fun reject(reason: String) +} + +inline fun durablePromiseKey(name: String): DurablePromiseKey { + return DurablePromiseKey.of(name, typeTag()) +} + +/** Shorthand for [Context.call] */ +suspend fun Request.call( + context: Context +): CallDurableFuture { + return context.call(this) +} + +/** Shorthand for [Context.send] */ +suspend fun Request.send( + context: Context, + delay: Duration? = null, +): InvocationHandle { + return context.send(this, delay) +} + +val HandlerRequest.invocationId: InvocationId + get() = this.invocationId() +val HandlerRequest.openTelemetryContext: io.opentelemetry.context.Context + get() = this.openTelemetryContext() +val HandlerRequest.body: Slice + get() = this.body() +val HandlerRequest.bodyAsByteArray: ByteArray + get() = this.bodyAsByteArray() +val HandlerRequest.bodyAsByteBuffer: ByteBuffer + get() = this.bodyAsBodyBuffer() +val HandlerRequest.headers: Map + get() = this.headers() + +// ============================================================================= +// Free-floating API functions for the reflection-based API +// ============================================================================= + +/** + * Get the current Restate [Context] from within a handler. + * + * This function must be called from within a Restate handler's suspend function. It retrieves the + * context from the coroutine context. + * + * Example usage: + * ```kotlin + * @Service + * class MyService { + * @Handler + * suspend fun myHandler(input: String): String { + * val ctx = context() + * // Use ctx for Restate operations + * return "processed: $input" + * } + * } + * ``` + * + * @throws IllegalStateException if called outside of a Restate handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun context(): Context { + val element = + currentCoroutineContext()[dev.restate.sdk.kotlin.internal.RestateContextElement] + ?: error("context() must be called from within a Restate handler") + return element.ctx +} + +/** + * Get the current request information. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.request + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun request(): HandlerRequest { + return context().request() +} + +/** + * Get the deterministic random instance. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.random + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun random(): RestateRandom { + return context().random() +} + +/** + * Get [RestateClock], that deterministically records the time. + * + * @see RestateClock.now + */ +@ExperimentalTime +@org.jetbrains.annotations.ApiStatus.Experimental +fun clock(): RestateClock { + return RestateClockImpl +} + +@ExperimentalTime +@org.jetbrains.annotations.ApiStatus.Experimental +interface RestateClock { + /** + * Returns the current time as a deterministic [Instant]. + * + *

This method returns the current timestamp in a way that is consistent across replays. The + * time is captured using [runBlock], ensuring that the same value is returned during replay as + * was returned during the original execution. + * + * @return the recorded [Instant] + * @throws IllegalStateException if called outside a Restate handler + * @see Clock.System.now + */ + suspend fun now(): Instant +} + +@ExperimentalTime +@org.jetbrains.annotations.ApiStatus.Experimental +private object RestateClockImpl : RestateClock { + override suspend fun now(): Instant { + return context().instantNow() + } +} + +/** + * Get [RestateClock], that deterministically records the time. + * + * @see RestateClock.now + */ +@ExperimentalTime +@get:org.jetbrains.annotations.ApiStatus.Experimental +val Clock.Companion.Restate: RestateClock + get() = clock() + +/** + * Causes the current execution of the function invocation to sleep for the given duration. + * + * @param duration for which to sleep. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.sleep + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun sleep(duration: Duration) { + context().sleep(duration) +} + +/** + * Causes the start of a timer for the given duration. + * + * @param duration for which to sleep. + * @param name name to be used for the timer + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.timer + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun timer(name: String = "", duration: Duration): DurableFuture { + return context().timer(duration, name) +} + +/** + * Execute a closure, recording the result value in the journal. + * + * @param name the name of the side effect. + * @param retryPolicy optional retry policy. + * @param block closure to execute. + * @return value of the run operation. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.runBlock + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun runBlock( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): T { + return context().runBlock(typeTag(), name, retryPolicy, block) +} + +/** + * Execute a closure asynchronously. + * + * @param name the name of the side effect. + * @param retryPolicy optional retry policy. + * @param block closure to execute. + * @return a [DurableFuture] that you can combine and select. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.runAsync + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun runAsync( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): DurableFuture { + return context().runAsync(typeTag(), name, retryPolicy, block) +} + +/** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * @return the [Awakeable] to await on. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.awakeable + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun awakeable(): Awakeable { + return context().awakeable(typeTag()) +} + +/** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * You can use this feature to implement external asynchronous systems interactions, for example you + * can send a Kafka record including the [Awakeable.id], and then let another service consume from + * Kafka the responses of given external system interaction by using [awakeableHandle]. + * + * @param typeTag the type tag for deserializing the [Awakeable] result. + * @return the [Awakeable] to await on. + * @see Awakeable + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun awakeable(typeTag: TypeTag): Awakeable { + return context().awakeable(typeTag) +} + +/** + * Create a new [AwakeableHandle] for the provided identifier. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.awakeableHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun awakeableHandle(id: String): AwakeableHandle { + return context().awakeableHandle(id) +} + +/** + * Get an [InvocationHandle] for an already existing invocation. + * + * @param invocationId The invocation to interact with. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.invocationHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun invocationHandle(invocationId: String): InvocationHandle { + return context().invocationHandle(invocationId, typeTag()) +} + +/** + * Get the key of this Virtual Object. + * + * @return the key of this object + * @throws IllegalStateException if called from a regular Service handler or outside of a Restate + * handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun objectKey(): String { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("objectKey() must be called from within a Restate handler") + + if (!handlerContext.canReadState()) { + error( + "objectKey() can be used only within Virtual Object handlers. " + + "Check https://docs.restate.dev/develop/java/services#virtual-objects for more details." + ) + } + + return (ctx as SharedObjectContext).key() +} + +/** + * Get the key of this Workflow. + * + * @return the key of this workflow + * @throws IllegalStateException if called from a regular Service handler, or from a virtual object + * handler, or outside of a Restate handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun workflowKey(): String { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("workflowKey() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises()) { + error( + "workflowKey() can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/services#workflows for more details." + ) + } + + return (ctx as SharedObjectContext).key() +} + +/** + * Access to this Virtual Object/Workflow state. + * + * @return [KotlinState] for this Virtual Object/Workflow + * @throws IllegalStateException if called from a regular Service handler or outside of a Restate + * handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun state(): KotlinState { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("state() must be called from within a Restate handler") + + if (!handlerContext.canReadState()) { + error( + "state() can be used only within Virtual Object or Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/state for more details." + ) + } + + return KotlinStateImpl(ctx as SharedObjectContext, handlerContext) +} + +/** + * Create a [DurablePromise] for the given key. + * + * @throws IllegalStateException if called from a non-Workflow handler or outside of a Restate + * handler + * @see SharedWorkflowContext.promise + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun promise(key: DurablePromiseKey): DurablePromise { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("promise() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises() || !handlerContext.canWritePromises()) { + error( + "promise(key) can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/external-events#durable-promises for more details." + ) + } + + return (ctx as SharedWorkflowContext).promise(key) +} + +/** + * Create a new [DurablePromiseHandle] for the provided key. + * + * @throws IllegalStateException if called from a non-Workflow handler or outside of a Restate + * handler + * @see SharedWorkflowContext.promiseHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun promiseHandle(key: DurablePromiseKey): DurablePromiseHandle { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("promiseHandle() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises() || !handlerContext.canWritePromises()) { + error( + "promiseHandle(key) can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/external-events#durable-promises for more details." + ) + } + + return (ctx as SharedWorkflowContext).promiseHandle(key) +} + +/** + * Interface for accessing Virtual Object/Workflow state in the reflection-based API. + * + * This interface provides suspend-friendly state operations that can be used from within Restate + * handlers using the free-floating `state()` function. + * + * Example usage: + * ```kotlin + * @VirtualObject + * class Counter { + * companion object { + * private val COUNT = stateKey("count") + * } + * + * @Handler + * suspend fun increment(): Long { + * val current = state().get(COUNT) ?: 0L + * val next = current + 1 + * state().set(COUNT, next) + * return next + * } + * } + * ``` + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KotlinState { + /** + * Gets the state stored under key, deserializing the raw value using the [StateKey.serdeInfo]. + * + * @param key identifying the state to get and its type. + * @return the value containing the stored state deserialized, or null if not set. + * @throws RuntimeException when the state cannot be deserialized. + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun get(key: StateKey): T? + + /** + * Sets the given value under the given key, serializing the value using the [StateKey.serdeInfo]. + * + * @param key identifying the value to store and its type. + * @param value to store under the given key. + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental + suspend fun set(key: StateKey, value: T) + + /** + * Clears the state stored under key. + * + * @param key identifying the state to clear. + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun clear(key: StateKey<*>) + + /** + * Clears all the state of this virtual object instance key-value state storage. + * + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun clearAll() + + /** + * Gets all the known state keys for this virtual object instance. + * + * @return the immutable collection of known state keys. + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun keys(): Collection +} + +/** + * Gets the state stored under key. + * + * @param key the name of the state key. + * @return the value containing the stored state deserialized, or null if not set. + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun KotlinState.get(key: String): T? { + return this.get(StateKey.of(key, typeTag())) +} + +/** + * Sets the given value under the given key. + * + * @param key the name of the state key. + * @param value to store under the given key. + * @throws IllegalStateException if called from a Shared handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun KotlinState.set(key: String, value: T) { + this.set(StateKey.of(key, typeTag()), value) +} + +// Internal implementation of KotlinState +private class KotlinStateImpl( + private val ctx: SharedObjectContext, + private val handlerContext: dev.restate.sdk.endpoint.definition.HandlerContext, +) : KotlinState { + override suspend fun get(key: StateKey): T? { + return ctx.get(key) + } + + override suspend fun set(key: StateKey, value: T) { + checkCanWriteState("set") + (ctx as ObjectContext).set(key, value) + } + + override suspend fun clear(key: StateKey<*>) { + checkCanWriteState("clear") + (ctx as ObjectContext).clear(key) + } + + override suspend fun clearAll() { + checkCanWriteState("clearAll") + (ctx as ObjectContext).clearAll() + } + + override suspend fun keys(): Collection { + return ctx.stateKeys() + } + + private fun checkCanWriteState(opName: String) { + if (!handlerContext.canWriteState()) { + error( + "state().$opName() cannot be used in shared handlers. " + + "Check https://docs.restate.dev/develop/java/state for more details." + ) + } + } +} + +/** + * Kotlin-idiomatic request for invoking Restate services from within a handler. + * + * Example usage: + * ```kotlin + * toService() + * .request { add(1) } + * .options { idempotencyKey = "123" } + * .call() + * ``` + * + * @param Req the request type + * @param Res the response type + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KRequest : Request { + + /** + * Configure invocation options using a DSL. + * + * @param block builder block for options + * @return a new request with the configured options + */ + @org.jetbrains.annotations.ApiStatus.Experimental + fun options(block: InvocationOptions.Builder.() -> Unit): KRequest + + /** + * Call the target handler and return a [CallDurableFuture] for the result. + * + * @return a [CallDurableFuture] that will contain the response + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun call(): CallDurableFuture + + /** + * Send the request without waiting for the response. + * + * @param delay optional delay before the invocation is executed + * @return an [InvocationHandle] to interact with the sent request + */ + @org.jetbrains.annotations.ApiStatus.Experimental + suspend fun send(delay: Duration? = null): InvocationHandle +} + +/** + * Builder for creating type-safe requests from within a handler. + * + * This builder allows the response type to be inferred from the lambda passed to [request]. + * + * @param SVC the service/virtual object/workflow class + */ +@org.jetbrains.annotations.ApiStatus.Experimental +class KRequestBuilder +@PublishedApi +internal constructor( + private val clazz: Class, + private val key: String?, +) { + /** + * Create a request by invoking a method on the target. + * + * The response type is inferred from the return type of the invoked method. + * + * @param Res the response type (inferred from the lambda) + * @param block a suspend lambda that invokes a method on the target + * @return a [KRequest] with the correct response type + */ + @Suppress("UNCHECKED_CAST") + suspend fun request(block: suspend SVC.() -> Res): KRequest { + return KRequestImpl( + RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest() + ) + as KRequest + } +} + +/** + * Create a builder for invoking a Restate service from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val result = toService() + * .request { greet("Alice") } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the service class annotated with @Service + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toService(): KRequestBuilder { + ReflectionUtils.mustHaveServiceAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, null) +} + +/** + * Create a builder for invoking a Restate virtual object from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): Long { + * val result = toVirtualObject("my-counter") + * .request { add(1) } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toVirtualObject(key: String): KRequestBuilder { + ReflectionUtils.mustHaveVirtualObjectAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, key) +} + +/** + * Create a builder for invoking a Restate workflow from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val result = toWorkflow("workflow-123") + * .request { run("input") } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toWorkflow(key: String): KRequestBuilder { + ReflectionUtils.mustHaveWorkflowAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, key) +} + +/** Implementation of [KRequest] for SDK context. */ +private class KRequestImpl(private val request: Request) : + KRequest, Request by request { + override fun options(block: InvocationOptions.Builder.() -> Unit): KRequest { + val builder = InvocationOptions.builder() + builder.block() + return KRequestImpl( + this.toBuilder().headers(builder.headers).idempotencyKey(builder.idempotencyKey).build() + ) + } + + override suspend fun call(): CallDurableFuture { + return context().call(request) + } + + override suspend fun send(delay: Duration?): InvocationHandle { + return context().send(request, delay) + } +} + +/** + * Create a proxy client for a Restate service. + * + * This creates a proxy that allows calling service methods directly. The proxy intercepts method + * calls, converts them to Restate requests, and awaits the result. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val greeter = service() + * val response = greeter.greet("Alice") + * return "Got: $response" + * } + * ``` + * + * @param SVC the service class annotated with @Service + * @return a proxy client to invoke the service + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun service(): SVC { + return service(SVC::class.java) +} + +/** + * Create a proxy client for a Restate virtual object. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): Long { + * val counter = virtualObject("my-counter") + * return counter.increment() + * } + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a proxy client to invoke the virtual object + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun virtualObject(key: String): SVC { + return virtualObject(SVC::class.java, key) +} + +/** + * Create a proxy client for a Restate workflow. + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a proxy client to invoke the workflow + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun workflow(key: String): SVC { + return workflow(SVC::class.java, key) +} + +@PublishedApi +internal fun service(clazz: Class): SVC { + ReflectionUtils.mustHaveServiceAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, null).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +@PublishedApi +internal fun virtualObject(clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +@PublishedApi +internal fun workflow(clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveWorkflowAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/endpoint/endpoint.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/endpoint/endpoint.kt new file mode 100644 index 000000000..f552747b0 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/endpoint/endpoint.kt @@ -0,0 +1,410 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.endpoint + +import dev.restate.sdk.endpoint.Endpoint +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.InvocationRetryPolicy +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import kotlin.time.Duration +import kotlin.time.toJavaDuration +import kotlin.time.toKotlinDuration + +/** Endpoint builder function. */ +fun endpoint(init: Endpoint.Builder.() -> Unit): Endpoint { + val builder = Endpoint.builder() + builder.init() + return builder.build() +} + +/** + * Documentation as shown in the UI, Admin REST API, and the generated OpenAPI documentation of this + * service. + */ +var ServiceDefinition.Configurator.documentation: String? + get() { + return this.documentation() + } + set(value) { + this.documentation(value) + } + +/** Service metadata, as propagated in the Admin REST API. */ +var ServiceDefinition.Configurator.metadata: Map? + get() { + return this.metadata() + } + set(value) { + this.metadata(value) + } + +/** + * This timer guards against stalled invocations. Once it expires, Restate triggers a graceful + * termination by asking the invocation to suspend (which preserves intermediate progress). + * + * The [abortTimeout] is used to abort the invocation, in case it doesn't react to the request to + * suspend. + * + * This overrides the default inactivity timeout configured in the restate-server for all + * invocations to this service. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.inactivityTimeout: Duration? + get() { + return this.inactivityTimeout()?.toKotlinDuration() + } + set(value) { + this.inactivityTimeout(value?.toJavaDuration()) + } + +/** + * This timer guards against stalled service/handler invocations that are supposed to terminate. The + * abort timeout is started after the [inactivityTimeout] has expired and the service/handler + * invocation has been asked to gracefully terminate. Once the timer expires, it will abort the + * service/handler invocation. + * + * This timer potentially *interrupts* user code. If the user code needs longer to gracefully + * terminate, then this value needs to be set accordingly. + * + * This overrides the default abort timeout configured in the restate-server for all invocations to + * this service. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.abortTimeout: Duration? + get() { + return this.abortTimeout()?.toKotlinDuration() + } + set(value) { + this.abortTimeout(value?.toJavaDuration()) + } + +/** + * The retention duration for this workflow. This applies only to workflow services. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.workflowRetention: Duration? + get() { + return this.workflowRetention()?.toKotlinDuration() + } + set(value) { + this.workflowRetention(value?.toJavaDuration()) + } + +/** + * The retention duration of idempotent requests to this service. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.idempotencyRetention: Duration? + get() { + return this.idempotencyRetention()?.toKotlinDuration() + } + set(value) { + this.idempotencyRetention(value?.toJavaDuration()) + } + +/** + * The journal retention. When set, this applies to all requests to all handlers of this service. + * + * In case the request has an idempotency key, the [idempotencyRetention] caps the journal retention + * time. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + * + * @return this + */ +var ServiceDefinition.Configurator.journalRetention: Duration? + get() { + return this.journalRetention()?.toKotlinDuration() + } + set(value) { + this.journalRetention(value?.toJavaDuration()) + } + +/** + * When set to `true`, lazy state will be enabled for all invocations to this service. This is + * relevant only for workflows and virtual objects. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.enableLazyState: Boolean? + get() { + return this.enableLazyState() + } + set(value) { + this.enableLazyState(value) + } + +/** + * When set to `true` this service, with all its handlers, cannot be invoked from the restate-server + * HTTP and Kafka ingress, but only from other services. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var ServiceDefinition.Configurator.ingressPrivate: Boolean? + get() { + return this.ingressPrivate() + } + set(value) { + this.ingressPrivate(value) + } + +/** + * Retry policy used by Restate when invoking this service. + * + *

NOTE: You can set this field only if you register this service against + * restate-server >= 1.5, otherwise the service discovery will fail. + * + * @see InvocationRetryPolicy + */ +var ServiceDefinition.Configurator.invocationRetryPolicy: InvocationRetryPolicy? + get() { + return this.invocationRetryPolicy() + } + set(value) { + this.invocationRetryPolicy(value) + } + +/** + * Set the acceptable content type when ingesting HTTP requests. Wildcards can be used, e.g. + * `application/*` or `*/*`. + */ +var HandlerDefinition.Configurator.acceptContentType: String? + get() { + return this.acceptContentType() + } + set(value) { + this.acceptContentType(value) + } + +/** + * Documentation as shown in the UI, Admin REST API, and the generated OpenAPI documentation of this + * handler. + */ +var HandlerDefinition.Configurator.documentation: String? + get() { + return this.documentation() + } + set(value) { + this.documentation(value) + } + +/** Handler metadata, as propagated in the Admin REST API. */ +var HandlerDefinition.Configurator.metadata: Map? + get() { + return this.metadata() + } + set(value) { + this.metadata(value) + } + +/** + * This timer guards against stalled invocations. Once it expires, Restate triggers a graceful + * termination by asking the invocation to suspend (which preserves intermediate progress). + * + * The [abortTimeout] is used to abort the invocation, in case it doesn't react to the request to + * suspend. + * + * This overrides the inactivity timeout set for the service and the default set in restate-server. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.inactivityTimeout: Duration? + get() { + return this.inactivityTimeout()?.toKotlinDuration() + } + set(value) { + this.inactivityTimeout(value?.toJavaDuration()) + } + +/** + * This timer guards against stalled invocations that are supposed to terminate. The abort timeout + * is started after the [inactivityTimeout] has expired and the invocation has been asked to + * gracefully terminate. Once the timer expires, it will abort the invocation. + * + * This timer potentially *interrupts* user code. If the user code needs longer to gracefully + * terminate, then this value needs to be set accordingly. + * + * This overrides the abort timeout set for the service and the default set in restate-server. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.abortTimeout: Duration? + get() { + return this.abortTimeout()?.toKotlinDuration() + } + set(value) { + this.abortTimeout(value?.toJavaDuration()) + } + +/** + * The retention duration of idempotent requests to this service. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.idempotencyRetention: Duration? + get() { + return this.idempotencyRetention()?.toKotlinDuration() + } + set(value) { + this.idempotencyRetention(value?.toJavaDuration()) + } + +/** + * The retention duration for this workflow handler. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.workflowRetention: Duration? + get() { + return this.workflowRetention()?.toKotlinDuration() + } + set(value) { + this.workflowRetention(value?.toJavaDuration()) + } + +/** + * The journal retention for invocations to this handler. + * + * In case the request has an idempotency key, the [idempotencyRetention] caps the journal retention + * time. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.journalRetention: Duration? + get() { + return this.journalRetention()?.toKotlinDuration() + } + set(value) { + this.journalRetention(value?.toJavaDuration()) + } + +/** + * When set to `true` this handler cannot be invoked from the restate-server HTTP and Kafka ingress, + * but only from other services. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.ingressPrivate: Boolean? + get() { + return this.ingressPrivate() + } + set(value) { + this.ingressPrivate(value) + } + +/** + * When set to `true`, lazy state will be enabled for all invocations to this handler. This is + * relevant only for workflows and virtual objects. + * + * *NOTE:* You can set this field only if you register this service against restate-server >= 1.4, + * otherwise the service discovery will fail. + */ +var HandlerDefinition.Configurator.enableLazyState: Boolean? + get() { + return this.enableLazyState() + } + set(value) { + this.enableLazyState(value) + } + +/** + * Retry policy used by Restate when invoking this handler. + * + *

NOTE: You can set this field only if you register this service against + * restate-server >= 1.5, otherwise the service discovery will fail. + * + * @see InvocationRetryPolicy + */ +var HandlerDefinition.Configurator.invocationRetryPolicy: InvocationRetryPolicy? + get() { + return this.invocationRetryPolicy() + } + set(value) { + this.invocationRetryPolicy(value) + } + +/** Initial delay before the first retry attempt. If unset, server defaults apply. */ +var InvocationRetryPolicy.Builder.initialInterval: Duration? + get() { + return this.initialInterval()?.toKotlinDuration() + } + set(value) { + this.initialInterval(value?.toJavaDuration()) + } + +/** Exponential backoff multiplier used to compute the next retry delay. */ +var InvocationRetryPolicy.Builder.exponentiationFactor: Double? + get() { + return this.exponentiationFactor() + } + set(value) { + this.exponentiationFactor(value) + } + +/** Upper bound for any computed retry delay. */ +var InvocationRetryPolicy.Builder.maxInterval: Duration? + get() { + return this.maxInterval()?.toKotlinDuration() + } + set(value) { + this.maxInterval(value?.toJavaDuration()) + } + +/** + * Maximum number of attempts before giving up retrying. + * + * The initial call counts as the first attempt; retries increment the count by 1. When giving up, + * the behavior defined with [onMaxAttempts] will be applied. + * + * @see InvocationRetryPolicy.OnMaxAttempts + */ +var InvocationRetryPolicy.Builder.maxAttempts: Int? + get() { + return this.maxAttempts() + } + set(value) { + this.maxAttempts(value) + } + +/** + * Behavior when reaching max attempts. + * + * @see InvocationRetryPolicy.OnMaxAttempts + */ +var InvocationRetryPolicy.Builder.onMaxAttempts: InvocationRetryPolicy.OnMaxAttempts? + get() { + return this.onMaxAttempts() + } + set(value) { + this.onMaxAttempts(value) + } + +/** [InvocationRetryPolicy] builder function. */ +fun invocationRetryPolicy(init: InvocationRetryPolicy.Builder.() -> Unit): InvocationRetryPolicy { + val builder = InvocationRetryPolicy.builder() + builder.init() + return builder.build() +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/futures.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/futures.kt new file mode 100644 index 000000000..917f930ce --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/futures.kt @@ -0,0 +1,250 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Output +import dev.restate.common.Slice +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.common.TimeoutException +import dev.restate.sdk.endpoint.definition.AsyncResult +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import kotlin.time.Duration +import kotlin.time.toJavaDuration +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.future.await +import kotlinx.coroutines.launch + +internal abstract class BaseDurableFutureImpl : DurableFuture { + abstract fun asyncResult(): AsyncResult + + override val onAwait: SelectClause + get() = SelectClauseImpl(this) + + override suspend fun await(): T { + return asyncResult().poll().await() + } + + override suspend fun await(duration: Duration): T { + return withTimeout(duration).await() + } + + override suspend fun withTimeout(duration: Duration): DurableFuture { + return (DurableFuture.any( + this, + SingleDurableFutureImpl( + asyncResult().ctx().timer(duration.toJavaDuration(), null).await() + ), + ) as BaseDurableFutureImpl<*>) + .simpleMap { + if (it == 1) { + throw TimeoutException("Timed out waiting for durable future after $duration") + } + + try { + @Suppress("UNCHECKED_CAST") + return@simpleMap this.asyncResult().poll().getNow(null) as T + } catch (e: ExecutionException) { + throw e.cause ?: e // unwrap original cause from ExecutionException + } + } + } + + fun simpleMap(transform: (T) -> R): DurableFuture { + return SingleDurableFutureImpl( + this.asyncResult().map { CompletableFuture.completedFuture(transform(it)) } + ) + } + + override suspend fun map(transform: suspend (T) -> R): DurableFuture { + var ctx = currentCoroutineContext() + return SingleDurableFutureImpl( + this.asyncResult().map { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transform(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + } + ) + } + + override suspend fun map( + transformSuccess: suspend (T) -> R, + transformFailure: suspend (TerminalException) -> R, + ): DurableFuture { + var ctx = currentCoroutineContext() + return SingleDurableFutureImpl( + this.asyncResult() + .map( + { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transformSuccess(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + }, + { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transformFailure(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + }, + ) + ) + } + + override suspend fun mapFailure(transform: suspend (TerminalException) -> T): DurableFuture { + var ctx = currentCoroutineContext() + return SingleDurableFutureImpl( + this.asyncResult().mapFailure { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val newT: T + try { + newT = transform(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(newT) + } + completableFuture + } + ) + } +} + +internal open class SingleDurableFutureImpl(private val asyncResult: AsyncResult) : + BaseDurableFutureImpl() { + override fun asyncResult(): AsyncResult { + return asyncResult + } +} + +internal fun wrapAllDurableFuture(durableFutures: List>): DurableFuture { + check(durableFutures.isNotEmpty()) { "The durable futures list should be non empty" } + val ctx = (durableFutures.get(0) as BaseDurableFutureImpl<*>).asyncResult().ctx() + return SingleDurableFutureImpl( + ctx.createAllAsyncResult( + durableFutures.map { (it as BaseDurableFutureImpl<*>).asyncResult() } + ) + ) + .simpleMap {} +} + +internal fun wrapAnyDurableFuture( + durableFutures: List> +): BaseDurableFutureImpl { + check(durableFutures.isNotEmpty()) { "The durable futures list should be non empty" } + val ctx = (durableFutures.get(0) as BaseDurableFutureImpl<*>).asyncResult().ctx() + return SingleDurableFutureImpl( + ctx.createAnyAsyncResult( + durableFutures.map { (it as BaseDurableFutureImpl<*>).asyncResult() } + ) + ) +} + +internal class CallDurableFutureImpl +internal constructor( + callAsyncResult: AsyncResult, + private val invocationIdAsyncResult: AsyncResult, +) : SingleDurableFutureImpl(callAsyncResult), CallDurableFuture { + override suspend fun invocationId(): String { + return invocationIdAsyncResult.poll().await() + } +} + +internal abstract class BaseInvocationHandle +internal constructor( + private val handlerContext: HandlerContext, + private val responseSerde: Serde, +) : InvocationHandle { + override suspend fun cancel() { + val ignored = handlerContext.cancelInvocation(invocationId()).await() + } + + override suspend fun attach(): DurableFuture = + SingleDurableFutureImpl( + handlerContext.attachInvocation(invocationId()).await().map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + } + ) + + override suspend fun output(): Output = + SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await()) + .simpleMap { it.map { responseSerde.deserialize(it) } } + .await() +} + +internal class AwakeableImpl +internal constructor(asyncResult: AsyncResult, serde: Serde, override val id: String) : + SingleDurableFutureImpl( + asyncResult.map { CompletableFuture.completedFuture(serde.deserialize(it)) } + ), + Awakeable + +internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) : AwakeableHandle { + override suspend fun resolve(typeTag: TypeTag, payload: T) { + contextImpl.handlerContext + .resolveAwakeable(id, contextImpl.resolveAndSerialize(typeTag, payload)) + .await() + } + + override suspend fun reject(reason: String) { + return + contextImpl.handlerContext.rejectAwakeable(id, TerminalException(reason)).await() + } +} + +internal class SelectClauseImpl(override val durableFuture: DurableFuture) : SelectClause + +@PublishedApi +internal class SelectImplementation : SelectBuilder { + + private val clauses: MutableList, suspend (Any?) -> R>> = + mutableListOf() + + @Suppress("UNCHECKED_CAST") + override fun SelectClause.invoke(block: suspend (T) -> R) { + clauses.add(this.durableFuture as BaseDurableFutureImpl<*> to block as suspend (Any?) -> R) + } + + suspend fun build(): DurableFuture { + return wrapAnyDurableFuture(clauses.map { it.first }).map { index -> + clauses[index].let { resolved -> resolved.first.await().let { resolved.second(it) } } + } + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt new file mode 100644 index 000000000..12587722d --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt @@ -0,0 +1,22 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +internal class MalformedRestateServiceException : Exception { + constructor( + serviceName: String, + message: String, + ) : super("Failed to instantiate Restate service '$serviceName'.\nReason: $message") + + constructor( + serviceName: String, + message: String, + cause: Throwable, + ) : super("Failed to instantiate Restate service '$serviceName'.\nReason: $message", cause) +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt new file mode 100644 index 000000000..4c9815ec2 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt @@ -0,0 +1,425 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.annotation.* +import dev.restate.sdk.endpoint.definition.* +import dev.restate.sdk.endpoint.definition.HandlerRunner +import dev.restate.sdk.kotlin.* +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory.KtTypeTag +import dev.restate.serde.provider.DefaultSerdeFactoryProvider +import java.lang.reflect.InvocationTargetException +import java.lang.reflect.Modifier +import java.util.* +import kotlin.reflect.KClass +import kotlin.reflect.KFunction +import kotlin.reflect.KVisibility +import kotlin.reflect.full.callSuspend +import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.hasAnnotation +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaMethod +import kotlin.reflect.jvm.jvmErasure + +internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory { + @Volatile private var cachedDefaultSerdeFactory: SerdeFactory? = null + + override fun create( + serviceInstance: Any, + overrideHandlerOptions: HandlerRunner.Options?, + ): ServiceDefinition { + val handlerRunnerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options? + if ( + overrideHandlerOptions == null || + overrideHandlerOptions is dev.restate.sdk.kotlin.HandlerRunner.Options + ) { + handlerRunnerOptions = overrideHandlerOptions + } else { + throw IllegalArgumentException( + "The provided options class MUST be instance of dev.restate.sdk.kotlin.HandlerRunner.Options, but was " + + overrideHandlerOptions.javaClass + ) + } + + val serviceClazz: Class<*> = serviceInstance.javaClass + + // The behavior of the reflections work as follows: + // * There is one class that has all the restate annotations. That being either the serviceClazz + // itself (concrete class) or some interface in the hierarchy. + // * Then there is the serviceInstance, which is where we call the methods themselves. + val restateAnnotatedClazz = ReflectionUtils.findRestateAnnotatedClass(serviceClazz) + val restateAnnotatedKotlinClazz = restateAnnotatedClazz.kotlin + + val hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(restateAnnotatedClazz) + val hasVirtualObjectAnnotation = + ReflectionUtils.hasVirtualObjectAnnotation(restateAnnotatedClazz) + val hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(restateAnnotatedClazz) + + val hasAnyAnnotation = + hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation + if (!hasAnyAnnotation) { + throw MalformedRestateServiceException( + restateAnnotatedClazz.simpleName, + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, no annotation was found", + ) + } + val hasExactlyOneAnnotation = + hasServiceAnnotation xor (hasVirtualObjectAnnotation xor hasWorkflowAnnotation) + + if (!hasExactlyOneAnnotation) { + throw MalformedRestateServiceException( + restateAnnotatedClazz.simpleName, + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, more than one annotation found", + ) + } + + val serviceName = ReflectionUtils.extractServiceName(restateAnnotatedClazz) + val serviceType = + if (hasServiceAnnotation) ServiceType.SERVICE + else if (hasVirtualObjectAnnotation) ServiceType.VIRTUAL_OBJECT else ServiceType.WORKFLOW + val serdeFactory: SerdeFactory = resolveSerdeFactory(restateAnnotatedKotlinClazz) + + val kFunctions = + restateAnnotatedKotlinClazz.memberFunctions.filter { + it.hasAnnotation() || + it.hasAnnotation() || + it.hasAnnotation() || + it.hasAnnotation() + } + + if (kFunctions.isEmpty()) { + throw MalformedRestateServiceException(serviceName, "No @Handler method found") + } + return ServiceDefinition.of( + serviceName, + serviceType, + kFunctions + .map { + this.createHandlerDefinition( + serviceInstance, + it, + serviceName, + serviceType, + serdeFactory, + handlerRunnerOptions, + ) + } + .toList(), + ) + } + + private fun createHandlerDefinition( + serviceInstance: Any, + kFunction: KFunction<*>, + serviceName: String, + serviceType: ServiceType, + serdeFactory: SerdeFactory, + overrideHandlerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options?, + ): HandlerDefinition<*, *> { + val handlerInfo: ReflectionUtils.HandlerInfo = + ReflectionUtils.mustHaveHandlerAnnotation(kFunction.javaMethod!!) + val handlerName: String? = handlerInfo.name + + // Check if this is a Kotlin suspend function + validateKFunction(kFunction, serviceName) + + val parameters = kFunction.valueParameters + + // Check for old-style context parameter + if ( + (parameters.size == 1 || parameters.size == 2) && + (parameters[0] == Context::class.java || + parameters[0] == SharedObjectContext::class.java || + parameters[0] == ObjectContext::class.java || + parameters[0] == WorkflowContext::class.java || + parameters[0] == SharedWorkflowContext::class.java) + ) { + val ctxTypeName = parameters[0].type.toString() + val returnTypeName = kFunction.returnType.toString() + val actualSignature = + if (parameters.size == 1) "ctx: $ctxTypeName" + else "ctx: $ctxTypeName, input: ${parameters[1].type}" + val expectedSignature = if (parameters.isEmpty()) "" else "input: ${parameters[1].type}" + throw MalformedRestateServiceException( + serviceName, + """ + The service is being loaded with the new Reflection based API, but handler '${handlerName}' contains $ctxTypeName as first parameter. Suggestions: + * If you want to use the new Reflection based API, remove $ctxTypeName from the method definition and use the functions from dev.restate.sdk.kotlin inside the handler: + - suspend fun ${handlerName}(${actualSignature}): $returnTypeName { + - // code + - } + Replace with: + + suspend fun ${handlerName}(${expectedSignature}): $returnTypeName { + + // Use functions from dev.restate.sdk.kotlin.* + + // code + + } + * If you''re still using the KSP based API, make sure the ServiceDefinitionFactory class was correctly generated. + """ + .trimIndent(), + ) + } + + if (parameters.size > 1) { + throw MalformedRestateServiceException( + serviceName, + "More than one parameter found in method ${kFunction.name}. Only zero or one parameter is supported.", + ) + } + + if (serviceType == ServiceType.SERVICE && handlerInfo.shared) { + throw MalformedRestateServiceException( + serviceName, + "@Shared is only supported on virtual objects and workflow handlers", + ) + } + val handlerType = + if (handlerInfo.shared) HandlerType.SHARED + else if (serviceType == ServiceType.VIRTUAL_OBJECT) HandlerType.EXCLUSIVE + else if (serviceType == ServiceType.WORKFLOW) HandlerType.WORKFLOW else null + + val inputSerde = + resolveInputSerde( + kFunction, + serdeFactory, + serviceName, + ) + val outputSerde = resolveOutputSerde(kFunction, serdeFactory, serviceName) + + val runner = + createSuspendHandlerRunner( + serviceInstance, + kFunction, + parameters.size, + serdeFactory, + overrideHandlerOptions, + ) + + var handlerDefinition: HandlerDefinition = + HandlerDefinition.of(handlerName, handlerType, inputSerde, outputSerde, runner) + + // Look for the accept annotation + if (parameters.isNotEmpty()) { + val acceptAnnotation: Accept? = parameters[0].findAnnotation() + if (acceptAnnotation != null) { + handlerDefinition = handlerDefinition.withAcceptContentType(acceptAnnotation.value) + } + } + + return handlerDefinition + } + + private fun createSuspendHandlerRunner( + serviceInstance: Any, + kFunction: KFunction<*>, + parameterCount: Int, + serdeFactory: SerdeFactory, + overrideHandlerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options?, + ): dev.restate.sdk.kotlin.HandlerRunner { + return dev.restate.sdk.kotlin.HandlerRunner.of( + serdeFactory, + overrideHandlerOptions ?: dev.restate.sdk.kotlin.HandlerRunner.Options.DEFAULT, + ) { _, input -> + try { + if (parameterCount == 0) { + kFunction.callSuspend(serviceInstance) + } else { + kFunction.callSuspend(serviceInstance, input) + } + } catch (t: InvocationTargetException) { + throw t.cause!! + } + } + } + + @Suppress("UNCHECKED_CAST") + private fun resolveInputSerde( + kFunction: KFunction<*>, + serdeFactory: SerdeFactory, + serviceName: String, + ): Serde { + if (kFunction.valueParameters.isEmpty()) { + return KotlinSerializationSerdeFactory.UNIT as Serde + } + + val parameter = kFunction.valueParameters[0] + + val rawAnnotation = parameter.findAnnotation() + val jsonAnnotation = parameter.findAnnotation() + + // Validate annotations + if (rawAnnotation != null && jsonAnnotation != null) { + throw MalformedRestateServiceException( + serviceName, + "Parameter in method ${kFunction.name} cannot be annotated with both @Raw and @Json", + ) + } + + if (rawAnnotation != null) { + // Validate parameter type is byte[] + if (parameter.type.jvmErasure != ByteArray::class) { + throw MalformedRestateServiceException( + serviceName, + "Parameter annotated with @Raw in method ${kFunction.name} MUST be of type ByteArray, was ${parameter.type}", + ) + } + var serde: Serde = Serde.RAW as Serde + // Apply content type if not default + if (rawAnnotation.contentType != "application/octet-stream") { + serde = Serde.withContentType(rawAnnotation.contentType, serde) + } + return serde + } + + // Use serdeFactory to create serde + var serde = + serdeFactory.create(KtTypeTag(parameter.type.jvmErasure, parameter.type)) + as Serde + + // Apply custom content-type from @Json if present + if (jsonAnnotation != null && jsonAnnotation.contentType != "application/json") { + serde = Serde.withContentType(jsonAnnotation.contentType, serde) + } + + return serde + } + + @Suppress("UNCHECKED_CAST") + private fun resolveOutputSerde( + kFunction: KFunction<*>, + serdeFactory: SerdeFactory, + serviceName: String, + ): Serde { + val outputType = kFunction.returnType + + // Handle Unit type (Kotlin void equivalent) + if (outputType == Void.TYPE || outputType.jvmErasure == Unit::class) { + return KotlinSerializationSerdeFactory.UNIT as Serde + } + + val rawAnnotation = kFunction.findAnnotation() + val jsonAnnotation = kFunction.findAnnotation() + + // Validate annotations + if (rawAnnotation != null && jsonAnnotation != null) { + throw MalformedRestateServiceException( + serviceName, + "Method ${kFunction.name} cannot be annotated with both @Raw and @Json", + ) + } + + if (rawAnnotation != null) { + // Validate return type is byte[] + if (outputType.jvmErasure != ByteArray::class) { + throw MalformedRestateServiceException( + serviceName, + "Method ${kFunction.name} annotated with @Raw MUST return byte[], was $outputType", + ) + } + var serde: Serde = Serde.RAW as Serde + // Apply content type if not default + if (rawAnnotation.contentType != "application/octet-stream") { + serde = Serde.withContentType(rawAnnotation.contentType, serde) + } + return serde + } + + // Use serdeFactory to create serde + var serde = + serdeFactory.create(KtTypeTag(outputType.jvmErasure, outputType)) as Serde + + // Apply custom content-type from @Json if present + if (jsonAnnotation != null && jsonAnnotation.contentType != "application/json") { + serde = Serde.withContentType(jsonAnnotation.contentType, serde) + } + + return serde + } + + private fun resolveSerdeFactory(serviceClazz: KClass<*>): SerdeFactory { + // Check for CustomSerdeFactory annotation + val customSerdeFactoryAnnotation = serviceClazz.findAnnotation() + + if (customSerdeFactoryAnnotation != null) { + try { + return customSerdeFactoryAnnotation.value.java.getDeclaredConstructor().newInstance() + } catch (e: Exception) { + throw MalformedRestateServiceException( + serviceClazz.simpleName!!, + "Failed to instantiate custom SerdeFactory: ${customSerdeFactoryAnnotation.value.java.name}", + e, + ) + } + } + + // Try DefaultSerdeFactoryProvider -> if there's one, it's an easy pick! + if (this.cachedDefaultSerdeFactory != null) { + return this.cachedDefaultSerdeFactory!! + } + + val loadedFactories: MutableList?> = + ServiceLoader.load(DefaultSerdeFactoryProvider::class.java).stream().toList() + if (loadedFactories.size == 1) { + this.cachedDefaultSerdeFactory = loadedFactories[0]!!.get()!!.create() + return this.cachedDefaultSerdeFactory!! + } + + // Load kotlinx serde factory + try { + val jacksonSerdeFactoryClass = + Class.forName("dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory") + val defaultInstance = jacksonSerdeFactoryClass.getConstructor().newInstance() + this.cachedDefaultSerdeFactory = defaultInstance as SerdeFactory? + return this.cachedDefaultSerdeFactory!! + } catch (e: Exception) { + throw MalformedRestateServiceException( + serviceClazz.simpleName!!, + "Failed to load KotlinSerializationSerdeFactory for Kotlin service. " + + "Make sure sdk-serde-kotlinx is on the classpath.", + e, + ) + } + } + + override fun supports(serviceObject: Any?): Boolean { + return serviceObject?.javaClass?.let { ReflectionUtils.isKotlinClass(it) } ?: false + } + + override fun priority(): Int { + // Run before last - after code-generated factories, before java + return ServiceDefinitionFactory.LOWEST_PRIORITY - 1 + } + + private fun validateKFunction(kFunction: KFunction<*>, serviceName: String) { + if (!kFunction.isSuspend) { + throw MalformedRestateServiceException( + serviceName, + "Method '${kFunction.name}' is not a suspend function, this is not supported.", + ) + } + if (kFunction.visibility != KVisibility.PUBLIC) { + throw MalformedRestateServiceException( + serviceName, + "Method '${kFunction.name}' is not public.", + ) + } + if (Modifier.isStatic(kFunction.javaMethod!!.modifiers)) { + throw MalformedRestateServiceException( + serviceName, + "Method '" + kFunction.name + "' is static, cannot be used as Restate handler", + ) + } + } +} diff --git a/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/RestateContextElement.kt b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/RestateContextElement.kt new file mode 100644 index 000000000..e421c6a22 --- /dev/null +++ b/sdk-api-kotlin/bin/main/dev/restate/sdk/kotlin/internal/RestateContextElement.kt @@ -0,0 +1,24 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +import dev.restate.sdk.kotlin.Context +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext + +/** + * Coroutine context element that holds the Restate [Context]. + * + * This element is added to the coroutine context when a handler is invoked, allowing free-floating + * API functions like `context()`, `run()`, etc. to access the current context from within suspend + * functions. + */ +internal class RestateContextElement(val ctx: Context) : AbstractCoroutineContextElement(Key) { + companion object Key : CoroutineContext.Key +} diff --git a/sdk-api/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory b/sdk-api/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory new file mode 100644 index 000000000..9bb86d5ac --- /dev/null +++ b/sdk-api/bin/main/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory @@ -0,0 +1 @@ +dev.restate.sdk.internal.ReflectionServiceDefinitionFactory \ No newline at end of file diff --git a/sdk-api/src/main/java/dev/restate/sdk/Restate.java b/sdk-api/src/main/java/dev/restate/sdk/Restate.java index acfee597f..29d3ace17 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Restate.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Restate.java @@ -19,7 +19,12 @@ import dev.restate.sdk.annotation.Service; import dev.restate.sdk.annotation.VirtualObject; import dev.restate.sdk.annotation.Workflow; -import dev.restate.sdk.common.*; +import dev.restate.sdk.common.AbortedExecutionException; +import dev.restate.sdk.common.DurablePromiseKey; +import dev.restate.sdk.common.HandlerRequest; +import dev.restate.sdk.common.RetryPolicy; +import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.common.TerminalException; import dev.restate.serde.Serde; import dev.restate.serde.TypeTag; import java.time.Duration; @@ -27,6 +32,7 @@ import java.util.Collection; import java.util.Optional; import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; /** * This class exposes the Restate functionalities to Restate services using the reflection-based @@ -445,26 +451,29 @@ public static AwakeableHandle awakeableHandle(String id) { @org.jetbrains.annotations.ApiStatus.Experimental public static SVC service(Class clazz) { ReflectionUtils.mustHaveServiceAnnotation(clazz); - String serviceName = ReflectionUtils.extractServiceName(clazz); - return ProxySupport.createProxy( - clazz, - invocation -> { - var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); + return service(clazz, ReflectionUtils.extractServiceName(clazz)); + } - //noinspection unchecked - return Context.current() - .call( - Request.of( - Target.virtualObject(serviceName, null, methodInfo.getHandlerName()), - (TypeTag) methodInfo.getInputType(), - (TypeTag) methodInfo.getOutputType(), - invocation.getArguments().length == 0 ? null : invocation.getArguments()[0])) - .await(); - }); + /** + * EXPERIMENTAL API: Simple API to invoke a Restate service. + * + *

Like {@link #service(Class)}, but specifying the service name. + * + *

Use this method when you want to use a common interface for multiple service + * implementations, where the service name is not known at compile time or is not defined in the + * interface. + * + * @param clazz the service class or interface + * @param serviceName the name of the service to invoke + * @return a proxy client to invoke the service + */ + @org.jetbrains.annotations.ApiStatus.Experimental + public static SVC service(Class clazz, String serviceName) { + return createProxy(clazz, serviceName, null); } /** - * EXPERIMENTAL API: Advanced API to invoke a Restate service with full control. + * Ok EXPERIMENTAL API: Advanced API to invoke a Restate service with full control. * *

Create a handle that provides advanced invocation capabilities including: * @@ -496,7 +505,25 @@ public static SVC service(Class clazz) { @org.jetbrains.annotations.ApiStatus.Experimental public static ServiceHandle serviceHandle(Class clazz) { ReflectionUtils.mustHaveServiceAnnotation(clazz); - return new ServiceHandleImpl<>(clazz, null); + return serviceHandle(clazz, ReflectionUtils.extractServiceName(clazz)); + } + + /** + * EXPERIMENTAL API: Advanced API to invoke a Restate service with full control. + * + *

Like {@link #serviceHandle(Class)}, but specifying the service name. + * + *

Use this method when you want to use a common interface for multiple service + * implementations, where the service name is not known at compile time or is not defined in the + * interface. + * + * @param clazz the service class or interface + * @param serviceName the name of the service to invoke + * @return a handle to invoke the service with advanced options + */ + @org.jetbrains.annotations.ApiStatus.Experimental + public static ServiceHandle serviceHandle(Class clazz, String serviceName) { + return new ServiceHandleImpl<>(clazz, serviceName, null); } /** @@ -521,22 +548,7 @@ public static ServiceHandle serviceHandle(Class clazz) { @org.jetbrains.annotations.ApiStatus.Experimental public static SVC virtualObject(Class clazz, String key) { ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); - String serviceName = ReflectionUtils.extractServiceName(clazz); - return ProxySupport.createProxy( - clazz, - invocation -> { - var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); - - //noinspection unchecked - return Context.current() - .call( - Request.of( - Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), - (TypeTag) methodInfo.getInputType(), - (TypeTag) methodInfo.getOutputType(), - invocation.getArguments().length == 0 ? null : invocation.getArguments()[0])) - .await(); - }); + return createProxy(clazz, ReflectionUtils.extractServiceName(clazz), key); } /** @@ -598,22 +610,7 @@ public static ServiceHandle virtualObjectHandle(Class clazz, Str @org.jetbrains.annotations.ApiStatus.Experimental public static SVC workflow(Class clazz, String key) { ReflectionUtils.mustHaveWorkflowAnnotation(clazz); - String serviceName = ReflectionUtils.extractServiceName(clazz); - return ProxySupport.createProxy( - clazz, - invocation -> { - var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); - - //noinspection unchecked - return Context.current() - .call( - Request.of( - Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), - (TypeTag) methodInfo.getInputType(), - (TypeTag) methodInfo.getOutputType(), - invocation.getArguments().length == 0 ? null : invocation.getArguments()[0])) - .await(); - }); + return createProxy(clazz, ReflectionUtils.extractServiceName(clazz), key); } /** @@ -653,6 +650,24 @@ public static ServiceHandle workflowHandle(Class clazz, String k return new ServiceHandleImpl<>(clazz, key); } + private static SVC createProxy(Class clazz, String serviceName, @Nullable String key) { + return ProxySupport.createProxy( + clazz, + invocation -> { + var methodInfo = MethodInfo.fromMethod(invocation.getMethod()); + + //noinspection unchecked + return Context.current() + .call( + Request.of( + Target.virtualObject(serviceName, key, methodInfo.getHandlerName()), + (TypeTag) methodInfo.getInputType(), + (TypeTag) methodInfo.getOutputType(), + invocation.getArguments().length == 0 ? null : invocation.getArguments()[0])) + .await(); + }); + } + /** EXPERIMENTAL API: Interface to interact with this Virtual Object/Workflow state. */ @org.jetbrains.annotations.ApiStatus.Experimental public interface State { diff --git a/sdk-api/src/main/java/dev/restate/sdk/ServiceHandleImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ServiceHandleImpl.java index f4d8eccdd..c17505407 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ServiceHandleImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ServiceHandleImpl.java @@ -11,7 +11,9 @@ import static dev.restate.common.reflections.RestateUtils.toRequest; import dev.restate.common.InvocationOptions; -import dev.restate.common.reflections.*; +import dev.restate.common.reflections.MethodInfo; +import dev.restate.common.reflections.MethodInfoCollector; +import dev.restate.common.reflections.ReflectionUtils; import dev.restate.serde.Serde; import dev.restate.serde.TypeTag; import java.time.Duration; @@ -31,8 +33,12 @@ final class ServiceHandleImpl implements ServiceHandle { private MethodInfoCollector methodInfoCollector; ServiceHandleImpl(Class clazz, @Nullable String key) { + this(clazz, ReflectionUtils.extractServiceName(clazz), key); + } + + ServiceHandleImpl(Class clazz, String serviceName, @Nullable String key) { this.clazz = clazz; - this.serviceName = ReflectionUtils.extractServiceName(clazz); + this.serviceName = serviceName; this.key = key; } diff --git a/sdk-core/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider b/sdk-core/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider new file mode 100644 index 000000000..466fb9027 --- /dev/null +++ b/sdk-core/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider @@ -0,0 +1 @@ +dev.restate.sdk.core.RestateContextDataProvider \ No newline at end of file diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt new file mode 100644 index 000000000..bef5b070f --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt @@ -0,0 +1,124 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TimeoutException +import dev.restate.sdk.core.AsyncResultTestSuite +import dev.restate.sdk.core.TestDefinitions.* +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.callGreeterGreetService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.kotlin.* +import java.util.stream.Stream +import kotlin.time.Duration.Companion.days + +class AsyncResultTest : AsyncResultTestSuite() { + override fun reverseAwaitOrder(): TestInvocationBuilder = + testDefinitionForVirtualObject("ReverseAwaitOrder") { ctx, _: Unit -> + val a1: DurableFuture = callGreeterGreetService(ctx, "Francesco") + val a2: DurableFuture = callGreeterGreetService(ctx, "Till") + + val a2Res: String = a2.await() + ctx.set(StateKey.of("A2", TestSerdes.STRING), a2Res) + + val a1Res: String = a1.await() + return@testDefinitionForVirtualObject "$a1Res-$a2Res" + } + + override fun awaitTwiceTheSameAwaitable(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitTwiceTheSameAwaitable") { ctx, _: Unit -> + val a = callGreeterGreetService(ctx, "Francesco") + return@testDefinitionForVirtualObject "${a.await()}-${a.await()}" + } + + override fun awaitAll(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitAll") { ctx, _: Unit -> + val a1 = callGreeterGreetService(ctx, "Francesco") + val a2 = callGreeterGreetService(ctx, "Till") + + return@testDefinitionForVirtualObject listOf(a1, a2) + .awaitAll() + .joinToString(separator = "-") + } + + override fun awaitAny(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitAny") { ctx, _: Unit -> + val a1 = callGreeterGreetService(ctx, "Francesco") + val a2 = callGreeterGreetService(ctx, "Till") + + return@testDefinitionForVirtualObject DurableFuture.any(a1, a2) + .map { it -> if (it == 0) a1.await() else a2.await() } + .await() + } + + private fun awaitSelect(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitSelect") { ctx, _: Unit -> + val a1 = callGreeterGreetService(ctx, "Francesco") + val a2 = callGreeterGreetService(ctx, "Till") + return@testDefinitionForVirtualObject select { + a1.onAwait { it } + a2.onAwait { it } + } + .await() + } + + override fun combineAnyWithAll(): TestInvocationBuilder = + testDefinitionForVirtualObject("CombineAnyWithAll") { ctx, _: Unit -> + val a1 = ctx.awakeable(TestSerdes.STRING) + val a2 = ctx.awakeable(TestSerdes.STRING) + val a3 = ctx.awakeable(TestSerdes.STRING) + val a4 = ctx.awakeable(TestSerdes.STRING) + + val a12 = DurableFuture.any(a1, a2).map { if (it == 0) a1.await() else a2.await() } + val a23 = DurableFuture.any(a2, a3).map { if (it == 0) a2.await() else a3.await() } + val a34 = DurableFuture.any(a3, a4).map { if (it == 0) a3.await() else a4.await() } + DurableFuture.all(a12, a23, a34).await() + + return@testDefinitionForVirtualObject a12.await() + a23.await() + a34.await() + } + + override fun awaitAnyIndex(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitAnyIndex") { ctx, _: Unit -> + val a1 = ctx.awakeable(TestSerdes.STRING) + val a2 = ctx.awakeable(TestSerdes.STRING) + val a3 = ctx.awakeable(TestSerdes.STRING) + val a4 = ctx.awakeable(TestSerdes.STRING) + + return@testDefinitionForVirtualObject DurableFuture.any(a1, DurableFuture.all(a2, a3), a4) + .await() + .toString() + } + + override fun awaitOnAlreadyResolvedAwaitables(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitOnAlreadyResolvedAwaitables") { ctx, _: Unit -> + val a1 = ctx.awakeable(TestSerdes.STRING) + val a2 = ctx.awakeable(TestSerdes.STRING) + val a12 = DurableFuture.all(a1, a2) + val a12and1 = DurableFuture.all(a12, a1) + val a121and12 = DurableFuture.all(a12and1, a12) + a12and1.await() + a121and12.await() + + return@testDefinitionForVirtualObject a1.await() + a2.await() + } + + override fun awaitWithTimeout(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitWithTimeout") { ctx, _: Unit -> + val a1 = callGreeterGreetService(ctx, "Francesco") + return@testDefinitionForVirtualObject try { + a1.await(1.days) + } catch (_: TimeoutException) { + "timeout" + } + } + + override fun definitions(): Stream = + Stream.concat(super.definitions(), super.anyTestDefinitions { awaitSelect() }) +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt new file mode 100644 index 000000000..2ab5398ec --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt @@ -0,0 +1,23 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.AwakeableIdTestSuite +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* + +class AwakeableIdTest : AwakeableIdTestSuite() { + + override fun returnAwakeableId(): TestDefinitions.TestInvocationBuilder = + testDefinitionForService("ReturnAwakeableId") { ctx, _: Unit -> + val awakeable: Awakeable = ctx.awakeable() + awakeable.id + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CallTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CallTest.kt new file mode 100644 index 000000000..9b21532c5 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CallTest.kt @@ -0,0 +1,40 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.common.Request +import dev.restate.common.Slice +import dev.restate.common.Target +import dev.restate.sdk.core.CallTestSuite +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.serde.Serde + +class CallTest : CallTestSuite() { + + override fun oneWayCall( + target: Target, + idempotencyKey: String, + headers: Map, + body: Slice, + ) = + testDefinitionForService("OneWayCall") { ctx, _: Unit -> + val ignored = + ctx.send( + Request.of(target, Serde.SLICE, Serde.RAW, body) + .headers(headers) + .idempotencyKey(idempotencyKey) + ) + } + + override fun implicitCancellation(target: Target, body: Slice) = + testDefinitionForService("ImplicitCancellation") { ctx, _: Unit -> + val ignored = + ctx.call(Request.of(target, Serde.SLICE, Serde.RAW, body)).await() + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt new file mode 100644 index 000000000..b2783f592 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt @@ -0,0 +1,94 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.AssertUtils.assertThatDiscovery +import dev.restate.sdk.core.generated.manifest.Handler +import dev.restate.sdk.core.generated.manifest.Input +import dev.restate.sdk.core.generated.manifest.Output +import dev.restate.sdk.core.generated.manifest.Service +import dev.restate.sdk.kotlin.endpoint.* +import org.assertj.core.api.Assertions +import org.assertj.core.api.InstanceOfAssertFactories.type +import org.junit.jupiter.api.Test + +class CodegenDiscoveryTest { + + @Test + fun checkCustomInputContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomCt") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun checkCustomInputAcceptContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomAccept") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/*") + } + + @Test + fun checkCustomOutputContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutputWithCustomCT") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun explicitNames() { + assertThatDiscovery( + object : GreeterWithExplicitName { + override fun greet(context: dev.restate.sdk.kotlin.Context, request: String): String { + TODO("Not yet implemented") + } + } + ) + .extractingService("MyExplicitName") + .extractingHandler("my_greeter") + Assertions.assertThat(GreeterWithExplicitNameHandlers.Metadata.SERVICE_NAME) + .isEqualTo("MyExplicitName") + } + + @Test + fun workflowType() { + assertThatDiscovery(CodegenTest.MyWorkflow()) + .extractingService("MyWorkflow") + .returns(Service.Ty.WORKFLOW) { obj -> obj.ty } + .extractingHandler("run") + .returns(Handler.Ty.WORKFLOW) { obj -> obj.ty } + } + + @Test + fun usingTransformer() { + assertThatDiscovery( + endpoint { + bind(CodegenTest.RawInputOutput()) { + it.documentation = "My service documentation" + it.configureHandler("rawInputWithCustomCt") { + it.documentation = "My handler documentation" + } + } + } + ) + .extractingService("RawInputOutput") + .returns("My service documentation", Service::getDocumentation) + .extractingHandler("rawInputWithCustomCt") + .returns("My handler documentation", Handler::getDocumentation) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenTest.kt new file mode 100644 index 000000000..9ee50f076 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/CodegenTest.kt @@ -0,0 +1,447 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.common.Slice +import dev.restate.common.Target +import dev.restate.sdk.annotation.* +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.TestDefinitions.TestDefinition +import dev.restate.sdk.core.TestDefinitions.testInvocation +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.kotlin.* +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeRef +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.* +import java.util.stream.Stream +import kotlinx.serialization.Serializable + +class CodegenTest : TestDefinitions.TestSuite { + @Service + class ServiceGreeter { + @Handler + suspend fun greet(context: Context, request: String): String { + return request + } + } + + @VirtualObject + class ObjectGreeter { + @Exclusive + suspend fun greet(context: ObjectContext, request: String): String { + return request + } + + @Handler + @Shared + suspend fun sharedGreet(context: SharedObjectContext, request: String): String { + return request + } + } + + @VirtualObject + class NestedDataClass { + @Serializable data class Input(val a: String) + + @Serializable data class Output(val a: String) + + @Exclusive + suspend fun greet(context: ObjectContext, request: Input): Output { + return Output(request.a) + } + + @Exclusive + suspend fun complexType( + context: ObjectContext, + request: Map>, + ): Map> { + return mapOf() + } + } + + @VirtualObject + interface GreeterInterface { + @Exclusive suspend fun greet(context: ObjectContext, request: String): String + } + + private class ObjectGreeterImplementedFromInterface : GreeterInterface { + override suspend fun greet(context: ObjectContext, request: String): String { + return request + } + } + + @Service + @Name("Empty") + class Empty { + @Handler + suspend fun emptyInput(context: Context): String { + val client = CodegenTestEmptyClient.fromContext(context) + return client.emptyInput().await() + } + + @Handler + suspend fun emptyOutput(context: Context, request: String) { + val client = CodegenTestEmptyClient.fromContext(context) + client.emptyOutput(request).await() + } + + @Handler + suspend fun emptyInputOutput(context: Context) { + val client = CodegenTestEmptyClient.fromContext(context) + client.emptyInputOutput().await() + } + } + + @Service + @Name("PrimitiveTypes") + class PrimitiveTypes { + @Handler + suspend fun primitiveOutput(context: Context): Int { + val client = CodegenTestPrimitiveTypesClient.fromContext(context) + return client.primitiveOutput().await() + } + + @Handler + suspend fun primitiveInput(context: Context, input: Int) { + val client = CodegenTestPrimitiveTypesClient.fromContext(context) + client.primitiveInput(input).await() + } + } + + @VirtualObject + class CornerCases { + @Exclusive + suspend fun send(context: ObjectContext, request: String): String { + // Just needs to compile + return CodegenTestCornerCasesClient.fromContext(context, request)._send("my_send").await() + } + + @Exclusive + suspend fun returnNull(context: ObjectContext, request: String?): String? { + return CodegenTestCornerCasesClient.fromContext(context, context.key()) + .returnNull(request) {} + .await() + } + + @Exclusive + suspend fun badReturnTypeInferred(context: ObjectContext): Unit { + CodegenTestCornerCasesClient.fromContext(context, context.key()) + .send() + .badReturnTypeInferred() + } + } + + @Workflow + class WorkflowCornerCases { + @Workflow + fun process(context: WorkflowContext, request: String): String { + return "" + } + + @Shared + suspend fun submit(context: SharedWorkflowContext, request: String): String { + // Just needs to compile + val ignored: String = + CodegenTestWorkflowCornerCasesClient.connect("invalid", request)._submit("my_send") + CodegenTestWorkflowCornerCasesClient.connect("invalid", request).submit("my_send") + return CodegenTestWorkflowCornerCasesClient.connect("invalid", request) + .workflowHandle() + .output + .response() + .value + } + } + + @Service + @Name("RawInputOutput") + class RawInputOutput { + @Handler + @Raw + suspend fun rawOutput(context: Context): ByteArray { + val client: CodegenTestRawInputOutputClient.ContextClient = + CodegenTestRawInputOutputClient.fromContext(context) + return client.rawOutput().await() + } + + @Handler + @Raw(contentType = "application/vnd.my.custom") + suspend fun rawOutputWithCustomCT(context: Context): ByteArray { + val client: CodegenTestRawInputOutputClient.ContextClient = + CodegenTestRawInputOutputClient.fromContext(context) + return client.rawOutputWithCustomCT().await() + } + + @Handler + suspend fun rawInput(context: Context, @Raw input: ByteArray) { + val client: CodegenTestRawInputOutputClient.ContextClient = + CodegenTestRawInputOutputClient.fromContext(context) + client.rawInput(input).await() + } + + @Handler + suspend fun rawInputWithCustomCt( + context: Context, + @Raw(contentType = "application/vnd.my.custom") input: ByteArray, + ) { + val client: CodegenTestRawInputOutputClient.ContextClient = + CodegenTestRawInputOutputClient.fromContext(context) + client.rawInputWithCustomCt(input).await() + } + + @Handler + suspend fun rawInputWithCustomAccept( + context: Context, + @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") input: ByteArray, + ) { + val client: CodegenTestRawInputOutputClient.ContextClient = + CodegenTestRawInputOutputClient.fromContext(context) + client.rawInputWithCustomCt(input).await() + } + } + + @Workflow + @Name("MyWorkflow") + class MyWorkflow { + @Workflow + suspend fun run(context: WorkflowContext, myInput: String) { + val client = CodegenTestMyWorkflowClient.fromContext(context, context.key()) + client.send().sharedHandler(myInput) + } + + @Handler + suspend fun sharedHandler(context: SharedWorkflowContext, myInput: String): String { + val client = CodegenTestMyWorkflowClient.fromContext(context, context.key()) + return client.sharedHandler(myInput).await() + } + } + + class MyCustomSerdeFactory : SerdeFactory { + override fun create(typeTag: TypeTag): Serde { + check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag) + check(typeTag.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(typeRef: TypeRef): Serde { + check(typeRef.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(clazz: Class?): Serde { + check(clazz == Byte::class.java) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + } + + @CustomSerdeFactory(MyCustomSerdeFactory::class) + @Service + @Name("CustomSerdeService") + class CustomSerdeService { + @Handler + suspend fun echo(context: Context, input: Byte): Byte { + return input + } + } + + override fun definitions(): Stream { + return Stream.of( + testInvocation({ ServiceGreeter() }, "greet") + .withInput(startMessage(1), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "greet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "sharedGreet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ NestedDataClass() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputCmd(jsonSerde(), NestedDataClass.Input("123")), + ) + .onlyBidiStream() + .expectingOutput( + outputCmd(jsonSerde(), NestedDataClass.Output("123")), + END_MESSAGE, + ), + testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ Empty() }, "emptyInput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInput")), + outputCmd("Till"), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), + outputCmd(), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyInputOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), + outputCmd(), + END_MESSAGE, + ) + .named("empty input and empty output"), + testInvocation({ PrimitiveTypes() }, "primitiveOutput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveOutput"), + Serde.VOID, + null, + ), + outputCmd(TestSerdes.INT, 10), + END_MESSAGE, + ) + .named("primitive output"), + testInvocation({ PrimitiveTypes() }, "primitiveInput") + .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveInput"), + TestSerdes.INT, + 10, + ), + outputCmd(), + END_MESSAGE, + ) + .named("primitive input"), + testInvocation({ RawInputOutput() }, "rawInput") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawInputWithCustomCt"), + "{{".toByteArray(), + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutput") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutput"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "returnNull") + .withInput( + startMessage(1, "mykey"), + inputCmd(jsonSerde(), null), + callCompletion(2, jsonSerde(), null), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.virtualObject("CodegenTestCornerCases", "mykey", "returnNull"), + jsonSerde(), + null, + ), + outputCmd(jsonSerde(), null), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "badReturnTypeInferred") + .withInput(startMessage(1, "mykey"), inputCmd()) + .onlyBidiStream() + .expectingOutput( + oneWayCallCmd( + 1, + Target.virtualObject( + "CodegenTestCornerCases", + "mykey", + "badReturnTypeInferred", + ), + null, + null, + Slice.EMPTY, + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ CustomSerdeService() }, "echo") + .withInput(startMessage(1), inputCmd(byteArrayOf(1))) + .onlyBidiStream() + .expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE), + ) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt new file mode 100644 index 000000000..b163050a1 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt @@ -0,0 +1,66 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.core.EagerStateTestSuite +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import org.assertj.core.api.AssertionsForClassTypes.assertThat + +class EagerStateTest : EagerStateTestSuite() { + override fun getEmpty(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetEmpty") { ctx, _: Unit -> + val stateIsEmpty = ctx.get(StateKey.of("STATE", TestSerdes.STRING)) == null + stateIsEmpty.toString() + } + + override fun get(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetEmpty") { ctx, _: Unit -> + ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + } + + override fun getAppendAndGet(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetAppendAndGet") { ctx, name: String -> + val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + ctx.set(StateKey.of("STATE", TestSerdes.STRING), oldState + name) + ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + } + + override fun getClearAndGet(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetClearAndGet") { ctx, _: Unit -> + val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + ctx.clear(StateKey.of("STATE", TestSerdes.STRING)) + assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isNull() + oldState + } + + override fun getClearAllAndGet(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetClearAllAndGet") { ctx, _: Unit -> + val oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + + ctx.clearAll() + + assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isNull() + assertThat(ctx.get(StateKey.of("ANOTHER_STATE", TestSerdes.STRING))).isNull() + oldState + } + + override fun listKeys(): TestInvocationBuilder = + testDefinitionForVirtualObject("ListKeys") { ctx, _: Unit -> + ctx.stateKeys().joinToString(separator = ",") + } + + override fun consecutiveGetWithEmpty(): TestInvocationBuilder = + testDefinitionForVirtualObject("ConsecutiveGetWithEmpty") { ctx, _: Unit -> + assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isNull() + assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isNull() + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt new file mode 100644 index 000000000..a01eaedd3 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/GreeterWithExplicitName.kt @@ -0,0 +1,18 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +@Service +@Name("MyExplicitName") +interface GreeterWithExplicitName { + @Handler @Name("my_greeter") fun greet(context: Context, request: String): String +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt new file mode 100644 index 000000000..a4e504f52 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt @@ -0,0 +1,21 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.InvocationIdTestSuite +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService + +class InvocationIdTest : InvocationIdTestSuite() { + + override fun returnInvocationId(): TestInvocationBuilder = + testDefinitionForService("ReturnInvocationId") { ctx, _: Unit -> + ctx.request().invocationId().toString() + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt new file mode 100644 index 000000000..e50305c56 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt @@ -0,0 +1,141 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.common.Request +import dev.restate.sdk.core.* +import dev.restate.sdk.core.TestDefinitions.TestExecutor +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.reflections.ReflectionTest +import dev.restate.sdk.core.statemachine.ProtoUtils +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.* +import dev.restate.serde.kotlinx.* +import java.util.stream.Stream +import kotlinx.coroutines.Dispatchers + +class KotlinAPITests : TestRunner() { + override fun executors(): Stream { + return Stream.of(MockRequestResponse.INSTANCE, MockBidiStream.INSTANCE) + } + + public override fun definitions(): Stream { + return Stream.of( + AwakeableIdTest(), + AsyncResultTest(), + CallTest(), + EagerStateTest(), + StateTest(), + InvocationIdTest(), + OnlyInputAndOutputTest(), + PromiseTest(), + SideEffectTest(), + SleepTest(), + StateMachineFailuresTest(), + UserFailuresTest(), + RandomTest(), + CodegenTest(), + ReflectionTest(), + ) + } + + companion object { + inline fun testDefinitionForService( + name: String, + noinline runner: suspend (Context, REQ) -> RES, + ): TestInvocationBuilder { + return TestDefinitions.testInvocation( + ServiceDefinition.of( + name, + ServiceType.SERVICE, + listOf( + HandlerDefinition.of( + "run", + HandlerType.SHARED, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner, + ), + ) + ), + ), + "run", + ) + } + + inline fun testDefinitionForVirtualObject( + name: String, + noinline runner: suspend (ObjectContext, REQ) -> RES, + ): TestInvocationBuilder { + return TestDefinitions.testInvocation( + ServiceDefinition.of( + name, + ServiceType.VIRTUAL_OBJECT, + listOf( + HandlerDefinition.of( + "run", + HandlerType.EXCLUSIVE, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner, + ), + ) + ), + ), + "run", + ) + } + + inline fun testDefinitionForWorkflow( + name: String, + noinline runner: suspend (WorkflowContext, REQ) -> RES, + ): TestInvocationBuilder { + return TestDefinitions.testInvocation( + ServiceDefinition.of( + name, + ServiceType.WORKFLOW, + listOf( + HandlerDefinition.of( + "run", + HandlerType.WORKFLOW, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner, + ), + ) + ), + ), + "run", + ) + } + + suspend fun callGreeterGreetService(ctx: Context, parameter: String): DurableFuture { + return ctx.call( + Request.of( + ProtoUtils.GREETER_SERVICE_TARGET, + TestSerdes.STRING, + TestSerdes.STRING, + parameter, + ) + ) + } + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt new file mode 100644 index 000000000..e8c6606a2 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt @@ -0,0 +1,13 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.annotation.Service + +@Service annotation class MyMetaServiceAnnotation(val name: String = "") diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt new file mode 100644 index 000000000..d8bf351ab --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt @@ -0,0 +1,19 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.OnlyInputAndOutputTestSuite +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService + +class OnlyInputAndOutputTest : OnlyInputAndOutputTestSuite() { + + override fun noSyscallsGreeter(): TestInvocationBuilder = + testDefinitionForService("NoSyscallsGreeter") { _, name: String -> "Hello $name" } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/PromiseTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/PromiseTest.kt new file mode 100644 index 000000000..4b9989663 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/PromiseTest.kt @@ -0,0 +1,61 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.core.PromiseTestSuite +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForWorkflow +import dev.restate.sdk.kotlin.* + +class PromiseTest : PromiseTestSuite() { + override fun awaitPromise(promiseKey: String): TestDefinitions.TestInvocationBuilder = + testDefinitionForWorkflow("AwaitPromise") { ctx, _: Unit -> + ctx.promise(durablePromiseKey(promiseKey)).future().await() + } + + override fun awaitPeekPromise( + promiseKey: String, + emptyCaseReturnValue: String, + ): TestDefinitions.TestInvocationBuilder = + testDefinitionForWorkflow("AwaitPeekPromise") { ctx, _: Unit -> + ctx.promise(durablePromiseKey(promiseKey)).peek().orElse(emptyCaseReturnValue) + } + + override fun awaitIsPromiseCompleted(promiseKey: String): TestDefinitions.TestInvocationBuilder = + testDefinitionForWorkflow("IsCompletedPromise") { ctx, _: Unit -> + ctx.promise(durablePromiseKey(promiseKey)).peek().isReady + } + + override fun awaitResolvePromise( + promiseKey: String, + completionValue: String, + ): TestDefinitions.TestInvocationBuilder = + testDefinitionForWorkflow("ResolvePromise") { ctx, _: Unit -> + try { + ctx.promiseHandle(durablePromiseKey(promiseKey)).resolve(completionValue) + return@testDefinitionForWorkflow true + } catch (e: TerminalException) { + return@testDefinitionForWorkflow false + } + } + + override fun awaitRejectPromise( + promiseKey: String, + rejectReason: String, + ): TestDefinitions.TestInvocationBuilder = + testDefinitionForWorkflow("RejectPromise") { ctx, _: Unit -> + try { + ctx.promiseHandle(durablePromiseKey(promiseKey)).reject(rejectReason) + return@testDefinitionForWorkflow true + } catch (e: TerminalException) { + return@testDefinitionForWorkflow false + } + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/RandomTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/RandomTest.kt new file mode 100644 index 000000000..99b7de8d8 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/RandomTest.kt @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.RandomTestSuite +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import kotlin.random.Random + +class RandomTest : RandomTestSuite() { + override fun randomShouldBeDeterministic(): TestInvocationBuilder = + testDefinitionForService("RandomShouldBeDeterministic") { ctx, _: Unit -> + ctx.random().nextInt() + } + + override fun getExpectedInt(seed: Long): Int { + return Random(seed).nextInt() + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt new file mode 100644 index 000000000..1d448cee0 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt @@ -0,0 +1,151 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import com.google.protobuf.ByteString +import dev.restate.common.Slice +import dev.restate.sdk.Restate +import dev.restate.sdk.common.RetryPolicy +import dev.restate.sdk.core.SideEffectTestSuite +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.* +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import dev.restate.serde.kotlinx.jsonSerde +import dev.restate.serde.kotlinx.typeTag +import java.util.* +import kotlin.coroutines.coroutineContext +import kotlin.time.Clock +import kotlin.time.ExperimentalTime +import kotlin.time.Instant +import kotlin.time.toJavaInstant +import kotlin.time.toKotlinDuration +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.Dispatchers +import org.assertj.core.api.Assertions + +class SideEffectTest : SideEffectTestSuite() { + + override fun sideEffect(sideEffectOutput: String): TestInvocationBuilder = + testDefinitionForService("SideEffect") { ctx, _: Unit -> + val result = ctx.runBlock { sideEffectOutput } + "Hello $result" + } + + override fun namedSideEffect(name: String, sideEffectOutput: String): TestInvocationBuilder = + testDefinitionForService("SideEffect") { ctx, _: Unit -> + val result = ctx.runBlock(name) { sideEffectOutput } + "Hello $result" + } + + override fun consecutiveSideEffect(sideEffectOutput: String): TestInvocationBuilder = + testDefinitionForService("ConsecutiveSideEffect") { ctx, _: Unit -> + val firstResult = ctx.runBlock { sideEffectOutput } + val secondResult = ctx.runBlock { firstResult.uppercase(Locale.getDefault()) } + "Hello $secondResult" + } + + override fun checkContextSwitching(): TestInvocationBuilder = + TestDefinitions.testInvocation( + ServiceDefinition.of( + "CheckContextSwitching", + ServiceType.SERVICE, + listOf( + HandlerDefinition.of( + "run", + HandlerType.SHARED, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options( + Dispatchers.Unconfined + + CoroutineName("CheckContextSwitchingTestCoroutine") + ), + ) { ctx: Context, _: Unit -> + val sideEffectCoroutine = + ctx.runBlock { coroutineContext[CoroutineName]!!.name } + check(sideEffectCoroutine == "CheckContextSwitchingTestCoroutine") { + "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectCoroutine" + } + "Hello" + }, + ) + ), + ), + "run", + ) + + override fun failingSideEffect(name: String, reason: String) = + testDefinitionForService("FailingSideEffect") { ctx, _: Unit -> + ctx.runBlock(name) { throw IllegalStateException(reason) } + } + + override fun awaitAllSideEffectWithFirstFailing( + firstSideEffect: String, + secondSideEffect: String, + successValue: String, + failureReason: String, + ) = + testDefinitionForService("AwaitAllSideEffectWithFirstFailing") { ctx, _: Unit -> + val fut1 = + ctx.runAsync(firstSideEffect) { throw IllegalStateException(failureReason) } + val fut2 = ctx.runAsync(secondSideEffect) { successValue } + listOf(fut1, fut2).awaitAll() + } + + override fun awaitAllSideEffectWithSecondFailing( + firstSideEffect: String, + secondSideEffect: String, + successValue: String, + failureReason: String, + ) = + testDefinitionForService("AwaitAllSideEffectWithSecondFailing") { ctx, _: Unit -> + val fut1 = ctx.runAsync(firstSideEffect) { successValue } + val fut2 = + ctx.runAsync(secondSideEffect) { throw IllegalStateException(failureReason) } + listOf(fut1, fut2).awaitAll() + } + + override fun failingSideEffectWithRetryPolicy(reason: String, retryPolicy: RetryPolicy?) = + testDefinitionForService("FailingSideEffectWithRetryPolicy") { ctx, _: Unit -> + ctx.runBlock( + retryPolicy = + retryPolicy?.let { + RetryPolicy( + initialDelay = it.initialDelay.toKotlinDuration(), + exponentiationFactor = it.exponentiationFactor, + maxDelay = it.maxDelay?.toKotlinDuration(), + maxDuration = it.maxDuration?.toKotlinDuration(), + maxAttempts = it.maxAttempts, + ) + } + ) { + throw IllegalStateException(reason) + } + } + + @OptIn(ExperimentalTime::class) + override fun instantNow() = + testDefinitionForService("InstantNow") { ctx, _: Unit -> Clock.Restate.now() } + + @OptIn(ExperimentalTime::class) + override fun assertIsInstant(bytes: ByteString) { + val instant = + KotlinSerializationSerdeFactory() + .create(typeTag()) + .deserialize(Slice.wrap(bytes.asReadOnlyByteBuffer())) + Assertions.assertThat(instant.toJavaInstant()).isNotNull().isBefore(java.time.Instant.now()) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SleepTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SleepTest.kt new file mode 100644 index 000000000..58b069f79 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/SleepTest.kt @@ -0,0 +1,33 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.core.SleepTestSuite +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* +import kotlin.time.Duration.Companion.seconds + +class SleepTest : SleepTestSuite() { + + override fun sleepGreeter(): TestDefinitions.TestInvocationBuilder = + testDefinitionForService("SleepGreeter") { ctx, _: Unit -> + ctx.sleep(1.seconds) + "Hello" + } + + override fun manySleeps(): TestDefinitions.TestInvocationBuilder = + testDefinitionForService("ManySleeps") { ctx, _: Unit -> + val durableFutures = mutableListOf>() + for (i in 0..9) { + durableFutures.add(ctx.timer(1.seconds)) + } + durableFutures.awaitAll() + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt new file mode 100644 index 000000000..94c7da305 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt @@ -0,0 +1,59 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.AbortedExecutionException +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.core.StateMachineFailuresTestSuite +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.serde.Serde +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicInteger +import kotlinx.coroutines.CancellationException + +class StateMachineFailuresTest : StateMachineFailuresTestSuite() { + companion object { + private val STATE = + StateKey.of( + "STATE", + Serde.using({ i: Int -> i.toString().toByteArray(StandardCharsets.UTF_8) }) { + b: ByteArray? -> + String(b!!, StandardCharsets.UTF_8).toInt() + }, + ) + } + + override fun getState(nonTerminalExceptionsSeen: AtomicInteger): TestInvocationBuilder = + testDefinitionForVirtualObject("GetState") { ctx, _: Unit -> + try { + ctx.get(STATE) + } catch (e: Throwable) { + // A user should never catch Throwable!!! + if (AbortedExecutionException.INSTANCE == e) { + throw e + } + // A user should never catch Throwable!!! + if (e !is CancellationException && e !is TerminalException) { + nonTerminalExceptionsSeen.addAndGet(1) + } else { + throw e + } + } + "Francesco" + } + + override fun sideEffectFailure(serde: Serde): TestInvocationBuilder = + testDefinitionForService("SideEffectFailure") { ctx, _: Unit -> + ctx.runBlock(serde) { 0 } + "Francesco" + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateTest.kt new file mode 100644 index 000000000..4f63035f0 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/StateTest.kt @@ -0,0 +1,86 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.core.StateTestSuite +import dev.restate.sdk.core.TestDefinitions.* +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.kotlin.* +import dev.restate.serde.kotlinx.* +import java.util.stream.Stream +import kotlinx.serialization.Serializable + +class StateTest : StateTestSuite() { + + override fun getState(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetState") { ctx, _: Unit -> + val state = ctx.get(StateKey.of("STATE", TestSerdes.STRING)) ?: "Unknown" + "Hello $state" + } + + override fun getAndSetState(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetAndSetState") { ctx, name: String -> + val state = ctx.get(StateKey.of("STATE", TestSerdes.STRING))!! + ctx.set(StateKey.of("STATE", TestSerdes.STRING), name) + "Hello $state" + } + + override fun setNullState(): TestInvocationBuilder { + return unsupported("The kotlin type system enforces non null state values") + } + + // --- Test using KTSerdes + + @Serializable data class Data(var a: Int, val b: String) + + private companion object { + val DATA = stateKey("STATE") + } + + private fun getAndSetStateUsingKtSerdes(): TestInvocationBuilder = + testDefinitionForVirtualObject("GetAndSetStateUsingKtSerdes") { ctx, _: Unit -> + val state = ctx.get(DATA)!! + state.a += 1 + ctx.set(DATA, state) + + "Hello $state" + } + + override fun definitions(): Stream { + return Stream.concat( + super.definitions(), + Stream.of( + getAndSetStateUsingKtSerdes() + .withInput( + startMessage(3), + inputCmd(), + getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till")), + setStateCmd("STATE", jsonSerde(), Data(2, "Till")), + ) + .expectingOutput(outputCmd("Hello " + Data(2, "Till")), END_MESSAGE) + .named("With GetState and SetState"), + getAndSetStateUsingKtSerdes() + .withInput( + startMessage(2), + inputCmd(), + getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till")), + ) + .expectingOutput( + setStateCmd("STATE", jsonSerde(), Data(2, "Till")), + outputCmd("Hello " + Data(2, "Till")), + END_MESSAGE, + ) + .named("With GetState already completed"), + ), + ) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt new file mode 100644 index 000000000..9d8e2af58 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt @@ -0,0 +1,51 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.UserFailuresTestSuite +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* +import java.util.concurrent.atomic.AtomicInteger +import kotlin.coroutines.cancellation.CancellationException + +class UserFailuresTest : UserFailuresTestSuite() { + override fun throwIllegalStateException(): TestInvocationBuilder = + testDefinitionForService("ThrowIllegalStateException") { _, _: Unit -> + throw IllegalStateException("Whatever") + } + + override fun sideEffectThrowIllegalStateException( + nonTerminalExceptionsSeen: AtomicInteger + ): TestInvocationBuilder = + testDefinitionForService("SideEffectThrowIllegalStateException") { ctx, _: Unit -> + try { + ctx.runBlock { throw IllegalStateException("Whatever") } + } catch (e: Throwable) { + if (e !is CancellationException && e !is TerminalException) { + nonTerminalExceptionsSeen.addAndGet(1) + } else { + throw e + } + } + throw IllegalStateException("Not expected to reach this point") + } + + override fun throwTerminalException(code: Int, message: String): TestInvocationBuilder = + testDefinitionForService("ThrowTerminalException") { _, _: Unit -> + throw TerminalException(code, message) + } + + override fun sideEffectThrowTerminalException(code: Int, message: String): TestInvocationBuilder = + testDefinitionForService("SideEffectThrowTerminalException") { ctx, _: Unit -> + ctx.runBlock { throw TerminalException(code, message) } + throw IllegalStateException("Not expected to reach this point") + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt new file mode 100644 index 000000000..b8098ba95 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt @@ -0,0 +1,112 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.sdk.core.AssertUtils.assertThatDiscovery +import dev.restate.sdk.core.generated.manifest.Handler +import dev.restate.sdk.core.generated.manifest.Input +import dev.restate.sdk.core.generated.manifest.Output +import dev.restate.sdk.core.generated.manifest.Service +import dev.restate.sdk.kotlin.endpoint.* +import dev.restate.serde.Serde +import org.assertj.core.api.InstanceOfAssertFactories.type +import org.junit.jupiter.api.Test + +class ReflectionDiscoveryTest { + + @Test + fun checkCustomInputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomCt") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun checkCustomInputAcceptContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomAccept") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/*") + } + + @Test + fun checkCustomOutputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutputWithCustomCT") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun checkRawInputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInput") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) + } + + @Test + fun checkRawOutputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutput") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) + } + + @Test + fun explicitNames() { + assertThatDiscovery( + object : GreeterWithExplicitName { + override suspend fun greet(request: String): String { + TODO("Not yet implemented") + } + } + ) + .extractingService("MyExplicitName") + .extractingHandler("my_greeter") + } + + @Test + fun workflowType() { + assertThatDiscovery(MyWorkflow()) + .extractingService("MyWorkflow") + .returns(Service.Ty.WORKFLOW) { obj -> obj.ty } + .extractingHandler("run") + .returns(Handler.Ty.WORKFLOW) { obj -> obj.ty } + } + + @Test + fun usingTransformer() { + assertThatDiscovery( + endpoint { + bind(RawInputOutput()) { + it.documentation = "My service documentation" + it.configureHandler("rawInputWithCustomCt") { + it.documentation = "My handler documentation" + } + } + } + ) + .extractingService("RawInputOutput") + .returns("My service documentation", Service::getDocumentation) + .extractingHandler("rawInputWithCustomCt") + .returns("My handler documentation", Handler::getDocumentation) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt new file mode 100644 index 000000000..e85b49087 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt @@ -0,0 +1,247 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.common.Slice +import dev.restate.common.Target +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.TestDefinitions.TestDefinition +import dev.restate.sdk.core.TestDefinitions.testInvocation +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.serde.Serde +import dev.restate.serde.kotlinx.* +import java.util.stream.Stream + +class ReflectionTest : TestDefinitions.TestSuite { + + override fun definitions(): Stream { + return Stream.of( + testInvocation({ ServiceGreeter() }, "greet") + .withInput(startMessage(1), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "greet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "sharedGreet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ NestedDataClass() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputCmd(jsonSerde(), NestedDataClass.Input("123")), + ) + .onlyBidiStream() + .expectingOutput( + outputCmd(jsonSerde(), NestedDataClass.Output("123")), + END_MESSAGE, + ), + testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputCmd("Francesco"), + callCompletion(2, "Francesco"), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.virtualObject("GreeterInterface", "slinkydeveloper", "greet"), + "Francesco", + ), + outputCmd("Francesco"), + END_MESSAGE, + ), + testInvocation({ Empty() }, "emptyInput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInput")), + outputCmd("Till"), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), + outputCmd(), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyInputOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), + outputCmd(), + END_MESSAGE, + ) + .named("empty input and empty output"), + testInvocation({ PrimitiveTypes() }, "primitiveOutput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveOutput"), + Serde.VOID, + null, + ), + outputCmd(TestSerdes.INT, 10), + END_MESSAGE, + ) + .named("primitive output"), + testInvocation({ PrimitiveTypes() }, "primitiveInput") + .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveInput"), + TestSerdes.INT, + 10, + ), + outputCmd(), + END_MESSAGE, + ) + .named("primitive input"), + testInvocation({ RawInputOutput() }, "rawInput") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawInputWithCustomCt"), + "{{".toByteArray(), + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutput") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutput"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "returnNull") + .withInput( + startMessage(1, "mykey"), + inputCmd(jsonSerde(), null), + callCompletion(2, jsonSerde(), null), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.virtualObject("CornerCases", "mykey", "returnNull"), + jsonSerde(), + null, + ), + outputCmd(jsonSerde(), null), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "badReturnTypeInferred") + .withInput(startMessage(1, "mykey"), inputCmd()) + .onlyBidiStream() + .expectingOutput( + oneWayCallCmd( + 1, + Target.virtualObject( + "CornerCases", + "mykey", + "badReturnTypeInferred", + ), + null, + null, + Slice.EMPTY, + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "callSuspendWithinProxy") + .withInput(startMessage(1, "mykey"), inputCmd()) + .onlyBidiStream() + .expectingOutput( + oneWayCallCmd( + 1, + Target.virtualObject( + "CornerCases", + "mykey", + "callSuspendWithinProxy", + ), + null, + null, + Slice.EMPTY, + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ CustomSerdeService() }, "echo") + .withInput(startMessage(1), inputCmd(byteArrayOf(1))) + .onlyBidiStream() + .expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE), + ) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt new file mode 100644 index 000000000..4f08348c6 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt @@ -0,0 +1,205 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeRef +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import kotlinx.coroutines.delay +import kotlinx.serialization.Serializable + +@Service +class ServiceGreeter { + @Handler + suspend fun greet(request: String): String { + return request + } +} + +@VirtualObject +class ObjectGreeter { + @Exclusive + suspend fun greet(request: String): String { + return request + } + + @Handler + @Shared + suspend fun sharedGreet(request: String): String { + return request + } +} + +@VirtualObject +class NestedDataClass { + @Serializable data class Input(val a: String) + + @Serializable data class Output(val a: String) + + @Exclusive + suspend fun greet(request: Input): Output { + return Output(request.a) + } + + @Exclusive + suspend fun complexType(request: Map>): Map> { + return mapOf() + } +} + +@VirtualObject +interface GreeterInterface { + @Exclusive suspend fun greet(request: String): String +} + +class ObjectGreeterImplementedFromInterface : GreeterInterface { + override suspend fun greet(request: String): String { + return virtualObject(objectKey()).greet(request) + } +} + +@Service +@Name("Empty") +open class Empty { + @Handler + open suspend fun emptyInput(): String { + return service().emptyInput() + } + + @Handler + open suspend fun emptyOutput(request: String) { + service().emptyOutput(request) + } + + @Handler + open suspend fun emptyInputOutput() { + service().emptyInputOutput() + } +} + +@Service +@Name("PrimitiveTypes") +open class PrimitiveTypes { + @Handler + open suspend fun primitiveOutput(): Int { + return service().primitiveOutput() + } + + @Handler + open suspend fun primitiveInput(input: Int) { + service().primitiveInput(input) + } +} + +@VirtualObject +open class CornerCases { + + @Exclusive + open suspend fun returnNull(request: String?): String? { + return virtualObject(objectKey()).returnNull(request) + } + + @Exclusive + open suspend fun badReturnTypeInferred(): Unit { + toVirtualObject(objectKey()).request { badReturnTypeInferred() }.send() + } + + @Exclusive + open suspend fun callSuspendWithinProxy() { + toVirtualObject(objectKey()) + .request { + // Doing a suspend call within the proxy + delay(1) + callSuspendWithinProxy() + } + .send() + } +} + +@Service +@Name("RawInputOutput") +open class RawInputOutput { + @Handler @Raw open suspend fun rawOutput(): ByteArray = service().rawOutput() + + @Handler + @Raw(contentType = "application/vnd.my.custom") + open suspend fun rawOutputWithCustomCT(): ByteArray = + service().rawOutputWithCustomCT() + + @Handler + open suspend fun rawInput(@Raw input: ByteArray) { + service().rawInput(input) + } + + @Handler + open suspend fun rawInputWithCustomCt( + @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + service().rawInputWithCustomCt(input) + } + + @Handler + open suspend fun rawInputWithCustomAccept( + @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + service().rawInputWithCustomAccept(input) + } +} + +@Workflow +@Name("MyWorkflow") +open class MyWorkflow { + @Workflow + open suspend fun run(myInput: String) { + toWorkflow(workflowKey()).request { sharedHandler(myInput) }.send() + } + + @Handler + open suspend fun sharedHandler(myInput: String): String = + workflow(workflowKey()).sharedHandler(myInput) +} + +@Suppress("UNCHECKED_CAST") +class MyCustomSerdeFactory : SerdeFactory { + override fun create(typeTag: TypeTag): Serde { + check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag) + check(typeTag.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(typeRef: TypeRef): Serde { + check(typeRef.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(clazz: Class?): Serde { + check(clazz == Byte::class.java) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } +} + +@CustomSerdeFactory(MyCustomSerdeFactory::class) +@Service +@Name("CustomSerdeService") +class CustomSerdeService { + @Handler + suspend fun echo(input: Byte): Byte { + return input + } +} + +@Service +@Name("MyExplicitName") +interface GreeterWithExplicitName { + @Handler @Name("my_greeter") suspend fun greet(request: String): String +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt new file mode 100644 index 000000000..497f5e958 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt @@ -0,0 +1,165 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.vertx + +import com.fasterxml.jackson.databind.ObjectMapper +import com.google.protobuf.MessageLite +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.HandlerRunner +import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.endpoint.endpoint +import dev.restate.sdk.kotlin.stateKey +import dev.restate.serde.kotlinx.* +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.HttpResponseStatus +import io.vertx.core.Vertx +import io.vertx.core.buffer.Buffer +import io.vertx.core.http.* +import io.vertx.junit5.VertxExtension +import io.vertx.kotlin.coroutines.coAwait +import io.vertx.kotlin.coroutines.dispatcher +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.runBlocking +import org.apache.logging.log4j.LogManager +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.parallel.Isolated + +@Isolated +@ExtendWith(VertxExtension::class) +internal class RestateHttpServerTest { + + companion object { + val HTTP_CLIENT_OPTIONS: HttpClientOptions = + HttpClientOptions() + // Set prior knowledge + .setProtocolVersion(HttpVersion.HTTP_2) + .setHttp2ClearTextUpgrade(false) + + private val LOG = LogManager.getLogger() + private val COUNTER = stateKey("counter") + + const val GREETER_NAME = "Greeter" + + fun greeter(): ServiceDefinition = + ServiceDefinition.of( + GREETER_NAME, + ServiceType.VIRTUAL_OBJECT, + listOf( + HandlerDefinition.of( + "greet", + HandlerType.EXCLUSIVE, + jsonSerde(), + jsonSerde(), + HandlerRunner.of(KotlinSerializationSerdeFactory()) { + ctx: ObjectContext, + request: String -> + LOG.info("Greet invoked!") + + val count = (ctx.get(COUNTER) ?: 0) + 1 + ctx.set(COUNTER, count) + + ctx.sleep(1.seconds) + + "Hello $request. Count: $count" + }, + ) + ), + ) + } + + @Test + fun return404(vertx: Vertx): Unit = + runBlocking(vertx.dispatcher()) { + val endpointPort: Int = + RestateHttpServer.fromEndpoint( + vertx, + endpoint { bind(greeter()) }, + HttpServerOptions().setPort(0), + ) + .listen() + .coAwait() + .actualPort() + + val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) + + val request = + client + .request( + HttpMethod.POST, + endpointPort, + "localhost", + "/invoke/$GREETER_NAME/unknownMethod", + ) + .coAwait() + + // Prepare request header + request + .setChunked(true) + .putHeader(HttpHeaders.CONTENT_TYPE, serviceProtocolContentTypeHeader(false)) + .putHeader(HttpHeaders.ACCEPT, serviceProtocolContentTypeHeader(false)) + request.write(encode(startMessage(0).build())) + + val response = request.response().coAwait() + + // Response status should be 404 + assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.NOT_FOUND.code()) + + response.end().coAwait() + } + + @Test + fun serviceDiscovery(vertx: Vertx): Unit = + runBlocking(vertx.dispatcher()) { + val endpointPort: Int = + RestateHttpServer.fromEndpoint( + vertx, + endpoint { bind(greeter()) }, + HttpServerOptions().setPort(0), + ) + .listen() + .coAwait() + .actualPort() + + val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) + + // Send request + val request = + client.request(HttpMethod.GET, endpointPort, "localhost", "/discover").coAwait() + request.putHeader(HttpHeaders.ACCEPT, serviceProtocolDiscoveryContentTypeHeader()) + request.end().coAwait() + + // Assert response + val response = request.response().coAwait() + + // Response status and content type header + assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code()) + assertThat(response.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(serviceProtocolDiscoveryContentTypeHeader()) + + // Parse response + val responseBody = response.body().coAwait() + // Compute response and write it back + val discoveryResponse: EndpointManifestSchema = + ObjectMapper().readValue(responseBody.bytes, EndpointManifestSchema::class.java) + + assertThat(discoveryResponse.services).map { it.name }.containsOnly(GREETER_NAME) + } + + private fun encode(msg: MessageLite): Buffer { + return Buffer.buffer(Unpooled.wrappedBuffer(encodeMessageToByteBuffer(msg))) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt new file mode 100644 index 000000000..5f2945bf0 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt @@ -0,0 +1,114 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.vertx + +import dev.restate.sdk.core.TestDefinitions.TestDefinition +import dev.restate.sdk.core.TestDefinitions.TestExecutor +import dev.restate.sdk.core.statemachine.ProtoUtils +import dev.restate.sdk.endpoint.Endpoint +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.http.vertx.RestateHttpServer +import io.netty.buffer.Unpooled +import io.vertx.core.Vertx +import io.vertx.core.buffer.Buffer +import io.vertx.core.http.HttpHeaders +import io.vertx.core.http.HttpMethod +import io.vertx.core.http.HttpServerOptions +import io.vertx.kotlin.coroutines.coAwait +import io.vertx.kotlin.coroutines.dispatcher +import java.nio.ByteBuffer +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.yield + +class RestateHttpServerTestExecutor(private val vertx: Vertx) : TestExecutor { + override fun buffered(): Boolean { + return false + } + + override fun executeTest(definition: TestDefinition) { + runBlocking(vertx.dispatcher()) { + // Build server + val endpointBuilder = + Endpoint.builder() + .bind(definition.serviceDefinition as ServiceDefinition, definition.serviceOptions) + if (definition.isEnablePreviewContext()) { + endpointBuilder.enablePreviewContext() + } + + // Start server + val server = + RestateHttpServer.fromEndpoint( + vertx, + endpointBuilder.build(), + HttpServerOptions().setPort(0), + ) + server.listen().coAwait() + + val client = vertx.createHttpClient(RestateHttpServerTest.Companion.HTTP_CLIENT_OPTIONS) + + val request = + client + .request( + HttpMethod.POST, + server.actualPort(), + "localhost", + "/invoke/${definition.serviceDefinition.serviceName}/${definition.method}", + ) + .coAwait() + + // Prepare request header and send them + request + .setChunked(true) + .putHeader( + HttpHeaders.CONTENT_TYPE, + ProtoUtils.serviceProtocolContentTypeHeader(definition.isEnablePreviewContext), + ) + .putHeader( + HttpHeaders.ACCEPT, + ProtoUtils.serviceProtocolContentTypeHeader(definition.isEnablePreviewContext), + ) + request.sendHead().coAwait() + + launch { + for (msg in definition.input) { + request + .write( + Buffer.buffer(Unpooled.wrappedBuffer(ProtoUtils.invocationInputToByteString(msg))) + ) + .coAwait() + yield() + } + + request.end().coAwait() + } + + val response = request.response().coAwait() + + // Start the response receiver + val inputChannel = Channel() + response.handler { launch(vertx.dispatcher()) { inputChannel.send(it) } } + response.endHandler { inputChannel.close() } + response.resume() + + // Collect all the output messages + val buffers = inputChannel.receiveAsFlow().toList() + + definition.outputAssert.accept( + ProtoUtils.bufferToMessages(buffers.map { ByteBuffer.wrap(it.bytes) }) + ) + + // Close the server + server.close().coAwait() + } + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt new file mode 100644 index 000000000..97cda32a6 --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt @@ -0,0 +1,45 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.vertx + +import dev.restate.sdk.core.TestDefinitions.TestExecutor +import dev.restate.sdk.core.TestDefinitions.TestSuite +import dev.restate.sdk.core.TestRunner +import dev.restate.sdk.core.javaapi.JavaAPITests +import dev.restate.sdk.core.kotlinapi.KotlinAPITests +import io.vertx.core.Vertx +import java.util.stream.Stream +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll + +class RestateHttpServerTests : TestRunner() { + + lateinit var vertx: Vertx + + @BeforeAll + fun beforeAll() { + vertx = Vertx.vertx() + } + + @AfterAll + fun afterAll() { + vertx.close().toCompletionStage().toCompletableFuture().get() + } + + override fun executors(): Stream { + return Stream.of(RestateHttpServerTestExecutor(vertx)) + } + + override fun definitions(): Stream { + return Stream.concat( + Stream.concat(JavaAPITests().definitions(), KotlinAPITests().definitions()), + Stream.of(ThreadTrampoliningTestSuite()), + ) + } +} diff --git a/sdk-core/bin/test/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt b/sdk-core/bin/test/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt new file mode 100644 index 000000000..5a4f6eafb --- /dev/null +++ b/sdk-core/bin/test/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt @@ -0,0 +1,125 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.vertx + +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.TestDefinitions.testInvocation +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.Context +import dev.restate.sdk.kotlin.HandlerRunner +import dev.restate.sdk.kotlin.runBlock +import dev.restate.serde.Serde +import dev.restate.serde.jackson.JacksonSerdeFactory +import dev.restate.serde.kotlinx.* +import io.vertx.core.Vertx +import java.util.stream.Stream +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.Dispatchers +import org.apache.logging.log4j.LogManager + +class ThreadTrampoliningTestSuite : TestDefinitions.TestSuite { + + private val nonBlockingCoroutineName = CoroutineName("CheckContextSwitchingTestCoroutine") + + companion object { + private val LOG = LogManager.getLogger() + } + + private suspend fun checkNonBlockingComponentTrampolineExecutor(ctx: Context) { + LOG.info("I am on the thread I am before executing side effect") + check(Vertx.currentContext() == null) + check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) + ctx.runBlock { + LOG.info("I am on the thread I am when executing side effect") + check(Vertx.currentContext() == null) + } + LOG.info("I am on the thread I am after executing side effect") + check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) + check(Vertx.currentContext() == null) + } + + private fun checkBlockingComponentTrampolineExecutor( + ctx: dev.restate.sdk.Context, + _unused: Any?, + ): Void? { + val id = Thread.currentThread().id + check(Vertx.currentContext() == null) + ctx.run { check(Vertx.currentContext() == null) } + check(Thread.currentThread().id == id) + check(Vertx.currentContext() == null) + return null + } + + override fun definitions(): Stream { + return Stream.of( + testInvocation( + ServiceDefinition.of( + "CheckNonBlockingComponentTrampolineExecutor", + ServiceType.SERVICE, + listOf( + HandlerDefinition.of( + "do", + HandlerType.SHARED, + KotlinSerializationSerdeFactory.UNIT, + KotlinSerializationSerdeFactory.UNIT, + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options( + Dispatchers.Default + nonBlockingCoroutineName + ), + ) { ctx: Context, _: Unit -> + checkNonBlockingComponentTrampolineExecutor(ctx) + }, + ) + ), + ), + "do", + ) + .withInput(startMessage(1), inputCmd()) + .onlyBidiStream() + .expectingOutput( + runCmd(1), + proposeRunCompletion(1, Serde.VOID, null), + suspensionMessage(1), + ), + testInvocation( + ServiceDefinition.of( + "CheckBlockingComponentTrampolineExecutor", + ServiceType.SERVICE, + listOf( + HandlerDefinition.of( + "do", + HandlerType.SHARED, + Serde.VOID, + Serde.VOID, + dev.restate.sdk.HandlerRunner.of( + this::checkBlockingComponentTrampolineExecutor, + JacksonSerdeFactory(), + null, + ), + ) + ), + ), + "do", + ) + .withInput(startMessage(1), inputCmd()) + .onlyBidiStream() + .expectingOutput( + runCmd(1), + proposeRunCompletion(1, Serde.VOID, null), + suspensionMessage(1), + ), + ) + } +} diff --git a/sdk-core/bin/test/junit-platform.properties b/sdk-core/bin/test/junit-platform.properties new file mode 100644 index 000000000..3e799af08 --- /dev/null +++ b/sdk-core/bin/test/junit-platform.properties @@ -0,0 +1,3 @@ +junit.jupiter.execution.parallel.enabled = true +junit.jupiter.execution.parallel.config.strategy = dynamic +junit.jupiter.execution.parallel.mode.default = same_thread \ No newline at end of file diff --git a/sdk-core/bin/test/log4j2.properties b/sdk-core/bin/test/log4j2.properties new file mode 100644 index 000000000..5fd081b53 --- /dev/null +++ b/sdk-core/bin/test/log4j2.properties @@ -0,0 +1,8 @@ +rootLogger.level = TRACE +rootLogger.appenderRef.testlogger.ref = TestLogger + +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r [%t] %-5p %X %c:%L - %m%n \ No newline at end of file diff --git a/sdk-core/build.gradle.kts b/sdk-core/build.gradle.kts index 87ce62e49..02ec1296b 100644 --- a/sdk-core/build.gradle.kts +++ b/sdk-core/build.gradle.kts @@ -118,6 +118,11 @@ tasks { "dev.restate.sdk.core.javaapi.reflections.RawInputOutput", "dev.restate.sdk.core.javaapi.reflections.RawService", "dev.restate.sdk.core.javaapi.reflections.ServiceGreeter", + "dev.restate.sdk.core.javaapi.reflections.ServiceWithInterface", + "dev.restate.sdk.core.javaapi.reflections.ExternalInterface", + "dev.restate.sdk.core.javaapi.reflections.ServiceA", + "dev.restate.sdk.core.javaapi.reflections.ServiceB", + "dev.restate.sdk.core.javaapi.reflections.RouterService", ) options.compilerArgs.addAll( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ExternalInterface.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ExternalInterface.java new file mode 100644 index 000000000..8ec2b175e --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ExternalInterface.java @@ -0,0 +1,16 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.annotation.Handler; + +public interface ExternalInterface { + @Handler + String greet(String name); +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java index 0ca8e2767..22c83d970 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionTest.java @@ -138,6 +138,36 @@ public Stream definitions() { END_MESSAGE), testInvocation(CustomSerde::new, "greet") .withInput(startMessage(1), inputCmd(MySerdeFactory.SERDE, "input")) - .expectingOutput(outputCmd(MySerdeFactory.SERDE, "OUTPUT"), END_MESSAGE)); + .expectingOutput(outputCmd(MySerdeFactory.SERDE, "OUTPUT"), END_MESSAGE), + testInvocation(ServiceWithInterface::new, "callInterface") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, "Hello Francesco")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("MyGreeter", "greet"), "Francesco"), + outputCmd("Hello Francesco"), + END_MESSAGE), + testInvocation(ServiceWithInterface::new, "callInterfaceHandle") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, "Hello Francesco")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("MyGreeterHandle", "greet"), "Francesco"), + outputCmd("Hello Francesco"), + END_MESSAGE), + testInvocation(RouterService::new, "route") + .withInput( + startMessage(1), inputCmd("ServiceA"), callCompletion(2, "Hello from A, world")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("ServiceA", "greet"), "world"), + outputCmd("Hello from A, world"), + END_MESSAGE), + testInvocation(RouterService::new, "route") + .withInput( + startMessage(1), inputCmd("ServiceB"), callCompletion(2, "Hello from B, world")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("ServiceB", "greet"), "world"), + outputCmd("Hello from B, world"), + END_MESSAGE)); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RouterService.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RouterService.java new file mode 100644 index 000000000..a0b3630d4 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RouterService.java @@ -0,0 +1,22 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.Restate; +import dev.restate.sdk.annotation.Handler; +import dev.restate.sdk.annotation.Service; + +@Service +public class RouterService { + + @Handler + public String route(String targetService) { + return Restate.service(ExternalInterface.class, targetService).greet("world"); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceA.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceA.java new file mode 100644 index 000000000..aae448f6f --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceA.java @@ -0,0 +1,19 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.annotation.Service; + +@Service +public class ServiceA implements ExternalInterface { + @Override + public String greet(String name) { + return "Hello from A, " + name; + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceB.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceB.java new file mode 100644 index 000000000..c2e837c53 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceB.java @@ -0,0 +1,19 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.annotation.Service; + +@Service +public class ServiceB implements ExternalInterface { + @Override + public String greet(String name) { + return "Hello from B, " + name; + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceWithInterface.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceWithInterface.java new file mode 100644 index 000000000..cd563e0ac --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ServiceWithInterface.java @@ -0,0 +1,34 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.Restate; +import dev.restate.sdk.annotation.Handler; +import dev.restate.sdk.annotation.Service; + +@Service +public class ServiceWithInterface { + + public interface SharedInterface { + @Handler + String greet(String name); + } + + @Handler + public String callInterface(String name) { + return Restate.service(SharedInterface.class, "MyGreeter").greet(name); + } + + @Handler + public String callInterfaceHandle(String name) { + return Restate.serviceHandle(SharedInterface.class, "MyGreeterHandle") + .call(SharedInterface::greet, name) + .await(); + } +} diff --git a/sdk-http-vertx/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider b/sdk-http-vertx/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider new file mode 100644 index 000000000..8939c1e04 --- /dev/null +++ b/sdk-http-vertx/bin/main/META-INF/services/org.apache.logging.log4j.core.util.ContextDataProvider @@ -0,0 +1 @@ +io.reactiverse.contextual.logging.VertxContextDataProvider \ No newline at end of file diff --git a/sdk-serde-jackson/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider b/sdk-serde-jackson/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider new file mode 100644 index 000000000..9e8dfc1bf --- /dev/null +++ b/sdk-serde-jackson/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider @@ -0,0 +1 @@ +dev.restate.serde.jackson.JacksonSerdeFactoryProvider \ No newline at end of file diff --git a/sdk-serde-kotlinx/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider b/sdk-serde-kotlinx/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider new file mode 100644 index 000000000..6411a949b --- /dev/null +++ b/sdk-serde-kotlinx/bin/main/META-INF/services/dev.restate.serde.provider.DefaultSerdeFactoryProvider @@ -0,0 +1 @@ +dev.restate.serde.kotlinx.KotlinSerializationSerdeFactoryProvider \ No newline at end of file diff --git a/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt new file mode 100644 index 000000000..f1b1c88b6 --- /dev/null +++ b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt @@ -0,0 +1,140 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde.kotlinx + +import dev.restate.serde.Serde +import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps +import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.compileReferencing +import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.generateJsonSchema +import io.github.smiley4.schemakenerator.jsonschema.TitleBuilder +import io.github.smiley4.schemakenerator.jsonschema.data.IntermediateJsonSchemaData +import io.github.smiley4.schemakenerator.jsonschema.data.RefType +import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonArray +import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonNode +import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonObject +import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonTextValue +import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.array +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.initial +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.renameMembers +import kotlin.collections.set +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.json.Json + +object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFactory { + @OptIn(ExperimentalSerializationApi::class) + override fun generateSchema(json: Json, serializer: KSerializer<*>) = + Serde.StringifiedJsonSchema( + runCatching { + var initialStep = + initial(serializer.descriptor).analyzeTypeUsingKotlinxSerialization { + serializersModule = json.serializersModule + } + + if (json.configuration.namingStrategy != null) { + initialStep = initialStep.renameMembers(json.configuration.namingStrategy!!) + } + + val intermediateStep = + initialStep.generateJsonSchema { + optionalHandling = JsonSchemaSteps.OptionalHandling.NON_REQUIRED + } + intermediateStep.writeTitles() + val compiledSchema = intermediateStep.compileReferencing(RefType.SIMPLE) + + // In case of nested schemas, compileReferencing also contains self schema... + val rootSchemaName = + TitleBuilder.BUILDER_SIMPLE( + compiledSchema.typeData, + intermediateStep.typeDataById, + ) + + // If schema is not json object, then it's boolean, so we're good no need for + // additional manipulation + if (compiledSchema.json !is JsonObject) { + return@runCatching compiledSchema.json + } + + // Assemble the final schema now + val rootNode = compiledSchema.json as JsonObject + // Add $schema + rootNode.properties.put( + "\$schema", + JsonTextValue("https://json-schema.org/draft/2020-12/schema"), + ) + // Add $defs + val definitions = + compiledSchema.definitions.filter { it.key != rootSchemaName }.toMutableMap() + if (definitions.isNotEmpty()) { + rootNode.properties.put("\$defs", JsonObject(definitions)) + } + // Replace all $refs + rootNode.fixRefsPrefix("#/definitions/$rootSchemaName") + // If the root type is nullable, it should be in the schema too + if (serializer.descriptor.isNullable) { + val oldTypeProperty = rootNode.properties["type"] + if (oldTypeProperty is JsonTextValue) { + rootNode.properties["type"] = array { + item(oldTypeProperty.value) + item(JsonTextValue("null")) + } + } else if (oldTypeProperty is JsonArray) { + oldTypeProperty.items.add(JsonTextValue("null")) + } + } + + return@runCatching rootNode + } + .getOrDefault(JsonObject(mutableMapOf())) + .prettyPrint() + ) + + private fun IntermediateJsonSchemaData.writeTitles() { + this.entries.forEach { schema -> + if (schema.json is JsonObject) { + if ( + (schema.typeData.isMap || + schema.typeData.isCollection || + schema.typeData.isEnum || + schema.typeData.isInlineValue || + schema.typeData.typeParameters.isNotEmpty() || + schema.typeData.members.isNotEmpty()) && + (schema.json as JsonObject).properties["title"] == null + ) { + (schema.json as JsonObject).properties["title"] = + JsonTextValue(TitleBuilder.BUILDER_SIMPLE(schema.typeData, this.typeDataById)) + } + } + } + } + + private fun JsonNode.fixRefsPrefix(rootDefinition: String) { + when (this) { + is JsonArray -> this.items.forEach { it.fixRefsPrefix(rootDefinition) } + is JsonObject -> this.fixRefsPrefix(rootDefinition) + else -> {} + } + } + + private fun JsonObject.fixRefsPrefix(rootDefinition: String) { + this.properties.computeIfPresent("\$ref") { key, node -> + if (node is JsonTextValue) { + if (node.value.startsWith(rootDefinition)) { + JsonTextValue("#/" + node.value.removePrefix(rootDefinition)) + } else { + JsonTextValue("#/\$defs/" + node.value.removePrefix("#/definitions/")) + } + } else { + node + } + } + this.properties.values.forEach { it.fixRefsPrefix(rootDefinition) } + } +} diff --git a/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactory.kt b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactory.kt new file mode 100644 index 000000000..d7d3c5efd --- /dev/null +++ b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactory.kt @@ -0,0 +1,169 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde.kotlinx + +import dev.restate.common.Slice +import dev.restate.serde.Serde +import dev.restate.serde.Serde.Schema +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeRef +import dev.restate.serde.TypeTag +import java.nio.charset.StandardCharsets +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlinx.serialization.* +import kotlinx.serialization.builtins.nullable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.modules.SerializersModule + +/** + * This class implements [SerdeFactory] using Kotlinx serialization stack. + * + * If you want to customize the [Json] object used in your service, it is recommended to subclass + * this class, and then register it using the [dev.restate.sdk.annotation.CustomSerdeFactory] + * annotation. + */ +open class KotlinSerializationSerdeFactory +@JvmOverloads +constructor( + private val json: Json = Json.Default, + private val jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory, +) : SerdeFactory { + + /** Factory to generate json schemas. */ + interface JsonSchemaFactory { + fun generateSchema(json: Json, serializer: KSerializer<*>): Schema? + + companion object { + val NOOP = + object : JsonSchemaFactory { + override fun generateSchema(json: Json, serializer: KSerializer<*>): Schema? = null + } + } + } + + class KtTypeTag( + val type: KClass<*>, + /** Reified type */ + val kotlinType: KType?, + ) : TypeTag + + override fun create(typeTag: TypeTag): Serde { + if (typeTag is KtTypeTag) { + return create(typeTag) + } + return super.create(typeTag) + } + + @Suppress("UNCHECKED_CAST") + override fun create(typeRef: TypeRef): Serde { + if (typeRef.type == Unit::class.java) { + return UNIT as Serde + } + val serializer: KSerializer = + json.serializersModule.serializer(typeRef.type) as KSerializer + return jsonSerde(json, jsonSchemaFactory, serializer) + } + + @Suppress("UNCHECKED_CAST") + override fun create(clazz: Class): Serde { + if (clazz == Unit::class.java) { + return UNIT as Serde + } + val serializer: KSerializer = json.serializersModule.serializer(clazz) as KSerializer + return jsonSerde(json, jsonSchemaFactory, serializer) + } + + @Suppress("UNCHECKED_CAST") + @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) + private fun create(ktSerdeInfo: KtTypeTag): Serde { + if (ktSerdeInfo.type == Unit::class) { + return UNIT as Serde + } + val serializer: KSerializer = + json.serializersModule.serializerForKtTypeInfo(ktSerdeInfo) as KSerializer + return jsonSerde(json, jsonSchemaFactory, serializer) + } + + companion object { + val UNIT: Serde = + object : Serde { + // This is fine, it's less strict + override fun serialize(value: Unit?): Slice { + return Slice.EMPTY + } + + override fun deserialize(value: Slice) { + return + } + + override fun contentType(): String? { + return null + } + } + + /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ + fun jsonSerde( + json: Json = Json.Default, + jsonSchemaFactory: JsonSchemaFactory = DefaultJsonSchemaFactory, + serializer: KSerializer, + ): Serde { + val schema = jsonSchemaFactory.generateSchema(json, serializer) + + return object : Serde { + override fun serialize(value: T?): Slice { + if (value == null) { + return Slice.wrap(json.encodeToString(JsonNull.serializer(), JsonNull)) + } + + return Slice.wrap(json.encodeToString(serializer, value)) + } + + override fun deserialize(value: Slice): T { + return json.decodeFromString( + serializer, + String(value.toByteArray(), StandardCharsets.UTF_8), + ) + } + + override fun contentType(): String { + return "application/json" + } + + override fun jsonSchema(): Schema? { + return schema + } + } + } + } + + @InternalSerializationApi + @ExperimentalSerializationApi + /** Copy-pasted from ktor! */ + private fun SerializersModule.serializerForKtTypeInfo( + ktSerdeInfoInfo: KtTypeTag<*> + ): KSerializer<*> { + val module = this + return ktSerdeInfoInfo.kotlinType?.let { type -> + if (type.arguments.isEmpty()) { + null // fallback to a simple case because of + // https://github.com/Kotlin/kotlinx.serialization/issues/1870 + } else { + module.serializerOrNull(type) + } + } + ?: module.getContextual(ktSerdeInfoInfo.type)?.maybeNullable(ktSerdeInfoInfo) + ?: ktSerdeInfoInfo.type.serializer().maybeNullable(ktSerdeInfoInfo) + } + + private fun KSerializer.maybeNullable(typeInfo: KtTypeTag<*>): KSerializer<*> { + return if (typeInfo.kotlinType?.isMarkedNullable == true) this.nullable else this + } +} diff --git a/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactoryProvider.kt b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactoryProvider.kt new file mode 100644 index 000000000..e9338190c --- /dev/null +++ b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactoryProvider.kt @@ -0,0 +1,15 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde.kotlinx + +import dev.restate.serde.provider.DefaultSerdeFactoryProvider + +public class KotlinSerializationSerdeFactoryProvider : DefaultSerdeFactoryProvider { + override fun create() = KotlinSerializationSerdeFactory() +} diff --git a/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/api.kt b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/api.kt new file mode 100644 index 000000000..75f0fd1ee --- /dev/null +++ b/sdk-serde-kotlinx/bin/main/dev/restate/serde/kotlinx/api.kt @@ -0,0 +1,32 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde.kotlinx + +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import kotlin.reflect.typeOf +import kotlinx.serialization.json.* +import kotlinx.serialization.serializer + +/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ +inline fun jsonSerde( + json: Json = Json.Default, + jsonSchemaFactory: KotlinSerializationSerdeFactory.JsonSchemaFactory = + KotlinSerializationSerdeFactory.JsonSchemaFactory.NOOP, +): Serde { + @Suppress("UNCHECKED_CAST") + return when (typeOf()) { + typeOf() -> KotlinSerializationSerdeFactory.UNIT as Serde + else -> KotlinSerializationSerdeFactory.jsonSerde(json, jsonSchemaFactory, serializer()) + } +} + +/** Kotlin specific [TypeTag], using Kotlin's reified generics. */ +inline fun typeTag(): TypeTag = + KotlinSerializationSerdeFactory.KtTypeTag(T::class, typeOf()) diff --git a/sdk-serde-kotlinx/bin/test/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt b/sdk-serde-kotlinx/bin/test/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt new file mode 100644 index 000000000..4583a389f --- /dev/null +++ b/sdk-serde-kotlinx/bin/test/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt @@ -0,0 +1,189 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde.kotlinx + +import dev.restate.serde.Serde +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class KotlinxSerdeTest { + + @Serializable data class Recursive(val rec: Recursive? = null, val value: String) + + @Serializable + data class RecursiveCircular(val rec: RecursiveOtherCircular? = null, val value: String) + + @Serializable + data class RecursiveOtherCircular(val rec: RecursiveCircular? = null, val value: String) + + @Serializable + data class RecursiveTemplateCircular( + val rec: RecursiveTemplateOtherCircular? = null, + val value: V, + ) + + @Serializable + data class RecursiveTemplateOtherCircular( + val rec: RecursiveTemplateCircular? = null, + val value: V, + ) + + @Test + fun schemaGenWithPrimitive() { + testSchemaGen( + """ + { + "${'$'}schema": "https://json-schema.org/draft/2020-12/schema", + "type": "string" + } + """ + .trimIndent() + ) + } + + @Test + fun schemaGenWithNullablePrimitive() { + testSchemaGen( + """ + { + "${'$'}schema": "https://json-schema.org/draft/2020-12/schema", + "type": ["string", "null"] + } + """ + .trimIndent() + ) + } + + @Test + fun schemaGenWithRecursive() { + testSchemaGen( + """ + { + "type": "object", + "required": [ + "value" + ], + "properties": { + "rec": { + "${'$'}ref": "#/" + }, + "value": { + "type": "string" + } + }, + "title": "Recursive", + "${'$'}schema": "https://json-schema.org/draft/2020-12/schema" + } + """ + .trimIndent() + ) + } + + @Test + fun schemaGenWithRecursiveCircular() { + testSchemaGen( + """ + { + "type": "object", + "required": [ + "value" + ], + "properties": { + "rec": { + "${'$'}ref": "#/${'$'}defs/RecursiveOtherCircular" + }, + "value": { + "type": "string" + } + }, + "title": "RecursiveCircular", + "${'$'}schema": "https://json-schema.org/draft/2020-12/schema", + "${'$'}defs": { + "RecursiveOtherCircular": { + "type": "object", + "required": [ + "value" + ], + "properties": { + "rec": { + "${'$'}ref": "#/" + }, + "value": { + "type": "string" + } + }, + "title": "RecursiveOtherCircular" + } + } + } + """ + .trimIndent() + ) + } + + @Test + fun schemaGenWorksWithNestedRecursionTemplated() { + testSchemaGen>( + """ + { + "type": "object", + "required": [ + "value" + ], + "properties": { + "rec": { + "${'$'}ref": "#/${'$'}defs/RecursiveTemplateOtherCircular" + }, + "value": { + "type": "integer", + "minimum": -2147483648, + "maximum": 2147483647 + } + }, + "title": "RecursiveTemplateCircular", + "${'$'}schema": "https://json-schema.org/draft/2020-12/schema", + "${'$'}defs": { + "RecursiveTemplateOtherCircular": { + "type": "object", + "required": [ + "value" + ], + "properties": { + "rec": { + "${'$'}ref": "#/" + }, + "value": { + "type": "integer", + "minimum": -2147483648, + "maximum": 2147483647 + } + }, + "title": "RecursiveTemplateOtherCircular" + } + } + } + """ + .trimIndent() + ) + } + + inline fun testSchemaGen(expectedSchema: String) { + val expectedJsonElement = Json.decodeFromString(expectedSchema) + val actualSchema = + (jsonSerde(jsonSchemaFactory = DefaultJsonSchemaFactory).jsonSchema() + as Serde.StringifiedJsonSchema) + .schema + val actualJsonElement = Json.decodeFromString(actualSchema) + + assertThat(actualJsonElement).isEqualTo(expectedJsonElement) + } +} diff --git a/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/Greeter.kt b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/Greeter.kt new file mode 100644 index 000000000..ca2d0a144 --- /dev/null +++ b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/Greeter.kt @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.kotlin + +import dev.restate.sdk.annotation.Handler +import dev.restate.sdk.annotation.Name +import dev.restate.sdk.kotlin.Context +import dev.restate.sdk.springboot.RestateService +import org.springframework.beans.factory.annotation.Value + +@RestateService +@Name("greeter") +class Greeter { + @Value("\${greetingPrefix}") lateinit var greetingPrefix: String + + @Handler + fun greet(ctx: Context, person: String): String { + return greetingPrefix + person + } +} diff --git a/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt new file mode 100644 index 000000000..840978445 --- /dev/null +++ b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.kotlin + +import dev.restate.sdk.annotation.Handler +import dev.restate.sdk.annotation.Name +import dev.restate.sdk.kotlin.runBlock +import dev.restate.sdk.springboot.RestateService +import org.springframework.beans.factory.annotation.Value + +@RestateService +@Name("greeterNewApi") +open class GreeterNewApi { + @Value($$"${greetingPrefix}") internal lateinit var greetingPrefix: String + + @Handler + open suspend fun greet(person: String): String { + return runBlock { greetingPrefix } + person + } +} diff --git a/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt new file mode 100644 index 000000000..08ba54cbe --- /dev/null +++ b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt @@ -0,0 +1,62 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.kotlin + +import com.fasterxml.jackson.databind.ObjectMapper +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema +import dev.restate.sdk.springboot.RestateEndpointConfiguration +import dev.restate.sdk.springboot.RestateHttpConfiguration +import dev.restate.sdk.springboot.RestateHttpEndpointBean +import java.io.IOException +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest + +@SpringBootTest( + classes = + [RestateEndpointConfiguration::class, RestateHttpConfiguration::class, Greeter::class], + properties = ["restate.sdk.http.port=0"], +) +class RestateHttpEndpointBeanTest { + @Autowired lateinit var restateHttpEndpointBean: RestateHttpEndpointBean + + @Test + @Throws(IOException::class, InterruptedException::class) + fun httpEndpointShouldBeRunning() { + assertThat(restateHttpEndpointBean.isRunning).isTrue() + assertThat(restateHttpEndpointBean.actualPort()).isPositive() + + // Check if discovery replies containing the Greeter service + val client = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).build() + val response = + client.send( + HttpRequest.newBuilder() + .GET() + .version(HttpClient.Version.HTTP_2) + .uri( + URI.create("http://localhost:${restateHttpEndpointBean.actualPort()}/discover") + ) + .header("Accept", "application/vnd.restate.endpointmanifest.v1+json") + .build(), + HttpResponse.BodyHandlers.ofString(), + ) + assertThat(response.version()).isEqualTo(HttpClient.Version.HTTP_2) + assertThat(response.statusCode()).isEqualTo(200) + + val endpointManifest = + ObjectMapper().readValue(response.body(), EndpointManifestSchema::class.java) + + assertThat(endpointManifest.services).map { it?.name }.containsOnly("greeter") + } +} diff --git a/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt new file mode 100644 index 000000000..2015f253d --- /dev/null +++ b/sdk-spring-boot-kotlin-starter/bin/test/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt @@ -0,0 +1,56 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.kotlin + +import dev.restate.client.Client +import dev.restate.client.kotlin.service +import dev.restate.client.kotlin.toService +import dev.restate.sdk.testing.BindService +import dev.restate.sdk.testing.RestateClient +import dev.restate.sdk.testing.RestateTest +import kotlinx.coroutines.test.runTest +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest + +@SpringBootTest( + classes = [Greeter::class, GreeterNewApi::class], + properties = ["greetingPrefix=Something something "], +) +@RestateTest(containerImage = "ghcr.io/restatedev/restate:main") +class SdkTestingIntegrationTest { + @Autowired @BindService lateinit var greeter: Greeter + @Autowired @BindService lateinit var greeterNewApi: GreeterNewApi + + @Test + @Timeout(value = 10) + fun greet(@RestateClient ingressClient: Client) = runTest { + val client = GreeterClient.fromClient(ingressClient) + + assertThat(client.greet("Francesco")).isEqualTo("Something something Francesco") + } + + @Test + @Timeout(value = 10) + fun greetNewApi(@RestateClient ingressClient: Client) = runTest { + assertThat(ingressClient.service().greet("Francesco")) + .isEqualTo("Something something Francesco") + } + + @Test + @Timeout(value = 10) + fun greetNewApiWithRequestTo(@RestateClient ingressClient: Client) = runTest { + val response: String = + ingressClient.toService().request { greet("Francesco") }.call().response() + + assertThat(response).isEqualTo("Something something Francesco") + } +} diff --git a/sdk-spring-boot/bin/default/META-INF/spring-configuration-metadata.json b/sdk-spring-boot/bin/default/META-INF/spring-configuration-metadata.json new file mode 100644 index 000000000..58215c3f5 --- /dev/null +++ b/sdk-spring-boot/bin/default/META-INF/spring-configuration-metadata.json @@ -0,0 +1,71 @@ +{ + "groups": [ + { + "name": "restate", + "type": "dev.restate.sdk.springboot.RestateComponentsProperties", + "sourceType": "dev.restate.sdk.springboot.RestateComponentsProperties" + }, + { + "name": "restate.client", + "type": "dev.restate.sdk.springboot.RestateClientProperties", + "sourceType": "dev.restate.sdk.springboot.RestateClientProperties" + }, + { + "name": "restate.sdk", + "type": "dev.restate.sdk.springboot.RestateEndpointProperties", + "sourceType": "dev.restate.sdk.springboot.RestateEndpointProperties" + }, + { + "name": "restate.sdk.http", + "type": "dev.restate.sdk.springboot.RestateHttpServerProperties", + "sourceType": "dev.restate.sdk.springboot.RestateHttpServerProperties" + } + ], + "properties": [ + { + "name": "restate.client.base-uri", + "type": "java.lang.String", + "sourceType": "dev.restate.sdk.springboot.RestateClientProperties", + "defaultValue": "http:\/\/localhost:8080" + }, + { + "name": "restate.client.headers", + "type": "java.util.Map", + "sourceType": "dev.restate.sdk.springboot.RestateClientProperties" + }, + { + "name": "restate.components", + "type": "java.util.Map", + "sourceType": "dev.restate.sdk.springboot.RestateComponentsProperties" + }, + { + "name": "restate.executor", + "type": "java.lang.String", + "sourceType": "dev.restate.sdk.springboot.RestateComponentsProperties" + }, + { + "name": "restate.sdk.enable-preview-context", + "type": "java.lang.Boolean", + "sourceType": "dev.restate.sdk.springboot.RestateEndpointProperties", + "defaultValue": false + }, + { + "name": "restate.sdk.http.disable-bidirectional-streaming", + "type": "java.lang.Boolean", + "sourceType": "dev.restate.sdk.springboot.RestateHttpServerProperties", + "defaultValue": false + }, + { + "name": "restate.sdk.http.port", + "type": "java.lang.Integer", + "sourceType": "dev.restate.sdk.springboot.RestateHttpServerProperties", + "defaultValue": 9080 + }, + { + "name": "restate.sdk.identity-key", + "type": "java.lang.String", + "sourceType": "dev.restate.sdk.springboot.RestateEndpointProperties" + } + ], + "hints": [] +} \ No newline at end of file diff --git a/sdk-spring-boot/bin/main/META-INF/additional-spring-configuration-metadata.json b/sdk-spring-boot/bin/main/META-INF/additional-spring-configuration-metadata.json new file mode 100644 index 000000000..b2ec1e7db --- /dev/null +++ b/sdk-spring-boot/bin/main/META-INF/additional-spring-configuration-metadata.json @@ -0,0 +1,15 @@ +{ + "hints": [ + { + "name": "restate.executor", + "providers": [ + { + "name": "spring-bean-reference", + "parameters": { + "target": "java.util.concurrent.Executor" + } + } + ] + } + ] +} diff --git a/sdk-testing/bin/test/log4j2.properties b/sdk-testing/bin/test/log4j2.properties new file mode 100644 index 000000000..130894e5f --- /dev/null +++ b/sdk-testing/bin/test/log4j2.properties @@ -0,0 +1,18 @@ +# Set to debug or trace if log4j initialization is failing +status = warn + +# Console appender configuration +appender.console.type = Console +appender.console.name = consoleLogger +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p %notEmpty{[%X{restateServiceMethod}]}%notEmpty{[%X{restateInvocationId}]} %c - %m%n + +# Restate logs to debug level +logger.app.name = dev.restate +logger.app.level = debug +logger.app.additivity = false +logger.app.appenderRef.console.ref = consoleLogger + +# Root logger +rootLogger.level = info +rootLogger.appenderRef.stdout.ref = consoleLogger \ No newline at end of file diff --git a/test-services/bin/main/dev/restate/sdk/testservices/AwakeableHolderImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/AwakeableHolderImpl.kt new file mode 100644 index 000000000..c373689a0 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/AwakeableHolderImpl.kt @@ -0,0 +1,33 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.AwakeableHolder + +class AwakeableHolderImpl : AwakeableHolder { + companion object { + private val ID_KEY: StateKey = stateKey("id") + } + + override suspend fun hold(id: String) { + state().set(ID_KEY, id) + } + + override suspend fun hasAwakeable(): Boolean { + return state().get(ID_KEY) != null + } + + override suspend fun unlock(payload: String) { + val awakeableId = state().get(ID_KEY) ?: throw TerminalException("No awakeable registered") + awakeableHandle(awakeableId).resolve(payload) + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt new file mode 100644 index 000000000..5d181ec83 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt @@ -0,0 +1,43 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.BlockAndWaitWorkflow + +class BlockAndWaitWorkflowImpl : BlockAndWaitWorkflow { + companion object { + private val MY_DURABLE_PROMISE: DurablePromiseKey = durablePromiseKey("durable-promise") + private val MY_STATE: StateKey = stateKey("my-state") + } + + override suspend fun run(input: String): String { + state().set(MY_STATE, input) + + // Wait on unblock + val output: String = promise(MY_DURABLE_PROMISE).future().await() + + if (!promise(MY_DURABLE_PROMISE).peek().isReady) { + throw TerminalException("Durable promise should be completed") + } + + return output + } + + override suspend fun unblock(output: String) { + promiseHandle(MY_DURABLE_PROMISE).resolve(output) + } + + override suspend fun getState(): String? { + return state().get(MY_STATE) + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/CancelTestImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/CancelTestImpl.kt new file mode 100644 index 000000000..628009863 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/CancelTestImpl.kt @@ -0,0 +1,64 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.AwakeableHolder +import dev.restate.sdk.testservices.contracts.CancelTest +import kotlin.time.Duration.Companion.days + +class CancelTestImpl { + class RunnerImpl : CancelTest.Runner { + companion object { + private val CANCELED_STATE: StateKey = stateKey("canceled") + } + + override suspend fun startTest(operation: CancelTest.BlockingOperation) { + try { + virtualObject(objectKey()).block(operation) + } catch (e: TerminalException) { + if (e.code == TerminalException.CANCELLED_CODE) { + state().set(CANCELED_STATE, true) + } else { + throw e + } + } + } + + override suspend fun verifyTest(): Boolean { + return state().get(CANCELED_STATE) ?: false + } + } + + class BlockingService : CancelTest.BlockingService { + override suspend fun block(operation: CancelTest.BlockingOperation) { + val self = virtualObject(objectKey()) + val awakeableHolder = virtualObject(objectKey()) + + val awakeable = awakeable() + awakeableHolder.hold(awakeable.id) + awakeable.await() + + when (operation) { + CancelTest.BlockingOperation.CALL -> self.block(operation) + CancelTest.BlockingOperation.SLEEP -> sleep(1024.days) + CancelTest.BlockingOperation.AWAKEABLE -> { + val uncompletable: Awakeable = awakeable() + uncompletable.await() + } + } + } + + override suspend fun isUnlocked() { + // no-op + } + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/CounterImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/CounterImpl.kt new file mode 100644 index 000000000..ce61ed24c --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/CounterImpl.kt @@ -0,0 +1,59 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.Counter +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.Logger + +class CounterImpl : Counter { + + companion object { + private val logger: Logger = LogManager.getLogger(CounterImpl::class.java) + + private val COUNTER_KEY: StateKey = stateKey("counter") + } + + override suspend fun reset() { + logger.info("Counter cleaned up") + state().clear(COUNTER_KEY) + } + + override suspend fun addThenFail(value: Long) { + var counter: Long = state().get(COUNTER_KEY) ?: 0L + logger.info("Old counter value: {}", counter) + + counter += value + state().set(COUNTER_KEY, counter) + + logger.info("New counter value: {}", counter) + + throw TerminalException(objectKey()) + } + + override suspend fun get(): Long { + val counter: Long = state().get(COUNTER_KEY) ?: 0L + logger.info("Get counter value: {}", counter) + return counter + } + + override suspend fun add(value: Long): Counter.CounterUpdateResponse { + val oldCount: Long = state().get(COUNTER_KEY) ?: 0L + val newCount = oldCount + value + state().set(COUNTER_KEY, newCount) + + logger.info("Old counter value: {}", oldCount) + logger.info("New counter value: {}", newCount) + + return Counter.CounterUpdateResponse(oldCount, newCount) + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/FailingImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/FailingImpl.kt new file mode 100644 index 000000000..0aa1224d9 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/FailingImpl.kt @@ -0,0 +1,103 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.Failing +import java.util.concurrent.atomic.AtomicInteger +import kotlin.time.Duration.Companion.milliseconds +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.Logger + +class FailingImpl : Failing { + companion object { + private val LOG: Logger = LogManager.getLogger(FailingImpl::class.java) + } + + private val eventualSuccessCalls = AtomicInteger(0) + private val eventualSuccessSideEffectCalls = AtomicInteger(0) + private val eventualFailureSideEffectCalls = AtomicInteger(0) + + override suspend fun terminallyFailingCall(errorMessage: String) { + LOG.info("Invoked fail") + + throw TerminalException(errorMessage) + } + + override suspend fun callTerminallyFailingCall( + errorMessage: String, + ): String { + LOG.info("Invoked failAndHandle") + + virtualObject(random().nextUUID().toString()).terminallyFailingCall(errorMessage) + + throw IllegalStateException("This should be unreachable") + } + + override suspend fun failingCallWithEventualSuccess(): Int { + val currentAttempt = eventualSuccessCalls.incrementAndGet() + + if (currentAttempt >= 4) { + eventualSuccessCalls.set(0) + return currentAttempt + } else { + throw IllegalArgumentException("Failed at attempt: $currentAttempt") + } + } + + override suspend fun terminallyFailingSideEffect(errorMessage: String) { + runBlock { throw TerminalException(errorMessage) } + + throw IllegalStateException("Should not be reached.") + } + + override suspend fun sideEffectSucceedsAfterGivenAttempts( + minimumAttempts: Int, + ): Int = + runBlock( + name = "failing_side_effect", + retryPolicy = + retryPolicy { + initialDelay = 10.milliseconds + exponentiationFactor = 1.0f + }, + ) { + val currentAttempt = eventualSuccessSideEffectCalls.incrementAndGet() + if (currentAttempt >= 4) { + eventualSuccessSideEffectCalls.set(0) + return@runBlock currentAttempt + } else { + throw IllegalArgumentException("Failed at attempt: $currentAttempt") + } + } + + override suspend fun sideEffectFailsAfterGivenAttempts( + retryPolicyMaxRetryCount: Int, + ): Int { + try { + runBlock( + name = "failing_side_effect", + retryPolicy = + retryPolicy { + initialDelay = 10.milliseconds + exponentiationFactor = 1.0f + maxAttempts = retryPolicyMaxRetryCount + }, + ) { + val currentAttempt = eventualFailureSideEffectCalls.incrementAndGet() + throw IllegalArgumentException("Failed at attempt: $currentAttempt") + } + } catch (_: TerminalException) { + return eventualFailureSideEffectCalls.get() + } + // If I reach this point, the side effect succeeded... + throw TerminalException("Expecting the side effect to fail!") + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/KillTestImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/KillTestImpl.kt new file mode 100644 index 000000000..2bfc4ec42 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/KillTestImpl.kt @@ -0,0 +1,41 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.AwakeableHolder +import dev.restate.sdk.testservices.contracts.KillTest +import dev.restate.serde.Serde + +class KillTestImpl { + class RunnerImpl : KillTest.Runner { + // The call tree method invokes the KillSingletonService::recursiveCall which blocks on calling + // itself again. + // This will ensure that we have a call tree that is two calls deep and has a pending invocation + // in the inbox: + // startCallTree --> recursiveCall --> recursiveCall:inboxed + override suspend fun startCallTree() { + virtualObject(objectKey()).recursiveCall() + } + } + + class SingletonImpl : KillTest.Singleton { + override suspend fun recursiveCall() { + val awakeable = awakeable(Serde.RAW) + toVirtualObject(objectKey()).request { hold(awakeable.id) }.send() + awakeable.await() + + virtualObject(objectKey()).recursiveCall() + } + + override suspend fun isUnlocked() { + // no-op + } + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/ListObjectImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/ListObjectImpl.kt new file mode 100644 index 000000000..54440a2d5 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/ListObjectImpl.kt @@ -0,0 +1,37 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.ListObject + +class ListObjectImpl : ListObject { + companion object { + private val LIST_KEY: StateKey> = + stateKey( + "list", + ) + } + + override suspend fun append(value: String) { + val list = state().get(LIST_KEY) ?: emptyList() + state().set(LIST_KEY, list + value) + } + + override suspend fun get(): List { + return state().get(LIST_KEY) ?: emptyList() + } + + override suspend fun clear(): List { + val result = state().get(LIST_KEY) ?: emptyList() + state().clear(LIST_KEY) + return result + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/Main.kt b/test-services/bin/main/dev/restate/sdk/testservices/Main.kt new file mode 100644 index 000000000..978cac115 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/Main.kt @@ -0,0 +1,81 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.common.reflections.ReflectionUtils.extractServiceName +import dev.restate.sdk.auth.signing.RestateRequestIdentityVerifier +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.endpoint.endpoint +import dev.restate.sdk.testservices.contracts.* + +val KNOWN_SERVICES_FACTORIES: Map Any> = + mapOf( + extractServiceName(AwakeableHolder::class.java) to { AwakeableHolderImpl() }, + extractServiceName(BlockAndWaitWorkflow::class.java) to { BlockAndWaitWorkflowImpl() }, + extractServiceName(CancelTest.BlockingService::class.java) to + { + CancelTestImpl.BlockingService() + }, + extractServiceName(CancelTest.Runner::class.java) to { CancelTestImpl.RunnerImpl() }, + extractServiceName(Counter::class.java) to { CounterImpl() }, + extractServiceName(Failing::class.java) to { FailingImpl() }, + extractServiceName(KillTest.Runner::class.java) to { KillTestImpl.RunnerImpl() }, + extractServiceName(KillTest.Singleton::class.java) to { KillTestImpl.SingletonImpl() }, + extractServiceName(ListObject::class.java) to { ListObjectImpl() }, + extractServiceName(MapObject::class.java) to { MapObjectImpl() }, + extractServiceName(NonDeterministic::class.java) to { NonDeterministicImpl() }, + extractServiceName(Proxy::class.java) to { ProxyImpl() }, + extractServiceName(TestUtilsService::class.java) to { TestUtilsServiceImpl() }, + extractServiceName(VirtualObjectCommandInterpreter::class.java) to + { + VirtualObjectCommandInterpreterImpl() + }, + interpreterName(0) to { ObjectInterpreterImpl.getInterpreterDefinition(0) }, + interpreterName(1) to { ObjectInterpreterImpl.getInterpreterDefinition(1) }, + interpreterName(2) to { ObjectInterpreterImpl.getInterpreterDefinition(2) }, + extractServiceName(ServiceInterpreterHelper::class.java) to + { + ServiceInterpreterHelperImpl() + }, + ) + +val NEEDS_EXPERIMENTAL_CONTEXT: Set = setOf() + +fun main(args: Array) { + var env = System.getenv("SERVICES") + if (env == null) { + env = "*" + } + val endpoint = endpoint { + if (env == "*") { + for (svc in KNOWN_SERVICES_FACTORIES.values) { + bind(svc()) + } + } else { + for (svc in env.split(",".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) { + val fqsn = svc.trim { it <= ' ' } + bind( + KNOWN_SERVICES_FACTORIES[fqsn]?.invoke() + ?: throw IllegalStateException("Service $fqsn not implemented") + ) + } + } + + val requestSigningKey = System.getenv("E2E_REQUEST_SIGNING") + if (requestSigningKey != null) { + withRequestIdentityVerifier(RestateRequestIdentityVerifier.fromKey(requestSigningKey)) + } + + if (env == "*" || NEEDS_EXPERIMENTAL_CONTEXT.any { env.contains(it) }) { + enablePreviewContext() + } + } + + RestateHttpServer.listen(endpoint) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/MapObjectImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/MapObjectImpl.kt new file mode 100644 index 000000000..69263d6f2 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/MapObjectImpl.kt @@ -0,0 +1,33 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.MapObject + +class MapObjectImpl : MapObject { + override suspend fun set(entry: MapObject.Entry) { + state().set(stateKey(entry.key), entry.value) + } + + override suspend fun get(key: String): String { + return state().get(stateKey(key)) ?: "" + } + + override suspend fun clearAll(): List { + val keys = state().keys() + // AH AH AH and here I wanna see if you really respect determinism!!! + val result = mutableListOf() + for (k in keys) { + result.add(MapObject.Entry(k, state().get(stateKey(k))!!)) + } + state().clearAll() + return result + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/NonDeterministicImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/NonDeterministicImpl.kt new file mode 100644 index 000000000..20802b3ab --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/NonDeterministicImpl.kt @@ -0,0 +1,76 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.Counter +import dev.restate.sdk.testservices.contracts.NonDeterministic +import java.util.concurrent.ConcurrentHashMap +import kotlin.time.Duration.Companion.milliseconds + +class NonDeterministicImpl : NonDeterministic { + private val invocationCounts: ConcurrentHashMap = ConcurrentHashMap() + private val STATE_A: StateKey = stateKey("a") + private val STATE_B: StateKey = stateKey("b") + + override suspend fun eitherSleepOrCall() { + if (doLeftAction()) { + sleep(100.milliseconds) + } else { + virtualObject("abc").get() + } + // This is required to cause a suspension after the non-deterministic operation + sleep(100.milliseconds) + incrementCounter() + } + + override suspend fun callDifferentMethod() { + if (doLeftAction()) { + virtualObject("abc").get() + } else { + virtualObject("abc").reset() + } + // This is required to cause a suspension after the non-deterministic operation + sleep(100.milliseconds) + incrementCounter() + } + + override suspend fun backgroundInvokeWithDifferentTargets() { + if (doLeftAction()) { + toVirtualObject("abc").request { get() }.send() + } else { + toVirtualObject("abc").request { reset() }.send() + } + // This is required to cause a suspension after the non-deterministic operation + sleep(100.milliseconds) + incrementCounter() + } + + override suspend fun setDifferentKey() { + if (doLeftAction()) { + state().set(STATE_A, "my-state") + } else { + state().set(STATE_B, "my-state") + } + // This is required to cause a suspension after the non-deterministic operation + sleep(100.milliseconds) + incrementCounter() + } + + private suspend fun incrementCounter() { + toVirtualObject("abc").request { add(1) }.send() + } + + private suspend fun doLeftAction(): Boolean { + // Test runner sets an appropriate key here + val countKey = objectKey() + return invocationCounts.merge(countKey, 1) { a: Int, b: Int -> a + b }!! % 2 == 1 + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/ProxyImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/ProxyImpl.kt new file mode 100644 index 000000000..ba2795c6f --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/ProxyImpl.kt @@ -0,0 +1,96 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.common.Request +import dev.restate.common.Target +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.Proxy +import dev.restate.serde.Serde +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +class ProxyImpl : Proxy { + private fun Proxy.ProxyRequest.toTarget(): Target { + return if (this.virtualObjectKey == null) { + Target.service(this.serviceName, this.handlerName) + } else { + Target.virtualObject(this.serviceName, this.virtualObjectKey, this.handlerName) + } + } + + override suspend fun call(request: Proxy.ProxyRequest): ByteArray { + return context() + .call( + Request.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message).also { + if (request.idempotencyKey != null) { + it.idempotencyKey = request.idempotencyKey + } + } + ) + .await() + } + + override suspend fun oneWayCall(request: Proxy.ProxyRequest): String = + context() + .send( + Request.of(request.toTarget(), Serde.RAW, Serde.SLICE, request.message).also { + if (request.idempotencyKey != null) { + it.idempotencyKey = request.idempotencyKey + } + }, + request.delayMillis?.milliseconds ?: Duration.ZERO, + ) + .invocationId() + + override suspend fun manyCalls(requests: List) { + val toAwait = mutableListOf>() + + for (request in requests) { + if (request.oneWayCall) { + context() + .send( + Request.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.SLICE, + request.proxyRequest.message, + ) + .also { + if (request.proxyRequest.idempotencyKey != null) { + it.idempotencyKey = request.proxyRequest.idempotencyKey + } + }, + request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO, + ) + } else { + val fut = + context() + .call( + Request.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.RAW, + request.proxyRequest.message, + ) + .also { + if (request.proxyRequest.idempotencyKey != null) { + it.idempotencyKey = request.proxyRequest.idempotencyKey + } + } + ) + if (request.awaitAtTheEnd) { + toAwait.add(fut) + } + } + } + + toAwait.toList().awaitAll() + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt new file mode 100644 index 000000000..398cdcacb --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt @@ -0,0 +1,54 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.* +import java.util.* +import java.util.concurrent.atomic.AtomicInteger +import kotlin.time.Duration.Companion.milliseconds + +class TestUtilsServiceImpl : TestUtilsService { + override suspend fun echo(input: String): String { + return input + } + + override suspend fun uppercaseEcho(input: String): String { + return input.uppercase(Locale.getDefault()) + } + + override suspend fun echoHeaders(): Map { + return request().headers + } + + override suspend fun rawEcho(input: ByteArray): ByteArray { + check(input.contentEquals(request().bodyAsByteArray)) + return input + } + + override suspend fun sleepConcurrently(millisDuration: List) { + val timers = millisDuration.map { timer("${it.milliseconds}ms", it.milliseconds) }.toList() + + timers.awaitAll() + } + + override suspend fun countExecutedSideEffects(increments: Int): Int { + val invokedSideEffects = AtomicInteger(0) + + for (i in 0..(invocationId).cancel() + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt b/test-services/bin/main/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt new file mode 100644 index 000000000..1ebdede17 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt @@ -0,0 +1,137 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.common.TimeoutException +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.get +import dev.restate.sdk.testservices.contracts.VirtualObjectCommandInterpreter +import kotlin.time.Duration.Companion.milliseconds +import org.apache.logging.log4j.LogManager +import org.apache.logging.log4j.Logger + +class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { + + companion object { + private val LOG: Logger = LogManager.getLogger(VirtualObjectCommandInterpreterImpl::class.java) + } + + override suspend fun interpretCommands( + req: VirtualObjectCommandInterpreter.InterpretRequest, + ): String { + LOG.info("Interpreting commands {}", req) + + var result = "" + + req.commands.forEach { + LOG.info("Start interpreting command {}", it) + when (it) { + is VirtualObjectCommandInterpreter.AwaitAny -> { + val cmds = it.commands.map { it.toAwaitable() } + result = + select { + for (cmd in cmds) { + cmd.onAwait { it } + } + } + .await() + } + is VirtualObjectCommandInterpreter.AwaitAnySuccessful -> { + val cmds = it.commands.map { it.toAwaitable() }.toMutableList() + + while (true) { + @Suppress("UNCHECKED_CAST") + val completed = DurableFuture.any(cmds as List>).await() + + try { + result = cmds[completed].await() + break + } catch (_: TerminalException) { + // Remove the cmd to make sure we don't fail on it again + cmds.removeAt(completed) + } + } + } + is VirtualObjectCommandInterpreter.AwaitOne -> { + result = it.command.toAwaitable().await() + } + is VirtualObjectCommandInterpreter.GetEnvVariable -> { + result = runBlock { System.getenv(it.envName) ?: "" } + } + is VirtualObjectCommandInterpreter.ResolveAwakeable -> { + resolveAwakeable(it) + result = "" + } + is VirtualObjectCommandInterpreter.RejectAwakeable -> { + rejectAwakeable(it) + result = "" + } + is VirtualObjectCommandInterpreter.AwaitAwakeableOrTimeout -> { + val awk = awakeable() + state().set("awk-${it.awakeableKey}", awk.id) + try { + result = awk.await(it.timeoutMillis.milliseconds) + } catch (_: TimeoutException) { + throw TerminalException("await-timeout") + } + } + } + LOG.info("Command result {}", result) + appendResult(result) + } + + return result + } + + override suspend fun resolveAwakeable( + resolveAwakeable: VirtualObjectCommandInterpreter.ResolveAwakeable, + ) { + awakeableHandle( + state().get("awk-${resolveAwakeable.awakeableKey}") + ?: throw TerminalException("awakeable is not registerd yet") + ) + .resolve(resolveAwakeable.value) + } + + override suspend fun rejectAwakeable( + rejectAwakeable: VirtualObjectCommandInterpreter.RejectAwakeable, + ) { + awakeableHandle( + state().get("awk-${rejectAwakeable.awakeableKey}") + ?: throw TerminalException("awakeable is not registerd yet") + ) + .reject(rejectAwakeable.reason) + } + + override suspend fun hasAwakeable(awakeableKey: String): Boolean = + !state().get("awk-$awakeableKey").isNullOrBlank() + + override suspend fun getResults(): List = state().get("results") ?: listOf() + + private suspend fun VirtualObjectCommandInterpreter.AwaitableCommand.toAwaitable(): + DurableFuture { + return when (this) { + is VirtualObjectCommandInterpreter.CreateAwakeable -> { + val awk = awakeable() + state().set("awk-${this.awakeableKey}", awk.id) + awk + } + is VirtualObjectCommandInterpreter.RunThrowTerminalException -> + runAsync("should-fail-with-${this.reason}") { + throw TerminalException(this.reason) + } + is VirtualObjectCommandInterpreter.Sleep -> + timer("command-timer", this.timeoutMillis.milliseconds).map { "sleep" } + } + } + + private suspend fun appendResult(newResult: String) = + state().set("results", (state().get("results") ?: listOf()) + listOf(newResult)) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt new file mode 100644 index 000000000..ca165b0f0 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt @@ -0,0 +1,24 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +// This is a generic utility service that can be used in various situations where +// we need to synchronize the services with the test runner using an awakeable. +@VirtualObject +@Name("AwakeableHolder") +interface AwakeableHolder { + @Exclusive suspend fun hold(id: String) + + @Exclusive suspend fun hasAwakeable(): Boolean + + @Exclusive suspend fun unlock(payload: String) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt new file mode 100644 index 000000000..d5cde2da6 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt @@ -0,0 +1,22 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +@Workflow +@Name("BlockAndWaitWorkflow") +interface BlockAndWaitWorkflow { + @Workflow suspend fun run(input: String): String + + @Shared suspend fun unblock(output: String) + + @Shared suspend fun getState(): String? +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/CancelTest.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/CancelTest.kt new file mode 100644 index 000000000..d7613bd21 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/CancelTest.kt @@ -0,0 +1,39 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.Serializable + +interface CancelTest { + + @Serializable + enum class BlockingOperation { + CALL, + SLEEP, + AWAKEABLE, + } + + @VirtualObject + @Name("CancelTestRunner") + interface Runner { + @Handler suspend fun startTest(operation: BlockingOperation) + + @Handler suspend fun verifyTest(): Boolean + } + + @VirtualObject + @Name("CancelTestBlockingService") + interface BlockingService { + @Handler suspend fun block(operation: BlockingOperation) + + @Handler suspend fun isUnlocked() + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/Counter.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Counter.kt new file mode 100644 index 000000000..34024297e --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Counter.kt @@ -0,0 +1,31 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.Serializable + +@VirtualObject +@Name("Counter") +interface Counter { + @Serializable data class CounterUpdateResponse(val oldValue: Long, val newValue: Long) + + /** Add value to counter */ + @Handler suspend fun add(value: Long): CounterUpdateResponse + + /** Add value to counter, then fail with a Terminal error */ + @Handler suspend fun addThenFail(value: Long) + + /** Get count */ + @Shared suspend fun get(): Long + + /** Reset count */ + @Handler suspend fun reset() +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/Failing.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Failing.kt new file mode 100644 index 000000000..4d5d1a921 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Failing.kt @@ -0,0 +1,47 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +@VirtualObject +@Name("Failing") +interface Failing { + @Handler suspend fun terminallyFailingCall(errorMessage: String) + + @Handler suspend fun callTerminallyFailingCall(errorMessage: String): String + + @Handler suspend fun failingCallWithEventualSuccess(): Int + + @Handler suspend fun terminallyFailingSideEffect(errorMessage: String) + + /** + * `minimumAttempts` should be used to check when to succeed. The retry policy should be + * configured to be infinite. + * + * @return the number of executed attempts. In order to implement this count, an atomic counter in + * the service should be used. + */ + @Handler + suspend fun sideEffectSucceedsAfterGivenAttempts( + minimumAttempts: Int, + ): Int + + /** + * `retryPolicyMaxRetryCount` should be used to configure the retry policy. + * + * @return the number of executed attempts. In order to implement this count, an atomic counter in + * the service should be used. + */ + @Handler + suspend fun sideEffectFailsAfterGivenAttempts( + retryPolicyMaxRetryCount: Int, + ): Int +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/KillTest.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/KillTest.kt new file mode 100644 index 000000000..13916eb05 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/KillTest.kt @@ -0,0 +1,28 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +interface KillTest { + @VirtualObject + @Name("KillTestRunner") + interface Runner { + @Handler suspend fun startCallTree() + } + + @VirtualObject + @Name("KillTestSingleton") + interface Singleton { + @Handler suspend fun recursiveCall() + + @Handler suspend fun isUnlocked() + } +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/ListObject.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/ListObject.kt new file mode 100644 index 000000000..5398b4692 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/ListObject.kt @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +@VirtualObject +@Name("ListObject") +interface ListObject { + /** Append a value to the list object */ + @Handler suspend fun append(value: String) + + /** Get current list */ + @Handler suspend fun get(): List + + /** Clear list */ + @Handler suspend fun clear(): List +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/MapObject.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/MapObject.kt new file mode 100644 index 000000000..ca4024c4f --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/MapObject.kt @@ -0,0 +1,34 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.Serializable + +@VirtualObject +@Name("MapObject") +interface MapObject { + + @Serializable data class Entry(val key: String, val value: String) + + /** + * Set value in map. + * + * The individual entries should be stored as separate Restate state keys, and not in a single + * state key + */ + @Handler suspend fun set(entry: Entry) + + /** Get value from map. */ + @Handler suspend fun get(key: String): String + + /** Clear all entries */ + @Handler suspend fun clearAll(): List +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/NonDeterministic.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/NonDeterministic.kt new file mode 100644 index 000000000..ea83f4ecc --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/NonDeterministic.kt @@ -0,0 +1,25 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* + +@VirtualObject +@Name("NonDeterministic") +interface NonDeterministic { + /** On first invocation sleeps, on second invocation calls */ + @Handler suspend fun eitherSleepOrCall() + + @Handler suspend fun callDifferentMethod() + + @Handler suspend fun backgroundInvokeWithDifferentTargets() + + @Handler suspend fun setDifferentKey() +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/Proxy.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Proxy.kt new file mode 100644 index 000000000..c7392cb55 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/Proxy.kt @@ -0,0 +1,48 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.Serializable + +@Service +@Name("Proxy") +interface Proxy { + @Serializable + data class ProxyRequest( + val serviceName: String, + val virtualObjectKey: String? = null, // If null, the request is to a service + val handlerName: String, + // Bytes are encoded as array of numbers + val message: ByteArray, + val delayMillis: Int? = null, + val idempotencyKey: String? = null, + ) + + @Serializable + data class ManyCallRequest( + val proxyRequest: ProxyRequest, + /** If true, perform a one way call instead of a regular call */ + val oneWayCall: Boolean, + /** + * If await at the end, then perform the call as regular call, and collect all the futures to + * wait at the end, before returning, instead of awaiting them immediately. + */ + val awaitAtTheEnd: Boolean, + ) + + // Bytes are encoded as array of numbers + @Handler suspend fun call(request: ProxyRequest): ByteArray + + // Returns the invocation id of the call + @Handler suspend fun oneWayCall(request: ProxyRequest): String + + @Handler suspend fun manyCalls(requests: List) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/TestUtilsService.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/TestUtilsService.kt new file mode 100644 index 000000000..ee966f810 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/TestUtilsService.kt @@ -0,0 +1,43 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* + +/** Collection of various utilities/corner cases scenarios used by tests */ +@Service +@Name("TestUtilsService") +interface TestUtilsService { + /** Just echo */ + @Handler suspend fun echo(input: String): String + + /** Just echo but with uppercase */ + @Handler suspend fun uppercaseEcho(input: String): String + + /** Echo ingress headers */ + @Handler suspend fun echoHeaders(): Map + + /** Just echo */ + @Handler @Raw suspend fun rawEcho(@Raw input: ByteArray): ByteArray + + /** Create timers and await them all. Durations in milliseconds */ + @Handler suspend fun sleepConcurrently(millisDuration: List) + + /** + * Invoke `ctx.run` incrementing a local variable counter (not a restate state key!). + * + * Returns the count value. + * + * This is used to verify acks will suspend when using the always suspend test-suite + */ + @Handler suspend fun countExecutedSideEffects(increments: Int): Int + + /** Cancel invocation using the context. */ + @Handler suspend fun cancelInvocation(invocationId: String) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt new file mode 100644 index 000000000..ec962eb08 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt @@ -0,0 +1,87 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@VirtualObject +@Name("VirtualObjectCommandInterpreter") +interface VirtualObjectCommandInterpreter { + + @Serializable sealed interface AwaitableCommand + + // This is serialized as `{"type": "createAwakeable", ...}` + @Serializable + @SerialName("createAwakeable") + data class CreateAwakeable(val awakeableKey: String) : AwaitableCommand + + // This is serialized as `{"type": "sleep", ...}` + @Serializable @SerialName("sleep") data class Sleep(val timeoutMillis: Long) : AwaitableCommand + + // This is serialized as `{"type": "runThrowTerminalException", ...}` + @Serializable + @SerialName("runThrowTerminalException") + data class RunThrowTerminalException(val reason: String) : AwaitableCommand + + @Serializable sealed interface Command + + // Returns the index of the one that completed first successfully + @Serializable + @SerialName("awaitAnySuccessful") + data class AwaitAnySuccessful(val commands: List) : Command + + // Returns the index of the one that completed first + @Serializable + @SerialName("awaitAny") + data class AwaitAny(val commands: List) : Command + + // Returns the result + @Serializable @SerialName("awaitOne") data class AwaitOne(val command: AwaitableCommand) : Command + + // This is serialized as `{"type": "awaitAwakeableOrTimeout", ...}` + // The timeout throws a terminal error with "await-timeout" string in it + @Serializable + @SerialName("awaitAwakeableOrTimeout") + data class AwaitAwakeableOrTimeout(val awakeableKey: String, val timeoutMillis: Long) : Command + + @Serializable data class InterpretRequest(val commands: List) + + @Serializable + @SerialName("resolveAwakeable") + data class ResolveAwakeable(val awakeableKey: String, val value: String) : Command + + @Serializable + @SerialName("rejectAwakeable") + data class RejectAwakeable(val awakeableKey: String, val reason: String) : Command + + // This is serialized as `{"type": "getEnvVariable", ...}` + // Reading an environment variable should be done within a side effect! + @Serializable + @SerialName("getEnvVariable") + data class GetEnvVariable(val envName: String) : Command + + /** + * This handler should iterate through the list of commands and execute them. + * + * For each command, the output should be appended to the given list name. Returns the result of + * the last command, or empty string otherwise. + */ + @Handler suspend fun interpretCommands(req: InterpretRequest): String + + @Shared suspend fun resolveAwakeable(resolveAwakeable: ResolveAwakeable) + + @Shared suspend fun rejectAwakeable(rejectAwakeable: RejectAwakeable) + + @Shared suspend fun hasAwakeable(awakeableKey: String): Boolean + + @Shared suspend fun getResults(): List +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/contracts/interpreter.kt b/test-services/bin/main/dev/restate/sdk/testservices/contracts/interpreter.kt new file mode 100644 index 000000000..de9e5cb2c --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/contracts/interpreter.kt @@ -0,0 +1,267 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +@file:OptIn(ExperimentalSerializationApi::class) + +package dev.restate.sdk.testservices.contracts + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +@Serializable(with = CommandSerializer::class) +sealed class InterpreterCommand { + abstract val kind: Int +} + +@Serializable +class IncrementStateCounter : InterpreterCommand() { + companion object { + const val KIND = 4 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class RecoverTerminalCall : InterpreterCommand() { + companion object { + const val KIND = 13 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class RecoverTerminalCallMaybeUnAwaited : InterpreterCommand() { + companion object { + const val KIND = 14 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class ThrowingSideEffect : InterpreterCommand() { + companion object { + const val KIND = 11 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class SlowSideEffect : InterpreterCommand() { + companion object { + const val KIND = 12 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class IncrementStateCounterIndirectly : InterpreterCommand() { + companion object { + const val KIND = 5 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class ResolveAwakeable : InterpreterCommand() { + companion object { + const val KIND = 16 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class RejectAwakeable : InterpreterCommand() { + companion object { + const val KIND = 17 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class IncrementStateCounterViaAwakeable : InterpreterCommand() { + companion object { + const val KIND = 18 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class CallService : InterpreterCommand() { + companion object { + const val KIND = 7 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +class SideEffect : InterpreterCommand() { + companion object { + const val KIND = 10 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class GetState(val key: Int) : InterpreterCommand() { + companion object { + const val KIND = 2 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class ClearState(val key: Int) : InterpreterCommand() { + companion object { + const val KIND = 3 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class SetState(val key: Int) : InterpreterCommand() { + companion object { + const val KIND = 1 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class Sleep(val duration: Int) : InterpreterCommand() { + companion object { + const val KIND = 6 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class IncrementViaDelayedCall(val duration: Int) : InterpreterCommand() { + companion object { + const val KIND = 9 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class AwaitPromise(val index: Int) : InterpreterCommand() { + companion object { + const val KIND = 15 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class CallSlowService(val sleep: Int) : InterpreterCommand() { + companion object { + const val KIND = 8 + } + + @EncodeDefault override val kind: Int = KIND +} + +@Serializable +data class CallObject(val key: Int, val program: Program) : InterpreterCommand() { + companion object { + const val KIND = 19 + } + + @EncodeDefault override val kind: Int = KIND +} + +object CommandSerializer : + JsonContentPolymorphicSerializer(InterpreterCommand::class) { + override fun selectDeserializer( + element: JsonElement + ): DeserializationStrategy { + + return when (val type = element.jsonObject["kind"]?.jsonPrimitive?.intOrNull) { + IncrementStateCounter.KIND -> IncrementStateCounter.serializer() + RecoverTerminalCall.KIND -> RecoverTerminalCall.serializer() + RecoverTerminalCallMaybeUnAwaited.KIND -> RecoverTerminalCallMaybeUnAwaited.serializer() + ThrowingSideEffect.KIND -> ThrowingSideEffect.serializer() + SlowSideEffect.KIND -> SlowSideEffect.serializer() + IncrementStateCounterIndirectly.KIND -> IncrementStateCounterIndirectly.serializer() + ResolveAwakeable.KIND -> ResolveAwakeable.serializer() + RejectAwakeable.KIND -> RejectAwakeable.serializer() + IncrementStateCounterViaAwakeable.KIND -> IncrementStateCounterViaAwakeable.serializer() + CallService.KIND -> CallService.serializer() + SideEffect.KIND -> SideEffect.serializer() + GetState.KIND -> GetState.serializer() + ClearState.KIND -> ClearState.serializer() + SetState.KIND -> SetState.serializer() + Sleep.KIND -> Sleep.serializer() + IncrementViaDelayedCall.KIND -> IncrementViaDelayedCall.serializer() + AwaitPromise.KIND -> AwaitPromise.serializer() + CallSlowService.KIND -> CallSlowService.serializer() + CallObject.KIND -> CallObject.serializer() + else -> error("unknown command kind $type") + } + } +} + +@Serializable data class Program(val commands: List) + +@VirtualObject +@Name("ObjectInterpreter") +interface ObjectInterpreter { + + @Shared suspend fun counter(): Int + + @Handler suspend fun interpret(program: Program) +} + +@Serializable data class EchoLaterRequest(val sleep: Int, val parameter: String) + +@Serializable data class InterpreterId(val layer: Int, val key: String) + +@Serializable +data class IncrementViaAwakeableDanceRequest( + val interpreter: InterpreterId, + val txPromiseId: String, +) + +@Service +@Name("ServiceInterpreterHelper") +interface ServiceInterpreterHelper { + @Handler suspend fun ping() + + @Handler suspend fun echo(param: String): String + + @Handler suspend fun echoLater(req: EchoLaterRequest): String + + @Handler suspend fun terminalFailure() + + @Handler suspend fun incrementIndirectly(id: InterpreterId) + + @Handler suspend fun resolveAwakeable(id: String) + + @Handler suspend fun rejectAwakeable(id: String) + + @Handler suspend fun incrementViaAwakeableDance(req: IncrementViaAwakeableDanceRequest) +} diff --git a/test-services/bin/main/dev/restate/sdk/testservices/interpreter.kt b/test-services/bin/main/dev/restate/sdk/testservices/interpreter.kt new file mode 100644 index 000000000..0d8add183 --- /dev/null +++ b/test-services/bin/main/dev/restate/sdk/testservices/interpreter.kt @@ -0,0 +1,281 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.testservices + +import dev.restate.common.Request +import dev.restate.common.Target +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.common.StateKey +import dev.restate.sdk.common.TerminalException +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactories +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.* +import dev.restate.sdk.testservices.contracts.Program +import dev.restate.serde.Serde +import dev.restate.serde.kotlinx.typeTag +import kotlin.random.Random +import kotlin.time.Duration.Companion.milliseconds + +fun interpreterName(layer: Int): String { + return "${ReflectionUtils.extractServiceName(ObjectInterpreter::class.java)}L$layer" +} + +fun interpretTarget(layer: Int, key: String): Target { + return Target.virtualObject(interpreterName(layer), key, "interpret") +} + +suspend fun checkAwaitable( + actual: DurableFuture, + expected: T, + cmdIndex: Int, + interpreterCommand: InterpreterCommand, +) { + val result = actual.await() + if (result != expected) { + throw TerminalException( + "Awaited promise mismatch. got '$result' expected '$expected'; command at index $cmdIndex was $interpreterCommand" + ) + } +} + +suspend fun checkAwaitableFails( + actual: DurableFuture, + cmdIndex: Int, + interpreterCommand: InterpreterCommand, +) { + try { + actual.await() + } catch (e: TerminalException) { + return + } + throw TerminalException( + "Awaited promise mismatch. should fail but instead got ${actual.await()}; command at index $cmdIndex was $interpreterCommand" + ) +} + +fun cmdStateKey(key: Int): StateKey { + return stateKey("key-$key") +} + +class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { + companion object { + private val COUNTER: StateKey = stateKey("counter") + + fun getInterpreterDefinition(layer: Int): ServiceDefinition { + val serviceImpl = ObjectInterpreterImpl(layer) + val originalDefinition = + ServiceDefinitionFactories.discover(serviceImpl).create(serviceImpl, null) + return ServiceDefinition.of( + interpreterName(layer), + originalDefinition.serviceType, + originalDefinition.handlers, + ) + } + } + + private suspend fun interpreterId(): InterpreterId { + return InterpreterId(layer, objectKey()) + } + + override suspend fun counter(): Int { + return state().get(COUNTER) ?: 0 + } + + override suspend fun interpret(program: Program) { + val promises: MutableMap Unit> = mutableMapOf() + for ((i, cmd) in program.commands.withIndex()) { + when (cmd) { + is AwaitPromise -> { + val p = + promises.remove(cmd.index) + ?: throw TerminalException( + "ObjectInterpreterL$layer: can not find a promise for the id ${cmd.index}." + ) + // Await on promise, this will under the hood check the promise result + p() + } + is CallObject -> { + val awaitable = + context() + .call( + Request.of( + interpretTarget(layer + 1, cmd.key.toString()), + typeTag(), + typeTag(), + cmd.program, + ) + ) + promises[i] = { awaitable.await() } + } + is CallService -> { + val expected = "hello-$i" + val awaitable = toService().request { echo(expected) }.call() + promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } + } + is CallSlowService -> { + val expected = "hello-$i" + val awaitable = + toService() + .request { echoLater(EchoLaterRequest(cmd.sleep, expected)) } + .call() + promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } + } + is ClearState -> { + state().clear(cmdStateKey(cmd.key)) + } + is GetState -> { + state().get(cmdStateKey(cmd.key)) + } + is IncrementStateCounter -> { + state().set(COUNTER, (state().get(COUNTER) ?: 0) + 1) + } + is IncrementStateCounterIndirectly -> { + toService() + .request { incrementIndirectly(interpreterId()) } + .send() + } + is IncrementStateCounterViaAwakeable -> { + // Dancing in the mooonlight! + val awakeable = awakeable() + toService() + .request { + incrementViaAwakeableDance( + IncrementViaAwakeableDanceRequest(interpreterId(), awakeable.id) + ) + } + .send() + val theirPromiseIdForUsToResolve = awakeable.await() + awakeableHandle(theirPromiseIdForUsToResolve).resolve("ok") + } + is IncrementViaDelayedCall -> { + toService() + .request { incrementIndirectly(interpreterId()) } + .send(delay = cmd.duration.milliseconds) + } + is RecoverTerminalCall -> { + var caught = false + try { + service().terminalFailure() + } catch (e: TerminalException) { + caught = true + } + if (!caught) { + throw TerminalException( + "Test assertion failed, was expected to get a terminal error. Layer $layer, Command $i" + ) + } + } + is RecoverTerminalCallMaybeUnAwaited -> { + val awaitable = toService().request { terminalFailure() }.call() + promises[i] = { checkAwaitableFails(awaitable, i, cmd) } + } + is RejectAwakeable -> { + val awakeable = awakeable() + promises[i] = { checkAwaitableFails(awakeable, i, cmd) } + toService().request { rejectAwakeable(awakeable.id) }.send() + } + is ResolveAwakeable -> { + val awakeable = awakeable() + promises[i] = { checkAwaitable(awakeable, "ok", i, cmd) } + toService().request { resolveAwakeable(awakeable.id) }.send() + } + is SetState -> { + state().set(cmdStateKey(cmd.key), "value-${cmd.key}") + } + is SideEffect -> { + val expected = "hello-$i" + val result = runBlock { expected } + if (result != expected) { + throw TerminalException("Side effect result don't match: $result != $expected") + } + } + is Sleep -> { + sleep(cmd.duration.milliseconds) + } + is SlowSideEffect -> { + runBlock { kotlinx.coroutines.delay(1.milliseconds) } + } + is ThrowingSideEffect -> { + runBlock { + check(Random.nextBoolean()) { "Random failure caused by a very cool language." } + } + } + } + } + } +} + +class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { + override suspend fun ping() {} + + override suspend fun echo(param: String): String { + return param + } + + override suspend fun echoLater(req: EchoLaterRequest): String { + sleep(req.sleep.milliseconds) + return req.parameter + } + + override suspend fun terminalFailure() { + throw TerminalException("bye") + } + + override suspend fun incrementIndirectly(id: InterpreterId) { + val ignored = + context() + .send( + Request.of( + interpretTarget(id.layer, id.key), + typeTag(), + Serde.SLICE, + Program(listOf(IncrementStateCounter())), + ) + ) + } + + override suspend fun resolveAwakeable(id: String) { + awakeableHandle(id).resolve("ok") + } + + override suspend fun rejectAwakeable(id: String) { + awakeableHandle(id).resolve("error") + } + + override suspend fun incrementViaAwakeableDance( + req: IncrementViaAwakeableDanceRequest, + ) { + // + // 1. create an awakeable that we will be blocked on + // + val awakeable = awakeable() + // + // 2. send our awakeable id to the interpreter via txPromise. + // + awakeableHandle(req.txPromiseId).resolve(awakeable.id) + // + // 3. wait for the interpreter resolve us + // + awakeable.await() + // + // 4. to thank our interpret, let us ask it to inc its state. + // + val ignored = + context() + .send( + Request.of( + interpretTarget(req.interpreter.layer, req.interpreter.key), + typeTag(), + Serde.SLICE, + Program(listOf(IncrementStateCounter())), + ) + ) + } +} diff --git a/test-services/bin/main/log4j2.properties b/test-services/bin/main/log4j2.properties new file mode 100644 index 000000000..7434d169e --- /dev/null +++ b/test-services/bin/main/log4j2.properties @@ -0,0 +1,18 @@ +# Set to debug or trace if log4j initialization is failing +status = warn + +# Console appender configuration +appender.console.type = Console +appender.console.name = consoleLogger +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} [%tn] %-5p %notEmpty{[%X{restateInvocationTarget}]}%notEmpty{[%X{restateInvocationId}]}%notEmpty{[%X{restateInvocationStatus}]} %c:%L - %m%n + +# Restate logs to debug level +logger.app.name = dev.restate +logger.app.level = ${env:RESTATE_LOGGING:-debug} +logger.app.additivity = false +logger.app.appenderRef.console.ref = consoleLogger + +# Root logger +rootLogger.level = ${env:RESTATE_CORE_LOG:-info} +rootLogger.appenderRef.stdout.ref = consoleLogger \ No newline at end of file