Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions external-storage/codec-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,14 @@ func newCORSHTTPHandler(origin string, next http.Handler) http.Handler {
})
}

func main() {
var port int
flag.IntVar(&port, "port", 8081, "Port to listen on")
flag.Parse()

ctx := context.Background()
driver, err := externalstorage.NewS3Driver(ctx)
if err != nil {
log.Fatalf("new s3 driver: %v", err)
}

// Build the payload handler with the same codec + external storage that
// the worker and starter use. PreStorageCodecs runs before storage on
// encode and after retrieval on decode, mirroring what the client-side
// DataConverter does.
// newCodecServerHandler builds the codec server's HTTP handler stack against the given
// external storage driver.
//
// PreStorageCodecs runs before storage on encode and after retrieval on
// decode, mirroring what the client-side DataConverter does. The handler must
// use the same codec + external storage configuration as the worker and
// starter so each side can read what the other wrote.
func newCodecServerHandler(driver converter.StorageDriver) (http.Handler, error) {
defaultNamespaceHandler, err := converter.NewPayloadHTTPHandler(converter.PayloadHTTPHandlerOptions{
PreStorageCodecs: []converter.PayloadCodec{
converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}),
Expand All @@ -107,15 +100,32 @@ func main() {
},
})
if err != nil {
log.Fatalf("new payload handler: %v", err)
return nil, err
}

// Per-namespace map: extend this to host additional namespaces with their
// own codec chain and/or storage backend.
handler := newPayloadNamespacesHTTPHandler(map[string]http.Handler{
h := newPayloadNamespacesHTTPHandler(map[string]http.Handler{
"default": defaultNamespaceHandler,
})
handler = newCORSHTTPHandler(webUIOrigin, handler)
return newCORSHTTPHandler(webUIOrigin, h), nil
}

func main() {
var port int
flag.IntVar(&port, "port", 8081, "Port to listen on")
flag.Parse()

ctx := context.Background()
driver, err := externalstorage.NewS3Driver(ctx)
if err != nil {
log.Fatalf("new s3 driver: %v", err)
}

handler, err := newCodecServerHandler(driver)
if err != nil {
log.Fatalf("new handler: %v", err)
}
handler = newLoggingHTTPHandler(handler)

srv := &http.Server{
Expand Down
195 changes: 195 additions & 0 deletions external-storage/codec-server/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package main

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/johannesboyne/gofakes3"
"github.com/johannesboyne/gofakes3/backend/s3mem"
"github.com/stretchr/testify/require"
externalstorage "github.com/temporalio/samples-go/external-storage"
commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/sdk/contrib/aws/s3driver"
"go.temporal.io/sdk/contrib/aws/s3driver/awssdkv2"
"go.temporal.io/sdk/converter"
"google.golang.org/protobuf/encoding/protojson"
)

const (
testBucket = "test-bucket"
testNamespace = "default"
)

// newCodecServer builds the full codec server middleware stack against an
// in-memory gofakes3, mirroring what main() does. Returned httptest.Server is
// closed by the caller.
func newCodecServer(t *testing.T) *httptest.Server {
t.Helper()

backend := s3mem.New()
require.NoError(t, backend.CreateBucket(testBucket))
s3Server := httptest.NewServer(gofakes3.New(backend).Server())
t.Cleanup(s3Server.Close)

cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("test", "test", "")),
)
require.NoError(t, err)
s3Client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.BaseEndpoint = aws.String(s3Server.URL)
o.UsePathStyle = true
})

driver, err := s3driver.NewDriver(s3driver.Options{
Client: awssdkv2.NewClient(s3Client),
Bucket: s3driver.StaticBucket(testBucket),
})
require.NoError(t, err)

// Exercise the same handler stack main() builds.
handler, err := newCodecServerHandler(driver)
require.NoError(t, err)

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
return srv
}

func smallPayload() *commonpb.Payload {
return &commonpb.Payload{
Metadata: map[string][]byte{"encoding": []byte("json/plain")},
Data: []byte(`{"hello":"world"}`),
}
}

// callPayloads POSTs the given payloads to the codec server and returns the
// decoded response.
func callPayloads(t *testing.T, url, namespace string, payloads ...*commonpb.Payload) (int, []*commonpb.Payload) {
t.Helper()
body, err := protojson.Marshal(&commonpb.Payloads{Payloads: payloads})
require.NoError(t, err)

req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
if namespace != "" {
req.Header.Set("X-Namespace", namespace)
}

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
if resp.StatusCode != http.StatusOK {
return resp.StatusCode, nil
}

var out commonpb.Payloads
require.NoError(t, protojson.Unmarshal(respBody, &out))
return resp.StatusCode, out.Payloads
}

func Test_UnknownNamespaceReturns404(t *testing.T) {
srv := newCodecServer(t)

status, _ := callPayloads(t, srv.URL+"/encode", "unregistered-ns", smallPayload())
require.Equal(t, http.StatusNotFound, status)
}

func Test_CORSPreflight(t *testing.T) {
srv := newCodecServer(t)

req, err := http.NewRequest(http.MethodOptions, srv.URL+"/decode", nil)
require.NoError(t, err)
req.Header.Set("Origin", webUIOrigin)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, webUIOrigin, resp.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "POST,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
require.Contains(t, resp.Header.Get("Access-Control-Allow-Headers"), "X-Namespace")
}

