mas_jose/
constraints.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::collections::HashSet;
8
9use mas_iana::jose::{JsonWebKeyType, JsonWebKeyUse, JsonWebSignatureAlg};
10
11use crate::jwt::JsonWebSignatureHeader;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum Constraint<'a> {
15    Alg {
16        constraint_alg: &'a JsonWebSignatureAlg,
17    },
18
19    Algs {
20        constraint_algs: &'a [JsonWebSignatureAlg],
21    },
22
23    Kid {
24        constraint_kid: &'a str,
25    },
26
27    Use {
28        constraint_use: &'a JsonWebKeyUse,
29    },
30
31    Kty {
32        constraint_kty: &'a JsonWebKeyType,
33    },
34}
35
36impl<'a> Constraint<'a> {
37    #[must_use]
38    pub fn alg(constraint_alg: &'a JsonWebSignatureAlg) -> Self {
39        Constraint::Alg { constraint_alg }
40    }
41
42    #[must_use]
43    pub fn algs(constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
44        Constraint::Algs { constraint_algs }
45    }
46
47    #[must_use]
48    pub fn kid(constraint_kid: &'a str) -> Self {
49        Constraint::Kid { constraint_kid }
50    }
51
52    #[must_use]
53    pub fn use_(constraint_use: &'a JsonWebKeyUse) -> Self {
54        Constraint::Use { constraint_use }
55    }
56
57    #[must_use]
58    pub fn kty(constraint_kty: &'a JsonWebKeyType) -> Self {
59        Constraint::Kty { constraint_kty }
60    }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum ConstraintDecision {
65    Positive,
66    Neutral,
67    Negative,
68}
69
70pub trait Constrainable {
71    fn alg(&self) -> Option<&JsonWebSignatureAlg> {
72        None
73    }
74
75    /// List of available algorithms for this key
76    fn algs(&self) -> &[JsonWebSignatureAlg] {
77        &[]
78    }
79
80    /// Key ID (`kid`) of this key
81    fn kid(&self) -> Option<&str> {
82        None
83    }
84
85    /// Usage specified for this key
86    fn use_(&self) -> Option<&JsonWebKeyUse> {
87        None
88    }
89
90    /// Key type (`kty`) of this key
91    fn kty(&self) -> JsonWebKeyType;
92}
93
94impl Constraint<'_> {
95    fn decide<T: Constrainable>(&self, constrainable: &T) -> ConstraintDecision {
96        match self {
97            Constraint::Alg { constraint_alg } => {
98                // If the constrainable has one specific alg defined, use that
99                if let Some(alg) = constrainable.alg() {
100                    if alg == *constraint_alg {
101                        ConstraintDecision::Positive
102                    } else {
103                        ConstraintDecision::Negative
104                    }
105                // If not, check that the requested alg is valid for this
106                // constrainable
107                } else if constrainable.algs().contains(constraint_alg) {
108                    ConstraintDecision::Neutral
109                } else {
110                    ConstraintDecision::Negative
111                }
112            }
113            Constraint::Algs { constraint_algs } => {
114                if let Some(alg) = constrainable.alg() {
115                    if constraint_algs.contains(alg) {
116                        ConstraintDecision::Positive
117                    } else {
118                        ConstraintDecision::Negative
119                    }
120                } else if constrainable
121                    .algs()
122                    .iter()
123                    .any(|alg| constraint_algs.contains(alg))
124                {
125                    ConstraintDecision::Neutral
126                } else {
127                    ConstraintDecision::Negative
128                }
129            }
130            Constraint::Kid { constraint_kid } => {
131                if let Some(kid) = constrainable.kid() {
132                    if kid == *constraint_kid {
133                        ConstraintDecision::Positive
134                    } else {
135                        ConstraintDecision::Negative
136                    }
137                } else {
138                    ConstraintDecision::Neutral
139                }
140            }
141            Constraint::Use { constraint_use } => {
142                if let Some(use_) = constrainable.use_() {
143                    if use_ == *constraint_use {
144                        ConstraintDecision::Positive
145                    } else {
146                        ConstraintDecision::Negative
147                    }
148                } else {
149                    ConstraintDecision::Neutral
150                }
151            }
152            Constraint::Kty { constraint_kty } => {
153                if **constraint_kty == constrainable.kty() {
154                    ConstraintDecision::Positive
155                } else {
156                    ConstraintDecision::Negative
157                }
158            }
159        }
160    }
161}
162
163#[derive(Default)]
164pub struct ConstraintSet<'a> {
165    constraints: HashSet<Constraint<'a>>,
166}
167
168impl<'a> FromIterator<Constraint<'a>> for ConstraintSet<'a> {
169    fn from_iter<T: IntoIterator<Item = Constraint<'a>>>(iter: T) -> Self {
170        Self {
171            constraints: HashSet::from_iter(iter),
172        }
173    }
174}
175
176#[allow(dead_code)]
177impl<'a> ConstraintSet<'a> {
178    pub fn new(constraints: impl IntoIterator<Item = Constraint<'a>>) -> Self {
179        constraints.into_iter().collect()
180    }
181
182    pub fn filter<'b, T: Constrainable, I: IntoIterator<Item = &'b T>>(
183        &self,
184        constrainables: I,
185    ) -> Vec<&'b T> {
186        let mut selected = Vec::new();
187
188        'outer: for constrainable in constrainables {
189            let mut score = 0;
190
191            for constraint in &self.constraints {
192                match constraint.decide(constrainable) {
193                    ConstraintDecision::Positive => score += 1,
194                    ConstraintDecision::Neutral => {}
195                    // If any constraint was negative, don't add it to the candidates
196                    ConstraintDecision::Negative => continue 'outer,
197                }
198            }
199
200            selected.push((score, constrainable));
201        }
202
203        selected.sort_by_key(|(score, _)| *score);
204
205        selected
206            .into_iter()
207            .map(|(_score, constrainable)| constrainable)
208            .collect()
209    }
210
211    #[must_use]
212    pub fn alg(mut self, constraint_alg: &'a JsonWebSignatureAlg) -> Self {
213        self.constraints.insert(Constraint::alg(constraint_alg));
214        self
215    }
216
217    #[must_use]
218    pub fn algs(mut self, constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
219        self.constraints.insert(Constraint::algs(constraint_algs));
220        self
221    }
222
223    #[must_use]
224    pub fn kid(mut self, constraint_kid: &'a str) -> Self {
225        self.constraints.insert(Constraint::kid(constraint_kid));
226        self
227    }
228
229    #[must_use]
230    pub fn use_(mut self, constraint_use: &'a JsonWebKeyUse) -> Self {
231        self.constraints.insert(Constraint::use_(constraint_use));
232        self
233    }
234
235    #[must_use]
236    pub fn kty(mut self, constraint_kty: &'a JsonWebKeyType) -> Self {
237        self.constraints.insert(Constraint::kty(constraint_kty));
238        self
239    }
240}
241
242impl<'a> From<&'a JsonWebSignatureHeader> for ConstraintSet<'a> {
243    fn from(header: &'a JsonWebSignatureHeader) -> Self {
244        let mut constraints = Self::default().alg(header.alg());
245
246        if let Some(kid) = header.kid() {
247            constraints = constraints.kid(kid);
248        }
249
250        constraints
251    }
252}