diff --git a/api/draft_registrations/permissions.py b/api/draft_registrations/permissions.py index 5232ee9d546..fb6829aafa0 100644 --- a/api/draft_registrations/permissions.py +++ b/api/draft_registrations/permissions.py @@ -1,11 +1,13 @@ from rest_framework import permissions +from api.base.exceptions import Conflict from api.base.utils import get_user_auth, assert_resource_type from osf.models import ( DraftRegistration, AbstractNode, DraftRegistrationContributor, OSFUser, + RegistrationProvider, ) from api.nodes.permissions import ContributorDetailPermissions from osf.utils.permissions import WRITE, ADMIN @@ -90,3 +92,32 @@ def has_object_permission(self, request, view, obj): elif isinstance(obj, AbstractNode): return obj.has_permission(auth.user, WRITE) return False + + +class CanSubmitDraftRegistrationToProvider(permissions.BasePermission): + """ + Prevent creating draft registrations for providers that are closed to submissions. + """ + + def has_permission(self, request, view): + if request.method != 'POST': + return True + + provider_id = request.data.get('provider') + + if not provider_id: + try: + provider = RegistrationProvider.get_default() + except RegistrationProvider.DoesNotExist: + return True + else: + try: + provider = RegistrationProvider.objects.get(_id=provider_id) + except RegistrationProvider.DoesNotExist: + # Let existing validation handle bad provider ids. + return True + + if not provider.allow_submissions: + raise Conflict(f"Registry {provider.name} is closed for new submissions. Please start a new registration with a different registry.") + + return True diff --git a/api/draft_registrations/views.py b/api/draft_registrations/views.py index 30c583dd94a..88a266c9ab6 100644 --- a/api/draft_registrations/views.py +++ b/api/draft_registrations/views.py @@ -8,6 +8,7 @@ DraftContributorDetailPermissions, DraftRegistrationPermission, IsAdminContributor, + CanSubmitDraftRegistrationToProvider, ) from api.draft_registrations.serializers import ( DraftRegistrationSerializer, @@ -53,6 +54,7 @@ class DraftRegistrationList(NodeDraftRegistrationsList): drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, DraftRegistrationPermission, + CanSubmitDraftRegistrationToProvider, ) view_category = 'draft_registrations' diff --git a/api/nodes/views.py b/api/nodes/views.py index 69c7fa7d25e..51648d365cd 100644 --- a/api/nodes/views.py +++ b/api/nodes/views.py @@ -74,7 +74,10 @@ NodeCommentSerializer, ) from api.draft_registrations.serializers import DraftRegistrationSerializer, DraftRegistrationDetailSerializer -from api.draft_registrations.permissions import DraftRegistrationPermission +from api.draft_registrations.permissions import ( + DraftRegistrationPermission, + CanSubmitDraftRegistrationToProvider, +) from api.files.serializers import FileSerializer, OsfStorageFileSerializer from api.files import annotations as file_annotations from api.identifiers.serializers import NodeIdentifierSerializer @@ -671,6 +674,7 @@ class NodeDraftRegistrationsList(JSONAPIBaseView, generics.ListCreateAPIView, No DraftRegistrationPermission, drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, + CanSubmitDraftRegistrationToProvider, ) parser_classes = (JSONAPIMultipleRelationshipsParser, JSONAPIMultipleRelationshipsParserForRegularJSON) diff --git a/api_tests/draft_registrations/views/test_draft_registration_list.py b/api_tests/draft_registrations/views/test_draft_registration_list.py index 4aed087605a..c2c0cce084c 100644 --- a/api_tests/draft_registrations/views/test_draft_registration_list.py +++ b/api_tests/draft_registrations/views/test_draft_registration_list.py @@ -399,6 +399,19 @@ def test_affiliated_institutions_are_copied_from_user(self, app, user, url_draft draft_registration = DraftRegistration.load(res.json['data']['id']) assert list(draft_registration.affiliated_institutions.all()) == list(user.get_affiliated_institutions()) + def test_cannot_create_draft_when_provider_disallows_submissions( + self, app, user, provider, payload, url_draft_registrations): + provider.allow_submissions = False + provider.save() + + res = app.post_json_api( + url_draft_registrations, + payload, + auth=user.auth, + expect_errors=True, + ) + assert res.status_code == 409 + class TestDraftRegistrationCreateWithoutNode(AbstractDraftRegistrationTestCase): @pytest.fixture() @@ -451,6 +464,19 @@ def test_create_draft_with_provider( draft = DraftRegistration.load(data['id']) assert draft.provider == non_default_provider + def test_cannot_create_draft_when_provider_disallows_submissions( + self, app, user, url_draft_registrations, non_default_provider, payload_with_non_default_provider): + non_default_provider.allow_submissions = False + non_default_provider.save() + + res = app.post_json_api( + url_draft_registrations, + payload_with_non_default_provider, + auth=user.auth, + expect_errors=True, + ) + assert res.status_code == 409 + def test_write_contrib(self, app, user, project_public, payload, url_draft_registrations, user_write_contrib): """(no node supplied, so any logged in user can create) """