Skip to content

Commit 791f052

Browse files
authored
Merge pull request #828 from superannotateai/user_project_filters
Add filter by project user role
2 parents 22ec567 + aa3dc4e commit 791f052

5 files changed

Lines changed: 106 additions & 25 deletions

File tree

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,11 @@ def list_users(
561561
- email__starts: str
562562
- email__ends: str
563563
564+
Following params if project is selected::
565+
566+
- role: str
567+
- role__in: List[str]
568+
564569
Following params if project is not selected::
565570
566571
- state: Literal[“Confirmed”, “Pending”]

src/superannotate/lib/core/entities/filters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,22 @@ class ProjectFilters(BaseFilters):
4040
status__notin: List[Literal["NotStarted", "InProgress", "Completed", "OnHold"]]
4141

4242

43-
class UserFilters(TypedDict, total=False):
43+
class BaseUserFilters(TypedDict, total=False):
4444
id: Optional[int]
4545
id__in: Optional[List[int]]
4646
email: Optional[str]
4747
email__in: Optional[List[str]]
4848
email__contains: Optional[str]
4949
email__starts: Optional[str]
5050
email__ends: Optional[str]
51+
52+
53+
class ProjectUserFilters(BaseUserFilters, total=False):
54+
role: Optional[str]
55+
role__in: Optional[List[str]]
56+
57+
58+
class TeamUserFilters(BaseUserFilters, total=False):
5159
state: Optional[str]
5260
state__in: Optional[List[str]]
5361
role: Optional[str]

src/superannotate/lib/infrastructure/controller.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from lib.core.entities.classes import AnnotationClassEntity
3434
from lib.core.entities.filters import ItemFilters
3535
from lib.core.entities.filters import ProjectFilters
36-
from lib.core.entities.filters import UserFilters
36+
from lib.core.entities.filters import ProjectUserFilters
37+
from lib.core.entities.filters import TeamUserFilters
3738
from lib.core.entities.integrations import IntegrationEntity
3839
from lib.core.entities.items import ProjectCategoryEntity
3940
from lib.core.entities.work_managament import ScoreEntity
@@ -56,8 +57,10 @@
5657
from lib.infrastructure.query_builder import IncludeHandler
5758
from lib.infrastructure.query_builder import ItemFilterHandler
5859
from lib.infrastructure.query_builder import ProjectFilterHandler
60+
from lib.infrastructure.query_builder import ProjectUserRoleFilterHandler
5961
from lib.infrastructure.query_builder import QueryBuilderChain
60-
from lib.infrastructure.query_builder import UserFilterHandler
62+
from lib.infrastructure.query_builder import TeamUserRoleFilterHandler
63+
from lib.infrastructure.query_builder import TeamUserStateFilterHandler
6164
from lib.infrastructure.repositories import S3Repository
6265
from lib.infrastructure.serviceprovider import ServiceProvider
6366
from lib.infrastructure.services.http_client import HttpClient
@@ -205,27 +208,50 @@ def list_users(
205208
if project:
206209
parent_entity = CustomFieldEntityEnum.PROJECT
207210
project_id = context["project_id"] = project.id
211+
valid_fields = generate_schema(
212+
ProjectUserFilters.__annotations__,
213+
self.service_provider.get_custom_fields_templates(
214+
context, CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
215+
),
216+
)
217+
chain = QueryBuilderChain(
218+
[
219+
FieldValidationHandler(valid_fields.keys()),
220+
ProjectUserRoleFilterHandler(
221+
team_id=self.service_provider.client.team_id,
222+
project=project,
223+
service_provider=self.service_provider,
224+
entity=CustomFieldEntityEnum.CONTRIBUTOR,
225+
parent=parent_entity,
226+
),
227+
]
228+
)
208229
else:
209230
parent_entity = CustomFieldEntityEnum.TEAM
210231
project_id = None
211-
valid_fields = generate_schema(
212-
UserFilters.__annotations__,
213-
self.service_provider.get_custom_fields_templates(
214-
context, CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
215-
),
216-
)
217-
chain = QueryBuilderChain(
218-
[
219-
FieldValidationHandler(valid_fields.keys()),
220-
UserFilterHandler(
221-
team_id=self.service_provider.client.team_id,
222-
project_id=project_id,
223-
service_provider=self.service_provider,
224-
entity=CustomFieldEntityEnum.CONTRIBUTOR,
225-
parent=parent_entity,
232+
valid_fields = generate_schema(
233+
TeamUserFilters.__annotations__,
234+
self.service_provider.get_custom_fields_templates(
235+
context, CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
226236
),
227-
]
228-
)
237+
)
238+
chain = QueryBuilderChain(
239+
[
240+
FieldValidationHandler(valid_fields.keys()),
241+
TeamUserRoleFilterHandler(
242+
team_id=self.service_provider.client.team_id,
243+
service_provider=self.service_provider,
244+
entity=CustomFieldEntityEnum.CONTRIBUTOR,
245+
parent=parent_entity,
246+
),
247+
TeamUserStateFilterHandler(
248+
team_id=self.service_provider.client.team_id,
249+
service_provider=self.service_provider,
250+
entity=CustomFieldEntityEnum.CONTRIBUTOR,
251+
parent=parent_entity,
252+
),
253+
]
254+
)
229255
query = chain.handle(filters, EmptyQuery())
230256

231257
if project and include and "categories" in include:

src/superannotate/lib/infrastructure/query_builder.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,20 @@ class BaseCustomFieldHandler(AbstractQueryHandler):
164164
def __init__(
165165
self,
166166
team_id: int,
167-
project_id: Optional[int],
168167
service_provider: BaseServiceProvider,
169168
entity: CustomFieldEntityEnum,
170169
parent: CustomFieldEntityEnum,
170+
project: Optional[ProjectEntity] = None,
171171
):
172172
self._service_provider = service_provider
173173
self._entity = entity
174174
self._parent = parent
175175
self._team_id = team_id
176-
self._project_id = project_id
177-
self._context = {"team_id": self._team_id, "project_id": self._project_id}
176+
self._project = project
177+
self._context = {
178+
"team_id": self._team_id,
179+
"project_id": project.id if project else None,
180+
}
178181

179182
def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]:
180183
for custom_field in sorted(
@@ -261,7 +264,7 @@ def handle(self, filters: Dict[str, Any], query: Query = None) -> Query:
261264
return super().handle(filters, query)
262265

263266

264-
class UserFilterHandler(BaseCustomFieldHandler):
267+
class TeamUserRoleFilterHandler(BaseCustomFieldHandler):
265268
def _handle_special_fields(self, keys: List[str], val):
266269
"""
267270
Handle special fields like 'custom_fields__'.
@@ -276,7 +279,36 @@ def _handle_special_fields(self, keys: List[str], val):
276279
raise AppException("Invalid user role provided.")
277280
except (KeyError, AttributeError):
278281
raise AppException("Invalid user role provided.")
279-
elif keys[0] == "state":
282+
return super()._handle_special_fields(keys, val)
283+
284+
285+
class ProjectUserRoleFilterHandler(BaseCustomFieldHandler):
286+
def _handle_special_fields(self, keys: List[str], val):
287+
"""
288+
Handle special fields like 'custom_fields__'.
289+
"""
290+
if keys[0] == "role":
291+
try:
292+
if isinstance(val, list):
293+
val = [
294+
self._service_provider.get_role_id(self._project, i)
295+
for i in val
296+
]
297+
elif isinstance(val, str):
298+
val = self._service_provider.get_role_id(self._project, val)
299+
else:
300+
raise AppException("Invalid user role provided.")
301+
except (KeyError, AttributeError):
302+
raise AppException("Invalid user role provided.")
303+
return super()._handle_special_fields(keys, val)
304+
305+
306+
class TeamUserStateFilterHandler(BaseCustomFieldHandler):
307+
def _handle_special_fields(self, keys: List[str], val):
308+
"""
309+
Handle special fields like 'custom_fields__'.
310+
"""
311+
if keys[0] == "state":
280312
try:
281313
if isinstance(val, list):
282314
val = [WMUserStateEnum[i].value for i in val]

tests/integration/work_management/test_list_users.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ def test_pending_users(self):
3333

3434
assert project["contributors"][1]["state"] == "PENDING"
3535

36+
@pytest.mark.skip(reason="For not send real email")
37+
def test_project_role_filter_users(self):
38+
test_email = "test1@superannotate.com"
39+
sa.invite_contributors_to_team(emails=[test_email])
40+
sa.add_contributors_to_project(self.PROJECT_NAME, [test_email], "Annotator")
41+
users = sa.list_users(project=self.PROJECT_NAME, role="QA")
42+
assert len(users) == 0
43+
users = sa.list_users(project=self.PROJECT_NAME, role="Annotator")
44+
assert len(users) == 2
45+
3646
def test_list_users_by_project_name(self):
3747
project_users = sa.list_users(project=self.PROJECT_NAME)
3848
assert len(project_users) == 1

0 commit comments

Comments
 (0)