diff --git a/README.md b/README.md index 9fa0e6e..1404896 100644 --- a/README.md +++ b/README.md @@ -319,7 +319,54 @@ end ## Middleware -Add custom middleware to the Rack stack: +### Invocation Middleware + +Invocation middleware wraps handler execution at the SDK level. This is different from Rack middleware — it runs inside the Restate invocation lifecycle and has access to the handler and context. + +```ruby +Restate.endpoint.define do + # Inbound: wraps handler execution + use MyInboundMiddleware.new + + # Outbound: wraps service-to-service calls + use_outbound MyOutboundMiddleware.new + + mount MyService +end +``` + +Write your own middleware by implementing `#call` and yielding: + +```ruby +class TimingMiddleware + def call(handler, context) + start = Process.clock_gettime(Process::CLOCK_MONOTONIC) + result = yield + elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - start + puts "#{handler.name} took #{elapsed.round(3)}s" + result + end +end +``` + +### Built-in: Deadlock Detection + +The SDK ships with deadlock detection middleware that catches re-entrant VirtualObject calls that would otherwise block forever: + +```ruby +Restate.endpoint.define do + use Restate::Middleware::DeadlockDetection::Inbound.new + use_outbound Restate::Middleware::DeadlockDetection::Outbound.new + + mount MyVirtualObject +end +``` + +If an exclusive handler on VO key "x" calls another exclusive handler on the same VO key "x", the middleware raises a `DeadlockError` (409) immediately instead of hanging. See `examples/middleware/` for a complete example. + +### Rack Middleware + +For HTTP-level middleware, use the Rack middleware stack: ```ruby middleware = Restate::MiddlewareStack.new diff --git a/examples/middleware/deadlock_detection.rb b/examples/middleware/deadlock_detection.rb new file mode 100644 index 0000000..4f6fc3e --- /dev/null +++ b/examples/middleware/deadlock_detection.rb @@ -0,0 +1,108 @@ +# frozen_string_literal: true + +# Deadlock Detection Middleware Example +# +# This example demonstrates how the DeadlockDetection middleware catches +# re-entrant VirtualObject calls that would otherwise block forever. +# +# == The Problem +# +# Restate VirtualObjects serialize exclusive handler access per key. If an +# exclusive handler on key "x" calls another exclusive handler on the same +# VO key "x", the second call waits for the first to finish — which never +# happens because the first is waiting for the second. Deadlock. +# +# == The Solution +# +# The DeadlockDetection middleware tracks which VO keys are held by the +# current call chain and raises immediately when a call would deadlock, +# giving you a clear error instead of a silent hang. +# +# == Running +# +# falcon serve -c examples/middleware/deadlock_detection.rb +# +# Then trigger the deadlock: +# +# curl -X POST http://localhost:8080/Account/my-account/transfer \ +# -H 'content-type: application/json' \ +# -d '{"to_account": "my-account", "amount": 100}' +# +# Without the middleware, this call would hang forever. +# With it, you get an immediate 409 error explaining the deadlock. + +require "restate" + +class Account < Restate::VirtualObject + state :balance, Integer, default: 0 + + handler :deposit + def deposit(input) + amount = input["amount"] + self.balance += amount + { balance: balance } + end + + handler :withdraw + def withdraw(input) + amount = input["amount"] + raise Restate::TerminalError.new("Insufficient funds", 400) if balance < amount + + self.balance -= amount + { balance: balance } + end + + # This handler demonstrates a potential deadlock. If `to_account` is the same + # as this object's key, the call to Account.call(to_account).deposit(...) + # would deadlock — we already hold the exclusive lock on this key. + # + # The DeadlockDetection middleware catches this and raises immediately. + handler :transfer + def transfer(input) + to_account = input["to_account"] + amount = input["amount"] + + raise Restate::TerminalError.new("Insufficient funds", 400) if balance < amount + + self.balance -= amount + + # If to_account == key, this call targets the same VO key we're holding. + # Without deadlock detection, it hangs forever. + # With deadlock detection, it raises DeadlockError immediately. + Account.call(to_account).deposit(amount: amount) + + { balance: balance } + end + + shared :get_balance + def get_balance + { balance: balance } + end +end + +# Configure with deadlock detection middleware +Restate.configure do |config| + config.bind = "http://localhost:4100" + config.ingress_url = ENV.fetch("RESTATE_INGRESS_URL", "http://localhost:8080") + config.admin_url = ENV.fetch("RESTATE_ADMIN_URL", "http://localhost:9070") +end + +Restate.endpoint.define do + # Register both inbound and outbound deadlock detection + use Restate::Middleware::DeadlockDetection::Inbound.new + use_outbound Restate::Middleware::DeadlockDetection::Outbound.new + + mount Account +end + +# For Falcon: falcon serve -c examples/middleware/deadlock_detection.rb +if $0 == __FILE__ + require "falcon/environment/server" + + service "restate" do + include Falcon::Environment::Server + count 1 + url { Restate.config.bind } + middleware { Falcon::Server.middleware(Restate.endpoint.to_rack_app, verbose:, cache:) } + end +end diff --git a/lib/restate.rb b/lib/restate.rb index 774cfce..fe74810 100644 --- a/lib/restate.rb +++ b/lib/restate.rb @@ -99,6 +99,10 @@ def initialize(message, status_code = 500) autoload :Client, "restate/client" autoload :InvocationMiddleware, "restate/invocation_middleware" + module Middleware + autoload :DeadlockDetection, "restate/middleware/deadlock_detection" + end + @config = nil class << self diff --git a/lib/restate/endpoint.rb b/lib/restate/endpoint.rb index cd9a850..5d22777 100644 --- a/lib/restate/endpoint.rb +++ b/lib/restate/endpoint.rb @@ -27,6 +27,8 @@ def initialize(protocol: nil, identity_keys: nil) @services = {} @protocol = protocol @identity_keys = identity_keys || [] + @inbound_middleware = [] + @outbound_middleware = [] end # @return [Symbol, nil] The protocol mode @@ -35,6 +37,12 @@ def initialize(protocol: nil, identity_keys: nil) # @return [Array] Identity verification keys attr_reader :identity_keys + # @return [Array] Inbound invocation middleware + attr_reader :inbound_middleware + + # @return [Array] Outbound invocation middleware + attr_reader :outbound_middleware + # Returns all registered services. # # @return [Array] Array of service classes @@ -51,6 +59,30 @@ def handler?(service_name, handler_name) !!(@services[service_name]&.handler?(handler_name)) end + # Adds inbound invocation middleware. + # + # Inbound middleware wraps handler execution. It receives the handler and + # context, and must yield to continue the chain. + # + # @param middleware [Object] A middleware instance responding to #call(handler, context) + # @return [self] + def use(middleware) + @inbound_middleware << middleware + self + end + + # Adds outbound invocation middleware. + # + # Outbound middleware wraps service-to-service calls. It receives the target + # service, handler, and a mutable headers hash, and must yield to continue. + # + # @param middleware [Object] A middleware instance responding to #call(service, handler, headers) + # @return [self] + def use_outbound(middleware) + @outbound_middleware << middleware + self + end + # Adds a service to the endpoint. # # @param service [Class] Service class to add @@ -82,6 +114,8 @@ def add(service, as: nil) # end def define(&) @services.clear + @inbound_middleware.clear + @outbound_middleware.clear Mapper.new(self).instance_exec(&) self end @@ -147,7 +181,10 @@ def invoke(service_name, handler_name, connection) Context.class_for(service.service_kind, handler.kind).new(connection).wrap do |context| serialized_input = connection.request.body - handler.call(context, serialized_input) + + InvocationMiddleware.invoke(@inbound_middleware, handler, context) do + handler.call(context, serialized_input) + end end end @@ -184,6 +221,22 @@ def initialize(endpoint) @endpoint = endpoint end + # Adds inbound invocation middleware. + # + # @param middleware [Object] A middleware instance + # @return [void] + def use(middleware) + @endpoint.use(middleware) + end + + # Adds outbound invocation middleware. + # + # @param middleware [Object] A middleware instance + # @return [void] + def use_outbound(middleware) + @endpoint.use_outbound(middleware) + end + # Mounts a service class. # # @param service_class [Class] Service class to mount diff --git a/lib/restate/invocation_middleware.rb b/lib/restate/invocation_middleware.rb new file mode 100644 index 0000000..a5a4e23 --- /dev/null +++ b/lib/restate/invocation_middleware.rb @@ -0,0 +1,82 @@ +# frozen_string_literal: true + +module Restate + # Middleware executed around each Restate handler invocation. + # + # Invocation middleware wraps the actual handler execution, giving you hooks + # to inspect or modify the invocation before and after the handler runs. + # This is distinct from Rack middleware, which wraps the HTTP layer. + # + # Middleware must implement +#call(handler, context)+ and yield to invoke the + # next middleware (or the handler itself). The return value of +yield+ is the + # handler's encoded output. + # + # @example Implementing a middleware + # class LoggingMiddleware + # def call(handler, context) + # puts "Before #{handler.name}" + # result = yield + # puts "After #{handler.name}" + # result + # end + # end + # + # @example Registering middleware on an endpoint + # Restate.endpoint.define do + # use LoggingMiddleware.new + # mount MyService + # end + # + # == Outbound middleware + # + # Outbound middleware wraps calls from one service to another. It receives + # the target service name, handler name, and a mutable headers hash. Yield + # to continue the call. + # + # @example Implementing outbound middleware + # class HeaderInjector + # def call(service, handler, headers) + # headers["x-custom"] = "value" + # yield + # end + # end + module InvocationMiddleware + # Builds a callable chain from an ordered list of middleware instances + # and a terminal block (the actual handler invocation). + # + # @param middlewares [Array] Middleware instances responding to #call + # @param handler [Restate::Handler] The handler being invoked + # @param context [Restate::Context] The invocation context + # @yield The terminal action (handler invocation) + # @return [Object] The result of the middleware chain + def self.invoke(middlewares, handler, context, &terminal) + if middlewares.empty? + terminal.call + else + chain = middlewares.reverse.reduce(terminal) do |next_step, mw| + proc { mw.call(handler, context) { next_step.call } } + end + chain.call + end + end + + # Builds a callable chain for outbound (service-to-service) middleware. + # + # @param middlewares [Array] Middleware instances responding to #call + # @param service [String] Target service name + # @param handler [String] Target handler name + # @param headers [Hash] Mutable headers hash + # @yield The terminal action (the actual outbound call) + # @return [Object] The result of the middleware chain + def self.invoke_outbound(middlewares, service, handler, headers, &terminal) + if middlewares.empty? + terminal.call + else + chain = middlewares.reverse.reduce(terminal) do |next_step, mw| + proc { mw.call(service, handler, headers) { next_step.call } } + end + chain.call + end + end + end +end diff --git a/lib/restate/middleware/deadlock_detection.rb b/lib/restate/middleware/deadlock_detection.rb new file mode 100644 index 0000000..9efd61e --- /dev/null +++ b/lib/restate/middleware/deadlock_detection.rb @@ -0,0 +1,174 @@ +# frozen_string_literal: true + +module Restate + module Middleware + # Detects VirtualObject deadlocks caused by re-entrant calls to a VO whose + # exclusive handler is still running higher up the call chain. + # + # == How it works + # + # Restate VirtualObjects serialize exclusive handler access per key. If handler A + # on VO key "x" calls handler B on the same VO key "x", the call will block + # forever — the key is already locked by A. This is a deadlock. + # + # This middleware detects that pattern by propagating a set of "held locks" + # (ServiceName:key pairs) via a header on every outbound call. + # + # === Inbound side + # + # 1. Reads the held-locks header from the incoming request. + # 2. If the current handler targets a VO+key already in the set → deadlock. + # Raises a {DeadlockError} immediately rather than blocking forever. + # 3. If this handler is an exclusive VO handler, appends its lock to the set + # so further downstream calls propagate it. + # + # === Outbound side + # + # Injects the held-locks header into every outbound service call, and also + # detects same-service deadlocks on the outbound side (calling the same VO + # service while holding its lock). + # + # == Journal determinism + # + # The held-locks header is deterministic across replays: its value depends only + # on the execution path, which Restate's journal guarantees is identical on + # every replay. + # + # == Usage + # + # Restate.endpoint.define do + # use Restate::Middleware::DeadlockDetection::Inbound.new + # use_outbound Restate::Middleware::DeadlockDetection::Outbound.new + # + # mount MyVirtualObject + # end + # + class DeadlockDetection + HEADER = "x-restate-held-locks" + SEPARATOR = "," + DEADLOCK_STATUS_CODE = 409 + + # Fiber-local storage key for the current set of held locks. + FIBER_KEY = :restate_held_exclusive_locks + + class << self + # Returns the current set of held exclusive locks for this fiber. + # + # @return [Set] Lock identifiers in the form "ServiceName:key" + def held_locks + Fiber[FIBER_KEY] || Set.new + end + + # Sets the held locks for the current fiber. + # + # @param locks [Set] The lock set + # @return [void] + def held_locks=(locks) + Fiber[FIBER_KEY] = locks + end + end + + # Error raised when a deadlock is detected. + # + # Uses status code 409 (Conflict) to signal that retrying won't help. + class DeadlockError < Restate::TerminalError + def initialize(message) + super(message, DEADLOCK_STATUS_CODE) + end + end + + # Inbound middleware that checks for and tracks VO locks. + # + # Wrap handler execution to detect deadlocks on the inbound side. + # + # @example + # endpoint.use(Restate::Middleware::DeadlockDetection::Inbound.new) + class Inbound + # @param handler [Restate::Handler] The handler being invoked + # @param context [Restate::Context] The invocation context + # @yield Invokes the next middleware or the handler + # @return [Object] The handler result + # @raise [DeadlockError] If the call would deadlock + def call(handler, context) + previous = DeadlockDetection.held_locks + + incoming = parse_locks(context) + is_vo = handler.respond_to?(:service_class) && + handler.service_class.respond_to?(:service_kind) && + handler.service_class.service_kind == :VIRTUAL_OBJECT + key = context.respond_to?(:key) ? context.key : nil + + if is_vo && key && handler.kind == :EXCLUSIVE + svc = handler.service_class.service_name + lock_id = "#{svc}:#{key}" + + if incoming.include?(lock_id) + raise DeadlockError, + "Deadlock detected: #{svc}##{handler.name} on key '#{key}' " \ + "called while an exclusive handler holds the same VO key. " \ + "Held locks: #{incoming.to_a.join(', ')}. " \ + "This call will never complete." + end + + incoming << lock_id + end + + DeadlockDetection.held_locks = incoming + yield + ensure + DeadlockDetection.held_locks = previous + end + + private + + def parse_locks(context) + headers = context.request.headers rescue nil + return Set.new unless headers + + raw = if headers.is_a?(Array) + headers.find { |name, _| name == HEADER }&.last + elsif headers.respond_to?(:[]) + headers[HEADER] + end + + return Set.new if raw.nil? || raw.to_s.empty? + + Set.new(raw.to_s.split(SEPARATOR).map(&:strip).reject(&:empty?)) + end + end + + # Outbound middleware that propagates held locks via headers. + # + # Injects the held-locks header into outbound calls and raises early + # if the outbound call targets a VO service whose lock is already held. + # + # @example + # endpoint.use_outbound(Restate::Middleware::DeadlockDetection::Outbound.new) + class Outbound + # @param service [String] Target service name + # @param handler [String] Target handler name + # @param headers [Hash] Mutable headers hash for the outbound call + # @yield Continues the outbound call + # @return [Object] The call result + # @raise [DeadlockError] If the call would deadlock + def call(service, handler, headers) + locks = DeadlockDetection.held_locks + if locks.any? + headers[HEADER] = locks.to_a.join(SEPARATOR) + + prefix = "#{service}:" + held_lock = locks.find { |l| l.start_with?(prefix) } + if held_lock + raise DeadlockError, + "Deadlock detected: outbound call to #{service}##{handler} " \ + "while exclusive lock held on #{held_lock}. " \ + "This call will block forever." + end + end + + yield + end + end + end + end +end diff --git a/test/middleware/deadlock_detection_test.rb b/test/middleware/deadlock_detection_test.rb new file mode 100644 index 0000000..a0bb46d --- /dev/null +++ b/test/middleware/deadlock_detection_test.rb @@ -0,0 +1,216 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class TestVO < Restate::VirtualObject + state :count, Integer, default: 0 + + handler :exclusive_work + def exclusive_work(input) + input + end + + shared :shared_read + def shared_read + count + end +end + +class TestService < Restate::BasicService + handler :process + def process(input) + input + end +end + +describe "Restate::Middleware::DeadlockDetection" do + let(:inbound) { Restate::Middleware::DeadlockDetection::Inbound.new } + let(:outbound) { Restate::Middleware::DeadlockDetection::Outbound.new } + let(:handler) { TestVO.find_handler("exclusive_work") } + let(:shared_handler) { TestVO.find_handler("shared_read") } + let(:service_handler) { TestService.find_handler("process") } + + before do + Restate::Middleware::DeadlockDetection.held_locks = Set.new + end + + describe "Inbound" do + it "allows calls with no held locks" do + context = stub_context(headers: []) + + result = inbound.call(handler, context) { :ok } + assert_equal :ok, result + end + + it "allows calls to a different VO key" do + context = stub_context( + headers: [["x-restate-held-locks", "TestVO:other-key"]], + key: "my-key" + ) + + result = inbound.call(handler, context) { :ok } + assert_equal :ok, result + end + + it "raises DeadlockError when calling same VO key" do + context = stub_context( + headers: [["x-restate-held-locks", "TestVO:my-key"]], + key: "my-key" + ) + + error = assert_raises(Restate::Middleware::DeadlockDetection::DeadlockError) do + inbound.call(handler, context) { :ok } + end + + assert_includes error.message, "Deadlock detected" + assert_includes error.message, "TestVO" + assert_includes error.message, "my-key" + assert_equal 409, error.status_code + end + + it "allows shared handlers on the same VO key" do + context = stub_context( + headers: [["x-restate-held-locks", "TestVO:my-key"]], + key: "my-key" + ) + + # Shared handlers don't hold an exclusive lock, but they also don't + # trigger deadlock detection since they can run concurrently + result = inbound.call(shared_handler, context) { :ok } + assert_equal :ok, result + end + + it "tracks exclusive locks through the chain" do + context = stub_context(headers: [], key: "my-key") + + inbound.call(handler, context) do + locks = Restate::Middleware::DeadlockDetection.held_locks + assert_includes locks, "TestVO:my-key" + :ok + end + end + + it "does not add locks for shared handlers" do + context = stub_context(headers: [], key: "my-key") + + inbound.call(shared_handler, context) do + locks = Restate::Middleware::DeadlockDetection.held_locks + refute_includes locks, "TestVO:my-key" + :ok + end + end + + it "restores previous locks after handler completes" do + previous = Set.new(["OtherVO:other-key"]) + Restate::Middleware::DeadlockDetection.held_locks = previous + + context = stub_context(headers: [], key: "my-key") + + inbound.call(handler, context) { :ok } + + assert_equal previous, Restate::Middleware::DeadlockDetection.held_locks + end + + it "restores previous locks even on error" do + previous = Set.new(["OtherVO:other-key"]) + Restate::Middleware::DeadlockDetection.held_locks = previous + + context = stub_context(headers: [], key: "my-key") + + assert_raises(RuntimeError) do + inbound.call(handler, context) { raise "boom" } + end + + assert_equal previous, Restate::Middleware::DeadlockDetection.held_locks + end + + it "skips detection for basic services" do + context = stub_context(headers: [["x-restate-held-locks", "TestService:something"]]) + + result = inbound.call(service_handler, context) { :ok } + assert_equal :ok, result + end + + it "accumulates locks from incoming header" do + context = stub_context( + headers: [["x-restate-held-locks", "OtherVO:other-key"]], + key: "my-key" + ) + + inbound.call(handler, context) do + locks = Restate::Middleware::DeadlockDetection.held_locks + assert_includes locks, "OtherVO:other-key" + assert_includes locks, "TestVO:my-key" + :ok + end + end + end + + describe "Outbound" do + it "injects held locks header" do + Restate::Middleware::DeadlockDetection.held_locks = Set.new(["SomeVO:some-key"]) + headers = {} + + outbound.call("OtherVO", "some_handler", headers) { :ok } + + assert_equal "SomeVO:some-key", headers["x-restate-held-locks"] + end + + it "does not inject header when no locks held" do + headers = {} + + outbound.call("SomeVO", "some_handler", headers) { :ok } + + refute headers.key?("x-restate-held-locks") + end + + it "raises DeadlockError when calling same service that holds lock" do + Restate::Middleware::DeadlockDetection.held_locks = Set.new(["MyVO:my-key"]) + headers = {} + + error = assert_raises(Restate::Middleware::DeadlockDetection::DeadlockError) do + outbound.call("MyVO", "some_handler", headers) { :ok } + end + + assert_includes error.message, "Deadlock detected" + assert_includes error.message, "MyVO" + assert_equal 409, error.status_code + end + + it "allows calls to a different service" do + Restate::Middleware::DeadlockDetection.held_locks = Set.new(["MyVO:my-key"]) + headers = {} + + result = outbound.call("OtherService", "some_handler", headers) { :ok } + assert_equal :ok, result + end + end + + describe "InvocationMiddleware.invoke integration" do + it "chains inbound middleware correctly" do + context = stub_context(headers: [], key: "test-key") + called = false + + Restate::InvocationMiddleware.invoke([inbound], handler, context) do + called = true + locks = Restate::Middleware::DeadlockDetection.held_locks + assert_includes locks, "TestVO:test-key" + :result + end + + assert called + end + end + + private + + def stub_context(headers: [], key: nil) + request = Struct.new(:headers).new(headers) + context = if key + Struct.new(:request, :key).new(request, key) + else + Struct.new(:request).new(request) + end + context + end +end