Skip to content

Commit cb6a62f

Browse files
feat: auth improvements (#217)
* feat: improvement to oauth resources resolution * feat: improvements to refreshToken procedure and updates to unit tests * feat: improvements to documentation of clien_id
1 parent 6132fd4 commit cb6a62f

5 files changed

Lines changed: 144 additions & 31 deletions

File tree

Sources/MCP/Base/Authorization/OAuthAuthorizer.swift

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable {
217217
return nil
218218
}
219219
if accessToken.isExpired() {
220-
tokenStorage.clear()
221220
return nil
222221
}
223222
return "\(OAuthTokenType.bearer) \(accessToken.value)"
@@ -367,7 +366,6 @@ public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable {
367366
public func prepareAuthorization(for endpoint: URL, session: URLSession) async throws {
368367
guard configuration.proactiveRefreshWindowSeconds > 0 else { return }
369368
guard let token = tokenStorage.load() else { return }
370-
guard !token.isExpired() else { return }
371369
guard token.isExpired(skewSeconds: configuration.proactiveRefreshWindowSeconds) else {
372370
return
373371
}
@@ -441,8 +439,10 @@ public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable {
441439
candidates.append(fallback)
442440
}
443441

442+
let fallbackIssuer = try? discoveryClient.metadataDiscovery
443+
.authorizationServerFallbackIssuer(from: endpoint)
444444
let metadata = try await discoveryClient.fetchProtectedResourceMetadata(
445-
candidates: candidates, session: session)
445+
candidates: candidates, fallbackIssuer: fallbackIssuer, session: session)
446446
try validateProtectedResource(metadata: metadata, endpoint: endpoint)
447447

448448
self.protectedResourceMetadata = metadata
@@ -716,10 +716,17 @@ public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable {
716716
expiresAt: expiresAt,
717717
scopes: scopeSet,
718718
authorizationServer: selectedAuthorizationServer,
719-
refreshToken: decoded.refreshToken
719+
refreshToken: decoded.refreshToken,
720+
clientID: nonEmptyClientID()
720721
))
721722
}
722723

724+
/// Returns the configured `client_id` or `nil` if the authorizer has not yet been assigned one.
725+
private func nonEmptyClientID() -> String? {
726+
let id = configuration.authentication.clientID
727+
return id.isEmpty ? nil : id
728+
}
729+
723730
private func resolveTokenEndpoint(
724731
asMetadata: OAuthAuthorizationServerMetadata
725732
) throws -> URL {
@@ -823,7 +830,8 @@ public final class OAuthAuthorizer: HTTPClientAuthorizer, @unchecked Sendable {
823830
expiresAt: nil,
824831
scopes: requestedScopes ?? [],
825832
authorizationServer: authorizationServer,
826-
refreshToken: nil
833+
refreshToken: nil,
834+
clientID: nonEmptyClientID()
827835
))
828836
}
829837

Sources/MCP/Base/Authorization/OAuthDiscoveryClient.swift

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import Foundation
77
/// Internal protocol for fetching OAuth discovery metadata.
88
protocol OAuthDiscoveryFetching: Sendable {
99
var metadataDiscovery: any OAuthMetadataDiscovering { get }
10-
func fetchProtectedResourceMetadata(candidates: [URL], session: URLSession) async throws -> OAuthProtectedResourceMetadata
10+
func fetchProtectedResourceMetadata(candidates: [URL], fallbackIssuer: URL?, session: URLSession) async throws -> OAuthProtectedResourceMetadata
1111
func fetchAuthorizationServerMetadata(candidates: [URL], session: URLSession) async throws -> (server: URL, metadata: OAuthAuthorizationServerMetadata)
1212
}
1313

@@ -29,8 +29,13 @@ struct OAuthDiscoveryClient: Sendable {
2929
}
3030

