1use std::{collections::HashMap, sync::Arc};
8
9use mas_data_model::{
10 UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_oidc_client::error::DiscoveryError;
14use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
15use oauth2_types::oidc::VerifiedProviderMetadata;
16use tokio::sync::RwLock;
17use url::Url;
18
19pub struct LazyProviderInfos<'a> {
22 cache: &'a MetadataCache,
23 provider: &'a UpstreamOAuthProvider,
24 client: &'a reqwest::Client,
25 loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
26}
27
28impl<'a> LazyProviderInfos<'a> {
29 pub fn new(
30 cache: &'a MetadataCache,
31 provider: &'a UpstreamOAuthProvider,
32 client: &'a reqwest::Client,
33 ) -> Self {
34 Self {
35 cache,
36 provider,
37 client,
38 loaded_metadata: None,
39 }
40 }
41
42 pub async fn maybe_discover(
45 &mut self,
46 ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
47 match self.load().await {
48 Ok(metadata) => Ok(Some(metadata)),
49 Err(DiscoveryError::Disabled) => Ok(None),
50 Err(e) => Err(e),
51 }
52 }
53
54 async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
55 if self.loaded_metadata.is_none() {
56 let verify = match self.provider.discovery_mode {
57 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
58 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
59 UpstreamOAuthProviderDiscoveryMode::Disabled => {
60 return Err(DiscoveryError::Disabled);
61 }
62 };
63
64 let Some(issuer) = &self.provider.issuer else {
65 return Err(DiscoveryError::MissingIssuer);
66 };
67
68 let metadata = self.cache.get(self.client, issuer, verify).await?;
69
70 self.loaded_metadata = Some(metadata);
71 }
72
73 Ok(self.loaded_metadata.as_ref().unwrap())
74 }
75
76 pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
81 if let Some(jwks_uri) = &self.provider.jwks_uri_override {
82 return Ok(jwks_uri);
83 }
84
85 Ok(self.load().await?.jwks_uri())
86 }
87
88 pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
93 if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
94 return Ok(authorization_endpoint);
95 }
96
97 Ok(self.load().await?.authorization_endpoint())
98 }
99
100 pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
105 if let Some(token_endpoint) = &self.provider.token_endpoint_override {
106 return Ok(token_endpoint);
107 }
108
109 Ok(self.load().await?.token_endpoint())
110 }
111
112 pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
117 if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
118 return Ok(userinfo_endpoint);
119 }
120
121 Ok(self.load().await?.userinfo_endpoint())
122 }
123
124 pub async fn pkce_methods(
129 &mut self,
130 ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
131 let methods = match self.provider.pkce_mode {
132 UpstreamOAuthProviderPkceMode::Auto => self
133 .maybe_discover()
134 .await?
135 .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
136 UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
137 UpstreamOAuthProviderPkceMode::Disabled => None,
138 };
139
140 Ok(methods)
141 }
142}
143
144#[allow(clippy::module_name_repetitions)]
150#[derive(Debug, Clone, Default)]
151pub struct MetadataCache {
152 cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
153 insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154}
155
156impl MetadataCache {
157 #[must_use]
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
168 pub async fn warm_up_and_run<R: RepositoryAccess>(
169 &self,
170 client: &reqwest::Client,
171 interval: std::time::Duration,
172 repository: &mut R,
173 ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
174 let providers = repository.upstream_oauth_provider().all_enabled().await?;
175
176 for provider in providers {
177 let verify = match provider.discovery_mode {
178 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
179 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
180 UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
181 };
182
183 let Some(issuer) = &provider.issuer else {
184 tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
185 continue;
186 };
187
188 if let Err(e) = self.fetch(client, issuer, verify).await {
189 tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
190 }
191 }
192
193 let cache = self.clone();
195 let client = client.clone();
196 Ok(tokio::spawn(async move {
197 loop {
198 tokio::time::sleep(interval).await;
200 cache.refresh_all(&client).await;
201 }
202 }))
203 }
204
205 #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
206 async fn fetch(
207 &self,
208 client: &reqwest::Client,
209 issuer: &str,
210 verify: bool,
211 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
212 if verify {
213 let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
214 let metadata = Arc::new(metadata);
215
216 self.cache
217 .write()
218 .await
219 .insert(issuer.to_owned(), metadata.clone());
220
221 Ok(metadata)
222 } else {
223 let metadata =
224 mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
225 let metadata = Arc::new(metadata);
226
227 self.insecure_cache
228 .write()
229 .await
230 .insert(issuer.to_owned(), metadata.clone());
231
232 Ok(metadata)
233 }
234 }
235
236 #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
238 pub async fn get(
239 &self,
240 client: &reqwest::Client,
241 issuer: &str,
242 verify: bool,
243 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
244 let cache = if verify {
245 self.cache.read().await
246 } else {
247 self.insecure_cache.read().await
248 };
249
250 if let Some(metadata) = cache.get(issuer) {
251 return Ok(Arc::clone(metadata));
252 }
253 drop(cache);
255
256 let metadata = self.fetch(client, issuer, verify).await?;
257 Ok(metadata)
258 }
259
260 #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
261 async fn refresh_all(&self, client: &reqwest::Client) {
262 let keys: Vec<String> = {
264 let cache = self.cache.read().await;
265 cache.keys().cloned().collect()
266 };
267
268 for issuer in keys {
269 if let Err(e) = self.fetch(client, &issuer, true).await {
270 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
271 }
272 }
273
274 let keys: Vec<String> = {
276 let cache = self.insecure_cache.read().await;
277 cache.keys().cloned().collect()
278 };
279
280 for issuer in keys {
281 if let Err(e) = self.fetch(client, &issuer, false).await {
282 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
283 }
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 #![allow(clippy::too_many_lines)]
291
292 use mas_data_model::{
296 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
297 };
298 use mas_iana::jose::JsonWebSignatureAlg;
299 use mas_storage::{Clock, clock::MockClock};
300 use oauth2_types::scope::{OPENID, Scope};
301 use ulid::Ulid;
302 use wiremock::{
303 Mock, MockServer, ResponseTemplate,
304 matchers::{method, path},
305 };
306
307 use super::*;
308 use crate::test_utils::setup;
309
310 #[tokio::test]
311 async fn test_metadata_cache() {
312 setup();
313 let mock_server = MockServer::start().await;
314 let http_client = mas_http::reqwest_client();
315
316 let cache = MetadataCache::new();
317
318 cache
320 .get(&http_client, &mock_server.uri(), false)
321 .await
322 .unwrap_err();
323
324 let expected_calls = 3;
325 let mut calls = 0;
326 let _mock_guard = Mock::given(method("GET"))
327 .and(path("/.well-known/openid-configuration"))
328 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
329 "issuer": mock_server.uri(),
330 "authorization_endpoint": "https://example.com/authorize",
331 "token_endpoint": "https://example.com/token",
332 "jwks_uri": "https://example.com/jwks",
333 "userinfo_endpoint": "https://example.com/userinfo",
334 "scopes_supported": ["openid"],
335 "response_types_supported": ["code"],
336 "response_modes_supported": ["query", "fragment"],
337 "grant_types_supported": ["authorization_code"],
338 "subject_types_supported": ["public"],
339 "id_token_signing_alg_values_supported": ["RS256"],
340 })))
341 .expect(expected_calls)
342 .mount(&mock_server)
343 .await;
344
345 cache
347 .get(&http_client, &mock_server.uri(), false)
348 .await
349 .unwrap();
350 calls += 1;
351
352 cache
354 .get(&http_client, &mock_server.uri(), false)
355 .await
356 .unwrap();
357 calls += 0;
358
359 cache
361 .get(&http_client, &mock_server.uri(), true)
362 .await
363 .unwrap_err();
364 calls += 1;
365
366 cache.refresh_all(&http_client).await;
368 calls += 1;
369
370 assert_eq!(calls, expected_calls);
371 }
372
373 #[tokio::test]
374 async fn test_lazy_provider_infos() {
375 setup();
376
377 let mock_server = MockServer::start().await;
378 let http_client = mas_http::reqwest_client();
379
380 let expected_calls = 2;
381 let mut calls = 0;
382 let _mock_guard = Mock::given(method("GET"))
383 .and(path("/.well-known/openid-configuration"))
384 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
385 "issuer": mock_server.uri(),
386 "authorization_endpoint": "https://example.com/authorize",
387 "token_endpoint": "https://example.com/token",
388 "jwks_uri": "https://example.com/jwks",
389 "userinfo_endpoint": "https://example.com/userinfo",
390 "scopes_supported": ["openid"],
391 "response_types_supported": ["code"],
392 "response_modes_supported": ["query", "fragment"],
393 "grant_types_supported": ["authorization_code"],
394 "subject_types_supported": ["public"],
395 "id_token_signing_alg_values_supported": ["RS256"],
396 })))
397 .expect(expected_calls)
398 .mount(&mock_server)
399 .await;
400
401 let clock = MockClock::default();
402 let provider = UpstreamOAuthProvider {
403 id: Ulid::nil(),
404 issuer: Some(mock_server.uri()),
405 human_name: Some("Example Ltd.".to_owned()),
406 brand_name: None,
407 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
408 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
409 fetch_userinfo: false,
410 userinfo_signed_response_alg: None,
411 jwks_uri_override: None,
412 authorization_endpoint_override: None,
413 scope: Scope::from_iter([OPENID]),
414 userinfo_endpoint_override: None,
415 token_endpoint_override: None,
416 client_id: "client_id".to_owned(),
417 encrypted_client_secret: None,
418 token_endpoint_signing_alg: None,
419 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
420 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
421 response_mode: None,
422 created_at: clock.now(),
423 disabled_at: None,
424 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
425 additional_authorization_parameters: Vec::new(),
426 };
427
428 {
430 let cache = MetadataCache::new();
431 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
432 lazy_metadata.maybe_discover().await.unwrap();
433 assert_eq!(
434 lazy_metadata
435 .authorization_endpoint()
436 .await
437 .unwrap()
438 .as_str(),
439 "https://example.com/authorize"
440 );
441 calls += 1;
442 }
443
444 {
446 let provider = UpstreamOAuthProvider {
447 jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
448 authorization_endpoint_override: Some(
449 "https://example.com/authorize_override".parse().unwrap(),
450 ),
451 token_endpoint_override: Some(
452 "https://example.com/token_override".parse().unwrap(),
453 ),
454 ..provider.clone()
455 };
456 let cache = MetadataCache::new();
457 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
458 assert_eq!(
459 lazy_metadata.jwks_uri().await.unwrap().as_str(),
460 "https://example.com/jwks_override"
461 );
462 assert_eq!(
463 lazy_metadata
464 .authorization_endpoint()
465 .await
466 .unwrap()
467 .as_str(),
468 "https://example.com/authorize_override"
469 );
470 assert_eq!(
471 lazy_metadata.token_endpoint().await.unwrap().as_str(),
472 "https://example.com/token_override"
473 );
474 calls += 0;
476 }
477
478 {
480 let provider = UpstreamOAuthProvider {
481 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
482 ..provider.clone()
483 };
484 let cache = MetadataCache::new();
485 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
486 lazy_metadata.authorization_endpoint().await.unwrap_err();
487 calls += 1;
489 }
490
491 {
493 let provider = UpstreamOAuthProvider {
494 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
495 authorization_endpoint_override: Some(
496 Url::parse("https://example.com/authorize_override").unwrap(),
497 ),
498 token_endpoint_override: None,
499 ..provider.clone()
500 };
501 let cache = MetadataCache::new();
502 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
503 assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
505 assert_eq!(
506 lazy_metadata
507 .authorization_endpoint()
508 .await
509 .unwrap()
510 .as_str(),
511 "https://example.com/authorize_override"
512 );
513 assert!(matches!(
514 lazy_metadata.token_endpoint().await,
515 Err(DiscoveryError::Disabled),
516 ));
517 calls += 0;
519 }
520
521 assert_eq!(calls, expected_calls);
522 }
523}