Skip to content

Commit 06e0b23

Browse files
committed
fix: omit resource from OAuth refresh grants
1 parent 161834d commit 06e0b23

2 files changed

Lines changed: 34 additions & 8 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_resource_url(self) -> str:
151151

152152
# If PRM provides a resource that's a valid parent, use it
153153
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
154-
prm_resource = str(self.protected_resource_metadata.resource)
154+
prm_resource = str(self.protected_resource_metadata.resource).rstrip("/")
155155
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
156156
resource = prm_resource
157157

@@ -442,10 +442,6 @@ async def _refresh_token(self) -> httpx.Request:
442442
"client_id": self.context.client_info.client_id,
443443
}
444444

445-
# Only include resource param if conditions are met
446-
if self.context.should_include_resource_param(self.context.protocol_version):
447-
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
448-
449445
# Prepare authentication based on preferred method
450446
headers = {"Content-Type": "application/x-www-form-urlencoded"}
451447
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)

tests/client/test_auth.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -762,15 +762,15 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
762762
expected_resource = quote(oauth_provider.context.get_resource_url(), safe="")
763763
assert f"resource={expected_resource}" in content
764764

765-
# Test in refresh token
765+
# Refresh grants don't carry resource; some OAuth providers reject it.
766766
oauth_provider.context.current_tokens = OAuthToken(
767767
access_token="test_access",
768768
token_type="Bearer",
769769
refresh_token="test_refresh",
770770
)
771771
refresh_request = await oauth_provider._refresh_token()
772772
refresh_content = refresh_request.content.decode()
773-
assert "resource=" in refresh_content
773+
assert "resource=" not in refresh_content
774774

775775
@pytest.mark.anyio
776776
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
@@ -800,7 +800,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro
800800

801801
@pytest.mark.anyio
802802
async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider):
803-
"""Test resource parameter is always included when protected resource metadata exists."""
803+
"""Test resource parameter is included when protected resource metadata exists."""
804804
# Set old protocol version but with protected resource metadata
805805
oauth_provider.context.protocol_version = "2025-03-26"
806806
oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata(
@@ -818,6 +818,15 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
818818
content = request.content.decode()
819819
assert "resource=" in content
820820

821+
oauth_provider.context.current_tokens = OAuthToken(
822+
access_token="test_access",
823+
token_type="Bearer",
824+
refresh_token="test_refresh",
825+
)
826+
refresh_request = await oauth_provider._refresh_token()
827+
refresh_content = refresh_request.content.decode()
828+
assert "resource=" not in refresh_content
829+
821830

822831
@pytest.mark.anyio
823832
async def test_validate_resource_rejects_mismatched_resource(
@@ -949,6 +958,27 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches(
949958
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")
950959

951960

961+
@pytest.mark.anyio
962+
async def test_get_resource_url_omits_pydantic_root_slash(
963+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
964+
) -> None:
965+
"""Bare-domain PRM resources should not inherit Pydantic's trailing slash."""
966+
provider = OAuthClientProvider(
967+
server_url="https://api.example.com",
968+
client_metadata=client_metadata,
969+
storage=mock_storage,
970+
)
971+
provider._initialized = True
972+
973+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
974+
resource=AnyHttpUrl("https://api.example.com"),
975+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
976+
)
977+
978+
assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/"
979+
assert provider.context.get_resource_url() == "https://api.example.com"
980+
981+
952982
class TestRegistrationResponse:
953983
"""Test client registration response handling."""
954984

0 commit comments

Comments
 (0)