Skip to content

Commit b9c764e

Browse files
committed
fix: omit OAuth resource on token refresh
1 parent 161834d commit b9c764e

2 files changed

Lines changed: 41 additions & 9 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
logger = logging.getLogger(__name__)
5454

5555

56+
def _normalize_resource_url(resource: str) -> str:
57+
parsed = urlparse(resource)
58+
if parsed.path == "/" and not parsed.params and not parsed.query and not parsed.fragment:
59+
return f"{parsed.scheme}://{parsed.netloc}"
60+
return resource
61+
62+
5663
class PKCEParameters(BaseModel):
5764
"""PKCE (Proof Key for Code Exchange) parameters."""
5865

@@ -151,7 +158,7 @@ def get_resource_url(self) -> str:
151158

152159
# If PRM provides a resource that's a valid parent, use it
153160
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
154-
prm_resource = str(self.protected_resource_metadata.resource)
161+
prm_resource = _normalize_resource_url(str(self.protected_resource_metadata.resource))
155162
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
156163
resource = prm_resource
157164

@@ -442,10 +449,6 @@ async def _refresh_token(self) -> httpx.Request:
442449
"client_id": self.context.client_info.client_id,
443450
}
444451

445-
# Only include resource param if conditions are met
446-
if self.context.should_include_resource_param(self.context.protocol_version):
447-
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
448-
449452
# Prepare authentication based on preferred method
450453
headers = {"Content-Type": "application/x-www-form-urlencoded"}
451454
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)

tests/client/test_auth.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ class TestProtectedResourceMetadata:
745745

746746
@pytest.mark.anyio
747747
async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider):
748-
"""Test resource parameter is included for protocol version >= 2025-06-18."""
748+
"""Test resource parameter is included for initial token requests on newer protocol versions."""
749749
# Set protocol version to 2025-06-18
750750
oauth_provider.context.protocol_version = "2025-06-18"
751751
oauth_provider.context.client_info = OAuthClientInformationFull(
@@ -762,15 +762,16 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
762762
expected_resource = quote(oauth_provider.context.get_resource_url(), safe="")
763763
assert f"resource={expected_resource}" in content
764764

765-
# Test in refresh token
765+
# Refresh tokens should not resend the resource parameter. Some providers
766+
# reject RFC 8707 resource values on refresh_token grants.
766767
oauth_provider.context.current_tokens = OAuthToken(
767768
access_token="test_access",
768769
token_type="Bearer",
769770
refresh_token="test_refresh",
770771
)
771772
refresh_request = await oauth_provider._refresh_token()
772773
refresh_content = refresh_request.content.decode()
773-
assert "resource=" in refresh_content
774+
assert "resource=" not in refresh_content
774775

775776
@pytest.mark.anyio
776777
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
@@ -800,7 +801,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro
800801

801802
@pytest.mark.anyio
802803
async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider):
803-
"""Test resource parameter is always included when protected resource metadata exists."""
804+
"""Test resource parameter is included in initial token requests when PRM exists."""
804805
# Set old protocol version but with protected resource metadata
805806
oauth_provider.context.protocol_version = "2025-03-26"
806807
oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata(
@@ -818,6 +819,15 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
818819
content = request.content.decode()
819820
assert "resource=" in content
820821

822+
oauth_provider.context.current_tokens = OAuthToken(
823+
access_token="test_access",
824+
token_type="Bearer",
825+
refresh_token="test_refresh",
826+
)
827+
refresh_request = await oauth_provider._refresh_token()
828+
refresh_content = refresh_request.content.decode()
829+
assert "resource=" not in refresh_content
830+
821831

822832
@pytest.mark.anyio
823833
async def test_validate_resource_rejects_mismatched_resource(
@@ -949,6 +959,25 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches(
949959
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")
950960

951961

962+
@pytest.mark.anyio
963+
async def test_get_resource_url_removes_root_prm_trailing_slash(
964+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
965+
) -> None:
966+
"""Bare-domain PRM resources should not pick up Pydantic's root slash."""
967+
provider = OAuthClientProvider(
968+
server_url="https://api.example.com",
969+
client_metadata=client_metadata,
970+
storage=mock_storage,
971+
)
972+
provider._initialized = True
973+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
974+
resource=AnyHttpUrl("https://api.example.com"),
975+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
976+
)
977+
978+
assert provider.context.get_resource_url() == snapshot("https://api.example.com")
979+
980+
952981
class TestRegistrationResponse:
953982
"""Test client registration response handling."""
954983

0 commit comments

Comments
 (0)