1use std::borrow::Cow;
12
13use base64ct::{Base64UrlUnpadded, Encoding};
14use mas_iana::oauth::PkceCodeChallengeMethod;
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17use thiserror::Error;
18
19#[derive(Debug, Error, PartialEq, Eq)]
21pub enum CodeChallengeError {
22 #[error("code_verifier should be at least 43 characters long")]
24 TooShort,
25
26 #[error("code_verifier should be at most 128 characters long")]
28 TooLong,
29
30 #[error("code_verifier contains invalid characters")]
32 InvalidCharacters,
33
34 #[error("challenge verification failed")]
36 VerificationFailed,
37
38 #[error("unknown challenge method")]
40 UnknownChallengeMethod,
41}
42
43fn validate_verifier(verifier: &str) -> Result<(), CodeChallengeError> {
44 if verifier.len() < 43 {
45 return Err(CodeChallengeError::TooShort);
46 }
47
48 if verifier.len() > 128 {
49 return Err(CodeChallengeError::TooLong);
50 }
51
52 if !verifier
53 .chars()
54 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
55 {
56 return Err(CodeChallengeError::InvalidCharacters);
57 }
58
59 Ok(())
60}
61
62pub trait CodeChallengeMethodExt {
64 fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError>;
71
72 fn verify(&self, challenge: &str, verifier: &str) -> Result<(), CodeChallengeError>
80 where
81 Self: Sized,
82 {
83 if self.compute_challenge(verifier)? == challenge {
84 Ok(())
85 } else {
86 Err(CodeChallengeError::VerificationFailed)
87 }
88 }
89}
90
91impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
92 fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError> {
93 validate_verifier(verifier)?;
94
95 let challenge = match self {
96 Self::Plain => verifier.into(),
97 Self::S256 => {
98 let mut hasher = Sha256::new();
99 hasher.update(verifier.as_bytes());
100 let hash = hasher.finalize();
101 let verifier = Base64UrlUnpadded::encode_string(&hash);
102 verifier.into()
103 }
104 _ => return Err(CodeChallengeError::UnknownChallengeMethod),
105 };
106
107 Ok(challenge)
108 }
109}
110
111#[derive(Clone, Serialize, Deserialize)]
113pub struct AuthorizationRequest {
114 pub code_challenge_method: PkceCodeChallengeMethod,
116
117 pub code_challenge: String,
119}
120
121#[derive(Clone, Serialize, Deserialize)]
123pub struct TokenRequest {
124 pub code_challenge_verifier: String,
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_pkce_verification() {
134 use PkceCodeChallengeMethod::{Plain, S256};
135 let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
137
138 assert!(
139 S256.verify(challenge, "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
140 .is_ok()
141 );
142
143 assert!(Plain.verify(challenge, challenge).is_ok());
144
145 assert_eq!(
146 S256.verify(challenge, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"),
147 Err(CodeChallengeError::VerificationFailed),
148 );
149
150 assert_eq!(
151 S256.verify(challenge, "tooshort"),
152 Err(CodeChallengeError::TooShort),
153 );
154
155 assert_eq!(
156 S256.verify(challenge, "toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"),
157 Err(CodeChallengeError::TooLong),
158 );
159
160 assert_eq!(
161 S256.verify(
162 challenge,
163 "this is long enough but has invalid characters in it"
164 ),
165 Err(CodeChallengeError::InvalidCharacters),
166 );
167 }
168}