Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 115 additions & 9 deletions invokeai/app/api/routers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.responses import FileResponse
from PIL import Image

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
Expand Down Expand Up @@ -33,16 +34,25 @@
},
)
async def get_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordWithThumbnailDTO:
"""Gets a workflow"""
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

config = ApiDependencies.invoker.services.configuration
if config.multiuser:
is_default = workflow.workflow.meta.category is WorkflowCategory.Default
is_owner = workflow.user_id == current_user.user_id
if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
raise HTTPException(status_code=403, detail="Not authorized to access this workflow")

thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())


@workflows_router.patch(
"/i/{workflow_id}",
Expand All @@ -52,9 +62,18 @@ async def get_workflow(
},
)
async def update_workflow(
current_user: CurrentUserOrDefault,
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)


Expand All @@ -63,9 +82,18 @@ async def update_workflow(
operation_id="delete_workflow",
)
async def delete_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except WorkflowThumbnailFileNotFoundException:
Expand All @@ -82,10 +110,11 @@ async def delete_workflow(
},
)
async def create_workflow(
current_user: CurrentUserOrDefault,
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id)


@workflows_router.get(
Expand All @@ -96,6 +125,7 @@ async def create_workflow(
},
)
async def list_workflows(
current_user: CurrentUserOrDefault,
page: int = Query(default=0, description="The page to get"),
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
Expand All @@ -106,8 +136,19 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
config = ApiDependencies.invoker.services.configuration

# In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows
user_id_filter: Optional[str] = None
if config.multiuser:
# Only filter 'user' category results by user_id when not explicitly listing public workflows
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id

workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
Expand All @@ -118,6 +159,8 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
user_id=user_id_filter,
is_public=is_public,
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
Expand All @@ -143,15 +186,20 @@ async def list_workflows(
},
)
async def set_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")

if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")

Expand All @@ -177,14 +225,19 @@ async def set_workflow_thumbnail(
},
)
async def delete_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")

try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
Expand Down Expand Up @@ -223,37 +276,90 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)


@workflows_router.patch(
"/i/{workflow_id}/is_public",
operation_id="update_workflow_is_public",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def update_workflow_is_public(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
) -> WorkflowRecordDTO:
"""Updates whether a workflow is shared publicly"""
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")

return ApiDependencies.invoker.services.workflow_records.update_is_public(
workflow_id=workflow_id, is_public=is_public
)


@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
current_user: CurrentUserOrDefault,
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> list[str]:
"""Gets all unique tags from workflows"""

return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id

return ApiDependencies.invoker.services.workflow_records.get_all_tags(
categories=categories, user_id=user_id_filter, is_public=is_public
)


@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
current_user: CurrentUserOrDefault,
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by tag"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id

return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
tags=tags, categories=categories, has_been_opened=has_been_opened
tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)


@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
async def counts_by_category(
current_user: CurrentUserOrDefault,
categories: list[WorkflowCategory] = Query(description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by category"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id

return ApiDependencies.invoker.services.workflow_records.counts_by_category(
categories=categories, has_been_opened=has_been_opened
categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)


Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/shared/sqlite/sqlite_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator


Expand Down Expand Up @@ -77,6 +78,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
migrator.register_migration(build_migration_27())
migrator.register_migration(build_migration_28())
migrator.run_migrations()

return db
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Migration 28: Add per-user workflow isolation columns to workflow_library.

This migration adds the database columns required for multiuser workflow isolation
to the workflow_library table:
- user_id: the owner of the workflow (defaults to 'system' for existing workflows)
- is_public: whether the workflow is shared with all users
"""

import sqlite3

from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration


class Migration28Callback:
"""Migration to add user_id and is_public to the workflow_library table."""

def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_workflow_library_table(cursor)

def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id and is_public columns to workflow_library table."""
cursor.execute("PRAGMA table_info(workflow_library);")
columns = [row[1] for row in cursor.fetchall()]

if "user_id" not in columns:
cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);")

if "is_public" not in columns:
cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);")


def build_migration_28() -> Migration:
"""Builds the migration object for migrating from version 27 to version 28.

This migration adds per-user workflow isolation to the workflow_library table:
- user_id column: identifies the owner of each workflow
- is_public column: controls whether a workflow is shared with all users
"""
return Migration(
from_version=27,
to_version=28,
callback=Migration28Callback(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowRecordDTO,
Expand All @@ -22,7 +23,7 @@ def get(self, workflow_id: str) -> WorkflowRecordDTO:
pass

@abstractmethod
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO:
"""Creates a workflow."""
pass

Expand All @@ -47,6 +48,8 @@ def get_many(
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
Expand All @@ -56,6 +59,8 @@ def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
Expand All @@ -66,6 +71,8 @@ def counts_by_tag(
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass
Expand All @@ -79,6 +86,13 @@ def update_opened_at(self, workflow_id: str) -> None:
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> list[str]:
"""Gets all unique tags from workflows."""
pass

@abstractmethod
def update_is_public(self, workflow_id: str, is_public: bool) -> WorkflowRecordDTO:
"""Updates the is_public field of a workflow."""
pass
Loading