From 63aad20d932d597809643dc1a045e6c000141fb4 Mon Sep 17 00:00:00 2001 From: "Haihui.Wang" Date: Fri, 8 May 2026 16:40:13 +0800 Subject: [PATCH] Dataflow use csghub hardware for integration --- .gitignore | 2 +- .mockery.yaml | 1 + .../builder/deploy/imagerunner/mock_Runner.go | 214 +++++ .../builder/deploy/mock_Deployer.go | 168 +++- .../store/database/mock_ArgoWorkFlowStore.go | 18 +- .../mock_PlatformDataflowComponent.go | 377 ++++++++ .../component/mock_DataflowComponent.go | 234 +++++ api/handler/platform_dataflow.go | 220 +++++ api/handler/platform_dataflow_test.go | 250 +++++ api/router/api.go | 17 +- builder/deploy/cluster/cluster_manager.go | 4 +- builder/deploy/deployer.go | 18 +- builder/deploy/deployer_dataflow.go | 51 + builder/deploy/deployer_dataflow_test.go | 228 +++++ builder/deploy/deployer_test.go | 8 +- builder/deploy/imagerunner/local_runner.go | 17 + builder/deploy/imagerunner/remote_runner.go | 33 + .../deploy/imagerunner/remote_runner_test.go | 208 +++++ builder/deploy/imagerunner/runner.go | 5 + builder/store/database/argo_workflow.go | 19 +- builder/store/database/argo_workflow_test.go | 6 +- ...dd_column_deletedat_dagtasks_argo.down.sql | 9 + ..._add_column_deletedat_dagtasks_argo.up.sql | 9 + common/types/argo_workflow.go | 2 + common/types/dataflow.go | 113 +++ common/types/finetune.go | 4 +- common/types/webhook.go | 6 + component/accounting.go | 26 +- component/accounting_test.go | 6 + .../executors/webhook_executor_dataflow.go | 100 ++ .../webhook_executor_dataflow_pod.go | 96 ++ .../webhook_executor_dataflow_pod_test.go | 417 +++++++++ .../webhook_executor_dataflow_test.go | 395 ++++++++ component/finetune.go | 16 +- component/finetune_test.go | 6 +- component/platform_dataflow.go | 287 ++++++ component/platform_dataflow_test.go | 883 ++++++++++++++++++ component/webhook.go | 18 + runner/component/dataflow.go | 716 ++++++++++++++ runner/component/dataflow_test.go | 363 +++++++ runner/component/workflow.go | 22 +- runner/component/workflow_test.go | 12 +- runner/handler/dataflow.go | 87 ++ runner/handler/dataflow_test.go | 210 +++++ runner/router/api.go | 12 + 45 files changed, 5808 insertions(+), 105 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/component/mock_PlatformDataflowComponent.go create mode 100644 _mocks/opencsg.com/csghub-server/runner/component/mock_DataflowComponent.go create mode 100644 api/handler/platform_dataflow.go create mode 100644 api/handler/platform_dataflow_test.go create mode 100644 builder/deploy/deployer_dataflow.go create mode 100644 builder/deploy/deployer_dataflow_test.go create mode 100644 builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.down.sql create mode 100644 builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.up.sql create mode 100644 common/types/dataflow.go create mode 100644 component/executors/webhook_executor_dataflow.go create mode 100644 component/executors/webhook_executor_dataflow_pod.go create mode 100644 component/executors/webhook_executor_dataflow_pod_test.go create mode 100644 component/executors/webhook_executor_dataflow_test.go create mode 100644 component/platform_dataflow.go create mode 100644 component/platform_dataflow_test.go create mode 100644 runner/component/dataflow.go create mode 100644 runner/component/dataflow_test.go create mode 100644 runner/handler/dataflow.go create mode 100644 runner/handler/dataflow_test.go diff --git a/.gitignore b/.gitignore index 4654e4f9f..0a02cf51c 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,4 @@ docs/swagger.json docs/swagger.yaml docs/docs.go pgdata15/ - +.prd/ diff --git a/.mockery.yaml b/.mockery.yaml index a6a525e4a..87d34783d 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -266,6 +266,7 @@ packages: WorkFlowComponent: ServiceComponent: ClusterComponent: + DataflowComponent: opencsg.com/csghub-server/logcollector/component: config: all: true diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go b/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go index 2f0fb3b0b..5a64c6e4f 100644 --- a/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner/mock_Runner.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" + runnertypes "opencsg.com/csghub-server/runner/types" + types "opencsg.com/csghub-server/common/types" ) @@ -25,6 +27,65 @@ func (_m *MockRunner) EXPECT() *MockRunner_Expecter { return &MockRunner_Expecter{mock: &_m.Mock} } +// CreateDataflowWorkflow provides a mock function with given fields: ctx, req +func (_m *MockRunner) CreateDataflowWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDataflowWorkflow") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoJobReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_CreateDataflowWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDataflowWorkflow' +type MockRunner_CreateDataflowWorkflow_Call struct { + *mock.Call +} + +// CreateDataflowWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoJobReq +func (_e *MockRunner_Expecter) CreateDataflowWorkflow(ctx interface{}, req interface{}) *MockRunner_CreateDataflowWorkflow_Call { + return &MockRunner_CreateDataflowWorkflow_Call{Call: _e.mock.On("CreateDataflowWorkflow", ctx, req)} +} + +func (_c *MockRunner_CreateDataflowWorkflow_Call) Run(run func(ctx context.Context, req *types.DataflowArgoJobReq)) *MockRunner_CreateDataflowWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoJobReq)) + }) + return _c +} + +func (_c *MockRunner_CreateDataflowWorkflow_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockRunner_CreateDataflowWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_CreateDataflowWorkflow_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)) *MockRunner_CreateDataflowWorkflow_Call { + _c.Call.Return(run) + return _c +} + // CreateRevisions provides a mock function with given fields: _a0, _a1 func (_m *MockRunner) CreateRevisions(_a0 context.Context, _a1 *types.CreateRevisionReq) error { ret := _m.Called(_a0, _a1) @@ -72,6 +133,112 @@ func (_c *MockRunner_CreateRevisions_Call) RunAndReturn(run func(context.Context return _c } +// CreateSandbox provides a mock function with given fields: ctx, req +func (_m *MockRunner) CreateSandbox(ctx context.Context, req *runnertypes.SandboxRequest) (*runnertypes.Sandbox, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateSandbox") + } + + var r0 *runnertypes.Sandbox + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *runnertypes.SandboxRequest) (*runnertypes.Sandbox, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *runnertypes.SandboxRequest) *runnertypes.Sandbox); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*runnertypes.Sandbox) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *runnertypes.SandboxRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRunner_CreateSandbox_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateSandbox' +type MockRunner_CreateSandbox_Call struct { + *mock.Call +} + +// CreateSandbox is a helper method to define mock.On call +// - ctx context.Context +// - req *runnertypes.SandboxRequest +func (_e *MockRunner_Expecter) CreateSandbox(ctx interface{}, req interface{}) *MockRunner_CreateSandbox_Call { + return &MockRunner_CreateSandbox_Call{Call: _e.mock.On("CreateSandbox", ctx, req)} +} + +func (_c *MockRunner_CreateSandbox_Call) Run(run func(ctx context.Context, req *runnertypes.SandboxRequest)) *MockRunner_CreateSandbox_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*runnertypes.SandboxRequest)) + }) + return _c +} + +func (_c *MockRunner_CreateSandbox_Call) Return(_a0 *runnertypes.Sandbox, _a1 error) *MockRunner_CreateSandbox_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRunner_CreateSandbox_Call) RunAndReturn(run func(context.Context, *runnertypes.SandboxRequest) (*runnertypes.Sandbox, error)) *MockRunner_CreateSandbox_Call { + _c.Call.Return(run) + return _c +} + +// DeleteDataflowWorkflow provides a mock function with given fields: ctx, req +func (_m *MockRunner) DeleteDataflowWorkflow(ctx context.Context, req *types.DataflowArgoReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteDataflowWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRunner_DeleteDataflowWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteDataflowWorkflow' +type MockRunner_DeleteDataflowWorkflow_Call struct { + *mock.Call +} + +// DeleteDataflowWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoReq +func (_e *MockRunner_Expecter) DeleteDataflowWorkflow(ctx interface{}, req interface{}) *MockRunner_DeleteDataflowWorkflow_Call { + return &MockRunner_DeleteDataflowWorkflow_Call{Call: _e.mock.On("DeleteDataflowWorkflow", ctx, req)} +} + +func (_c *MockRunner_DeleteDataflowWorkflow_Call) Run(run func(ctx context.Context, req *types.DataflowArgoReq)) *MockRunner_DeleteDataflowWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoReq)) + }) + return _c +} + +func (_c *MockRunner_DeleteDataflowWorkflow_Call) Return(_a0 error) *MockRunner_DeleteDataflowWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRunner_DeleteDataflowWorkflow_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoReq) error) *MockRunner_DeleteDataflowWorkflow_Call { + _c.Call.Return(run) + return _c +} + // DeleteKsvcVersion provides a mock function with given fields: ctx, clusterID, svcName, commitID func (_m *MockRunner) DeleteKsvcVersion(ctx context.Context, clusterID string, svcName string, commitID string) error { ret := _m.Called(ctx, clusterID, svcName, commitID) @@ -121,6 +288,53 @@ func (_c *MockRunner_DeleteKsvcVersion_Call) RunAndReturn(run func(context.Conte return _c } +// DeleteSandbox provides a mock function with given fields: ctx, req +func (_m *MockRunner) DeleteSandbox(ctx context.Context, req *runnertypes.SandboxDeleteRequest) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteSandbox") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *runnertypes.SandboxDeleteRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRunner_DeleteSandbox_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteSandbox' +type MockRunner_DeleteSandbox_Call struct { + *mock.Call +} + +// DeleteSandbox is a helper method to define mock.On call +// - ctx context.Context +// - req *runnertypes.SandboxDeleteRequest +func (_e *MockRunner_Expecter) DeleteSandbox(ctx interface{}, req interface{}) *MockRunner_DeleteSandbox_Call { + return &MockRunner_DeleteSandbox_Call{Call: _e.mock.On("DeleteSandbox", ctx, req)} +} + +func (_c *MockRunner_DeleteSandbox_Call) Run(run func(ctx context.Context, req *runnertypes.SandboxDeleteRequest)) *MockRunner_DeleteSandbox_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*runnertypes.SandboxDeleteRequest)) + }) + return _c +} + +func (_c *MockRunner_DeleteSandbox_Call) Return(_a0 error) *MockRunner_DeleteSandbox_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRunner_DeleteSandbox_Call) RunAndReturn(run func(context.Context, *runnertypes.SandboxDeleteRequest) error) *MockRunner_DeleteSandbox_Call { + _c.Call.Return(run) + return _c +} + // DeleteWorkFlow provides a mock function with given fields: _a0, _a1 func (_m *MockRunner) DeleteWorkFlow(_a0 context.Context, _a1 types.ArgoWorkFlowDeleteReq) (*httpbase.R, error) { ret := _m.Called(_a0, _a1) diff --git a/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go b/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go index 8df1422c5..edd0ce375 100644 --- a/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go +++ b/_mocks/opencsg.com/csghub-server/builder/deploy/mock_Deployer.go @@ -156,6 +156,112 @@ func (_c *MockDeployer_CheckResourceAvailable_Call) RunAndReturn(run func(contex return _c } +// CreateDataflowJob provides a mock function with given fields: ctx, req +func (_m *MockDeployer) CreateDataflowJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDataflowJob") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoJobReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDeployer_CreateDataflowJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDataflowJob' +type MockDeployer_CreateDataflowJob_Call struct { + *mock.Call +} + +// CreateDataflowJob is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoJobReq +func (_e *MockDeployer_Expecter) CreateDataflowJob(ctx interface{}, req interface{}) *MockDeployer_CreateDataflowJob_Call { + return &MockDeployer_CreateDataflowJob_Call{Call: _e.mock.On("CreateDataflowJob", ctx, req)} +} + +func (_c *MockDeployer_CreateDataflowJob_Call) Run(run func(ctx context.Context, req *types.DataflowArgoJobReq)) *MockDeployer_CreateDataflowJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoJobReq)) + }) + return _c +} + +func (_c *MockDeployer_CreateDataflowJob_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockDeployer_CreateDataflowJob_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDeployer_CreateDataflowJob_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)) *MockDeployer_CreateDataflowJob_Call { + _c.Call.Return(run) + return _c +} + +// DeleteDataflowJob provides a mock function with given fields: ctx, req +func (_m *MockDeployer) DeleteDataflowJob(ctx context.Context, req *types.DataflowArgoReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteDataflowJob") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDeployer_DeleteDataflowJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteDataflowJob' +type MockDeployer_DeleteDataflowJob_Call struct { + *mock.Call +} + +// DeleteDataflowJob is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoReq +func (_e *MockDeployer_Expecter) DeleteDataflowJob(ctx interface{}, req interface{}) *MockDeployer_DeleteDataflowJob_Call { + return &MockDeployer_DeleteDataflowJob_Call{Call: _e.mock.On("DeleteDataflowJob", ctx, req)} +} + +func (_c *MockDeployer_DeleteDataflowJob_Call) Run(run func(ctx context.Context, req *types.DataflowArgoReq)) *MockDeployer_DeleteDataflowJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoReq)) + }) + return _c +} + +func (_c *MockDeployer_DeleteDataflowJob_Call) Return(_a0 error) *MockDeployer_DeleteDataflowJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDeployer_DeleteDataflowJob_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoReq) error) *MockDeployer_DeleteDataflowJob_Call { + _c.Call.Return(run) + return _c +} + // DeleteEvaluation provides a mock function with given fields: ctx, req func (_m *MockDeployer) DeleteEvaluation(ctx context.Context, req types.ArgoWorkFlowDeleteReq) error { ret := _m.Called(ctx, req) @@ -601,9 +707,9 @@ func (_c *MockDeployer_GetSharedModeResourceName_Call) RunAndReturn(run func(*co return _c } -// GetWorkflowLogsInStream provides a mock function with given fields: ctx, req -func (_m *MockDeployer) GetWorkflowLogsInStream(ctx context.Context, req types.FinetuneLogReq) (*deploy.MultiLogReader, error) { - ret := _m.Called(ctx, req) +// GetWorkflowLogsInStream provides a mock function with given fields: ctx, req, labels +func (_m *MockDeployer) GetWorkflowLogsInStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*deploy.MultiLogReader, error) { + ret := _m.Called(ctx, req, labels) if len(ret) == 0 { panic("no return value specified for GetWorkflowLogsInStream") @@ -611,19 +717,19 @@ func (_m *MockDeployer) GetWorkflowLogsInStream(ctx context.Context, req types.F var r0 *deploy.MultiLogReader var r1 error - if rf, ok := ret.Get(0).(func(context.Context, types.FinetuneLogReq) (*deploy.MultiLogReader, error)); ok { - return rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, types.WorkflowLogReq, map[string]string) (*deploy.MultiLogReader, error)); ok { + return rf(ctx, req, labels) } - if rf, ok := ret.Get(0).(func(context.Context, types.FinetuneLogReq) *deploy.MultiLogReader); ok { - r0 = rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, types.WorkflowLogReq, map[string]string) *deploy.MultiLogReader); ok { + r0 = rf(ctx, req, labels) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*deploy.MultiLogReader) } } - if rf, ok := ret.Get(1).(func(context.Context, types.FinetuneLogReq) error); ok { - r1 = rf(ctx, req) + if rf, ok := ret.Get(1).(func(context.Context, types.WorkflowLogReq, map[string]string) error); ok { + r1 = rf(ctx, req, labels) } else { r1 = ret.Error(1) } @@ -638,14 +744,15 @@ type MockDeployer_GetWorkflowLogsInStream_Call struct { // GetWorkflowLogsInStream is a helper method to define mock.On call // - ctx context.Context -// - req types.FinetuneLogReq -func (_e *MockDeployer_Expecter) GetWorkflowLogsInStream(ctx interface{}, req interface{}) *MockDeployer_GetWorkflowLogsInStream_Call { - return &MockDeployer_GetWorkflowLogsInStream_Call{Call: _e.mock.On("GetWorkflowLogsInStream", ctx, req)} +// - req types.WorkflowLogReq +// - labels map[string]string +func (_e *MockDeployer_Expecter) GetWorkflowLogsInStream(ctx interface{}, req interface{}, labels interface{}) *MockDeployer_GetWorkflowLogsInStream_Call { + return &MockDeployer_GetWorkflowLogsInStream_Call{Call: _e.mock.On("GetWorkflowLogsInStream", ctx, req, labels)} } -func (_c *MockDeployer_GetWorkflowLogsInStream_Call) Run(run func(ctx context.Context, req types.FinetuneLogReq)) *MockDeployer_GetWorkflowLogsInStream_Call { +func (_c *MockDeployer_GetWorkflowLogsInStream_Call) Run(run func(ctx context.Context, req types.WorkflowLogReq, labels map[string]string)) *MockDeployer_GetWorkflowLogsInStream_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(types.FinetuneLogReq)) + run(args[0].(context.Context), args[1].(types.WorkflowLogReq), args[2].(map[string]string)) }) return _c } @@ -655,14 +762,14 @@ func (_c *MockDeployer_GetWorkflowLogsInStream_Call) Return(_a0 *deploy.MultiLog return _c } -func (_c *MockDeployer_GetWorkflowLogsInStream_Call) RunAndReturn(run func(context.Context, types.FinetuneLogReq) (*deploy.MultiLogReader, error)) *MockDeployer_GetWorkflowLogsInStream_Call { +func (_c *MockDeployer_GetWorkflowLogsInStream_Call) RunAndReturn(run func(context.Context, types.WorkflowLogReq, map[string]string) (*deploy.MultiLogReader, error)) *MockDeployer_GetWorkflowLogsInStream_Call { _c.Call.Return(run) return _c } -// GetWorkflowLogsNonStream provides a mock function with given fields: ctx, req -func (_m *MockDeployer) GetWorkflowLogsNonStream(ctx context.Context, req types.FinetuneLogReq) (*loki.LokiQueryResponse, error) { - ret := _m.Called(ctx, req) +// GetWorkflowLogsNonStream provides a mock function with given fields: ctx, req, labels +func (_m *MockDeployer) GetWorkflowLogsNonStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*loki.LokiQueryResponse, error) { + ret := _m.Called(ctx, req, labels) if len(ret) == 0 { panic("no return value specified for GetWorkflowLogsNonStream") @@ -670,19 +777,19 @@ func (_m *MockDeployer) GetWorkflowLogsNonStream(ctx context.Context, req types. var r0 *loki.LokiQueryResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context, types.FinetuneLogReq) (*loki.LokiQueryResponse, error)); ok { - return rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, types.WorkflowLogReq, map[string]string) (*loki.LokiQueryResponse, error)); ok { + return rf(ctx, req, labels) } - if rf, ok := ret.Get(0).(func(context.Context, types.FinetuneLogReq) *loki.LokiQueryResponse); ok { - r0 = rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, types.WorkflowLogReq, map[string]string) *loki.LokiQueryResponse); ok { + r0 = rf(ctx, req, labels) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*loki.LokiQueryResponse) } } - if rf, ok := ret.Get(1).(func(context.Context, types.FinetuneLogReq) error); ok { - r1 = rf(ctx, req) + if rf, ok := ret.Get(1).(func(context.Context, types.WorkflowLogReq, map[string]string) error); ok { + r1 = rf(ctx, req, labels) } else { r1 = ret.Error(1) } @@ -697,14 +804,15 @@ type MockDeployer_GetWorkflowLogsNonStream_Call struct { // GetWorkflowLogsNonStream is a helper method to define mock.On call // - ctx context.Context -// - req types.FinetuneLogReq -func (_e *MockDeployer_Expecter) GetWorkflowLogsNonStream(ctx interface{}, req interface{}) *MockDeployer_GetWorkflowLogsNonStream_Call { - return &MockDeployer_GetWorkflowLogsNonStream_Call{Call: _e.mock.On("GetWorkflowLogsNonStream", ctx, req)} +// - req types.WorkflowLogReq +// - labels map[string]string +func (_e *MockDeployer_Expecter) GetWorkflowLogsNonStream(ctx interface{}, req interface{}, labels interface{}) *MockDeployer_GetWorkflowLogsNonStream_Call { + return &MockDeployer_GetWorkflowLogsNonStream_Call{Call: _e.mock.On("GetWorkflowLogsNonStream", ctx, req, labels)} } -func (_c *MockDeployer_GetWorkflowLogsNonStream_Call) Run(run func(ctx context.Context, req types.FinetuneLogReq)) *MockDeployer_GetWorkflowLogsNonStream_Call { +func (_c *MockDeployer_GetWorkflowLogsNonStream_Call) Run(run func(ctx context.Context, req types.WorkflowLogReq, labels map[string]string)) *MockDeployer_GetWorkflowLogsNonStream_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(types.FinetuneLogReq)) + run(args[0].(context.Context), args[1].(types.WorkflowLogReq), args[2].(map[string]string)) }) return _c } @@ -714,7 +822,7 @@ func (_c *MockDeployer_GetWorkflowLogsNonStream_Call) Return(_a0 *loki.LokiQuery return _c } -func (_c *MockDeployer_GetWorkflowLogsNonStream_Call) RunAndReturn(run func(context.Context, types.FinetuneLogReq) (*loki.LokiQueryResponse, error)) *MockDeployer_GetWorkflowLogsNonStream_Call { +func (_c *MockDeployer_GetWorkflowLogsNonStream_Call) RunAndReturn(run func(context.Context, types.WorkflowLogReq, map[string]string) (*loki.LokiQueryResponse, error)) *MockDeployer_GetWorkflowLogsNonStream_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ArgoWorkFlowStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ArgoWorkFlowStore.go index 116c6d312..fedac79ba 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ArgoWorkFlowStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_ArgoWorkFlowStore.go @@ -188,22 +188,24 @@ func (_c *MockArgoWorkFlowStore_FindByID_Call) RunAndReturn(run func(context.Con } // FindByTaskID provides a mock function with given fields: ctx, id -func (_m *MockArgoWorkFlowStore) FindByTaskID(ctx context.Context, id string) (database.ArgoWorkflow, error) { +func (_m *MockArgoWorkFlowStore) FindByTaskID(ctx context.Context, id string) (*database.ArgoWorkflow, error) { ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindByTaskID") } - var r0 database.ArgoWorkflow + var r0 *database.ArgoWorkflow var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (database.ArgoWorkflow, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (*database.ArgoWorkflow, error)); ok { return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(context.Context, string) database.ArgoWorkflow); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) *database.ArgoWorkflow); ok { r0 = rf(ctx, id) } else { - r0 = ret.Get(0).(database.ArgoWorkflow) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.ArgoWorkflow) + } } if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -234,12 +236,12 @@ func (_c *MockArgoWorkFlowStore_FindByTaskID_Call) Run(run func(ctx context.Cont return _c } -func (_c *MockArgoWorkFlowStore_FindByTaskID_Call) Return(WorkFlow database.ArgoWorkflow, err error) *MockArgoWorkFlowStore_FindByTaskID_Call { - _c.Call.Return(WorkFlow, err) +func (_c *MockArgoWorkFlowStore_FindByTaskID_Call) Return(_a0 *database.ArgoWorkflow, _a1 error) *MockArgoWorkFlowStore_FindByTaskID_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockArgoWorkFlowStore_FindByTaskID_Call) RunAndReturn(run func(context.Context, string) (database.ArgoWorkflow, error)) *MockArgoWorkFlowStore_FindByTaskID_Call { +func (_c *MockArgoWorkFlowStore_FindByTaskID_Call) RunAndReturn(run func(context.Context, string) (*database.ArgoWorkflow, error)) *MockArgoWorkFlowStore_FindByTaskID_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/component/mock_PlatformDataflowComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_PlatformDataflowComponent.go new file mode 100644 index 000000000..cc05397f4 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/component/mock_PlatformDataflowComponent.go @@ -0,0 +1,377 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + deploy "opencsg.com/csghub-server/builder/deploy" + + types "opencsg.com/csghub-server/common/types" +) + +// MockPlatformDataflowComponent is an autogenerated mock type for the PlatformDataflowComponent type +type MockPlatformDataflowComponent struct { + mock.Mock +} + +type MockPlatformDataflowComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPlatformDataflowComponent) EXPECT() *MockPlatformDataflowComponent_Expecter { + return &MockPlatformDataflowComponent_Expecter{mock: &_m.Mock} +} + +// CheckUserPermission provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) CheckUserPermission(ctx context.Context, req types.DataflowLogReq) (bool, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CheckUserPermission") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) (bool, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) bool); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DataflowLogReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPlatformDataflowComponent_CheckUserPermission_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckUserPermission' +type MockPlatformDataflowComponent_CheckUserPermission_Call struct { + *mock.Call +} + +// CheckUserPermission is a helper method to define mock.On call +// - ctx context.Context +// - req types.DataflowLogReq +func (_e *MockPlatformDataflowComponent_Expecter) CheckUserPermission(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_CheckUserPermission_Call { + return &MockPlatformDataflowComponent_CheckUserPermission_Call{Call: _e.mock.On("CheckUserPermission", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_CheckUserPermission_Call) Run(run func(ctx context.Context, req types.DataflowLogReq)) *MockPlatformDataflowComponent_CheckUserPermission_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DataflowLogReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_CheckUserPermission_Call) Return(_a0 bool, _a1 error) *MockPlatformDataflowComponent_CheckUserPermission_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPlatformDataflowComponent_CheckUserPermission_Call) RunAndReturn(run func(context.Context, types.DataflowLogReq) (bool, error)) *MockPlatformDataflowComponent_CheckUserPermission_Call { + _c.Call.Return(run) + return _c +} + +// CreateJob provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) CreateJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateJob") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoJobReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPlatformDataflowComponent_CreateJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJob' +type MockPlatformDataflowComponent_CreateJob_Call struct { + *mock.Call +} + +// CreateJob is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoJobReq +func (_e *MockPlatformDataflowComponent_Expecter) CreateJob(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_CreateJob_Call { + return &MockPlatformDataflowComponent_CreateJob_Call{Call: _e.mock.On("CreateJob", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_CreateJob_Call) Run(run func(ctx context.Context, req *types.DataflowArgoJobReq)) *MockPlatformDataflowComponent_CreateJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoJobReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_CreateJob_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockPlatformDataflowComponent_CreateJob_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPlatformDataflowComponent_CreateJob_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)) *MockPlatformDataflowComponent_CreateJob_Call { + _c.Call.Return(run) + return _c +} + +// DeleteJob provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) DeleteJob(ctx context.Context, req *types.DataflowDeleteReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteJob") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowDeleteReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPlatformDataflowComponent_DeleteJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteJob' +type MockPlatformDataflowComponent_DeleteJob_Call struct { + *mock.Call +} + +// DeleteJob is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowDeleteReq +func (_e *MockPlatformDataflowComponent_Expecter) DeleteJob(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_DeleteJob_Call { + return &MockPlatformDataflowComponent_DeleteJob_Call{Call: _e.mock.On("DeleteJob", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_DeleteJob_Call) Run(run func(ctx context.Context, req *types.DataflowDeleteReq)) *MockPlatformDataflowComponent_DeleteJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowDeleteReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_DeleteJob_Call) Return(_a0 error) *MockPlatformDataflowComponent_DeleteJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPlatformDataflowComponent_DeleteJob_Call) RunAndReturn(run func(context.Context, *types.DataflowDeleteReq) error) *MockPlatformDataflowComponent_DeleteJob_Call { + _c.Call.Return(run) + return _c +} + +// GetJob provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) GetJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetJob") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoJobReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPlatformDataflowComponent_GetJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJob' +type MockPlatformDataflowComponent_GetJob_Call struct { + *mock.Call +} + +// GetJob is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoJobReq +func (_e *MockPlatformDataflowComponent_Expecter) GetJob(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_GetJob_Call { + return &MockPlatformDataflowComponent_GetJob_Call{Call: _e.mock.On("GetJob", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_GetJob_Call) Run(run func(ctx context.Context, req *types.DataflowArgoJobReq)) *MockPlatformDataflowComponent_GetJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoJobReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_GetJob_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockPlatformDataflowComponent_GetJob_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPlatformDataflowComponent_GetJob_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)) *MockPlatformDataflowComponent_GetJob_Call { + _c.Call.Return(run) + return _c +} + +// ReadJobLogsInStream provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) ReadJobLogsInStream(ctx context.Context, req types.DataflowLogReq) (*deploy.MultiLogReader, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ReadJobLogsInStream") + } + + var r0 *deploy.MultiLogReader + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) (*deploy.MultiLogReader, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) *deploy.MultiLogReader); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*deploy.MultiLogReader) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DataflowLogReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPlatformDataflowComponent_ReadJobLogsInStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadJobLogsInStream' +type MockPlatformDataflowComponent_ReadJobLogsInStream_Call struct { + *mock.Call +} + +// ReadJobLogsInStream is a helper method to define mock.On call +// - ctx context.Context +// - req types.DataflowLogReq +func (_e *MockPlatformDataflowComponent_Expecter) ReadJobLogsInStream(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_ReadJobLogsInStream_Call { + return &MockPlatformDataflowComponent_ReadJobLogsInStream_Call{Call: _e.mock.On("ReadJobLogsInStream", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsInStream_Call) Run(run func(ctx context.Context, req types.DataflowLogReq)) *MockPlatformDataflowComponent_ReadJobLogsInStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DataflowLogReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsInStream_Call) Return(_a0 *deploy.MultiLogReader, _a1 error) *MockPlatformDataflowComponent_ReadJobLogsInStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsInStream_Call) RunAndReturn(run func(context.Context, types.DataflowLogReq) (*deploy.MultiLogReader, error)) *MockPlatformDataflowComponent_ReadJobLogsInStream_Call { + _c.Call.Return(run) + return _c +} + +// ReadJobLogsNonStream provides a mock function with given fields: ctx, req +func (_m *MockPlatformDataflowComponent) ReadJobLogsNonStream(ctx context.Context, req types.DataflowLogReq) (string, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for ReadJobLogsNonStream") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) (string, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, types.DataflowLogReq) string); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, types.DataflowLogReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPlatformDataflowComponent_ReadJobLogsNonStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadJobLogsNonStream' +type MockPlatformDataflowComponent_ReadJobLogsNonStream_Call struct { + *mock.Call +} + +// ReadJobLogsNonStream is a helper method to define mock.On call +// - ctx context.Context +// - req types.DataflowLogReq +func (_e *MockPlatformDataflowComponent_Expecter) ReadJobLogsNonStream(ctx interface{}, req interface{}) *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call { + return &MockPlatformDataflowComponent_ReadJobLogsNonStream_Call{Call: _e.mock.On("ReadJobLogsNonStream", ctx, req)} +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call) Run(run func(ctx context.Context, req types.DataflowLogReq)) *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.DataflowLogReq)) + }) + return _c +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call) Return(_a0 string, _a1 error) *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call) RunAndReturn(run func(context.Context, types.DataflowLogReq) (string, error)) *MockPlatformDataflowComponent_ReadJobLogsNonStream_Call { + _c.Call.Return(run) + return _c +} + +// NewMockPlatformDataflowComponent creates a new instance of MockPlatformDataflowComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPlatformDataflowComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPlatformDataflowComponent { + mock := &MockPlatformDataflowComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/runner/component/mock_DataflowComponent.go b/_mocks/opencsg.com/csghub-server/runner/component/mock_DataflowComponent.go new file mode 100644 index 000000000..e1c35d50f --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/runner/component/mock_DataflowComponent.go @@ -0,0 +1,234 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + types "opencsg.com/csghub-server/common/types" +) + +// MockDataflowComponent is an autogenerated mock type for the DataflowComponent type +type MockDataflowComponent struct { + mock.Mock +} + +type MockDataflowComponent_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDataflowComponent) EXPECT() *MockDataflowComponent_Expecter { + return &MockDataflowComponent_Expecter{mock: &_m.Mock} +} + +// CreateWorkflow provides a mock function with given fields: ctx, req +func (_m *MockDataflowComponent) CreateWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateWorkflow") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoJobReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoJobReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataflowComponent_CreateWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateWorkflow' +type MockDataflowComponent_CreateWorkflow_Call struct { + *mock.Call +} + +// CreateWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoJobReq +func (_e *MockDataflowComponent_Expecter) CreateWorkflow(ctx interface{}, req interface{}) *MockDataflowComponent_CreateWorkflow_Call { + return &MockDataflowComponent_CreateWorkflow_Call{Call: _e.mock.On("CreateWorkflow", ctx, req)} +} + +func (_c *MockDataflowComponent_CreateWorkflow_Call) Run(run func(ctx context.Context, req *types.DataflowArgoJobReq)) *MockDataflowComponent_CreateWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoJobReq)) + }) + return _c +} + +func (_c *MockDataflowComponent_CreateWorkflow_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockDataflowComponent_CreateWorkflow_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataflowComponent_CreateWorkflow_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error)) *MockDataflowComponent_CreateWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// DeleteWorkflow provides a mock function with given fields: ctx, req +func (_m *MockDataflowComponent) DeleteWorkflow(ctx context.Context, req *types.DataflowArgoReq) error { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for DeleteWorkflow") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoReq) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDataflowComponent_DeleteWorkflow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWorkflow' +type MockDataflowComponent_DeleteWorkflow_Call struct { + *mock.Call +} + +// DeleteWorkflow is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoReq +func (_e *MockDataflowComponent_Expecter) DeleteWorkflow(ctx interface{}, req interface{}) *MockDataflowComponent_DeleteWorkflow_Call { + return &MockDataflowComponent_DeleteWorkflow_Call{Call: _e.mock.On("DeleteWorkflow", ctx, req)} +} + +func (_c *MockDataflowComponent_DeleteWorkflow_Call) Run(run func(ctx context.Context, req *types.DataflowArgoReq)) *MockDataflowComponent_DeleteWorkflow_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoReq)) + }) + return _c +} + +func (_c *MockDataflowComponent_DeleteWorkflow_Call) Return(_a0 error) *MockDataflowComponent_DeleteWorkflow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDataflowComponent_DeleteWorkflow_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoReq) error) *MockDataflowComponent_DeleteWorkflow_Call { + _c.Call.Return(run) + return _c +} + +// GetStatus provides a mock function with given fields: ctx, req +func (_m *MockDataflowComponent) GetStatus(ctx context.Context, req *types.DataflowArgoReq) (*types.DataflowArgoJobResp, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for GetStatus") + } + + var r0 *types.DataflowArgoJobResp + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoReq) (*types.DataflowArgoJobResp, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *types.DataflowArgoReq) *types.DataflowArgoJobResp); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.DataflowArgoJobResp) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *types.DataflowArgoReq) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataflowComponent_GetStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatus' +type MockDataflowComponent_GetStatus_Call struct { + *mock.Call +} + +// GetStatus is a helper method to define mock.On call +// - ctx context.Context +// - req *types.DataflowArgoReq +func (_e *MockDataflowComponent_Expecter) GetStatus(ctx interface{}, req interface{}) *MockDataflowComponent_GetStatus_Call { + return &MockDataflowComponent_GetStatus_Call{Call: _e.mock.On("GetStatus", ctx, req)} +} + +func (_c *MockDataflowComponent_GetStatus_Call) Run(run func(ctx context.Context, req *types.DataflowArgoReq)) *MockDataflowComponent_GetStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*types.DataflowArgoReq)) + }) + return _c +} + +func (_c *MockDataflowComponent_GetStatus_Call) Return(_a0 *types.DataflowArgoJobResp, _a1 error) *MockDataflowComponent_GetStatus_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataflowComponent_GetStatus_Call) RunAndReturn(run func(context.Context, *types.DataflowArgoReq) (*types.DataflowArgoJobResp, error)) *MockDataflowComponent_GetStatus_Call { + _c.Call.Return(run) + return _c +} + +// RunInformer provides a mock function with no fields +func (_m *MockDataflowComponent) RunInformer() { + _m.Called() +} + +// MockDataflowComponent_RunInformer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RunInformer' +type MockDataflowComponent_RunInformer_Call struct { + *mock.Call +} + +// RunInformer is a helper method to define mock.On call +func (_e *MockDataflowComponent_Expecter) RunInformer() *MockDataflowComponent_RunInformer_Call { + return &MockDataflowComponent_RunInformer_Call{Call: _e.mock.On("RunInformer")} +} + +func (_c *MockDataflowComponent_RunInformer_Call) Run(run func()) *MockDataflowComponent_RunInformer_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDataflowComponent_RunInformer_Call) Return() *MockDataflowComponent_RunInformer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockDataflowComponent_RunInformer_Call) RunAndReturn(run func()) *MockDataflowComponent_RunInformer_Call { + _c.Run(run) + return _c +} + +// NewMockDataflowComponent creates a new instance of MockDataflowComponent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDataflowComponent(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDataflowComponent { + mock := &MockDataflowComponent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/handler/platform_dataflow.go b/api/handler/platform_dataflow.go new file mode 100644 index 000000000..43769f7e1 --- /dev/null +++ b/api/handler/platform_dataflow.go @@ -0,0 +1,220 @@ +package handler + +import ( + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "opencsg.com/csghub-server/api/httpbase" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/component" +) + +type PlatformDataflowHandler struct { + component component.PlatformDataflowComponent +} + +func NewPlatformDataflowHandler(config *config.Config) (*PlatformDataflowHandler, error) { + c, err := component.NewPlatformDataflowComponent(config) + if err != nil { + return nil, err + } + return &PlatformDataflowHandler{ + component: c, + }, nil +} + +func (h *PlatformDataflowHandler) CreateJob(ctx *gin.Context) { + currentUserName := httpbase.GetCurrentUser(ctx) + currentUserUUID := httpbase.GetCurrentUserUUID(ctx) + nsUUID := ctx.Param("uuid") + + if len(nsUUID) < 1 { + err := fmt.Errorf("ns_uuid is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + var req types.DataflowArgoJobReq + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.ErrorContext(ctx.Request.Context(), "failed to bind request body", slog.Any("error", err)) + httpbase.BadRequest(ctx, err.Error()) + return + } + + req.OpUserUUID = currentUserUUID + req.Username = currentUserName + req.NSUUID = nsUUID + + resp, err := h.component.CreateJob(ctx.Request.Context(), &req) + if err != nil { + slog.ErrorContext(ctx.Request.Context(), "failed to create dataflow workflow job", + slog.Any("error", err), slog.Any("req", req)) + httpbase.ServerError(ctx, err) + return + } + + httpbase.Created(ctx, resp) +} + +func (h *PlatformDataflowHandler) DeleteJob(ctx *gin.Context) { + currentUserName := httpbase.GetCurrentUser(ctx) + currentUserUUID := httpbase.GetCurrentUserUUID(ctx) + taskID := ctx.Param("task_id") + nsUUID := ctx.Param("uuid") + if len(taskID) < 1 { + err := fmt.Errorf("task_id is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + if len(nsUUID) < 1 { + err := fmt.Errorf("ns_uuid is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + req := &types.DataflowDeleteReq{ + OpUserUUID: currentUserUUID, + Username: currentUserName, + ArgoTaskID: taskID, + NSUUID: nsUUID, + } + + err := h.component.DeleteJob(ctx.Request.Context(), req) + if err != nil { + slog.ErrorContext(ctx.Request.Context(), "failed to delete dataflow workflow job", + slog.Any("error", err), slog.String("taskid", taskID), slog.String("nsuuid", nsUUID)) + httpbase.ServerError(ctx, err) + return + } + + httpbase.OK(ctx, nil) +} + +func (h *PlatformDataflowHandler) GetJob(ctx *gin.Context) { + currentUserName := httpbase.GetCurrentUser(ctx) + currentUserUUID := httpbase.GetCurrentUserUUID(ctx) + taskID := ctx.Param("task_id") + nsUUID := ctx.Param("uuid") + + if len(taskID) < 1 { + err := fmt.Errorf("task_id is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + if len(nsUUID) < 1 { + err := fmt.Errorf("ns_uuid is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + req := &types.DataflowArgoJobReq{ + OpUserUUID: currentUserUUID, + Username: currentUserName, + NSUUID: nsUUID, + ArgoTaskID: taskID, + } + + resp, err := h.component.GetJob(ctx.Request.Context(), req) + if err != nil { + slog.ErrorContext(ctx.Request.Context(), "failed to get dataflow workflow job", + slog.Any("error", err), slog.String("taskid", taskID)) + httpbase.ServerError(ctx, err) + return + } + + httpbase.OK(ctx, resp) +} + +func (h *PlatformDataflowHandler) GetLogs(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + + since := ctx.Query("since") + stream := ctx.Query("stream") + taskID := ctx.Param("task_id") + dagTaskID := ctx.Query("dag_task_id") + + if len(taskID) < 1 { + err := fmt.Errorf("task_id is required") + httpbase.BadRequest(ctx, err.Error()) + return + } + + req := types.DataflowLogReq{ + CurrentUser: currentUser, + Since: since, + TaskId: taskID, + DagTaskId: dagTaskID, + } + + allow, err := h.component.CheckUserPermission(ctx.Request.Context(), req) + if !allow { + slog.Error("user not allowed to read dataflow job logs", slog.Any("error", err), slog.Any("req", req)) + httpbase.ForbiddenError(ctx, errors.New("user not allowed to read dataflow job logs")) + return + } + + if strings.Trim(stream, " ") == "true" { + h.readLogInStream(ctx, req) + } else { + h.readLogNonStream(ctx, req) + } +} + +func (h *PlatformDataflowHandler) readLogNonStream(ctx *gin.Context, req types.DataflowLogReq) { + logs, err := h.component.ReadJobLogsNonStream(ctx.Request.Context(), req) + if err != nil { + slog.Error("failed to get dataflow job non-stream logs", slog.Any("error", err), slog.Any("req", req)) + httpbase.ServerError(ctx, err) + return + } + httpbase.OK(ctx, logs) +} + +func (h *PlatformDataflowHandler) readLogInStream(ctx *gin.Context, req types.DataflowLogReq) { + logReader, err := h.component.ReadJobLogsInStream(ctx.Request.Context(), req) + if err != nil { + slog.Error("failed to get dataflow job in-stream logs", slog.Any("error", err), slog.Any("req", req)) + httpbase.ServerError(ctx, err) + return + } + + if logReader.RunLog() == nil { + httpbase.ServerError(ctx, errors.New("don't find any dataflow job log")) + return + } + + ctx.Writer.Header().Set("Content-Type", "text/event-stream") + ctx.Writer.Header().Set("Cache-Control", "no-cache") + ctx.Writer.Header().Set("Connection", "keep-alive") + ctx.Writer.Header().Set("Transfer-Encoding", "chunked") + + ctx.Writer.WriteHeader(http.StatusOK) + ctx.Writer.Flush() + + heartbeatTicker := time.NewTicker(30 * time.Second) + defer heartbeatTicker.Stop() + for { + select { + case <-ctx.Request.Context().Done(): + return + case data, ok := <-logReader.RunLog(): + if ok { + ctx.SSEvent("Container", string(data)) + ctx.Writer.Flush() + } + case <-heartbeatTicker.C: + ctx.SSEvent("Heartbeat", "keep-alive") + ctx.Writer.Flush() + default: + time.Sleep(time.Second * 1) + } + } +} diff --git a/api/handler/platform_dataflow_test.go b/api/handler/platform_dataflow_test.go new file mode 100644 index 000000000..6028ff75e --- /dev/null +++ b/api/handler/platform_dataflow_test.go @@ -0,0 +1,250 @@ +package handler + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/testutil" + "opencsg.com/csghub-server/common/types" +) + +type PlatformDataflowTester struct { + *testutil.GinTester + handler *PlatformDataflowHandler + mockComp *mockcomponent.MockPlatformDataflowComponent +} + +func NewPlatformDataflowTester(t *testing.T) *PlatformDataflowTester { + tester := &PlatformDataflowTester{GinTester: testutil.NewGinTester()} + tester.mockComp = mockcomponent.NewMockPlatformDataflowComponent(t) + + tester.handler = &PlatformDataflowHandler{ + component: tester.mockComp, + } + tester.WithParam("uuid", "test-ns-uuid") + return tester +} + +func (t *PlatformDataflowTester) WithHandleFunc(fn func(h *PlatformDataflowHandler) gin.HandlerFunc) *PlatformDataflowTester { + t.Handler(fn(t.handler)) + return t +} + +func (t *PlatformDataflowTester) WithUserAndUUID() *PlatformDataflowTester { + t.Gctx().Set("currentUser", "testuser") + t.Gctx().Set("currentUserUUID", "test-user-uuid") + return t +} + +func TestPlatformDataflowHandler_CreateJob(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.CreateJob + }) + tester.WithUserAndUUID() + + req := &types.DataflowArgoJobReq{ + ResourceId: 1, + JobID: "job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "python script.py", + DagTasks: []types.ArgoDagTask{{ID: "task1", Name: "task1", Template: "main"}}, + } + resp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "argo-task-1", + JobID: "job-1", + JobName: "test-job", + Status: "Pending", + } + + tester.mockComp.EXPECT().CreateJob(tester.Ctx(), &types.DataflowArgoJobReq{ + OpUserUUID: "test-user-uuid", + Username: "testuser", + NSUUID: "test-ns-uuid", + ResourceId: 1, + JobID: "job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "python script.py", + DagTasks: []types.ArgoDagTask{{ID: "task1", Name: "task1", Template: "main"}}, + }).Return(resp, nil) + + tester.WithBody(t, req).Execute() + tester.ResponseEq(t, 201, "Created", resp) + }) + + t.Run("missing_ns_uuid", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.CreateJob + }) + tester.WithUserAndUUID() + tester.WithParam("uuid", "") + + req := &types.DataflowArgoJobReq{} + tester.WithBody(t, req).Execute() + + tester.ResponseEqCode(t, 400) + }) + + t.Run("invalid_request_body", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.CreateJob + }) + tester.WithUserAndUUID() + + tester.Gctx().Request.Body = io.NopCloser(bytes.NewBuffer([]byte("{invalid json"))) + tester.Gctx().Request.Header = map[string][]string{"Content-Type": {"application/json"}} + + tester.Execute() + + tester.ResponseEqCode(t, 400) + }) + + t.Run("component_error", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.CreateJob + }) + tester.WithUserAndUUID() + + req := &types.DataflowArgoJobReq{ + ResourceId: 1, + JobID: "job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "python script.py", + DagTasks: []types.ArgoDagTask{{ID: "task1", Name: "task1", Template: "main"}}, + } + + tester.mockComp.EXPECT().CreateJob(tester.Ctx(), mock.Anything).Return(nil, errors.New("some error")) + + tester.WithBody(t, req).Execute() + tester.ResponseEqCode(t, 500) + }) +} + +func TestPlatformDataflowHandler_DeleteJob(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.DeleteJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + + tester.mockComp.EXPECT().DeleteJob(tester.Ctx(), &types.DataflowDeleteReq{ + OpUserUUID: "test-user-uuid", + Username: "testuser", + NSUUID: "test-ns-uuid", + ArgoTaskID: "argo-task-1", + }).Return(nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) + }) + + t.Run("missing_task_id", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.DeleteJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "") + + tester.Execute() + tester.ResponseEqCode(t, 400) + }) + + t.Run("missing_ns_uuid", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.DeleteJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + tester.WithParam("uuid", "") + + tester.Execute() + tester.ResponseEqCode(t, 400) + }) + + t.Run("component_error", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.DeleteJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + + tester.mockComp.EXPECT().DeleteJob(tester.Ctx(), mock.Anything).Return(errors.New("some error")) + + tester.Execute() + tester.ResponseEqCode(t, 500) + }) +} + +func TestPlatformDataflowHandler_GetJob(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.GetJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + + resp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "argo-task-1", + JobID: "job-1", + JobName: "test-job", + Status: "Running", + } + + tester.mockComp.EXPECT().GetJob(tester.Ctx(), &types.DataflowArgoJobReq{ + OpUserUUID: "test-user-uuid", + Username: "testuser", + NSUUID: "test-ns-uuid", + ArgoTaskID: "argo-task-1", + }).Return(resp, nil) + + tester.Execute() + tester.ResponseEq(t, 200, tester.OKText, resp) + }) + + t.Run("missing_task_id", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.GetJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "") + + tester.Execute() + tester.ResponseEqCode(t, 400) + }) + + t.Run("missing_ns_uuid", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.GetJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + tester.WithParam("uuid", "") + + tester.Execute() + tester.ResponseEqCode(t, 400) + }) + + t.Run("component_error", func(t *testing.T) { + tester := NewPlatformDataflowTester(t).WithHandleFunc(func(h *PlatformDataflowHandler) gin.HandlerFunc { + return h.GetJob + }) + tester.WithUserAndUUID() + tester.WithParam("task_id", "argo-task-1") + + tester.mockComp.EXPECT().GetJob(tester.Ctx(), mock.Anything).Return(nil, errors.New("some error")) + + tester.Execute() + tester.ResponseEqCode(t, 500) + }) +} diff --git a/api/router/api.go b/api/router/api.go index a1327c2ff..0341c8643 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -518,7 +518,12 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) { if err != nil { return nil, fmt.Errorf("error creating data flow proxy handler:%w", err) } - createDataflowRoutes(apiGroup, dataflowHandler) + // platform dataflow + platformDFHandler, err := handler.NewPlatformDataflowHandler(config) + if err != nil { + return nil, fmt.Errorf("error creating platform dataflow handler:%w", err) + } + createDataflowRoutes(apiGroup, dataflowHandler, platformDFHandler) err = createAdvancedRoutes(apiGroup, adminGroup, middlewareCollection, config, mqFactory) if err != nil { @@ -1270,10 +1275,18 @@ func createPromptRoutes( promptInfoGrp.GET("/:namespace/:name", promptHandler.PromptDetail) } -func createDataflowRoutes(apiGroup *gin.RouterGroup, dataflowHandler *handler.DataflowProxyHandler) { +func createDataflowRoutes(apiGroup *gin.RouterGroup, dataflowHandler *handler.DataflowProxyHandler, platformDFHandler *handler.PlatformDataflowHandler) { dataflowGrp := apiGroup.Group("/dataflow") dataflowGrp.Use(middleware.MustLogin()) dataflowGrp.Any("/*any", dataflowHandler.Proxy) + + platformGrp := apiGroup.Group("/platform") + platformDFGrp := platformGrp.Group("/dataflow") + platformDFGrp.Use(middleware.MustLogin()) + platformDFGrp.POST("/:uuid/jobs", platformDFHandler.CreateJob) + platformDFGrp.DELETE("/:uuid/jobs/:task_id", platformDFHandler.DeleteJob) + platformDFGrp.GET("/:uuid/jobs/:task_id", platformDFHandler.GetJob) + platformDFGrp.GET("/:uuid/jobs/:task_id/logs", platformDFHandler.GetLogs) } func createMemoryRoutes(apiGroup *gin.RouterGroup, middlewareCollection middleware.MiddlewareCollection, memoryHandler *handler.MemoryHandler) { diff --git a/builder/deploy/cluster/cluster_manager.go b/builder/deploy/cluster/cluster_manager.go index b417c5126..d68bc9c0b 100644 --- a/builder/deploy/cluster/cluster_manager.go +++ b/builder/deploy/cluster/cluster_manager.go @@ -309,6 +309,8 @@ func buildCluster(kubeconfig *rest.Config, id string, index int, connectMode typ return nil, fmt.Errorf("failed to add cluster info to db error: %w", err) } if !cluster.Enable { + slog.Info("cluster is disabled, will not be used", slog.String("cluster_id", cluster.ClusterID), + slog.String("config", cluster.ClusterConfig), slog.String("region", region)) return nil, nil } @@ -620,7 +622,7 @@ func getXPULabel(labels map[string]string, config *config.Config) (string, strin return "enflame.com/gcu.count", "enflame.com/gcu.model", []string{"enflame.com/gcu.mem"} } if _, found := labels["amd.com/gpu"]; found { - //for enflame gcu + //for amd gpu return "amd.com/gpu", "amd.com/gpu.product-name", []string{"amd.com/gpu.vram"} } //check custom gpu model label diff --git a/builder/deploy/deployer.go b/builder/deploy/deployer.go index 644d4839f..48d6716f9 100644 --- a/builder/deploy/deployer.go +++ b/builder/deploy/deployer.go @@ -58,11 +58,13 @@ type Deployer interface { CheckHeartbeatTimeout(ctx context.Context, clusterId string) (bool, error) SubmitFinetuneJob(ctx context.Context, req types.FinetuneReq) (*types.ArgoWorkFlowRes, error) DeleteFinetuneJob(ctx context.Context, req types.ArgoWorkFlowDeleteReq) error - GetWorkflowLogsInStream(ctx context.Context, req types.FinetuneLogReq) (*MultiLogReader, error) - GetWorkflowLogsNonStream(ctx context.Context, req types.FinetuneLogReq) (*loki.LokiQueryResponse, error) + GetWorkflowLogsInStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*MultiLogReader, error) + GetWorkflowLogsNonStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*loki.LokiQueryResponse, error) IsDefaultScheduler() bool GetSharedModeResourceName(config *config.Config) string LabelNode(ctx context.Context, req *types.NodeLabel) error + CreateDataflowJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) + DeleteDataflowJob(ctx context.Context, req *types.DataflowArgoReq) error } func (d *deployer) GenerateUniqueSvcName(dr types.DeployRequest) string { @@ -992,11 +994,7 @@ func (d *deployer) DeleteFinetuneJob(ctx context.Context, req types.ArgoWorkFlow return nil } -func (d *deployer) GetWorkflowLogsInStream(ctx context.Context, req types.FinetuneLogReq) (*MultiLogReader, error) { - slog.Info("GetWorkflowLogsInStream", slog.Any("req", req)) - labels := map[string]string{ - types.StreamKeyInstanceName: req.PodName, - } +func (d *deployer) GetWorkflowLogsInStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*MultiLogReader, error) { var startTime = req.SubmitTime if len(req.Since) > 0 { @@ -1016,10 +1014,8 @@ func (d *deployer) GetWorkflowLogsInStream(ctx context.Context, req types.Finetu return NewMultiLogReader(nil, runLog), nil } -func (d *deployer) GetWorkflowLogsNonStream(ctx context.Context, req types.FinetuneLogReq) (*loki.LokiQueryResponse, error) { - labels := map[string]string{ - types.StreamKeyInstanceName: req.PodName, - } +func (d *deployer) GetWorkflowLogsNonStream(ctx context.Context, req types.WorkflowLogReq, labels map[string]string) (*loki.LokiQueryResponse, error) { + query := d.lokiClient.GenerateLabelQuery(labels) var startTime = req.SubmitTime if req.Since != "" { diff --git a/builder/deploy/deployer_dataflow.go b/builder/deploy/deployer_dataflow.go new file mode 100644 index 000000000..215f96f04 --- /dev/null +++ b/builder/deploy/deployer_dataflow.go @@ -0,0 +1,51 @@ +package deploy + +import ( + "context" + "fmt" + + "opencsg.com/csghub-server/common/types" +) + +func (d *deployer) CreateDataflowJob(ctx context.Context, req *types.DataflowCreateReq) (*types.DataflowArgoJobResp, error) { + runnerReq := &types.DataflowArgoJobReq{ + ID: req.ID, + ClusterID: req.ClusterID, + ArgoTaskID: req.ArgoTaskID, + ResourceName: req.ResourceName, + OpUserUUID: req.OpUserUUID, + Username: req.Username, + NSUUID: req.NSUUID, + // dataflow specific + RepoIds: req.RepoIds, + ResourceId: req.ResourceId, + JobID: req.JobID, + JobName: req.JobName, + JobDesc: req.JobDesc, + StorageSize: req.StorageSize, + Entrypoint: req.Entrypoint, + Template: req.Template, + DagTasks: req.DagTasks, + // extra + Nodes: req.Nodes, + Scheduler: d.kubeScheduler, + DeployExtend: types.DeployExtend{ + NodeAffinity: req.NodeAffinity, + Tolerations: req.Tolerations, + }, + } + + resp, err := d.imageRunner.CreateDataflowWorkflow(ctx, runnerReq) + if err != nil { + return nil, fmt.Errorf("failed to create dataflow job %s workflow error: %w", req.JobID, err) + } + return resp, nil +} + +func (d *deployer) DeleteDataflowJob(ctx context.Context, req *types.DataflowArgoReq) error { + err := d.imageRunner.DeleteDataflowWorkflow(ctx, req) + if err != nil { + return fmt.Errorf("failed to delete dataflow %s workflow error: %w", req.ArgoTaskID, err) + } + return nil +} diff --git a/builder/deploy/deployer_dataflow_test.go b/builder/deploy/deployer_dataflow_test.go new file mode 100644 index 000000000..f9d6d147a --- /dev/null +++ b/builder/deploy/deployer_dataflow_test.go @@ -0,0 +1,228 @@ +package deploy + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockrunner "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/imagerunner" + "opencsg.com/csghub-server/common/types" +) + +func TestDeployer_CreateDataflowWorkflow(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + req := &types.DataflowCreateReq{ + ID: 1, + ClusterID: "cluster-1", + ArgoTaskID: "task-123", + ResourceName: "gpu-resource", + OpUserUUID: "user-uuid-1", + Username: "testuser", + NSUUID: "ns-uuid-1", + RepoIds: []string{"repo1", "repo2"}, + ResourceId: 100, + JobID: "job-1", + JobName: "test-job", + JobDesc: "test description", + StorageSize: "10Gi", + Entrypoint: "main.py", + Template: types.ArgoFlowTemplate{ + Name: "template-1", + }, + DagTasks: []types.ArgoDagTask{ + { + ID: "task-1", + Name: "dag-task-1", + }, + }, + Nodes: []types.Node{ + { + Name: "node-1", + }, + }, + DeployExtend: types.DeployExtend{ + NodeAffinity: nil, + Tolerations: nil, + }, + } + + expectedResp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "task-123", + JobID: "job-1", + JobName: "test-job", + Status: "Running", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().CreateDataflowWorkflow(ctx, mock.Anything).RunAndReturn( + func(ctx context.Context, r *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + require.Equal(t, req.ID, r.ID) + require.Equal(t, req.ClusterID, r.ClusterID) + require.Equal(t, req.ArgoTaskID, r.ArgoTaskID) + require.Equal(t, req.ResourceName, r.ResourceName) + require.Equal(t, req.OpUserUUID, r.OpUserUUID) + require.Equal(t, req.Username, r.Username) + require.Equal(t, req.NSUUID, r.NSUUID) + require.Equal(t, req.RepoIds, r.RepoIds) + require.Equal(t, req.ResourceId, r.ResourceId) + require.Equal(t, req.JobID, r.JobID) + require.Equal(t, req.JobName, r.JobName) + require.Equal(t, req.JobDesc, r.JobDesc) + require.Equal(t, req.StorageSize, r.StorageSize) + require.Equal(t, req.Entrypoint, r.Entrypoint) + require.Equal(t, req.Template, r.Template) + require.Equal(t, req.DagTasks, r.DagTasks) + require.Equal(t, req.Nodes, r.Nodes) + require.Equal(t, req.NodeAffinity, r.NodeAffinity) + require.Equal(t, req.Tolerations, r.Tolerations) + return expectedResp, nil + }, + ) + + d := &deployer{ + imageRunner: mockRunner, + } + + resp, err := d.CreateDataflowJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, expectedResp, resp) + }) + + t.Run("runner returns error", func(t *testing.T) { + req := &types.DataflowCreateReq{ + ID: 1, + ClusterID: "cluster-1", + ArgoTaskID: "task-123", + ResourceId: 100, + JobID: "job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "main.py", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().CreateDataflowWorkflow(ctx, mock.Anything).Return(nil, errors.New("runner error")) + + d := &deployer{ + imageRunner: mockRunner, + } + + resp, err := d.CreateDataflowJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to create dataflow job job-1 workflow error") + require.Contains(t, err.Error(), "runner error") + }) + + t.Run("with scheduler", func(t *testing.T) { + req := &types.DataflowCreateReq{ + ID: 1, + ClusterID: "cluster-1", + ArgoTaskID: "task-123", + ResourceName: "gpu-resource", + OpUserUUID: "user-uuid-1", + Username: "testuser", + NSUUID: "ns-uuid-1", + ResourceId: 100, + JobID: "job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "main.py", + } + + scheduler := &types.Scheduler{ + Volcano: &types.VolcanoConfig{ + SchedulerName: "custom-scheduler", + }, + } + + expectedResp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "task-123", + JobID: "job-1", + JobName: "test-job", + Status: "Running", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().CreateDataflowWorkflow(ctx, mock.Anything).RunAndReturn( + func(ctx context.Context, r *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + require.NotNil(t, r.Scheduler) + require.Equal(t, scheduler, r.Scheduler) + return expectedResp, nil + }, + ) + + d := &deployer{ + imageRunner: mockRunner, + kubeScheduler: scheduler, + } + + resp, err := d.CreateDataflowJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + }) +} + +func TestDeployer_DeleteDataflowWorkflow(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + req := &types.DataflowArgoReq{ + ArgoTaskID: "task-123", + ClusterID: "cluster-1", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().DeleteDataflowWorkflow(ctx, req).Return(nil) + + d := &deployer{ + imageRunner: mockRunner, + } + + err := d.DeleteDataflowJob(ctx, req) + require.NoError(t, err) + }) + + t.Run("runner returns error", func(t *testing.T) { + req := &types.DataflowArgoReq{ + ArgoTaskID: "task-456", + ClusterID: "cluster-2", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().DeleteDataflowWorkflow(ctx, req).Return(errors.New("delete failed")) + + d := &deployer{ + imageRunner: mockRunner, + } + + err := d.DeleteDataflowJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to delete dataflow task-456 workflow error") + require.Contains(t, err.Error(), "delete failed") + }) + + t.Run("empty cluster id", func(t *testing.T) { + req := &types.DataflowArgoReq{ + ArgoTaskID: "task-789", + ClusterID: "", + } + + mockRunner := mockrunner.NewMockRunner(t) + mockRunner.EXPECT().DeleteDataflowWorkflow(ctx, req).Return(nil) + + d := &deployer{ + imageRunner: mockRunner, + } + + err := d.DeleteDataflowJob(ctx, req) + require.NoError(t, err) + }) +} diff --git a/builder/deploy/deployer_test.go b/builder/deploy/deployer_test.go index 82faeeb32..496d8e4fc 100644 --- a/builder/deploy/deployer_test.go +++ b/builder/deploy/deployer_test.go @@ -938,7 +938,7 @@ func TestDeployer_DeleteEvaluation_Error(t *testing.T) { func TestDeployer_GetWorkflowLogsInStream(t *testing.T) { now := time.Now() - req := types.FinetuneLogReq{ + req := types.WorkflowLogReq{ CurrentUser: "test-user", PodName: "pod1", SubmitTime: now, @@ -954,7 +954,7 @@ func TestDeployer_GetWorkflowLogsInStream(t *testing.T) { deployTaskStore: mockDeployTaskStore, lokiClient: sender, } - lreader, err := d.GetWorkflowLogsInStream(context.TODO(), req) + lreader, err := d.GetWorkflowLogsInStream(context.TODO(), req, nil) require.Nil(t, err) require.Nil(t, lreader.buildLogs) require.NotNil(t, lreader.RunLog()) @@ -962,7 +962,7 @@ func TestDeployer_GetWorkflowLogsInStream(t *testing.T) { func TestDeployer_GetWorkflowLogsNonStream(t *testing.T) { now := time.Now() - req := types.FinetuneLogReq{ + req := types.WorkflowLogReq{ CurrentUser: "test-user", PodName: "pod1", SubmitTime: now, @@ -979,7 +979,7 @@ func TestDeployer_GetWorkflowLogsNonStream(t *testing.T) { deployTaskStore: mockDeployTaskStore, lokiClient: sender, } - resp, err := d.GetWorkflowLogsNonStream(context.TODO(), req) + resp, err := d.GetWorkflowLogsNonStream(context.TODO(), req, nil) require.Nil(t, err) require.NotNil(t, resp) } diff --git a/builder/deploy/imagerunner/local_runner.go b/builder/deploy/imagerunner/local_runner.go index 158a16fda..af23f2875 100644 --- a/builder/deploy/imagerunner/local_runner.go +++ b/builder/deploy/imagerunner/local_runner.go @@ -6,6 +6,7 @@ import ( "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/builder/deploy/common" "opencsg.com/csghub-server/common/types" + runnerTypes "opencsg.com/csghub-server/runner/types" ) var _ Runner = (*LocalRunner)(nil) @@ -118,3 +119,19 @@ func (h *LocalRunner) ListKsvcVersions(ctx context.Context, clusterID, svcName s func (h *LocalRunner) DeleteKsvcVersion(ctx context.Context, clusterID, svcName, commitID string) error { return nil } + +func (h *LocalRunner) CreateSandbox(ctx context.Context, req *runnerTypes.SandboxRequest) (*runnerTypes.Sandbox, error) { + return nil, nil +} + +func (h *LocalRunner) DeleteSandbox(ctx context.Context, req *runnerTypes.SandboxDeleteRequest) error { + return nil +} + +func (h *LocalRunner) CreateDataflowWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + return nil, nil +} + +func (h *LocalRunner) DeleteDataflowWorkflow(ctx context.Context, req *types.DataflowArgoReq) error { + return nil +} diff --git a/builder/deploy/imagerunner/remote_runner.go b/builder/deploy/imagerunner/remote_runner.go index 5ee2e5cdc..b5444147d 100644 --- a/builder/deploy/imagerunner/remote_runner.go +++ b/builder/deploy/imagerunner/remote_runner.go @@ -578,3 +578,36 @@ func (h *RemoteRunner) DeleteSandbox(ctx context.Context, req *runnerTypes.Sandb defer response.Body.Close() return nil } + +func (h *RemoteRunner) CreateDataflowWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + remote, err := h.GetRemoteRunnerHost(ctx, req.ClusterID) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/api/v1/dataflow/jobs", remote) + response, err := h.doRequest(ctx, http.MethodPost, url, req) + if err != nil { + return nil, fmt.Errorf("failed to create dataflow workflow, %w", err) + } + defer response.Body.Close() + + var res types.DataflowArgoJobResp + if err := json.NewDecoder(response.Body).Decode(&res); err != nil { + return nil, errorx.InternalServerError(err, nil) + } + return &res, nil +} + +func (h *RemoteRunner) DeleteDataflowWorkflow(ctx context.Context, req *types.DataflowArgoReq) error { + remote, err := h.GetRemoteRunnerHost(ctx, req.ClusterID) + if err != nil { + return err + } + url := fmt.Sprintf("%s/api/v1/dataflow/jobs/%s?cluster_id=%s", remote, req.ArgoTaskID, req.ClusterID) + response, err := h.doRequest(ctx, http.MethodDelete, url, nil) + if err != nil { + return fmt.Errorf("failed to delete dataflow workflow, %w", err) + } + defer response.Body.Close() + return nil +} diff --git a/builder/deploy/imagerunner/remote_runner_test.go b/builder/deploy/imagerunner/remote_runner_test.go index 5f590fea2..093130f04 100644 --- a/builder/deploy/imagerunner/remote_runner_test.go +++ b/builder/deploy/imagerunner/remote_runner_test.go @@ -659,3 +659,211 @@ func TestRemoteRunner_SetVersionsTraffic_HTTPError(t *testing.T) { t.Errorf("expected error message to contain 'failed to update traffic', got %v", err) } } + +func TestRemoteRunner_CreateDataflowWorkflow_Success(t *testing.T) { + req := &types.DataflowArgoJobReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + ResourceName: "test-resource", + JobID: "test-job", + JobName: "test-job-name", + } + + expectedResp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "test-argo-task", + JobID: "test-job", + JobName: "test-job-name", + Status: "Running", + Message: "Workflow created successfully", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/dataflow/jobs" { + t.Errorf("expected path /api/v1/dataflow/jobs, got %s", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("expected method POST, got %s", r.Method) + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(expectedResp) + require.Nil(t, err) + })) + defer server.Close() + + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{ + Mode: types.ConnectModeInCluster, + RunnerEndpoint: server.URL, + }, nil).Once() + + remoteURL, _ := url.Parse(server.URL) + runner := &RemoteRunner{ + remote: remoteURL, + client: server.Client(), + clusterStore: mockClusterStore, + } + + got, err := runner.CreateDataflowWorkflow(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(got, expectedResp) { + t.Errorf("expected response %v, got %v", expectedResp, got) + } +} + +func TestRemoteRunner_CreateDataflowWorkflow_ClusterStoreError(t *testing.T) { + req := &types.DataflowArgoJobReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + } + + expectedErr := errors.New("database error") + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{}, expectedErr).Once() + + remoteURL, _ := url.Parse("http://default.runner") + runner := &RemoteRunner{ + remote: remoteURL, + client: &http.Client{}, + clusterStore: mockClusterStore, + } + + _, err := runner.CreateDataflowWorkflow(context.Background(), req) + if err == nil { + t.Fatal("expected an error, but got nil") + } + if !errors.Is(err, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, err) + } +} + +func TestRemoteRunner_CreateDataflowWorkflow_HTTPError(t *testing.T) { + req := &types.DataflowArgoJobReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{ + Mode: types.ConnectModeInCluster, + RunnerEndpoint: server.URL, + }, nil).Once() + + remoteURL, _ := url.Parse(server.URL) + runner := &RemoteRunner{ + remote: remoteURL, + client: server.Client(), + clusterStore: mockClusterStore, + } + + _, err := runner.CreateDataflowWorkflow(context.Background(), req) + if err == nil { + t.Fatal("expected an error, but got nil") + } + if !strings.Contains(err.Error(), "failed to create dataflow workflow") { + t.Errorf("expected error message to contain 'failed to create dataflow workflow', got %v", err) + } +} + +func TestRemoteRunner_DeleteDataflowWorkflow_Success(t *testing.T) { + req := &types.DataflowArgoReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := "/api/v1/dataflow/jobs/test-argo-task" + if r.URL.Path != expectedPath { + t.Errorf("expected path %s, got %s", expectedPath, r.URL.Path) + } + if r.Method != http.MethodDelete { + t.Errorf("expected method DELETE, got %s", r.Method) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{ + Mode: types.ConnectModeInCluster, + RunnerEndpoint: server.URL, + }, nil).Once() + + remoteURL, _ := url.Parse(server.URL) + runner := &RemoteRunner{ + remote: remoteURL, + client: server.Client(), + clusterStore: mockClusterStore, + } + + err := runner.DeleteDataflowWorkflow(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRemoteRunner_DeleteDataflowWorkflow_ClusterStoreError(t *testing.T) { + req := &types.DataflowArgoReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + } + + expectedErr := errors.New("database error") + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{}, expectedErr).Once() + + remoteURL, _ := url.Parse("http://default.runner") + runner := &RemoteRunner{ + remote: remoteURL, + client: &http.Client{}, + clusterStore: mockClusterStore, + } + + err := runner.DeleteDataflowWorkflow(context.Background(), req) + if err == nil { + t.Fatal("expected an error, but got nil") + } + if !errors.Is(err, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, err) + } +} + +func TestRemoteRunner_DeleteDataflowWorkflow_HTTPError(t *testing.T) { + req := &types.DataflowArgoReq{ + ClusterID: "test-cluster", + ArgoTaskID: "test-argo-task", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + mockClusterStore := mockdb.NewMockClusterInfoStore(t) + mockClusterStore.EXPECT().ByClusterID(mock.Anything, req.ClusterID).Return(database.ClusterInfo{ + Mode: types.ConnectModeInCluster, + RunnerEndpoint: server.URL, + }, nil).Once() + + remoteURL, _ := url.Parse(server.URL) + runner := &RemoteRunner{ + remote: remoteURL, + client: server.Client(), + clusterStore: mockClusterStore, + } + + err := runner.DeleteDataflowWorkflow(context.Background(), req) + if err == nil { + t.Fatal("expected an error, but got nil") + } + if !strings.Contains(err.Error(), "failed to delete dataflow workflow") { + t.Errorf("expected error message to contain 'failed to delete dataflow workflow', got %v", err) + } +} diff --git a/builder/deploy/imagerunner/runner.go b/builder/deploy/imagerunner/runner.go index 870adaa3e..ab49a6b79 100644 --- a/builder/deploy/imagerunner/runner.go +++ b/builder/deploy/imagerunner/runner.go @@ -5,6 +5,7 @@ import ( "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/common/types" + runnerTypes "opencsg.com/csghub-server/runner/types" ) type Runner interface { @@ -28,4 +29,8 @@ type Runner interface { ListKsvcVersions(ctx context.Context, clusterID, svcName string) ([]types.KsvcRevisionInfo, error) DeleteKsvcVersion(ctx context.Context, clusterID, svcName, commitID string) error LabelNode(ctx context.Context, req *types.NodeLabel) error + CreateSandbox(ctx context.Context, req *runnerTypes.SandboxRequest) (*runnerTypes.Sandbox, error) + DeleteSandbox(ctx context.Context, req *runnerTypes.SandboxDeleteRequest) error + CreateDataflowWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) + DeleteDataflowWorkflow(ctx context.Context, req *types.DataflowArgoReq) error } diff --git a/builder/store/database/argo_workflow.go b/builder/store/database/argo_workflow.go index fb64b7ce1..9189ddabb 100644 --- a/builder/store/database/argo_workflow.go +++ b/builder/store/database/argo_workflow.go @@ -16,7 +16,7 @@ type argoWorkFlowStoreImpl struct { type ArgoWorkFlowStore interface { FindByID(ctx context.Context, id int64) (WorkFlow ArgoWorkflow, err error) - FindByTaskID(ctx context.Context, id string) (WorkFlow ArgoWorkflow, err error) + FindByTaskID(ctx context.Context, id string) (*ArgoWorkflow, error) FindByUsername(ctx context.Context, username string, taskType types.TaskType, per, page int) (WorkFlows []ArgoWorkflow, total int, err error) CreateWorkFlow(ctx context.Context, workFlow ArgoWorkflow) (*ArgoWorkflow, error) UpdateWorkFlowByTaskID(ctx context.Context, workFlow ArgoWorkflow) (*ArgoWorkflow, error) @@ -68,22 +68,27 @@ type ArgoWorkflow struct { FailuresURL string `bun:"," json:"failures_url"` ClusterNode string `bun:"," json:"cluster_node"` QueueName string `bun:"," json:"queue_name"` + DagTasks string `bun:"," json:"dag_tasks"` + DeletedAt time.Time `bun:",soft_delete,nullzero" json:"deleted_at"` } func (s *argoWorkFlowStoreImpl) FindByID(ctx context.Context, id int64) (WorkFlow ArgoWorkflow, err error) { - err = s.db.Operator.Core.NewSelect().Model(&WorkFlow).Where("id = ?", id).Scan(ctx, &WorkFlow) + err = s.db.Operator.Core.NewSelect().Model(&WorkFlow).WhereAllWithDeleted().Where("id = ?", id).Scan(ctx, &WorkFlow) if err != nil { return } return } -func (s *argoWorkFlowStoreImpl) FindByTaskID(ctx context.Context, id string) (WorkFlow ArgoWorkflow, err error) { - err = s.db.Operator.Core.NewSelect().Model(&WorkFlow).Where("task_id = ?", id).Scan(ctx, &WorkFlow) +func (s *argoWorkFlowStoreImpl) FindByTaskID(ctx context.Context, id string) (*ArgoWorkflow, error) { + var err error + workFlow := &ArgoWorkflow{} + q := s.db.Operator.Core.NewSelect().Model(workFlow).WhereAllWithDeleted() + err = q.Where("task_id = ?", id).Scan(ctx, workFlow) if err != nil { - return + return nil, err } - return + return workFlow, nil } func (s *argoWorkFlowStoreImpl) FindByUsername(ctx context.Context, username string, taskType types.TaskType, per, page int) (WorkFlows []ArgoWorkflow, total int, err error) { @@ -112,7 +117,7 @@ func (s *argoWorkFlowStoreImpl) CreateWorkFlow(ctx context.Context, workFlow Arg wf, err := s.FindByTaskID(ctx, workFlow.TaskId) if err == nil && wf.ID != 0 { // already exists - return &wf, nil + return wf, nil } res, err := s.db.Core.NewInsert().Model(&workFlow).Exec(ctx, &workFlow) if err := assertAffectedOneRow(res, err); err != nil { diff --git a/builder/store/database/argo_workflow_test.go b/builder/store/database/argo_workflow_test.go index 9cd230bf9..ebf5bf0da 100644 --- a/builder/store/database/argo_workflow_test.go +++ b/builder/store/database/argo_workflow_test.go @@ -49,9 +49,9 @@ func TestArgoWorkflowStore_CRUD(t *testing.T) { require.Nil(t, err) require.Equal(t, "task-new", flowfind.TaskName) - flowfind, err = store.FindByTaskID(ctx, "tid") + flowfind2, err := store.FindByTaskID(ctx, "tid") require.Nil(t, err) - require.Equal(t, "task-new", flowfind.TaskName) + require.Equal(t, "task-new", flowfind2.TaskName) _, err = store.CreateWorkFlow(ctx, database.ArgoWorkflow{ Username: "user", @@ -74,7 +74,7 @@ func TestArgoWorkflowStore_CRUD(t *testing.T) { err = store.DeleteWorkFlow(ctx, dbflow.ID) require.Nil(t, err) _, err = store.FindByID(ctx, dbflow.ID) - require.NotNil(t, err) + require.Nil(t, err) } diff --git a/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.down.sql b/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.down.sql new file mode 100644 index 000000000..80ff6c343 --- /dev/null +++ b/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.down.sql @@ -0,0 +1,9 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE argo_workflows DROP COLUMN IF EXISTS dag_tasks; + +--bun:split + +ALTER TABLE argo_workflows DROP COLUMN IF EXISTS deleted_at; diff --git a/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.up.sql b/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.up.sql new file mode 100644 index 000000000..60375e574 --- /dev/null +++ b/builder/store/database/migrations/20260430054132_add_column_deletedat_dagtasks_argo.up.sql @@ -0,0 +1,9 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE argo_workflows ADD COLUMN IF NOT EXISTS dag_tasks VARCHAR; + +--bun:split + +ALTER TABLE argo_workflows ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMP; diff --git a/common/types/argo_workflow.go b/common/types/argo_workflow.go index 666656265..da9f11f03 100644 --- a/common/types/argo_workflow.go +++ b/common/types/argo_workflow.go @@ -41,6 +41,7 @@ const ( TaskTypeComparison TaskType = "comparison" TaskTypeLeaderBoard TaskType = "leaderboard" TaskTypeFinetune TaskType = "finetune" + TaskTypeDataflow TaskType = "dataflow" ) type EvaluationReq struct { @@ -84,6 +85,7 @@ type ArgoFlowTemplate struct { HardWare HardWare `json:"hardware,omitempty"` Env map[string]string `json:"env,omitempty"` Annotation map[string]string `json:"annotation,omitempty"` + Parameters []string `json:"parameters"` } type ArgoWorkFlowReq struct { diff --git a/common/types/dataflow.go b/common/types/dataflow.go new file mode 100644 index 000000000..13bd0a084 --- /dev/null +++ b/common/types/dataflow.go @@ -0,0 +1,113 @@ +package types + +import "time" + +const DFPVCNamePrefix = "df-" +const DFParamDagTaskIDKey = "bld_dag_task_id" +const DFParamDagTaskNameKey = "bld_dag_task_name" + +const DFLabelTagKey = "workflow-scope" +const DFLabelTagValue = "csghub-dataflow" +const DFLabelDagTaskIDKey = "csghub_df_dag_task_id" +const DFLabelDagTaskNameKey = "csghub_df_dag_task_name" + +const DFUniqueIDKey = "csghub_df_unique_id" +const DFArgoTaskIDKey = "csghub_df_argo_task_id" +const DFOpUserUUIDKey = "csghub_df_op_user_uuid" +const DFOpUserNameKey = "csghub_df_op_user_name" +const DFNSUUIDKey = "csghub_df_ns_uuid" +const DFClusterIDKey = "csghub_df_cluster_id" +const DFResourceIDKey = "csghub_df_res_id" +const DFResourceNameKey = "csghub_df_res_name" +const DFJobIDKey = "csghub_df_job_id" +const DFJobNameKey = "csghub_df_job_name" +const DFJobDescKey = "csghub_df_job_desc" +const DFImageKey = "csghub_df_image" +const DFStorageSizeKey = "csghub_df_storage_size" + +const DFCancelled = "Canceled" + +// DataflowArgoJobReq - Request body for creating dataflow job +type DataflowArgoJobReq struct { + ID int64 `json:"id"` + ClusterID string `json:"cluster_id"` + ArgoTaskID string `json:"argo_task_id"` // db task id + ResourceName string `json:"resource_name"` + OpUserUUID string `json:"op_user_uuid"` + Username string `json:"username"` + NSUUID string `json:"ns_uuid"` + + RepoIds []string `json:"repo_ids"` // dataset repo ids from dataflow + ResourceId int64 `json:"resource_id" binding:"required"` // from dataflow + JobID string `json:"job_id" binding:"required"` // job id from dataflow + JobName string `json:"job_name" binding:"required"` // job name from dataflow + JobDesc string `json:"job_desc"` // job description from dataflow + StorageSize string `json:"storage_size" binding:"required"` // from dataflow + Entrypoint string `json:"entrypoint" binding:"required"` // from dataflow + Template ArgoFlowTemplate `json:"template" binding:"required"` // from dataflow + DagTasks []ArgoDagTask `json:"dag_tasks" binding:"required"` // from dataflow + + Nodes []Node `json:"nodes"` + Scheduler *Scheduler `json:"scheduler,omitempty"` + DeployExtend +} + +// DataflowCreateReq is an alias for DataflowArgoJobReq (used in deployer) +type DataflowCreateReq = DataflowArgoJobReq + +// DagTask - DAG task definition with dependencies +type ArgoDagTask struct { + ID string `json:"id" binding:"required"` + Name string `json:"name" binding:"required"` + Deps []string `json:"deps"` + Template string `json:"template" binding:"required"` + Parameters []ArgoDagTaskParam `json:"parameters" binding:"required"` +} + +type ArgoDagTaskParam struct { + Name string `json:"name" binding:"required"` + Value string `json:"value" binding:"required"` +} + +// DataflowJobStatusResp - Response for job status query +type DataflowArgoJobResp struct { + ID int64 `json:"id"` + ArgoTaskID string `json:"argo_task_id"` + JobID string `json:"job_id"` + JobName string `json:"job_name"` + Status string `json:"status"` + Message string `json:"message"` + CreatedAt int64 `json:"created_at"` + DagTasks string `json:"dag_tasks"` + DeleteAt int64 `json:"delete_at"` +} + +// DataflowLogsResp - Response for job logs query +type DataflowArgoReq struct { + ArgoTaskID string `json:"argo_task_id"` + ClusterID string `json:"cluster_id"` +} + +// DataflowDeleteReq - Request body for deleting dataflow job +type DataflowDeleteReq struct { + OpUserUUID string `json:"-"` + Username string `json:"-"` + NSUUID string `json:"ns_uuid"` + ArgoTaskID string `json:"argo_task_id"` +} + +type DataflowDagTask struct { + Name string `json:"name"` + Status string `json:"status"` + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` +} + +// DataflowLogReq - Request for querying dataflow job logs +type DataflowLogReq struct { + CurrentUser string // Current user for permission check + Since string // Query parameter: "10mins", "30mins", "1hour", "6hours", "1day", "2days", "1week" + TaskId string // ArgoTaskID from path (workflow name) + DagTaskId string // Query parameter: specific DAG task id + SubmitTime time.Time // Job submit time - populated from workflow +} diff --git a/common/types/finetune.go b/common/types/finetune.go index 491e2e157..4ac1e2b09 100644 --- a/common/types/finetune.go +++ b/common/types/finetune.go @@ -85,7 +85,9 @@ type FinetuneRes struct { ResultURL string `json:"result_url"` } -type FinetuneLogReq struct { +type FinetuneLogReq = WorkflowLogReq + +type WorkflowLogReq struct { CurrentUser string Since string ID int64 diff --git a/common/types/webhook.go b/common/types/webhook.go index 513474383..159e109d6 100644 --- a/common/types/webhook.go +++ b/common/types/webhook.go @@ -21,6 +21,12 @@ const ( RunnerWorkflowCreate WebHookEventType = "runner.evaluation.create" RunnerWorkflowChange WebHookEventType = "runner.evaluation.change" + + RunnerDataflowChange WebHookEventType = "runner.dataflow.change" + RunnerDataflowDelete WebHookEventType = "runner.dataflow.delete" + + RunnerDataflowPodUpdate WebHookEventType = "runner.dataflow.pod.update" + RunnerDataflowPodDelete WebHookEventType = "runner.dataflow.pod.delete" ) type WebHookDataType string diff --git a/component/accounting.go b/component/accounting.go index 8509d4006..5f052bc77 100644 --- a/component/accounting.go +++ b/component/accounting.go @@ -75,7 +75,7 @@ func NewAccountingComponent(config *config.Config) (AccountingComponent, error) } func (ac *accountingComponentImpl) ListMeteringsByUserIDAndTime(ctx context.Context, req types.ActStatementsReq) (interface{}, error) { - if err := ac.allowQueryData(ctx, req.CurrentUser, req.UserUUID); err != nil { + if _, err := checkOwnerOrOrgMemberPermission(ctx, ac.userSvcClient, req.CurrentUser, req.UserUUID); err != nil { return nil, errorx.Forbidden(err, map[string]any{ "user": req.CurrentUser, }) @@ -88,30 +88,30 @@ func (ac *accountingComponentImpl) ListMeteringsByUserIDAndTime(ctx context.Cont // 1. Current user is the same as the target user (querying own data) // 2. Current user is an admin // 3. Current user is a member of the organization that owns the target user's namespace -func (ac *accountingComponentImpl) allowQueryData(ctx context.Context, currentUser, targetUUID string) error { - user, err := ac.userSvcClient.GetUserByName(ctx, currentUser) +func checkOwnerOrOrgMemberPermission(ctx context.Context, userSvcClient rpc.UserSvcClient, currentUser, targetUUID string) (*rpc.Namespace, error) { + user, err := userSvcClient.GetUserByName(ctx, currentUser) if err != nil { - return fmt.Errorf("current user not found: %w", err) + return nil, fmt.Errorf("current user not found: %w", err) } - if user.IsAdmin() || user.UUID == targetUUID { - return nil + ns, err := userSvcClient.GetNameSpaceInfoByUUID(ctx, targetUUID) + if err != nil { + return ns, fmt.Errorf("target namespace not found: %w", err) } - ns, err := ac.userSvcClient.GetNameSpaceInfoByUUID(ctx, targetUUID) - if err != nil { - return fmt.Errorf("target namespace not found: %w", err) + if user.IsAdmin() || user.UUID == targetUUID { + return ns, nil } if ns.NSType != string(database.OrgNamespace) { - return fmt.Errorf("do not have permission to query the target org's data: %w", err) + return ns, fmt.Errorf("do not have permission to query the target org's data: %w", err) } // Check if current user is member of org that owns target user's namespace - role, err := ac.userSvcClient.GetMemberRoleByUUID(ctx, ns.UUID, currentUser) + role, err := userSvcClient.GetMemberRoleByUUID(ctx, ns.UUID, currentUser) if err != nil || role == membership.RoleUnknown { - return fmt.Errorf("do not have permission to query the target org's data: %w", err) + return ns, fmt.Errorf("do not have permission to query the target org's data: %w", err) } - return nil + return ns, nil } diff --git a/component/accounting_test.go b/component/accounting_test.go index bff897d6c..b96852d27 100644 --- a/component/accounting_test.go +++ b/component/accounting_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" ) @@ -21,6 +23,10 @@ func TestAccountingComponent_ListMeteringsByUserIDAndTime(t *testing.T) { UUID: "uuid", Roles: []string{}, }, nil) + ac.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "uuid").Return(&rpc.Namespace{ + UUID: "uuid", + NSType: string(database.UserNamespace), + }, nil) ac.mocks.accountingClient.EXPECT().ListMeteringsByUserIDAndTime(req).Return( "", nil, ) diff --git a/component/executors/webhook_executor_dataflow.go b/component/executors/webhook_executor_dataflow.go new file mode 100644 index 000000000..4ce60c541 --- /dev/null +++ b/component/executors/webhook_executor_dataflow.go @@ -0,0 +1,100 @@ +package executors + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type DataflowExecutor interface { +} + +type dataflowExecutorImpl struct { + store database.ArgoWorkFlowStore +} + +var _ DataflowExecutor = (*dataflowExecutorImpl)(nil) +var _ WebHookExecutor = (*dataflowExecutorImpl)(nil) + +func NewDataflowExecutor(config *config.Config) (DataflowExecutor, error) { + executor := &dataflowExecutorImpl{ + store: database.NewArgoWorkFlowStore(), + } + + err := RegisterWebHookExecutor(types.RunnerDataflowChange, executor) + if err != nil { + return nil, fmt.Errorf("failed to register dataflow change executor: %w", err) + } + + err = RegisterWebHookExecutor(types.RunnerDataflowDelete, executor) + if err != nil { + return nil, fmt.Errorf("failed to register dataflow delete executor: %w", err) + } + + return executor, nil +} + +func (h *dataflowExecutorImpl) ProcessEvent(ctx context.Context, event *types.WebHookRecvEvent) error { + var newWF database.ArgoWorkflow + err := json.Unmarshal(event.Data, &newWF) + if err != nil { + return fmt.Errorf("failed to unmarshal dataflow event data: %w", err) + } + + slog.InfoContext(ctx, "dataflow_webhook_event", slog.Any("event_type", event.EventType), slog.Any("newWF", newWF)) + oldwf, err := h.store.FindByTaskID(ctx, newWF.TaskId) + if err != nil { + slog.WarnContext(ctx, "dataflow workflow not exists and skip update", slog.Any("task_id", newWF.TaskId)) + return nil + } + if len(newWF.Status) > 0 { + oldwf.Status = newWF.Status + } + if len(newWF.Reason) > 0 { + oldwf.Reason = newWF.Reason + } + if len(newWF.Namespace) > 0 { + oldwf.Namespace = newWF.Namespace + } + if !newWF.StartTime.IsZero() { + oldwf.StartTime = newWF.StartTime + } + if !newWF.EndTime.IsZero() { + oldwf.EndTime = newWF.EndTime + } + if len(newWF.QueueName) > 0 { + oldwf.QueueName = newWF.QueueName + } + if len(newWF.ClusterNode) > 0 { + oldwf.ClusterNode = newWF.ClusterNode + } + switch event.EventType { + case types.RunnerDataflowChange: + _, err = h.store.UpdateWorkFlowByTaskID(ctx, *oldwf) + if err != nil { + slog.ErrorContext(ctx, "failed to update dataflow workflow", slog.Any("oldwf", oldwf), slog.Any("err", err)) + } + case types.RunnerDataflowDelete: + if oldwf.Status == v1alpha1.WorkflowPending || oldwf.Status == v1alpha1.WorkflowRunning { + oldwf.Status = types.DFCancelled + _, err = h.store.UpdateWorkFlowByTaskID(ctx, *oldwf) + if err != nil { + slog.WarnContext(ctx, "failed to update dataflow workflow status", slog.Any("oldwf", oldwf), slog.Any("err", err)) + } + } + err = h.store.DeleteWorkFlow(ctx, oldwf.ID) + if err != nil { + slog.ErrorContext(ctx, "failed to delete dataflow workflow", slog.Any("oldwf", oldwf), slog.Any("err", err)) + } + default: + return fmt.Errorf("unknown dataflow event type: %s", event.EventType) + } + + return nil +} diff --git a/component/executors/webhook_executor_dataflow_pod.go b/component/executors/webhook_executor_dataflow_pod.go new file mode 100644 index 000000000..9dae2dcac --- /dev/null +++ b/component/executors/webhook_executor_dataflow_pod.go @@ -0,0 +1,96 @@ +package executors + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type DataflowPodExecutor interface { +} + +type dataflowPodExecutorImpl struct { + store database.ArgoWorkFlowStore +} + +var _ DataflowPodExecutor = (*dataflowPodExecutorImpl)(nil) +var _ WebHookExecutor = (*dataflowPodExecutorImpl)(nil) + +func NewDataflowPodExecutor(config *config.Config) (DataflowPodExecutor, error) { + executor := &dataflowPodExecutorImpl{ + store: database.NewArgoWorkFlowStore(), + } + + err := RegisterWebHookExecutor(types.RunnerDataflowPodUpdate, executor) + if err != nil { + return nil, fmt.Errorf("failed to register dataflow pod update executor: %w", err) + } + + err = RegisterWebHookExecutor(types.RunnerDataflowPodDelete, executor) + if err != nil { + return nil, fmt.Errorf("failed to register dataflow pod delete executor: %w", err) + } + + return executor, nil +} + +func (h *dataflowPodExecutorImpl) ProcessEvent(ctx context.Context, event *types.WebHookRecvEvent) error { + var newWF database.ArgoWorkflow + err := json.Unmarshal(event.Data, &newWF) + if err != nil { + return fmt.Errorf("failed to unmarshal dataflow pod event data: %w", err) + } + + slog.InfoContext(ctx, "dataflow_pod_webhook_event", slog.Any("event_type", event.EventType), slog.Any("newWF", newWF)) + + oldwf, err := h.store.FindByTaskID(ctx, newWF.TaskId) + if err != nil { + slog.WarnContext(ctx, "dataflow workflow not exists and skip pod update", slog.Any("task_id", newWF.TaskId)) + return nil + } + + if len(newWF.ClusterNode) > 0 { + oldwf.ClusterNode = newWF.ClusterNode + } + + if len(newWF.DagTasks) > 0 { + var existingMap map[string]interface{} + if oldwf.DagTasks != "" { + err := json.Unmarshal([]byte(oldwf.DagTasks), &existingMap) + if err != nil { + return fmt.Errorf("failed to unmarshal existing dag_tasks map string %s to map error: %w", oldwf.DagTasks, err) + } + } else { + existingMap = make(map[string]interface{}) + } + var newMap map[string]interface{} + err = json.Unmarshal([]byte(newWF.DagTasks), &newMap) + if err != nil { + return fmt.Errorf("failed to unmarshal new dag_tasks map string %s to map error: %w", newWF.DagTasks, err) + } + for k, v := range newMap { + existingMap[k] = v + } + merged, err := json.Marshal(existingMap) + if err != nil { + return fmt.Errorf("failed to marshal merged dag_tasks map string: %w", err) + } + oldwf.DagTasks = string(merged) + } + + _, err = h.store.UpdateWorkFlowByTaskID(ctx, *oldwf) + if err != nil { + slog.ErrorContext(ctx, "failed to update dataflow workflow dag_tasks", + slog.Any("task_id", newWF.TaskId), + slog.Any("dag_tasks", oldwf.DagTasks), + slog.Any("err", err)) + return fmt.Errorf("failed to update dataflow workflow dag_tasks: %w", err) + } + + return nil +} diff --git a/component/executors/webhook_executor_dataflow_pod_test.go b/component/executors/webhook_executor_dataflow_pod_test.go new file mode 100644 index 000000000..ceeb07cb3 --- /dev/null +++ b/component/executors/webhook_executor_dataflow_pod_test.go @@ -0,0 +1,417 @@ +package executors + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func NewTestDataflowPodExecutor(store database.ArgoWorkFlowStore) WebHookExecutor { + executor := &dataflowPodExecutorImpl{ + store: store, + } + return executor +} + +func TestDataflowPodExecutor_ProcessEvent(t *testing.T) { + ctx := context.TODO() + now := time.Now() + + t.Run("update cluster node and dag tasks", func(t *testing.T) { + existingDagTasks := `{"task1": {"status": "running", "start_time": "2024-01-01T00:00:00Z"}}` + oldWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "task-123", + ClusterNode: "old-node", + DagTasks: existingDagTasks, + } + + newDagTasks := `{"task2": {"status": "succeeded", "start_time": "2024-01-01T01:00:00Z"}}` + newWF := database.ArgoWorkflow{ + TaskId: "task-123", + ClusterNode: "new-node", + DagTasks: newDagTasks, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-123").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + return wf.ID == 1 && + wf.TaskId == "task-123" && + wf.ClusterNode == "new-node" && + len(wf.DagTasks) > 0 + })).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("update cluster node only", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 2, + TaskId: "task-456", + ClusterNode: "old-node", + DagTasks: "", + } + + newWF := database.ArgoWorkflow{ + TaskId: "task-456", + ClusterNode: "new-node", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-456").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + return wf.ID == 2 && + wf.TaskId == "task-456" && + wf.ClusterNode == "new-node" && + wf.DagTasks == "" + })).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("update dag tasks with empty existing", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 3, + TaskId: "task-789", + ClusterNode: "node-1", + DagTasks: "", + } + + newDagTasks := `{"task1": {"status": "running"}}` + newWF := database.ArgoWorkflow{ + TaskId: "task-789", + DagTasks: newDagTasks, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-789").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + var dagMap map[string]interface{} + err := json.Unmarshal([]byte(wf.DagTasks), &dagMap) + if err != nil { + return false + } + task1, ok := dagMap["task1"] + return ok && + wf.ID == 3 && + wf.TaskId == "task-789" && + wf.ClusterNode == "node-1" && + task1.(map[string]interface{})["status"] == "running" + })).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("merge dag tasks", func(t *testing.T) { + existingDagTasks := `{"task1": {"status": "running"}, "task2": {"status": "pending"}}` + oldWF := &database.ArgoWorkflow{ + ID: 4, + TaskId: "task-merge", + DagTasks: existingDagTasks, + } + + newDagTasks := `{"task2": {"status": "succeeded"}, "task3": {"status": "running"}}` + newWF := database.ArgoWorkflow{ + TaskId: "task-merge", + DagTasks: newDagTasks, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-merge").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + var dagMap map[string]interface{} + err := json.Unmarshal([]byte(wf.DagTasks), &dagMap) + if err != nil { + return false + } + task1, ok1 := dagMap["task1"] + task2, ok2 := dagMap["task2"] + task3, ok3 := dagMap["task3"] + return ok1 && ok2 && ok3 && + task1.(map[string]interface{})["status"] == "running" && + task2.(map[string]interface{})["status"] == "succeeded" && + task3.(map[string]interface{})["status"] == "running" + })).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("workflow not found", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-not-exist", + ClusterNode: "node-1", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-not-exist").Return(nil, errors.New("not found")) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("unmarshal error", func(t *testing.T) { + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: []byte("invalid json"), + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + executor := NewTestDataflowPodExecutor(mockStore) + err := executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal dataflow pod event data") + }) + + t.Run("invalid existing dag_tasks json", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 5, + TaskId: "task-bad-existing", + DagTasks: "invalid json", + } + + newDagTasks := `{"task1": {"status": "running"}}` + newWF := database.ArgoWorkflow{ + TaskId: "task-bad-existing", + DagTasks: newDagTasks, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-bad-existing").Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal existing dag_tasks map string") + }) + + t.Run("invalid new dag_tasks json", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 6, + TaskId: "task-bad-new", + DagTasks: `{"task1": {"status": "running"}}`, + } + + newWF := database.ArgoWorkflow{ + TaskId: "task-bad-new", + DagTasks: "invalid new json", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-bad-new").Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal new dag_tasks map string") + }) + + t.Run("update error", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 7, + TaskId: "task-update-error", + ClusterNode: "old-node", + DagTasks: "", + } + + newWF := database.ArgoWorkflow{ + TaskId: "task-update-error", + ClusterNode: "new-node", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-update-error").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.Anything).Return(nil, errors.New("update failed")) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to update dataflow workflow dag_tasks") + }) + + t.Run("RunnerDataflowPodDelete event", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 8, + TaskId: "task-pod-delete", + ClusterNode: "node-1", + DagTasks: "", + } + + newWF := database.ArgoWorkflow{ + TaskId: "task-pod-delete", + ClusterNode: "node-2", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodDelete, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-pod-delete").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.Anything).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("no fields to update", func(t *testing.T) { + oldWF := &database.ArgoWorkflow{ + ID: 9, + TaskId: "task-no-update", + ClusterNode: "node-1", + DagTasks: "", + } + + newWF := database.ArgoWorkflow{ + TaskId: "task-no-update", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowPodUpdate, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-no-update").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + return wf.ClusterNode == "node-1" && wf.DagTasks == "" + })).Return(oldWF, nil) + + executor := NewTestDataflowPodExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) +} + +func TestNewDataflowPodExecutor(t *testing.T) { + cfg := &config.Config{} + + t.Run("success", func(t *testing.T) { + executor, err := NewDataflowPodExecutor(cfg) + require.NoError(t, err) + require.NotNil(t, executor) + }) +} diff --git a/component/executors/webhook_executor_dataflow_test.go b/component/executors/webhook_executor_dataflow_test.go new file mode 100644 index 000000000..6f2dc0389 --- /dev/null +++ b/component/executors/webhook_executor_dataflow_test.go @@ -0,0 +1,395 @@ +package executors + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func NewTestDataflowExecutor(store database.ArgoWorkFlowStore) WebHookExecutor { + executor := &dataflowExecutorImpl{ + store: store, + } + return executor +} + +func TestDataflowExecutor_ProcessEvent(t *testing.T) { + ctx := context.TODO() + now := time.Now() + + t.Run("RunnerDataflowChange success", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-123", + Status: v1alpha1.WorkflowRunning, + Reason: "test reason", + Namespace: "default", + StartTime: now, + EndTime: now.Add(1 * time.Hour), + QueueName: "queue-1", + ClusterNode: "node-1", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowChange, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "task-123", + Status: v1alpha1.WorkflowPending, + Reason: "", + Namespace: "", + StartTime: time.Time{}, + EndTime: time.Time{}, + QueueName: "", + ClusterNode: "", + } + + updatedWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "task-123", + Status: newWF.Status, + Reason: newWF.Reason, + Namespace: newWF.Namespace, + StartTime: newWF.StartTime, + EndTime: newWF.EndTime, + QueueName: newWF.QueueName, + ClusterNode: newWF.ClusterNode, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-123").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + return wf.TaskId == "task-123" && + wf.Status == v1alpha1.WorkflowRunning && + wf.Reason == "test reason" && + wf.Namespace == "default" && + wf.QueueName == "queue-1" && + wf.ClusterNode == "node-1" + })).Return(updatedWF, nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("RunnerDataflowChange partial update", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-456", + Status: v1alpha1.WorkflowRunning, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowChange, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 2, + TaskId: "task-456", + Status: v1alpha1.WorkflowPending, + Reason: "old reason", + Namespace: "old-namespace", + } + + updatedWF := &database.ArgoWorkflow{ + ID: 2, + TaskId: "task-456", + Status: v1alpha1.WorkflowRunning, + Reason: "old reason", + Namespace: "old-namespace", + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-456").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, *updatedWF).Return(updatedWF, nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("RunnerDataflowDelete with pending status", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-789", + Status: v1alpha1.WorkflowPending, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowDelete, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 3, + TaskId: "task-789", + Status: v1alpha1.WorkflowPending, + } + + cancelledWF := &database.ArgoWorkflow{ + ID: 3, + TaskId: "task-789", + Status: types.DFCancelled, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-789").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, *cancelledWF).Return(cancelledWF, nil) + mockStore.EXPECT().DeleteWorkFlow(ctx, oldWF.ID).Return(nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("RunnerDataflowDelete with running status", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-running", + Status: v1alpha1.WorkflowRunning, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowDelete, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 4, + TaskId: "task-running", + Status: v1alpha1.WorkflowRunning, + } + + cancelledWF := &database.ArgoWorkflow{ + ID: 4, + TaskId: "task-running", + Status: types.DFCancelled, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-running").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, *cancelledWF).Return(cancelledWF, nil) + mockStore.EXPECT().DeleteWorkFlow(ctx, oldWF.ID).Return(nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("RunnerDataflowDelete with succeeded status", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-succeeded", + Status: v1alpha1.WorkflowSucceeded, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowDelete, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 5, + TaskId: "task-succeeded", + Status: v1alpha1.WorkflowSucceeded, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-succeeded").Return(oldWF, nil) + mockStore.EXPECT().DeleteWorkFlow(ctx, oldWF.ID).Return(nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("workflow not found", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-not-exist", + Status: v1alpha1.WorkflowRunning, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowChange, + EventTime: now.Unix(), + }, + Data: data, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-not-exist").Return(nil, errors.New("not found")) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("unmarshal error", func(t *testing.T) { + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowChange, + EventTime: now.Unix(), + }, + Data: []byte("invalid json"), + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + executor := NewTestDataflowExecutor(mockStore) + err := executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal dataflow event data") + }) + + t.Run("unknown event type", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-unknown", + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.WebHookEventType("unknown"), + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 6, + TaskId: "task-unknown", + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-unknown").Return(oldWF, nil) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown dataflow event type") + }) + + t.Run("RunnerDataflowChange update error", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-update-error", + Status: v1alpha1.WorkflowRunning, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowChange, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 7, + TaskId: "task-update-error", + Status: v1alpha1.WorkflowPending, + } + + updatedWF := &database.ArgoWorkflow{ + ID: 7, + TaskId: "task-update-error", + Status: v1alpha1.WorkflowRunning, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-update-error").Return(oldWF, nil) + mockStore.EXPECT().UpdateWorkFlowByTaskID(ctx, *updatedWF).Return(nil, errors.New("update failed")) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) + + t.Run("RunnerDataflowDelete delete error", func(t *testing.T) { + newWF := database.ArgoWorkflow{ + TaskId: "task-delete-error", + Status: v1alpha1.WorkflowSucceeded, + } + + data, err := json.Marshal(newWF) + require.NoError(t, err) + + event := &types.WebHookRecvEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: types.RunnerDataflowDelete, + EventTime: now.Unix(), + }, + Data: data, + } + + oldWF := &database.ArgoWorkflow{ + ID: 8, + TaskId: "task-delete-error", + Status: v1alpha1.WorkflowSucceeded, + } + + mockStore := mockdb.NewMockArgoWorkFlowStore(t) + mockStore.EXPECT().FindByTaskID(ctx, "task-delete-error").Return(oldWF, nil) + mockStore.EXPECT().DeleteWorkFlow(ctx, oldWF.ID).Return(errors.New("delete failed")) + + executor := NewTestDataflowExecutor(mockStore) + err = executor.ProcessEvent(ctx, event) + require.NoError(t, err) + }) +} + +func TestNewDataflowExecutor(t *testing.T) { + cfg := &config.Config{} + + t.Run("success", func(t *testing.T) { + executor, err := NewDataflowExecutor(cfg) + require.NoError(t, err) + require.NotNil(t, executor) + }) +} diff --git a/component/finetune.go b/component/finetune.go index ecd5f10ea..93bd359cc 100644 --- a/component/finetune.go +++ b/component/finetune.go @@ -310,15 +310,19 @@ func (c *finetuneComponentImpl) CheckUserPermission(ctx context.Context, req typ func (c *finetuneComponentImpl) ReadJobLogsNonStream(ctx context.Context, req types.FinetuneLogReq) (string, error) { wf, err := c.workflowStore.FindByID(ctx, req.ID) if err != nil { - return "", fmt.Errorf("fail to find finetune workflow by id %d error: %w", req.ID, err) + return "", fmt.Errorf("fail to find argo workflow by id %d error: %w", req.ID, err) } req.PodName = wf.TaskId req.SubmitTime = wf.SubmitTime - lokiResp, err := c.deployer.GetWorkflowLogsNonStream(ctx, req) + labels := map[string]string{ + types.StreamKeyInstanceName: req.PodName, + } + + lokiResp, err := c.deployer.GetWorkflowLogsNonStream(ctx, req, labels) if err != nil { - return "", fmt.Errorf("failed to read finetune job logs, error:%w", err) + return "", fmt.Errorf("failed to read workflow job logs, error:%w", err) } return c.formatLogs(lokiResp), nil @@ -333,7 +337,11 @@ func (c *finetuneComponentImpl) ReadJobLogsInStream(ctx context.Context, req typ req.PodName = wf.TaskId req.SubmitTime = wf.SubmitTime - return c.deployer.GetWorkflowLogsInStream(ctx, req) + labels := map[string]string{ + types.StreamKeyInstanceName: req.PodName, + } + + return c.deployer.GetWorkflowLogsInStream(ctx, req, labels) } func (c *finetuneComponentImpl) formatLogs(lokiLog *loki.LokiQueryResponse) string { diff --git a/component/finetune_test.go b/component/finetune_test.go index 5d108f62d..74adede89 100644 --- a/component/finetune_test.go +++ b/component/finetune_test.go @@ -322,7 +322,7 @@ func TestFinetuneComponent_ReadJobLogsNonStream(t *testing.T) { }, nil) expectedLogs := &loki.LokiQueryResponse{} - mockDeployer.EXPECT().GetWorkflowLogsNonStream(ctx, mock.Anything).Return(expectedLogs, nil) + mockDeployer.EXPECT().GetWorkflowLogsNonStream(ctx, mock.Anything, mock.Anything).Return(expectedLogs, nil) logs, err := c.ReadJobLogsNonStream(ctx, req) require.NoError(t, err) @@ -364,7 +364,7 @@ func TestFinetuneComponent_ReadJobLogsNonStream(t *testing.T) { }, nil) expectedErr := errors.New("failed to get logs") - mockDeployer.EXPECT().GetWorkflowLogsNonStream(ctx, mock.Anything).Return(nil, expectedErr) + mockDeployer.EXPECT().GetWorkflowLogsNonStream(ctx, mock.Anything, mock.Anything).Return(nil, expectedErr) _, err := c.ReadJobLogsNonStream(ctx, req) require.NotNil(t, err) @@ -393,7 +393,7 @@ func TestFinetuneComponent_ReadJobLogsInStream(t *testing.T) { }, nil) expectedReader := &deploy.MultiLogReader{} - mockDeployer.EXPECT().GetWorkflowLogsInStream(ctx, mock.Anything).Return(expectedReader, nil) + mockDeployer.EXPECT().GetWorkflowLogsInStream(ctx, mock.Anything, mock.Anything).Return(expectedReader, nil) reader, err := c.ReadJobLogsInStream(ctx, req) require.NoError(t, err) diff --git a/component/platform_dataflow.go b/component/platform_dataflow.go new file mode 100644 index 000000000..9dd2959ff --- /dev/null +++ b/component/platform_dataflow.go @@ -0,0 +1,287 @@ +package component + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/bwmarrin/snowflake" + "opencsg.com/csghub-server/builder/deploy" + "opencsg.com/csghub-server/builder/loki" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type PlatformDataflowComponent interface { + CreateJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) + DeleteJob(ctx context.Context, req *types.DataflowDeleteReq) error + GetJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) + ReadJobLogsInStream(ctx context.Context, req types.DataflowLogReq) (*deploy.MultiLogReader, error) + ReadJobLogsNonStream(ctx context.Context, req types.DataflowLogReq) (string, error) + CheckUserPermission(ctx context.Context, req types.DataflowLogReq) (bool, error) +} + +type platformDataflowComponentImpl struct { + deployer deploy.Deployer + workflowStore database.ArgoWorkFlowStore + userSvcClient rpc.UserSvcClient + clusterStore database.ClusterInfoStore + spaceResourceStore database.SpaceResourceStore + repoComponent RepoComponent + snowflakeNode *snowflake.Node + config *config.Config +} + +func NewPlatformDataflowComponent(cfg *config.Config) (PlatformDataflowComponent, error) { + var err error + c := &platformDataflowComponentImpl{} + c.config = cfg + c.deployer = deploy.NewDeployer() + c.workflowStore = database.NewArgoWorkFlowStore() + c.userSvcClient = rpc.NewUserSvcHttpClient( + fmt.Sprintf("%s:%d", cfg.User.Host, cfg.User.Port), + rpc.AuthWithApiKey(cfg.APIToken), + ) + c.clusterStore = database.NewClusterInfoStore() + c.spaceResourceStore = database.NewSpaceResourceStore() + c.repoComponent, err = NewRepoComponent(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create repo component, error: %w", err) + } + node, err := snowflake.NewNode(1) + if err != nil || node == nil { + return nil, fmt.Errorf("failed to create snowflake node, error: %w", err) + } + c.snowflakeNode = node + return c, nil +} + +func (c *platformDataflowComponentImpl) CreateJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + // Check user or org permission + ns, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, req.Username, req.NSUUID) + if err != nil { + return nil, err + } + + var hardware types.HardWare + + resource, err := c.spaceResourceStore.FindByID(ctx, req.ResourceId) + if err != nil { + return nil, fmt.Errorf("cannot find resource %d error: %w", req.ResourceId, err) + } + + err = json.Unmarshal([]byte(resource.Resources), &hardware) + if err != nil { + return nil, fmt.Errorf("invalid hardware setting error: %w", err) + } + + // check resource available + exclusiveResp, err := c.repoComponent.CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, 0, resource) + if err != nil { + return nil, fmt.Errorf("failed to check account and resource, error: %w", err) + } + + req.ClusterID = resource.ClusterID + req.ResourceName = resource.Name + req.NodeAffinity = exclusiveResp.NodeAffinity + req.Tolerations = exclusiveResp.Tolerations + + clusterNodes, err := c.clusterStore.FindNodeByClusterID(ctx, req.ClusterID) + if err != nil { + return nil, fmt.Errorf("failed to find nodes by clusterID %s, error: %w", req.ClusterID, err) + } + + uniqueJobID := c.snowflakeNode.Generate().Base36() + + now := time.Now() + workflow := database.ArgoWorkflow{ + Username: req.Username, + UserUUID: req.NSUUID, + TaskName: req.JobName, + TaskId: fmt.Sprintf("df%s", uniqueJobID), + TaskType: types.TaskTypeDataflow, + ClusterID: req.ClusterID, + RepoIds: req.RepoIds, + RepoType: string(types.DatasetRepo), + TaskDesc: req.JobDesc, + Status: v1alpha1.WorkflowPending, + Image: req.Template.Image, + Datasets: req.RepoIds, + ResourceId: req.ResourceId, + ResourceName: req.ResourceName, + SubmitTime: now, + } + + for _, node := range clusterNodes { + req.Nodes = append(req.Nodes, types.Node{ + Name: node.Name, + EnableVXPU: node.EnableVXPU, + HasXPU: node.Hardware.HasXPU() || node.EnableVXPU, + }) + } + + createdWorkflow, err := c.workflowStore.CreateWorkFlow(ctx, workflow) + if err != nil { + return nil, fmt.Errorf("failed to create ArgoWorkflow record, error: %w", err) + } + + req.ID = createdWorkflow.ID + req.ArgoTaskID = createdWorkflow.TaskId + + resp, err := c.deployer.CreateDataflowJob(ctx, req) + if err != nil { + // Delete ArgoWorkflow record + delErr := c.workflowStore.DeleteWorkFlow(ctx, createdWorkflow.ID) + if delErr != nil { + slog.ErrorContext(ctx, "failed to delete ArgoWorkflow record due to create dataflow workflow failed", + slog.Any("error", delErr)) + } + return nil, fmt.Errorf("failed to create dataflow workflow, error: %w", err) + } + + resp.ID = createdWorkflow.ID + return resp, nil +} + +func (c *platformDataflowComponentImpl) DeleteJob(ctx context.Context, req *types.DataflowDeleteReq) error { + wf, err := c.workflowStore.FindByTaskID(ctx, req.ArgoTaskID) + if err != nil { + return fmt.Errorf("failed to find dataflow workflow by task_id %s error: %w", req.ArgoTaskID, err) + } + + if wf.UserUUID != req.NSUUID { + return fmt.Errorf("do not have permission to operate the target namespace's data: %w", err) + } + + // Check owner or org permission + _, err = checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, req.Username, req.NSUUID) + if err != nil { + return err + } + + deleteReq := &types.DataflowArgoReq{ + ArgoTaskID: wf.TaskId, + ClusterID: wf.ClusterID, + } + err = c.deployer.DeleteDataflowJob(ctx, deleteReq) + if err != nil { + return fmt.Errorf("failed to delete dataflow workflow %s error: %w", req.ArgoTaskID, err) + } + + err = c.workflowStore.DeleteWorkFlow(ctx, wf.ID) + if err != nil { + return fmt.Errorf("failed to delete dataflow workflow record %d error: %w", wf.ID, err) + } + + return nil +} + +func (c *platformDataflowComponentImpl) GetJob(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + wf, err := c.workflowStore.FindByTaskID(ctx, req.ArgoTaskID) + if err != nil { + return nil, fmt.Errorf("failed to find dataflow workflow by task_id %s: %w", req.ArgoTaskID, err) + } + + _, err = checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, req.Username, req.NSUUID) + if err != nil { + return nil, err + } + + resp := &types.DataflowArgoJobResp{ + ID: wf.ID, + ArgoTaskID: wf.TaskId, + JobID: wf.TaskId, + JobName: wf.TaskName, + Status: string(wf.Status), + Message: wf.Reason, + CreatedAt: wf.SubmitTime.Unix(), + DagTasks: wf.DagTasks, + DeleteAt: wf.DeletedAt.Unix(), + } + return resp, nil +} + +func (c *platformDataflowComponentImpl) CheckUserPermission(ctx context.Context, req types.DataflowLogReq) (bool, error) { + wf, err := c.workflowStore.FindByTaskID(ctx, req.TaskId) + if err != nil { + return false, fmt.Errorf("failed to find dataflow workflow by task_id %s error: %w", req.TaskId, err) + } + + _, err = checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, req.CurrentUser, wf.UserUUID) + if err != nil { + return false, err + } + + return true, nil +} + +func (c *platformDataflowComponentImpl) ReadJobLogsNonStream(ctx context.Context, req types.DataflowLogReq) (string, error) { + wf, err := c.workflowStore.FindByTaskID(ctx, req.TaskId) + if err != nil { + return "", fmt.Errorf("failed to find dataflow workflow by task_id %s error: %w", req.TaskId, err) + } + + logReq := types.WorkflowLogReq{ + Since: req.Since, + SubmitTime: wf.SubmitTime, + } + + labels := map[string]string{ + types.DFArgoTaskIDKey: req.TaskId, + } + if len(req.DagTaskId) > 0 { + labels[types.DFLabelDagTaskIDKey] = req.DagTaskId + } + + lokiResp, err := c.deployer.GetWorkflowLogsNonStream(ctx, logReq, labels) + if err != nil { + return "", fmt.Errorf("failed to read dataflow job logs, error:%w", err) + } + + return c.formatLogs(lokiResp), nil +} + +func (c *platformDataflowComponentImpl) ReadJobLogsInStream(ctx context.Context, req types.DataflowLogReq) (*deploy.MultiLogReader, error) { + wf, err := c.workflowStore.FindByTaskID(ctx, req.TaskId) + if err != nil { + return nil, fmt.Errorf("fail to find dataflow workflow by task_id %s error: %w", req.TaskId, err) + } + + logReq := types.WorkflowLogReq{ + CurrentUser: req.CurrentUser, + Since: req.Since, + PodName: req.TaskId, + SubmitTime: wf.SubmitTime, + } + + labels := map[string]string{ + types.DFArgoTaskIDKey: req.TaskId, + } + if len(req.DagTaskId) > 0 { + labels[types.DFLabelDagTaskIDKey] = req.DagTaskId + } + + return c.deployer.GetWorkflowLogsInStream(ctx, logReq, labels) +} + +func (c *platformDataflowComponentImpl) formatLogs(lokiLog *loki.LokiQueryResponse) string { + var bulkLog strings.Builder + for _, item := range lokiLog.Data.Result { + for _, valuePair := range item.Values { + for _, log := range strings.Split(valuePair[1], "\n") { + if log == "" { + continue + } + bulkLog.WriteString(log) + bulkLog.WriteString(c.config.LogCollector.LineSeparator) + } + } + } + return strings.TrimSuffix(bulkLog.String(), c.config.LogCollector.LineSeparator) +} diff --git a/component/platform_dataflow_test.go b/component/platform_dataflow_test.go new file mode 100644 index 000000000..176b028e6 --- /dev/null +++ b/component/platform_dataflow_test.go @@ -0,0 +1,883 @@ +package component + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/bwmarrin/snowflake" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockdeploy "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy" + mockrpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + mockcomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type testPlatformDataflowComponent struct { + *platformDataflowComponentImpl + mocks struct { + deployer *mockdeploy.MockDeployer + workflowStore *mockdb.MockArgoWorkFlowStore + userSvcClient *mockrpc.MockUserSvcClient + clusterStore *mockdb.MockClusterInfoStore + spaceResourceStore *mockdb.MockSpaceResourceStore + repoComponent *mockcomp.MockRepoComponent + } +} + +func newTestPlatformDataflowComponent(t *testing.T) *testPlatformDataflowComponent { + node, err := snowflake.NewNode(1) + require.NoError(t, err) + + c := &testPlatformDataflowComponent{ + platformDataflowComponentImpl: &platformDataflowComponentImpl{ + snowflakeNode: node, + config: &config.Config{}, + }, + } + + c.mocks.deployer = mockdeploy.NewMockDeployer(t) + c.deployer = c.mocks.deployer + + c.mocks.workflowStore = mockdb.NewMockArgoWorkFlowStore(t) + c.workflowStore = c.mocks.workflowStore + + c.mocks.userSvcClient = mockrpc.NewMockUserSvcClient(t) + c.userSvcClient = c.mocks.userSvcClient + + c.mocks.clusterStore = mockdb.NewMockClusterInfoStore(t) + c.clusterStore = c.mocks.clusterStore + + c.mocks.spaceResourceStore = mockdb.NewMockSpaceResourceStore(t) + c.spaceResourceStore = c.mocks.spaceResourceStore + + c.mocks.repoComponent = mockcomp.NewMockRepoComponent(t) + c.repoComponent = c.mocks.repoComponent + + return c +} + +func TestPlatformDataflowComponent_CreateJob(t *testing.T) { + ctx := context.TODO() + now := time.Now() + + t.Run("success", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + RepoIds: []string{"repo1", "repo2"}, + ResourceId: 100, + JobName: "test-job", + JobDesc: "test description", + Template: types.ArgoFlowTemplate{ + Image: "test-image:latest", + }, + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + Name: "test-resource", + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}, "memory": "4Gi"}`, + } + + clusterNodes := []database.ClusterNode{ + { + Name: "node-1", + }, + } + + createdWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "df-abc123", + TaskName: req.JobName, + UserUUID: req.NSUUID, + Username: req.Username, + Status: v1alpha1.WorkflowPending, + SubmitTime: now, + } + + deployResp := &types.DataflowArgoJobResp{ + ArgoTaskID: "df-abc123", + JobID: "df-abc123", + JobName: req.JobName, + Status: "Pending", + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(&types.CheckExclusiveResp{}, nil) + c.mocks.clusterStore.EXPECT().FindNodeByClusterID(ctx, resource.ClusterID).Return(clusterNodes, nil) + c.mocks.workflowStore.EXPECT().CreateWorkFlow(ctx, mock.MatchedBy(func(wf database.ArgoWorkflow) bool { + return wf.TaskName == req.JobName && + wf.UserUUID == req.NSUUID && + wf.Username == req.Username + })).Return(createdWF, nil) + c.mocks.deployer.EXPECT().CreateDataflowJob(ctx, mock.MatchedBy(func(r *types.DataflowArgoJobReq) bool { + return r.ID == createdWF.ID && r.ArgoTaskID == createdWF.TaskId + })).Return(deployResp, nil) + + resp, err := c.CreateJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, createdWF.ID, resp.ID) + require.Equal(t, createdWF.TaskId, resp.ArgoTaskID) + }) + + t.Run("user not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "nonexistent", + NSUUID: "user-uuid-1", + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(nil, errors.New("user not found")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "current user not found") + }) + + t.Run("namespace not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "nonexistent-uuid", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(nil, errors.New("namespace not found")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "target namespace not found") + }) + + t.Run("resource not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ResourceId: 999, + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(nil, errors.New("resource not found")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "cannot find resource") + }) + + t.Run("check account and resource failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ResourceId: 100, + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}}`, + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(nil, errors.New("resource unavailable")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to check account and resource") + }) + + t.Run("cluster nodes not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ResourceId: 100, + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}}`, + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(&types.CheckExclusiveResp{}, nil) + c.mocks.clusterStore.EXPECT().FindNodeByClusterID(ctx, resource.ClusterID).Return(nil, errors.New("cluster not found")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to find nodes by clusterID") + }) + + t.Run("create workflow failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ResourceId: 100, + JobName: "test-job", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}}`, + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(&types.CheckExclusiveResp{}, nil) + c.mocks.clusterStore.EXPECT().FindNodeByClusterID(ctx, resource.ClusterID).Return([]database.ClusterNode{}, nil) + c.mocks.workflowStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(nil, errors.New("db error")) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to create ArgoWorkflow record") + }) + + t.Run("deployer create workflow failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ResourceId: 100, + JobName: "test-job", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}}`, + } + + createdWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "df-abc123", + TaskName: req.JobName, + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(&types.CheckExclusiveResp{}, nil) + c.mocks.clusterStore.EXPECT().FindNodeByClusterID(ctx, resource.ClusterID).Return([]database.ClusterNode{}, nil) + c.mocks.workflowStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(createdWF, nil) + c.mocks.deployer.EXPECT().CreateDataflowJob(ctx, mock.Anything).Return(nil, errors.New("deployer error")) + c.mocks.workflowStore.EXPECT().DeleteWorkFlow(ctx, createdWF.ID).Return(nil) + + resp, err := c.CreateJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to create dataflow workflow") + }) + + t.Run("org member has permission", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "orgmember", + NSUUID: "org-uuid-1", + ResourceId: 100, + JobName: "test-job", + } + + user := &types.User{ + UUID: "user-uuid-member", + Username: "orgmember", + } + ns := &rpc.Namespace{ + Path: "testorg", + UUID: "org-uuid-1", + NSType: string(database.OrgNamespace), + } + + resource := &database.SpaceResource{ + ID: 100, + Name: "test-resource", + ClusterID: "cluster-1", + Resources: `{"cpu": {"num": "2"}}`, + } + + createdWF := &database.ArgoWorkflow{ + ID: 1, + TaskId: "df-abc123", + TaskName: req.JobName, + } + + deployResp := &types.DataflowArgoJobResp{ + ArgoTaskID: "df-abc123", + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.userSvcClient.EXPECT().GetMemberRoleByUUID(ctx, ns.UUID, req.Username).Return(membership.RoleAdmin, nil) + c.mocks.spaceResourceStore.EXPECT().FindByID(ctx, req.ResourceId).Return(resource, nil) + c.mocks.repoComponent.EXPECT().CheckAccountAndResource(ctx, ns.Path, resource.ClusterID, int64(0), resource).Return(&types.CheckExclusiveResp{}, nil) + c.mocks.clusterStore.EXPECT().FindNodeByClusterID(ctx, resource.ClusterID).Return([]database.ClusterNode{}, nil) + c.mocks.workflowStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(createdWF, nil) + c.mocks.deployer.EXPECT().CreateDataflowJob(ctx, mock.Anything).Return(deployResp, nil) + + resp, err := c.CreateJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + }) +} + +func TestPlatformDataflowComponent_DeleteJob(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: req.NSUUID, + ClusterID: "cluster-1", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.deployer.EXPECT().DeleteDataflowJob(ctx, mock.MatchedBy(func(r *types.DataflowArgoReq) bool { + return r.ArgoTaskID == wf.TaskId && r.ClusterID == wf.ClusterID + })).Return(nil) + c.mocks.workflowStore.EXPECT().DeleteWorkFlow(ctx, wf.ID).Return(nil) + + err := c.DeleteJob(ctx, req) + require.NoError(t, err) + }) + + t.Run("workflow not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "nonexistent", + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(nil, errors.New("not found")) + + err := c.DeleteJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to find dataflow workflow") + }) + + t.Run("permission denied - different user", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: "different-user-uuid", + ClusterID: "cluster-1", + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + + err := c.DeleteJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "do not have permission") + }) + + t.Run("deployer delete failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: req.NSUUID, + ClusterID: "cluster-1", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.deployer.EXPECT().DeleteDataflowJob(ctx, mock.Anything).Return(errors.New("deployer error")) + + err := c.DeleteJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to delete dataflow workflow") + }) + + t.Run("delete workflow record failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: req.NSUUID, + ClusterID: "cluster-1", + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.deployer.EXPECT().DeleteDataflowJob(ctx, mock.Anything).Return(nil) + c.mocks.workflowStore.EXPECT().DeleteWorkFlow(ctx, wf.ID).Return(errors.New("db error")) + + err := c.DeleteJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to delete dataflow workflow record") + }) + + t.Run("user permission check failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowDeleteReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: req.NSUUID, + ClusterID: "cluster-1", + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(nil, errors.New("user not found")) + + err := c.DeleteJob(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "current user not found") + }) +} + +func TestPlatformDataflowComponent_GetJob(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + TaskName: "test-job", + UserUUID: req.NSUUID, + Username: req.Username, + Status: v1alpha1.WorkflowRunning, + Reason: "", + SubmitTime: time.Now(), + DagTasks: `{"task1": {"status": "running"}}`, + } + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + + resp, err := c.GetJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, wf.ID, resp.ID) + require.Equal(t, wf.TaskId, resp.ArgoTaskID) + require.Equal(t, wf.TaskName, resp.JobName) + require.Equal(t, string(wf.Status), resp.Status) + }) + + t.Run("workflow not found", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "nonexistent", + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(nil, errors.New("not found")) + + resp, err := c.GetJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to find dataflow workflow") + }) + + t.Run("permission check failed", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "testuser", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + UserUUID: req.NSUUID, + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(nil, errors.New("user not found")) + + resp, err := c.GetJob(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "current user not found") + }) + + t.Run("admin user has permission", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "admin", + NSUUID: "user-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + TaskName: "test-job", + UserUUID: "different-user-uuid", + Username: "differentuser", + Status: v1alpha1.WorkflowSucceeded, + SubmitTime: time.Now(), + } + + user := &types.User{ + UUID: "admin-uuid", + Username: "admin", + Roles: []string{"admin"}, + } + ns := &rpc.Namespace{ + Path: "admin", + UUID: "admin-uuid", + NSType: string(database.UserNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + + resp, err := c.GetJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, wf.ID, resp.ID) + }) + + t.Run("org member has permission", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + req := &types.DataflowArgoJobReq{ + Username: "orgmember", + NSUUID: "org-uuid-1", + ArgoTaskID: "df-abc123", + } + + wf := &database.ArgoWorkflow{ + ID: 1, + TaskId: req.ArgoTaskID, + TaskName: "test-job", + UserUUID: req.NSUUID, + Status: v1alpha1.WorkflowRunning, + SubmitTime: time.Now(), + } + + user := &types.User{ + UUID: "member-uuid", + Username: "orgmember", + } + ns := &rpc.Namespace{ + Path: "testorg", + UUID: "org-uuid-1", + NSType: string(database.OrgNamespace), + } + + c.mocks.workflowStore.EXPECT().FindByTaskID(ctx, req.ArgoTaskID).Return(wf, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, req.Username).Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, req.NSUUID).Return(ns, nil) + c.mocks.userSvcClient.EXPECT().GetMemberRoleByUUID(ctx, ns.UUID, req.Username).Return(membership.RoleWrite, nil) + + resp, err := c.GetJob(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + }) +} + +func TestCheckUserOrOrgPermission(t *testing.T) { + ctx := context.TODO() + + t.Run("user is admin", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + user := &types.User{ + UUID: "admin-uuid", + Username: "admin", + Roles: []string{"admin"}, + } + ns := &rpc.Namespace{ + Path: "targetuser", + UUID: "target-uuid", + NSType: string(database.UserNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "admin").Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "target-uuid").Return(ns, nil) + + result, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, "admin", "target-uuid") + require.NoError(t, err) + require.NotNil(t, result) + }) + + t.Run("user accessing own namespace", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testuser", + UUID: "user-uuid-1", + NSType: string(database.UserNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "testuser").Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "user-uuid-1").Return(ns, nil) + + result, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, "testuser", "user-uuid-1") + require.NoError(t, err) + require.NotNil(t, result) + }) + + t.Run("org member accessing org namespace", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + user := &types.User{ + UUID: "member-uuid", + Username: "orgmember", + } + ns := &rpc.Namespace{ + Path: "testorg", + UUID: "org-uuid-1", + NSType: string(database.OrgNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "orgmember").Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "org-uuid-1").Return(ns, nil) + c.mocks.userSvcClient.EXPECT().GetMemberRoleByUUID(ctx, "org-uuid-1", "orgmember").Return("member", nil) + + result, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, "orgmember", "org-uuid-1") + require.NoError(t, err) + require.NotNil(t, result) + }) + + t.Run("user accessing other user namespace denied", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "otheruser", + UUID: "other-uuid", + NSType: string(database.UserNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "testuser").Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "other-uuid").Return(ns, nil) + + result, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, "testuser", "other-uuid") + require.Error(t, err) + require.Contains(t, err.Error(), "do not have permission") + require.NotNil(t, result) + }) + + t.Run("non-member accessing org namespace denied", func(t *testing.T) { + c := newTestPlatformDataflowComponent(t) + + user := &types.User{ + UUID: "user-uuid-1", + Username: "testuser", + } + ns := &rpc.Namespace{ + Path: "testorg", + UUID: "org-uuid-1", + NSType: string(database.OrgNamespace), + } + + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "testuser").Return(user, nil) + c.mocks.userSvcClient.EXPECT().GetNameSpaceInfoByUUID(ctx, "org-uuid-1").Return(ns, nil) + c.mocks.userSvcClient.EXPECT().GetMemberRoleByUUID(ctx, "org-uuid-1", "testuser").Return(membership.RoleUnknown, nil) + + result, err := checkOwnerOrOrgMemberPermission(ctx, c.userSvcClient, "testuser", "org-uuid-1") + require.Error(t, err) + require.Contains(t, err.Error(), "do not have permission") + require.NotNil(t, result) + }) +} diff --git a/component/webhook.go b/component/webhook.go index d99c8be95..29f15a904 100644 --- a/component/webhook.go +++ b/component/webhook.go @@ -57,6 +57,24 @@ func NewWebHookComponent(config *config.Config, mqFactory bldmq.MessageQueueFact return nil, fmt.Errorf("failed to create kservice executor error: %w", err) } + // init dataflow executor + _, err = executors.NewDataflowExecutor(config) + if err != nil { + return nil, fmt.Errorf("failed to create dataflow executor error: %w", err) + } + + // init dataflow pod executor + _, err = executors.NewDataflowPodExecutor(config) + if err != nil { + return nil, fmt.Errorf("failed to create dataflow pod executor error: %w", err) + } + + // // init sandbox executor + // _, err = executors.NewSandboxExecutor(config) + // if err != nil { + // return nil, fmt.Errorf("failed to create sandbox executor error: %w", err) + // } + mq, err := mqFactory.GetInstance() if err != nil { return nil, fmt.Errorf("failed to get mq instance error: %w", err) diff --git a/runner/component/dataflow.go b/runner/component/dataflow.go new file mode 100644 index 000000000..744c4d10a --- /dev/null +++ b/runner/component/dataflow.go @@ -0,0 +1,716 @@ +package component + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "path" + "strconv" + "sync" + "time" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/argoproj/argo-workflows/v3/pkg/client/informers/externalversions" + internalinterfaces "github.com/argoproj/argo-workflows/v3/pkg/client/informers/externalversions/internalinterfaces" + corev1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" + "k8s.io/utils/ptr" + "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/runner/common" + sched "opencsg.com/csghub-server/runner/component/kube_scheduler" + rtypes "opencsg.com/csghub-server/runner/types" + "opencsg.com/csghub-server/runner/utils" +) + +type DataflowComponent interface { + CreateWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) + GetStatus(ctx context.Context, req *types.DataflowArgoReq) (*types.DataflowArgoJobResp, error) + DeleteWorkflow(ctx context.Context, req *types.DataflowArgoReq) error + RunInformer() +} + +type dataflowComponentImpl struct { + config *config.Config + clusterPool cluster.Pool + namespace string + wfStore database.ArgoWorkFlowStore +} + +func NewDataflowComponent(config *config.Config, clusterPool cluster.Pool) DataflowComponent { + df := &dataflowComponentImpl{ + config: config, + clusterPool: clusterPool, + namespace: config.Cluster.SpaceNamespace, + wfStore: database.NewArgoWorkFlowStore(), + } + go df.RunInformer() + return df +} + +func (d *dataflowComponentImpl) CreateWorkflow(ctx context.Context, req *types.DataflowArgoJobReq) (*types.DataflowArgoJobResp, error) { + cluster, err := d.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return nil, fmt.Errorf("failed to get cluster %s for dataflow job %s error: %w", req.ClusterID, req.JobID, err) + } + + dwf, err := d.buildWorkflow(req) + if err != nil { + return nil, fmt.Errorf("failed to build dataflow workflow job %s error: %w", req.JobID, err) + } + + if err := d.createPVC(ctx, cluster, req); err != nil { + return nil, fmt.Errorf("failed to create pvc for dataflow workflow job %s error: %w", req.JobID, err) + } + + dfWorkflow, err := cluster.ArgoClient.ArgoprojV1alpha1().Workflows(d.namespace).Create(ctx, dwf, v1.CreateOptions{}) + if err != nil { + delErr := d.deletePVC(ctx, cluster, &types.DataflowArgoReq{ClusterID: req.ClusterID, ArgoTaskID: req.ArgoTaskID}) + if delErr != nil { + slog.ErrorContext(ctx, "delete pvc due to create dataflow workflow job %s failed error: %w", req.ArgoTaskID, delErr) + } + return nil, fmt.Errorf("failed to create dataflow workflow job %s error: %w", req.JobID, err) + } + slog.InfoContext(ctx, "create dataflow workflow success", + slog.String("namespace", d.namespace), slog.String("name", dfWorkflow.Name)) + + return &types.DataflowArgoJobResp{ + ID: req.ID, + ArgoTaskID: dfWorkflow.Name, + JobID: req.JobID, + JobName: req.JobName, + Status: string(v1alpha1.WorkflowPending), + Message: dfWorkflow.Status.Message, + CreatedAt: dfWorkflow.CreationTimestamp.Unix(), + }, nil +} + +func genPVCName(taskID string) string { + return types.DFPVCNamePrefix + taskID +} + +func (d *dataflowComponentImpl) createPVC(ctx context.Context, cluster *cluster.Cluster, req *types.DataflowArgoJobReq) error { + pvcName := genPVCName(req.ArgoTaskID) + _, err := cluster.Client.CoreV1().PersistentVolumeClaims(d.namespace).Get(ctx, pvcName, v1.GetOptions{}) + if err == nil { + slog.WarnContext(ctx, "pvc already exists", slog.Any("pvcName", pvcName), + slog.Any("argoTaskID", req.ArgoTaskID), slog.Any("jobid", req.JobID)) + return nil + } + + storageSize, err := resource.ParseQuantity(req.StorageSize) + if err != nil { + return fmt.Errorf("failed to parse storage size %s for dataflow job %s, taskid %s, error: %w", req.StorageSize, req.JobID, req.ArgoTaskID, err) + } + + pvc := corev1.PersistentVolumeClaim{ + ObjectMeta: v1.ObjectMeta{ + Namespace: d.namespace, + Name: pvcName, + }, + Spec: corev1.PersistentVolumeClaimSpec{ + AccessModes: []corev1.PersistentVolumeAccessMode{ + corev1.ReadWriteMany, + }, + Resources: corev1.VolumeResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: storageSize, + }, + }, + StorageClassName: &cluster.StorageClass, + }, + } + + _, err = cluster.Client.CoreV1().PersistentVolumeClaims(d.namespace).Create(ctx, &pvc, v1.CreateOptions{}) + if err != nil { + return fmt.Errorf("failed to create pvc %s for dataflow job %s failed: %w", pvcName, req.ArgoTaskID, err) + } + + return nil +} + +func (d *dataflowComponentImpl) deletePVC(ctx context.Context, cluster *cluster.Cluster, req *types.DataflowArgoReq) error { + pvcName := genPVCName(req.ArgoTaskID) + err := cluster.Client.CoreV1().PersistentVolumeClaims(d.namespace).Delete(ctx, pvcName, v1.DeleteOptions{}) + return err +} + +func (d *dataflowComponentImpl) buildWorkflow(req *types.DataflowArgoJobReq) (*v1alpha1.Workflow, error) { + applier := sched.NewApplier(req.Scheduler) + deployExt := types.DeployExtend{ + NodeAffinity: req.NodeAffinity, + Tolerations: req.Tolerations, + } + genRes := common.GenerateResources(rtypes.ResourceGeneratorParams{ + Hardware: req.Template.HardWare, + Nodes: req.Nodes, + DeployExt: deployExt, + Config: d.config, + }) + resReq, nodeAffinity := genRes.ResourceRequirements, genRes.NodeAffinity + resources := corev1.ResourceRequirements{ + Limits: resReq, + Requests: resReq, + } + annotations := map[string]string{ + types.DFUniqueIDKey: fmt.Sprintf("%d", req.ID), + types.DFJobIDKey: req.JobID, + types.DFJobNameKey: req.JobName, + types.DFArgoTaskIDKey: req.ArgoTaskID, + types.DFOpUserUUIDKey: req.OpUserUUID, + types.DFOpUserNameKey: req.Username, + types.DFNSUUIDKey: req.NSUUID, + types.DFClusterIDKey: req.ClusterID, + types.DFResourceIDKey: fmt.Sprintf("%d", req.ResourceId), + types.DFResourceNameKey: req.ResourceName, + types.DFJobDescKey: req.JobDesc, + types.DFImageKey: req.Template.Image, + types.DFStorageSizeKey: req.StorageSize, + } + + templates := []v1alpha1.Template{} + volumeName := "workflow-data" + + volumes := []corev1.Volume{ + { + Name: volumeName, + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: genPVCName(req.ArgoTaskID), + ReadOnly: false, + }, + }, + }, + } + + runtimeTemp := d.buildRuntimeTemplate(volumeName, annotations, resources, req) + + // merge node affinity + utils.FillAffinity(&runtimeTemp.Affinity, nodeAffinity) + // fill tolerations + if len(genRes.Tolerations) > 0 { + runtimeTemp.Tolerations = genRes.Tolerations + } + if err := applier.ApplyToArgo(runtimeTemp); err != nil { + return nil, fmt.Errorf("failed to apply scheduler to dataflow runtime template: %v", err) + } + + dagTemp := d.buildDAGTemplate(req) + + templates = append(templates, *runtimeTemp) + templates = append(templates, *dagTemp) + + dataflowObject := &v1alpha1.Workflow{ + ObjectMeta: v1.ObjectMeta{ + Namespace: d.namespace, + Name: req.ArgoTaskID, + Annotations: annotations, + Labels: map[string]string{ + types.DFUniqueIDKey: fmt.Sprintf("%d", req.ID), + types.DFLabelTagKey: types.DFLabelTagValue, + types.DFJobIDKey: req.JobID, + types.DFArgoTaskIDKey: req.ArgoTaskID, + }, + }, + Spec: v1alpha1.WorkflowSpec{ + ServiceAccountName: d.config.Argo.ServiceAccountName, + Entrypoint: req.Entrypoint, + Volumes: volumes, + Templates: templates, + TTLStrategy: &v1alpha1.TTLStrategy{ + // Set TTL here + SecondsAfterCompletion: ptr.To(int32(d.config.Argo.JobTTL)), + }, + }, + } + + return dataflowObject, nil +} + +func (d *dataflowComponentImpl) buildRuntimeTemplate( + volumeName string, + annotations map[string]string, + resources corev1.ResourceRequirements, + req *types.DataflowArgoJobReq) *v1alpha1.Template { + containerImg := path.Join(d.config.Model.DockerRegBase, req.Template.Image) + + params := []v1alpha1.Parameter{} + for _, param := range req.Template.Parameters { + params = append(params, v1alpha1.Parameter{ + Name: param, + }) + } + params = append(params, + v1alpha1.Parameter{ + Name: types.DFParamDagTaskIDKey, + }, v1alpha1.Parameter{ + Name: types.DFParamDagTaskNameKey, + }, + ) + + runtimeTemp := &v1alpha1.Template{ + Name: req.Template.Name, // "echo" + Inputs: v1alpha1.Inputs{ + Parameters: params, // []v1alpha1.Parameter{ { Name: "cmd" }, { Name: "task_id" } }, + }, + Metadata: v1alpha1.Metadata{ + Annotations: annotations, + Labels: map[string]string{ + types.DFArgoTaskIDKey: req.ArgoTaskID, + types.DFUniqueIDKey: fmt.Sprintf("%d", req.ID), + types.DFLabelTagKey: types.DFLabelTagValue, + types.DFJobIDKey: req.JobID, + types.DFLabelDagTaskIDKey: fmt.Sprintf("{{inputs.parameters.%s}}", types.DFParamDagTaskIDKey), + types.DFLabelDagTaskNameKey: fmt.Sprintf("{{inputs.parameters.%s}}", types.DFParamDagTaskNameKey), + types.StreamKeyDeployID: req.ArgoTaskID, + }, + }, + Container: &corev1.Container{ + // example: "opencsg-registry.cn-beijing.cr.aliyuncs.com/opencsg_public/alpine:latest", + Image: containerImg, + Command: req.Template.Command, // []string{"sh", "-c"}, + Args: req.Template.Args, // []string{"{{inputs.parameters.cmd}}"}, + VolumeMounts: []corev1.VolumeMount{ + {Name: volumeName, MountPath: "/data/dataflow_data"}, + }, + Resources: resources, + ImagePullPolicy: corev1.PullAlways, + }, + } + + return runtimeTemp +} + +func (d *dataflowComponentImpl) buildDAGTemplate(req *types.DataflowArgoJobReq) *v1alpha1.Template { + tasks := []v1alpha1.DAGTask{} + for _, task := range req.DagTasks { + taskParams := []v1alpha1.Parameter{} + taskParams = append(taskParams, + v1alpha1.Parameter{ + Name: types.DFParamDagTaskIDKey, + Value: v1alpha1.AnyStringPtr(task.ID), + }, + v1alpha1.Parameter{ + Name: types.DFParamDagTaskNameKey, + Value: v1alpha1.AnyStringPtr(task.Name), + }, + ) + for _, param := range task.Parameters { + taskParams = append(taskParams, v1alpha1.Parameter{ + Name: param.Name, + Value: v1alpha1.AnyStringPtr(param.Value), + }) + } + tasks = append(tasks, v1alpha1.DAGTask{ + Name: task.Name, + Template: task.Template, + Dependencies: task.Deps, + Arguments: v1alpha1.Arguments{ + Parameters: taskParams, + }, + }) + } + + dagTemp := &v1alpha1.Template{ + Name: req.Entrypoint, + DAG: &v1alpha1.DAGTemplate{ + Tasks: tasks, + }, + } + + return dagTemp +} + +func (d *dataflowComponentImpl) GetStatus(ctx context.Context, req *types.DataflowArgoReq) (*types.DataflowArgoJobResp, error) { + cluster, err := d.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return nil, err + } + + workflow, err := cluster.ArgoClient.ArgoprojV1alpha1().Workflows(d.namespace).Get(ctx, req.ArgoTaskID, v1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("dataflow %s workflow not found error: %w", req.ArgoTaskID, err) + } + + return &types.DataflowArgoJobResp{ + ArgoTaskID: req.ArgoTaskID, + JobID: workflow.Annotations[types.DFJobIDKey], + JobName: workflow.Annotations[types.DFJobNameKey], + Status: string(workflow.Status.Phase), + Message: workflow.Status.Message, + CreatedAt: workflow.CreationTimestamp.Unix(), + }, nil +} + +func (d *dataflowComponentImpl) DeleteWorkflow(ctx context.Context, req *types.DataflowArgoReq) error { + cluster, err := d.clusterPool.GetClusterByID(ctx, req.ClusterID) + if err != nil { + return fmt.Errorf("failed to get cluster %s error: %w", req.ClusterID, err) + } + + err = cluster.ArgoClient.ArgoprojV1alpha1().Workflows(d.namespace).Delete(ctx, req.ArgoTaskID, v1.DeleteOptions{}) + if err != nil { + k8serr := new(k8serrors.StatusError) + if errors.As(err, &k8serr) { + if k8serr.Status().Code != http.StatusNotFound { + return fmt.Errorf("failed to delete dataflow %s workflow error: %w", req.ArgoTaskID, err) + } else { + slog.WarnContext(ctx, "dataflow %s workflow not found for delete", slog.Any("task_id", req.ArgoTaskID)) + } + } else { + return fmt.Errorf("failed to delete dataflow %s workflow error: %w", req.ArgoTaskID, err) + } + } + + err = d.deletePVC(ctx, cluster, req) + if err != nil { + slog.ErrorContext(ctx, "failed to delete dataflow pvc", slog.Any("task_id", req.ArgoTaskID), slog.Any("error", err)) + } + + return nil +} + +// RunInformer starts workflow and pod informers for all clusters +func (d *dataflowComponentImpl) RunInformer() { + ctx := context.Background() + + var wg sync.WaitGroup + stopCh := make(chan struct{}) + defer close(stopCh) + defer runtime.HandleCrash() + + clusters := d.clusterPool.GetAllCluster() + for _, cls := range clusters { + _, err := cls.Client.Discovery().ServerVersion() + if err != nil { + slog.ErrorContext(ctx, "cluster is unavailable for dataflow informer", slog.Any("cluster config", cls.CID), slog.Any("error", err)) + continue + } + + wg.Go(func() { + d.runWorkflowInformer(stopCh, cls) + }) + wg.Go(func() { + d.runPodInformer(stopCh, cls) + }) + } + slog.InfoContext(ctx, "dataflow informer started") + // wait for all informers to start + wg.Wait() +} + +// runWorkflowInformer watches Argo Workflow events +func (d *dataflowComponentImpl) runWorkflowInformer(stopCh <-chan struct{}, cluster *cluster.Cluster) { + labelSelector := fmt.Sprintf("%s=%s", types.DFLabelTagKey, types.DFLabelTagValue) + client := cluster.ArgoClient + + f := externalversions.NewFilteredSharedInformerFactory( + client, + 2*time.Minute, + d.namespace, + internalinterfaces.TweakListOptionsFunc(func(list *v1.ListOptions) { + list.LabelSelector = labelSelector + }), + ) + + informer := f.Argoproj().V1alpha1().Workflows().Informer() + + eventHandler := cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + wf := obj.(*v1alpha1.Workflow) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handleWorkflowEvent(ctx, wf, types.RunnerDataflowChange); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow workflow create event", + slog.Any("error", err), slog.String("workflow", wf.Name)) + } + }, + UpdateFunc: func(oldObj, newObj interface{}) { + // oldWF := oldObj.(*v1alpha1.Workflow) + newWF := newObj.(*v1alpha1.Workflow) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handleWorkflowEvent(ctx, newWF, types.RunnerDataflowChange); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow workflow update event", + slog.Any("error", err), slog.String("workflow", newWF.Name)) + } + }, + DeleteFunc: func(obj interface{}) { + wf, ok := obj.(*v1alpha1.Workflow) + if !ok { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handleWorkflowEvent(ctx, wf, types.RunnerDataflowDelete); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow workflow delete event", + slog.Any("error", err), slog.String("workflow", wf.Name)) + } + + pvcReq := &types.DataflowArgoReq{ + ArgoTaskID: wf.Name, + } + err := d.deletePVC(ctx, cluster, pvcReq) + if err != nil { + slog.ErrorContext(ctx, "failed to delete dataflow pvc due workflow delete informer event", + slog.Any("task_id", pvcReq.ArgoTaskID), slog.Any("error", err)) + } + + }, + } + + _, err := informer.AddEventHandler(eventHandler) + if err != nil { + runtime.HandleError(fmt.Errorf("failed to add event handler for dataflow workflow informer: %w", err)) + return + } + + informer.Run(stopCh) + if !cache.WaitForCacheSync(stopCh, informer.HasSynced) { + runtime.HandleError(fmt.Errorf("timed out waiting for dataflow workflow caches to sync")) + } +} + +// handleWorkflowEvent processes workflow events and reports to csghub +func (d *dataflowComponentImpl) handleWorkflowEvent(ctx context.Context, wf *v1alpha1.Workflow, eventType types.WebHookEventType) error { + annotations := wf.Annotations + if len(annotations) < 1 { + return fmt.Errorf("workflow %s/%s has no annotations", wf.Namespace, wf.Name) + } + + resID := annotations[types.DFResourceIDKey] + if len(resID) < 1 { + slog.WarnContext(ctx, "workflow has no resource id", slog.Any("dataflow", wf.Name)) + } + resIDInt, err := strconv.ParseInt(resID, 10, 64) + if err != nil { + slog.WarnContext(ctx, "dataflow workflow has invalid resource id", + slog.Any("wf.name", wf.Name), slog.Any("error", err), slog.String("resource_id", resID)) + } + + wfStatus := v1alpha1.WorkflowPending + if len(wf.Status.Phase) > 0 { + wfStatus = wf.Status.Phase + } + + // Extract info from annotations + wfInfo := &database.ArgoWorkflow{ + Username: annotations[types.DFOpUserNameKey], + UserUUID: annotations[types.DFNSUUIDKey], + TaskName: annotations[types.DFJobNameKey], + TaskId: wf.Name, + TaskType: types.TaskTypeDataflow, + ClusterID: annotations[types.DFClusterIDKey], + Namespace: wf.Namespace, + RepoType: string(types.DatasetRepo), + TaskDesc: annotations[types.DFJobDescKey], + Image: annotations[types.DFImageKey], + ResourceId: resIDInt, + ResourceName: annotations[types.DFResourceNameKey], + Status: wfStatus, + Reason: wf.Status.Message, + QueueName: annotations[rtypes.VolcanoAnnoQueue], + } + if !wf.Status.StartedAt.IsZero() { + wfInfo.StartTime = wf.Status.StartedAt.Time + } + if !wf.Status.FinishedAt.IsZero() { + wfInfo.EndTime = wf.Status.FinishedAt.Time + } + + if len(wfInfo.TaskId) < 1 { + return fmt.Errorf("dataflow workflow %s has no task id", wf.Name) + } + + slog.InfoContext(ctx, "handling dataflow workflow event", + slog.String("event_type", string(eventType)), + slog.Any("wf", wfInfo)) + + // find workflow in database or create it + dbWF, err := d.wfStore.FindByTaskID(ctx, wfInfo.TaskId) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + slog.WarnContext(ctx, "failed to find dataflow workflow in db", + slog.Any("error", err), slog.String("task_id", wfInfo.TaskId)) + } + if errors.Is(err, sql.ErrNoRows) { + dbWF, err = d.wfStore.CreateWorkFlow(ctx, *wfInfo) + if err != nil { + slog.ErrorContext(ctx, "dataflow workflow failed to create in db", + slog.Any("error", err), slog.Any("wfInfo", wfInfo)) + } + } + if dbWF == nil { + slog.WarnContext(ctx, "dataflow workflow not found in db", slog.Any("wfInfo", wfInfo)) + } else { + // Update workflow status + dbWF.Status = wfStatus + dbWF.Reason = wf.Status.Message + if !wf.Status.StartedAt.IsZero() { + dbWF.StartTime = wf.Status.StartedAt.Time + } + if !wf.Status.FinishedAt.IsZero() { + dbWF.EndTime = wf.Status.FinishedAt.Time + } + dbWF, err := d.wfStore.UpdateWorkFlow(ctx, *dbWF) + if err != nil { + slog.ErrorContext(ctx, "failed to update dataflow workflow in db", + slog.Any("error", err), slog.Any("dbWF", dbWF)) + } + } + + // Report event to csghub + d.reportDataflowEvent(ctx, wfInfo.ClusterID, wfInfo, eventType) + + return nil +} + +// runPodInformer watches Pod events for dataflow workloads +func (d *dataflowComponentImpl) runPodInformer(stopCh <-chan struct{}, cluster *cluster.Cluster) { + labelSelector := fmt.Sprintf("%s=%s", types.DFLabelTagKey, types.DFLabelTagValue) + + factory := informers.NewSharedInformerFactoryWithOptions( + cluster.Client, + 1*time.Hour, + informers.WithNamespace(d.namespace), + informers.WithTweakListOptions(func(options *v1.ListOptions) { + options.LabelSelector = labelSelector + }), + ) + + podInformer := factory.Core().V1().Pods() + + _, err := podInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + pod, ok := obj.(*corev1.Pod) + if !ok { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handlePodEvent(ctx, pod, types.RunnerDataflowPodUpdate); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow pod add event", + slog.Any("error", err), slog.String("pod", pod.Name)) + } + }, + UpdateFunc: func(oldObj, newObj interface{}) { + // oldPod := oldObj.(*corev1.Pod) + newPod, ok := newObj.(*corev1.Pod) + if !ok { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handlePodEvent(ctx, newPod, types.RunnerDataflowPodUpdate); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow pod update event", + slog.Any("error", err), slog.String("pod", newPod.Name)) + } + }, + DeleteFunc: func(obj interface{}) { + pod, ok := obj.(*corev1.Pod) + if !ok { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := d.handlePodEvent(ctx, pod, types.RunnerDataflowPodDelete); err != nil { + slog.ErrorContext(ctx, "failed to handle dataflow pod delete event", + slog.Any("error", err), slog.String("pod", pod.Name)) + } + }, + }) + + if err != nil { + runtime.HandleError(fmt.Errorf("failed to add event handler for dataflow pod informer: %w", err)) + return + } + + factory.Start(stopCh) + + if !cache.WaitForCacheSync(stopCh, podInformer.Informer().HasSynced) { + runtime.HandleError(fmt.Errorf("timed out waiting for dataflow pod caches to sync")) + } +} + +// handlePodEvent processes pod events and reports to csghub +func (d *dataflowComponentImpl) handlePodEvent(ctx context.Context, pod *corev1.Pod, eventType types.WebHookEventType) error { + annotations := pod.Annotations + if len(annotations) < 1 { + return fmt.Errorf("dataflow pod %s/%s has no annotations", pod.Namespace, pod.Name) + } + + // Extract info from annotations + clusterID := annotations[types.DFClusterIDKey] + taskID := annotations[types.DFArgoTaskIDKey] + dagTaskID := pod.Labels[types.DFLabelDagTaskIDKey] + dagTaskName := pod.Labels[types.DFLabelDagTaskNameKey] + if len(dagTaskID) < 1 { + return fmt.Errorf("dataflow pod %s/%s has no dag task id", pod.Namespace, pod.Name) + } + + dagTask := types.DataflowDagTask{ + Name: dagTaskName, + Status: string(pod.Status.Phase), + } + + if !pod.Status.StartTime.IsZero() { + dagTask.StartTime = pod.Status.StartTime.Format("2006-01-02 15:04:05.000") + } + + if pod.Status.Phase != corev1.PodPending && pod.Status.Phase != corev1.PodRunning { + dagTask.EndTime = time.Now().Format("2006-01-02 15:04:05.000") + } + + podMap := make(map[string]types.DataflowDagTask) + podMap[dagTaskID] = dagTask + + jsonStr, err := json.Marshal(podMap) + if err != nil { + return fmt.Errorf("failed to marshal dag_tasks pod map: %w", err) + } + + wfInfo := &database.ArgoWorkflow{ + TaskId: taskID, + ClusterNode: pod.Spec.NodeName, + DagTasks: string(jsonStr), + } + + slog.InfoContext(ctx, "handling dataflow pod event", slog.Any("wfInfo", wfInfo)) + + // Report event to csghub + d.reportDataflowEvent(ctx, clusterID, wfInfo, eventType) + return nil +} + +// reportDataflowEvent sends event to csghub API +func (d *dataflowComponentImpl) reportDataflowEvent(ctx context.Context, clusterID string, wf *database.ArgoWorkflow, eventType types.WebHookEventType) { + event := &types.WebHookSendEvent{ + WebHookHeader: types.WebHookHeader{ + EventType: eventType, + EventTime: time.Now().Unix(), + ClusterID: clusterID, + DataType: types.WebHookDataTypeObject, + }, + Data: wf, + } + + slog.InfoContext(ctx, "reporting dataflow event", slog.Any("event", event)) + + go func() { + err := common.Push(d.config.Runner.WebHookEndpoint, d.config.APIToken, event) + if err != nil { + slog.ErrorContext(ctx, "failed to push dataflow workflow event", slog.Any("error", err)) + } + }() +} diff --git a/runner/component/dataflow_test.go b/runner/component/dataflow_test.go new file mode 100644 index 000000000..e98db96f1 --- /dev/null +++ b/runner/component/dataflow_test.go @@ -0,0 +1,363 @@ +package component + +import ( + "context" + "errors" + "testing" + + "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + argofake "github.com/argoproj/argo-workflows/v3/pkg/client/clientset/versioned/fake" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + mockCluster "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/deploy/cluster" + mockdb "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func newTestDataflowComponent(t *testing.T) (*dataflowComponentImpl, *mockCluster.MockPool) { + pool := mockCluster.NewMockPool(t) + df := &dataflowComponentImpl{ + config: &config.Config{}, + clusterPool: pool, + namespace: "test-ns", + wfStore: mockdb.NewMockArgoWorkFlowStore(t), + } + return df, pool +} + +func TestDataflowComponent_CreateWorkflow(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + Client: kubeClient, + ArgoClient: argofake.NewSimpleClientset(), + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + req := &types.DataflowArgoJobReq{ + ID: 1, + ClusterID: "test-cluster", + ArgoTaskID: "df-test-task", + JobID: "df-test-job", + JobName: "test-job", + JobDesc: "test desc", + StorageSize: "10Gi", + Entrypoint: "main", + Template: types.ArgoFlowTemplate{ + Name: "echo", + Image: "alpine:latest", + Command: []string{"echo"}, + Args: []string{"hello"}, + Parameters: []string{"cmd", "task_id"}, + }, + DagTasks: []types.ArgoDagTask{ + {ID: "task-1", Name: "task1", Template: "echo", Deps: []string{}}, + }, + Nodes: []types.Node{ + {Name: "node-1"}, + }, + } + + resp, err := df.CreateWorkflow(ctx, req) + require.NoError(t, err) + require.Equal(t, req.ID, resp.ID) + require.Equal(t, req.ArgoTaskID, resp.ArgoTaskID) + require.Equal(t, req.JobID, resp.JobID) + require.Equal(t, req.JobName, resp.JobName) + require.Equal(t, "Pending", resp.Status) + + pvcName := types.DFPVCNamePrefix + req.ArgoTaskID + pvc, err := kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Get(ctx, pvcName, metav1.GetOptions{}) + require.NoError(t, err) + require.Equal(t, pvcName, pvc.Name) + + wf, err := testCluster.ArgoClient.ArgoprojV1alpha1().Workflows(df.namespace).Get(ctx, req.ArgoTaskID, metav1.GetOptions{}) + require.NoError(t, err) + require.Equal(t, req.ArgoTaskID, wf.Name) + }) + + t.Run("cluster not found", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + pool.EXPECT().GetClusterByID(ctx, "unknown-cluster").Return(nil, errors.New("cluster not found")) + + req := &types.DataflowArgoJobReq{ + ClusterID: "unknown-cluster", + } + + resp, err := df.CreateWorkflow(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to get cluster") + }) + + t.Run("argo create fails and cleans up pvc", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + argoClient := argofake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + Client: kubeClient, + ArgoClient: argoClient, + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + req := &types.DataflowArgoJobReq{ + ID: 1, + ClusterID: "test-cluster", + ArgoTaskID: "df-test-fail", + JobID: "df-test-job", + JobName: "test-job", + JobDesc: "test desc", + StorageSize: "10Gi", + Entrypoint: "main", + Template: types.ArgoFlowTemplate{ + Name: "echo", + Image: "alpine:latest", + Command: []string{"echo"}, + Args: []string{"hello"}, + Parameters: []string{"cmd", "task_id"}, + }, + DagTasks: []types.ArgoDagTask{ + {ID: "task-1", Name: "task1", Template: "echo", Deps: []string{}}, + }, + Nodes: []types.Node{ + {Name: "node-1"}, + }, + } + + existingWF := &v1alpha1.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: df.namespace, + Name: req.ArgoTaskID, + }, + } + _, err := argoClient.ArgoprojV1alpha1().Workflows(df.namespace).Create(ctx, existingWF, metav1.CreateOptions{}) + require.NoError(t, err) + + resp, err := df.CreateWorkflow(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "failed to create dataflow workflow") + + pvcName := types.DFPVCNamePrefix + req.ArgoTaskID + _, err = kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Get(ctx, pvcName, metav1.GetOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestDataflowComponent_deletePVC(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + df, _ := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + Client: kubeClient, + } + req := &types.DataflowArgoReq{ + ArgoTaskID: "df-test-pvc", + } + + pvcName := types.DFPVCNamePrefix + req.ArgoTaskID + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: df.namespace, + Name: pvcName, + }, + } + _, err := kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Create(ctx, pvc, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Get(ctx, pvcName, metav1.GetOptions{}) + require.NoError(t, err) + + err = df.deletePVC(ctx, testCluster, req) + require.NoError(t, err) + + _, err = kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Get(ctx, pvcName, metav1.GetOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("delete non-existent pvc returns error", func(t *testing.T) { + df, _ := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + Client: kubeClient, + } + req := &types.DataflowArgoReq{ + ArgoTaskID: "df-nonexistent", + } + + err := df.deletePVC(ctx, testCluster, req) + require.Error(t, err) + }) +} + +func TestDataflowComponent_GetStatus(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + argoClient := argofake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + ArgoClient: argoClient, + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + wf := &v1alpha1.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: df.namespace, + Name: "df-test-status", + }, + Status: v1alpha1.WorkflowStatus{ + Phase: v1alpha1.WorkflowRunning, + }, + } + _, err := argoClient.ArgoprojV1alpha1().Workflows(df.namespace).Create(ctx, wf, metav1.CreateOptions{}) + require.NoError(t, err) + + req := &types.DataflowArgoReq{ + ArgoTaskID: "df-test-status", + ClusterID: "test-cluster", + } + + resp, err := df.GetStatus(ctx, req) + require.NoError(t, err) + require.Equal(t, req.ArgoTaskID, resp.ArgoTaskID) + }) + + t.Run("workflow not found", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + argoClient := argofake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + ArgoClient: argoClient, + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + req := &types.DataflowArgoReq{ + ArgoTaskID: "nonexistent", + ClusterID: "test-cluster", + } + + resp, err := df.GetStatus(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "workflow not found") + }) + + t.Run("cluster not found", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + pool.EXPECT().GetClusterByID(ctx, "unknown-cluster").Return(nil, errors.New("cluster not found")) + + req := &types.DataflowArgoReq{ + ClusterID: "unknown-cluster", + } + + resp, err := df.GetStatus(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "cluster not found") + }) +} + +func TestDataflowComponent_DeleteWorkflow(t *testing.T) { + ctx := context.TODO() + + t.Run("success", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + argoClient := argofake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + Client: kubeClient, + ArgoClient: argoClient, + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + wf := &v1alpha1.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: df.namespace, + Name: "df-test-delete", + }, + } + _, err := argoClient.ArgoprojV1alpha1().Workflows(df.namespace).Create(ctx, wf, metav1.CreateOptions{}) + require.NoError(t, err) + + pvcName := types.DFPVCNamePrefix + "df-test-delete" + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: df.namespace, + Name: pvcName, + }, + } + _, err = kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Create(ctx, pvc, metav1.CreateOptions{}) + require.NoError(t, err) + + req := &types.DataflowArgoReq{ + ArgoTaskID: "df-test-delete", + ClusterID: "test-cluster", + } + + err = df.DeleteWorkflow(ctx, req) + require.NoError(t, err) + + _, err = argoClient.ArgoprojV1alpha1().Workflows(df.namespace).Get(ctx, req.ArgoTaskID, metav1.GetOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + + _, err = kubeClient.CoreV1().PersistentVolumeClaims(df.namespace).Get(ctx, pvcName, metav1.GetOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("workflow not found handled gracefully", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + kubeClient := fake.NewSimpleClientset() + argoClient := argofake.NewSimpleClientset() + testCluster := &cluster.Cluster{ + CID: "config", + ID: "test-cluster", + Client: kubeClient, + ArgoClient: argoClient, + } + pool.EXPECT().GetClusterByID(ctx, "test-cluster").Return(testCluster, nil) + + req := &types.DataflowArgoReq{ + ArgoTaskID: "nonexistent", + ClusterID: "test-cluster", + } + + err := df.DeleteWorkflow(ctx, req) + require.NoError(t, err) + }) + + t.Run("cluster not found", func(t *testing.T) { + df, pool := newTestDataflowComponent(t) + pool.EXPECT().GetClusterByID(ctx, "unknown-cluster").Return(nil, errors.New("cluster not found")) + + req := &types.DataflowArgoReq{ + ClusterID: "unknown-cluster", + } + + err := df.DeleteWorkflow(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to get cluster") + }) +} diff --git a/runner/component/workflow.go b/runner/component/workflow.go index ab066aa56..01fd79e64 100644 --- a/runner/component/workflow.go +++ b/runner/component/workflow.go @@ -127,7 +127,7 @@ func (wc *workFlowComponentImpl) CreateWorkflow(ctx context.Context, req types.A if findErr != nil { return nil, fmt.Errorf("failed to check workflow %s in db for create action, error: %w", argowf.TaskId, findErr) } - wf = &wfObj + wf = wfObj } wc.addKServiceWithEvent(ctx, types.RunnerWorkflowCreate, argowf) @@ -171,15 +171,15 @@ func (wc *workFlowComponentImpl) GetWorkflow(ctx context.Context, id int64, user // Update workflow func (wc *workFlowComponentImpl) UpdateWorkflow(ctx context.Context, update *v1alpha1.Workflow, cluster *cluster.Cluster) (*database.ArgoWorkflow, error) { oldwf, err := wc.wf.FindByTaskID(ctx, update.Name) - slog.InfoContext(ctx, "get-UpdateWorkflow-from-db", slog.Any("oldwf.TaskId", oldwf.TaskId), slog.Any("result-url", oldwf.ResultURL)) + slog.InfoContext(ctx, "get-UpdateWorkflow-from-db", slog.Any("oldwf", oldwf), slog.Any("err", err)) if errors.Is(err, sql.ErrNoRows) { - oldwf = *wc.getWorkflowFromLabels(ctx, update) - wf, err := wc.wf.CreateWorkFlow(ctx, oldwf) + oldwf = wc.getWorkflowFromLabels(ctx, update) + wf, err := wc.wf.CreateWorkFlow(ctx, *oldwf) if err != nil { slog.ErrorContext(ctx, "failed to create workflow in db", slog.Any("error", err)) return nil, fmt.Errorf("failed to create workflow in db: %v", err) } - oldwf = *wf + oldwf = wf } if err != nil { return nil, err @@ -228,11 +228,11 @@ func (wc *workFlowComponentImpl) UpdateWorkflow(ctx context.Context, update *v1a } slog.InfoContext(ctx, "UpdateWorkflow-report", slog.Any("name", oldwf.TaskId), slog.Any("result-url", oldwf.ResultURL)) - wc.addKServiceWithEvent(ctx, types.RunnerWorkflowChange, &oldwf) + wc.addKServiceWithEvent(ctx, types.RunnerWorkflowChange, oldwf) if lastStatus != oldwf.Status { - wc.reportWorFlowLog(types.WorkflowUpdated.String(), &oldwf) + wc.reportWorFlowLog(types.WorkflowUpdated.String(), oldwf) } - return wc.wf.UpdateWorkFlow(ctx, oldwf) + return wc.wf.UpdateWorkFlow(ctx, *oldwf) } // DeleteWorkflowInargo @@ -242,17 +242,17 @@ func (wc *workFlowComponentImpl) DeleteWorkflowInargo(ctx context.Context, delet return fmt.Errorf("failed to get workflow by id: %v", err) } - wc.reportWorFlowLog(types.WorkflowDeleted.String(), &wf) + wc.reportWorFlowLog(types.WorkflowDeleted.String(), wf) // for deleted case,check if the workflow did not finish if wf.Status == v1alpha1.WorkflowPending || wf.Status == v1alpha1.WorkflowRunning { wf.Status = v1alpha1.WorkflowFailed wf.Reason = "deleted by system, please check if your required resources are sufficient or if your account has enough credit" slog.InfoContext(ctx, "DeleteWorkflowInargo-report", slog.Any("name", wf.TaskId), slog.Any("result-url", wf.ResultURL)) - _, err = wc.wf.UpdateWorkFlow(ctx, wf) + _, err = wc.wf.UpdateWorkFlow(ctx, *wf) if err != nil { return err } - wc.addKServiceWithEvent(ctx, types.RunnerWorkflowChange, &wf) + wc.addKServiceWithEvent(ctx, types.RunnerWorkflowChange, wf) return nil } return nil diff --git a/runner/component/workflow_test.go b/runner/component/workflow_test.go index bd421b293..928751206 100644 --- a/runner/component/workflow_test.go +++ b/runner/component/workflow_test.go @@ -125,7 +125,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { ArgoClient: argofake.NewSimpleClientset(), } - existingWf := database.ArgoWorkflow{ + existingWf := &database.ArgoWorkflow{ ID: 1, TaskId: "test-task", Username: "test-user", @@ -184,7 +184,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { ArgoClient: argofake.NewSimpleClientset(), } - existingWf := database.ArgoWorkflow{ + existingWf := &database.ArgoWorkflow{ ID: 1, TaskId: "test-task", Username: "test-user", @@ -254,7 +254,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { ArgoClient: argofake.NewSimpleClientset(), } - existingWf := database.ArgoWorkflow{ + existingWf := &database.ArgoWorkflow{ ID: 1, TaskId: "test-task", Username: "test-user", @@ -312,7 +312,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { ArgoClient: argofake.NewSimpleClientset(), } - existingWf := database.ArgoWorkflow{ + existingWf := &database.ArgoWorkflow{ ID: 1, TaskId: "test-task", Username: "test-user", @@ -368,7 +368,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { ArgoClient: argofake.NewSimpleClientset(), } - existingWf := database.ArgoWorkflow{ + existingWf := &database.ArgoWorkflow{ ID: 1, TaskId: "test-task", Username: "test-user", @@ -457,7 +457,7 @@ func TestArgoComponent_UpdateWorkflow(t *testing.T) { }, } - argoStore.EXPECT().FindByTaskID(ctx, "new-task").Return(database.ArgoWorkflow{}, sql.ErrNoRows) + argoStore.EXPECT().FindByTaskID(ctx, "new-task").Return(&database.ArgoWorkflow{}, sql.ErrNoRows) argoStore.EXPECT().CreateWorkFlow(ctx, mock.Anything).Return(nil, errors.New("create error")) _, err := wfc.UpdateWorkflow(ctx, updateWf, cluster) diff --git a/runner/handler/dataflow.go b/runner/handler/dataflow.go new file mode 100644 index 000000000..1ec646d38 --- /dev/null +++ b/runner/handler/dataflow.go @@ -0,0 +1,87 @@ +package handler + +import ( + "log/slog" + "net/http" + + "github.com/gin-gonic/gin" + "opencsg.com/csghub-server/builder/deploy/cluster" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" + "opencsg.com/csghub-server/runner/component" +) + +// DataflowHandler handles dataflow job requests +type DataflowHandler struct { + clusterPool cluster.Pool + config *config.Config + dfc component.DataflowComponent +} + +func NewDataflowHandler(config *config.Config, clusterPool cluster.Pool) (*DataflowHandler, error) { + dfc := component.NewDataflowComponent(config, clusterPool) + return &DataflowHandler{ + clusterPool: clusterPool, + config: config, + dfc: dfc, + }, nil +} + +// CreateDataflowWorkflow creates a new dataflow workflow +func (h *DataflowHandler) CreateDataflowWorkflow(ctx *gin.Context) { + var req types.DataflowArgoJobReq + if err := ctx.ShouldBindJSON(&req); err != nil { + slog.ErrorContext(ctx, "bad request format", "error", err) + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + resp, err := h.dfc.CreateWorkflow(ctx.Request.Context(), &req) + if err != nil { + slog.ErrorContext(ctx, "failed to create dataflow workflow", slog.Any("error", err), slog.Any("req", req)) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, resp) +} + +// GetDataflowStatus gets the status of a dataflow workflow +func (h *DataflowHandler) GetDataflowStatus(ctx *gin.Context) { + taskID := ctx.Param("task_id") + clusterID := ctx.Query("cluster_id") + + req := types.DataflowArgoReq{ + ArgoTaskID: taskID, + ClusterID: clusterID, + } + + status, err := h.dfc.GetStatus(ctx.Request.Context(), &req) + if err != nil { + slog.ErrorContext(ctx, "failed to get dataflow status", slog.Any("error", err), slog.Any("req", req)) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, status) +} + +// DeleteDataflowWorkflow deletes a dataflow workflow +func (h *DataflowHandler) DeleteDataflowWorkflow(ctx *gin.Context) { + taskID := ctx.Param("task_id") + clusterID := ctx.Query("cluster_id") + + req := types.DataflowArgoReq{ + ArgoTaskID: taskID, + ClusterID: clusterID, + } + + err := h.dfc.DeleteWorkflow(ctx.Request.Context(), &req) + if err != nil { + slog.ErrorContext(ctx, "failed to delete dataflow workflow", slog.Any("error", err), slog.Any("req", req)) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "dataflow workflow deleted successfully"}) +} diff --git a/runner/handler/dataflow_test.go b/runner/handler/dataflow_test.go new file mode 100644 index 000000000..26f959f16 --- /dev/null +++ b/runner/handler/dataflow_test.go @@ -0,0 +1,210 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + mockcom "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/runner/component" + "opencsg.com/csghub-server/builder/testutil" + "opencsg.com/csghub-server/common/types" +) + +type DataflowTester struct { + *testutil.GinTester + handler *DataflowHandler + mocks struct { + dfComp *mockcom.MockDataflowComponent + } +} + +func (t *DataflowTester) WithHandleFunc(fn func(h *DataflowHandler) gin.HandlerFunc) *DataflowTester { + t.Handler(fn(t.handler)) + return t +} + +func NewDataflowTester(t *testing.T) *DataflowTester { + tester := &DataflowTester{GinTester: testutil.NewGinTester()} + tester.mocks.dfComp = mockcom.NewMockDataflowComponent(t) + tester.handler = &DataflowHandler{ + dfc: tester.mocks.dfComp, + } + return tester +} + +func TestDataflowHandler_CreateDataflowWorkflow(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.CreateDataflowWorkflow + }) + + req := types.DataflowArgoJobReq{ + ClusterID: "test-cluster", + ArgoTaskID: "df-task-1", + JobID: "df-job-1", + JobName: "test-job", + ResourceId: 100, + StorageSize: "10Gi", + Entrypoint: "main", + Template: types.ArgoFlowTemplate{ + Name: "echo", + Image: "alpine:latest", + }, + DagTasks: []types.ArgoDagTask{ + {ID: "task-1", Name: "task1", Template: "echo"}, + }, + } + resp := &types.DataflowArgoJobResp{ + ID: 1, + ArgoTaskID: "df-task-1", + JobID: "df-job-1", + JobName: "test-job", + Status: "Pending", + } + + tester.mocks.dfComp.EXPECT().CreateWorkflow(mock.Anything, &req).Return(resp, nil) + tester.WithBody(t, req) + tester.Execute() + + assert.Equal(t, http.StatusOK, tester.Response().Code) + + var actual types.DataflowArgoJobResp + err := json.Unmarshal(tester.Response().Body.Bytes(), &actual) + assert.NoError(t, err) + assert.Equal(t, resp.ID, actual.ID) + assert.Equal(t, resp.ArgoTaskID, actual.ArgoTaskID) + }) + + t.Run("bad request", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.CreateDataflowWorkflow + }) + + tester.WithBody(t, "invalid json will be parsed differently") + + tester.Execute() + + assert.Equal(t, http.StatusBadRequest, tester.Response().Code) + }) + + t.Run("component error", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.CreateDataflowWorkflow + }) + + req := types.DataflowArgoJobReq{ + ClusterID: "test-cluster", + ResourceId: 100, + JobID: "df-job-1", + JobName: "test-job", + StorageSize: "10Gi", + Entrypoint: "main", + Template: types.ArgoFlowTemplate{ + Name: "echo", + Image: "alpine:latest", + }, + DagTasks: []types.ArgoDagTask{ + {ID: "task-1", Name: "task1", Template: "echo"}, + }, + } + + tester.mocks.dfComp.EXPECT().CreateWorkflow(mock.Anything, &req).Return(nil, errors.New("creation failed")) + tester.WithBody(t, req) + tester.Execute() + + assert.Equal(t, http.StatusInternalServerError, tester.Response().Code) + }) +} + +func TestDataflowHandler_GetDataflowStatus(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.GetDataflowStatus + }) + + resp := &types.DataflowArgoJobResp{ + ArgoTaskID: "df-task-1", + JobID: "df-job-1", + JobName: "test-job", + Status: "Running", + } + + tester.mocks.dfComp.EXPECT().GetStatus(mock.Anything, &types.DataflowArgoReq{ + ArgoTaskID: "df-task-1", + ClusterID: "cluster-1", + }).Return(resp, nil) + + tester.WithParam("task_id", "df-task-1") + tester.WithQuery("cluster_id", "cluster-1") + tester.Execute() + + assert.Equal(t, http.StatusOK, tester.Response().Code) + + var actual types.DataflowArgoJobResp + err := json.Unmarshal(tester.Response().Body.Bytes(), &actual) + assert.NoError(t, err) + assert.Equal(t, resp.Status, actual.Status) + }) + + t.Run("component error", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.GetDataflowStatus + }) + + tester.mocks.dfComp.EXPECT().GetStatus(mock.Anything, &types.DataflowArgoReq{ + ArgoTaskID: "unknown-task", + ClusterID: "cluster-1", + }).Return(nil, errors.New("workflow not found")) + + tester.WithParam("task_id", "unknown-task") + tester.WithQuery("cluster_id", "cluster-1") + tester.Execute() + + assert.Equal(t, http.StatusInternalServerError, tester.Response().Code) + }) +} + +func TestDataflowHandler_DeleteDataflowWorkflow(t *testing.T) { + t.Run("success", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.DeleteDataflowWorkflow + }) + + tester.mocks.dfComp.EXPECT().DeleteWorkflow(mock.Anything, &types.DataflowArgoReq{ + ArgoTaskID: "df-task-1", + ClusterID: "cluster-1", + }).Return(nil) + + tester.WithParam("task_id", "df-task-1") + tester.WithQuery("cluster_id", "cluster-1") + tester.Execute() + + assert.Equal(t, http.StatusOK, tester.Response().Code) + + var body map[string]string + err := json.Unmarshal(tester.Response().Body.Bytes(), &body) + assert.NoError(t, err) + assert.Equal(t, "dataflow workflow deleted successfully", body["message"]) + }) + + t.Run("component error", func(t *testing.T) { + tester := NewDataflowTester(t).WithHandleFunc(func(h *DataflowHandler) gin.HandlerFunc { + return h.DeleteDataflowWorkflow + }) + + tester.mocks.dfComp.EXPECT().DeleteWorkflow(mock.Anything, &types.DataflowArgoReq{ + ArgoTaskID: "unknown-task", + ClusterID: "cluster-1", + }).Return(errors.New("delete failed")) + + tester.WithParam("task_id", "unknown-task") + tester.WithQuery("cluster_id", "cluster-1") + tester.Execute() + + assert.Equal(t, http.StatusInternalServerError, tester.Response().Code) + }) +} diff --git a/runner/router/api.go b/runner/router/api.go index 2793e5a2e..dcc44c163 100644 --- a/runner/router/api.go +++ b/runner/router/api.go @@ -82,6 +82,18 @@ func NewHttpServer(ctx context.Context, config *config.Config) (*gin.Engine, err workflows.GET("/:id", argoHandler.GetWorkflow) } + // dataflow + dataflowHandler, err := handler.NewDataflowHandler(config, clusterPool) + if err != nil { + return nil, fmt.Errorf("failed to build NewDataflowHandler error: %w", err) + } + dataflowGroup := apiGroup.Group("/dataflow/jobs") + { + dataflowGroup.POST("", dataflowHandler.CreateDataflowWorkflow) + dataflowGroup.GET("/:task_id", dataflowHandler.GetDataflowStatus) + dataflowGroup.DELETE("/:task_id", dataflowHandler.DeleteDataflowWorkflow) + } + // image builder imagebuilderHandler, err := handler.NewImagebuilderHandler(ctx, config, clusterPool, logReporter) if err != nil {