Skip to main content

rustical_dav/header/
depth.rs

1use axum::{body::Body, response::IntoResponse};
2use headers::Header;
3use http::{HeaderName, HeaderValue};
4use rustical_xml::{ValueDeserialize, ValueSerialize, XmlError};
5use std::str::FromStr;
6use thiserror::Error;
7
8static DEPTH: HeaderName = HeaderName::from_static("depth");
9
10#[derive(Error, Debug)]
11#[error("Invalid Depth header")]
12pub struct InvalidDepthHeader;
13
14impl IntoResponse for InvalidDepthHeader {
15    fn into_response(self) -> axum::response::Response {
16        axum::response::Response::builder()
17            .status(axum::http::StatusCode::BAD_REQUEST)
18            .body(Body::empty())
19            .expect("this always works")
20    }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Default)]
24pub enum Depth {
25    Zero,
26    One,
27    #[default]
28    Infinity,
29}
30
31impl From<&Depth> for &'static str {
32    fn from(value: &Depth) -> Self {
33        match value {
34            Depth::Zero => "0",
35            Depth::One => "1",
36            Depth::Infinity => "infinity",
37        }
38    }
39}
40
41impl ValueSerialize for Depth {
42    fn serialize(&self) -> String {
43        let string: &str = self.into();
44        string.to_owned()
45    }
46}
47
48impl ValueDeserialize for Depth {
49    fn deserialize(val: &str) -> Result<Self, XmlError> {
50        Self::from_str(val)
51            .map_err(|_| XmlError::InvalidVariant("Invalid value for depth".to_owned()))
52    }
53}
54
55impl FromStr for Depth {
56    type Err = InvalidDepthHeader;
57
58    fn from_str(s: &str) -> Result<Self, Self::Err> {
59        match s {
60            "0" => Ok(Self::Zero),
61            "1" => Ok(Self::One),
62            "Infinity" | "infinity" => Ok(Self::Infinity),
63            _ => Err(InvalidDepthHeader),
64        }
65    }
66}
67
68impl Header for Depth {
69    fn name() -> &'static HeaderName {
70        &DEPTH
71    }
72
73    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
74        values.extend(std::iter::once(HeaderValue::from_static(self.into())));
75    }
76
77    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
78    where
79        Self: Sized,
80        I: Iterator<Item = &'i HeaderValue>,
81    {
82        let Some(val) = values.next() else {
83            return Err(headers::Error::invalid());
84        };
85        if values.next().is_some() {
86            return Err(headers::Error::invalid());
87        }
88        let val = val.to_str().map_err(|_| headers::Error::invalid())?;
89        Self::from_str(val).map_err(|_| headers::Error::invalid())
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::Depth;
96    use axum::{body::Body, extract::FromRequest};
97    use axum_extra::TypedHeader;
98    use http::Request;
99
100    #[tokio::test]
101    #[rstest::rstest]
102    #[case("0", Depth::Zero)]
103    #[case("1", Depth::One)]
104    #[case("infinity", Depth::Infinity)]
105    async fn test_depth_header(#[case] input: &str, #[case] header: Depth) {
106        let request = Request::builder()
107            .method("GET")
108            .header("Depth", input)
109            .body(Body::empty())
110            .unwrap();
111        let TypedHeader(depth) = TypedHeader::<Depth>::from_request(request, &())
112            .await
113            .unwrap();
114        assert_eq!(depth, header);
115    }
116
117    #[tokio::test]
118    async fn test_invalid_depth_header() {
119        let request = Request::builder()
120            .method("GET")
121            .header("Depth", "asldkj")
122            .body(Body::empty())
123            .unwrap();
124        assert!(
125            TypedHeader::<Depth>::from_request(request, &())
126                .await
127                .is_err()
128        );
129    }
130}