1use std::collections::HashSet;
8
9use mas_iana::jose::{JsonWebKeyType, JsonWebKeyUse, JsonWebSignatureAlg};
10
11use crate::jwt::JsonWebSignatureHeader;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum Constraint<'a> {
15 Alg {
16 constraint_alg: &'a JsonWebSignatureAlg,
17 },
18
19 Algs {
20 constraint_algs: &'a [JsonWebSignatureAlg],
21 },
22
23 Kid {
24 constraint_kid: &'a str,
25 },
26
27 Use {
28 constraint_use: &'a JsonWebKeyUse,
29 },
30
31 Kty {
32 constraint_kty: &'a JsonWebKeyType,
33 },
34}
35
36impl<'a> Constraint<'a> {
37 #[must_use]
38 pub fn alg(constraint_alg: &'a JsonWebSignatureAlg) -> Self {
39 Constraint::Alg { constraint_alg }
40 }
41
42 #[must_use]
43 pub fn algs(constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
44 Constraint::Algs { constraint_algs }
45 }
46
47 #[must_use]
48 pub fn kid(constraint_kid: &'a str) -> Self {
49 Constraint::Kid { constraint_kid }
50 }
51
52 #[must_use]
53 pub fn use_(constraint_use: &'a JsonWebKeyUse) -> Self {
54 Constraint::Use { constraint_use }
55 }
56
57 #[must_use]
58 pub fn kty(constraint_kty: &'a JsonWebKeyType) -> Self {
59 Constraint::Kty { constraint_kty }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum ConstraintDecision {
65 Positive,
66 Neutral,
67 Negative,
68}
69
70pub trait Constrainable {
71 fn alg(&self) -> Option<&JsonWebSignatureAlg> {
72 None
73 }
74
75 fn algs(&self) -> &[JsonWebSignatureAlg] {
77 &[]
78 }
79
80 fn kid(&self) -> Option<&str> {
82 None
83 }
84
85 fn use_(&self) -> Option<&JsonWebKeyUse> {
87 None
88 }
89
90 fn kty(&self) -> JsonWebKeyType;
92}
93
94impl Constraint<'_> {
95 fn decide<T: Constrainable>(&self, constrainable: &T) -> ConstraintDecision {
96 match self {
97 Constraint::Alg { constraint_alg } => {
98 if let Some(alg) = constrainable.alg() {
100 if alg == *constraint_alg {
101 ConstraintDecision::Positive
102 } else {
103 ConstraintDecision::Negative
104 }
105 } else if constrainable.algs().contains(constraint_alg) {
108 ConstraintDecision::Neutral
109 } else {
110 ConstraintDecision::Negative
111 }
112 }
113 Constraint::Algs { constraint_algs } => {
114 if let Some(alg) = constrainable.alg() {
115 if constraint_algs.contains(alg) {
116 ConstraintDecision::Positive
117 } else {
118 ConstraintDecision::Negative
119 }
120 } else if constrainable
121 .algs()
122 .iter()
123 .any(|alg| constraint_algs.contains(alg))
124 {
125 ConstraintDecision::Neutral
126 } else {
127 ConstraintDecision::Negative
128 }
129 }
130 Constraint::Kid { constraint_kid } => {
131 if let Some(kid) = constrainable.kid() {
132 if kid == *constraint_kid {
133 ConstraintDecision::Positive
134 } else {
135 ConstraintDecision::Negative
136 }
137 } else {
138 ConstraintDecision::Neutral
139 }
140 }
141 Constraint::Use { constraint_use } => {
142 if let Some(use_) = constrainable.use_() {
143 if use_ == *constraint_use {
144 ConstraintDecision::Positive
145 } else {
146 ConstraintDecision::Negative
147 }
148 } else {
149 ConstraintDecision::Neutral
150 }
151 }
152 Constraint::Kty { constraint_kty } => {
153 if **constraint_kty == constrainable.kty() {
154 ConstraintDecision::Positive
155 } else {
156 ConstraintDecision::Negative
157 }
158 }
159 }
160 }
161}
162
163#[derive(Default)]
164pub struct ConstraintSet<'a> {
165 constraints: HashSet<Constraint<'a>>,
166}
167
168impl<'a> FromIterator<Constraint<'a>> for ConstraintSet<'a> {
169 fn from_iter<T: IntoIterator<Item = Constraint<'a>>>(iter: T) -> Self {
170 Self {
171 constraints: HashSet::from_iter(iter),
172 }
173 }
174}
175
176#[allow(dead_code)]
177impl<'a> ConstraintSet<'a> {
178 pub fn new(constraints: impl IntoIterator<Item = Constraint<'a>>) -> Self {
179 constraints.into_iter().collect()
180 }
181
182 pub fn filter<'b, T: Constrainable, I: IntoIterator<Item = &'b T>>(
183 &self,
184 constrainables: I,
185 ) -> Vec<&'b T> {
186 let mut selected = Vec::new();
187
188 'outer: for constrainable in constrainables {
189 let mut score = 0;
190
191 for constraint in &self.constraints {
192 match constraint.decide(constrainable) {
193 ConstraintDecision::Positive => score += 1,
194 ConstraintDecision::Neutral => {}
195 ConstraintDecision::Negative => continue 'outer,
197 }
198 }
199
200 selected.push((score, constrainable));
201 }
202
203 selected.sort_by_key(|(score, _)| *score);
204
205 selected
206 .into_iter()
207 .map(|(_score, constrainable)| constrainable)
208 .collect()
209 }
210
211 #[must_use]
212 pub fn alg(mut self, constraint_alg: &'a JsonWebSignatureAlg) -> Self {
213 self.constraints.insert(Constraint::alg(constraint_alg));
214 self
215 }
216
217 #[must_use]
218 pub fn algs(mut self, constraint_algs: &'a [JsonWebSignatureAlg]) -> Self {
219 self.constraints.insert(Constraint::algs(constraint_algs));
220 self
221 }
222
223 #[must_use]
224 pub fn kid(mut self, constraint_kid: &'a str) -> Self {
225 self.constraints.insert(Constraint::kid(constraint_kid));
226 self
227 }
228
229 #[must_use]
230 pub fn use_(mut self, constraint_use: &'a JsonWebKeyUse) -> Self {
231 self.constraints.insert(Constraint::use_(constraint_use));
232 self
233 }
234
235 #[must_use]
236 pub fn kty(mut self, constraint_kty: &'a JsonWebKeyType) -> Self {
237 self.constraints.insert(Constraint::kty(constraint_kty));
238 self
239 }
240}
241
242impl<'a> From<&'a JsonWebSignatureHeader> for ConstraintSet<'a> {
243 fn from(header: &'a JsonWebSignatureHeader) -> Self {
244 let mut constraints = Self::default().alg(header.alg());
245
246 if let Some(kid) = header.kid() {
247 constraints = constraints.kid(kid);
248 }
249
250 constraints
251 }
252}