mas_jose/
base64.rs

1//! Transparent base64 encoding / decoding as part of (de)serialization.
2
3use std::{borrow::Cow, fmt, marker::PhantomData, str};
4
5use base64ct::Encoding;
6use serde::{
7    Deserialize, Deserializer, Serialize, Serializer,
8    de::{self, Unexpected, Visitor},
9};
10
11/// A wrapper around `Vec<u8>` that (de)serializes from / to a base64 string.
12///
13/// The generic parameter `C` represents the base64 flavor.
14#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
15pub struct Base64<C = base64ct::Base64> {
16    bytes: Vec<u8>,
17    // Invariant PhantomData, Send + Sync
18    _phantom_conf: PhantomData<fn(C) -> C>,
19}
20
21pub type Base64UrlNoPad = Base64<base64ct::Base64UrlUnpadded>;
22
23impl<C: Encoding> Base64<C> {
24    /// Create a `Base64` instance from raw bytes, to be base64-encoded in
25    /// serialization.
26    #[must_use]
27    pub fn new(bytes: Vec<u8>) -> Self {
28        Self {
29            bytes,
30            _phantom_conf: PhantomData,
31        }
32    }
33
34    /// Get a reference to the raw bytes held by this `Base64` instance.
35    #[must_use]
36    pub fn as_bytes(&self) -> &[u8] {
37        self.bytes.as_ref()
38    }
39
40    /// Encode the bytes contained in this `Base64` instance to unpadded base64.
41    #[must_use]
42    pub fn encode(&self) -> String {
43        C::encode_string(self.as_bytes())
44    }
45
46    /// Get the raw bytes held by this `Base64` instance.
47    #[must_use]
48    pub fn into_inner(self) -> Vec<u8> {
49        self.bytes
50    }
51
52    /// Create a `Base64` instance containing an empty `Vec<u8>`.
53    #[must_use]
54    pub fn empty() -> Self {
55        Self::new(Vec::new())
56    }
57
58    /// Parse some base64-encoded data to create a `Base64` instance.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if the input is not valid base64.
63    pub fn parse(encoded: &str) -> Result<Self, base64ct::Error> {
64        C::decode_vec(encoded).map(Self::new)
65    }
66}
67
68impl<C: Encoding> fmt::Debug for Base64<C> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        self.encode().fmt(f)
71    }
72}
73
74impl<C: Encoding> fmt::Display for Base64<C> {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        self.encode().fmt(f)
77    }
78}
79
80impl<'de, C: Encoding> Deserialize<'de> for Base64<C> {
81    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82    where
83        D: Deserializer<'de>,
84    {
85        let encoded = deserialize_cow_str(deserializer)?;
86        Self::parse(&encoded).map_err(de::Error::custom)
87    }
88}
89
90impl<C: Encoding> Serialize for Base64<C> {
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: Serializer,
94    {
95        serializer.serialize_str(&self.encode())
96    }
97}
98
99/// Deserialize a `Cow<'de, str>`.
100///
101/// Different from serde's implementation of `Deserialize` for `Cow` since it
102/// borrows from the input when possible.
103pub fn deserialize_cow_str<'de, D>(deserializer: D) -> Result<Cow<'de, str>, D::Error>
104where
105    D: Deserializer<'de>,
106{
107    deserializer.deserialize_string(CowStrVisitor)
108}
109
110struct CowStrVisitor;
111
112impl<'de> Visitor<'de> for CowStrVisitor {
113    type Value = Cow<'de, str>;
114
115    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        formatter.write_str("a string")
117    }
118
119    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
120    where
121        E: de::Error,
122    {
123        Ok(Cow::Borrowed(v))
124    }
125
126    fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
127    where
128        E: de::Error,
129    {
130        match str::from_utf8(v) {
131            Ok(s) => Ok(Cow::Borrowed(s)),
132            Err(_) => Err(de::Error::invalid_value(Unexpected::Bytes(v), &self)),
133        }
134    }
135
136    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
137    where
138        E: de::Error,
139    {
140        Ok(Cow::Owned(v.to_owned()))
141    }
142
143    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
144    where
145        E: de::Error,
146    {
147        Ok(Cow::Owned(v))
148    }
149
150    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
151    where
152        E: de::Error,
153    {
154        match str::from_utf8(v) {
155            Ok(s) => Ok(Cow::Owned(s.to_owned())),
156            Err(_) => Err(de::Error::invalid_value(Unexpected::Bytes(v), &self)),
157        }
158    }
159
160    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
161    where
162        E: de::Error,
163    {
164        match String::from_utf8(v) {
165            Ok(s) => Ok(Cow::Owned(s)),
166            Err(e) => Err(de::Error::invalid_value(
167                Unexpected::Bytes(&e.into_bytes()),
168                &self,
169            )),
170        }
171    }
172}