rustical_dav/header/
depth.rs1use axum::{body::Body, response::IntoResponse};
2use headers::Header;
3use http::{HeaderName, HeaderValue};
4use rustical_xml::{ValueDeserialize, ValueSerialize, XmlError};
5use std::str::FromStr;
6use thiserror::Error;
7
8static DEPTH: HeaderName = HeaderName::from_static("depth");
9
10#[derive(Error, Debug)]
11#[error("Invalid Depth header")]
12pub struct InvalidDepthHeader;
13
14impl IntoResponse for InvalidDepthHeader {
15 fn into_response(self) -> axum::response::Response {
16 axum::response::Response::builder()
17 .status(axum::http::StatusCode::BAD_REQUEST)
18 .body(Body::empty())
19 .expect("this always works")
20 }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Default)]
24pub enum Depth {
25 Zero,
26 One,
27 #[default]
28 Infinity,
29}
30
31impl From<&Depth> for &'static str {
32 fn from(value: &Depth) -> Self {
33 match value {
34 Depth::Zero => "0",
35 Depth::One => "1",
36 Depth::Infinity => "infinity",
37 }
38 }
39}
40
41impl ValueSerialize for Depth {
42 fn serialize(&self) -> String {
43 let string: &str = self.into();
44 string.to_owned()
45 }
46}
47
48impl ValueDeserialize for Depth {
49 fn deserialize(val: &str) -> Result<Self, XmlError> {
50 Self::from_str(val)
51 .map_err(|_| XmlError::InvalidVariant("Invalid value for depth".to_owned()))
52 }
53}
54
55impl FromStr for Depth {
56 type Err = InvalidDepthHeader;
57
58 fn from_str(s: &str) -> Result<Self, Self::Err> {
59 match s {
60 "0" => Ok(Self::Zero),
61 "1" => Ok(Self::One),
62 "Infinity" | "infinity" => Ok(Self::Infinity),
63 _ => Err(InvalidDepthHeader),
64 }
65 }
66}
67
68impl Header for Depth {
69 fn name() -> &'static HeaderName {
70 &DEPTH
71 }
72
73 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
74 values.extend(std::iter::once(HeaderValue::from_static(self.into())));
75 }
76
77 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
78 where
79 Self: Sized,
80 I: Iterator<Item = &'i HeaderValue>,
81 {
82 let Some(val) = values.next() else {
83 return Err(headers::Error::invalid());
84 };
85 if values.next().is_some() {
86 return Err(headers::Error::invalid());
87 }
88 let val = val.to_str().map_err(|_| headers::Error::invalid())?;
89 Self::from_str(val).map_err(|_| headers::Error::invalid())
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::Depth;
96 use axum::{body::Body, extract::FromRequest};
97 use axum_extra::TypedHeader;
98 use http::Request;
99
100 #[tokio::test]
101 #[rstest::rstest]
102 #[case("0", Depth::Zero)]
103 #[case("1", Depth::One)]
104 #[case("infinity", Depth::Infinity)]
105 async fn test_depth_header(#[case] input: &str, #[case] header: Depth) {
106 let request = Request::builder()
107 .method("GET")
108 .header("Depth", input)
109 .body(Body::empty())
110 .unwrap();
111 let TypedHeader(depth) = TypedHeader::<Depth>::from_request(request, &())
112 .await
113 .unwrap();
114 assert_eq!(depth, header);
115 }
116
117 #[tokio::test]
118 async fn test_invalid_depth_header() {
119 let request = Request::builder()
120 .method("GET")
121 .header("Depth", "asldkj")
122 .body(Body::empty())
123 .unwrap();
124 assert!(
125 TypedHeader::<Depth>::from_request(request, &())
126 .await
127 .is_err()
128 );
129 }
130}