3131
/// Fetches Protected Resource Metadata from the first candidate that returns a valid response.
32+
///
33+
/// If all candidates fail and `fallbackIssuer` is provided, returns synthetic metadata
34+
/// using that issuer as the authorization server — for servers that do not expose a
35+
/// PRM document at any well-known path.
3236
func fetchProtectedResourceMetadata(
3337
candidates: [URL],
38+
fallbackIssuer: URL?,
3439
session: URLSession
3540
) async throws -> OAuthProtectedResourceMetadata {
3641
let decoder = JSONDecoder()
@@ -56,15 +61,28 @@ struct OAuthDiscoveryClient: Sendable {
5661
continue
5762
}
5863
}
64+
if let fallbackIssuer {
65+
return OAuthProtectedResourceMetadata(
66+
resource: nil,
67+
authorizationServers: [fallbackIssuer],
68+
scopesSupported: nil
69+
)
70+
}
5971
throw OAuthAuthorizationError.metadataDiscoveryFailed
6072
}
6173

62-
/// Fetches Authorization Server Metadata from the first candidate that returns a valid response.
74+
/// Fetches Authorization Server Metadata from candidates, preferring a response whose
75+
/// `issuer` matches the candidate URL (RFC 8414 §3). If no candidate produces a matching
76+
/// issuer, the first valid response is accepted and its own `issuer` is used as the
77+
/// server identity — accommodating servers that serve metadata at one host but advertise
78+
/// a different issuer.
6379
func fetchAuthorizationServerMetadata(
6480
candidates: [URL],
6581
session: URLSession
6682
) async throws -> (server: URL, metadata: OAuthAuthorizationServerMetadata) {
6783
let decoder = JSONDecoder()
84+
var firstValid: (server: URL, metadata: OAuthAuthorizationServerMetadata)?
85+
6886
for candidateServer in candidates {
6987
guard (try? urlValidator.validateAuthorizationServer(
7088
candidateServer, context: "Authorization server issuer")) != nil
@@ -95,22 +113,28 @@ struct OAuthDiscoveryClient: Sendable {
95113
let asMetadata = try decoder.decode(
96114
OAuthAuthorizationServerMetadata.self, from: data)
97115

98-
// RFC 8414 §3: issuer field must match the candidate server URL.
99-
// Absent issuer is tolerated (some servers omit it).
100-
if let metadataIssuer = asMetadata.issuer {
101-
guard metadataIssuer.absoluteString.lowercased()
116+
// Prefer metadata whose issuer matches the candidate URL (RFC 8414 §3).
117+
let issuerMatches =
118+
asMetadata.issuer == nil
119+
|| asMetadata.issuer?.absoluteString.lowercased()
102120
== candidateServer.absoluteString.lowercased()
103-
else {
104-
continue
105-
}
121+
if issuerMatches {
122+
return (server: candidateServer, metadata: asMetadata)
123+
}
124+
// Keep as fallback in case no issuer-matching candidate is found.
125+
if firstValid == nil {
126+
let server = asMetadata.issuer ?? candidateServer
127+
firstValid = (server: server, metadata: asMetadata)
106128
}
107-
108-
return (server: candidateServer, metadata: asMetadata)
109129
} catch {
110130
continue
111131
}
112132
}
113133
}
134+
135+
if let firstValid {
136+
return firstValid
137+
}
114138
throw OAuthAuthorizationError.authorizationServerMetadataDiscoveryFailed
115139
}
116140
}

