diff --git a/.github/instructions/*.instructions.md b/.github/instructions/*.instructions.md new file mode 100644 index 000000000..76a36da75 --- /dev/null +++ b/.github/instructions/*.instructions.md @@ -0,0 +1,7 @@ +# Instructions for Large PR Reviews + +## General Approach for Big Changesets +- Prioritize high-impact issues: security, performance, errors. +- Review files individually; skip minor style if >1000 lines total diff. +- Group feedback by file; limit to 5-10 key comments per file. +- If diff too large, comment on architecture/summary + top 3 files. diff --git a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py index 73be02c1d..947bbba5c 100644 --- a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py +++ b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py @@ -21,7 +21,10 @@ from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import ( create_group, @@ -38,7 +41,7 @@ depends_on: None = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" @@ -46,6 +49,7 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: # n async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) entity_type_dao = await cnt.get(EntityTypeDAO) + entity_type_use_case = await cnt.get(EntityTypeUseCase) role_use_case = await cnt.get(RoleUseCase) base_dn_list = await get_base_directories(session) @@ -104,7 +108,10 @@ async def _add_domain_computers_group(connection: AsyncConnection) -> None: # n attribute_names=["attributes"], with_for_update=None, ) - await entity_type_dao.attach_entity_type_to_directory(dir_, False) + await entity_type_use_case.attach_entity_type_to_directory( + dir_, + False, + ) await role_use_case.inherit_parent_aces( parent_directory=parent, directory=dir_, @@ -169,7 +176,7 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: # noqa: A op.run_async(_add_primary_group_id) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: """Downgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/05ddc0bd562a_add_roles.py b/app/alembic/versions/05ddc0bd562a_add_roles.py index aff773460..aceae92e1 100644 --- a/app/alembic/versions/05ddc0bd562a_add_roles.py +++ b/app/alembic/versions/05ddc0bd562a_add_roles.py @@ -24,7 +24,7 @@ depends_on: None = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" op.create_table( diff --git a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py index 3cfbca4a2..f935824b5 100644 --- a/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py +++ b/app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py @@ -25,7 +25,7 @@ depends_on: None | list[str] = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() @@ -72,7 +72,7 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 session.close() -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/275222846605_initial_ldap_schema.py b/app/alembic/versions/275222846605_initial_ldap_schema.py index 6994b0c77..e9bcf6344 100644 --- a/app/alembic/versions/275222846605_initial_ldap_schema.py +++ b/app/alembic/versions/275222846605_initial_ldap_schema.py @@ -12,15 +12,26 @@ from alembic import op from dishka import AsyncContainer, Scope from ldap3.protocol.schemas.ad2012R2 import ad_2012_r2_schema -from sqlalchemy import delete, or_, select +from sqlalchemy import delete, or_ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -from sqlalchemy.orm import Session, selectinload +from sqlalchemy.orm import Session -from entities import Attribute, AttributeType, ObjectClass +from entities import Attribute from extra.alembic_utils import temporary_stub_column -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_dao import ( # noqa: E501 + AttributeTypeDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_dao import ( + ObjectClassDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_use_case import ( # noqa: E501 + ObjectClassUseCaseLegacy, +) from ldap_protocol.ldap_schema.dto import AttributeTypeDTO -from ldap_protocol.utils.raw_definition_parser import ( +from ldap_protocol.ldap_schema.raw_definition_parser import ( RawDefinitionParser as RDParser, ) from repo.pg.tables import queryable_attr as qa @@ -35,7 +46,7 @@ ad_2012_r2_schema_json = json.loads(ad_2012_r2_schema) -@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" bind = op.get_bind() @@ -184,91 +195,122 @@ def upgrade(container: AsyncContainer) -> None: # NOTE: catalog is a non-existent object class session.execute( - delete(Attribute).where( + delete(Attribute) + .where( or_( qa(Attribute.name) == "objectClass", qa(Attribute.name) == "objectclass", ), qa(Attribute.value) == "catalog", ), - ) + ) # fmt: skip - # NOTE: Load attributeTypes into the database - at_raw_definitions: list[str] = ad_2012_r2_schema_json["raw"][ - "attributeTypes" - ] - at_raw_definitions.extend( - [ - "( 1.2.840.113556.1.4.9999 NAME 'entityTypeName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )", # noqa: E501 - # - # Kerberos schema: https://github.com/krb5/krb5/blob/master/src/plugins/kdb/ldap/libkdb_ldap/kerberos.schema - "( 2.16.840.1.113719.1.301.4.1.1 NAME 'krbPrincipalName' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 - "( 1.2.840.113554.1.4.1.6.1 NAME 'krbCanonicalName' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.3.1 NAME 'krbPrincipalType' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.5.1 NAME 'krbUPEnabled' DESC 'Boolean' SYNTAX 1.3.6.1.4.1.1466.115.121.1.7 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.6.1 NAME 'krbPrincipalExpiration' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.8.1 NAME 'krbTicketFlags' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.9.1 NAME 'krbMaxTicketLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.10.1 NAME 'krbMaxRenewableAge' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.14.1 NAME 'krbRealmReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.15.1 NAME 'krbLdapServers' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.17.1 NAME 'krbKdcServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.18.1 NAME 'krbPwdServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.24.1 NAME 'krbHostServer' EQUALITY caseExactIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.25.1 NAME 'krbSearchScope' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.26.1 NAME 'krbPrincipalReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.28.1 NAME 'krbPrincNamingAttr' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.29.1 NAME 'krbAdmServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.30.1 NAME 'krbMaxPwdLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.31.1 NAME 'krbMinPwdLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.32.1 NAME 'krbPwdMinDiffChars' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.33.1 NAME 'krbPwdMinLength' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.34.1 NAME 'krbPwdHistoryLength' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.3.6.1.4.1.5322.21.2.1 NAME 'krbPwdMaxFailure' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.3.6.1.4.1.5322.21.2.2 NAME 'krbPwdFailureCountInterval' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.3.6.1.4.1.5322.21.2.3 NAME 'krbPwdLockoutDuration' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.2.840.113554.1.4.1.6.2 NAME 'krbPwdAttributes' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.2.840.113554.1.4.1.6.3 NAME 'krbPwdMaxLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.2.840.113554.1.4.1.6.4 NAME 'krbPwdMaxRenewableLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 1.2.840.113554.1.4.1.6.5 NAME 'krbPwdAllowedKeysalts' EQUALITY caseIgnoreIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.36.1 NAME 'krbPwdPolicyReference' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.37.1 NAME 'krbPasswordExpiration' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.39.1 NAME 'krbPrincipalKey' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.40.1 NAME 'krbTicketPolicyReference' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.41.1 NAME 'krbSubTrees' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.42.1 NAME 'krbDefaultEncSaltTypes' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.43.1 NAME 'krbSupportedEncSaltTypes' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.44.1 NAME 'krbPwdHistory' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.45.1 NAME 'krbLastPwdChange' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 1.3.6.1.4.1.5322.21.2.5 NAME 'krbLastAdminUnlock' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.46.1 NAME 'krbMKey' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.47.1 NAME 'krbPrincipalAliases' EQUALITY caseExactIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.48.1 NAME 'krbLastSuccessfulAuth' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.49.1 NAME 'krbLastFailedAuth' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.50.1 NAME 'krbLoginFailedCount' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.51.1 NAME 'krbExtraData' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.52.1 NAME 'krbObjectReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113719.1.301.4.53.1 NAME 'krbPrincContainerRef' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 - "( 2.16.840.1.113730.3.8.15.2.1 NAME 'krbPrincipalAuthInd' EQUALITY caseExactMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 - "( 1.3.6.1.4.1.5322.21.2.4 NAME 'krbAllowedToDelegateTo' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 - ], - ) - at_raw_definitions_filtered = [ - definition - for definition in at_raw_definitions - if "name 'ms" not in definition.lower() - ] - for at_raw_definition in at_raw_definitions_filtered: - attribute_type = RDParser.create_attribute_type_by_raw( - raw_definition=at_raw_definition, + async def _create_attribute_types(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + at_type_use_case = await cnt.get(AttributeTypeUseCaseLegacy) + + for oid, name in ( + ("2.16.840.1.113730.3.1.610", "nsAccountLock"), + ("1.3.6.1.4.1.99999.1.1", "posixEmail"), + ): + await at_type_use_case.create( + AttributeTypeDTO( + oid=oid, + name=name, + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=True, + system_flags=0, + is_included_anr=False, + ), + ) + + await session.flush() + + # NOTE: Load attributeTypes into the database + at_raw_definitions: list[str] = ad_2012_r2_schema_json["raw"][ + "attributeTypes" + ] + at_raw_definitions.extend( + [ + "( 1.2.840.113556.1.4.9999 NAME 'entityTypeName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )", # noqa: E501 + # + # Kerberos schema: https://github.com/krb5/krb5/blob/master/src/plugins/kdb/ldap/libkdb_ldap/kerberos.schema + "( 2.16.840.1.113719.1.301.4.1.1 NAME 'krbPrincipalName' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 + "( 1.2.840.113554.1.4.1.6.1 NAME 'krbCanonicalName' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.3.1 NAME 'krbPrincipalType' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.5.1 NAME 'krbUPEnabled' DESC 'Boolean' SYNTAX 1.3.6.1.4.1.1466.115.121.1.7 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.6.1 NAME 'krbPrincipalExpiration' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.8.1 NAME 'krbTicketFlags' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.9.1 NAME 'krbMaxTicketLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.10.1 NAME 'krbMaxRenewableAge' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.14.1 NAME 'krbRealmReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.15.1 NAME 'krbLdapServers' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.17.1 NAME 'krbKdcServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.18.1 NAME 'krbPwdServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.24.1 NAME 'krbHostServer' EQUALITY caseExactIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.25.1 NAME 'krbSearchScope' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.26.1 NAME 'krbPrincipalReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.28.1 NAME 'krbPrincNamingAttr' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.29.1 NAME 'krbAdmServers' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.30.1 NAME 'krbMaxPwdLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.31.1 NAME 'krbMinPwdLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.32.1 NAME 'krbPwdMinDiffChars' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.33.1 NAME 'krbPwdMinLength' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.34.1 NAME 'krbPwdHistoryLength' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.3.6.1.4.1.5322.21.2.1 NAME 'krbPwdMaxFailure' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.3.6.1.4.1.5322.21.2.2 NAME 'krbPwdFailureCountInterval' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.3.6.1.4.1.5322.21.2.3 NAME 'krbPwdLockoutDuration' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.2.840.113554.1.4.1.6.2 NAME 'krbPwdAttributes' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.2.840.113554.1.4.1.6.3 NAME 'krbPwdMaxLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.2.840.113554.1.4.1.6.4 NAME 'krbPwdMaxRenewableLife' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 1.2.840.113554.1.4.1.6.5 NAME 'krbPwdAllowedKeysalts' EQUALITY caseIgnoreIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.36.1 NAME 'krbPwdPolicyReference' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.37.1 NAME 'krbPasswordExpiration' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.39.1 NAME 'krbPrincipalKey' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.40.1 NAME 'krbTicketPolicyReference' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.41.1 NAME 'krbSubTrees' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.42.1 NAME 'krbDefaultEncSaltTypes' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.43.1 NAME 'krbSupportedEncSaltTypes' EQUALITY caseIgnoreMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.44.1 NAME 'krbPwdHistory' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.45.1 NAME 'krbLastPwdChange' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 1.3.6.1.4.1.5322.21.2.5 NAME 'krbLastAdminUnlock' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.46.1 NAME 'krbMKey' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.47.1 NAME 'krbPrincipalAliases' EQUALITY caseExactIA5Match SYNTAX 1 3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.48.1 NAME 'krbLastSuccessfulAuth' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.49.1 NAME 'krbLastFailedAuth' EQUALITY generalizedTimeMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.50.1 NAME 'krbLoginFailedCount' EQUALITY integerMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.27 SINGLE-VALUE)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.51.1 NAME 'krbExtraData' EQUALITY octetStringMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.40)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.52.1 NAME 'krbObjectReferences' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113719.1.301.4.53.1 NAME 'krbPrincContainerRef' EQUALITY distinguishedNameMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.12)", # noqa: E501 + "( 2.16.840.1.113730.3.8.15.2.1 NAME 'krbPrincipalAuthInd' EQUALITY caseExactMatch SYNTAX 1 3.6.1.4.1.1466.115.121.1.15)", # noqa: E501 + "( 1.3.6.1.4.1.5322.21.2.4 NAME 'krbAllowedToDelegateTo' EQUALITY caseExactIA5Match SUBSTR caseExactSubstringsMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.26)", # noqa: E501 + ], ) - session.add(attribute_type) - session.commit() + + at_raw_definitions_filtered = [ + definition + for definition in at_raw_definitions + if "name 'ms" not in definition.lower() + ] + + for at_raw_definition in at_raw_definitions_filtered: + attribute_type_dto = RDParser.collect_attribute_type_dto_from_raw( + raw_definition=at_raw_definition, + ) + await at_type_use_case.create(attribute_type_dto) + + await session.commit() + + op.run_async(_create_attribute_types) # NOTE: Load objectClasses into the database async def _create_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) + oc_use_case = await cnt.get(ObjectClassUseCaseLegacy) oc_already_created_oids = set() oc_first_priority_raw_definitions = ( @@ -308,11 +350,12 @@ async def _create_object_classes(connection: AsyncConnection) -> None: # noqa: ) oc_already_created_oids.add(object_class_info.oid) - object_class = await RDParser.create_object_class_by_info( - session=session, - object_class_info=object_class_info, + object_class_dto = ( + await RDParser.collect_object_class_dto_from_info( + object_class_info=object_class_info, + ) ) - session.add(object_class) + await oc_use_case.create(object_class_dto) oc_raw_definitions: list[str] = ad_2012_r2_schema_json["raw"][ "objectClasses" @@ -330,46 +373,29 @@ async def _create_object_classes(connection: AsyncConnection) -> None: # noqa: if object_class_info.oid in oc_already_created_oids: continue - object_class = await RDParser.create_object_class_by_info( - session=session, - object_class_info=object_class_info, + object_class_dto = ( + await RDParser.collect_object_class_dto_from_info( + object_class_info=object_class_info, + ) ) - session.add(object_class) + await oc_use_case.create(object_class_dto) await session.commit() - await session.close() op.run_async(_create_object_classes) - async def _create_attribute_types(connection: AsyncConnection) -> None: # noqa: ARG001 - async with container(scope=Scope.REQUEST) as cnt: - session = await cnt.get(AsyncSession) - attribute_type_dao = await cnt.get(AttributeTypeDAO) - - for oid, name in ( - ("2.16.840.1.113730.3.1.610", "nsAccountLock"), - ("1.3.6.1.4.1.99999.1.1", "posixEmail"), - ): - await attribute_type_dao.create( - AttributeTypeDTO( - oid=oid, - name=name, - syntax="1.3.6.1.4.1.1466.115.121.1.15", - single_value=True, - no_user_modification=False, - is_system=True, - system_flags=0, - is_included_anr=False, - ), - ) - - await session.commit() - - op.run_async(_create_attribute_types) - async def _modify_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) + attribute_type_dao_legacy = AttributeTypeDAOLegacy(session=session) + object_class_dao_legacy = ObjectClassDAOLegacy(session=session) + attribute_type_use_case = AttributeTypeUseCaseLegacy( + attribute_type_dao_legacy=attribute_type_dao_legacy, + ) + object_class_use_case = ObjectClassUseCaseLegacy( + attribute_type_dao_legacy=attribute_type_dao_legacy, + object_class_dao_legacy=object_class_dao_legacy, + ) for oc_name, at_names in ( ("user", ["nsAccountLock", "shadowExpire"]), @@ -377,22 +403,18 @@ async def _modify_object_classes(connection: AsyncConnection) -> None: # noqa: ("posixAccount", ["posixEmail"]), ("organizationalUnit", ["title", "jpegPhoto"]), ): - object_class = await session.scalar( - select(ObjectClass) - .filter_by(name=oc_name) - .options(selectinload(qa(ObjectClass.attribute_types_may))), - ) + object_class = await object_class_use_case.get_raw_by_name(oc_name) if not object_class: continue - attribute_types = await session.scalars( - select(AttributeType) - .where(qa(AttributeType.name).in_(at_names), - ), - ) # fmt: skip + attribute_types = ( + await attribute_type_use_case.get_all_raw_by_names( + at_names, + ) + ) - object_class.attribute_types_may.extend(attribute_types.all()) + object_class.attribute_types_may.extend(attribute_types) await session.commit() @@ -442,4 +464,3 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 op.drop_index("ix_AttributeTypes_name", table_name="AttributeTypes") op.drop_index("ix_AttributeTypes_oid", table_name="AttributeTypes") op.drop_table("AttributeTypes") - # ### end Alembic commands ### diff --git a/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py b/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py index b819c1c86..a06b1af95 100644 --- a/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py +++ b/app/alembic/versions/2dadf40c026a_add_system_flags_to_attribute_types.py @@ -6,19 +6,15 @@ """ -import contextlib - import sqlalchemy as sa from alembic import op from dishka import AsyncContainer, Scope from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import Session -from entities import AttributeType -from ldap_protocol.ldap_schema.attribute_type_use_case import ( - AttributeTypeUseCase, +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, ) -from ldap_protocol.ldap_schema.exceptions import AttributeTypeNotFoundError # revision identifiers, used by Alembic. revision: None | str = "2dadf40c026a" @@ -27,7 +23,7 @@ depends_on: None | list[str] = None -_NON_REPLICATED_ATTRIBUTES_TYPE_NAMES = ( +_NON_REPLICATED_ATTRIBUTES_TYPE_NAMES: tuple[str, ...] = ( "badPasswordTime", "badPwdCount", "bridgeheadServerListBL", @@ -144,23 +140,28 @@ def upgrade(container: AsyncContainer) -> None: ), ) - session.execute(sa.update(AttributeType).values({"system_flags": 0})) + async def _zero_all_replicated_flags(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + at_type_use_case = await cnt.get(AttributeTypeUseCaseLegacy) + + await at_type_use_case.zero_all_replicated_flags() + await session.commit() + + op.run_async(_zero_all_replicated_flags) - async def _set_attr_replication_flag(connection: AsyncConnection) -> None: # noqa: ARG001 + async def _set_false_replication_flag(connection: AsyncConnection) -> None: # noqa: ARG001 async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - at_type_use_case = await cnt.get(AttributeTypeUseCase) + at_type_use_case = await cnt.get(AttributeTypeUseCaseLegacy) - for name in _NON_REPLICATED_ATTRIBUTES_TYPE_NAMES: - with contextlib.suppress(AttributeTypeNotFoundError): - await at_type_use_case.set_attr_replication_flag( - name, - need_to_replicate=False, - ) + await at_type_use_case.set_false_replication_flag( + _NON_REPLICATED_ATTRIBUTES_TYPE_NAMES, + ) await session.commit() - op.run_async(_set_attr_replication_flag) + op.run_async(_set_false_replication_flag) op.alter_column("AttributeTypes", "system_flags", nullable=False) diff --git a/app/alembic/versions/4442d1d982a4_remove_krb_policy.py b/app/alembic/versions/4442d1d982a4_remove_krb_policy.py index 5673da6a8..0dfd80697 100644 --- a/app/alembic/versions/4442d1d982a4_remove_krb_policy.py +++ b/app/alembic/versions/4442d1d982a4_remove_krb_policy.py @@ -22,7 +22,7 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py b/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py index 6d101e95a..b67735386 100644 --- a/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py +++ b/app/alembic/versions/6303f5c706ec_update_krbadmin_useraccountcontrol_.py @@ -26,7 +26,7 @@ depends_on: None | list[str] = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" @@ -93,7 +93,7 @@ async def _change_uid_admin(connection: AsyncConnection) -> None: # noqa: ARG00 op.run_async(_change_uid_admin) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/6c858cc05da7_add_default_admin_name.py b/app/alembic/versions/6c858cc05da7_add_default_admin_name.py index 7b1a59f1e..f7e7e8e3f 100644 --- a/app/alembic/versions/6c858cc05da7_add_default_admin_name.py +++ b/app/alembic/versions/6c858cc05da7_add_default_admin_name.py @@ -22,7 +22,7 @@ depends_on: None | list[str] = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/6f8fe2548893_fix_read_only.py b/app/alembic/versions/6f8fe2548893_fix_read_only.py index 8d0f87874..f28264704 100644 --- a/app/alembic/versions/6f8fe2548893_fix_read_only.py +++ b/app/alembic/versions/6f8fe2548893_fix_read_only.py @@ -24,8 +24,8 @@ depends_on: None = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/708b01eaf025_convert_schema_to_ldap.py b/app/alembic/versions/708b01eaf025_convert_schema_to_ldap.py new file mode 100644 index 000000000..66a5006c8 --- /dev/null +++ b/app/alembic/versions/708b01eaf025_convert_schema_to_ldap.py @@ -0,0 +1,195 @@ +"""Migrate LDAP Schema data to Directory (like as LDAP). + +Revision ID: 759d196145ae +Revises: 19d86e660cf2 +Create Date: 2026-02-24 13:18:06.715730 + +""" + +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from constants import ENTITY_TYPE_DTOS_V2 +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_use_case import ( # noqa: E501 + ObjectClassUseCaseLegacy, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( + AttributeTypeUseCase, +) +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) +from ldap_protocol.roles.migrations_ace_dao import ( + AccessControlEntryMigrationsDAO, +) +from ldap_protocol.utils.queries import get_base_directories + +# revision identifiers, used by Alembic. +revision: None | str = "708b01eaf025" +down_revision: None | str = "df4287898910" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + op.drop_constraint( + op.f("AccessControlEntries_attributeTypeId_fkey"), + "AccessControlEntries", + type_="foreignkey", + ) + + async def _update_entity_types(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) + + if not await get_base_directories(session): + return + + for entity_type_dto in ENTITY_TYPE_DTOS_V2: + await entity_type_use_case.create(entity_type_dto) + + await session.commit() + + async def _create_ldap_attributes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_type_use_case = await cnt.get(AttributeTypeUseCase) + attribute_type_use_case_legacy = await cnt.get(AttributeTypeUseCaseLegacy) # noqa: E501 # fmt: skip + + if not await get_base_directories(session): + return + + attr_type_dtos = await attribute_type_use_case_legacy.get_all() + for attr_type_dto in attr_type_dtos: + await attribute_type_use_case.create(attr_type_dto) + + await session.commit() + + async def _create_ldap_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + obj_cls_use_case_legacy = await cnt.get(ObjectClassUseCaseLegacy) + object_class_use_case = await cnt.get(ObjectClassUseCase) + + if not await get_base_directories(session): + return + + obj_class_dtos = await obj_cls_use_case_legacy.get_all() + for obj_class_dto in obj_class_dtos: + obj_class_dto.attribute_types_may = [ + _.name # type: ignore + for _ in obj_class_dto.attribute_types_may + ] + obj_class_dto.attribute_types_must = [ + _.name # type: ignore + for _ in obj_class_dto.attribute_types_must + ] + await object_class_use_case.create(obj_class_dto) # type: ignore + + await session.commit() + + async def _rebind_ace_attribute_types_to_directories( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + ace_dao = await cnt.get(AccessControlEntryMigrationsDAO) + + if not await get_base_directories(session): + return + + await ace_dao.upgrade() + + op.run_async(_update_entity_types) + op.run_async(_create_ldap_attributes) + op.run_async(_create_ldap_object_classes) + op.run_async(_rebind_ace_attribute_types_to_directories) + + op.create_foreign_key( + op.f("AccessControlEntries_directoryAttributeTypeId_fkey"), + "AccessControlEntries", + "Directory", + ["attributeTypeId"], + ["id"], + ondelete="CASCADE", + ) + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" + + async def _rebind_ace_attribute_types_to_legacy( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + ace_dao = await cnt.get(AccessControlEntryMigrationsDAO) + + if not await get_base_directories(session): + return + + await ace_dao.downgrade() + + async def _delete_ldap_attributes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_type_use_case_legacy = await cnt.get(AttributeTypeUseCaseLegacy) # noqa: E501 # fmt: skip + + if not await get_base_directories(session): + return + + await attribute_type_use_case_legacy.delete_all_dirs() + await session.commit() + + async def _delete_ldap_object_classes(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + obj_cls_use_case_legacy = await cnt.get(ObjectClassUseCaseLegacy) + + if not await get_base_directories(session): + return + + await obj_cls_use_case_legacy.delete_all_dirs() + await session.commit() + + async def _delete_entity_types(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) + + if not await get_base_directories(session): + return + + entity_type_names = [dto.name for dto in ENTITY_TYPE_DTOS_V2] + + await entity_type_use_case.delete_all_by_names(entity_type_names) + await session.commit() + + op.drop_constraint( + op.f("AccessControlEntries_directoryAttributeTypeId_fkey"), + "AccessControlEntries", + type_="foreignkey", + ) + + op.run_async(_rebind_ace_attribute_types_to_legacy) + op.run_async(_delete_ldap_attributes) + op.run_async(_delete_ldap_object_classes) + op.run_async(_delete_entity_types) + + op.create_foreign_key( + op.f("AccessControlEntries_attributeTypeId_fkey"), + "AccessControlEntries", + "AttributeTypes", + ["attributeTypeId"], + ["id"], + ondelete="CASCADE", + ) diff --git a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py index a4eb7297c..46d4e1334 100644 --- a/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py +++ b/app/alembic/versions/8164b4a9e1f1_add_ou_computers.py @@ -34,7 +34,7 @@ } -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" from ldap_protocol.auth.setup_gateway import SetupGateway @@ -83,7 +83,7 @@ async def _create_ou_computers(connection: AsyncConnection) -> None: # noqa: AR op.run_async(_create_ou_computers) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/ba78cef9700a_initial_entity_type.py b/app/alembic/versions/ba78cef9700a_initial_entity_type.py index 0e6744919..90458e19f 100644 --- a/app/alembic/versions/ba78cef9700a_initial_entity_type.py +++ b/app/alembic/versions/ba78cef9700a_initial_entity_type.py @@ -13,12 +13,12 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession -from constants import ENTITY_TYPE_DATAS +from constants import ENTITY_TYPE_DTOS_V1 from entities import Attribute, Directory, User from extra.alembic_utils import temporary_stub_column -from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -29,8 +29,8 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade database schema and data, creating Entity Types.""" op.create_table( @@ -105,14 +105,8 @@ async def _create_entity_types(connection: AsyncConnection) -> None: # noqa: AR if not await get_base_directories(session): return - for entity_type_data in ENTITY_TYPE_DATAS: - await entity_type_use_case.create( - EntityTypeDTO( - name=entity_type_data["name"], - object_class_names=entity_type_data["object_class_names"], - is_system=True, - ), - ) + for entity_type_dto in ENTITY_TYPE_DTOS_V1: + await entity_type_use_case.create(entity_type_dto) await session.commit() @@ -159,12 +153,12 @@ async def _attach_entity_type_to_directories( ) -> None: async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - entity_type_dao = await cnt.get(EntityTypeDAO) + entity_type_use_case = await cnt.get(EntityTypeUseCase) if not await get_base_directories(session): return - await entity_type_dao.attach_entity_type_to_directories() + await entity_type_use_case.attach_entity_type_to_directories() await session.commit() diff --git a/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py b/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py index 88eaf4581..e59ca94b4 100644 --- a/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py +++ b/app/alembic/versions/bf435bbd95ff_add_rdn_attr_name.py @@ -22,8 +22,8 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" op.add_column("Directory", sa.Column("rdname", sa.String(length=64))) @@ -58,8 +58,8 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 op.alter_column("Directory", "rdname", nullable=False) -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Downgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py b/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py index dfaa36aa0..37a49a988 100644 --- a/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py +++ b/app/alembic/versions/bv546ccd35fa_fix_krbadmin_attrs.py @@ -22,8 +22,8 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py index dbaa321be..0f17a6965 100644 --- a/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py +++ b/app/alembic/versions/c4888c68e221_fix_admin_attr_and_policy.py @@ -14,7 +14,9 @@ from entities import Attribute, Directory, NetworkPolicy from extra.alembic_utils import temporary_stub_column -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.utils.helpers import create_integer_hash from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -26,7 +28,7 @@ depends_on: None | list[str] = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" @@ -35,12 +37,12 @@ async def _attach_entity_type_to_directories( ) -> None: async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - entity_type_dao = await cnt.get(EntityTypeDAO) + entity_type_use_case = await cnt.get(EntityTypeUseCase) if not await get_base_directories(session): return - await entity_type_dao.attach_entity_type_to_directories() + await entity_type_use_case.attach_entity_type_to_directories() await session.commit() async def _change_uid_admin(connection: AsyncConnection) -> None: # noqa: ARG001 diff --git a/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py b/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py index 33bd3a433..a629c1ffa 100644 --- a/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py +++ b/app/alembic/versions/c5a9b3f2e8d7_add_contact_object_class.py @@ -14,7 +14,9 @@ from entities import EntityType from enums import EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa diff --git a/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py b/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py index 38d982694..576fb2b84 100644 --- a/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py +++ b/app/alembic/versions/dafg3a4b22ab_add_preauth_princ.py @@ -23,8 +23,8 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 """Upgrade.""" bind = op.get_bind() diff --git a/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py b/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py index 0a64e7fb7..e332bb950 100644 --- a/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py +++ b/app/alembic/versions/e4d6d99d32bd_add_audit_policies.py @@ -28,7 +28,7 @@ depends_on: None | str = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" diff --git a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py index cf1c80e1b..199d6b107 100644 --- a/app/alembic/versions/f1abf7ef2443_add_container_object_class.py +++ b/app/alembic/versions/f1abf7ef2443_add_container_object_class.py @@ -24,7 +24,7 @@ depends_on: None | str = None -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" @@ -110,7 +110,7 @@ async def _migrate_ou_to_cn_containers( op.run_async(_migrate_ou_to_cn_containers) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/alembic/versions/f24ed0e49df2_add_filter_anr.py b/app/alembic/versions/f24ed0e49df2_add_filter_anr.py index b6ec3ee1a..cb6eab110 100644 --- a/app/alembic/versions/f24ed0e49df2_add_filter_anr.py +++ b/app/alembic/versions/f24ed0e49df2_add_filter_anr.py @@ -8,12 +8,15 @@ import sqlalchemy as sa from alembic import op -from dishka import AsyncContainer +from dishka import AsyncContainer, Scope from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import Session -from entities import AttributeType -from repo.pg.tables import queryable_attr as qa +from extra.alembic_utils import temporary_stub_column +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) # revision identifiers, used by Alembic. revision: None | str = "f24ed0e49df2" @@ -35,7 +38,8 @@ ) -def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 +@temporary_stub_column("AttributeTypes", "system_flags", sa.Integer()) +def upgrade(container: AsyncContainer) -> None: """Upgrade.""" bind = op.get_bind() session = Session(bind=bind) @@ -44,9 +48,17 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 "AttributeTypes", sa.Column("is_included_anr", sa.Boolean(), nullable=True), ) - session.execute( - sa.update(AttributeType).values({"is_included_anr": False}), - ) + + async def _false_all_is_included_anr(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_type_use_case = await cnt.get(AttributeTypeUseCaseLegacy) + + await attribute_type_use_case.false_all_is_included_anr() + await session.flush() + + op.run_async(_false_all_is_included_anr) + op.alter_column("AttributeTypes", "is_included_anr", nullable=False) op.alter_column( @@ -56,14 +68,22 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 nullable=True, ) - updated_attrs = session.execute( - sa.update(AttributeType) - .where(qa(AttributeType.name).in_(_DEFAULT_ANR_ATTRIBUTE_TYPE_NAMES)) - .values({"is_included_anr": True}) - .returning(qa(AttributeType.name)), - ) - if len(updated_attrs.all()) != len(_DEFAULT_ANR_ATTRIBUTE_TYPE_NAMES): - raise ValueError("Not all expected attributes were found in the DB.") + async def _mark_anr_included(connection: AsyncConnection) -> None: # noqa: ARG001 + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + attribute_type_use_case = await cnt.get(AttributeTypeUseCaseLegacy) + + len_updated_attrs = len( + await attribute_type_use_case.mark_anr_included_by_attr_names( + _DEFAULT_ANR_ATTRIBUTE_TYPE_NAMES, + ), + ) + if len_updated_attrs != len(_DEFAULT_ANR_ATTRIBUTE_TYPE_NAMES): + raise ValueError("Not all expected attributes were found") + + await session.flush() + + op.run_async(_mark_anr_included) session.commit() diff --git a/app/alembic/versions/fafc3d0b11ec_.py b/app/alembic/versions/fafc3d0b11ec_.py index 1ab09b6b3..2f7c733b5 100644 --- a/app/alembic/versions/fafc3d0b11ec_.py +++ b/app/alembic/versions/fafc3d0b11ec_.py @@ -32,8 +32,8 @@ depends_on: None | str = None -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def upgrade(container: AsyncContainer) -> None: """Upgrade.""" @@ -76,8 +76,8 @@ async def _create_readonly_grp_and_plcy( op.run_async(_create_readonly_grp_and_plcy) -@temporary_stub_column("entity_type_id", sa.Integer()) -@temporary_stub_column("is_system", sa.Boolean()) +@temporary_stub_column("Directory", "entity_type_id", sa.Integer()) +@temporary_stub_column("Directory", "is_system", sa.Boolean()) def downgrade(container: AsyncContainer) -> None: """Downgrade.""" diff --git a/app/api/ldap_schema/adapters/attribute_type.py b/app/api/ldap_schema/adapters/attribute_type.py index 73e5f32bc..9a12debf9 100644 --- a/app/api/ldap_schema/adapters/attribute_type.py +++ b/app/api/ldap_schema/adapters/attribute_type.py @@ -17,19 +17,19 @@ from api.ldap_schema.adapters.base_ldap_schema_adapter import ( BaseLDAPSchemaAdapter, ) +from api.ldap_schema.constants import ( + DEFAULT_ATTRIBUTE_TYPE_IS_SYSTEM, + DEFAULT_ATTRIBUTE_TYPE_NO_USER_MOD, + DEFAULT_ATTRIBUTE_TYPE_SYNTAX, +) from api.ldap_schema.schema import ( AttributeTypePaginationSchema, AttributeTypeSchema, AttributeTypeUpdateSchema, ) -from ldap_protocol.ldap_schema.attribute_type_use_case import ( +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( AttributeTypeUseCase, ) -from ldap_protocol.ldap_schema.constants import ( - DEFAULT_ATTRIBUTE_TYPE_IS_SYSTEM, - DEFAULT_ATTRIBUTE_TYPE_NO_USER_MOD, - DEFAULT_ATTRIBUTE_TYPE_SYNTAX, -) from ldap_protocol.ldap_schema.dto import AttributeTypeDTO @@ -37,7 +37,7 @@ def _convert_update_uschema_to_dto( request: AttributeTypeUpdateSchema, ) -> AttributeTypeDTO[None]: """Convert AttributeTypeUpdateSchema to AttributeTypeDTO for update.""" - return AttributeTypeDTO( + return AttributeTypeDTO[None]( oid="", name="", syntax=request.syntax, diff --git a/app/api/ldap_schema/adapters/base_ldap_schema_adapter.py b/app/api/ldap_schema/adapters/base_ldap_schema_adapter.py index ebbe17192..f431f58e4 100644 --- a/app/api/ldap_schema/adapters/base_ldap_schema_adapter.py +++ b/app/api/ldap_schema/adapters/base_ldap_schema_adapter.py @@ -68,10 +68,7 @@ class BaseLDAPSchemaAdapter( _converter_update_sch_to_dto: staticmethod[[UpdateSchemaT], DtoT] async def create(self, data: SchemaT) -> None: - """Create a new entity. - - :param request_data: Data for creating entity. - """ + """Create a new entity.""" dto = self._converter_to_dto(data) await self._service.create(dto) @@ -79,11 +76,7 @@ async def get( self, name: str, ) -> SchemaT: - """Get a single entity by name. - - :param str name: Name of the entity. - :return: Entity schema. - """ + """Get a single entity by name.""" attribute_type = await self._service.get(name) return self._converter_to_schema(attribute_type) @@ -91,11 +84,7 @@ async def get_list_paginated( self, params: PaginationParams, ) -> PaginationSchemaT: - """Get a list of entities with pagination. - - :param PaginationParams params: Pagination parameters. - :return: Paginated result schema. - """ + """Get a list of entities with pagination.""" pagination_result = await self._service.get_paginator(params) items: list[SchemaT] = [ @@ -112,11 +101,7 @@ async def update( name: str, data: UpdateSchemaT, ) -> None: - """Modify an entity. - - :param str name: Name of the entity to modify. - :param data: Updated data. - """ + """Modify an entity.""" dto = self._converter_update_sch_to_dto(data) await self._service.update(name, dto) @@ -124,8 +109,5 @@ async def delete_bulk( self, names: LimitedListType, ) -> None: - """Delete multiple entities. - - :param LimitedListType names: Names of entities to delete. - """ + """Delete multiple entities.""" await self._service.delete_all_by_names(names) diff --git a/app/api/ldap_schema/adapters/entity_type.py b/app/api/ldap_schema/adapters/entity_type.py index 03199b634..598683253 100644 --- a/app/api/ldap_schema/adapters/entity_type.py +++ b/app/api/ldap_schema/adapters/entity_type.py @@ -10,14 +10,16 @@ from api.ldap_schema.adapters.base_ldap_schema_adapter import ( BaseLDAPSchemaAdapter, ) +from api.ldap_schema.constants import DEFAULT_ENTITY_TYPE_IS_SYSTEM from api.ldap_schema.schema import ( EntityTypePaginationSchema, EntityTypeSchema, EntityTypeUpdateSchema, ) -from ldap_protocol.ldap_schema.constants import DEFAULT_ENTITY_TYPE_IS_SYSTEM from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) def _convert_update_chema_to_dto( diff --git a/app/api/ldap_schema/adapters/object_class.py b/app/api/ldap_schema/adapters/object_class.py index 7c0199a88..2d611a760 100644 --- a/app/api/ldap_schema/adapters/object_class.py +++ b/app/api/ldap_schema/adapters/object_class.py @@ -11,15 +11,17 @@ from api.ldap_schema.adapters.base_ldap_schema_adapter import ( BaseLDAPSchemaAdapter, ) +from api.ldap_schema.constants import DEFAULT_OBJECT_CLASS_IS_SYSTEM from api.ldap_schema.schema import ( ObjectClassPaginationSchema, ObjectClassSchema, ObjectClassUpdateSchema, ) from enums import KindType -from ldap_protocol.ldap_schema.constants import DEFAULT_OBJECT_CLASS_IS_SYSTEM -from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO -from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.ldap_schema.dto import ObjectClassDTO +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) def _convert_update_schema_to_dto( @@ -57,20 +59,20 @@ def _convert_update_schema_to_dto( ], ) -_convert_dto_to_schema = get_converter( - ObjectClassDTO[int, AttributeTypeDTO], - ObjectClassSchema[int], - recipe=[ - link_function( - lambda dto: [attr.name for attr in dto.attribute_types_must], - P[ObjectClassSchema].attribute_type_names_must, - ), - link_function( - lambda dto: [attr.name for attr in dto.attribute_types_may], - P[ObjectClassSchema].attribute_type_names_may, - ), - ], -) + +def _convert_dto_to_schema(dto: ObjectClassDTO) -> ObjectClassSchema[int]: + """Map DTO object to API schema with explicit attribute name fields.""" + return ObjectClassSchema( + oid=dto.oid, + name=dto.name, + superior_name=dto.superior_name, + kind=dto.kind, + is_system=dto.is_system, + attribute_type_names_must=dto.attribute_types_must, + attribute_type_names_may=dto.attribute_types_may, + id=dto.id, + entity_type_names=dto.entity_type_names, + ) class ObjectClassFastAPIAdapter( diff --git a/app/ldap_protocol/ldap_schema/constants.py b/app/api/ldap_schema/constants.py similarity index 89% rename from app/ldap_protocol/ldap_schema/constants.py rename to app/api/ldap_schema/constants.py index fe53fe58a..12b81f1b4 100644 --- a/app/ldap_protocol/ldap_schema/constants.py +++ b/app/api/ldap_schema/constants.py @@ -4,8 +4,6 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -import re as re - DEFAULT_ATTRIBUTE_TYPE_SYNTAX = "1.3.6.1.4.1.1466.115.121.1.15" DEFAULT_ATTRIBUTE_TYPE_NO_USER_MOD = False DEFAULT_ATTRIBUTE_TYPE_IS_SYSTEM = False @@ -16,4 +14,3 @@ # NOTE: Domain value object # RFC 4512: OID = number 1*( "." number ) OID_REGEX_PATTERN = r"^[0-9]+(\.[0-9]+)+$" -OID_REGEX = re.compile(OID_REGEX_PATTERN) diff --git a/app/api/ldap_schema/schema.py b/app/api/ldap_schema/schema.py index b3dabefb6..d4c07c827 100644 --- a/app/api/ldap_schema/schema.py +++ b/app/api/ldap_schema/schema.py @@ -9,12 +9,10 @@ from pydantic import BaseModel, Field from enums import EntityTypeNames, KindType -from ldap_protocol.ldap_schema.constants import ( - DEFAULT_ENTITY_TYPE_IS_SYSTEM, - OID_REGEX_PATTERN, -) from ldap_protocol.utils.pagination import BasePaginationSchema +from .constants import DEFAULT_ENTITY_TYPE_IS_SYSTEM, OID_REGEX_PATTERN + _IdT = TypeVar("_IdT", int, None) diff --git a/app/constants.py b/app/constants.py index 5086dfad1..514f40557 100644 --- a/app/constants.py +++ b/app/constants.py @@ -4,10 +4,10 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from typing import TypedDict - from enums import EntityTypeNames, SamAccountTypeCodes +from ldap_protocol.ldap_schema.dto import EntityTypeDTO +CONFIGURATION_DIR_NAME = "Configuration" GROUPS_CONTAINER_NAME = "Groups" COMPUTERS_CONTAINER_NAME = "Computers" USERS_CONTAINER_NAME = "Users" @@ -223,36 +223,36 @@ ] -class EntityTypeData(TypedDict): - """Entity Type data.""" - - name: EntityTypeNames - object_class_names: list[str] - - -ENTITY_TYPE_DATAS: tuple[EntityTypeData, ...] = ( - EntityTypeData( +# NOTE: First time load +ENTITY_TYPE_DTOS_V1: tuple[EntityTypeDTO, ...] = ( + EntityTypeDTO( name=EntityTypeNames.DOMAIN, + is_system=True, object_class_names=["top", "domain", "domainDNS"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.COMPUTER, + is_system=True, object_class_names=["top", "computer"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.CONTAINER, + is_system=True, object_class_names=["top", "container"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.ORGANIZATIONAL_UNIT, + is_system=True, object_class_names=["top", "container", "organizationalUnit"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.GROUP, + is_system=True, object_class_names=["top", "group", "posixGroup"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.USER, + is_system=True, object_class_names=[ "top", "user", @@ -263,8 +263,9 @@ class EntityTypeData(TypedDict): "inetOrgPerson", ], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.CONTACT, + is_system=True, object_class_names=[ "top", "person", @@ -273,28 +274,57 @@ class EntityTypeData(TypedDict): "mailRecipient", ], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.KRB_CONTAINER, + is_system=True, object_class_names=["krbContainer"], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.KRB_PRINCIPAL, + is_system=True, object_class_names=[ "krbprincipal", "krbprincipalaux", "krbTicketPolicyAux", ], ), - EntityTypeData( + EntityTypeDTO( name=EntityTypeNames.KRB_REALM_CONTAINER, + is_system=True, object_class_names=["top", "krbrealmcontainer", "krbticketpolicyaux"], ), ) +# NOTE: Second time load +ENTITY_TYPE_DTOS_V2: tuple[EntityTypeDTO, ...] = ( + EntityTypeDTO( + name=EntityTypeNames.CONFIGURATION, + is_system=True, + object_class_names=["top", "container", "configuration"], + ), + EntityTypeDTO( + name=EntityTypeNames.ATTRIBUTE_TYPE, + is_system=True, + object_class_names=["top", "attributeSchema"], + ), + EntityTypeDTO( + name=EntityTypeNames.OBJECT_CLASS, + is_system=True, + object_class_names=["top", "classSchema"], + ), +) + FIRST_SETUP_DATA = [ + { + "name": CONFIGURATION_DIR_NAME, + "entity_type_name": EntityTypeNames.CONFIGURATION, + "object_class": "container", + "attributes": {"objectClass": ["top", "configuration"]}, + }, { "name": GROUPS_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": { "objectClass": ["top"], @@ -303,6 +333,7 @@ class EntityTypeData(TypedDict): "children": [ { "name": DOMAIN_ADMIN_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -318,6 +349,7 @@ class EntityTypeData(TypedDict): }, { "name": DOMAIN_USERS_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -333,6 +365,7 @@ class EntityTypeData(TypedDict): }, { "name": READ_ONLY_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -348,6 +381,7 @@ class EntityTypeData(TypedDict): }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -365,6 +399,7 @@ class EntityTypeData(TypedDict): }, { "name": COMPUTERS_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [], diff --git a/app/entities.py b/app/entities.py index 9b4d70e16..56c125571 100644 --- a/app/entities.py +++ b/app/entities.py @@ -18,7 +18,6 @@ AuditDestinationServiceType, AuditSeverity, AuthorizationRules, - KindType, MFAFlags, RoleScope, ) @@ -53,96 +52,6 @@ class EntityType: def object_class_names_set(self) -> set[str]: return set(self.object_class_names) - @classmethod - def generate_entity_type_name(cls, directory: Directory) -> str: - return f"{directory.name}_entity_type_{directory.id}" - - -@dataclass -class AttributeType: - """LDAP attribute type definition (schema element).""" - - id: int | None = field(init=False, default=None) - oid: str = "" - name: str = "" - syntax: str = "" - single_value: bool = False - no_user_modification: bool = False - is_system: bool = False - system_flags: int = 0 - # NOTE: ms-adts/cf133d47-b358-4add-81d3-15ea1cff9cd9 - # see section 3.1.1.2.3 `searchFlags` (fANR) for details - is_included_anr: bool = False - - def get_raw_definition(self) -> str: - if not self.oid or not self.name or not self.syntax: - raise ValueError( - f"{self}: Fields 'oid', 'name', " - "and 'syntax' are required for LDAP definition.", - ) - chunks = [ - "(", - self.oid, - f"NAME '{self.name}'", - f"SYNTAX '{self.syntax}'", - ] - if self.single_value: - chunks.append("SINGLE-VALUE") - if self.no_user_modification: - chunks.append("NO-USER-MODIFICATION") - chunks.append(")") - return " ".join(chunks) - - -@dataclass -class ObjectClass: - """LDAP object class definition with MUST/MAY attribute sets.""" - - id: int = field(init=False) - oid: str = "" - name: str = "" - superior_name: str | None = None - kind: KindType | None = None - is_system: bool = False - superior: ObjectClass | None = field(default=None, repr=False) - attribute_types_must: list[AttributeType] = field( - default_factory=list, - repr=False, - ) - attribute_types_may: list[AttributeType] = field( - default_factory=list, - repr=False, - ) - - def get_raw_definition(self) -> str: - if not self.oid or not self.name or not self.kind: - raise ValueError( - f"{self}: Fields 'oid', 'name', and 'kind'" - " are required for LDAP definition.", - ) - chunks = ["(", self.oid, f"NAME '{self.name}'"] - if self.superior_name: - chunks.append(f"SUP {self.superior_name}") - chunks.append(self.kind) - if self.attribute_type_names_must: - chunks.append( - f"MUST ({' $ '.join(self.attribute_type_names_must)} )", - ) - if self.attribute_type_names_may: - chunks.append( - f"MAY ({' $ '.join(self.attribute_type_names_may)} )", - ) - chunks.append(")") - return " ".join(chunks) - - @property - def attribute_type_names_must(self) -> list[str]: - return [a.name for a in self.attribute_types_must] - - @property - def attribute_type_names_may(self) -> list[str]: - return [a.name for a in self.attribute_types_may] - @dataclass class PasswordPolicy: @@ -466,7 +375,7 @@ class AccessControlEntry: is_allow: bool = False role: Role | None = field(init=False, default=None, repr=False) - attribute_type: AttributeType | None = field( + attribute_type: Directory | None = field( init=False, default=None, repr=False, diff --git a/app/entities_legacy.py b/app/entities_legacy.py new file mode 100644 index 000000000..9ff1d5ff1 --- /dev/null +++ b/app/entities_legacy.py @@ -0,0 +1,45 @@ +"""Legacy entities.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from enums import KindType + + +@dataclass +class AttributeTypeLegacy: + """LDAP attribute type definition (schema element).""" + + id: int | None = field(init=False, default=None) + oid: str = "" + name: str = "" + syntax: str = "" + single_value: bool = False + no_user_modification: bool = False + is_system: bool = False + system_flags: int = 0 + # NOTE: ms-adts/cf133d47-b358-4add-81d3-15ea1cff9cd9 + # see section 3.1.1.2.3 `searchFlags` (fANR) for details + is_included_anr: bool = False + + +@dataclass +class ObjectClassLegacy: + """LDAP object class definition with MUST/MAY attribute sets.""" + + id: int = field(init=False) + oid: str = "" + name: str = "" + superior_name: str | None = None + kind: KindType | None = None + is_system: bool = False + superior: ObjectClassLegacy | None = field(default=None, repr=False) + attribute_types_must: list[AttributeTypeLegacy] = field( + default_factory=list, + repr=False, + ) + attribute_types_may: list[AttributeTypeLegacy] = field( + default_factory=list, + repr=False, + ) diff --git a/app/enums.py b/app/enums.py index 2c991d9f4..a57912073 100644 --- a/app/enums.py +++ b/app/enums.py @@ -60,6 +60,9 @@ class EntityTypeNames(StrEnum): """ DOMAIN = "Domain" + CONFIGURATION = "Configuration" + ATTRIBUTE_TYPE = "Attribute Type" + OBJECT_CLASS = "Object Class" COMPUTER = "Computer" CONTAINER = "Container" ORGANIZATIONAL_UNIT = "Organizational Unit" @@ -157,7 +160,6 @@ class AuthorizationRules(IntFlag): ATTRIBUTE_TYPE_GET_PAGINATOR = auto() ATTRIBUTE_TYPE_UPDATE = auto() ATTRIBUTE_TYPE_DELETE_ALL_BY_NAMES = auto() - ATTRIBUTE_TYPE_SET_ATTR_REPLICATION_FLAG = auto() ENTITY_TYPE_GET = auto() ENTITY_TYPE_CREATE = auto() diff --git a/app/extra/alembic_utils.py b/app/extra/alembic_utils.py index ac8cfffd8..01c5b27d8 100644 --- a/app/extra/alembic_utils.py +++ b/app/extra/alembic_utils.py @@ -6,12 +6,16 @@ from alembic import op -def temporary_stub_column(column_name: str, type_: Any) -> Callable: - """Add and drop a temporary column in the 'Directory' table. +def temporary_stub_column( + table_name: str, + column_name: str, + type_: Any, +) -> Callable: + """Add and drop a temporary column in the table. State of the database at the time of migration - doesn't contain the specified column in the 'Directory' table, - but 'Directory' model has the column. + doesn't contain the specified column in the table, + but model has the column. Before starting the migration, add the specified column. Then migration completed, delete the column. @@ -27,11 +31,11 @@ def temporary_stub_column(column_name: str, type_: Any) -> Callable: def decorator(func: Callable) -> Callable: def wrapper(*args: tuple, **kwargs: dict) -> None: op.add_column( - "Directory", + table_name, sa.Column(column_name, type_, nullable=True), ) func(*args, **kwargs) - op.drop_column("Directory", column_name) + op.drop_column(table_name, column_name) return None return wrapper diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 3f700328a..21d3bbaed 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -12,7 +12,9 @@ from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory from enums import SamAccountTypeCodes -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.helpers import create_object_sid @@ -23,7 +25,7 @@ async def _add_domain_controller( session: AsyncSession, role_use_case: RoleUseCase, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, settings: Settings, domain: Directory, dc_ou_dir: Directory, @@ -88,7 +90,7 @@ async def _add_domain_controller( parent_directory=dc_ou_dir, directory=dc_directory, ) - await entity_type_dao.attach_entity_type_to_directory( + await entity_type_use_case.attach_entity_type_to_directory( directory=dc_directory, is_system_entity_type=False, object_class_names={"top", "computer"}, @@ -100,7 +102,7 @@ async def add_domain_controller( session: AsyncSession, settings: Settings, role_use_case: RoleUseCase, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, ) -> None: logger.info("Adding domain controller.") @@ -136,7 +138,7 @@ async def add_domain_controller( await _add_domain_controller( session=session, role_use_case=role_use_case, - entity_type_dao=entity_type_dao, + entity_type_use_case=entity_type_use_case, settings=settings, domain=domains[0], dc_ou_dir=domain_controllers_ou, diff --git a/app/ioc.py b/app/ioc.py index 1a87389d4..9df00448f 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -85,20 +85,45 @@ LDAPSearchRequestContext, LDAPUnbindRequestContext, ) -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO -from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_dao import ( # noqa: E501 + AttributeTypeDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_dao import ( + ObjectClassDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_use_case import ( # noqa: E501 + ObjectClassUseCaseLegacy, +) +from ldap_protocol.ldap_schema.attribute_dao import AttributeDAO +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_system_flags_use_case import ( # noqa: E501 AttributeTypeSystemFlagsUseCase, ) -from ldap_protocol.ldap_schema.attribute_type_use_case import ( +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( AttributeTypeUseCase, ) from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) +from ldap_protocol.ldap_schema.schema_create_use_case import ( + SchemaLikeAsDirectoryCreateUseCase, +) from ldap_protocol.master_check_use_case import ( MasterCheckUseCase, MasterGatewayProtocol, @@ -155,6 +180,9 @@ ) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO +from ldap_protocol.roles.migrations_ace_dao import ( + AccessControlEntryMigrationsDAO, +) from ldap_protocol.roles.role_dao import RoleDAO from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.rootdse.gateway import SADomainGateway @@ -511,18 +539,58 @@ def get_dhcp_mngr( scope=Scope.RUNTIME, ) attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) + attribute_type_dao_legacy = provide( + AttributeTypeDAOLegacy, + scope=Scope.REQUEST, + ) + attribute_dao = provide(AttributeDAO, scope=Scope.REQUEST) attribute_type_system_flags_use_case = provide( AttributeTypeSystemFlagsUseCase, scope=Scope.REQUEST, ) object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) + object_class_dao_legacy = provide( + ObjectClassDAOLegacy, + scope=Scope.REQUEST, + ) + + directory_dao = provide(DirectoryDAO, scope=Scope.REQUEST) entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) attribute_type_use_case = provide( AttributeTypeUseCase, scope=Scope.REQUEST, ) + + @provide(scope=Scope.REQUEST) + def get_attribute_type_use_case_legacy( + self, + session: AsyncSession, + ) -> AttributeTypeUseCaseLegacy: + """Legacy attribute type use case on a single session.""" + at_dao_legacy = AttributeTypeDAOLegacy(session) + return AttributeTypeUseCaseLegacy( + attribute_type_dao_legacy=at_dao_legacy, + ) + + schema_create_use_case = provide( + SchemaLikeAsDirectoryCreateUseCase, + scope=Scope.REQUEST, + ) object_class_use_case = provide(ObjectClassUseCase, scope=Scope.REQUEST) + @provide(scope=Scope.REQUEST) + def get_object_class_use_case_legacy( + self, + session: AsyncSession, + ) -> ObjectClassUseCaseLegacy: + """Legacy object class use case sharing one session for all DAOs.""" + at_dao_legacy = AttributeTypeDAOLegacy(session) + oc_dao_legacy = ObjectClassDAOLegacy(session) + return ObjectClassUseCaseLegacy( + object_class_dao_legacy=oc_dao_legacy, + attribute_type_dao_legacy=at_dao_legacy, + ) + user_password_history_use_cases = provide( UserPasswordHistoryUseCases, scope=Scope.REQUEST, @@ -550,6 +618,10 @@ def get_dhcp_mngr( access_manager = provide(AccessManager, scope=Scope.RUNTIME) role_dao = provide(RoleDAO, scope=Scope.REQUEST) ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) + ace_migrations_dao = provide( + AccessControlEntryMigrationsDAO, + scope=Scope.REQUEST, + ) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) session_repository = provide(SessionRepository, scope=Scope.REQUEST) entity_type_use_case = provide(EntityTypeUseCase, scope=Scope.REQUEST) diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index 6cbad0ea1..9d79c80f8 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -12,10 +12,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory, Group, NetworkPolicy, User +from enums import EntityTypeNames from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class @@ -30,8 +34,9 @@ def __init__( self, session: AsyncSession, password_utils: PasswordUtils, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, attribute_value_validator: AttributeValueValidator, + directory_dao: DirectoryDAO, ) -> None: """Initialize Setup use case. @@ -41,8 +46,9 @@ def __init__( """ self._session = session self._password_utils = password_utils - self._entity_type_dao = entity_type_dao + self._entity_type_use_case = entity_type_use_case self._attribute_value_validator = attribute_value_validator + self._directory_dao = directory_dao async def is_setup(self) -> bool: """Check if setup is performed. @@ -96,9 +102,13 @@ async def setup_enviroment( attribute_names=["attributes"], with_for_update=None, ) - await self._entity_type_dao.attach_entity_type_to_directory( - directory=domain, - is_system_entity_type=True, + + entity_type = await self._entity_type_use_case.get( + EntityTypeNames.DOMAIN, + ) + await self._directory_dao.bind_entity_type( + domain, + entity_type.id if entity_type else None, ) if not self._attribute_value_validator.is_directory_valid(domain): raise ValueError( @@ -216,9 +226,17 @@ async def create_dir( attribute_names=["attributes", "user"], with_for_update=None, ) - await self._entity_type_dao.attach_entity_type_to_directory( + + entity_type = None + if entity_type_name := data.get("entity_type_name"): + entity_type = await self._entity_type_use_case.get( + entity_type_name, + ) + entity_type_id = entity_type.id if entity_type else None + await self._entity_type_use_case.attach_entity_type_to_directory( directory=dir_, is_system_entity_type=True, + entity_type_id=entity_type_id, ) if not self._attribute_value_validator.is_directory_valid(dir_): raise ValueError("Invalid directory attribute values") diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index ca063bcd7..d73ab4bf5 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -5,6 +5,7 @@ """ import copy +from itertools import chain from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -13,17 +14,33 @@ from constants import ( DOMAIN_ADMIN_GROUP_NAME, DOMAIN_CONTROLLERS_OU_NAME, + ENTITY_TYPE_DTOS_V1, + ENTITY_TYPE_DTOS_V2, FIRST_SETUP_DATA, USERS_CONTAINER_NAME, ) -from enums import SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes from ldap_protocol.auth.dto import SetupDTO from ldap_protocol.auth.setup_gateway import SetupGateway from ldap_protocol.identity.exceptions import ( AlreadyConfiguredError, ForbiddenError, ) -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_use_case import ( # noqa: E501 + ObjectClassUseCaseLegacy, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( + AttributeTypeUseCase, +) +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases @@ -36,6 +53,10 @@ class SetupUseCase: def __init__( self, + attribute_type_use_case_legacy: AttributeTypeUseCaseLegacy, + attribute_type_use_case: AttributeTypeUseCase, + object_class_use_case_legacy: ObjectClassUseCaseLegacy, + object_class_use_case: ObjectClassUseCase, setup_gateway: SetupGateway, entity_type_use_case: EntityTypeUseCase, password_use_cases: PasswordPolicyUseCases, @@ -56,6 +77,10 @@ def __init__( self._role_use_case = role_use_case self._audit_use_case = audit_use_case self._session = session + self._attribute_type_use_case_legacy = attribute_type_use_case_legacy + self._attribute_type_use_case = attribute_type_use_case + self._object_class_use_case_legacy = object_class_use_case_legacy + self._object_class_use_case = object_class_use_case self._settings = settings async def setup(self, dto: SetupDTO) -> None: @@ -68,7 +93,9 @@ async def setup(self, dto: SetupDTO) -> None: """ if await self.is_setup(): raise AlreadyConfiguredError("Setup already performed") - await self._entity_type_use_case.create_for_first_setup() + + for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): + await self._entity_type_use_case.create_not_safe(entity_type_dto) data = copy.deepcopy(FIRST_SETUP_DATA) data.append(self._create_user_data(dto)) @@ -86,6 +113,7 @@ async def is_setup(self) -> bool: def _create_domain_controller_data(self) -> dict: return { "name": DOMAIN_CONTROLLERS_OU_NAME, + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], @@ -93,6 +121,7 @@ def _create_domain_controller_data(self) -> dict: "children": [ { "name": self._settings.HOST_MACHINE_SHORT_NAME, + "entity_type_name": EntityTypeNames.COMPUTER, "object_class": "computer", "attributes": { "objectClass": ["top"], @@ -121,11 +150,13 @@ def _create_user_data(self, dto: SetupDTO) -> dict: """ return { "name": USERS_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [ { "name": dto.username, + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": dto.username, @@ -173,6 +204,28 @@ async def _create(self, dto: SetupDTO, data: list) -> None: dn=dto.domain, is_system=True, ) + + attrs = await self._attribute_type_use_case_legacy.get_all() + for attr in attrs: + await self._attribute_type_use_case.create(attr) + + obj_classes = await self._object_class_use_case_legacy.get_all() + for obj_class in obj_classes: + obj_class.attribute_types_may = [ + _.name # type: ignore + for _ in obj_class.attribute_types_may + ] + obj_class.attribute_types_must = [ + _.name # type: ignore + for _ in obj_class.attribute_types_must + ] + await self._object_class_use_case.create(obj_class) # type: ignore + + await self._attribute_type_use_case_legacy.delete_table() + await self._object_class_use_case_legacy.delete_may_table() + await self._object_class_use_case_legacy.delete_must_table() + await self._object_class_use_case_legacy.delete_main_table() + await self._password_use_cases.create_default_domain_policy() errors = await ( diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index ce0b301c5..c7da74319 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -22,14 +22,8 @@ ) from sqlalchemy.sql.expression import false as sql_false -from entities import ( - Attribute, - AttributeType, - Directory, - EntityType, - Group, - User, -) +from entities import Attribute, Directory, EntityType, Group, User +from enums import EntityTypeNames from ldap_protocol.utils.helpers import ft_to_dt from ldap_protocol.utils.queries import get_path_filter, get_search_path from repo.pg.tables import ( @@ -114,11 +108,18 @@ def _get_anr_filter(self, val: str) -> ColumnElement[bool]: if is_first_char_equal: vl = normalized.replace("=", "") + attributes_expr.append( and_( qa(Attribute.name).in_( - select(qa(AttributeType.name)) - .where(qa(AttributeType.is_included_anr).is_(True)), + select(qa(Directory.name)) + .join(qa(Directory.entity_type)) + .join(qa(Directory.attributes)) + .where( + qa(Attribute.name) == "aNR", + qa(Attribute.value) == "True", + qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE, # noqa: E501 + ), ), func.lower(Attribute.value) == vl, ), @@ -144,8 +145,14 @@ def _get_anr_filter(self, val: str) -> ColumnElement[bool]: attributes_expr.append( and_( qa(Attribute.name).in_( - select(qa(AttributeType.name)) - .where(qa(AttributeType.is_included_anr).is_(True)), + select(qa(Directory.name)) + .join(qa(Directory.entity_type)) + .join(qa(Directory.attributes)) + .where( + qa(Attribute.name) == "aNR", + qa(Attribute.value) == "True", + qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE, # noqa: E501 + ), ), qa(Attribute.value).ilike(vl), ), @@ -207,9 +214,14 @@ def _get_anr_filter(self, val: str) -> ColumnElement[bool]: attributes_expr.append( and_( qa(Attribute.name).in_( - select(qa(AttributeType.name)).where( - qa(AttributeType.name) == "legacyExchangeDN", - qa(AttributeType.is_included_anr).is_(True), + select(qa(Directory.name)) + .join(qa(Directory.entity_type)) + .join(qa(Directory.attributes)) + .where( + qa(Directory.name) == "legacyExchangeDN", + qa(Attribute.name) == "aNR", + qa(Attribute.value) == "True", + qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE, ), ), qa(Attribute.value) == normalized.replace("=", ""), diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index d6e6e8078..0dcdc8384 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -160,12 +160,15 @@ async def handle( # noqa: C901 yield AddResponse(result_code=LDAPCodes.NO_SUCH_OBJECT) return - entity_type = ( - await ctx.entity_type_dao.get_entity_type_by_object_class_names( - object_class_names=self.object_class_names, - ) + entity_type = await ctx.entity_type_use_case.get_entity_type_by_object_class_names( # noqa: E501 + object_class_names=self.object_class_names, ) - if entity_type and entity_type.name == EntityTypeNames.CONTAINER: + if entity_type and entity_type.name in ( + EntityTypeNames.CONTAINER, + EntityTypeNames.ATTRIBUTE_TYPE, + EntityTypeNames.OBJECT_CLASS, + EntityTypeNames.CONFIGURATION, + ): yield AddResponse(result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS) return @@ -477,10 +480,11 @@ async def handle( # noqa: C901 ctx.session.add_all(items_to_add) await ctx.session.flush() - await ctx.entity_type_dao.attach_entity_type_to_directory( + entity_type_id = entity_type.id if entity_type else None + await ctx.entity_type_use_case.attach_entity_type_to_directory( directory=new_dir, is_system_entity_type=False, - entity_type=entity_type, + entity_type_id=entity_type_id, object_class_names=self.object_class_names, ) await ctx.role_use_case.inherit_parent_aces( diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index 98f6e1a9b..d94b92af8 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -11,10 +11,19 @@ from config import Settings from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import AbstractKadmin +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( + AttributeTypeUseCase, +) from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) from ldap_protocol.multifactor import LDAPMultiFactorAPI from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases @@ -32,7 +41,7 @@ class LDAPAddRequestContext: session: AsyncSession ldap_session: LDAPSession kadmin: AbstractKadmin - entity_type_dao: EntityTypeDAO + entity_type_use_case: EntityTypeUseCase password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils access_manager: AccessManager @@ -49,7 +58,7 @@ class LDAPModifyRequestContext: session_storage: SessionStorage kadmin: AbstractKadmin settings: Settings - entity_type_dao: EntityTypeDAO + entity_type_use_case: EntityTypeUseCase access_manager: AccessManager password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils @@ -79,6 +88,8 @@ class LDAPSearchRequestContext: settings: Settings access_manager: AccessManager rootdse_rd: RootDSEReader + attribute_type_use_case: AttributeTypeUseCase + object_class_use_case: ObjectClassUseCase @dataclass diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index b8ad639d8..cf7e637e0 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import joinedload, selectinload from entities import Directory, Group -from enums import AceType +from enums import AceType, EntityTypeNames from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.kerberos.exceptions import ( KRBAPIConnectionError, @@ -106,6 +106,17 @@ async def handle( # noqa: C901 ) return + entity_type = directory.entity_type if directory else None + if entity_type and entity_type.name in ( + EntityTypeNames.ATTRIBUTE_TYPE, + EntityTypeNames.OBJECT_CLASS, + EntityTypeNames.CONFIGURATION, + ): + yield DeleteResponse( + result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, + ) + return + self.set_event_data( {"before_attrs": self.get_directory_attrs(directory)}, ) diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index e6ccb3b89..981bd70ef 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -157,7 +157,7 @@ async def _update_password_expiration( now = datetime.now(timezone.utc) + timedelta(days=max_age_days) change.modification.vals[0] = now.strftime("%Y%m%d%H%M%SZ") - async def handle( + async def handle( # noqa: C901 self, ctx: LDAPModifyRequestContext, ) -> AsyncGenerator[ModifyResponse, None]: @@ -223,8 +223,18 @@ async def handle( yield ModifyResponse(result_code=LDAPCodes.NOT_ALLOWED_ON_RDN) return - before_attrs = self.get_directory_attrs(directory) entity_type = directory.entity_type + if entity_type and entity_type.name in ( + EntityTypeNames.ATTRIBUTE_TYPE, + EntityTypeNames.OBJECT_CLASS, + EntityTypeNames.CONFIGURATION, + ): + yield ModifyResponse( + result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, + ) + return + + before_attrs = self.get_directory_attrs(directory) try: for change in self.changes: if change.l_type in Directory.ro_fields: @@ -300,7 +310,7 @@ async def handle( ) if "objectclass" in names: - await ctx.entity_type_dao.attach_entity_type_to_directory( + await ctx.entity_type_use_case.attach_entity_type_to_directory( directory=directory, is_system_entity_type=False, ) diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index c9ab0bd57..6674adaf7 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -23,14 +23,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select -from entities import ( - Attribute, - AttributeType, - Directory, - Group, - ObjectClass, - User, -) +from entities import Attribute, Directory, Group, User from enums import AceType from ldap_protocol.asn1parser import ASN1Row from ldap_protocol.dialogue import UserSchema @@ -46,6 +39,18 @@ SearchResultEntry, SearchResultReference, ) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_raw_display import ( # noqa: E501 + AttributeTypeRawDisplay, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( + AttributeTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_raw_display import ( + ObjectClassRawDisplay, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) from ldap_protocol.objects import DerefAliases, ProtocolRequests, Scope from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.rootdse.netlogon import NetLogonAttributeHandler @@ -194,28 +199,27 @@ def from_data(cls, data: dict[str, list[ASN1Row]]) -> "SearchRequest": attributes=[field.value for field in attributes.value], ) - async def _get_subschema(self, session: AsyncSession) -> SearchResultEntry: + async def _get_subschema( + self, + attribute_type_use_case: AttributeTypeUseCase, + object_class_use_case: ObjectClassUseCase, + ) -> SearchResultEntry: attrs: dict[str, list[str]] = defaultdict(list) attrs["name"].append("Schema") attrs["objectClass"].append("subSchema") attrs["objectClass"].append("top") - attribute_types = await session.scalars(select(AttributeType)) + attribute_type_dtos = await attribute_type_use_case.get_all() attrs["attributeTypes"] = [ - attribute_type.get_raw_definition() - for attribute_type in attribute_types + AttributeTypeRawDisplay.get_raw_definition(attribute_type_dto) + for attribute_type_dto in attribute_type_dtos ] - object_classes = await session.scalars( - select(ObjectClass).options( - selectinload(qa(ObjectClass.attribute_types_must)), - selectinload(qa(ObjectClass.attribute_types_may)), - ), - ) + object_class_dtos = await object_class_use_case.get_all() attrs["objectClasses"] = [ - object_class.get_raw_definition() - for object_class in object_classes + ObjectClassRawDisplay.get_raw_definition(object_class_dto) + for object_class_dto in object_class_dtos ] return SearchResultEntry( @@ -278,7 +282,10 @@ async def get_result( if self.scope == Scope.BASE_OBJECT and (is_root_dse or is_schema): if is_schema: - yield await self._get_subschema(ctx.session) + yield await self._get_subschema( + ctx.attribute_type_use_case, + ctx.object_class_use_case, + ) elif is_netlogon: nl_attr = await self._get_netlogon(ctx) yield SearchResultEntry( diff --git a/app/ldap_protocol/ldap_schema/_legacy/attribute_type/__init__.py b/app/ldap_protocol/ldap_schema/_legacy/attribute_type/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_dao.py b/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_dao.py new file mode 100644 index 000000000..069560fb7 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_dao.py @@ -0,0 +1,154 @@ +"""Attribute Type DAO. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from adaptix import P +from adaptix.conversion import ( + allow_unlinked_optional, + get_converter, + link_function, +) +from entities_legacy import AttributeTypeLegacy, ObjectClassLegacy +from sqlalchemy import delete, or_, select, text, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory, EntityType +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO +from ldap_protocol.ldap_schema.exceptions import ( + AttributeTypeAlreadyExistsError, + AttributeTypeNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + +_convert_model_to_dto = get_converter( + AttributeTypeLegacy, + AttributeTypeDTO, + recipe=[ + allow_unlinked_optional(P[AttributeTypeDTO].object_class_names), + ], +) +_convert_dto_to_model = get_converter( + AttributeTypeDTO, + AttributeTypeLegacy, + recipe=[ + link_function( + lambda _: None, + P[AttributeTypeLegacy].id, + ), + ], +) + + +class AttributeTypeDAOLegacy: + """Attribute Type DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Attribute Type DAO with session.""" + self.__session = session + + async def delete_all_dirs(self) -> None: + attr_subq = ( + select(qa(EntityType.id)) + .where(qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE) + .scalar_subquery(), + ) + await self.__session.execute( + delete(Directory) + .where(qa(Directory.entity_type_id).in_(attr_subq)), + ) # fmt: skip + + async def delete_table(self) -> None: + await self.__session.execute( + text('DROP TABLE IF EXISTS "AttributeTypes" CASCADE'), + ) + + async def get_object_class_names_include_attribute_type( + self, + attribute_type_name: str, + ) -> set[str]: + """Get all Object Class names include Attribute Type name.""" + result = await self.__session.execute( + select(qa(ObjectClassLegacy.name)) + .where( + or_( + qa(ObjectClassLegacy.attribute_types_must).any(name=attribute_type_name), + qa(ObjectClassLegacy.attribute_types_may).any(name=attribute_type_name), + ), + ), + ) # fmt: skip + return set(row[0] for row in result.fetchall()) + + async def get_all(self) -> list[AttributeTypeDTO[int]]: + """Get all Attribute Types.""" + res = await self.__session.scalars(select(AttributeTypeLegacy)) + return list(map(_convert_model_to_dto, res.all())) + + async def create(self, dto: AttributeTypeDTO[None]) -> None: + """Create Attribute Type.""" + try: + self.__session.add(_convert_dto_to_model(dto)) + await self.__session.flush() + + except IntegrityError: + raise AttributeTypeAlreadyExistsError( + f"Attribute Type with oid '{dto.oid}' and name" + + f" '{dto.name}' already exists.", + ) + + async def zero_all_replicated_flags(self) -> None: + """Set replication flag to False for all Attribute Types.""" + await self.__session.execute(update(AttributeTypeLegacy).values({"system_flags": 0})) # fmt: skip # noqa: E501 + + async def set_false_replication_flag(self, names: tuple[str, ...]) -> None: + """Set replication flag in systemFlags.""" + await self.__session.execute( + update(AttributeTypeLegacy) + .where(qa(AttributeTypeLegacy.name).in_(names)) + .values({"system_flags": 0}), + ) + + async def false_all_is_included_anr(self) -> None: + """Set is_included_anr to False for all Attribute Types.""" + await self.__session.execute(update(AttributeTypeLegacy).values({"is_included_anr": False})) # fmt: skip # noqa: E501 + + async def mark_anr_included_by_attr_names( + self, + names: tuple[str, ...], + ) -> list[str]: + """Update Attribute Types and return updated AttrType names.""" + result = await self.__session.scalars( + update(AttributeTypeLegacy) + .where(qa(AttributeTypeLegacy.name).in_(names)) + .values({"is_included_anr": True}) + .returning(qa(AttributeTypeLegacy.name)), + ) + return list(result.all()) + + async def get(self, name: str) -> AttributeTypeDTO[int]: + attribute_type = await self.__session.scalar( + select(AttributeTypeLegacy) + .filter_by(name=name), + ) # fmt: skip + + if not attribute_type: + raise AttributeTypeNotFoundError( + f"Attribute Type with name '{name}' not found.", + ) + return _convert_model_to_dto(attribute_type) + + async def get_all_raw_by_names( + self, + names: list[str], + ) -> list[AttributeTypeLegacy]: + """Get list of Attribute Types by names.""" + res = await self.__session.scalars( + select(AttributeTypeLegacy) + .where(qa(AttributeTypeLegacy.name).in_(names)), + ) # fmt: skip + return list(res.all()) diff --git a/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_use_case.py b/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_use_case.py new file mode 100644 index 000000000..b49a3d5c2 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/_legacy/attribute_type/attribute_type_use_case.py @@ -0,0 +1,87 @@ +"""Attribute Type Use Case. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import ClassVar, Sequence + +from entities_legacy import AttributeTypeLegacy + +from abstract_service import AbstractService +from enums import AuthorizationRules +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_dao import ( # noqa: E501 + AttributeTypeDAOLegacy, +) +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO + + +class AttributeTypeUseCaseLegacy(AbstractService): + """AttributeTypeUseCase.""" + + __attribute_type_dao_legacy: AttributeTypeDAOLegacy + + def __init__( + self, + attribute_type_dao_legacy: AttributeTypeDAOLegacy, + ) -> None: + """Init AttributeTypeUseCase.""" + self.__attribute_type_dao_legacy = attribute_type_dao_legacy + + async def get(self, name: str) -> AttributeTypeDTO[int]: + """Get Attribute Type by name.""" + dto = await self.__attribute_type_dao_legacy.get(name) + dto.object_class_names = await self.__attribute_type_dao_legacy.get_object_class_names_include_attribute_type(dto.name) # noqa: E501 # fmt: skip + return dto + + async def get_all(self) -> list[AttributeTypeDTO[int]]: + """Get all Attribute Types.""" + return await self.__attribute_type_dao_legacy.get_all() + + async def create(self, dto: AttributeTypeDTO[None]) -> None: + """Create Attribute Type.""" + await self.__attribute_type_dao_legacy.create(dto) + + async def delete_all_dirs(self) -> None: + """Delete all Attribute Type directories.""" + await self.__attribute_type_dao_legacy.delete_all_dirs() + + async def delete_table(self) -> None: + await self.__attribute_type_dao_legacy.delete_table() + + async def zero_all_replicated_flags(self) -> None: + """Set replication flag to False for all Attribute Types.""" + await self.__attribute_type_dao_legacy.zero_all_replicated_flags() + + async def set_false_replication_flag( + self, + names: tuple[str, ...], + ) -> None: + """Set replication flag in systemFlags.""" + await self.__attribute_type_dao_legacy.set_false_replication_flag( + names, + ) + + async def mark_anr_included_by_attr_names( + self, + names: tuple[str, ...], + ) -> list[str]: + """Update Attribute Types and return updated DTOs.""" + return await self.__attribute_type_dao_legacy.mark_anr_included_by_attr_names( # noqa: E501 + names, + ) + + async def false_all_is_included_anr(self) -> None: + """Set is_included_anr to False for all Attribute Types.""" + await self.__attribute_type_dao_legacy.false_all_is_included_anr() + + async def get_all_raw_by_names( + self, + names: list[str], + ) -> Sequence[AttributeTypeLegacy]: + """Get list of Attribute Types by names.""" + return await self.__attribute_type_dao_legacy.get_all_raw_by_names( + names, + ) + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = {} diff --git a/app/ldap_protocol/ldap_schema/_legacy/object_class/__init__.py b/app/ldap_protocol/ldap_schema/_legacy/object_class/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_dao.py b/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_dao.py new file mode 100644 index 000000000..f13fc34a7 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_dao.py @@ -0,0 +1,141 @@ +"""Object Class DAO. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dataclasses import dataclass + +from adaptix import P +from adaptix.conversion import ( + allow_unlinked_optional, + get_converter, + link_function, +) +from entities_legacy import AttributeTypeLegacy, ObjectClassLegacy +from sqlalchemy import delete, select, text +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from entities import Directory, EntityType +from enums import EntityTypeNames, KindType +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO +from ldap_protocol.ldap_schema.exceptions import ( + ObjectClassAlreadyExistsError, + ObjectClassNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + +_convert_model_to_dto = get_converter( + ObjectClassLegacy, + ObjectClassDTO[int, AttributeTypeDTO], + recipe=[ + allow_unlinked_optional(P[ObjectClassDTO].id), + allow_unlinked_optional(P[ObjectClassDTO].entity_type_names), + allow_unlinked_optional(P[AttributeTypeDTO].object_class_names), + link_function(lambda x: x.kind, P[ObjectClassDTO].kind), + ], +) + + +@dataclass +class ObjectClassCreateDTO: + """Object Class DTO.""" + + oid: str + name: str + kind: KindType + is_system: bool + attribute_types_must: list[AttributeTypeLegacy] + attribute_types_may: list[AttributeTypeLegacy] + superior: ObjectClassLegacy | None = None + + +class ObjectClassDAOLegacy: + """Object Class DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Object Class DAO with session.""" + self.__session = session + + async def get_all(self) -> list[ObjectClassDTO[int, AttributeTypeDTO]]: + """Get all Object Classes.""" + obj_classes = await self.__session.scalars( + select(ObjectClassLegacy) + .options( + selectinload(qa(ObjectClassLegacy.attribute_types_may)), + selectinload(qa(ObjectClassLegacy.attribute_types_must)), + ), + ) # fmt: skip + return list(map(_convert_model_to_dto, obj_classes.all())) + + async def create( + self, + dto: ObjectClassCreateDTO, + ) -> None: + """Create a new Object Class.""" + try: + object_class = ObjectClassLegacy( + oid=dto.oid, + name=dto.name, + superior=dto.superior, + kind=dto.kind, + is_system=dto.is_system, + attribute_types_must=dto.attribute_types_must, + attribute_types_may=dto.attribute_types_may, + ) + self.__session.add(object_class) + await self.__session.flush() + except IntegrityError: + raise ObjectClassAlreadyExistsError( + f"Object Class with oid '{dto.oid}' and name" + + f" '{dto.name}' already exists.", + ) + + async def get_raw_by_name(self, name: str) -> ObjectClassLegacy: + """Get single Object Class by name.""" + object_class = await self.__session.scalar( + select(ObjectClassLegacy) + .filter_by(name=name) + .options(selectinload(qa(ObjectClassLegacy.attribute_types_may))) + .options(selectinload(qa(ObjectClassLegacy.attribute_types_must))), + ) # fmt: skip + + if not object_class: + raise ObjectClassNotFoundError( + f"Object Class with name '{name}' not found.", + ) + return object_class + + async def delete_all_dirs(self) -> None: + objcls_subq = ( + select(qa(EntityType.id)) + .where(qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS) + .scalar_subquery() + ) + await self.__session.execute( + delete(Directory) + .where(qa(Directory.entity_type_id).in_(objcls_subq)), + ) # fmt: skip + + async def delete_main_table(self) -> None: + await self.__session.execute( + text('DROP TABLE IF EXISTS "ObjectClasses" CASCADE'), + ) + + async def delete_may_table(self) -> None: + await self.__session.execute( + text('DROP TABLE IF EXISTS "ObjectClassAttributeTypeMayMemberships" CASCADE'), # noqa: E501 + ) # fmt: skip + + async def delete_must_table(self) -> None: + await self.__session.execute( + text('DROP TABLE IF EXISTS "ObjectClassAttributeTypeMustMemberships" CASCADE'), # noqa: E501 + ) # fmt: skip + + async def get(self, name: str) -> ObjectClassDTO[int, AttributeTypeDTO]: + """Get single Object Class by name.""" + return _convert_model_to_dto(await self.get_raw_by_name(name)) diff --git a/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_use_case.py b/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_use_case.py new file mode 100644 index 000000000..d381576a1 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/_legacy/object_class/object_class_use_case.py @@ -0,0 +1,114 @@ +"""Object Class Use Case. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import ClassVar + +from entities_legacy import ObjectClassLegacy + +from abstract_service import AbstractService +from enums import AuthorizationRules +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_dao import ( # noqa: E501 + AttributeTypeDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_dao import ( + ObjectClassCreateDTO, + ObjectClassDAOLegacy, +) +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO +from ldap_protocol.ldap_schema.exceptions import ObjectClassNotFoundError + + +class ObjectClassUseCaseLegacy(AbstractService): + """ObjectClassUseCase.""" + + __attribute_type_dao_legacy: AttributeTypeDAOLegacy + __object_class_dao_legacy: ObjectClassDAOLegacy + + def __init__( + self, + object_class_dao_legacy: ObjectClassDAOLegacy, + attribute_type_dao_legacy: AttributeTypeDAOLegacy, + ) -> None: + """Init ObjectClassUseCase.""" + self.__attribute_type_dao_legacy = attribute_type_dao_legacy + self.__object_class_dao_legacy = object_class_dao_legacy + + async def get(self, name: str) -> ObjectClassDTO[int, AttributeTypeDTO]: + """Get Object Class by name.""" + dto = await self.__object_class_dao_legacy.get(name) + return dto + + async def get_all(self) -> list[ObjectClassDTO[int, AttributeTypeDTO]]: + """Get all Object Classes.""" + return await self.__object_class_dao_legacy.get_all() + + async def create(self, dto: ObjectClassDTO[None, str]) -> None: + """Create a new Object Class.""" + create_dto = ObjectClassCreateDTO( + oid=dto.oid, + name=dto.name, + kind=dto.kind, + is_system=dto.is_system, + attribute_types_must=[], + attribute_types_may=[], + ) + + if dto.superior_name: + create_dto.superior = ( + await self.__object_class_dao_legacy.get_raw_by_name( + dto.superior_name, + ) + ) + + if dto.superior_name and not create_dto.superior: + raise ObjectClassNotFoundError( + f"Superior (parent) Object class {dto.superior_name} " + "not found in schema.", + ) + + attribute_types_may_filtered = [ + name + for name in dto.attribute_types_may + if name not in dto.attribute_types_must + ] + + if dto.attribute_types_must: + create_dto.attribute_types_must = ( + await self.__attribute_type_dao_legacy.get_all_raw_by_names( + dto.attribute_types_must, + ) + ) + + if attribute_types_may_filtered: + create_dto.attribute_types_may = ( + await self.__attribute_type_dao_legacy.get_all_raw_by_names( + attribute_types_may_filtered, + ) + ) + + await self.__object_class_dao_legacy.create(create_dto) + + async def get_raw_by_name(self, name: str) -> ObjectClassLegacy: + """Get Object Class by name without related data.""" + return await self.__object_class_dao_legacy.get_raw_by_name(name) + + async def delete_all_dirs(self) -> None: + """Delete all Object Class directories.""" + await self.__object_class_dao_legacy.delete_all_dirs() + + async def delete_main_table(self) -> None: + """Delete Object Class table.""" + await self.__object_class_dao_legacy.delete_main_table() + + async def delete_may_table(self) -> None: + """Delete Object Class May membership table.""" + await self.__object_class_dao_legacy.delete_may_table() + + async def delete_must_table(self) -> None: + """Delete Object Class Must membership table.""" + await self.__object_class_dao_legacy.delete_must_table() + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = {} diff --git a/app/ldap_protocol/ldap_schema/attribute_dao.py b/app/ldap_protocol/ldap_schema/attribute_dao.py new file mode 100644 index 000000000..5d7f743a6 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_dao.py @@ -0,0 +1,54 @@ +"""Attribute DAO. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Attribute +from ldap_protocol.ldap_schema.dto import AttributeDTO + + +class AttributeDAO: + """Attribute DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Attribute DAO with session.""" + self.__session = session + + async def add_directory_name_attribute( + self, + directory_id: int, + attribute_dto: AttributeDTO, + ) -> None: + """Add the RDN attribute for a Directory.""" + self.__session.add( + Attribute( + name=attribute_dto.name, + value=attribute_dto.values[0] if attribute_dto.values else "", + directory_id=directory_id, + ), + ) + + async def add_attributes_from_dto( + self, + directory_id: int, + attributes: tuple[AttributeDTO, ...], + ) -> None: + """Add Attributes from a CreateDirDTO payload.""" + for attribute_dto in attributes: + for value in attribute_dto.values: + if not isinstance(value, str): + raise ValueError("Only string values are supported.") + + self.__session.add( + Attribute( + directory_id=directory_id, + name=attribute_dto.name, + value=value, + bvalue=None, + ), + ) diff --git a/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_dao.py b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_dao.py new file mode 100644 index 000000000..8b24ff403 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_dao.py @@ -0,0 +1,177 @@ +"""Attribute Type DAO. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from entities import Directory, EntityType +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.attribute_type.constants import ( + AttributeTypeAttributeNames as Names, +) +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO +from ldap_protocol.ldap_schema.exceptions import AttributeTypeNotFoundError +from ldap_protocol.utils.pagination import PaginationParams, PaginationResult +from repo.pg.tables import queryable_attr as qa + + +def _convert_model_to_dto(directory: Directory) -> AttributeTypeDTO[int]: + return AttributeTypeDTO[int]( + id=directory.id, + name=directory.name, + oid=directory.attributes_dict[Names.OID][0], + syntax=directory.attributes_dict[Names.SYNTAX][0], + single_value=directory.attributes_dict[Names.SINGLE_VALUE][0] == "True", # noqa: E501 + no_user_modification=directory.attributes_dict[Names.NO_USER_MODIFICATION][0] == "True", # noqa: E501 + is_system=directory.is_system, + system_flags=int(directory.attributes_dict[Names.SYSTEM_FLAGS][0]), + is_included_anr=directory.attributes_dict[Names.IS_INCLUDED_ANR][0] == "True", # noqa: E501 + object_class_names=set(), + ) # fmt: skip + + +class AttributeTypeDAO: + """Attribute Type DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Attribute Type DAO with session.""" + self.__session = session + + async def _get_dir(self, name: str) -> Directory | None: + res = await self.__session.scalars( + select(Directory) + .join(qa(Directory.entity_type)) + .options(selectinload(qa(Directory.attributes))) + .where( + qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE, + qa(Directory.name) == name, + ), + ) + dir_ = res.first() + return dir_ + + async def get_all_names_by_names(self, names: list[str]) -> list[str]: + res = await self.__session.scalars( + select(qa(Directory.name)) + .join(qa(Directory.entity_type)) + .where( + qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE, + qa(Directory.name).in_(names), + ), + ) + return list(res.all()) + + async def get_all(self) -> list[AttributeTypeDTO]: + res = await self.__session.scalars( + select(Directory) + .join(qa(Directory.entity_type)) + .options(selectinload(qa(Directory.attributes))) + .where(qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE), + ) + return list(map(_convert_model_to_dto, res.all())) + + async def get(self, name: str) -> AttributeTypeDTO: + """Get Attribute Type by name.""" + dir_ = await self._get_dir(name) + if not dir_: + raise AttributeTypeNotFoundError( + f"Attribute Type with name '{name}' not found.", + ) + + return _convert_model_to_dto(dir_) + + async def update(self, name: str, dto: AttributeTypeDTO) -> None: + """Update Attribute Type. + + Docs: + ANR (Ambiguous Name Resolution) inclusion can be modified for + all attributes, including system ones, as it's a search + optimization setting that doesn't affect the LDAP schema + structure or data integrity. + + Other properties (`syntax`, `single_value`, `no_user_modification`) + can only be modified for non-system attributes to preserve + LDAP schema integrity. + """ + dir_ = await self._get_dir(name) + if not dir_: + raise AttributeTypeNotFoundError( + f"Attribute Type with name '{name}' not found.", + ) + + for attr in dir_.attributes: + if not dir_.is_system: + if attr.name == Names.SYNTAX: + attr.value = dto.syntax + elif attr.name == Names.SINGLE_VALUE: + attr.value = str(dto.single_value) + elif attr.name == Names.NO_USER_MODIFICATION: + attr.value = str(dto.no_user_modification) + else: + if attr.name == Names.IS_INCLUDED_ANR: + attr.value = str(dto.is_included_anr) + break + + await self.__session.flush() + + async def update_sys_flags(self, name: str, dto: AttributeTypeDTO) -> None: + """Update system flags of Attribute Type.""" + dir_ = await self._get_dir(name) + if not dir_: + raise AttributeTypeNotFoundError( + f"Attribute Type with name '{name}' not found.", + ) + + for attr in dir_.attributes: + if attr.name == Names.SYSTEM_FLAGS: + attr.value = str(dto.system_flags) + break + + await self.__session.flush() + + async def get_paginator( + self, + params: PaginationParams, + ) -> PaginationResult[Directory, AttributeTypeDTO]: + """Retrieve paginated Attribute Types.""" + filters = [qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE] + + if params.query: + filters.append(qa(Directory.name).like(f"%{params.query}%")) + + query = ( + select(Directory) + .join(qa(Directory.entity_type)) + .where(*filters) + .options(selectinload(qa(Directory.attributes))) + .order_by(qa(Directory.id)) + ) + + return await PaginationResult[Directory, AttributeTypeDTO].get( + params=params, + query=query, + converter=_convert_model_to_dto, + session=self.__session, + ) + + async def delete_all_by_names(self, names: list[str]) -> None: + """Delete not system Attribute Types by names.""" + if not names: + return + + await self.__session.execute( + delete(Directory) + .where( + qa(Directory.entity_type) + .has(qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE), + qa(Directory.name).in_(names), + qa(Directory.is_system).is_(False), + ), + ) # fmt: skip + await self.__session.flush() diff --git a/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_raw_display.py b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_raw_display.py new file mode 100644 index 000000000..08d5f3e62 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_raw_display.py @@ -0,0 +1,25 @@ +"""AttributeTypeRawDisplay.""" + +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO + + +class AttributeTypeRawDisplay: + @staticmethod + def get_raw_definition(dto: AttributeTypeDTO) -> str: + if not dto.oid or not dto.name or not dto.syntax: + raise ValueError( + f"{dto}: Fields 'oid', 'name', " + "and 'syntax' are required for LDAP definition.", + ) + chunks = [ + "(", + dto.oid, + f"NAME '{dto.name}'", + f"SYNTAX '{dto.syntax}'", + ] + if dto.single_value: + chunks.append("SINGLE-VALUE") + if dto.no_user_modification: + chunks.append("NO-USER-MODIFICATION") + chunks.append(")") + return " ".join(chunks) diff --git a/app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_system_flags_use_case.py similarity index 98% rename from app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py rename to app/ldap_protocol/ldap_schema/attribute_type/attribute_type_system_flags_use_case.py index a903028a8..f01e68c1e 100644 --- a/app/ldap_protocol/ldap_schema/attribute_type_system_flags_use_case.py +++ b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_system_flags_use_case.py @@ -44,7 +44,7 @@ def is_attr_replicated( & AttributeTypeSystemFlags.ATTR_NOT_REPLICATED, ) - def set_attr_replication_flag( + def set_attr_replication( self, attribute_type_dto: AttributeTypeDTO, need_to_replicate: bool, diff --git a/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_use_case.py b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_use_case.py new file mode 100644 index 000000000..227a45d80 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_type/attribute_type_use_case.py @@ -0,0 +1,150 @@ +"""Attribute Type Use Case. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import ClassVar + +from sqlalchemy.exc import IntegrityError + +from abstract_service import AbstractService +from enums import AuthorizationRules, EntityTypeNames +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_system_flags_use_case import ( # noqa: E501 + AttributeTypeSystemFlagsUseCase, +) +from ldap_protocol.ldap_schema.attribute_type.constants import ( + AttributeTypeAttributeNames as Names, +) +from ldap_protocol.ldap_schema.dto import ( + AttributeDTO, + AttributeTypeDTO, + CreateDirDTO, +) +from ldap_protocol.ldap_schema.exceptions import ( + AttributeTypeAlreadyExistsError, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) +from ldap_protocol.ldap_schema.schema_create_use_case import ( + SchemaLikeAsDirectoryCreateUseCase, +) +from ldap_protocol.utils.pagination import PaginationParams, PaginationResult + + +class AttributeTypeUseCase(AbstractService): + """AttributeTypeUseCase.""" + + __attribute_type_dao: AttributeTypeDAO + __attribute_type_system_flags_use_case: AttributeTypeSystemFlagsUseCase + __object_class_dao: ObjectClassDAO + __schema_create_use_case: SchemaLikeAsDirectoryCreateUseCase + + def __init__( + self, + attribute_type_dao: AttributeTypeDAO, + attribute_type_system_flags_use_case: AttributeTypeSystemFlagsUseCase, + object_class_dao: ObjectClassDAO, + schema_create_use_case: SchemaLikeAsDirectoryCreateUseCase, + ) -> None: + """Init AttributeTypeUseCase.""" + self.__attribute_type_dao = attribute_type_dao + self.__attribute_type_system_flags_use_case = attribute_type_system_flags_use_case # noqa: E501 # fmt: skip + self.__object_class_dao = object_class_dao + self.__schema_create_use_case = schema_create_use_case + + async def get(self, name: str) -> AttributeTypeDTO[int]: + """Get Attribute Type by name.""" + dto = await self.__attribute_type_dao.get(name) + dto.object_class_names = await self.__object_class_dao.get_object_class_names_include_attribute_type( # noqa: E501 + dto.name, + ) + return dto + + async def get_all(self) -> list[AttributeTypeDTO[int]]: + """Get all Attribute Types.""" + return await self.__attribute_type_dao.get_all() + + async def create(self, dto: AttributeTypeDTO) -> None: + """Create Attribute Type.""" + _dto = CreateDirDTO( + name=dto.name, + entity_type_name=EntityTypeNames.ATTRIBUTE_TYPE, + attributes=( + AttributeDTO(name=Names.OID, values=[str(dto.oid)]), + AttributeDTO(name=Names.NAME, values=[str(dto.name)]), + AttributeDTO(name=Names.SYNTAX, values=[str(dto.syntax)]), + AttributeDTO( + name=Names.SINGLE_VALUE, + values=[str(dto.single_value)], + ), + AttributeDTO( + name=Names.NO_USER_MODIFICATION, + values=[str(dto.no_user_modification)], + ), + AttributeDTO( + name=Names.SYSTEM_FLAGS, + values=[str(dto.system_flags)], + ), + AttributeDTO( + name=Names.IS_INCLUDED_ANR, + values=[str(dto.is_included_anr)], + ), + ), + is_system=dto.is_system, + ) + try: + await self.__schema_create_use_case.create_dir(dto=_dto) + except IntegrityError: + raise AttributeTypeAlreadyExistsError( + f"Attribute Type with oid '{dto.oid}' and name" + + f" '{dto.name}' already exists.", + ) + + async def update(self, name: str, dto: AttributeTypeDTO) -> None: + """Update Attribute Type.""" + await self.__attribute_type_dao.update(name, dto) + + async def get_paginator( + self, + params: PaginationParams, + ) -> PaginationResult: + """Retrieve paginated Attribute Types.""" + return await self.__attribute_type_dao.get_paginator(params) + + async def delete_all_by_names(self, names: list[str]) -> None: + """Delete not system Attribute Types by names.""" + return await self.__attribute_type_dao.delete_all_by_names(names) + + async def is_attr_replicated(self, name: str) -> bool: + """Check if attribute is replicated based on systemFlags.""" + dto = await self.__attribute_type_dao.get(name) + return self.__attribute_type_system_flags_use_case.is_attr_replicated(dto) # noqa: E501 # fmt: skip + + async def set_attr_replication_flag( + self, + name: str, + need_to_replicate: bool, + ) -> None: + """Set replication flag in systemFlags.""" + dto = await self.get(name) + dto = self.__attribute_type_system_flags_use_case.set_attr_replication( + dto, + need_to_replicate, + ) + await self.__attribute_type_dao.update_sys_flags( + dto.name, + dto, + ) + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + get.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET, + create.__name__: AuthorizationRules.ATTRIBUTE_TYPE_CREATE, + get_paginator.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET_PAGINATOR, # noqa: E501 + update.__name__: AuthorizationRules.ATTRIBUTE_TYPE_UPDATE, + delete_all_by_names.__name__: AuthorizationRules.ATTRIBUTE_TYPE_DELETE_ALL_BY_NAMES, # noqa: E501 + } diff --git a/app/ldap_protocol/ldap_schema/attribute_type/constants.py b/app/ldap_protocol/ldap_schema/attribute_type/constants.py new file mode 100644 index 000000000..a432570b3 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/attribute_type/constants.py @@ -0,0 +1,15 @@ +"""Constants for attribute type property names.""" + +from enum import StrEnum + + +class AttributeTypeAttributeNames(StrEnum): + """Attribute Type attribute names.""" + + OID = "attributeID" + NAME = "name" + SYNTAX = "attributeSyntax" + SINGLE_VALUE = "isSingleValued" + NO_USER_MODIFICATION = "systemOnly" + SYSTEM_FLAGS = "systemFlags" + IS_INCLUDED_ANR = "aNR" diff --git a/app/ldap_protocol/ldap_schema/attribute_type_dao.py b/app/ldap_protocol/ldap_schema/attribute_type_dao.py deleted file mode 100644 index 63b795e0a..000000000 --- a/app/ldap_protocol/ldap_schema/attribute_type_dao.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Attribute Type DAO. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from adaptix import P -from adaptix.conversion import ( - allow_unlinked_optional, - get_converter, - link_function, -) -from sqlalchemy import delete, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession - -from abstract_dao import AbstractDAO -from entities import AttributeType -from ldap_protocol.ldap_schema.dto import AttributeTypeDTO -from ldap_protocol.ldap_schema.exceptions import ( - AttributeTypeAlreadyExistsError, - AttributeTypeNotFoundError, -) -from ldap_protocol.utils.pagination import ( - PaginationParams, - PaginationResult, - build_paginated_search_query, -) -from repo.pg.tables import queryable_attr as qa - -_convert_model_to_dto = get_converter( - AttributeType, - AttributeTypeDTO, - recipe=[ - allow_unlinked_optional(P[AttributeTypeDTO].object_class_names), - ], -) -_convert_dto_to_model = get_converter( - AttributeTypeDTO, - AttributeType, - recipe=[ - link_function( - lambda _: None, - P[AttributeType].id, - ), - ], -) - - -class AttributeTypeDAO(AbstractDAO[AttributeTypeDTO, str]): - """Attribute Type DAO.""" - - __session: AsyncSession - - def __init__(self, session: AsyncSession) -> None: - """Initialize Attribute Type DAO with session.""" - self.__session = session - - async def get(self, name: str) -> AttributeTypeDTO: - """Get Attribute Type by name.""" - return _convert_model_to_dto(await self._get_one_raw_by_name(name)) - - async def get_all(self) -> list[AttributeTypeDTO]: - """Get all Attribute Types.""" - return [ - _convert_model_to_dto(attribute_type) - for attribute_type in await self.__session.scalars( - select(AttributeType), - ) - ] - - async def create(self, dto: AttributeTypeDTO) -> None: - """Create Attribute Type.""" - try: - attribute_type = _convert_dto_to_model(dto) - self.__session.add(attribute_type) - await self.__session.flush() - - except IntegrityError: - raise AttributeTypeAlreadyExistsError( - f"Attribute Type with oid '{dto.oid}' and name" - + f" '{dto.name}' already exists.", - ) - - async def update(self, name: str, dto: AttributeTypeDTO) -> None: - """Update Attribute Type. - - Docs: - ANR (Ambiguous Name Resolution) inclusion can be modified for - all attributes, including system ones, as it's a search - optimization setting that doesn't affect the LDAP schema - structure or data integrity. - - Other properties (`syntax`, `single_value`, `no_user_modification`) - can only be modified for non-system attributes to preserve - LDAP schema integrity. - """ - obj = await self._get_one_raw_by_name(name) - - obj.is_included_anr = dto.is_included_anr - - if not obj.is_system: - obj.syntax = dto.syntax - obj.single_value = dto.single_value - obj.no_user_modification = dto.no_user_modification - - await self.__session.flush() - - async def update_sys_flags(self, name: str, dto: AttributeTypeDTO) -> None: - """Update system flags of Attribute Type.""" - obj = await self._get_one_raw_by_name(name) - obj.system_flags = dto.system_flags - await self.__session.flush() - - async def delete(self, name: str) -> None: - """Delete Attribute Type.""" - attribute_type = await self._get_one_raw_by_name(name) - await self.__session.delete(attribute_type) - await self.__session.flush() - - async def get_paginator( - self, - params: PaginationParams, - ) -> PaginationResult[AttributeType, AttributeTypeDTO]: - """Retrieve paginated Attribute Types. - - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Attribute Types and metadata. - """ - query = build_paginated_search_query( - model=AttributeType, - order_by_field=qa(AttributeType.id), - params=params, - search_field=qa(AttributeType.name), - ) - - return await PaginationResult[AttributeType, AttributeTypeDTO].get( - params=params, - query=query, - converter=_convert_model_to_dto, - session=self.__session, - ) - - async def _get_one_raw_by_name(self, name: str) -> AttributeType: - attribute_type = await self.__session.scalar( - select(AttributeType) - .filter_by(name=name), - ) # fmt: skip - - if not attribute_type: - raise AttributeTypeNotFoundError( - f"Attribute Type with name '{name}' not found.", - ) - return attribute_type - - async def get_all_by_names( - self, - names: list[str] | set[str], - ) -> list[AttributeTypeDTO[int]]: - """Get list of Attribute Types by names. - - :param list[str] names: Attribute Type names. - :return list[AttributeTypeDTO]: List of Attribute Types. - """ - if not names: - return [] - - query = await self.__session.scalars( - select(AttributeType) - .where(qa(AttributeType.name).in_(names)), - ) # fmt: skip - return list(map(_convert_model_to_dto, query.all())) - - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete not system Attribute Types by names. - - :param list[str] names: List of Attribute Types names. - :return None: None. - """ - if not names: - return - - await self.__session.execute( - delete(AttributeType) - .where( - qa(AttributeType.name).in_(names), - qa(AttributeType.is_system).is_(False), - ), - ) # fmt: skip - await self.__session.flush() diff --git a/app/ldap_protocol/ldap_schema/attribute_type_use_case.py b/app/ldap_protocol/ldap_schema/attribute_type_use_case.py deleted file mode 100644 index 95f5425fc..000000000 --- a/app/ldap_protocol/ldap_schema/attribute_type_use_case.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Attribute Type Use Case. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from typing import ClassVar - -from abstract_service import AbstractService -from enums import AuthorizationRules -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO -from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( - AttributeTypeSystemFlagsUseCase, -) -from ldap_protocol.ldap_schema.dto import AttributeTypeDTO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.utils.pagination import PaginationParams, PaginationResult - - -class AttributeTypeUseCase(AbstractService): - """AttributeTypeUseCase.""" - - def __init__( - self, - attribute_type_dao: AttributeTypeDAO, - attribute_type_system_flags_use_case: AttributeTypeSystemFlagsUseCase, - object_class_dao: ObjectClassDAO, - ) -> None: - """Init AttributeTypeUseCase.""" - self._attribute_type_dao = attribute_type_dao - self._attribute_type_system_flags_use_case = ( - attribute_type_system_flags_use_case - ) - self._object_class_dao = object_class_dao - - async def get(self, name: str) -> AttributeTypeDTO: - """Get Attribute Type by name.""" - dto = await self._attribute_type_dao.get(name) - dto.object_class_names = await self._object_class_dao.get_object_class_names_include_attribute_type( # noqa: E501 - dto.name, - ) - return dto - - async def get_all(self) -> list[AttributeTypeDTO]: - """Get all Attribute Types.""" - return await self._attribute_type_dao.get_all() - - async def create(self, dto: AttributeTypeDTO) -> None: - """Create Attribute Type.""" - await self._attribute_type_dao.create(dto) - - async def update(self, name: str, dto: AttributeTypeDTO) -> None: - """Update Attribute Type.""" - await self._attribute_type_dao.update(name, dto) - - async def delete(self, name: str) -> None: - """Delete Attribute Type.""" - await self._attribute_type_dao.delete(name) - - async def get_paginator( - self, - params: PaginationParams, - ) -> PaginationResult: - """Retrieve paginated Attribute Types.""" - return await self._attribute_type_dao.get_paginator(params) - - async def get_all_by_names( - self, - names: list[str] | set[str], - ) -> list[AttributeTypeDTO]: - """Get list of Attribute Types by names.""" - return await self._attribute_type_dao.get_all_by_names(names) - - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete not system Attribute Types by names.""" - return await self._attribute_type_dao.delete_all_by_names(names) - - async def is_attr_replicated(self, name: str) -> bool: - """Check if attribute is replicated based on systemFlags.""" - dto = await self.get(name) - return self._attribute_type_system_flags_use_case.is_attr_replicated(dto) # noqa: E501 # fmt: skip - - async def set_attr_replication_flag( - self, - name: str, - need_to_replicate: bool, - ) -> None: - """Set replication flag in systemFlags.""" - dto = await self.get(name) - dto = self._attribute_type_system_flags_use_case.set_attr_replication_flag( # noqa: E501 - dto, - need_to_replicate, - ) - await self._attribute_type_dao.update_sys_flags(dto.name, dto) - - PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { - get.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET, - create.__name__: AuthorizationRules.ATTRIBUTE_TYPE_CREATE, - get_paginator.__name__: AuthorizationRules.ATTRIBUTE_TYPE_GET_PAGINATOR, # noqa: E501 - update.__name__: AuthorizationRules.ATTRIBUTE_TYPE_UPDATE, - delete_all_by_names.__name__: AuthorizationRules.ATTRIBUTE_TYPE_DELETE_ALL_BY_NAMES, # noqa: E501 - set_attr_replication_flag.__name__: AuthorizationRules.ATTRIBUTE_TYPE_SET_ATTR_REPLICATION_FLAG, # noqa: E501 - } diff --git a/app/ldap_protocol/ldap_schema/directory_dao.py b/app/ldap_protocol/ldap_schema/directory_dao.py new file mode 100644 index 000000000..456c1c0b7 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/directory_dao.py @@ -0,0 +1,105 @@ +"""Directory DAO. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from constants import CONFIGURATION_DIR_NAME +from entities import Directory, EntityType +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + + +class DirectoryDAO: + """Directory DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Directory DAO with session.""" + self.__session = session + + async def create_directory( + self, + name: str, + is_system: bool, + parent_dir: Directory, + parent_dir_id: int, + ) -> Directory: + """Create a Directory and return it with id populated.""" + directory = Directory( + is_system=is_system, + object_class="", + name=name, + ) + directory.groups = [] + directory.create_path(parent_dir, directory.get_dn_prefix()) + self.__session.add(directory) + await self.__session.flush() + + directory.parent_id = parent_dir_id + await self.__session.refresh(directory, ["id"]) + return directory + + async def get_all_without_entity_type(self) -> list[Directory]: + """Get all Directories without Entity Type.""" + result = await self.__session.scalars( + select(Directory) + .where(qa(Directory.entity_type_id).is_(None)) + .options( + selectinload(qa(Directory.attributes)), + selectinload(qa(Directory.entity_type)), + ), + ) + return list(result.all()) + + async def get_base_directory_paths_with_sid(self) -> list[tuple[str, str]]: + """Get all base directory paths.""" + base_dirs = await get_base_directories(self.__session) + return [ + (base_dir.path_dn, base_dir.object_sid) for base_dir in base_dirs + ] + + def get_object_sid(self, base_dn_sid: str, rid: int) -> str: + return f"{base_dn_sid}-{rid}" + + def is_dn_in_base_directory(self, path_dn: str, entry: str) -> bool: + """Check if an entry in a base dn.""" + return entry.lower().endswith(path_dn.lower()) + + async def get_configuration_dir(self) -> Directory: + """Get configuration directory.""" + result = await self.__session.execute( + select(Directory) + .where(qa(Directory.name) == CONFIGURATION_DIR_NAME), + ) # fmt: skip + return result.scalar_one() + + async def get_all_dir_ids_by_entity_type_name( + self, + name: str, + ) -> list[int]: + """Get all Directory IDs by Entity Type name.""" + result = await self.__session.scalars( + select(qa(Directory.id)) + .join(qa(Directory.entity_type)) + .where(qa(EntityType.name) == name), + ) + return list(result.all()) + + async def bind_entity_type( + self, + directory: Directory, + entity_type_id: int | None, + ) -> None: + """Ensure the Directory.entity_type relationship is loaded.""" + directory.entity_type_id = entity_type_id + await self.__session.flush() + await self.__session.refresh( + directory, + attribute_names=["entity_type"], + ) diff --git a/app/ldap_protocol/ldap_schema/dto.py b/app/ldap_protocol/ldap_schema/dto.py index 7699b6966..ae4eef8f3 100644 --- a/app/ldap_protocol/ldap_schema/dto.py +++ b/app/ldap_protocol/ldap_schema/dto.py @@ -54,3 +54,17 @@ class EntityTypeDTO(Generic[_IdT]): is_system: bool object_class_names: list[str] id: _IdT = None # type: ignore + + +@dataclass +class AttributeDTO: + name: str + values: list[str] + + +@dataclass +class CreateDirDTO: + name: str + entity_type_name: EntityTypeNames + attributes: tuple[AttributeDTO, ...] + is_system: bool diff --git a/app/ldap_protocol/ldap_schema/entity_type_dao.py b/app/ldap_protocol/ldap_schema/entity_type/entity_type_dao.py similarity index 54% rename from app/ldap_protocol/ldap_schema/entity_type_dao.py rename to app/ldap_protocol/ldap_schema/entity_type/entity_type_dao.py index 1a708d711..06683349a 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_dao.py +++ b/app/ldap_protocol/ldap_schema/entity_type/entity_type_dao.py @@ -4,7 +4,6 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -import contextlib from typing import Iterable from adaptix import P @@ -12,21 +11,19 @@ from sqlalchemy import delete, func, or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from abstract_dao import AbstractDAO -from entities import Attribute, Directory, EntityType, ObjectClass +from entities import Attribute, Directory, EntityType from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, AttributeValueValidatorError, ) +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.exceptions import ( EntityTypeAlreadyExistsError, EntityTypeCantModifyError, EntityTypeNotFoundError, ) -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO from ldap_protocol.utils.pagination import ( PaginationParams, PaginationResult, @@ -43,32 +40,31 @@ ) -class EntityTypeDAO(AbstractDAO[EntityTypeDTO, str]): +class EntityTypeDAO: """Entity Type DAO.""" __session: AsyncSession - __object_class_dao: ObjectClassDAO __attribute_value_validator: AttributeValueValidator + __directory_dao: DirectoryDAO def __init__( self, session: AsyncSession, - object_class_dao: ObjectClassDAO, attribute_value_validator: AttributeValueValidator, + directory_dao: DirectoryDAO, ) -> None: """Initialize Entity Type DAO with a database session.""" self.__session = session - self.__object_class_dao = object_class_dao self.__attribute_value_validator = attribute_value_validator + self.__directory_dao = directory_dao + + def generate_entity_type_name(self, directory: Directory) -> str: + return f"{directory.name}_entity_type_{directory.id}" async def get_all(self) -> list[EntityTypeDTO[int]]: """Get all Entity Types.""" - return [ - _convert(entity_type) - for entity_type in await self.__session.scalars( - select(EntityType), - ) - ] + res = await self.__session.scalars(select(EntityType)) + return list(map(_convert, res)) async def create(self, dto: EntityTypeDTO[None]) -> None: """Create a new Entity Type.""" @@ -90,26 +86,8 @@ async def update(self, name: str, dto: EntityTypeDTO[int]) -> None: entity_type = await self._get_one_raw_by_name(name) try: - await self.__object_class_dao.is_all_object_classes_exists( - dto.object_class_names, - ) - entity_type.name = dto.name - # Sort object_class_names to ensure a - # consistent order for database operations - # and to facilitate duplicate detection. - - entity_type.object_class_names = sorted( - dto.object_class_names, - ) - result = await self.__session.execute( - select(Directory) - .join(qa(Directory.entity_type)) - .filter(qa(EntityType.name) == entity_type.name) - .options(selectinload(qa(Directory.attributes))), - ) # fmt: skip - await self.__session.execute( delete(Attribute) .where( @@ -125,7 +103,19 @@ async def update(self, name: str, dto: EntityTypeDTO[int]) -> None: ), ) # fmt: skip - for directory in result.scalars(): + # Sort object_class_names to ensure a + # consistent order for database operations + # and to facilitate duplicate detection. + + entity_type.object_class_names = sorted( + dto.object_class_names, + ) + directory_ids = ( + await self.__directory_dao.get_all_dir_ids_by_entity_type_name( + entity_type.name, + ) + ) + for directory_id in directory_ids: for object_class_name in entity_type.object_class_names: if not self.__attribute_value_validator.is_value_valid( entity_type.name, @@ -138,7 +128,7 @@ async def update(self, name: str, dto: EntityTypeDTO[int]) -> None: self.__session.add( Attribute( - directory_id=directory.id, + directory_id=directory_id, name="objectClass", value=object_class_name, ), @@ -163,11 +153,7 @@ async def get_paginator( self, params: PaginationParams, ) -> PaginationResult[EntityType, EntityTypeDTO]: - """Retrieve paginated Entity Types. - - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Entity Types and metadata. - """ + """Retrieve paginated Entity Types.""" query = build_paginated_search_query( model=EntityType, order_by_field=qa(EntityType.name), @@ -183,12 +169,7 @@ async def get_paginator( ) async def _get_one_raw_by_name(self, name: str) -> EntityType: - """Get single Entity Type by name. - - :param str name: Entity Type name. - :raise EntityTypeNotFoundError: If Entity Type not found. - :return EntityType: Instance of Entity Type. - """ + """Get single Entity Type by name.""" entity_type = await self.__session.scalar( select(EntityType) .filter_by(name=name), @@ -201,37 +182,25 @@ async def _get_one_raw_by_name(self, name: str) -> EntityType: return entity_type async def get(self, name: str) -> EntityTypeDTO: - """Get single Entity Type by name. - - :param str name: Entity Type name. - :raise EntityTypeNotFoundError: If Entity Type not found. - :return EntityType: Instance of Entity Type. - """ + """Get single Entity Type by name.""" return _convert(await self._get_one_raw_by_name(name)) async def get_entity_type_by_object_class_names( self, object_class_names: Iterable[str], - ) -> EntityType | None: - """Get single Entity Type by object class names. - - :param Iterable[str] object_class_names: object class names. - :return EntityType | None: Instance of Entity Type or None. - """ + ) -> EntityTypeDTO | None: + """Get single Entity Type by object class names.""" list_object_class_names = [name.lower() for name in object_class_names] result = await self.__session.execute( select(EntityType) .where( - func.array_lowercase(EntityType.object_class_names).op("@>")( - list_object_class_names, - ), - func.array_lowercase(EntityType.object_class_names).op("<@")( - list_object_class_names, - ), + func.array_lowercase(EntityType.object_class_names).op("@>")(list_object_class_names), + func.array_lowercase(EntityType.object_class_names).op("<@")(list_object_class_names), ), ) # fmt: skip - return result.scalars().first() + entity_type = result.scalars().first() + return _convert(entity_type) if entity_type else None async def get_entity_type_names_include_oc_name( self, @@ -244,48 +213,11 @@ async def get_entity_type_names_include_oc_name( ) # fmt: skip return set(row[0] for row in result.fetchall()) - async def get_entity_type_attributes(self, name: str) -> list[str]: - """Get all attribute names for an Entity Type. - - :param str entity_type_name: Entity Type name. - :return list[str]: List of attribute names. - """ - entity_type = await self._get_one_raw_by_name(name) - - if not entity_type.object_class_names: - return [] - - object_classes_query = await self.__session.scalars( - select(ObjectClass) - .where( - qa(ObjectClass.name).in_( - entity_type.object_class_names, - ), - ) - .options( - selectinload(qa(ObjectClass.attribute_types_must)), - selectinload(qa(ObjectClass.attribute_types_may)), - ), - ) - object_classes = list(object_classes_query.all()) - - attribute_names = set() - for object_class in object_classes: - for attr in object_class.attribute_types_must: - attribute_names.add(attr.name) - for attr in object_class.attribute_types_may: - attribute_names.add(attr.name) - - return sorted(list(attribute_names)) - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete not system and not used Entity Type by their names. - - :param list[str] names: Entity Type names. - :return None. - """ + """Delete not system and not used Entity Type by their names.""" await self.__session.execute( - delete(EntityType).where( + delete(EntityType) + .where( qa(EntityType.name).in_(names), qa(EntityType.is_system).is_(False), qa(EntityType.id).not_in( @@ -295,74 +227,3 @@ async def delete_all_by_names(self, names: list[str]) -> None: ), ) # fmt: skip await self.__session.flush() - - async def attach_entity_type_to_directories(self) -> None: - """Find all Directories without an Entity Type and attach it to them. - - :return None. - """ - result = await self.__session.execute( - select(Directory) - .where(qa(Directory.entity_type_id).is_(None)) - .options( - selectinload(qa(Directory.attributes)), - selectinload(qa(Directory.entity_type)), - ), - ) - - for directory in result.scalars(): - await self.attach_entity_type_to_directory( - directory=directory, - is_system_entity_type=False, - ) - - await self.__session.flush() - - async def attach_entity_type_to_directory( - self, - directory: Directory, - is_system_entity_type: bool, - entity_type: EntityType | None = None, - object_class_names: set[str] | None = None, - ) -> None: - """Try to find the Entity Type, attach it to the Directory. - - :param Directory directory: Directory to attach Entity Type. - :param bool is_system_entity_type: Is system Entity Type. - :param EntityType | None entity_type: Predefined Entity Type. - :param set[str] | None object_class_names: Predefined object - class names. - :return None. - """ - if entity_type: - directory.entity_type = entity_type - return - - if object_class_names is None: - object_class_names = directory.object_class_names_set - - await self.__object_class_dao.is_all_object_classes_exists( - object_class_names, - ) - - entity_type = await self.get_entity_type_by_object_class_names( - object_class_names, - ) - if not entity_type: - entity_type_name = EntityType.generate_entity_type_name( - directory=directory, - ) - with contextlib.suppress(EntityTypeAlreadyExistsError): - await self.create( - EntityTypeDTO[None]( - name=entity_type_name, - object_class_names=list(object_class_names), - is_system=is_system_entity_type, - ), - ) - - entity_type = await self.get_entity_type_by_object_class_names( - object_class_names, - ) - - directory.entity_type = entity_type diff --git a/app/ldap_protocol/ldap_schema/entity_type/entity_type_use_case.py b/app/ldap_protocol/ldap_schema/entity_type/entity_type_use_case.py new file mode 100644 index 000000000..620560688 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/entity_type/entity_type_use_case.py @@ -0,0 +1,195 @@ +"""Entity Use Case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import contextlib +from typing import TYPE_CHECKING, ClassVar, Iterable + +from abstract_service import AbstractService +from enums import AuthorizationRules, EntityTypeNames +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.dto import EntityTypeDTO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.exceptions import ( + EntityTypeAlreadyExistsError, + EntityTypeCantModifyError, + EntityTypeNotFoundError, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) +from ldap_protocol.utils.pagination import PaginationParams, PaginationResult + +if TYPE_CHECKING: + from entities import Directory + + +class EntityTypeUseCase(AbstractService): + """Entity Use Case.""" + + __entity_type_dao: EntityTypeDAO + __object_class_dao: ObjectClassDAO + __directory_dao: DirectoryDAO + + def __init__( + self, + entity_type_dao: EntityTypeDAO, + object_class_dao: ObjectClassDAO, + directory_dao: DirectoryDAO, + ) -> None: + """Initialize Entity Use Case.""" + self.__entity_type_dao = entity_type_dao + self.__object_class_dao = object_class_dao + self.__directory_dao = directory_dao + + async def create(self, dto: EntityTypeDTO) -> None: + """Create Entity Type.""" + await self.__object_class_dao.is_all_object_classes_exists( + dto.object_class_names, + ) + + await self.__entity_type_dao.create(dto) + + async def create_not_safe(self, dto: EntityTypeDTO) -> None: + """Create Entity Type.""" + await self.__entity_type_dao.create(dto) + + async def update(self, name: str, dto: EntityTypeDTO) -> None: + """Update Entity Type.""" + try: + entity_type = await self.get(name) + + except EntityTypeNotFoundError: + raise EntityTypeCantModifyError + if entity_type.is_system: + raise EntityTypeCantModifyError( + f"Entity Type '{dto.name}' is system and cannot be modified.", + ) + if name != dto.name: + await self._validate_name(name=dto.name) + + await self.__object_class_dao.is_all_object_classes_exists( + dto.object_class_names, + ) + + await self.__entity_type_dao.update(entity_type.name, dto) + + async def get(self, name: str) -> EntityTypeDTO: + """Get Entity Type by name.""" + return await self.__entity_type_dao.get(name) + + async def _validate_name( + self, + name: str, + ) -> None: + if name in EntityTypeNames: + raise EntityTypeCantModifyError( + f"Can't change entity type name {name}", + ) + + async def get_paginator( + self, + params: PaginationParams, + ) -> PaginationResult: + """Get paginated Entity Types.""" + return await self.__entity_type_dao.get_paginator(params) + + async def get_entity_type_attributes(self, name: str) -> list[str]: + """Get entity type attributes.""" + entity_type = await self.__entity_type_dao.get(name) + + if not entity_type.object_class_names: + return [] + + object_class_dirs = await self.__object_class_dao.get_all_by_names( + entity_type.object_class_names, + ) + + attribute_names: set[str] = set() + for object_class_dir in object_class_dirs: + attribute_names.update(object_class_dir.attribute_types_may) + attribute_names.update(object_class_dir.attribute_types_must) + + return sorted(attribute_names) + + async def get_entity_type_by_object_class_names( + self, + object_class_names: Iterable[str], + ) -> EntityTypeDTO | None: + """Get Entity Type by object class names.""" + return ( + await self.__entity_type_dao.get_entity_type_by_object_class_names( + object_class_names, + ) + ) + + async def delete_all_by_names(self, names: list[str]) -> None: + """Delete all Entity Types by names.""" + await self.__entity_type_dao.delete_all_by_names(names) + + async def attach_entity_type_to_directories(self) -> None: + """Find all Directories without an Entity Type and attach it to them.""" # noqa: E501 + directories = await self.__directory_dao.get_all_without_entity_type() + + for directory in directories: + await self.attach_entity_type_to_directory( + directory=directory, + is_system_entity_type=False, + ) + + async def attach_entity_type_to_directory( + self, + directory: "Directory", + is_system_entity_type: bool, + entity_type_id: int | None = None, + object_class_names: set[str] | None = None, + ) -> None: + """Try to find the Entity Type, attach it to the Directory.""" + if entity_type_id: + await self.__directory_dao.bind_entity_type( + directory, + entity_type_id, + ) + return + + if object_class_names is None: + object_class_names = directory.object_class_names_set + + await self.__object_class_dao.is_all_object_classes_exists( + object_class_names, + ) + + entity_type = ( + await self.__entity_type_dao.get_entity_type_by_object_class_names( + object_class_names, + ) + ) + if not entity_type: + entity_type_name = ( + self.__entity_type_dao.generate_entity_type_name(directory) + ) + with contextlib.suppress(EntityTypeAlreadyExistsError): + await self.create( + EntityTypeDTO[None]( + name=entity_type_name, + object_class_names=list(object_class_names), + is_system=is_system_entity_type, + ), + ) + + entity_type = await self.__entity_type_dao.get_entity_type_by_object_class_names( # noqa: E501 + object_class_names, + ) + + await self.__directory_dao.bind_entity_type(directory, entity_type.id) # type: ignore + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + get.__name__: AuthorizationRules.ENTITY_TYPE_GET, + create.__name__: AuthorizationRules.ENTITY_TYPE_CREATE, + get_paginator.__name__: AuthorizationRules.ENTITY_TYPE_GET_PAGINATOR, + update.__name__: AuthorizationRules.ENTITY_TYPE_UPDATE, + delete_all_by_names.__name__: AuthorizationRules.ENTITY_TYPE_DELETE_ALL_BY_NAMES, # noqa: E501 + get_entity_type_attributes.__name__: AuthorizationRules.ENTITY_TYPE_GET_ATTRIBUTES, # noqa: E501 + } diff --git a/app/ldap_protocol/ldap_schema/entity_type_use_case.py b/app/ldap_protocol/ldap_schema/entity_type_use_case.py deleted file mode 100644 index e7589c3f4..000000000 --- a/app/ldap_protocol/ldap_schema/entity_type_use_case.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Entity Use Case. - -Copyright (c) 2025 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from typing import ClassVar - -from abstract_service import AbstractService -from constants import ENTITY_TYPE_DATAS -from enums import AuthorizationRules, EntityTypeNames -from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.exceptions import ( - EntityTypeCantModifyError, - EntityTypeNotFoundError, -) -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.utils.pagination import PaginationParams, PaginationResult - - -class EntityTypeUseCase(AbstractService): - """Entity Use Case.""" - - def __init__( - self, - entity_type_dao: EntityTypeDAO, - object_class_dao: ObjectClassDAO, - ) -> None: - """Initialize Entity Use Case. - - :param EntityTypeDAO entity_type_dao: Entity Type DAO. - :param ObjectClassDAO object_class_dao: Object Class DAO. - """ - self._entity_type_dao = entity_type_dao - self._object_class_dao = object_class_dao - - async def create(self, dto: EntityTypeDTO) -> None: - """Create Entity Type.""" - await self._object_class_dao.is_all_object_classes_exists( - dto.object_class_names, - ) - await self._entity_type_dao.create(dto) - - async def update(self, name: str, dto: EntityTypeDTO) -> None: - """Update Entity Type.""" - try: - entity_type = await self.get(name) - - except EntityTypeNotFoundError: - raise EntityTypeCantModifyError - if entity_type.is_system: - raise EntityTypeCantModifyError( - f"Entity Type '{dto.name}' is system and cannot be modified.", - ) - if name != dto.name: - await self._validate_name(name=dto.name) - await self._entity_type_dao.update(entity_type.name, dto) - - async def get(self, name: str) -> EntityTypeDTO: - """Get Entity Type by name.""" - return await self._entity_type_dao.get(name) - - async def _validate_name( - self, - name: str, - ) -> None: - if name in EntityTypeNames: - raise EntityTypeCantModifyError( - f"Can't change entity type name {name}", - ) - - async def get_paginator( - self, - params: PaginationParams, - ) -> PaginationResult: - """Get paginated Entity Types.""" - return await self._entity_type_dao.get_paginator(params) - - async def get_entity_type_attributes(self, name: str) -> list[str]: - """Get entity type attributes.""" - return await self._entity_type_dao.get_entity_type_attributes(name) - - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete all Entity Types by names.""" - await self._entity_type_dao.delete_all_by_names(names) - - async def create_for_first_setup(self) -> None: - """Create Entity Types for first setup. - - :return: None. - """ - for entity_type_data in ENTITY_TYPE_DATAS: - await self.create( - EntityTypeDTO( - name=entity_type_data["name"], - object_class_names=list( - entity_type_data["object_class_names"], - ), - is_system=True, - ), - ) - - PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { - get.__name__: AuthorizationRules.ENTITY_TYPE_GET, - create.__name__: AuthorizationRules.ENTITY_TYPE_CREATE, - get_paginator.__name__: AuthorizationRules.ENTITY_TYPE_GET_PAGINATOR, - update.__name__: AuthorizationRules.ENTITY_TYPE_UPDATE, - delete_all_by_names.__name__: AuthorizationRules.ENTITY_TYPE_DELETE_ALL_BY_NAMES, # noqa: E501 - get_entity_type_attributes.__name__: AuthorizationRules.ENTITY_TYPE_GET_ATTRIBUTES, # noqa: E501 - } diff --git a/app/ldap_protocol/ldap_schema/exceptions.py b/app/ldap_protocol/ldap_schema/exceptions.py index 02a9d43f3..1f5273519 100644 --- a/app/ldap_protocol/ldap_schema/exceptions.py +++ b/app/ldap_protocol/ldap_schema/exceptions.py @@ -22,6 +22,7 @@ class ErrorCodes(IntEnum): ENTITY_TYPE_NOT_FOUND_ERROR = 7 ENTITY_TYPE_CANT_MODIFY_ERROR = 8 ENTITY_TYPE_ALREADY_EXISTS_ERROR = 9 + CANT_CREATE_DIRECTORY_WITH_SCHEMA_LIKE_AS_DIRECTORY = 10 class LdapSchemaError(BaseDomainException): @@ -30,6 +31,12 @@ class LdapSchemaError(BaseDomainException): code: ErrorCodes = ErrorCodes.BASE_ERROR +class CantCreateDirectoryWithSchemaLikeAsDirectoryError(LdapSchemaError): + """Raised when trying to create directory with schema like as directory.""" + + code = ErrorCodes.CANT_CREATE_DIRECTORY_WITH_SCHEMA_LIKE_AS_DIRECTORY + + class AttributeTypeNotFoundError(LdapSchemaError): """Raised when an attribute type is not found.""" diff --git a/app/ldap_protocol/ldap_schema/object_class/constants.py b/app/ldap_protocol/ldap_schema/object_class/constants.py new file mode 100644 index 000000000..d77d1ec13 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/object_class/constants.py @@ -0,0 +1,15 @@ +"""Constants for object class property names.""" + +from enum import StrEnum + + +class ObjectClassAttributeNames(StrEnum): + """Attribute Type attribute names.""" + + OID = "governsID" + NAME = "name" + OBJECT_CLASS = "objectClass" + SUPERIOR_NAME = "subClassOf" + KIND = "objectClassCategory" + ATTRIBUTE_TYPES_MUST = "mustContain" + ATTRIBUTE_TYPES_MAY = "mayContain" diff --git a/app/ldap_protocol/ldap_schema/object_class/object_class_dao.py b/app/ldap_protocol/ldap_schema/object_class/object_class_dao.py new file mode 100644 index 000000000..bd8c55a9f --- /dev/null +++ b/app/ldap_protocol/ldap_schema/object_class/object_class_dao.py @@ -0,0 +1,222 @@ +"""Object Class DAO. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Iterable, Literal + +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.object_class.constants import ( + ObjectClassAttributeNames as Names, +) +from ldap_protocol.utils.pagination import PaginationParams, PaginationResult +from repo.pg.tables import queryable_attr as qa + +from ..dto import ObjectClassDTO +from ..exceptions import ObjectClassCantModifyError, ObjectClassNotFoundError + + +def _convert_model_to_dto(dir_: Directory) -> ObjectClassDTO[int, str]: + return ObjectClassDTO( + oid=dir_.attributes_dict.get(Names.OID)[0], # type: ignore + name=dir_.name, + superior_name=dir_.attributes_dict.get(Names.SUPERIOR_NAME)[0], # type: ignore + kind=dir_.attributes_dict.get(Names.KIND)[0], # type: ignore + is_system=dir_.is_system, + attribute_types_must=dir_.attributes_dict.get(Names.ATTRIBUTE_TYPES_MUST, []), # noqa: E501 + attribute_types_may=dir_.attributes_dict.get(Names.ATTRIBUTE_TYPES_MAY, []), # noqa: E501 + id=dir_.id, + entity_type_names=set(), + ) # fmt: skip + + +class ObjectClassDAO: + """Object Class DAO.""" + + __session: AsyncSession + + def __init__( + self, + session: AsyncSession, + ) -> None: + """Initialize Object Class DAO with session.""" + self.__session = session + + async def get_all(self) -> list[ObjectClassDTO[int, str]]: + """Get all Object Classes.""" + result = await self.__session.scalars( + select(Directory) + .join(qa(Directory.entity_type)) + .where(qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS) + .options(selectinload(qa(Directory.attributes))), + ) + return list(map(_convert_model_to_dto, result)) + + async def get_object_class_names_include_attribute_type( + self, + attribute_type_name: str, + ) -> set[str]: + """Get all Object Class names include Attribute Type name.""" + result = await self.__session.scalars( + select(qa(Directory.name)) + .select_from(qa(Directory)) + .join(qa(Directory.entity_type)) + .join(qa(Directory.attributes)) + .where( + qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS, + qa(Attribute.name).in_((Names.ATTRIBUTE_TYPES_MUST, Names.ATTRIBUTE_TYPES_MAY)), # noqa: E501 + func.lower(qa(Attribute.value)) == attribute_type_name.lower(), + ), + ) # fmt: skip + return set(result.all()) + + async def delete(self, name: str) -> None: + """Delete Object Class.""" + object_class = await self._get_dir(name) + await self.__session.delete(object_class) + await self.__session.flush() + + async def get_paginator( + self, + params: PaginationParams, + ) -> PaginationResult[Directory, ObjectClassDTO]: + """Retrieve paginated Object Classes.""" + filters = [qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS] + + query = ( + select(Directory) + .join(qa(Directory.entity_type)) + .where(*filters) + .options(selectinload(qa(Directory.attributes))) + .order_by(qa(Directory.id)) + ) + + return await PaginationResult[Directory, ObjectClassDTO].get( + params=params, + query=query, + converter=_convert_model_to_dto, + session=self.__session, + ) + + async def is_all_object_classes_exists( + self, + names: Iterable[str], + ) -> Literal[True]: + """Check if all Object Classes exist.""" + names = set(object_class.lower() for object_class in names) + + count_query = ( + select(func.count()) + .select_from(Directory) + .join(qa(Directory.entity_type)) + .where( + qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS, + func.lower(qa(Directory.name)).in_(names), + ) + ) + + result = await self.__session.scalar(count_query) + count_ = int(result or 0) + + if count_ != len(names): + raise ObjectClassNotFoundError( + f"Not all Object Classes with names {names} ( != {count_} ) found.", # noqa: E501 + ) + + return True + + async def get(self, name: str) -> ObjectClassDTO: + dir_ = await self._get_dir(name) + if not dir_: + raise ObjectClassNotFoundError( + f"Object Class with name '{name}' not found.", + ) + + return _convert_model_to_dto(dir_) + + async def _get_dir(self, name: str) -> Directory | None: + res = await self.__session.scalars( + select(Directory) + .join(qa(Directory.entity_type)) + .where( + qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS, + qa(Directory.name) == name, + ) + .options(selectinload(qa(Directory.attributes))), + ) + return res.first() + + async def get_all_by_names( + self, + names: list[str] | set[str], + ) -> list[ObjectClassDTO[int, str]]: + """Get list of Object Classes by names.""" + query = await self.__session.scalars( + select(Directory) + .join(qa(Directory.entity_type)) + .where( + qa(Directory.name).in_(names), + qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS, + ) + .options(selectinload(qa(Directory.attributes))), + ) + return list(map(_convert_model_to_dto, query.all())) + + async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: + """Update Object Class.""" + obj = await self.get(name) + if obj.is_system: + raise ObjectClassCantModifyError( + "System Object Class cannot be modified.", + ) + + await self.__session.execute( + delete(Attribute) + .where( + qa(Attribute.directory_id) == obj.id, + qa(Attribute.name).in_((Names.ATTRIBUTE_TYPES_MUST, Names.ATTRIBUTE_TYPES_MAY)), # noqa: E501 + ), + ) # fmt: skip + + for value in dto.attribute_types_may: + self.__session.add( + Attribute( + directory_id=obj.id, + name=Names.ATTRIBUTE_TYPES_MAY, + value=value, + ), + ) + + for value in dto.attribute_types_must: + self.__session.add( + Attribute( + directory_id=obj.id, + name=Names.ATTRIBUTE_TYPES_MUST, + value=value, + ), + ) + + await self.__session.flush() + + async def delete_all_by_names(self, names: list[str]) -> None: + """Delete not system Object Classes by Names.""" + subq = ( + select(func.unnest(qa(EntityType.object_class_names))) + .where(qa(EntityType.object_class_names).isnot(None)) + ) # fmt: skip + + await self.__session.execute( + delete(Directory) + .where( + qa(Directory.entity_type).has(qa(EntityType.name) == EntityTypeNames.OBJECT_CLASS), # noqa: E501 + qa(Directory.name).in_(names), + qa(Directory.is_system).is_(False), + ~qa(Directory.name).in_(subq), + ), + ) # fmt: skip diff --git a/app/ldap_protocol/ldap_schema/object_class/object_class_raw_display.py b/app/ldap_protocol/ldap_schema/object_class/object_class_raw_display.py new file mode 100644 index 000000000..4f2514531 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/object_class/object_class_raw_display.py @@ -0,0 +1,27 @@ +"""ObjectClassRawDisplay.""" + +from ldap_protocol.ldap_schema.dto import ObjectClassDTO + + +class ObjectClassRawDisplay: + @staticmethod + def get_raw_definition(dto: ObjectClassDTO) -> str: + if not dto.oid or not dto.name or not dto.kind: + raise ValueError( + f"{dto}: Fields 'oid', 'name', and 'kind'" + " are required for LDAP definition.", + ) + chunks = ["(", dto.oid, f"NAME '{dto.name}'"] + if dto.superior_name: + chunks.append(f"SUP {dto.superior_name}") + chunks.append(dto.kind) + if dto.attribute_types_must: + chunks.append( + f"MUST ({' $ '.join(dto.attribute_types_must)} )", + ) + if dto.attribute_types_may: + chunks.append( + f"MAY ({' $ '.join(dto.attribute_types_may)} )", + ) + chunks.append(")") + return " ".join(chunks) diff --git a/app/ldap_protocol/ldap_schema/object_class/object_class_use_case.py b/app/ldap_protocol/ldap_schema/object_class/object_class_use_case.py new file mode 100644 index 000000000..6c086df64 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/object_class/object_class_use_case.py @@ -0,0 +1,178 @@ +"""Object Class Use Case. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import ClassVar + +from sqlalchemy.exc import IntegrityError + +from abstract_service import AbstractService +from enums import AuthorizationRules, EntityTypeNames +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.dto import ( + AttributeDTO, + CreateDirDTO, + ObjectClassDTO, +) +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.exceptions import ( + ObjectClassAlreadyExistsError, + ObjectClassNotFoundError, +) +from ldap_protocol.ldap_schema.object_class.constants import ( + ObjectClassAttributeNames as Names, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) +from ldap_protocol.ldap_schema.schema_create_use_case import ( + SchemaLikeAsDirectoryCreateUseCase, +) +from ldap_protocol.utils.pagination import PaginationParams, PaginationResult + + +class ObjectClassUseCase(AbstractService): + """ObjectClassUseCase.""" + + __attribute_type_dao: AttributeTypeDAO + __object_class_dao: ObjectClassDAO + __entity_type_dao: EntityTypeDAO + __schema_create_use_case: SchemaLikeAsDirectoryCreateUseCase + + def __init__( + self, + attribute_type_dao: AttributeTypeDAO, + object_class_dao: ObjectClassDAO, + entity_type_dao: EntityTypeDAO, + schema_create_use_case: SchemaLikeAsDirectoryCreateUseCase, + ) -> None: + """Init ObjectClassUseCase.""" + self.__attribute_type_dao = attribute_type_dao + self.__object_class_dao = object_class_dao + self.__entity_type_dao = entity_type_dao + self.__schema_create_use_case = schema_create_use_case + + async def get_all(self) -> list[ObjectClassDTO[int, str]]: + """Get all Object Classes.""" + return await self.__object_class_dao.get_all() + + async def delete(self, name: str) -> None: + """Delete Object Class.""" + await self.__object_class_dao.delete(name) + + async def get_paginator( + self, + params: PaginationParams, + ) -> PaginationResult: + """Retrieve paginated Object Classes.""" + return await self.__object_class_dao.get_paginator(params) + + async def create(self, dto: ObjectClassDTO[None, str]) -> None: + """Create a new Object Class.""" + attribute_types_may_filtered = [ + name + for name in dto.attribute_types_may + if name not in dto.attribute_types_must + ] + + if dto.attribute_types_must: + dto.attribute_types_must = ( + await self.__attribute_type_dao.get_all_names_by_names( + dto.attribute_types_must, + ) + ) + + if attribute_types_may_filtered: + dto.attribute_types_may = ( + await self.__attribute_type_dao.get_all_names_by_names( + attribute_types_may_filtered, + ) + ) + + superior = None + if dto.superior_name: + superior = await self.__object_class_dao.get( + dto.superior_name, + ) + + if not superior: + raise ObjectClassNotFoundError( + f"Superior (parent) Object class {dto.superior_name} " + "not found in schema.", + ) + + _dto = CreateDirDTO( + name=dto.name, + entity_type_name=EntityTypeNames.OBJECT_CLASS, + attributes=( + AttributeDTO( + name=Names.OBJECT_CLASS, + values=["top", "classSchema"], + ), + AttributeDTO(name=Names.OID, values=[str(dto.oid)]), + AttributeDTO(name=Names.NAME, values=[str(dto.name)]), + AttributeDTO( + name=Names.SUPERIOR_NAME, + values=[str(dto.superior_name)], + ), + AttributeDTO(name=Names.KIND, values=[str(dto.kind)]), + AttributeDTO( + name=Names.ATTRIBUTE_TYPES_MUST, + values=dto.attribute_types_must, + ), + AttributeDTO( + name=Names.ATTRIBUTE_TYPES_MAY, + values=dto.attribute_types_may, + ), + ), + is_system=dto.is_system, + ) + try: + await self.__schema_create_use_case.create_dir(dto=_dto) + except IntegrityError: + raise ObjectClassAlreadyExistsError( + f"Object Class with oid '{dto.oid}' and name" + + f" '{dto.name}' already exists.", + ) + + async def get(self, name: str) -> ObjectClassDTO: + """Get Object Class by name.""" + dto = await self.__object_class_dao.get(name) + dto.entity_type_names = ( + await self.__entity_type_dao.get_entity_type_names_include_oc_name( + dto.name, + ) + ) + return dto + + async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: + """Modify Object Class.""" + dto.attribute_types_must = ( + await self.__attribute_type_dao.get_all_names_by_names( + dto.attribute_types_must, + ) + ) + dto.attribute_types_may = [ + name + for name in await self.__attribute_type_dao.get_all_names_by_names( + dto.attribute_types_may, + ) + if name not in dto.attribute_types_must + ] + await self.__object_class_dao.update(name, dto) + + async def delete_all_by_names(self, names: list[str]) -> None: + """Delete not system Object Classes by Names.""" + await self.__object_class_dao.delete_all_by_names(names) + + PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { + get.__name__: AuthorizationRules.OBJECT_CLASS_GET, + create.__name__: AuthorizationRules.OBJECT_CLASS_CREATE, + get_paginator.__name__: AuthorizationRules.OBJECT_CLASS_GET_PAGINATOR, + update.__name__: AuthorizationRules.OBJECT_CLASS_UPDATE, + delete_all_by_names.__name__: AuthorizationRules.OBJECT_CLASS_DELETE_ALL_BY_NAMES, # noqa: E501 + } diff --git a/app/ldap_protocol/ldap_schema/object_class_dao.py b/app/ldap_protocol/ldap_schema/object_class_dao.py deleted file mode 100644 index 83bcd7eef..000000000 --- a/app/ldap_protocol/ldap_schema/object_class_dao.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Object Class DAO. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from typing import Iterable, Literal - -from adaptix import P -from adaptix.conversion import ( - allow_unlinked_optional, - get_converter, - link_function, -) -from sqlalchemy import delete, func, or_, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from abstract_dao import AbstractDAO -from entities import AttributeType, EntityType, ObjectClass -from ldap_protocol.utils.pagination import ( - PaginationParams, - PaginationResult, - build_paginated_search_query, -) -from repo.pg.tables import queryable_attr as qa - -from .dto import AttributeTypeDTO, ObjectClassDTO -from .exceptions import ( - ObjectClassAlreadyExistsError, - ObjectClassCantModifyError, - ObjectClassNotFoundError, -) - -_converter = get_converter( - ObjectClass, - ObjectClassDTO[int, AttributeTypeDTO], - recipe=[ - allow_unlinked_optional(P[ObjectClassDTO].id), - allow_unlinked_optional(P[ObjectClassDTO].entity_type_names), - allow_unlinked_optional(P[AttributeTypeDTO].object_class_names), - link_function(lambda x: x.kind, P[ObjectClassDTO].kind), - ], -) - - -class ObjectClassDAO(AbstractDAO[ObjectClassDTO, str]): - """Object Class DAO.""" - - def __init__(self, session: AsyncSession) -> None: - """Initialize Object Class DAO with session.""" - self.__session = session - - async def get_all(self) -> list[ObjectClassDTO[int, AttributeTypeDTO]]: - """Get all Object Classes.""" - return [ - _converter(object_class) - for object_class in await self.__session.scalars( - select(ObjectClass), - ) - ] - - async def get_object_class_names_include_attribute_type( - self, - attribute_type_name: str, - ) -> set[str]: - """Get all Object Class names include Attribute Type name.""" - result = await self.__session.execute( - select(qa(ObjectClass.name)) - .where( - or_( - qa(ObjectClass.attribute_types_must).any(name=attribute_type_name), - qa(ObjectClass.attribute_types_may).any(name=attribute_type_name), - ), - ), - ) # fmt: skip - return set(row[0] for row in result.fetchall()) - - async def delete(self, name: str) -> None: - """Delete Object Class.""" - object_class = await self._get_one_raw_by_name(name) - await self.__session.delete(object_class) - await self.__session.flush() - - async def get_paginator( - self, - params: PaginationParams, - ) -> PaginationResult[ObjectClass, ObjectClassDTO]: - """Retrieve paginated Object Classes. - - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Object Classes and metadata. - """ - query = build_paginated_search_query( - model=ObjectClass, - order_by_field=qa(ObjectClass.id), - params=params, - search_field=qa(ObjectClass.name), - load_params=( - selectinload(qa(ObjectClass).attribute_types_may), - selectinload(qa(ObjectClass).attribute_types_must), - ), - ) - - return await PaginationResult[ObjectClass, ObjectClassDTO].get( - params=params, - query=query, - converter=_converter, - session=self.__session, - ) - - async def create( - self, - dto: ObjectClassDTO[None, str], - ) -> None: - """Create a new Object Class. - - :param str oid: OID. - :param str name: Name. - :param str | None superior_name: Parent Object Class. - :param KindType kind: Kind. - :param bool is_system: Object Class is system. - :param list[str] attribute_type_names_must: Attribute Types must. - :param list[str] attribute_type_names_may: Attribute Types may. - :raise ObjectClassNotFoundError: If superior Object Class not found. - :return None. - """ - try: - superior = None - if dto.superior_name: - superior = await self.__session.scalar( - select(ObjectClass) - .filter_by(name=dto.superior_name), - ) # fmt: skip - - if dto.superior_name and not superior: - raise ObjectClassNotFoundError( - f"Superior (parent) Object class {dto.superior_name} " - "not found in schema.", - ) - - attribute_types_may_filtered = [ - name - for name in dto.attribute_types_may - if name not in dto.attribute_types_must - ] - - if dto.attribute_types_must: - res = await self.__session.scalars( - select(AttributeType) - .where(qa(AttributeType.name).in_(dto.attribute_types_must)), - ) # fmt: skip - attribute_types_must = list(res.all()) - - else: - attribute_types_must = [] - - if attribute_types_may_filtered: - res = await self.__session.scalars( - select(AttributeType) - .where( - qa(AttributeType.name).in_(attribute_types_may_filtered), - ), - ) # fmt: skip - attribute_types_may = list(res.all()) - else: - attribute_types_may = [] - - object_class = ObjectClass( - oid=dto.oid, - name=dto.name, - superior=superior, - kind=dto.kind, - is_system=dto.is_system, - attribute_types_must=attribute_types_must, - attribute_types_may=attribute_types_may, - ) - self.__session.add(object_class) - await self.__session.flush() - except IntegrityError: - raise ObjectClassAlreadyExistsError( - f"Object Class with oid '{dto.oid}' and name" - + f" '{dto.name}' already exists.", - ) - - async def _count_exists_object_class_by_names( - self, - names: Iterable[str], - ) -> int: - """Count exists Object Class by names. - - :param list[str] names: Object Class names. - :return int. - """ - count_query = ( - select(func.count()) - .select_from(ObjectClass) - .where(func.lower(ObjectClass.name).in_(names)) - ) - result = await self.__session.scalars(count_query) - return result.one() - - async def is_all_object_classes_exists( - self, - names: Iterable[str], - ) -> Literal[True]: - """Check if all Object Classes exist. - - :param list[str] names: Object Class names. - :raise ObjectClassNotFoundError: If Object Class not found. - :return bool. - """ - names = set(object_class.lower() for object_class in names) - - count_ = await self._count_exists_object_class_by_names( - names, - ) - - if count_ != len(names): - raise ObjectClassNotFoundError( - f"Not all Object Classes\ - with names {names} found.", - ) - - return True - - async def _get_one_raw_by_name(self, name: str) -> ObjectClass: - """Get single Object Class by name. - - :param str name: Object Class name. - :raise ObjectClassNotFoundError: If Object Class not found. - :return ObjectClass: Instance of Object Class. - """ - object_class = await self.__session.scalar( - select(ObjectClass) - .filter_by(name=name) - .options(selectinload(qa(ObjectClass.attribute_types_may))) - .options(selectinload(qa(ObjectClass.attribute_types_must))), - ) # fmt: skip - - if not object_class: - raise ObjectClassNotFoundError( - f"Object Class with name '{name}' not found.", - ) - return object_class - - async def get(self, name: str) -> ObjectClassDTO: - """Get single Object Class by name. - - :param str name: Object Class name. - :raise ObjectClassNotFoundError: If Object Class not found. - :return ObjectClass: Instance of Object Class. - """ - return _converter(await self._get_one_raw_by_name(name)) - - async def get_all_by_names( - self, - names: list[str] | set[str], - ) -> list[ObjectClassDTO]: - """Get list of Object Classes by names. - - :param list[str] names: Object Classes names. - :return list[ObjectClassDTO]: List of Object Classes. - """ - query = await self.__session.scalars( - select(ObjectClass) - .where(qa(ObjectClass.name).in_(names)) - .options( - selectinload(qa(ObjectClass.attribute_types_must)), - selectinload(qa(ObjectClass.attribute_types_may)), - ), - ) # fmt: skip - return list(map(_converter, query.all())) - - async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: - """Update Object Class.""" - obj = await self._get_one_raw_by_name(name) - if obj.is_system: - raise ObjectClassCantModifyError( - "System Object Class cannot be modified.", - ) - - obj.attribute_types_must.clear() - obj.attribute_types_may.clear() - - if dto.attribute_types_must: - must_query = await self.__session.scalars( - select(AttributeType).where( - qa(AttributeType.name).in_( - dto.attribute_types_must, - ), - ), - ) - obj.attribute_types_must.extend(must_query.all()) - - attribute_types_may_filtered = [ - name - for name in dto.attribute_types_may - if name not in dto.attribute_types_must - ] - - if attribute_types_may_filtered: - may_query = await self.__session.scalars( - select(AttributeType) - .where(qa(AttributeType.name).in_(attribute_types_may_filtered)), - ) # fmt: skip - obj.attribute_types_may.extend(list(may_query.all())) - - await self.__session.flush() - - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete not system Object Classes by Names. - - :param list[str] names: Object Classes names. - :return None. - """ - subq = ( - select(func.unnest(qa(EntityType.object_class_names))) - .where(qa(EntityType.object_class_names).isnot(None)) - ) # fmt: skip - - await self.__session.execute( - delete(ObjectClass) - .where( - qa(ObjectClass.name).in_(names), - qa(ObjectClass.is_system).is_(False), - ~qa(ObjectClass.name).in_(subq), - ), - ) # fmt: skip diff --git a/app/ldap_protocol/ldap_schema/object_class_use_case.py b/app/ldap_protocol/ldap_schema/object_class_use_case.py deleted file mode 100644 index 11c171a58..000000000 --- a/app/ldap_protocol/ldap_schema/object_class_use_case.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Object Class Use Case. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from typing import ClassVar - -from abstract_service import AbstractService -from enums import AuthorizationRules -from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.utils.pagination import PaginationParams, PaginationResult - - -class ObjectClassUseCase(AbstractService): - """ObjectClassUseCase.""" - - def __init__( - self, - object_class_dao: ObjectClassDAO, - entity_type_dao: EntityTypeDAO, - ) -> None: - """Init ObjectClassUseCase.""" - self._object_class_dao = object_class_dao - self._entity_type_dao = entity_type_dao - - async def get_all(self) -> list[ObjectClassDTO[int, AttributeTypeDTO]]: - """Get all Object Classes.""" - return await self._object_class_dao.get_all() - - async def delete(self, name: str) -> None: - """Delete Object Class.""" - await self._object_class_dao.delete(name) - - async def get_paginator( - self, - params: PaginationParams, - ) -> PaginationResult: - """Retrieve paginated Object Classes.""" - return await self._object_class_dao.get_paginator(params) - - async def create(self, dto: ObjectClassDTO[None, str]) -> None: - """Create a new Object Class.""" - await self._object_class_dao.create(dto) - - async def get(self, name: str) -> ObjectClassDTO: - """Get Object Class by name.""" - dto = await self._object_class_dao.get(name) - dto.entity_type_names = ( - await self._entity_type_dao.get_entity_type_names_include_oc_name( - dto.name, - ) - ) - return dto - - async def get_all_by_names( - self, - names: list[str] | set[str], - ) -> list[ObjectClassDTO]: - """Get list of Object Classes by names.""" - return await self._object_class_dao.get_all_by_names(names) - - async def update(self, name: str, dto: ObjectClassDTO[None, str]) -> None: - """Modify Object Class.""" - await self._object_class_dao.update(name, dto) - - async def delete_all_by_names(self, names: list[str]) -> None: - """Delete not system Object Classes by Names.""" - await self._object_class_dao.delete_all_by_names(names) - - PERMISSIONS: ClassVar[dict[str, AuthorizationRules]] = { - get.__name__: AuthorizationRules.OBJECT_CLASS_GET, - create.__name__: AuthorizationRules.OBJECT_CLASS_CREATE, - get_paginator.__name__: AuthorizationRules.OBJECT_CLASS_GET_PAGINATOR, - update.__name__: AuthorizationRules.OBJECT_CLASS_UPDATE, - delete_all_by_names.__name__: AuthorizationRules.OBJECT_CLASS_DELETE_ALL_BY_NAMES, # noqa: E501 - } diff --git a/app/ldap_protocol/ldap_schema/raw_definition_parser.py b/app/ldap_protocol/ldap_schema/raw_definition_parser.py new file mode 100644 index 000000000..1cbc29a36 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/raw_definition_parser.py @@ -0,0 +1,78 @@ +"""Raw definition parser. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Iterable + +from ldap3.protocol.rfc4512 import AttributeTypeInfo, ObjectClassInfo + +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO, ObjectClassDTO + + +class RawDefinitionParser: + """Parser for ObjectClass and AttributeType raw definition.""" + + @staticmethod + def _list_to_string(data: Iterable[str]) -> str | None: + if not data: + return None + + data = list(data) + if len(data) == 1: + return data[0] + + raise ValueError("Data is not a single element list") + + @staticmethod + def _get_attribute_type_info(raw_definition: str) -> AttributeTypeInfo: + tmp = AttributeTypeInfo.from_definition(definitions=[raw_definition]) + return RawDefinitionParser._list_to_string(tmp.values()) + + @staticmethod + def get_object_class_info(raw_definition: str) -> ObjectClassInfo: + tmp = ObjectClassInfo.from_definition(definitions=[raw_definition]) + return RawDefinitionParser._list_to_string(tmp.values()) + + @staticmethod + def collect_attribute_type_dto_from_raw( + raw_definition: str, + ) -> AttributeTypeDTO[None]: + attribute_type_info = RawDefinitionParser._get_attribute_type_info( + raw_definition=raw_definition, + ) + + name = RawDefinitionParser._list_to_string(attribute_type_info.name) + if not name: + raise ValueError("Attribute Type name is required") + + return AttributeTypeDTO( + oid=attribute_type_info.oid, + name=name, + syntax=attribute_type_info.syntax, + single_value=attribute_type_info.single_value, + no_user_modification=attribute_type_info.no_user_modification, + is_system=True, + system_flags=0, + is_included_anr=False, + ) + + @staticmethod + async def collect_object_class_dto_from_info( + object_class_info: ObjectClassInfo, + ) -> ObjectClassDTO: + """Create Object Class by ObjectClassInfo.""" + name = RawDefinitionParser._list_to_string(object_class_info.name) + if not name: + raise ValueError("Attribute Type name is required") + + return ObjectClassDTO( + oid=object_class_info.oid, + name=name, + superior_name=RawDefinitionParser._list_to_string(object_class_info.superior), + kind=object_class_info.kind, + is_system=True, + attribute_types_must=object_class_info.must_contain, + attribute_types_may=object_class_info.may_contain, + ) # fmt: skip diff --git a/app/ldap_protocol/ldap_schema/schema_create_use_case.py b/app/ldap_protocol/ldap_schema/schema_create_use_case.py new file mode 100644 index 000000000..247598cc3 --- /dev/null +++ b/app/ldap_protocol/ldap_schema/schema_create_use_case.py @@ -0,0 +1,118 @@ +"""Identity use cases. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import TYPE_CHECKING + +from sqlalchemy.ext.asyncio import AsyncSession + +from ldap_protocol.ldap_schema.attribute_dao import AttributeDAO +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.dto import AttributeDTO, CreateDirDTO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.exceptions import ( + CantCreateDirectoryWithSchemaLikeAsDirectoryError, +) +from ldap_protocol.roles.role_use_case import RoleUseCase + +if TYPE_CHECKING: + from entities import Directory + + +class SchemaLikeAsDirectoryCreateUseCase: + """Setup use case.""" + + __session: AsyncSession + __entity_type_use_case: EntityTypeUseCase + __role_use_case: RoleUseCase + __directory_dao: DirectoryDAO + __attribute_dao: AttributeDAO + __parent_dir: "Directory | None" + + def __init__( + self, + session: AsyncSession, + entity_type_use_case: EntityTypeUseCase, + role_use_case: RoleUseCase, + directory_dao: DirectoryDAO, + attribute_dao: AttributeDAO, + ) -> None: + """Initialize.""" + self.__session = session + self.__entity_type_use_case = entity_type_use_case + self.__role_use_case = role_use_case + self.__directory_dao = directory_dao + self.__attribute_dao = attribute_dao + self.__parent_dir = None + + async def create_dir(self, dto: CreateDirDTO) -> None: + """Create.""" + if not self.__parent_dir: + self.__parent_dir = ( + await self.__directory_dao.get_configuration_dir() + ) + + base_directory_paths_and_sids = ( + await self.__directory_dao.get_base_directory_paths_with_sid() + ) + + dir_ = await self.__directory_dao.create_directory( + name=dto.name, + is_system=dto.is_system, + parent_dir=self.__parent_dir, + parent_dir_id=self.__parent_dir.id, + ) + + for _path, _sid in base_directory_paths_and_sids: + if self.__directory_dao.is_dn_in_base_directory( + _path, + dir_.path_dn, + ): + base_dn_sid = _sid + break + else: + raise CantCreateDirectoryWithSchemaLikeAsDirectoryError( + "Cannot create a directory with schema like as directory.", + ) + + dir_.object_sid = self.__directory_dao.get_object_sid( + base_dn_sid, + dir_.id, + ) + + attr_dto = AttributeDTO(name=dir_.rdname, values=[dir_.name]) + await self.__attribute_dao.add_directory_name_attribute( + dir_.id, + attr_dto, + ) + + await self.__attribute_dao.add_attributes_from_dto( + directory_id=dir_.id, + attributes=dto.attributes, + ) + + await self.__session.flush() + + await self.__session.refresh( + instance=dir_, + attribute_names=["attributes"], + ) + + entity_type = await self.__entity_type_use_case.get( + dto.entity_type_name, + ) + await self.__directory_dao.bind_entity_type( + dir_, + entity_type.id if entity_type else None, + ) + await self.__session.flush() + + await self.__role_use_case.inherit_parent_aces( + parent_directory=self.__parent_dir, + directory=dir_, + ) + await self.__session.flush() diff --git a/app/ldap_protocol/roles/migrations_ace_dao.py b/app/ldap_protocol/roles/migrations_ace_dao.py new file mode 100644 index 000000000..7d332a272 --- /dev/null +++ b/app/ldap_protocol/roles/migrations_ace_dao.py @@ -0,0 +1,112 @@ +"""Access control entry DAO. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Sequence + +from entities_legacy import AttributeTypeLegacy +from sqlalchemy import Row, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import AccessControlEntry, Directory, EntityType +from enums import EntityTypeNames +from repo.pg.tables import queryable_attr as qa + + +class AccessControlEntryMigrationsDAO: + """Access control entry DAO.""" + + __session: AsyncSession + + def __init__(self, session: AsyncSession) -> None: + """Initialize Access Control Entry DAO with a database session.""" + self.__session = session + + async def upgrade(self) -> None: + ace_rows = await self._get_all_raw_aces_legacy() + if not ace_rows: + return + + attribute_names = {row.name for row in ace_rows} + directory_rows_q = await self.__session.execute( + select(qa(Directory.name), qa(Directory.id)) + .join( + EntityType, + qa(EntityType.id) == qa(Directory.entity_type_id), + ) + .where(qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE) + .where(qa(Directory.name).in_(attribute_names)), + ) + directory_by_name = {row.name: row.id for row in directory_rows_q} + + updates = [ + {"ace_id": row.id, "directory_id": directory_by_name[row.name]} + for row in ace_rows + if row.name in directory_by_name + ] + if not updates: + return + + for item in updates: + await self.__session.execute( + update(AccessControlEntry) + .where(qa(AccessControlEntry.id) == item["ace_id"]) + .values(attribute_type_id=item["directory_id"]), + ) + + async def _get_all_raw_aces_legacy(self) -> Sequence[Row[tuple[int, str]]]: + ace_rows_q = await self.__session.execute( + select(qa(AccessControlEntry.id), qa(AttributeTypeLegacy.name)) + .join( + AttributeTypeLegacy, + qa(AccessControlEntry.attribute_type_id) + == qa(AttributeTypeLegacy.id), + ) + .where(qa(AccessControlEntry.attribute_type_id).is_not(None)), + ) + return ace_rows_q.all() + + async def downgrade(self) -> None: + ace_rows = await self._get_all_raw_aces() + if not ace_rows: + return + + attribute_names = {row.name for row in ace_rows} + legacy_rows_q = await self.__session.execute( + select(qa(AttributeTypeLegacy.name), qa(AttributeTypeLegacy.id)) + .where(qa(AttributeTypeLegacy.name).in_(attribute_names)), + ) # fmt: skip + legacy_by_name = {row.name: row.id for row in legacy_rows_q} + + updates = [ + {"ace_id": row.id, "legacy_id": legacy_by_name[row.name]} + for row in ace_rows + if row.name in legacy_by_name + ] + if not updates: + return + + for item in updates: + await self.__session.execute( + update(AccessControlEntry) + .where(qa(AccessControlEntry.id) == item["ace_id"]) + .values(attribute_type_id=item["legacy_id"]), + ) + + async def _get_all_raw_aces(self) -> Sequence[Row[tuple[int, str]]]: + ace_rows_q = await self.__session.execute( + select(qa(AccessControlEntry.id), qa(Directory.name)) + .join( + Directory, + qa(AccessControlEntry.attribute_type_id) == qa(Directory.id), + ) + .join( + EntityType, + qa(EntityType.id) == qa(Directory.entity_type_id), + ) + .where(qa(EntityType.name) == EntityTypeNames.ATTRIBUTE_TYPE) + .where(qa(AccessControlEntry.attribute_type_id).is_not(None)), + ) + return ace_rows_q.all() diff --git a/app/ldap_protocol/utils/raw_definition_parser.py b/app/ldap_protocol/utils/raw_definition_parser.py deleted file mode 100644 index 4fa7361e0..000000000 --- a/app/ldap_protocol/utils/raw_definition_parser.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Raw definition parser. - -Copyright (c) 2024 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -from ldap3.protocol.rfc4512 import AttributeTypeInfo, ObjectClassInfo -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from entities import AttributeType, ObjectClass -from repo.pg.tables import queryable_attr as qa - - -class RawDefinitionParser: - """Parser for ObjectClass and AttributeType raw definition.""" - - @staticmethod - def _list_to_string(data: list[str]) -> str | None: - if not data: - return None - if len(data) == 1: - return data[0] - raise ValueError("Data is not a single element list") - - @staticmethod - def _get_attribute_type_info(raw_definition: str) -> AttributeTypeInfo: - tmp = AttributeTypeInfo.from_definition(definitions=[raw_definition]) - return list(tmp.values())[0] - - @staticmethod - def get_object_class_info(raw_definition: str) -> ObjectClassInfo: - tmp = ObjectClassInfo.from_definition(definitions=[raw_definition]) - return list(tmp.values())[0] - - @staticmethod - async def _get_attribute_types_by_names( - session: AsyncSession, - names: list[str], - ) -> list[AttributeType]: - query = await session.execute( - select(AttributeType) - .where(qa(AttributeType.name).in_(names)), - ) # fmt: skip - return list(query.scalars().all()) - - @staticmethod - def create_attribute_type_by_raw( - raw_definition: str, - ) -> AttributeType: - attribute_type_info = RawDefinitionParser._get_attribute_type_info( - raw_definition=raw_definition, - ) - - return AttributeType( - oid=attribute_type_info.oid, - name=RawDefinitionParser._list_to_string(attribute_type_info.name), # type: ignore[arg-type] - syntax=attribute_type_info.syntax, - single_value=attribute_type_info.single_value, - no_user_modification=attribute_type_info.no_user_modification, - is_system=True, - system_flags=0, - is_included_anr=False, - ) - - @staticmethod - async def _get_object_class_by_name( - object_class_name: str | None, - session: AsyncSession, - ) -> ObjectClass | None: - if not object_class_name: - return None - - return await session.scalar( - select(ObjectClass) - .filter_by(name=object_class_name), - ) # fmt: skip - - @staticmethod - async def create_object_class_by_info( - session: AsyncSession, - object_class_info: ObjectClassInfo, - ) -> ObjectClass: - """Create Object Class by ObjectClassInfo.""" - superior_name = RawDefinitionParser._list_to_string( - object_class_info.superior, - ) - - superior_object_class = ( - await RawDefinitionParser._get_object_class_by_name( - superior_name, - session, - ) - ) - - object_class = ObjectClass( - oid=object_class_info.oid, - name=RawDefinitionParser._list_to_string(object_class_info.name), # type: ignore[arg-type] - superior=superior_object_class, - kind=object_class_info.kind, - is_system=True, - ) - if object_class_info.must_contain: - object_class.attribute_types_must.extend( - await RawDefinitionParser._get_attribute_types_by_names( - session, - object_class_info.must_contain, - ), - ) - if object_class_info.may_contain: - object_class.attribute_types_may.extend( - await RawDefinitionParser._get_attribute_types_by_names( - session, - object_class_info.may_contain, - ), - ) - - return object_class diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index a13db43ae..987cbdba8 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -8,6 +8,7 @@ import uuid from typing import Literal, TypeVar, cast +from entities_legacy import AttributeTypeLegacy, ObjectClassLegacy from sqlalchemy import ( Boolean, CheckConstraint, @@ -36,7 +37,6 @@ from entities import ( AccessControlEntry, Attribute, - AttributeType, AuditDestination, AuditPolicy, AuditPolicyTrigger, @@ -46,7 +46,6 @@ EntityType, Group, NetworkPolicy, - ObjectClass, PasswordBanWord, PasswordPolicy, Role, @@ -523,7 +522,7 @@ def _compile_create_uc( Column( "attributeTypeId", Integer, - ForeignKey("AttributeTypes.id", ondelete="CASCADE"), + ForeignKey("Directory.id", ondelete="CASCADE"), nullable=True, key="attribute_type_id", ), @@ -951,7 +950,7 @@ def _compile_create_uc( lazy="raise", ), "attribute_type": relationship( - AttributeType, + Directory, lazy="raise", uselist=False, ), @@ -966,26 +965,26 @@ def _compile_create_uc( ) mapper_registry.map_imperatively( - AttributeType, + AttributeTypeLegacy, attribute_types_table, ) mapper_registry.map_imperatively( - ObjectClass, + ObjectClassLegacy, object_classes_table, properties={ "superior": relationship( - ObjectClass, + ObjectClassLegacy, remote_side=[object_classes_table.c.name], lazy="raise", ), "attribute_types_must": relationship( - AttributeType, + AttributeTypeLegacy, secondary=object_class_attr_must_table, lazy="raise", ), "attribute_types_may": relationship( - AttributeType, + AttributeTypeLegacy, secondary=object_class_attr_may_table, lazy="raise", ), diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 2c0ac63d2..afc4b1550 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -24,7 +24,7 @@ services: POSTGRES_HOST: postgres # PYTHONTRACEMALLOC: 1 PYTHONDONTWRITEBYTECODE: 1 - command: sh -c "python -B -m pytest -n auto -x -W ignore::DeprecationWarning -W ignore::coverage.exceptions.CoverageWarning -vv" + command: sh -c "python -B -m pytest -n auto -x -W ignore::DeprecationWarning -W ignore::coverage.exceptions.CoverageWarning -W ignore::SyntaxWarning -vv" tty: true postgres: diff --git a/interface b/interface index 5d5a80ee7..13767b9a4 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 5d5a80ee7e9ea073338cac26a57be5f91a8d47f7 +Subproject commit 13767b9a41052532e697850374894731d5eeb7a1 diff --git a/tests/conftest.py b/tests/conftest.py index 9c4337450..524dc2ab4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import uuid from contextlib import suppress from dataclasses import dataclass +from itertools import chain from typing import AsyncGenerator, AsyncIterator, Generator, Iterator from unittest.mock import AsyncMock, Mock @@ -62,8 +63,7 @@ from api.shadow.adapter import ShadowAdapter from authorization_provider_protocol import AuthorizationProviderProtocol from config import Settings -from constants import ENTITY_TYPE_DATAS -from entities import AttributeType +from constants import ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2 from enums import AuthorizationRules from ioc import AuditRedisClient, MFACredsProvider, SessionStorageClient from ldap_protocol.auth import AuthManager, MFAManager @@ -96,21 +96,46 @@ LDAPSearchRequestContext, LDAPUnbindRequestContext, ) -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO -from ldap_protocol.ldap_schema.attribute_type_system_flags_use_case import ( +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_dao import ( # noqa: E501 + AttributeTypeDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.attribute_type.attribute_type_use_case import ( # noqa: E501 + AttributeTypeUseCaseLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_dao import ( + ObjectClassDAOLegacy, +) +from ldap_protocol.ldap_schema._legacy.object_class.object_class_use_case import ( # noqa: E501 + ObjectClassUseCaseLegacy, +) +from ldap_protocol.ldap_schema.attribute_dao import AttributeDAO +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.attribute_type.attribute_type_system_flags_use_case import ( # noqa: E501 AttributeTypeSystemFlagsUseCase, ) -from ldap_protocol.ldap_schema.attribute_type_use_case import ( +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( AttributeTypeUseCase, ) from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO -from ldap_protocol.ldap_schema.object_class_use_case import ObjectClassUseCase +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) +from ldap_protocol.ldap_schema.object_class.object_class_use_case import ( + ObjectClassUseCase, +) +from ldap_protocol.ldap_schema.schema_create_use_case import ( + SchemaLikeAsDirectoryCreateUseCase, +) from ldap_protocol.master_check_use_case import ( MasterCheckUseCase, MasterGatewayProtocol, @@ -152,6 +177,9 @@ from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import RoleDTO +from ldap_protocol.roles.migrations_ace_dao import ( + AccessControlEntryMigrationsDAO, +) from ldap_protocol.roles.role_dao import RoleDAO from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.rootdse.gateway import SADomainGateway @@ -163,7 +191,12 @@ from ldap_protocol.utils.queries import get_user from password_utils import PasswordUtils from repo.pg.master_gateway import PGMasterGateway -from tests.constants import TEST_DATA +from tests.constants import ( + TEST_DATA, + admin_user_data_dict, + user_data_dict, + user_with_login_perm_data_dict, +) class TestProvider(Provider): @@ -296,8 +329,24 @@ async def get_dns_mngr_settings( domain.name, ) + schema_create_use_case = provide( + SchemaLikeAsDirectoryCreateUseCase, + scope=Scope.REQUEST, + ) attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) + attribute_type_dao_legacy = provide( + AttributeTypeDAOLegacy, + scope=Scope.REQUEST, + ) + object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) + object_class_dao_legacy = provide( + ObjectClassDAOLegacy, + scope=Scope.REQUEST, + ) + attribute_dao = provide(AttributeDAO, scope=Scope.REQUEST) + directory_dao = provide(DirectoryDAO, scope=Scope.REQUEST) + entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) attribute_type_system_flags_use_case = provide( AttributeTypeSystemFlagsUseCase, @@ -307,8 +356,33 @@ async def get_dns_mngr_settings( AttributeTypeUseCase, scope=Scope.REQUEST, ) + + @provide(scope=Scope.REQUEST) + def get_attribute_type_use_case_legacy( + self, + session: AsyncSession, + ) -> AttributeTypeUseCaseLegacy: + """Legacy attribute type use case bound to a single session.""" + at_dao_legacy = AttributeTypeDAOLegacy(session=session) + return AttributeTypeUseCaseLegacy( + attribute_type_dao_legacy=at_dao_legacy, + ) + object_class_use_case = provide(ObjectClassUseCase, scope=Scope.REQUEST) + @provide(scope=Scope.REQUEST) + def get_object_class_use_case_legacy( + self, + session: AsyncSession, + ) -> ObjectClassUseCaseLegacy: + """Legacy object class use case bound to a single session for all DAOs.""" # noqa: E501 + at_dao_legacy = AttributeTypeDAOLegacy(session=session) + oc_dao_legacy = ObjectClassDAOLegacy(session=session) + return ObjectClassUseCaseLegacy( + attribute_type_dao_legacy=at_dao_legacy, + object_class_dao_legacy=oc_dao_legacy, + ) + user_password_history_use_cases = provide( UserPasswordHistoryUseCases, scope=Scope.REQUEST, @@ -373,7 +447,7 @@ def get_session_factory( autocommit=False, ) - @provide(scope=Scope.APP, cache=False) + @provide(scope=Scope.APP) async def get_session( self, engine: AsyncEngine, @@ -489,6 +563,10 @@ async def get_session_storage( role_dao = provide(RoleDAO, scope=Scope.REQUEST, cache=False) ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) + ace_migrations_dao = provide( + AccessControlEntryMigrationsDAO, + scope=Scope.REQUEST, + ) access_manager = provide(AccessManager, scope=Scope.REQUEST) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) @@ -944,24 +1022,55 @@ async def setup_session( password_utils: PasswordUtils, ) -> None: """Get session and acquire after completion.""" - object_class_dao = ObjectClassDAO(session) + role_dao = RoleDAO(session) + ace_dao = AccessControlEntryDAO(session) + role_use_case = RoleUseCase(role_dao, ace_dao) attribute_value_validator = AttributeValueValidator() + attribute_type_dao = AttributeTypeDAO(session) + attribute_type_system_flags_use_case = AttributeTypeSystemFlagsUseCase() + object_class_dao_legacy = ObjectClassDAOLegacy(session=session) + attribute_type_dao_legacy = AttributeTypeDAOLegacy(session=session) + object_class_use_case_legacy = ObjectClassUseCaseLegacy( + attribute_type_dao_legacy=attribute_type_dao_legacy, + object_class_dao_legacy=object_class_dao_legacy, + ) + attribute_type_use_case_legacy = AttributeTypeUseCaseLegacy( + attribute_type_dao_legacy=attribute_type_dao_legacy, + ) + + object_class_dao = ObjectClassDAO(session) + directory_dao = DirectoryDAO(session) + attribute_dao = AttributeDAO(session) entity_type_dao = EntityTypeDAO( session, - object_class_dao=object_class_dao, attribute_value_validator=attribute_value_validator, + directory_dao=directory_dao, + ) + entity_type_use_case = EntityTypeUseCase( + entity_type_dao=entity_type_dao, + object_class_dao=object_class_dao, + directory_dao=directory_dao, + ) + schema_create_use_case = SchemaLikeAsDirectoryCreateUseCase( + session=session, + entity_type_use_case=entity_type_use_case, + role_use_case=role_use_case, + directory_dao=directory_dao, + attribute_dao=attribute_dao, + ) + object_class_use_case = ObjectClassUseCase( + attribute_type_dao=attribute_type_dao, + object_class_dao=object_class_dao, + entity_type_dao=entity_type_dao, + schema_create_use_case=schema_create_use_case, ) - for entity_type_data in ENTITY_TYPE_DATAS: - await entity_type_dao.create( - dto=EntityTypeDTO( - id=None, - name=entity_type_data["name"], - object_class_names=entity_type_data["object_class_names"], - is_system=True, - ), - ) - await session.flush() + attribute_type_use_case = AttributeTypeUseCase( + attribute_type_dao=attribute_type_dao, + attribute_type_system_flags_use_case=attribute_type_system_flags_use_case, + object_class_dao=object_class_dao, + schema_create_use_case=schema_create_use_case, + ) audit_policy_dao = AuditPoliciesDAO(session) audit_destination_dao = AuditDestinationDAO(session) @@ -987,9 +1096,15 @@ async def setup_session( setup_gateway = SetupGateway( session, password_utils, - entity_type_dao, + entity_type_use_case=entity_type_use_case, attribute_value_validator=attribute_value_validator, + directory_dao=directory_dao, ) + + for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): + await entity_type_use_case.create_not_safe(entity_type_dto) + await session.flush() + await audit_use_case.create_policies() await setup_gateway.setup_enviroment( dn="md.test", @@ -997,44 +1112,91 @@ async def setup_session( is_system=False, ) - # NOTE: after setup environment we need base DN to be created - await password_use_cases.create_default_domain_policy() - - role_dao = RoleDAO(session) - ace_dao = AccessControlEntryDAO(session) - role_use_case = RoleUseCase(role_dao, ace_dao) - await role_use_case.create_domain_admins_role() - - await role_use_case._role_dao.create( # noqa: SLF001 - dto=RoleDTO( - name="TEST ONLY LOGIN ROLE", - creator_upn=None, - is_system=True, - groups=["cn=admin login only,cn=Groups,dc=md,dc=test"], - permissions=AuthorizationRules.AUTH_LOGIN, - ), - ) - - session.add( - AttributeType( + for _at_dto in ( + AttributeTypeDTO[None]( oid="1.2.3.4.5.6.7.8", name="attr_with_bvalue", syntax="1.3.6.1.4.1.1466.115.121.1.40", # Octet String single_value=True, no_user_modification=False, is_system=True, + system_flags=0, + is_included_anr=False, ), - ) - session.add( - AttributeType( + AttributeTypeDTO[None]( oid="1.2.3.4.5.6.7.8.9", name="testing_attr", syntax="1.3.6.1.4.1.1466.115.121.1.15", single_value=True, no_user_modification=False, is_system=True, + system_flags=0, + is_included_anr=False, + ), + ): + await attribute_type_use_case.create(_at_dto) + + for attr_type_name in ( + "description", + "posixEmail", + "userPrincipalName", + "userAccountControl", + "cn", + "objectClass", + ): + _at = await attribute_type_use_case_legacy.get( + attr_type_name, + ) + if not _at: + raise ValueError( + f"setup_session:: AttributeType {attr_type_name} not found", + ) + await attribute_type_use_case.create(_at) + + for _obj_class_name in ( + "top", + "person", + "organizationalPerson", + "user", + "domain", + "container", + "organization", + "domainDNS", + "group", + "inetOrgPerson", + "posixAccount", + ): + _oc_dto = await object_class_use_case_legacy.get(_obj_class_name) + _oc_dto.attribute_types_may = [ + _.name # type: ignore + for _ in _oc_dto.attribute_types_may + ] + _oc_dto.attribute_types_must = [ + _.name # type: ignore + for _ in _oc_dto.attribute_types_must + ] + await object_class_use_case.create(_oc_dto) # type: ignore + + await attribute_type_use_case_legacy.delete_table() + await object_class_use_case_legacy.delete_may_table() + await object_class_use_case_legacy.delete_must_table() + await object_class_use_case_legacy.delete_main_table() + + # NOTE: after setup environment we need base DN to be created + await password_use_cases.create_default_domain_policy() + + await role_use_case.create_domain_admins_role() + + await role_use_case._role_dao.create( # noqa: SLF001 + dto=RoleDTO( + name="TEST ONLY LOGIN ROLE", + creator_upn=None, + is_system=True, + groups=["cn=admin login only,cn=Groups,dc=md,dc=test"], + permissions=AuthorizationRules.AUTH_LOGIN, ), ) + await session.commit() @@ -1106,19 +1268,28 @@ async def entity_type_dao( container: AsyncContainer, ) -> AsyncIterator[EntityTypeDAO]: """Get session and acquire after completion.""" - async with container(scope=Scope.APP) as container: + async with container(scope=Scope.REQUEST) as container: session = await container.get(AsyncSession) - object_class_dao = ObjectClassDAO(session) attribute_value_validator = await container.get( AttributeValueValidator, ) + directory_dao = await container.get(DirectoryDAO) yield EntityTypeDAO( session, - object_class_dao, attribute_value_validator=attribute_value_validator, + directory_dao=directory_dao, ) +@pytest_asyncio.fixture(scope="function") +async def entity_type_use_case( + container: AsyncContainer, +) -> AsyncIterator[EntityTypeUseCase]: + """Get entity type use case.""" + async with container(scope=Scope.REQUEST) as container: + yield await container.get(EntityTypeUseCase) + + @pytest_asyncio.fixture(scope="function") async def password_policy_dao( container: AsyncContainer, @@ -1198,7 +1369,7 @@ async def attribute_type_dao( container: AsyncContainer, ) -> AsyncIterator[AttributeTypeDAO]: """Get session and acquire after completion.""" - async with container(scope=Scope.APP) as container: + async with container(scope=Scope.REQUEST) as container: session = await container.get(AsyncSession) yield AttributeTypeDAO(session) @@ -1394,12 +1565,6 @@ def creds(user: dict) -> TestCreds: return TestCreds(user["sam_account_name"], user["password"]) -@pytest.fixture -def user() -> dict: - """Get user data.""" - return TEST_DATA[1]["children"][0]["organizationalPerson"] # type: ignore - - @pytest.fixture def creds_with_login_perm(user_with_login_perm: dict) -> TestCreds: """Get creds from test data.""" @@ -1419,15 +1584,21 @@ def admin_creds(admin_user: dict) -> TestAdminCreds: @pytest.fixture -def user_with_login_perm() -> dict: +def user() -> dict: """Get user data.""" - return TEST_DATA[1]["children"][2]["organizationalPerson"] # type: ignore + return user_data_dict @pytest.fixture def admin_user() -> dict: """Get admin user data.""" - return TEST_DATA[1]["children"][1]["organizationalPerson"] # type: ignore + return admin_user_data_dict + + +@pytest.fixture +def user_with_login_perm() -> dict: + """Get user data.""" + return user_with_login_perm_data_dict @pytest.fixture diff --git a/tests/constants.py b/tests/constants.py index ab5ffb954..abc511790 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -5,26 +5,57 @@ """ from constants import ( + CONFIGURATION_DIR_NAME, DOMAIN_ADMIN_GROUP_NAME, DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME, GROUPS_CONTAINER_NAME, USERS_CONTAINER_NAME, ) -from enums import SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes from ldap_protocol.objects import UserAccountControlFlag +user_data_dict = { + "sam_account_name": "user0", + "user_principal_name": "user0", + "mail": "user0@mail.com", + "display_name": "user0", + "password": "password", + "groups": [DOMAIN_ADMIN_GROUP_NAME], +} + +admin_user_data_dict = { + "sam_account_name": "user_admin", + "user_principal_name": "user_admin", + "mail": "user_admin@mail.com", + "display_name": "user_admin", + "password": "password", + "groups": [DOMAIN_ADMIN_GROUP_NAME], +} + +user_with_login_perm_data_dict = { + "sam_account_name": "user_admin_for_roles", + "user_principal_name": "user_admin_for_roles", + "mail": "user_admin_for_roles@mail.com", + "display_name": "user_admin_for_roles", + "password": "password", + "groups": ["admin login only"], +} + + TEST_DATA = [ { "name": GROUPS_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": { - "objectClass": ["top"], + "objectClass": ["top", "container"], "sAMAccountName": ["groups"], }, "children": [ { "name": DOMAIN_ADMIN_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -39,6 +70,7 @@ }, { "name": "developers", + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "groups": [DOMAIN_ADMIN_GROUP_NAME], "attributes": { @@ -53,6 +85,7 @@ }, { "name": "admin login only", + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -66,6 +99,7 @@ }, { "name": DOMAIN_USERS_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -79,6 +113,7 @@ }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -94,20 +129,15 @@ }, { "name": USERS_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": {"objectClass": ["top"]}, "children": [ { "name": "user0", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", - "organizationalPerson": { - "sam_account_name": "user0", - "user_principal_name": "user0", - "mail": "user0@mail.com", - "display_name": "user0", - "password": "password", - "groups": [DOMAIN_ADMIN_GROUP_NAME], - }, + "organizationalPerson": user_data_dict, "attributes": { "givenName": ["John"], "surname": ["Lennon"], @@ -129,15 +159,9 @@ }, { "name": "user_admin", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", - "organizationalPerson": { - "sam_account_name": "user_admin", - "user_principal_name": "user_admin", - "mail": "user_admin@mail.com", - "display_name": "user_admin", - "password": "password", - "groups": [DOMAIN_ADMIN_GROUP_NAME], - }, + "organizationalPerson": admin_user_data_dict, "attributes": { "objectClass": [ "top", @@ -156,15 +180,9 @@ }, { "name": "user_admin_for_roles", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", - "organizationalPerson": { - "sam_account_name": "user_admin_for_roles", - "user_principal_name": "user_admin_for_roles", - "mail": "user_admin_for_roles@mail.com", - "display_name": "user_admin_for_roles", - "password": "password", - "groups": ["admin login only"], - }, + "organizationalPerson": user_with_login_perm_data_dict, "attributes": { "objectClass": [ "top", @@ -183,6 +201,7 @@ }, { "name": "user_non_admin", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "user_non_admin", @@ -211,6 +230,7 @@ }, { "name": "russia", + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": { "objectClass": ["top"], @@ -219,6 +239,7 @@ "children": [ { "name": "moscow", + "entity_type_name": EntityTypeNames.CONTAINER, "object_class": "container", "attributes": { "objectClass": ["top"], @@ -227,6 +248,7 @@ "children": [ { "name": "user1", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "user1", @@ -262,11 +284,13 @@ }, { "name": "test_bit_rules", + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": {"objectClass": ["top", "container"]}, "children": [ { "name": "user_admin_1", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "user_admin_1", @@ -299,6 +323,7 @@ }, { "name": "user_admin_2", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "user_admin_2", @@ -329,6 +354,7 @@ }, { "name": "user_admin_3", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "user_admin_3", @@ -358,6 +384,7 @@ }, { "name": "testModifyDn1", + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], @@ -366,6 +393,7 @@ "children": [ { "name": "testModifyDn2", + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], @@ -374,6 +402,7 @@ "children": [ { "name": "testGroup1", + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -391,6 +420,7 @@ }, { "name": "testGroup2", + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -406,6 +436,7 @@ }, { "name": "testModifyDn3", + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], @@ -414,6 +445,7 @@ "children": [ { "name": "testGroup3", + "entity_type_name": EntityTypeNames.GROUP, "object_class": "group", "attributes": { "objectClass": ["top", "posixGroup"], @@ -427,10 +459,18 @@ }, ], }, + { + "name": CONFIGURATION_DIR_NAME, + "entity_type_name": EntityTypeNames.CONFIGURATION, + "object_class": "container", + "attributes": {"objectClass": ["top", "configuration"]}, + "children": [], + }, ] TEST_SYSTEM_ADMIN_DATA = { "name": "System Administrator", + "entity_type_name": EntityTypeNames.USER, "object_class": "user", "organizationalPerson": { "sam_account_name": "system_admin", diff --git a/tests/test_api/test_ldap_schema/test_attribute_type_router.py b/tests/test_api/test_ldap_schema/test_attribute_type_router.py index bc9018948..b31790507 100644 --- a/tests/test_api/test_ldap_schema/test_attribute_type_router.py +++ b/tests/test_api/test_ldap_schema/test_attribute_type_router.py @@ -82,7 +82,7 @@ async def test_get_list_attribute_types_with_pagination( ) -> None: """Test retrieving a list of attribute types.""" page_number = 1 - page_size = 50 + page_size = 3 response = await http_client.get( f"/schema/attribute_types?page_number={page_number}&page_size={page_size}", ) @@ -133,7 +133,7 @@ async def test_modify_one_attribute_type( response = await http_client.patch( f"/schema/attribute_type/{attribute_type_name}", - json=dataset["attribute_type_changes"], + json=dataset["attribute_type_changes"].model_dump(), ) assert response.status_code == dataset["status_code"] @@ -142,7 +142,9 @@ async def test_modify_one_attribute_type( f"/schema/attribute_type/{attribute_type_name}", ) attribute_type_json = response.json() - for field_name, value in dataset["attribute_type_changes"].items(): + for field_name, value in ( + dataset["attribute_type_changes"].model_dump().items() + ): assert attribute_type_json.get(field_name) == value diff --git a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py index e04eecc8d..18dcac0d5 100644 --- a/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py +++ b/tests/test_api/test_ldap_schema/test_attribute_type_router_datasets.py @@ -2,7 +2,10 @@ from fastapi import status -from api.ldap_schema.schema import AttributeTypeSchema +from api.ldap_schema.schema import ( + AttributeTypeSchema, + AttributeTypeUpdateSchema, +) test_modify_one_attribute_type_dataset = [ { @@ -16,12 +19,12 @@ is_system=False, is_included_anr=False, ), - "attribute_type_changes": { - "syntax": "1.3.6.1.4.1.1466.115.121.1.15", - "single_value": True, - "no_user_modification": False, - "is_included_anr": False, - }, + "attribute_type_changes": AttributeTypeUpdateSchema( + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_included_anr=False, + ), "status_code": status.HTTP_200_OK, }, { @@ -35,12 +38,12 @@ is_system=False, is_included_anr=False, ), - "attribute_type_changes": { - "syntax": "1.3.6.1.4.1.1466.115.121.1.15", - "single_value": True, - "no_user_modification": False, - "is_included_anr": False, - }, + "attribute_type_changes": AttributeTypeUpdateSchema( + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_included_anr=False, + ), "status_code": status.HTTP_400_BAD_REQUEST, }, { @@ -54,12 +57,12 @@ is_system=True, is_included_anr=False, ), - "attribute_type_changes": { - "syntax": "1.3.6.1.4.1.1466.115.121.1.15", - "single_value": True, - "no_user_modification": False, - "is_included_anr": False, - }, + "attribute_type_changes": AttributeTypeUpdateSchema( + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_included_anr=False, + ), "status_code": status.HTTP_200_OK, }, ] diff --git a/tests/test_api/test_ldap_schema/test_entity_type_router.py b/tests/test_api/test_ldap_schema/test_entity_type_router.py index b7a40c66e..3fb0bbddb 100644 --- a/tests/test_api/test_ldap_schema/test_entity_type_router.py +++ b/tests/test_api/test_ldap_schema/test_entity_type_router.py @@ -4,7 +4,8 @@ from fastapi import status from httpx import AsyncClient -from constants import ENTITY_TYPE_DATAS +from api.ldap_schema.schema import AttributeTypeSchema, EntityTypeSchema +from constants import ENTITY_TYPE_DTOS_V1 from enums import EntityTypeNames from .test_entity_type_router_datasets import ( @@ -123,6 +124,73 @@ async def test_get_list_entity_types_with_pagination( assert len(response.json().get("items")) == page_size +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +async def test_get_entity_type_attributes(http_client: AsyncClient) -> None: + """Test retrieving attribute names for an entity type.""" + attribute_types = [ + AttributeTypeSchema( + oid="1.2.3.100", + name="testEntityTypeAttr1", + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=False, + is_included_anr=False, + ), + AttributeTypeSchema( + oid="1.2.3.101", + name="testEntityTypeAttr2", + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=False, + is_included_anr=False, + ), + ] + for attribute_type in attribute_types: + response = await http_client.post( + "/schema/attribute_type", + json=attribute_type.model_dump(), + ) + assert response.status_code == status.HTTP_201_CREATED + + object_class_name = "testEntityTypeObjectClass" + response = await http_client.post( + "/schema/object_class", + json={ + "oid": "1.2.3.102", + "name": object_class_name, + "superior_name": None, + "kind": "STRUCTURAL", + "is_system": False, + "attribute_type_names_must": ["testEntityTypeAttr1"], + "attribute_type_names_may": ["testEntityTypeAttr2"], + }, + ) + assert response.status_code == status.HTTP_201_CREATED + + entity_type_name = "testEntityTypeWithAttrs" + response = await http_client.post( + "/schema/entity_type", + json=EntityTypeSchema( + name=entity_type_name, + object_class_names=[object_class_name], + is_system=False, + ).model_dump(), + ) + assert response.status_code == status.HTTP_201_CREATED + + response = await http_client.get( + f"/schema/entity_type/{entity_type_name}/attrs", + ) + assert response.status_code == status.HTTP_200_OK + assert set(response.json()) == { + "testEntityTypeAttr1", + "testEntityTypeAttr2", + } + + @pytest.mark.parametrize( "dataset", test_modify_entity_type_with_duplicates_dataset, @@ -215,19 +283,19 @@ async def test_modify_primary_entity_type_name( ) -> None: """Test modifying a primary entity type name.""" new_statement = "TestEntityTypeName" - entity_type_data = ENTITY_TYPE_DATAS[0] + entity_type_dto = ENTITY_TYPE_DTOS_V1[0] response = await http_client.patch( - f"/schema/entity_type/{entity_type_data['name']}", - json={ - "name": new_statement, - "is_system": True, - "object_class_names": entity_type_data["object_class_names"], - }, + f"/schema/entity_type/{entity_type_dto.name}", + json=EntityTypeSchema( + name=new_statement, + object_class_names=entity_type_dto.object_class_names, + is_system=entity_type_dto.is_system, + ).model_dump(), ) assert response.status_code == status.HTTP_400_BAD_REQUEST response = await http_client.get( - f"/schema/entity_type/{entity_type_data['name']}", + f"/schema/entity_type/{entity_type_dto.name}", ) assert response.status_code == status.HTTP_200_OK assert isinstance(response.json(), dict) diff --git a/tests/test_api/test_ldap_schema/test_object_class_router.py b/tests/test_api/test_ldap_schema/test_object_class_router.py index 6e04cecdc..1359c7d4e 100644 --- a/tests/test_api/test_ldap_schema/test_object_class_router.py +++ b/tests/test_api/test_ldap_schema/test_object_class_router.py @@ -124,7 +124,7 @@ async def test_get_list_object_classes_with_pagination( ) -> None: """Test retrieving a list of object classes.""" page_number = 1 - page_size = 25 + page_size = 7 response = await http_client.get( f"/schema/object_classes?page_number={page_number}&page_size={page_size}", ) @@ -170,6 +170,7 @@ async def test_modify_one_object_class( assert response.status_code == status.HTTP_200_OK assert isinstance(response.json(), dict) object_class = response.json() + assert set(object_class.get("attribute_type_names_must")) == set( new_statement.get("attribute_type_names_must"), ) diff --git a/tests/test_api/test_main/test_router/conftest.py b/tests/test_api/test_main/test_router/conftest.py index 5ec37b884..dc26b0577 100644 --- a/tests/test_api/test_main/test_router/conftest.py +++ b/tests/test_api/test_main/test_router/conftest.py @@ -11,8 +11,14 @@ from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.ldap_schema.object_class_dao import ObjectClassDAO +from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.ldap_schema.object_class.object_class_dao import ( + ObjectClassDAO, +) from ldap_protocol.utils.queries import get_base_directories from password_utils import PasswordUtils from tests.constants import TEST_SYSTEM_ADMIN_DATA @@ -25,19 +31,26 @@ async def add_system_administrator( setup_session: None, # noqa: ARG001 ) -> None: """Create system administrator user for tests that require it.""" - object_class_dao = ObjectClassDAO(session) attribute_value_validator = AttributeValueValidator() + object_class_dao = ObjectClassDAO(session) + directory_dao = DirectoryDAO(session) entity_type_dao = EntityTypeDAO( - session, - object_class_dao=object_class_dao, + session=session, attribute_value_validator=attribute_value_validator, + directory_dao=directory_dao, + ) + entity_type_use_case = EntityTypeUseCase( + entity_type_dao=entity_type_dao, + object_class_dao=object_class_dao, + directory_dao=directory_dao, ) setup_gateway = SetupGateway( session, password_utils, - entity_type_dao, + entity_type_use_case, attribute_value_validator=attribute_value_validator, + directory_dao=directory_dao, ) domain = (await get_base_directories(session))[0] diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index 1c591bd17..c4c604ed1 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -131,6 +131,7 @@ async def test_api_search(http_client: AsyncClient) -> None: sub_dirs = { "cn=Groups,dc=md,dc=test", + "cn=Configuration,dc=md,dc=test", "cn=Users,dc=md,dc=test", "ou=testModifyDn1,dc=md,dc=test", "ou=testModifyDn3,dc=md,dc=test", diff --git a/tests/test_ldap/test_ldap_schema/conftest.py b/tests/test_ldap/test_ldap_schema/conftest.py index 75b13a356..49621627a 100644 --- a/tests/test_ldap/test_ldap_schema/conftest.py +++ b/tests/test_ldap/test_ldap_schema/conftest.py @@ -9,7 +9,7 @@ import pytest_asyncio from dishka import AsyncContainer, Scope -from ldap_protocol.ldap_schema.attribute_type_use_case import ( +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( AttributeTypeUseCase, ) diff --git a/tests/test_ldap/test_ldap_schema/test_attribute_type_system_flags_use_case.py b/tests/test_ldap/test_ldap_schema/test_attribute_type_system_flags_use_case.py new file mode 100644 index 000000000..e91b7ba8c --- /dev/null +++ b/tests/test_ldap/test_ldap_schema/test_attribute_type_system_flags_use_case.py @@ -0,0 +1,65 @@ +"""Test AttributeTypeUseCase. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import pytest + +from ldap_protocol.ldap_schema.attribute_type.attribute_type_use_case import ( + AttributeTypeUseCase, +) +from ldap_protocol.ldap_schema.dto import AttributeTypeDTO + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_attribute_type_system_flags_use_case_is_not_replicated( + attribute_type_use_case: AttributeTypeUseCase, +) -> None: + """Test AttributeType is not replicated.""" + await attribute_type_use_case.create( + AttributeTypeDTO( + oid="1.2.3.4", + name="objectClass123", + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=False, + system_flags=0x00000001, # ATTR_NOT_REPLICATED + is_included_anr=False, + ), + ) + assert not await attribute_type_use_case.is_attr_replicated( + "objectClass123", + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_attribute_type_system_flags_use_case_is_replicated( + attribute_type_use_case: AttributeTypeUseCase, +) -> None: + """Test AttributeType is replicated.""" + await attribute_type_use_case.create( + AttributeTypeDTO( + oid="1.2.3.4", + name="objectClass123", + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=False, + system_flags=0x00000000, # ATTR_NOT_REPLICATED + is_included_anr=False, + ), + ) + assert await attribute_type_use_case.is_attr_replicated("objectClass123") + await attribute_type_use_case.set_attr_replication_flag( + "objectClass123", + False, + ) + assert not await attribute_type_use_case.is_attr_replicated( + "objectClass123", + ) diff --git a/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py b/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py deleted file mode 100644 index 0c359351a..000000000 --- a/tests/test_ldap/test_ldap_schema/test_attribute_type_use_case.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Test AttributeTypeUseCase. - -Copyright (c) 2026 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -import pytest - -from ldap_protocol.ldap_schema.attribute_type_use_case import ( - AttributeTypeUseCase, -) - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("session") -@pytest.mark.usefixtures("setup_session") -async def test_attribute_type_system_flags_use_case_is_not_replicated( - attribute_type_use_case: AttributeTypeUseCase, -) -> None: - """Test AttributeType is not replicated.""" - assert not await attribute_type_use_case.is_attr_replicated("netbootSCPBL") - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("session") -@pytest.mark.usefixtures("setup_session") -async def test_attribute_type_system_flags_use_case_is_replicated( - attribute_type_use_case: AttributeTypeUseCase, -) -> None: - """Test AttributeType is replicated.""" - assert await attribute_type_use_case.is_attr_replicated("objectClass") - await attribute_type_use_case.set_attr_replication_flag( - "objectClass", - False, - ) - assert not await attribute_type_use_case.is_attr_replicated("objectClass") diff --git a/tests/test_ldap/test_ldap3_definition_parse.py b/tests/test_ldap/test_ldap_schema/test_ldap3_definition_parse.py similarity index 80% rename from tests/test_ldap/test_ldap3_definition_parse.py rename to tests/test_ldap/test_ldap_schema/test_ldap3_definition_parse.py index 14ce27f7a..20621ce1a 100644 --- a/tests/test_ldap/test_ldap3_definition_parse.py +++ b/tests/test_ldap/test_ldap_schema/test_ldap3_definition_parse.py @@ -5,10 +5,14 @@ """ import pytest -from sqlalchemy.ext.asyncio import AsyncSession -from entities import AttributeType, ObjectClass -from ldap_protocol.utils.raw_definition_parser import ( +from ldap_protocol.ldap_schema.attribute_type.attribute_type_raw_display import ( # noqa: E501 + AttributeTypeRawDisplay, +) +from ldap_protocol.ldap_schema.object_class.object_class_raw_display import ( + ObjectClassRawDisplay, +) +from ldap_protocol.ldap_schema.raw_definition_parser import ( RawDefinitionParser as RDParser, ) @@ -38,11 +42,12 @@ async def test_ldap3_parse_attribute_types(test_dataset: list[str]) -> None: """Test parse ldap3 attribute types.""" for raw_definition in test_dataset: - attribute_type: AttributeType = RDParser.create_attribute_type_by_raw( + attribute_type_dto = RDParser.collect_attribute_type_dto_from_raw( raw_definition, ) - - assert raw_definition == attribute_type.get_raw_definition() + assert raw_definition == AttributeTypeRawDisplay.get_raw_definition( + attribute_type_dto, + ) test_ldap3_parse_object_classes_dataset = [ @@ -60,7 +65,6 @@ async def test_ldap3_parse_attribute_types(test_dataset: list[str]) -> None: ) @pytest.mark.asyncio async def test_ldap3_parse_object_classes( - session: AsyncSession, test_dataset: list[str], ) -> None: """Test parse ldap3 object classes.""" @@ -68,9 +72,10 @@ async def test_ldap3_parse_object_classes( object_class_info = RDParser.get_object_class_info( raw_definition=raw_definition, ) - object_class: ObjectClass = await RDParser.create_object_class_by_info( - session=session, + object_class_dto = await RDParser.collect_object_class_dto_from_info( object_class_info=object_class_info, ) - assert raw_definition == object_class.get_raw_definition() + assert raw_definition == ObjectClassRawDisplay.get_raw_definition( + object_class_dto, + ) diff --git a/tests/test_ldap/test_roles/test_multiple_access.py b/tests/test_ldap/test_roles/test_multiple_access.py index 4691ba0fb..73a902bb0 100644 --- a/tests/test_ldap/test_roles/test_multiple_access.py +++ b/tests/test_ldap/test_roles/test_multiple_access.py @@ -14,8 +14,10 @@ from config import Settings from entities import Directory from enums import AceType, EntityTypeNames, RoleScope -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from ldap_protocol.utils.queries import get_filter_from_path diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index 0795be89b..4e632a5f8 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -8,8 +8,10 @@ from config import Settings from enums import AceType, EntityTypeNames, RoleScope -from ldap_protocol.ldap_schema.attribute_type_dao import AttributeTypeDAO -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.attribute_type.attribute_type_dao import ( + AttributeTypeDAO, +) +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO from tests.conftest import TestCreds @@ -102,6 +104,7 @@ async def test_role_search_3( creds=creds, search_base=BASE_DN, expected_dn=[ + "dn: cn=Configuration,dc=md,dc=test", "dn: cn=Groups,dc=md,dc=test", "dn: cn=Users,dc=md,dc=test", "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index b5eadf172..4a51c8cf2 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -21,7 +21,7 @@ from enums import AceType, EntityTypeNames, RoleScope from ldap_protocol.kerberos.base import AbstractKadmin from ldap_protocol.ldap_codes import LDAPCodes -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_dao import EntityTypeDAO from ldap_protocol.objects import Operation from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import AccessControlEntryDTO, RoleDTO diff --git a/tests/test_shedule.py b/tests/test_shedule.py index fa293902a..a952b94e1 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -14,7 +14,9 @@ from extra.scripts.uac_sync import disable_accounts from extra.scripts.update_krb5_config import update_krb5_config from ldap_protocol.kerberos import AbstractKadmin -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.roles.role_use_case import RoleUseCase @@ -85,12 +87,12 @@ async def test_add_domain_controller( session: AsyncSession, settings: Settings, role_use_case: RoleUseCase, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, ) -> None: """Test add domain controller.""" await add_domain_controller( settings=settings, session=session, role_use_case=role_use_case, - entity_type_dao=entity_type_dao, + entity_type_use_case=entity_type_use_case, )