mas_handlers/upstream_oauth2/
cache.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, sync::Arc};
8
9use mas_data_model::{
10    UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_oidc_client::error::DiscoveryError;
14use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
15use oauth2_types::oidc::VerifiedProviderMetadata;
16use tokio::sync::RwLock;
17use url::Url;
18
19/// A high-level layer over metadata cache and provider configuration, which
20/// resolves endpoint overrides and discovery modes.
21pub struct LazyProviderInfos<'a> {
22    cache: &'a MetadataCache,
23    provider: &'a UpstreamOAuthProvider,
24    client: &'a reqwest::Client,
25    loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
26}
27
28impl<'a> LazyProviderInfos<'a> {
29    pub fn new(
30        cache: &'a MetadataCache,
31        provider: &'a UpstreamOAuthProvider,
32        client: &'a reqwest::Client,
33    ) -> Self {
34        Self {
35            cache,
36            provider,
37            client,
38            loaded_metadata: None,
39        }
40    }
41
42    /// Trigger the discovery process and return the metadata if discovery is
43    /// enabled.
44    pub async fn maybe_discover(
45        &mut self,
46    ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
47        match self.load().await {
48            Ok(metadata) => Ok(Some(metadata)),
49            Err(DiscoveryError::Disabled) => Ok(None),
50            Err(e) => Err(e),
51        }
52    }
53
54    async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
55        if self.loaded_metadata.is_none() {
56            let verify = match self.provider.discovery_mode {
57                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
58                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
59                UpstreamOAuthProviderDiscoveryMode::Disabled => {
60                    return Err(DiscoveryError::Disabled);
61                }
62            };
63
64            let Some(issuer) = &self.provider.issuer else {
65                return Err(DiscoveryError::MissingIssuer);
66            };
67
68            let metadata = self.cache.get(self.client, issuer, verify).await?;
69
70            self.loaded_metadata = Some(metadata);
71        }
72
73        Ok(self.loaded_metadata.as_ref().unwrap())
74    }
75
76    /// Get the JWKS URI for the provider.
77    ///
78    /// Uses [`UpstreamOAuthProvider.jwks_uri_override`] if set, otherwise uses
79    /// the one from discovery.
80    pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
81        if let Some(jwks_uri) = &self.provider.jwks_uri_override {
82            return Ok(jwks_uri);
83        }
84
85        Ok(self.load().await?.jwks_uri())
86    }
87
88    /// Get the authorization endpoint for the provider.
89    ///
90    /// Uses [`UpstreamOAuthProvider.authorization_endpoint_override`] if set,
91    /// otherwise uses the one from discovery.
92    pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
93        if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
94            return Ok(authorization_endpoint);
95        }
96
97        Ok(self.load().await?.authorization_endpoint())
98    }
99
100    /// Get the token endpoint for the provider.
101    ///
102    /// Uses [`UpstreamOAuthProvider.token_endpoint_override`] if set, otherwise
103    /// uses the one from discovery.
104    pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
105        if let Some(token_endpoint) = &self.provider.token_endpoint_override {
106            return Ok(token_endpoint);
107        }
108
109        Ok(self.load().await?.token_endpoint())
110    }
111
112    /// Get the userinfo endpoint for the provider.
113    ///
114    /// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
115    /// otherwise uses the one from discovery.
116    pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
117        if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
118            return Ok(userinfo_endpoint);
119        }
120
121        Ok(self.load().await?.userinfo_endpoint())
122    }
123
124    /// Get the PKCE methods supported by the provider.
125    ///
126    /// If the mode is set to auto, it will use the ones from discovery,
127    /// defaulting to none if discovery is disabled.
128    pub async fn pkce_methods(
129        &mut self,
130    ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
131        let methods = match self.provider.pkce_mode {
132            UpstreamOAuthProviderPkceMode::Auto => self
133                .maybe_discover()
134                .await?
135                .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
136            UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
137            UpstreamOAuthProviderPkceMode::Disabled => None,
138        };
139
140        Ok(methods)
141    }
142}
143
144/// A simple OIDC metadata cache
145///
146/// It never evicts entries, does not cache failures and has no locking.
147/// It can also be refreshed in the background, and warmed up on startup.
148/// It is good enough for our use case.
149#[allow(clippy::module_name_repetitions)]
150#[derive(Debug, Clone, Default)]
151pub struct MetadataCache {
152    cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
153    insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154}
155
156impl MetadataCache {
157    #[must_use]
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Warm up the cache by fetching all the known providers from the database
163    /// and inserting them into the cache.
164    ///
165    /// This spawns a background task that will refresh the cache at the given
166    /// interval.
167    #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
168    pub async fn warm_up_and_run<R: RepositoryAccess>(
169        &self,
170        client: &reqwest::Client,
171        interval: std::time::Duration,
172        repository: &mut R,
173    ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
174        let providers = repository.upstream_oauth_provider().all_enabled().await?;
175
176        for provider in providers {
177            let verify = match provider.discovery_mode {
178                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
179                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
180                UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
181            };
182
183            let Some(issuer) = &provider.issuer else {
184                tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
185                continue;
186            };
187
188            if let Err(e) = self.fetch(client, issuer, verify).await {
189                tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
190            }
191        }
192
193        // Spawn a background task to refresh the cache regularly
194        let cache = self.clone();
195        let client = client.clone();
196        Ok(tokio::spawn(async move {
197            loop {
198                // Re-fetch the known metadata at the given interval
199                tokio::time::sleep(interval).await;
200                cache.refresh_all(&client).await;
201            }
202        }))
203    }
204
205    #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
206    async fn fetch(
207        &self,
208        client: &reqwest::Client,
209        issuer: &str,
210        verify: bool,
211    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
212        if verify {
213            let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
214            let metadata = Arc::new(metadata);
215
216            self.cache
217                .write()
218                .await
219                .insert(issuer.to_owned(), metadata.clone());
220
221            Ok(metadata)
222        } else {
223            let metadata =
224                mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
225            let metadata = Arc::new(metadata);
226
227            self.insecure_cache
228                .write()
229                .await
230                .insert(issuer.to_owned(), metadata.clone());
231
232            Ok(metadata)
233        }
234    }
235
236    /// Get the metadata for the given issuer.
237    #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
238    pub async fn get(
239        &self,
240        client: &reqwest::Client,
241        issuer: &str,
242        verify: bool,
243    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
244        let cache = if verify {
245            self.cache.read().await
246        } else {
247            self.insecure_cache.read().await
248        };
249
250        if let Some(metadata) = cache.get(issuer) {
251            return Ok(Arc::clone(metadata));
252        }
253        // Drop the cache guard so that we don't deadlock when we try to fetch
254        drop(cache);
255
256        let metadata = self.fetch(client, issuer, verify).await?;
257        Ok(metadata)
258    }
259
260    #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
261    async fn refresh_all(&self, client: &reqwest::Client) {
262        // Grab all the keys first to avoid locking the cache for too long
263        let keys: Vec<String> = {
264            let cache = self.cache.read().await;
265            cache.keys().cloned().collect()
266        };
267
268        for issuer in keys {
269            if let Err(e) = self.fetch(client, &issuer, true).await {
270                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
271            }
272        }
273
274        // Do the same for the insecure cache
275        let keys: Vec<String> = {
276            let cache = self.insecure_cache.read().await;
277            cache.keys().cloned().collect()
278        };
279
280        for issuer in keys {
281            if let Err(e) = self.fetch(client, &issuer, false).await {
282                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
283            }
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    #![allow(clippy::too_many_lines)]
291
292    // XXX: sadly, we can't test HTTPS requests with wiremock, so we can only test
293    // 'insecure' discovery
294
295    use mas_data_model::{
296        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
297    };
298    use mas_iana::jose::JsonWebSignatureAlg;
299    use mas_storage::{Clock, clock::MockClock};
300    use oauth2_types::scope::{OPENID, Scope};
301    use ulid::Ulid;
302    use wiremock::{
303        Mock, MockServer, ResponseTemplate,
304        matchers::{method, path},
305    };
306
307    use super::*;
308    use crate::test_utils::setup;
309
310    #[tokio::test]
311    async fn test_metadata_cache() {
312        setup();
313        let mock_server = MockServer::start().await;
314        let http_client = mas_http::reqwest_client();
315
316        let cache = MetadataCache::new();
317
318        // An inexistant issuer should fail
319        cache
320            .get(&http_client, &mock_server.uri(), false)
321            .await
322            .unwrap_err();
323
324        let expected_calls = 3;
325        let mut calls = 0;
326        let _mock_guard = Mock::given(method("GET"))
327            .and(path("/.well-known/openid-configuration"))
328            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
329                "issuer": mock_server.uri(),
330                "authorization_endpoint": "https://example.com/authorize",
331                "token_endpoint": "https://example.com/token",
332                "jwks_uri": "https://example.com/jwks",
333                "userinfo_endpoint": "https://example.com/userinfo",
334                "scopes_supported": ["openid"],
335                "response_types_supported": ["code"],
336                "response_modes_supported": ["query", "fragment"],
337                "grant_types_supported": ["authorization_code"],
338                "subject_types_supported": ["public"],
339                "id_token_signing_alg_values_supported": ["RS256"],
340            })))
341            .expect(expected_calls)
342            .mount(&mock_server)
343            .await;
344
345        // A valid issuer should succeed
346        cache
347            .get(&http_client, &mock_server.uri(), false)
348            .await
349            .unwrap();
350        calls += 1;
351
352        // Calling again should not trigger a new fetch
353        cache
354            .get(&http_client, &mock_server.uri(), false)
355            .await
356            .unwrap();
357        calls += 0;
358
359        // A secure discovery should call but fail because the issuer is insecure
360        cache
361            .get(&http_client, &mock_server.uri(), true)
362            .await
363            .unwrap_err();
364        calls += 1;
365
366        // Calling refresh should refresh all the known issuers
367        cache.refresh_all(&http_client).await;
368        calls += 1;
369
370        assert_eq!(calls, expected_calls);
371    }
372
373    #[tokio::test]
374    async fn test_lazy_provider_infos() {
375        setup();
376
377        let mock_server = MockServer::start().await;
378        let http_client = mas_http::reqwest_client();
379
380        let expected_calls = 2;
381        let mut calls = 0;
382        let _mock_guard = Mock::given(method("GET"))
383            .and(path("/.well-known/openid-configuration"))
384            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
385                "issuer": mock_server.uri(),
386                "authorization_endpoint": "https://example.com/authorize",
387                "token_endpoint": "https://example.com/token",
388                "jwks_uri": "https://example.com/jwks",
389                "userinfo_endpoint": "https://example.com/userinfo",
390                "scopes_supported": ["openid"],
391                "response_types_supported": ["code"],
392                "response_modes_supported": ["query", "fragment"],
393                "grant_types_supported": ["authorization_code"],
394                "subject_types_supported": ["public"],
395                "id_token_signing_alg_values_supported": ["RS256"],
396            })))
397            .expect(expected_calls)
398            .mount(&mock_server)
399            .await;
400
401        let clock = MockClock::default();
402        let provider = UpstreamOAuthProvider {
403            id: Ulid::nil(),
404            issuer: Some(mock_server.uri()),
405            human_name: Some("Example Ltd.".to_owned()),
406            brand_name: None,
407            discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
408            pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
409            fetch_userinfo: false,
410            userinfo_signed_response_alg: None,
411            jwks_uri_override: None,
412            authorization_endpoint_override: None,
413            scope: Scope::from_iter([OPENID]),
414            userinfo_endpoint_override: None,
415            token_endpoint_override: None,
416            client_id: "client_id".to_owned(),
417            encrypted_client_secret: None,
418            token_endpoint_signing_alg: None,
419            token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
420            id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
421            response_mode: None,
422            created_at: clock.now(),
423            disabled_at: None,
424            claims_imports: UpstreamOAuthProviderClaimsImports::default(),
425            additional_authorization_parameters: Vec::new(),
426        };
427
428        // Without any override, it should just use discovery
429        {
430            let cache = MetadataCache::new();
431            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
432            lazy_metadata.maybe_discover().await.unwrap();
433            assert_eq!(
434                lazy_metadata
435                    .authorization_endpoint()
436                    .await
437                    .unwrap()
438                    .as_str(),
439                "https://example.com/authorize"
440            );
441            calls += 1;
442        }
443
444        // Test overriding endpoints
445        {
446            let provider = UpstreamOAuthProvider {
447                jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
448                authorization_endpoint_override: Some(
449                    "https://example.com/authorize_override".parse().unwrap(),
450                ),
451                token_endpoint_override: Some(
452                    "https://example.com/token_override".parse().unwrap(),
453                ),
454                ..provider.clone()
455            };
456            let cache = MetadataCache::new();
457            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
458            assert_eq!(
459                lazy_metadata.jwks_uri().await.unwrap().as_str(),
460                "https://example.com/jwks_override"
461            );
462            assert_eq!(
463                lazy_metadata
464                    .authorization_endpoint()
465                    .await
466                    .unwrap()
467                    .as_str(),
468                "https://example.com/authorize_override"
469            );
470            assert_eq!(
471                lazy_metadata.token_endpoint().await.unwrap().as_str(),
472                "https://example.com/token_override"
473            );
474            // This shouldn't trigger a new fetch as the endpoint is overriden
475            calls += 0;
476        }
477
478        // Loading an insecure provider with secure discovery should fail
479        {
480            let provider = UpstreamOAuthProvider {
481                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
482                ..provider.clone()
483            };
484            let cache = MetadataCache::new();
485            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
486            lazy_metadata.authorization_endpoint().await.unwrap_err();
487            // This triggered a fetch, even though it failed
488            calls += 1;
489        }
490
491        // Getting endpoints when discovery is disabled only works for overriden ones
492        {
493            let provider = UpstreamOAuthProvider {
494                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
495                authorization_endpoint_override: Some(
496                    Url::parse("https://example.com/authorize_override").unwrap(),
497                ),
498                token_endpoint_override: None,
499                ..provider.clone()
500            };
501            let cache = MetadataCache::new();
502            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
503            // This should not fail, but also does nothing
504            assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
505            assert_eq!(
506                lazy_metadata
507                    .authorization_endpoint()
508                    .await
509                    .unwrap()
510                    .as_str(),
511                "https://example.com/authorize_override"
512            );
513            assert!(matches!(
514                lazy_metadata.token_endpoint().await,
515                Err(DiscoveryError::Disabled),
516            ));
517            // This did not trigger a fetch
518            calls += 0;
519        }
520
521        assert_eq!(calls, expected_calls);
522    }
523}