Sources/MCP/Base/Authorization/OAuthModels.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ struct OAuthTokenErrorResponse: Decodable {
113113
///
114114
/// Stored by ``TokenStorage`` and produced by ``OAuthAuthorizer`` after a successful
115115
/// token request. Use ``isExpired(now:skewSeconds:)`` to check validity before use.
116-
public struct OAuthAccessToken: Sendable {
116+
public struct OAuthAccessToken: Sendable, Codable {
117117
/// The raw bearer token string for use in the `Authorization` header.
118118
public let value: String
119119

@@ -135,21 +135,31 @@ public struct OAuthAccessToken: Sendable {
135135
/// The refresh token, if the authorization server issued one alongside the access token.
136136
public let refreshToken: String?
137137

138+
/// The OAuth `client_id` this token was issued to, if the authorizer had one at storage time.
139+
///
140+
/// Captures the `client_id` held by the authorizer's ``OAuthConfiguration/authentication`` when
141+
/// the token was saved — including identifiers assigned by Dynamic Client Registration
142+
/// (RFC 7591). `nil` when no `client_id` was configured (for example, the placeholder used
143+
/// before registration completes).
144+
public let clientID: String?
145+
138146
/// Creates a new access token record.
139147
public init(
140148
value: String,
141149
tokenType: String,
142150
expiresAt: Date?,
143151
scopes: Set<String>,
144152
authorizationServer: URL?,
145-
refreshToken: String?
153+
refreshToken: String?,
154+
clientID: String? = nil
146155
) {
147156
self.value = value
148157
self.tokenType = tokenType
149158
self.expiresAt = expiresAt
150159
self.scopes = scopes
151160
self.authorizationServer = authorizationServer
152161
self.refreshToken = refreshToken
162+
self.clientID = clientID
153163
}
154164

155165
/// Returns `true` if the token has expired or will expire within the skew window.

Tests/MCPTests/OAuthAuthorizerTests.swift

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ final class MockDiscoveryClient: OAuthDiscoveryFetching, @unchecked Sendable {
6464
)
6565
}
6666

67-
func fetchProtectedResourceMetadata(candidates: [URL], session: URLSession) async throws -> OAuthProtectedResourceMetadata {
67+
func fetchProtectedResourceMetadata(candidates: [URL], fallbackIssuer: URL?, session: URLSession) async throws -> OAuthProtectedResourceMetadata {
6868
fetchProtectedResourceMetadataCallCount += 1
6969
return protectedResourceMetadataResult
7070
}
@@ -100,6 +100,10 @@ final class MockTokenClient: OAuthTokenRequesting, @unchecked Sendable {
100100

101101
final class MockClientRegistrar: OAuthClientRegistering, @unchecked Sendable {
102102
var registerCallCount = 0
103+
var registrationResult: (
104+
response: OAuthClientRegistrationResponse,
105+
updatedAuthentication: OAuthConfiguration.TokenEndpointAuthentication
106+
)?
103107

104108
func register(
105109
configuration: OAuthConfiguration,
@@ -110,7 +114,7 @@ final class MockClientRegistrar: OAuthClientRegistering, @unchecked Sendable {
110114
updatedAuthentication: OAuthConfiguration.TokenEndpointAuthentication
111115
)? {
112116
registerCallCount += 1
113-
return nil
117+
return registrationResult
114118
}
115119
}
116120

@@ -381,4 +385,45 @@ struct OAuthAuthorizerTests {
381385

382386
#expect(registrar.registerCallCount == 0)
383387
}
388+
389+
@Test("handleChallenge persists the DCR-assigned clientID on the saved token")
390+
func testHandleChallengePersistsDCRClientIDOnToken() async throws {
391+
let assignedClientID = "dcr-assigned-client-id"
392+
let tokenStorage = InMemoryTokenStorage()
393+
let registrar = MockClientRegistrar()
394+
registrar.registrationResult = (
395+
response: OAuthClientRegistrationResponse(
396+
clientID: assignedClientID,
397+
clientSecret: nil,
398+
tokenEndpointAuthMethod: nil,
399+
clientSecretExpiresAt: nil
400+
),
401+
updatedAuthentication: .none(clientID: assignedClientID)
402+
)
403+
404+
let config = OAuthConfiguration(
405+
authentication: .none(clientID: "")
406+
)
407+
let authorizer = OAuthAuthorizer(
408+
configuration: config,
409+
tokenStorage: tokenStorage,
410+
urlValidator: MockURLValidator(),
411+
discoveryClient: MockDiscoveryClient(),
412+
tokenEndpointClient: MockTokenClient(),
413+
clientRegistrar: registrar,
414+
authCodeFlow: MockAuthCodeFlow()
415+
)
416+
417+
let handled = try await authorizer.handleChallenge(
418+
statusCode: 401,
419+
headers: headers401,
420+
endpoint: endpoint,
421+
operationKey: nil,
422+
session: .shared
423+
)
424+
425+
#expect(handled == true)
426+
#expect(registrar.registerCallCount == 1)
427+
#expect(tokenStorage.load()?.clientID == assignedClientID)
428+
}
384429
}

Tests/MCPTests/OAuthDiscoveryClientTests.swift

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import Testing
4949

5050
let metadata = try await makeClient().fetchProtectedResourceMetadata(
5151
candidates: [URL(string: "https://example.com/.well-known/oauth-protected-resource")!],
52+
fallbackIssuer: nil,
5253
session: session
5354
)
5455
let expected = OAuthProtectedResourceMetadata(
@@ -76,6 +77,7 @@ import Testing
7677
URL(string: "https://example.com/.well-known/oauth-protected-resource/mcp")!,
7778
URL(string: "https://example.com/.well-known/oauth-protected-resource")!,
7879
],
80+
fallbackIssuer: nil,
7981
session: session
8082
)
8183
let expected = OAuthProtectedResourceMetadata(
@@ -103,6 +105,7 @@ import Testing
103105
URL(string: "https://example.com/.well-known/oauth-protected-resource/mcp")!,
104106
URL(string: "https://example.com/.well-known/oauth-protected-resource")!,
105107
],
108+
fallbackIssuer: nil,
106109
session: session
107110
)
108111
let expected = OAuthProtectedResourceMetadata(
@@ -112,7 +115,7 @@ import Testing
112115
#expect(metadata == expected)
113116
}
114117

115-
@Test("Throws metadataDiscoveryFailed when all candidates fail")
118+
@Test("Throws metadataDiscoveryFailed when all candidates fail and no fallback issuer")
116119
func testFetchProtectedResourceMetadataThrowsWhenAllFail() async throws {
117120
let (session, key) = makeIsolatedSession()
118121
await IsolatedMockURLProtocol.setHandler(key: key) { request in
@@ -126,11 +129,34 @@ import Testing
126129
candidates: [
127130
URL(string: "https://example.com/.well-known/oauth-protected-resource")!
128131
],
132+
fallbackIssuer: nil,
129133
session: session
130134
)
131135
}
132136
}
133137

138+
@Test("Returns synthetic metadata with fallback issuer when all candidates fail")
139+
func testFetchProtectedResourceMetadataUsesFallbackIssuer() async throws {
140+
let fallback = URL(string: "https://example.com")!
141+
let (session, key) = makeIsolatedSession()
142+
await IsolatedMockURLProtocol.setHandler(key: key) { request in
143+
let response = HTTPURLResponse(
144+
url: request.url!, statusCode: 404, httpVersion: nil, headerFields: nil)!
145+
return (response, Data())
146+
}
147+
148+
let metadata = try await makeClient().fetchProtectedResourceMetadata(
149+
candidates: [URL(string: "https://example.com/.well-known/oauth-protected-resource")!],
150+
fallbackIssuer: fallback,
151+
session: session
152+
)
153+
let expected = OAuthProtectedResourceMetadata(
154+
resource: nil,
155+
authorizationServers: [fallback],
156+
scopesSupported: nil)
157+
#expect(metadata == expected)
158+
}
159+
134160
// MARK: - fetchAuthorizationServerMetadata
135161

136162
@Test("Returns server and metadata when issuer matches")
@@ -162,23 +188,23 @@ import Testing
162188
#expect(metadata == expectedMetadata)
163189
}
164190

165-
@Test("Skips candidate when issuer field does not match")
166-
func testFetchAuthorizationServerMetadataSkipsIssuerMismatch() async throws {
167-
let wrongIssuerBody = try makeASMetadataBody(issuer: "https://other.example.com")
191+
@Test("Uses metadata issuer as server identity when it differs from candidate URL")
192+
func testFetchAuthorizationServerMetadataUsesMetadataIssuer() async throws {
193+
let metadataIssuer = "https://other.example.com"
194+
let body = try makeASMetadataBody(issuer: metadataIssuer)
168195
let (session, key) = makeIsolatedSession()
169196
await IsolatedMockURLProtocol.setHandler(key: key) { _ in
170197
let response = HTTPURLResponse(
171198
url: URL(string: "https://auth.example.com")!,
172199
statusCode: 200, httpVersion: nil, headerFields: nil)!
173-
return (response, wrongIssuerBody)
200+
return (response, body)
174201
}
175202

176-
await #expect(throws: OAuthAuthorizationError.self) {
177-
try await makeClient().fetchAuthorizationServerMetadata(
178-
candidates: [URL(string: "https://auth.example.com")!],
179-
session: session
180-
)
181-
}
203+
let (server, _) = try await makeClient().fetchAuthorizationServerMetadata(
204+
candidates: [URL(string: "https://auth.example.com")!],
205+
session: session
206+
)
207+
#expect(server == URL(string: metadataIssuer)!)
182208
}
183209

184210
@Test("Skips private IP candidates without making HTTP calls")

0 commit comments

Comments
 (0)