diff --git a/external-storage/codec-server/main.go b/external-storage/codec-server/main.go index 784cfc8c..c545f4d1 100644 --- a/external-storage/codec-server/main.go +++ b/external-storage/codec-server/main.go @@ -83,21 +83,14 @@ func newCORSHTTPHandler(origin string, next http.Handler) http.Handler { }) } -func main() { - var port int - flag.IntVar(&port, "port", 8081, "Port to listen on") - flag.Parse() - - ctx := context.Background() - driver, err := externalstorage.NewS3Driver(ctx) - if err != nil { - log.Fatalf("new s3 driver: %v", err) - } - - // Build the payload handler with the same codec + external storage that - // the worker and starter use. PreStorageCodecs runs before storage on - // encode and after retrieval on decode, mirroring what the client-side - // DataConverter does. +// newCodecServerHandler builds the codec server's HTTP handler stack against the given +// external storage driver. +// +// PreStorageCodecs runs before storage on encode and after retrieval on +// decode, mirroring what the client-side DataConverter does. The handler must +// use the same codec + external storage configuration as the worker and +// starter so each side can read what the other wrote. +func newCodecServerHandler(driver converter.StorageDriver) (http.Handler, error) { defaultNamespaceHandler, err := converter.NewPayloadHTTPHandler(converter.PayloadHTTPHandlerOptions{ PreStorageCodecs: []converter.PayloadCodec{ converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}), @@ -107,15 +100,32 @@ func main() { }, }) if err != nil { - log.Fatalf("new payload handler: %v", err) + return nil, err } // Per-namespace map: extend this to host additional namespaces with their // own codec chain and/or storage backend. - handler := newPayloadNamespacesHTTPHandler(map[string]http.Handler{ + h := newPayloadNamespacesHTTPHandler(map[string]http.Handler{ "default": defaultNamespaceHandler, }) - handler = newCORSHTTPHandler(webUIOrigin, handler) + return newCORSHTTPHandler(webUIOrigin, h), nil +} + +func main() { + var port int + flag.IntVar(&port, "port", 8081, "Port to listen on") + flag.Parse() + + ctx := context.Background() + driver, err := externalstorage.NewS3Driver(ctx) + if err != nil { + log.Fatalf("new s3 driver: %v", err) + } + + handler, err := newCodecServerHandler(driver) + if err != nil { + log.Fatalf("new handler: %v", err) + } handler = newLoggingHTTPHandler(handler) srv := &http.Server{ diff --git a/external-storage/codec-server/main_test.go b/external-storage/codec-server/main_test.go new file mode 100644 index 00000000..43da8979 --- /dev/null +++ b/external-storage/codec-server/main_test.go @@ -0,0 +1,195 @@ +package main + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/johannesboyne/gofakes3" + "github.com/johannesboyne/gofakes3/backend/s3mem" + "github.com/stretchr/testify/require" + externalstorage "github.com/temporalio/samples-go/external-storage" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/contrib/aws/s3driver" + "go.temporal.io/sdk/contrib/aws/s3driver/awssdkv2" + "go.temporal.io/sdk/converter" + "google.golang.org/protobuf/encoding/protojson" +) + +const ( + testBucket = "test-bucket" + testNamespace = "default" +) + +// newCodecServer builds the full codec server middleware stack against an +// in-memory gofakes3, mirroring what main() does. Returned httptest.Server is +// closed by the caller. +func newCodecServer(t *testing.T) *httptest.Server { + t.Helper() + + backend := s3mem.New() + require.NoError(t, backend.CreateBucket(testBucket)) + s3Server := httptest.NewServer(gofakes3.New(backend).Server()) + t.Cleanup(s3Server.Close) + + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion("us-east-1"), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("test", "test", "")), + ) + require.NoError(t, err) + s3Client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.BaseEndpoint = aws.String(s3Server.URL) + o.UsePathStyle = true + }) + + driver, err := s3driver.NewDriver(s3driver.Options{ + Client: awssdkv2.NewClient(s3Client), + Bucket: s3driver.StaticBucket(testBucket), + }) + require.NoError(t, err) + + // Exercise the same handler stack main() builds. + handler, err := newCodecServerHandler(driver) + require.NoError(t, err) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + return srv +} + +func smallPayload() *commonpb.Payload { + return &commonpb.Payload{ + Metadata: map[string][]byte{"encoding": []byte("json/plain")}, + Data: []byte(`{"hello":"world"}`), + } +} + +// callPayloads POSTs the given payloads to the codec server and returns the +// decoded response. +func callPayloads(t *testing.T, url, namespace string, payloads ...*commonpb.Payload) (int, []*commonpb.Payload) { + t.Helper() + body, err := protojson.Marshal(&commonpb.Payloads{Payloads: payloads}) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + if namespace != "" { + req.Header.Set("X-Namespace", namespace) + } + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + if resp.StatusCode != http.StatusOK { + return resp.StatusCode, nil + } + + var out commonpb.Payloads + require.NoError(t, protojson.Unmarshal(respBody, &out)) + return resp.StatusCode, out.Payloads +} + +func Test_UnknownNamespaceReturns404(t *testing.T) { + srv := newCodecServer(t) + + status, _ := callPayloads(t, srv.URL+"/encode", "unregistered-ns", smallPayload()) + require.Equal(t, http.StatusNotFound, status) +} + +func Test_CORSPreflight(t *testing.T) { + srv := newCodecServer(t) + + req, err := http.NewRequest(http.MethodOptions, srv.URL+"/decode", nil) + require.NoError(t, err) + req.Header.Set("Origin", webUIOrigin) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, webUIOrigin, resp.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "POST,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods")) + require.Contains(t, resp.Header.Get("Access-Control-Allow-Headers"), "X-Namespace") +} + +func Test_CORSRejectsOtherOrigin(t *testing.T) { + srv := newCodecServer(t) + + req, err := http.NewRequest(http.MethodOptions, srv.URL+"/decode", nil) + require.NoError(t, err) + req.Header.Set("Origin", "http://evil.example.com") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + // The middleware only sets CORS headers when Origin matches webUIOrigin; + // a request from any other origin gets no CORS headers, which makes the + // browser refuse to use the response. + require.Empty(t, resp.Header.Get("Access-Control-Allow-Origin")) + require.Empty(t, resp.Header.Get("Access-Control-Allow-Methods")) + require.Empty(t, resp.Header.Get("Access-Control-Allow-Headers")) +} + +func Test_CORSHeadersOnPostResponse(t *testing.T) { + srv := newCodecServer(t) + + // CORS headers must land on the actual POST response too, not just the + // OPTIONS preflight — the browser checks both. + body, err := protojson.Marshal(&commonpb.Payloads{Payloads: []*commonpb.Payload{smallPayload()}}) + require.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, srv.URL+"/encode", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Namespace", testNamespace) + req.Header.Set("Origin", webUIOrigin) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, webUIOrigin, resp.Header.Get("Access-Control-Allow-Origin")) +} + +// Test_DecodePayloadCompatibility checks compatibility between the codec +// server and the DataConverter that the worker/starter use. +func Test_DecodePayloadCompatibility(t *testing.T) { + srv := newCodecServer(t) + + // Use the exact DataConverter the worker/starter use, so a change to + // either side breaks this test rather than passing silently. + workerConv := externalstorage.NewSampleDataConverter() + + type sample struct { + Greeting string `json:"greeting"` + Count int `json:"count"` + } + original := sample{Greeting: "hello", Count: 42} + + // Encode as the worker would, then send the result to /decode. + encoded, err := workerConv.ToPayloads(original) + require.NoError(t, err) + require.Equal(t, "binary/zlib", string(encoded.Payloads[0].GetMetadata()["encoding"])) + + status, decoded := callPayloads(t, srv.URL+"/decode", testNamespace, encoded.Payloads...) + require.Equal(t, http.StatusOK, status) + require.Len(t, decoded, 1) + + var got sample + require.NoError(t, converter.GetDefaultDataConverter().FromPayloads( + &commonpb.Payloads{Payloads: decoded}, &got)) + require.Equal(t, original, got) +} diff --git a/external-storage/data_converter.go b/external-storage/data_converter.go index 75d6ca72..8cbbc947 100644 --- a/external-storage/data_converter.go +++ b/external-storage/data_converter.go @@ -8,16 +8,18 @@ import ( "go.temporal.io/sdk/converter" ) +// NewSampleDataConverter is the single source of truth for the codec chain applied to +// every payload in this sample. The worker, starter, and codec server must all +// agree on it. +func NewSampleDataConverter() converter.DataConverter { + return converter.NewCodecDataConverter( + converter.GetDefaultDataConverter(), + converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}), + ) +} + // NewClient dials Temporal with the data converter and external storage -// configuration shared by every process in this sample (worker, starter, -// codec server). They must all agree on: -// -// - the codec chain that wraps each payload (zlib here), and -// - the external storage driver and threshold that decides when a payload -// is offloaded instead of stored inline. -// -// If any one of them diverges, payloads written by one side will not be -// readable by the other. +// configuration shared by every process in this sample. func NewClient(ctx context.Context, options client.Options) (client.Client, error) { driver, err := NewS3Driver(ctx) if err != nil { @@ -25,10 +27,7 @@ func NewClient(ctx context.Context, options client.Options) (client.Client, erro } if options.DataConverter == nil { - options.DataConverter = converter.NewCodecDataConverter( - converter.GetDefaultDataConverter(), - converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}), - ) + options.DataConverter = NewSampleDataConverter() } options.ExternalStorage = converter.ExternalStorage{ Drivers: []converter.StorageDriver{driver}, diff --git a/external-storage/s3-mock/main.go b/external-storage/s3-mock/main.go index 71a882b4..101822ec 100644 --- a/external-storage/s3-mock/main.go +++ b/external-storage/s3-mock/main.go @@ -41,20 +41,29 @@ func newLoggingHTTPHandler(next http.Handler) http.Handler { }) } +// newHandler builds an in-memory S3 handler with a single bucket pre-created. +// Callers (main, tests) typically wrap the result with newLoggingHTTPHandler. +func newHandler(bucket string) (http.Handler, error) { + backend := s3mem.New() + if err := backend.CreateBucket(bucket); err != nil { + return nil, err + } + return gofakes3.New(backend).Server(), nil +} + func main() { var port int flag.IntVar(&port, "port", 5000, "Port to listen on") flag.Parse() - backend := s3mem.New() - if err := backend.CreateBucket(externalstorage.S3Bucket); err != nil { - log.Fatalf("create bucket: %v", err) + handler, err := newHandler(externalstorage.S3Bucket) + if err != nil { + log.Fatalf("new handler: %v", err) } - faker := gofakes3.New(backend) srv := &http.Server{ Addr: "localhost:" + strconv.Itoa(port), - Handler: newLoggingHTTPHandler(faker.Server()), + Handler: newLoggingHTTPHandler(handler), } errCh := make(chan error, 1) diff --git a/external-storage/sample_data_test.go b/external-storage/sample_data_test.go new file mode 100644 index 00000000..048adebf --- /dev/null +++ b/external-storage/sample_data_test.go @@ -0,0 +1,19 @@ +package externalstorage + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GenerateOrders_DeterministicForSameBatchID(t *testing.T) { + first := generateOrders("BATCH-DETERMINISM", 5) + second := generateOrders("BATCH-DETERMINISM", 5) + require.Equal(t, first, second, "same batch ID must produce identical orders") +} + +func Test_GenerateOrders_DifferentBatchIDsDiffer(t *testing.T) { + a := generateOrders("BATCH-A", 5) + b := generateOrders("BATCH-B", 5) + require.NotEqual(t, a, b, "different batch IDs must produce different orders") +} diff --git a/external-storage/workflows.go b/external-storage/workflow.go similarity index 100% rename from external-storage/workflows.go rename to external-storage/workflow.go diff --git a/external-storage/workflow_test.go b/external-storage/workflow_test.go new file mode 100644 index 00000000..d21e3275 --- /dev/null +++ b/external-storage/workflow_test.go @@ -0,0 +1,49 @@ +package externalstorage + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/testsuite" +) + +func Test_ProcessOrderBatchWorkflow(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + // Real activities, not mocks: generateOrders is deterministic per batchID + // so the workflow's totals can be derived from the same source of truth. + env.RegisterActivity(FetchOrders) + env.RegisterActivity(ProcessOrders) + + request := OrderBatchRequest{BatchID: "BATCH-TEST", OrderCount: 10} + env.ExecuteWorkflow(ProcessOrderBatchWorkflow, request) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + + var summary BatchSummary + require.NoError(t, env.GetWorkflowResult(&summary)) + + // Re-run the activity pipeline directly to compute the expected totals. + // The workflow's BatchSummary must agree with this independent reduction. + orders, err := FetchOrders(context.Background(), request) + require.NoError(t, err) + processed, err := ProcessOrders(context.Background(), orders) + require.NoError(t, err) + + var expectedCost, expectedWeight float64 + var totalDays int + for _, p := range processed { + expectedCost += p.ShippingCostUSD + expectedWeight += p.TotalWeightKg + totalDays += p.EstimatedDeliveryDays + } + expectedAvg := float64(totalDays) / float64(len(processed)) + + require.Equal(t, request.BatchID, summary.BatchID) + require.Equal(t, request.OrderCount, summary.OrderCount) + require.InDelta(t, expectedCost, summary.TotalShippingCostUSD, 0.01) + require.InDelta(t, expectedWeight, summary.TotalWeightKg, 0.01) + require.InDelta(t, expectedAvg, summary.AvgDeliveryDays, 0.1) +}