From 9c5afe94c2c067a69a37e1ee5fd282fee6682a4c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 02:15:18 +0000 Subject: [PATCH 1/2] Add per-user workflow isolation: migration 28, service updates, router ownership checks, is_public endpoint, schema regeneration, frontend UI Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- invokeai/app/api/routers/workflows.py | 124 +++++++- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_28.py | 45 +++ .../workflow_records/workflow_records_base.py | 16 +- .../workflow_records_common.py | 6 + .../workflow_records_sqlite.py | 73 ++++- invokeai/frontend/web/openapi.json | 177 ++++++++++- invokeai/frontend/web/public/locales/en.json | 2 + .../WorkflowLibrarySideNav.tsx | 1 + .../workflow/WorkflowLibrary/WorkflowList.tsx | 10 + .../WorkflowLibrary/WorkflowListItem.tsx | 53 +++- .../WorkflowLibrary/WorkflowSortControl.tsx | 3 +- .../nodes/store/workflowLibrarySlice.ts | 15 +- .../components/SaveWorkflowAsDialog.tsx | 25 +- .../hooks/useCreateNewWorkflow.ts | 4 +- .../src/services/api/endpoints/workflows.ts | 16 + .../frontend/web/src/services/api/schema.ts | 104 ++++++- .../frontend/web/src/services/api/types.ts | 2 +- tests/app/routers/test_workflows_multiuser.py | 274 ++++++++++++++++++ 19 files changed, 920 insertions(+), 32 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py create mode 100644 tests/app/routers/test_workflows_multiuser.py diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 72d50a416b4..7e34660a1df 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -6,6 +6,7 @@ from fastapi.responses import FileResponse from PIL import Image +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection @@ -33,16 +34,25 @@ }, ) async def get_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to get"), ) -> WorkflowRecordWithThumbnailDTO: """Gets a workflow""" try: - thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id) - return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + is_default = workflow.workflow.meta.category is WorkflowCategory.Default + is_owner = workflow.user_id == current_user.user_id + if not (is_default or is_owner or workflow.is_public or current_user.is_admin): + raise HTTPException(status_code=403, detail="Not authorized to access this workflow") + + thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) + return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) + @workflows_router.patch( "/i/{workflow_id}", @@ -52,9 +62,18 @@ async def get_workflow( }, ) async def update_workflow( + current_user: CurrentUserOrDefault, workflow: Workflow = Body(description="The updated workflow", embed=True), ) -> WorkflowRecordDTO: """Updates a workflow""" + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow) @@ -63,9 +82,18 @@ async def update_workflow( operation_id="delete_workflow", ) async def delete_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to delete"), ) -> None: """Deletes a workflow""" + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to delete this workflow") try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except WorkflowThumbnailFileNotFoundException: @@ -82,10 +110,11 @@ async def delete_workflow( }, ) async def create_workflow( + current_user: CurrentUserOrDefault, workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True), ) -> WorkflowRecordDTO: """Creates a workflow""" - return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow) + return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id) @workflows_router.get( @@ -96,6 +125,7 @@ async def create_workflow( }, ) async def list_workflows( + current_user: CurrentUserOrDefault, page: int = Query(default=0, description="The page to get"), per_page: Optional[int] = Query(default=None, description="The number of workflows per page"), order_by: WorkflowRecordOrderBy = Query( @@ -106,8 +136,19 @@ async def list_workflows( tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"), query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]: """Gets a page of workflows""" + config = ApiDependencies.invoker.services.configuration + + # In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows + user_id_filter: Optional[str] = None + if config.multiuser: + # Only filter 'user' category results by user_id when not explicitly listing public workflows + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = [] workflows = ApiDependencies.invoker.services.workflow_records.get_many( order_by=order_by, @@ -118,6 +159,8 @@ async def list_workflows( categories=categories, tags=tags, has_been_opened=has_been_opened, + user_id=user_id_filter, + is_public=is_public, ) for workflow in workflows.items: workflows_with_thumbnails.append( @@ -143,15 +186,20 @@ async def list_workflows( }, ) async def set_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), image: UploadFile = File(description="The image file to upload"), ): """Sets a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + if not image.content_type or not image.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -177,14 +225,19 @@ async def set_workflow_thumbnail( }, ) async def delete_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ): """Removes a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except ValueError as e: @@ -223,37 +276,90 @@ async def get_workflow_thumbnail( raise HTTPException(status_code=404) +@workflows_router.patch( + "/i/{workflow_id}/is_public", + operation_id="update_workflow_is_public", + responses={ + 200: {"model": WorkflowRecordDTO}, + }, +) +async def update_workflow_is_public( + current_user: CurrentUserOrDefault, + workflow_id: str = Path(description="The workflow to update"), + is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True), +) -> WorkflowRecordDTO: + """Updates whether a workflow is shared publicly""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + return ApiDependencies.invoker.services.workflow_records.update_is_public( + workflow_id=workflow_id, is_public=is_public + ) + + @workflows_router.get("/tags", operation_id="get_all_tags") async def get_all_tags( + current_user: CurrentUserOrDefault, categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> list[str]: """Gets all unique tags from workflows""" - - return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories) + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + + return ApiDependencies.invoker.services.workflow_records.get_all_tags( + categories=categories, user_id=user_id_filter, is_public=is_public + ) @workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag") async def get_counts_by_tag( + current_user: CurrentUserOrDefault, tags: list[str] = Query(description="The tags to get counts for"), categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by tag""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_tag( - tags=tags, categories=categories, has_been_opened=has_been_opened + tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @workflows_router.get("/counts_by_category", operation_id="counts_by_category") async def counts_by_category( + current_user: CurrentUserOrDefault, categories: list[WorkflowCategory] = Query(description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by category""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_category( - categories=categories, has_been_opened=has_been_opened + categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 645509f1dde..2478e8cdcae 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -30,6 +30,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -77,6 +78,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_25(app_config=config, logger=logger)) migrator.register_migration(build_migration_26(app_config=config, logger=logger)) migrator.register_migration(build_migration_27()) + migrator.register_migration(build_migration_28()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py new file mode 100644 index 00000000000..0cbd683ab5e --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py @@ -0,0 +1,45 @@ +"""Migration 28: Add per-user workflow isolation columns to workflow_library. + +This migration adds the database columns required for multiuser workflow isolation +to the workflow_library table: +- user_id: the owner of the workflow (defaults to 'system' for existing workflows) +- is_public: whether the workflow is shared with all users +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration28Callback: + """Migration to add user_id and is_public to the workflow_library table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_workflow_library_table(cursor) + + def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None: + """Add user_id and is_public columns to workflow_library table.""" + cursor.execute("PRAGMA table_info(workflow_library);") + columns = [row[1] for row in cursor.fetchall()] + + if "user_id" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);") + + if "is_public" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);") + + +def build_migration_28() -> Migration: + """Builds the migration object for migrating from version 27 to version 28. + + This migration adds per-user workflow isolation to the workflow_library table: + - user_id column: identifies the owner of each workflow + - is_public column: controls whether a workflow is shared with all users + """ + return Migration( + from_version=27, + to_version=28, + callback=Migration28Callback(), + ) diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index d5cf319594b..8da1e97daf7 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -4,6 +4,7 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowRecordDTO, @@ -22,7 +23,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: pass @abstractmethod - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: """Creates a workflow.""" pass @@ -47,6 +48,8 @@ def get_many( query: Optional[str], tags: Optional[list[str]], has_been_opened: Optional[bool], + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: """Gets many workflows.""" pass @@ -56,6 +59,8 @@ def counts_by_category( self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided categories.""" pass @@ -66,6 +71,8 @@ def counts_by_tag( tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided tags.""" pass @@ -79,6 +86,13 @@ def update_opened_at(self, workflow_id: str) -> None: def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: """Gets all unique tags from workflows.""" pass + + @abstractmethod + def update_is_public(self, workflow_id: str, is_public: bool) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow.""" + pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index e0cea37468d..9c505530c90 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -9,6 +9,9 @@ __workflow_meta_version__ = semver.Version.parse("1.0.0") +WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system" +"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases.""" + class ExposedField(BaseModel): nodeId: str @@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum): UpdatedAt = "updated_at" OpenedAt = "opened_at" Name = "name" + IsPublic = "is_public" class WorkflowCategory(str, Enum, metaclass=MetaEnum): @@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel): opened_at: Optional[Union[datetime.datetime, str]] = Field( default=None, description="The opened timestamp of the workflow." ) + user_id: str = Field(description="The id of the user who owns this workflow.") + is_public: bool = Field(description="Whether this workflow is shared with all users.") class WorkflowRecordDTO(WorkflowRecordDTOBase): diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index 0f72f7cd92c..eac45fb18e4 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowNotFoundError, @@ -36,7 +37,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT workflow_id, workflow, name, created_at, updated_at, opened_at + SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public FROM workflow_library WHERE workflow_id = ?; """, @@ -47,7 +48,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO: raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") return WorkflowRecordDTO.from_dict(dict(row)) - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be created via this method") @@ -57,11 +58,12 @@ def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: """--sql INSERT OR IGNORE INTO workflow_library ( workflow_id, - workflow + workflow, + user_id ) - VALUES (?, ?); + VALUES (?, ?, ?); """, - (workflow_with_id.id, workflow_with_id.model_dump_json()), + (workflow_with_id.id, workflow_with_id.model_dump_json(), user_id), ) return self.get(workflow_with_id.id) @@ -94,6 +96,19 @@ def delete(self, workflow_id: str) -> None: ) return None + def update_is_public(self, workflow_id: str, is_public: bool) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow.""" + with self._db.transaction() as cursor: + cursor.execute( + """--sql + UPDATE workflow_library + SET is_public = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (is_public, workflow_id), + ) + return self.get(workflow_id) + def get_many( self, order_by: WorkflowRecordOrderBy, @@ -104,6 +119,8 @@ def get_many( query: Optional[str] = None, tags: Optional[list[str]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: with self._db.transaction() as cursor: # sanitize! @@ -122,7 +139,9 @@ def get_many( created_at, updated_at, opened_at, - tags + tags, + user_id, + is_public FROM workflow_library """ count_query = "SELECT COUNT(*) FROM workflow_library" @@ -177,6 +196,15 @@ def get_many( conditions.append(query_condition) params.extend([wildcard_query, wildcard_query, wildcard_query]) + if user_id is not None: + conditions.append("user_id = ?") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + if conditions: # If there are conditions, add a WHERE clause and then join the conditions main_query += " WHERE " @@ -226,6 +254,8 @@ def counts_by_tag( tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: if not tags: return {} @@ -248,6 +278,15 @@ def counts_by_tag( elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + base_conditions.append("user_id = ?") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each tag to count, run a separate query for tag in tags: # Start with the base conditions @@ -277,6 +316,8 @@ def counts_by_category( self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: with self._db.transaction() as cursor: result: dict[str, int] = {} @@ -296,6 +337,15 @@ def counts_by_category( elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + base_conditions.append("user_id = ?") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each category to count, run a separate query for category in categories: # Start with the base conditions @@ -335,6 +385,8 @@ def update_opened_at(self, workflow_id: str) -> None: def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: with self._db.transaction() as cursor: conditions: list[str] = [] @@ -349,6 +401,15 @@ def get_all_tags( conditions.append(f"category IN ({placeholders})") params.extend([category.value for category in categories]) + if user_id is not None: + conditions.append("user_id = ?") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + stmt = """--sql SELECT DISTINCT tags FROM workflow_library diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index af8476528d6..19e5a3a68e9 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6463,6 +6463,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6655,6 +6672,23 @@ "title": "Categories" }, "description": "The categories to include" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6744,6 +6778,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6812,6 +6863,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -7352,6 +7420,67 @@ } } } + }, + "/api/v1/workflows/i/{workflow_id}/is_public": { + "patch": { + "tags": ["workflows"], + "summary": "Update Workflow Is Public", + "description": "Updates whether a workflow is shared publicly", + "operationId": "update_workflow_is_public", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + }, + "description": "The workflow to update" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether the workflow should be shared publicly" + } + }, + "type": "object", + "required": ["is_public"], + "title": "Body_update_workflow_is_public" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowRecordDTO" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -59137,10 +59266,20 @@ "workflow": { "$ref": "#/components/schemas/Workflow", "description": "The workflow." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordDTO" }, "WorkflowRecordListItemWithThumbnailDTO": { @@ -59222,15 +59361,35 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"], + "required": [ + "workflow_id", + "name", + "created_at", + "updated_at", + "description", + "category", + "tags", + "user_id", + "is_public" + ], "title": "WorkflowRecordListItemWithThumbnailDTO" }, "WorkflowRecordOrderBy": { "type": "string", - "enum": ["created_at", "updated_at", "opened_at", "name"], + "enum": ["created_at", "updated_at", "opened_at", "name", "is_public"], "title": "WorkflowRecordOrderBy", "description": "The order by options for workflow records" }, @@ -59303,10 +59462,20 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordWithThumbnailDTO" }, "WorkflowWithoutID": { diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 2db971d06a6..225c54d458e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2158,6 +2158,8 @@ "tags": "Tags", "yourWorkflows": "Your Workflows", "recentlyOpened": "Recently Opened", + "sharedWorkflows": "Shared Workflows", + "shareWorkflow": "Share Workflow", "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx index 73b046c83a9..501b8365db5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx @@ -41,6 +41,7 @@ export const WorkflowLibrarySideNav = () => { {t('workflows.recentlyOpened')} + {t('workflows.sharedWorkflows')} diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx index 79dff535b05..e6605d2076a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx @@ -32,6 +32,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => { return ['user', 'default']; case 'yours': return ['user']; + case 'shared': + return ['user']; default: assert>(false); } @@ -44,6 +46,13 @@ const getHasBeenOpened = (view: WorkflowLibraryView): boolean | undefined => { return undefined; }; +const getIsPublic = (view: WorkflowLibraryView): boolean | undefined => { + if (view === 'shared') { + return true; + } + return undefined; +}; + const useInfiniteQueryAry = () => { const orderBy = useAppSelector(selectWorkflowLibraryOrderBy); const direction = useAppSelector(selectWorkflowLibraryDirection); @@ -62,6 +71,7 @@ const useInfiniteQueryAry = () => { query: debouncedSearchTerm, tags: view === 'defaults' || view === 'yours' ? selectedTags : [], has_been_opened: getHasBeenOpened(view), + is_public: getIsPublic(view), } satisfies Parameters[0]; }, [orderBy, direction, view, debouncedSearchTerm, selectedTags]); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx index a1767765c93..7dcc85cc2a3 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx @@ -1,13 +1,15 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; -import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library'; +import { Badge, Flex, Icon, Image, Spacer, Switch, Text, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { selectWorkflowId } from 'features/nodes/store/selectors'; import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice'; import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog'; import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg'; -import { memo, useCallback, useMemo } from 'react'; +import { type ChangeEvent, memo, type MouseEvent, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiImage } from 'react-icons/pi'; +import { useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows'; import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types'; import { DeleteWorkflow } from './WorkflowLibraryListItemActions/DeleteWorkflow'; @@ -33,12 +35,17 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi const { t } = useTranslation(); const dispatch = useAppDispatch(); const workflowId = useAppSelector(selectWorkflowId); + const currentUser = useAppSelector(selectCurrentUser); const loadWorkflowWithDialog = useLoadWorkflowWithDialog(); const isActive = useMemo(() => { return workflowId === workflow.workflow_id; }, [workflowId, workflow.workflow_id]); + const isOwner = useMemo(() => { + return currentUser !== null && workflow.user_id === currentUser.user_id; + }, [currentUser, workflow.user_id]); + const tags = useMemo(() => { if (!workflow.tags) { return []; @@ -102,6 +109,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi {t('workflows.opened')} )} + {workflow.is_public && workflow.category !== 'default' && ( + + {t('workflows.shared')} + + )} {workflow.category === 'default' && ( )} + {isOwner && } {workflow.category === 'default' && } {workflow.category !== 'default' && ( <> @@ -152,6 +172,35 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi }); WorkflowListItem.displayName = 'WorkflowListItem'; +const ShareWorkflowToggle = memo(({ workflow }: { workflow: WorkflowRecordListItemWithThumbnailDTO }) => { + const { t } = useTranslation(); + const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation(); + + const handleChange = useCallback( + (e: ChangeEvent) => { + e.stopPropagation(); + updateIsPublic({ workflow_id: workflow.workflow_id, is_public: e.target.checked }); + }, + [updateIsPublic, workflow.workflow_id] + ); + + const handleClick = useCallback((e: MouseEvent) => { + e.stopPropagation(); + }, []); + + return ( + + + + {t('workflows.shared')} + + + + + ); +}); +ShareWorkflowToggle.displayName = 'ShareWorkflowToggle'; + const UserThumbnailFallback = memo(() => { return ( ; const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success; @@ -32,6 +32,7 @@ export const WorkflowSortControl = () => { created_at: t('workflows.created'), updated_at: t('workflows.updated'), name: t('workflows.name'), + is_public: t('workflows.shared'), }), [t] ); diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts index ee85a03c18f..1d5d8554aeb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowLibrarySlice.ts @@ -11,7 +11,7 @@ import { } from 'services/api/types'; import z from 'zod'; -const zWorkflowLibraryView = z.enum(['recent', 'yours', 'defaults']); +const zWorkflowLibraryView = z.enum(['recent', 'yours', 'shared', 'defaults']); export type WorkflowLibraryView = z.infer; const zWorkflowLibraryState = z.object({ @@ -55,6 +55,9 @@ const slice = createSlice({ if (action.payload === 'recent') { state.orderBy = 'opened_at'; state.direction = 'DESC'; + } else if (action.payload === 'shared') { + state.orderBy = 'name'; + state.direction = 'ASC'; } }, workflowLibraryTagToggled: (state, action: PayloadAction) => { @@ -121,5 +124,11 @@ export const WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [ ]; export const WORKFLOW_LIBRARY_TAGS = WORKFLOW_LIBRARY_TAG_CATEGORIES.flatMap(({ tags }) => tags); -type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name'; -export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = ['opened_at', 'created_at', 'updated_at', 'name']; +type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name' | 'is_public'; +export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = [ + 'opened_at', + 'created_at', + 'updated_at', + 'name', + 'is_public', +]; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx index 72ca9c309b3..e29ca82fa2b 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/components/SaveWorkflowAsDialog.tsx @@ -5,6 +5,7 @@ import { AlertDialogFooter, AlertDialogHeader, Button, + Checkbox, Flex, FormControl, FormLabel, @@ -19,6 +20,7 @@ import { t } from 'i18next'; import { atom, computed } from 'nanostores'; import type { ChangeEvent, RefObject } from 'react'; import { memo, useCallback, useRef, useState } from 'react'; +import { useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows'; import { assert } from 'tsafe'; /** @@ -87,8 +89,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef } return ''; }); + const [isPublic, setIsPublic] = useState(false); const { createNewWorkflow } = useCreateLibraryWorkflow(); + const [updateIsPublic] = useUpdateWorkflowIsPublicMutation(); const inputRef = useRef(null); @@ -96,6 +100,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef setName(e.target.value); }, []); + const onChangeIsPublic = useCallback((e: ChangeEvent) => { + setIsPublic(e.target.checked); + }, []); + const onClose = useCallback(() => { $workflowToSave.set(null); }, []); @@ -110,10 +118,19 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef await createNewWorkflow({ workflow, - onSuccess: onClose, + onSuccess: async (workflowId?: string) => { + if (isPublic && workflowId) { + try { + await updateIsPublic({ workflow_id: workflowId, is_public: true }).unwrap(); + } catch { + // Sharing failed silently - workflow was saved, just not shared + } + } + onClose(); + }, onError: onClose, }); - }, [workflow, name, createNewWorkflow, onClose]); + }, [workflow, name, isPublic, createNewWorkflow, updateIsPublic, onClose]); return ( @@ -126,6 +143,10 @@ const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef {t('workflows.workflowName')} + + + {t('workflows.shareWorkflow')} + diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts index 543283c779c..37fe48726e0 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useCreateNewWorkflow.ts @@ -29,7 +29,7 @@ export const isDraftWorkflow = (workflow: WorkflowV3): workflow is DraftWorkflow type CreateLibraryWorkflowArg = { workflow: DraftWorkflow; - onSuccess?: () => void; + onSuccess?: (workflowId?: string) => void; onError?: () => void; }; @@ -70,7 +70,7 @@ export const useCreateLibraryWorkflow = (): CreateLibraryWorkflowReturn => { // When a workflow is saved, the form field initial values are updated to the current form field values dispatch(formFieldInitialValuesChanged({ formFieldInitialValues: getFormFieldInitialValues() })); updateOpenedAt({ workflow_id: id }); - onSuccess?.(); + onSuccess?.(id); toast.update(toastRef.current, { title: t('workflows.workflowSaved'), status: 'success', diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index f58d3281a26..176546c90fd 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -157,6 +157,21 @@ export const workflowsApi = api.injectEndpoints({ }), invalidatesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }], }), + updateWorkflowIsPublic: build.mutation< + paths['/api/v1/workflows/i/{workflow_id}/is_public']['patch']['responses']['200']['content']['application/json'], + { workflow_id: string; is_public: boolean } + >({ + query: ({ workflow_id, is_public }) => ({ + url: buildWorkflowsUrl(`i/${workflow_id}/is_public`), + method: 'PATCH', + body: { is_public }, + }), + invalidatesTags: (result, error, { workflow_id }) => [ + { type: 'Workflow', id: workflow_id }, + { type: 'Workflow', id: LIST_TAG }, + 'WorkflowCategoryCounts', + ], + }), }), }); @@ -173,4 +188,5 @@ export const { useListWorkflowsInfiniteInfiniteQuery, useSetWorkflowThumbnailMutation, useDeleteWorkflowThumbnailMutation, + useUpdateWorkflowIsPublicMutation, } = workflowsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index b605413787b..898423246a9 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2001,6 +2001,26 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/workflows/i/{workflow_id}/is_public": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + /** + * Update Workflow Is Public + * @description Updates whether a workflow is shared publicly + */ + patch: operations["update_workflow_is_public"]; + trace?: never; + }; "/api/v1/workflows/tags": { parameters: { query?: never; @@ -3166,6 +3186,14 @@ export type components = { /** @description The updated workflow */ workflow: components["schemas"]["Workflow"]; }; + /** Body_update_workflow_is_public */ + Body_update_workflow_is_public: { + /** + * Is Public + * @description Whether the workflow should be shared publicly + */ + is_public: boolean; + }; /** Body_upload_image */ Body_upload_image: { /** @@ -27450,6 +27478,16 @@ export type components = { * @description The opened timestamp of the workflow. */ opened_at?: string | null; + /** + * User Id + * @description The id of the user who owns this workflow. + */ + user_id: string; + /** + * Is Public + * @description Whether this workflow is shared with all users. + */ + is_public: boolean; /** @description The workflow. */ workflow: components["schemas"]["Workflow"]; }; @@ -27480,6 +27518,16 @@ export type components = { * @description The opened timestamp of the workflow. */ opened_at?: string | null; + /** + * User Id + * @description The id of the user who owns this workflow. + */ + user_id: string; + /** + * Is Public + * @description Whether this workflow is shared with all users. + */ + is_public: boolean; /** * Description * @description The description of the workflow. @@ -27503,7 +27551,7 @@ export type components = { * @description The order by options for workflow records * @enum {string} */ - WorkflowRecordOrderBy: "created_at" | "updated_at" | "opened_at" | "name"; + WorkflowRecordOrderBy: "created_at" | "updated_at" | "opened_at" | "name" | "is_public"; /** WorkflowRecordWithThumbnailDTO */ WorkflowRecordWithThumbnailDTO: { /** @@ -27531,6 +27579,16 @@ export type components = { * @description The opened timestamp of the workflow. */ opened_at?: string | null; + /** + * User Id + * @description The id of the user who owns this workflow. + */ + user_id: string; + /** + * Is Public + * @description Whether this workflow is shared with all users. + */ + is_public: boolean; /** @description The workflow. */ workflow: components["schemas"]["Workflow"]; /** @@ -32380,6 +32438,8 @@ export interface operations { query?: string | null; /** @description Whether to include/exclude recent workflows */ has_been_opened?: boolean | null; + /** @description Filter by public/shared status */ + is_public?: boolean | null; }; header?: never; path?: never; @@ -32554,11 +32614,49 @@ export interface operations { }; }; }; + update_workflow_is_public: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The workflow to update */ + workflow_id: string; + }; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["Body_update_workflow_is_public"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["WorkflowRecordDTO"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; get_all_tags: { parameters: { query?: { /** @description The categories to include */ categories?: components["schemas"]["WorkflowCategory"][] | null; + /** @description Filter by public/shared status */ + is_public?: boolean | null; }; header?: never; path?: never; @@ -32595,6 +32693,8 @@ export interface operations { categories?: components["schemas"]["WorkflowCategory"][] | null; /** @description Whether to include/exclude recent workflows */ has_been_opened?: boolean | null; + /** @description Filter by public/shared status */ + is_public?: boolean | null; }; header?: never; path?: never; @@ -32631,6 +32731,8 @@ export interface operations { categories: components["schemas"]["WorkflowCategory"][]; /** @description Whether to include/exclude recent workflows */ has_been_opened?: boolean | null; + /** @description Filter by public/shared status */ + is_public?: boolean | null; }; header?: never; path?: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 5d56c346f87..80264f792c4 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -337,7 +337,7 @@ export type ModelInstallStatus = S['InstallStatus']; export type Graph = S['Graph']; export type NonNullableGraph = SetRequired; export type Batch = S['Batch']; -export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at']); +export const zWorkflowRecordOrderBy = z.enum(['name', 'created_at', 'updated_at', 'opened_at', 'is_public']); export type WorkflowRecordOrderBy = z.infer; assert>(); diff --git a/tests/app/routers/test_workflows_multiuser.py b/tests/app/routers/test_workflows_multiuser.py new file mode 100644 index 00000000000..dc32b0752d9 --- /dev/null +++ b/tests/app/routers/test_workflows_multiuser.py @@ -0,0 +1,274 @@ +"""Tests for multiuser workflow library functionality.""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.api_app import app +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.users.users_common import UserCreateRequest + + +class MockApiDependencies(ApiDependencies): + invoker: Invoker + + def __init__(self, invoker: Invoker) -> None: + self.invoker = invoker + + +WORKFLOW_BODY = { + "name": "Test Workflow", + "author": "", + "description": "A test workflow", + "version": "1.0.0", + "contact": "", + "tags": "", + "notes": "", + "nodes": {}, + "edges": [], + "exposedFields": [], + "meta": {"version": "3.0.0", "category": "user"}, + "id": None, + "form_fields": [], +} + + +@pytest.fixture +def setup_jwt_secret(): + from invokeai.app.services.auth.token_service import set_jwt_secret + + set_jwt_secret("test-secret-key-for-unit-tests-only-do-not-use-in-production") + + +@pytest.fixture +def client(): + return TestClient(app) + + +def create_test_user(mock_invoker: Invoker, email: str, display_name: str, is_admin: bool = False) -> str: + user_service = mock_invoker.services.users + user_data = UserCreateRequest(email=email, display_name=display_name, password="TestPass123", is_admin=is_admin) + user = user_service.create(user_data) + return user.user_id + + +def get_user_token(client: TestClient, email: str) -> str: + response = client.post( + "/api/v1/auth/login", + json={"email": email, "password": "TestPass123", "remember_me": False}, + ) + assert response.status_code == 200 + return response.json()["token"] + + +@pytest.fixture +def enable_multiuser(monkeypatch: Any, mock_invoker: Invoker): + mock_invoker.services.configuration.multiuser = True + mock_workflow_thumbnails = MagicMock() + mock_workflow_thumbnails.get_url.return_value = None + mock_invoker.services.workflow_thumbnails = mock_workflow_thumbnails + + mock_deps = MockApiDependencies(mock_invoker) + monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", mock_deps) + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", mock_deps) + monkeypatch.setattr("invokeai.app.api.routers.workflows.ApiDependencies", mock_deps) + yield + + +@pytest.fixture +def admin_token(setup_jwt_secret: None, enable_multiuser: Any, mock_invoker: Invoker, client: TestClient): + create_test_user(mock_invoker, "admin@test.com", "Admin", is_admin=True) + return get_user_token(client, "admin@test.com") + + +@pytest.fixture +def user1_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + create_test_user(mock_invoker, "user1@test.com", "User One", is_admin=False) + return get_user_token(client, "user1@test.com") + + +@pytest.fixture +def user2_token(enable_multiuser: Any, mock_invoker: Invoker, client: TestClient, admin_token: str): + create_test_user(mock_invoker, "user2@test.com", "User Two", is_admin=False) + return get_user_token(client, "user2@test.com") + + +def create_workflow(client: TestClient, token: str) -> str: + response = client.post( + "/api/v1/workflows/", + json={"workflow": WORKFLOW_BODY}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200, response.text + return response.json()["workflow_id"] + + +# --------------------------------------------------------------------------- +# Auth tests +# --------------------------------------------------------------------------- + + +def test_list_workflows_requires_auth(enable_multiuser: Any, client: TestClient): + response = client.get("/api/v1/workflows/") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_create_workflow_requires_auth(enable_multiuser: Any, client: TestClient): + response = client.post("/api/v1/workflows/", json={"workflow": WORKFLOW_BODY}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +# --------------------------------------------------------------------------- +# Ownership isolation +# --------------------------------------------------------------------------- + + +def test_workflows_are_isolated_between_users(client: TestClient, user1_token: str, user2_token: str): + """Users should only see their own workflows in list.""" + # user1 creates a workflow + create_workflow(client, user1_token) + + # user1 can see it + r1 = client.get("/api/v1/workflows/?categories=user", headers={"Authorization": f"Bearer {user1_token}"}) + assert r1.status_code == 200 + assert r1.json()["total"] == 1 + + # user2 cannot see user1's workflow + r2 = client.get("/api/v1/workflows/?categories=user", headers={"Authorization": f"Bearer {user2_token}"}) + assert r2.status_code == 200 + assert r2.json()["total"] == 0 + + +def test_user_cannot_delete_another_users_workflow(client: TestClient, user1_token: str, user2_token: str): + workflow_id = create_workflow(client, user1_token) + response = client.delete( + f"/api/v1/workflows/i/{workflow_id}", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_user_cannot_update_another_users_workflow(client: TestClient, user1_token: str, user2_token: str): + workflow_id = create_workflow(client, user1_token) + updated = {**WORKFLOW_BODY, "id": workflow_id, "name": "Hijacked"} + response = client.patch( + f"/api/v1/workflows/i/{workflow_id}", + json={"workflow": updated}, + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_owner_can_delete_own_workflow(client: TestClient, user1_token: str): + workflow_id = create_workflow(client, user1_token) + response = client.delete( + f"/api/v1/workflows/i/{workflow_id}", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == 200 + + +def test_admin_can_delete_any_workflow(client: TestClient, admin_token: str, user1_token: str): + workflow_id = create_workflow(client, user1_token) + response = client.delete( + f"/api/v1/workflows/i/{workflow_id}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + +# --------------------------------------------------------------------------- +# Shared workflow (is_public) +# --------------------------------------------------------------------------- + + +def test_update_is_public_owner_succeeds(client: TestClient, user1_token: str): + workflow_id = create_workflow(client, user1_token) + response = client.patch( + f"/api/v1/workflows/i/{workflow_id}/is_public", + json={"is_public": True}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == 200 + assert response.json()["is_public"] is True + + +def test_update_is_public_other_user_forbidden(client: TestClient, user1_token: str, user2_token: str): + workflow_id = create_workflow(client, user1_token) + response = client.patch( + f"/api/v1/workflows/i/{workflow_id}/is_public", + json={"is_public": True}, + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_public_workflow_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str): + """A shared (is_public=True) workflow should appear when filtering with is_public=true.""" + workflow_id = create_workflow(client, user1_token) + # Make it public + client.patch( + f"/api/v1/workflows/i/{workflow_id}/is_public", + json={"is_public": True}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + + # user2 can see it through is_public=true filter + response = client.get( + "/api/v1/workflows/?categories=user&is_public=true", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == 200 + ids = [w["workflow_id"] for w in response.json()["items"]] + assert workflow_id in ids + + +def test_private_workflow_not_visible_to_other_users(client: TestClient, user1_token: str, user2_token: str): + """A private (is_public=False) user workflow should NOT appear for another user.""" + workflow_id = create_workflow(client, user1_token) + + # user2 lists 'yours' style (their own workflows) + response = client.get( + "/api/v1/workflows/?categories=user", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + assert response.status_code == 200 + ids = [w["workflow_id"] for w in response.json()["items"]] + assert workflow_id not in ids + + +def test_public_workflow_still_in_owners_list(client: TestClient, user1_token: str): + """A shared workflow should still appear in the owner's own workflow list.""" + workflow_id = create_workflow(client, user1_token) + client.patch( + f"/api/v1/workflows/i/{workflow_id}/is_public", + json={"is_public": True}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + + # owner's 'yours' list (no is_public filter) + response = client.get( + "/api/v1/workflows/?categories=user", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == 200 + ids = [w["workflow_id"] for w in response.json()["items"]] + assert workflow_id in ids + + +def test_workflow_has_user_id_and_is_public_fields(client: TestClient, user1_token: str): + """Created workflow should return user_id and is_public fields.""" + response = client.post( + "/api/v1/workflows/", + json={"workflow": WORKFLOW_BODY}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert "user_id" in data + assert "is_public" in data + assert data["is_public"] is False From ef5b610a8738a055a1bc6337997a8a725d3bc407 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 04:48:45 +0000 Subject: [PATCH 2/2] feat: add shared workflow checkbox to Details panel, auto-tag, gate edit/delete, fix tests Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- .../workflow_records_sqlite.py | 18 +++++- invokeai/frontend/web/public/locales/en.json | 2 +- .../sidePanel/workflow/WorkflowGeneralTab.tsx | 54 +++++++++++++++- .../WorkflowLibrary/WorkflowListItem.tsx | 8 ++- tests/app/routers/test_workflows_multiuser.py | 62 ++++++++++++++++++- 5 files changed, 135 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index eac45fb18e4..0e6dfe1b700 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -97,15 +97,27 @@ def delete(self, workflow_id: str) -> None: return None def update_is_public(self, workflow_id: str, is_public: bool) -> WorkflowRecordDTO: - """Updates the is_public field of a workflow.""" + """Updates the is_public field of a workflow and manages the 'shared' tag automatically.""" + record = self.get(workflow_id) + workflow = record.workflow + + # Manage "shared" tag: add when public, remove when private + tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else [] + if is_public and "shared" not in tags_list: + tags_list.append("shared") + elif not is_public and "shared" in tags_list: + tags_list.remove("shared") + updated_tags = ", ".join(tags_list) + updated_workflow = workflow.model_copy(update={"tags": updated_tags}) + with self._db.transaction() as cursor: cursor.execute( """--sql UPDATE workflow_library - SET is_public = ? + SET workflow = ?, is_public = ? WHERE workflow_id = ? AND category = 'user'; """, - (is_public, workflow_id), + (updated_workflow.model_dump_json(), is_public, workflow_id), ) return self.get(workflow_id) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 225c54d458e..3b7f1c2b5d9 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2159,7 +2159,7 @@ "yourWorkflows": "Your Workflows", "recentlyOpened": "Recently Opened", "sharedWorkflows": "Shared Workflows", - "shareWorkflow": "Share Workflow", + "shareWorkflow": "Shared workflow", "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx index c1094abf86d..11d27335352 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowGeneralTab.tsx @@ -1,8 +1,19 @@ import type { FormControlProps } from '@invoke-ai/ui-library'; -import { Box, Flex, FormControl, FormControlGroup, FormLabel, Image, Input, Textarea } from '@invoke-ai/ui-library'; +import { + Box, + Checkbox, + Flex, + FormControl, + FormControlGroup, + FormLabel, + Image, + Input, + Textarea, +} from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { workflowAuthorChanged, workflowContactChanged, @@ -25,7 +36,8 @@ import { import type { ChangeEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetWorkflowQuery } from 'services/api/endpoints/workflows'; +import { useGetSetupStatusQuery } from 'services/api/endpoints/auth'; +import { useGetWorkflowQuery, useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows'; import { WorkflowThumbnailEditor } from './WorkflowThumbnail/WorkflowThumbnailEditor'; @@ -95,6 +107,7 @@ const WorkflowGeneralTab = () => { {t('nodes.workflowName')} + {t('nodes.workflowVersion')} @@ -187,3 +200,40 @@ const Thumbnail = ({ id }: { id?: string | null }) => { // This is a default workflow and it does not have a thumbnail set. Users may not edit the thumbnail. return null; }; + +const ShareWorkflowCheckbox = ({ id }: { id?: string | null }) => { + const { t } = useTranslation(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: setupStatus } = useGetSetupStatusQuery(); + const { data } = useGetWorkflowQuery(id ?? skipToken); + const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation(); + + const handleChange = useCallback( + (e: ChangeEvent) => { + if (!id) { + return; + } + updateIsPublic({ workflow_id: id, is_public: e.target.checked }); + }, + [id, updateIsPublic] + ); + + // Only show for saved user workflows in multiuser mode when the current user is the owner or admin + if (!data || !id || data.workflow.meta.category !== 'user') { + return null; + } + if (setupStatus?.multiuser_enabled) { + const isOwner = currentUser !== null && data.user_id === currentUser.user_id; + const isAdmin = currentUser?.is_admin ?? false; + if (!isOwner && !isAdmin) { + return null; + } + } + + return ( + + + {t('workflows.shareWorkflow')} + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx index 7dcc85cc2a3..a184f04039a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx @@ -46,6 +46,10 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi return currentUser !== null && workflow.user_id === currentUser.user_id; }, [currentUser, workflow.user_id]); + const canEditOrDelete = useMemo(() => { + return isOwner || (currentUser?.is_admin ?? false); + }, [isOwner, currentUser]); + const tags = useMemo(() => { if (!workflow.tags) { return []; @@ -160,9 +164,9 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi {workflow.category === 'default' && } {workflow.category !== 'default' && ( <> - + {canEditOrDelete && } - + {canEditOrDelete && } )} diff --git a/tests/app/routers/test_workflows_multiuser.py b/tests/app/routers/test_workflows_multiuser.py index dc32b0752d9..28b301e18e3 100644 --- a/tests/app/routers/test_workflows_multiuser.py +++ b/tests/app/routers/test_workflows_multiuser.py @@ -1,5 +1,6 @@ """Tests for multiuser workflow library functionality.""" +import logging from typing import Any from unittest.mock import MagicMock @@ -9,8 +10,13 @@ from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker from invokeai.app.services.users.users_common import UserCreateRequest +from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import create_mock_sqlite_database class MockApiDependencies(ApiDependencies): @@ -28,7 +34,7 @@ def __init__(self, invoker: Invoker) -> None: "contact": "", "tags": "", "notes": "", - "nodes": {}, + "nodes": [], "edges": [], "exposedFields": [], "meta": {"version": "3.0.0", "category": "user"}, @@ -49,6 +55,60 @@ def client(): return TestClient(app) +@pytest.fixture +def mock_services() -> InvocationServices: + from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage + from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage + from invokeai.app.services.boards.boards_default import BoardService + from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService + from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ( + ClientStatePersistenceSqlite, + ) + from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage + from invokeai.app.services.images.images_default import ImageService + from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache + from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService + from invokeai.app.services.users.users_default import UserService + from tests.test_nodes import TestEventService + + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + return InvocationServices( + board_image_records=SqliteBoardImageRecordStorage(db=db), + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=BoardService(), + bulk_download=BulkDownloadService(), + configuration=configuration, + events=TestEventService(), + image_files=None, # type: ignore + image_records=SqliteImageRecordStorage(db=db), + images=ImageService(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), + logger=logging, # type: ignore + model_images=None, # type: ignore + model_manager=None, # type: ignore + download_queue=None, # type: ignore + names=None, # type: ignore + performance_statistics=InvocationStatsService(), + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=SqliteWorkflowRecordsStorage(db=db), + tensors=None, # type: ignore + conditioning=None, # type: ignore + style_preset_records=None, # type: ignore + style_preset_image_files=None, # type: ignore + workflow_thumbnails=None, # type: ignore + model_relationship_records=None, # type: ignore + model_relationships=None, # type: ignore + client_state_persistence=ClientStatePersistenceSqlite(db=db), + users=UserService(db), + ) + + def create_test_user(mock_invoker: Invoker, email: str, display_name: str, is_admin: bool = False) -> str: user_service = mock_invoker.services.users user_data = UserCreateRequest(email=email, display_name=display_name, password="TestPass123", is_admin=is_admin)