func Test_CORSRejectsOtherOrigin(t *testing.T) {
srv := newCodecServer(t)

req, err := http.NewRequest(http.MethodOptions, srv.URL+"/decode", nil)
require.NoError(t, err)
req.Header.Set("Origin", "http://evil.example.com")

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

// The middleware only sets CORS headers when Origin matches webUIOrigin;
// a request from any other origin gets no CORS headers, which makes the
// browser refuse to use the response.
require.Empty(t, resp.Header.Get("Access-Control-Allow-Origin"))
require.Empty(t, resp.Header.Get("Access-Control-Allow-Methods"))
require.Empty(t, resp.Header.Get("Access-Control-Allow-Headers"))
}

func Test_CORSHeadersOnPostResponse(t *testing.T) {
srv := newCodecServer(t)

// CORS headers must land on the actual POST response too, not just the
// OPTIONS preflight — the browser checks both.
body, err := protojson.Marshal(&commonpb.Payloads{Payloads: []*commonpb.Payload{smallPayload()}})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, srv.URL+"/encode", bytes.NewReader(body))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Namespace", testNamespace)
req.Header.Set("Origin", webUIOrigin)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()

require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, webUIOrigin, resp.Header.Get("Access-Control-Allow-Origin"))
}

// Test_DecodePayloadCompatibility checks compatibility between the codec
// server and the DataConverter that the worker/starter use.
func Test_DecodePayloadCompatibility(t *testing.T) {
srv := newCodecServer(t)

// Use the exact DataConverter the worker/starter use, so a change to
// either side breaks this test rather than passing silently.
workerConv := externalstorage.NewSampleDataConverter()

type sample struct {
Greeting string `json:"greeting"`
Count int `json:"count"`
}
original := sample{Greeting: "hello", Count: 42}

// Encode as the worker would, then send the result to /decode.
encoded, err := workerConv.ToPayloads(original)
require.NoError(t, err)
require.Equal(t, "binary/zlib", string(encoded.Payloads[0].GetMetadata()["encoding"]))

status, decoded := callPayloads(t, srv.URL+"/decode", testNamespace, encoded.Payloads...)
require.Equal(t, http.StatusOK, status)
require.Len(t, decoded, 1)

var got sample
require.NoError(t, converter.GetDefaultDataConverter().FromPayloads(
&commonpb.Payloads{Payloads: decoded}, &got))
require.Equal(t, original, got)
}
25 changes: 12 additions & 13 deletions external-storage/data_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@ import (
"go.temporal.io/sdk/converter"
)

// NewSampleDataConverter is the single source of truth for the codec chain applied to
// every payload in this sample. The worker, starter, and codec server must all
// agree on it.
func NewSampleDataConverter() converter.DataConverter {
return converter.NewCodecDataConverter(
converter.GetDefaultDataConverter(),
converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}),
)
}

// NewClient dials Temporal with the data converter and external storage
// configuration shared by every process in this sample (worker, starter,
// codec server). They must all agree on:
//
// - the codec chain that wraps each payload (zlib here), and
// - the external storage driver and threshold that decides when a payload
// is offloaded instead of stored inline.
//
// If any one of them diverges, payloads written by one side will not be
// readable by the other.
// configuration shared by every process in this sample.
func NewClient(ctx context.Context, options client.Options) (client.Client, error) {
driver, err := NewS3Driver(ctx)
if err != nil {
return nil, fmt.Errorf("new s3 driver: %w", err)
}

if options.DataConverter == nil {
options.DataConverter = converter.NewCodecDataConverter(
converter.GetDefaultDataConverter(),
converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}),
)
options.DataConverter = NewSampleDataConverter()
}
options.ExternalStorage = converter.ExternalStorage{
Drivers: []converter.StorageDriver{driver},
Expand Down
19 changes: 14 additions & 5 deletions external-storage/s3-mock/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,29 @@ func newLoggingHTTPHandler(next http.Handler) http.Handler {
})
}

// newHandler builds an in-memory S3 handler with a single bucket pre-created.
// Callers (main, tests) typically wrap the result with newLoggingHTTPHandler.
func newHandler(bucket string) (http.Handler, error) {
backend := s3mem.New()
if err := backend.CreateBucket(bucket); err != nil {
return nil, err
}
return gofakes3.New(backend).Server(), nil
}

func main() {
var port int
flag.IntVar(&port, "port", 5000, "Port to listen on")
flag.Parse()

backend := s3mem.New()
if err := backend.CreateBucket(externalstorage.S3Bucket); err != nil {
log.Fatalf("create bucket: %v", err)
handler, err := newHandler(externalstorage.S3Bucket)
if err != nil {
log.Fatalf("new handler: %v", err)
}
faker := gofakes3.New(backend)

srv := &http.Server{
Addr: "localhost:" + strconv.Itoa(port),
Handler: newLoggingHTTPHandler(faker.Server()),
Handler: newLoggingHTTPHandler(handler),
}

errCh := make(chan error, 1)
Expand Down
19 changes: 19 additions & 0 deletions external-storage/sample_data_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package externalstorage

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_GenerateOrders_DeterministicForSameBatchID(t *testing.T) {
first := generateOrders("BATCH-DETERMINISM", 5)
second := generateOrders("BATCH-DETERMINISM", 5)
require.Equal(t, first, second, "same batch ID must produce identical orders")
}

func Test_GenerateOrders_DifferentBatchIDsDiffer(t *testing.T) {
a := generateOrders("BATCH-A", 5)
b := generateOrders("BATCH-B", 5)
require.NotEqual(t, a, b, "different batch IDs must produce different orders")
}
File renamed without changes.
Loading
Loading