mas_config/sections/
database.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{num::NonZeroU32, time::Duration};
8
9use camino::Utf8PathBuf;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14use super::ConfigurationSection;
15use crate::schema;
16
17#[allow(clippy::unnecessary_wraps)]
18fn default_connection_string() -> Option<String> {
19    Some("postgresql://".to_owned())
20}
21
22fn default_max_connections() -> NonZeroU32 {
23    NonZeroU32::new(10).unwrap()
24}
25
26fn default_connect_timeout() -> Duration {
27    Duration::from_secs(30)
28}
29
30#[allow(clippy::unnecessary_wraps)]
31fn default_idle_timeout() -> Option<Duration> {
32    Some(Duration::from_secs(10 * 60))
33}
34
35#[allow(clippy::unnecessary_wraps)]
36fn default_max_lifetime() -> Option<Duration> {
37    Some(Duration::from_secs(30 * 60))
38}
39
40impl Default for DatabaseConfig {
41    fn default() -> Self {
42        Self {
43            uri: default_connection_string(),
44            host: None,
45            port: None,
46            socket: None,
47            username: None,
48            password: None,
49            database: None,
50            ssl_mode: None,
51            ssl_ca: None,
52            ssl_ca_file: None,
53            ssl_certificate: None,
54            ssl_certificate_file: None,
55            ssl_key: None,
56            ssl_key_file: None,
57            max_connections: default_max_connections(),
58            min_connections: Default::default(),
59            connect_timeout: default_connect_timeout(),
60            idle_timeout: default_idle_timeout(),
61            max_lifetime: default_max_lifetime(),
62        }
63    }
64}
65
66/// Options for controlling the level of protection provided for PostgreSQL SSL
67/// connections.
68#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
69#[serde(rename_all = "kebab-case")]
70pub enum PgSslMode {
71    /// Only try a non-SSL connection.
72    Disable,
73
74    /// First try a non-SSL connection; if that fails, try an SSL connection.
75    Allow,
76
77    /// First try an SSL connection; if that fails, try a non-SSL connection.
78    Prefer,
79
80    /// Only try an SSL connection. If a root CA file is present, verify the
81    /// connection in the same way as if `VerifyCa` was specified.
82    Require,
83
84    /// Only try an SSL connection, and verify that the server certificate is
85    /// issued by a trusted certificate authority (CA).
86    VerifyCa,
87
88    /// Only try an SSL connection; verify that the server certificate is issued
89    /// by a trusted CA and that the requested server host name matches that
90    /// in the certificate.
91    VerifyFull,
92}
93
94/// Database connection configuration
95#[serde_as]
96#[derive(Debug, Serialize, Deserialize, JsonSchema)]
97pub struct DatabaseConfig {
98    /// Connection URI
99    ///
100    /// This must not be specified if `host`, `port`, `socket`, `username`,
101    /// `password`, or `database` are specified.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    #[schemars(url, default = "default_connection_string")]
104    pub uri: Option<String>,
105
106    /// Name of host to connect to
107    ///
108    /// This must not be specified if `uri` is specified.
109    #[serde(skip_serializing_if = "Option::is_none")]
110    #[schemars(with = "Option::<schema::Hostname>")]
111    pub host: Option<String>,
112
113    /// Port number to connect at the server host
114    ///
115    /// This must not be specified if `uri` is specified.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    #[schemars(range(min = 1, max = 65535))]
118    pub port: Option<u16>,
119
120    /// Directory containing the UNIX socket to connect to
121    ///
122    /// This must not be specified if `uri` is specified.
123    #[serde(skip_serializing_if = "Option::is_none")]
124    #[schemars(with = "Option<String>")]
125    pub socket: Option<Utf8PathBuf>,
126
127    /// PostgreSQL user name to connect as
128    ///
129    /// This must not be specified if `uri` is specified.
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub username: Option<String>,
132
133    /// Password to be used if the server demands password authentication
134    ///
135    /// This must not be specified if `uri` is specified.
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub password: Option<String>,
138
139    /// The database name
140    ///
141    /// This must not be specified if `uri` is specified.
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub database: Option<String>,
144
145    /// How to handle SSL connections
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub ssl_mode: Option<PgSslMode>,
148
149    /// The PEM-encoded root certificate for SSL connections
150    ///
151    /// This must not be specified if the `ssl_ca_file` option is specified.
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub ssl_ca: Option<String>,
154
155    /// Path to the root certificate for SSL connections
156    ///
157    /// This must not be specified if the `ssl_ca` option is specified.
158    #[serde(skip_serializing_if = "Option::is_none")]
159    #[schemars(with = "Option<String>")]
160    pub ssl_ca_file: Option<Utf8PathBuf>,
161
162    /// The PEM-encoded client certificate for SSL connections
163    ///
164    /// This must not be specified if the `ssl_certificate_file` option is
165    /// specified.
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub ssl_certificate: Option<String>,
168
169    /// Path to the client certificate for SSL connections
170    ///
171    /// This must not be specified if the `ssl_certificate` option is specified.
172    #[serde(skip_serializing_if = "Option::is_none")]
173    #[schemars(with = "Option<String>")]
174    pub ssl_certificate_file: Option<Utf8PathBuf>,
175
176    /// The PEM-encoded client key for SSL connections
177    ///
178    /// This must not be specified if the `ssl_key_file` option is specified.
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub ssl_key: Option<String>,
181
182    /// Path to the client key for SSL connections
183    ///
184    /// This must not be specified if the `ssl_key` option is specified.
185    #[serde(skip_serializing_if = "Option::is_none")]
186    #[schemars(with = "Option<String>")]
187    pub ssl_key_file: Option<Utf8PathBuf>,
188
189    /// Set the maximum number of connections the pool should maintain
190    #[serde(default = "default_max_connections")]
191    pub max_connections: NonZeroU32,
192
193    /// Set the minimum number of connections the pool should maintain
194    #[serde(default)]
195    pub min_connections: u32,
196
197    /// Set the amount of time to attempt connecting to the database
198    #[schemars(with = "u64")]
199    #[serde(default = "default_connect_timeout")]
200    #[serde_as(as = "serde_with::DurationSeconds<u64>")]
201    pub connect_timeout: Duration,
202
203    /// Set a maximum idle duration for individual connections
204    #[schemars(with = "Option<u64>")]
205    #[serde(
206        default = "default_idle_timeout",
207        skip_serializing_if = "Option::is_none"
208    )]
209    #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
210    pub idle_timeout: Option<Duration>,
211
212    /// Set the maximum lifetime of individual connections
213    #[schemars(with = "u64")]
214    #[serde(
215        default = "default_max_lifetime",
216        skip_serializing_if = "Option::is_none"
217    )]
218    #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
219    pub max_lifetime: Option<Duration>,
220}
221
222impl ConfigurationSection for DatabaseConfig {
223    const PATH: Option<&'static str> = Some("database");
224
225    fn validate(&self, figment: &figment::Figment) -> Result<(), figment::error::Error> {
226        let metadata = figment.find_metadata(Self::PATH.unwrap());
227        let annotate = |mut error: figment::Error| {
228            error.metadata = metadata.cloned();
229            error.profile = Some(figment::Profile::Default);
230            error.path = vec![Self::PATH.unwrap().to_owned()];
231            Err(error)
232        };
233
234        // Check that the user did not specify both `uri` and the split options at the
235        // same time
236        let has_split_options = self.host.is_some()
237            || self.port.is_some()
238            || self.socket.is_some()
239            || self.username.is_some()
240            || self.password.is_some()
241            || self.database.is_some();
242
243        if self.uri.is_some() && has_split_options {
244            return annotate(figment::error::Error::from(
245                "uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
246            ));
247        }
248
249        if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
250            return annotate(figment::error::Error::from(
251                "ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
252            ));
253        }
254
255        if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
256            return annotate(figment::error::Error::from(
257                "ssl_certificate must not be specified if ssl_certificate_file is specified"
258                    .to_owned(),
259            ));
260        }
261
262        if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
263            return annotate(figment::error::Error::from(
264                "ssl_key must not be specified if ssl_key_file is specified".to_owned(),
265            ));
266        }
267
268        if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
269            ^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
270        {
271            return annotate(figment::error::Error::from(
272                "both a ssl_certificate and a ssl_key must be set at the same time or none of them"
273                    .to_owned(),
274            ));
275        }
276
277        Ok(())
278    }
279}
280#[cfg(test)]
281mod tests {
282    use figment::{
283        Figment, Jail,
284        providers::{Format, Yaml},
285    };
286
287    use super::*;
288
289    #[test]
290    fn load_config() {
291        Jail::expect_with(|jail| {
292            jail.create_file(
293                "config.yaml",
294                r"
295                    database:
296                      uri: postgresql://user:password@host/database
297                ",
298            )?;
299
300            let config = Figment::new()
301                .merge(Yaml::file("config.yaml"))
302                .extract_inner::<DatabaseConfig>("database")?;
303
304            assert_eq!(
305                config.uri.as_deref(),
306                Some("postgresql://user:password@host/database")
307            );
308
309            Ok(())
310        });
311    }
312}