Skip to main content

rustical_dav/header/
overwrite.rs

1use axum::{body::Body, response::IntoResponse};
2use derive_more::{From, Into};
3use headers::Header;
4use http::{HeaderName, HeaderValue};
5use std::str::FromStr;
6use thiserror::Error;
7
8static OVERWRITE: HeaderName = HeaderName::from_static("overwrite");
9
10#[derive(Error, Debug)]
11#[error("Invalid Overwrite header")]
12pub struct InvalidOverwriteHeader;
13
14impl IntoResponse for InvalidOverwriteHeader {
15    fn into_response(self) -> axum::response::Response {
16        axum::response::Response::builder()
17            .status(axum::http::StatusCode::BAD_REQUEST)
18            .body(Body::new("Invalid Overwrite header".to_string()))
19            .expect("this always works")
20    }
21}
22
23#[derive(Debug, PartialEq, Eq, From, Into)]
24pub struct Overwrite(pub bool);
25
26impl Default for Overwrite {
27    fn default() -> Self {
28        Self(true)
29    }
30}
31
32impl From<&Overwrite> for &'static str {
33    fn from(value: &Overwrite) -> Self {
34        if value.0 { "T" } else { "F" }
35    }
36}
37
38impl FromStr for Overwrite {
39    type Err = InvalidOverwriteHeader;
40
41    fn from_str(s: &str) -> Result<Self, Self::Err> {
42        match s {
43            "T" => Ok(Self(true)),
44            "F" => Ok(Self(false)),
45            _ => Err(InvalidOverwriteHeader),
46        }
47    }
48}
49
50impl Header for Overwrite {
51    fn name() -> &'static HeaderName {
52        &OVERWRITE
53    }
54
55    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
56        values.extend(std::iter::once(HeaderValue::from_static(self.into())));
57    }
58
59    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
60    where
61        Self: Sized,
62        I: Iterator<Item = &'i HeaderValue>,
63    {
64        let Some(val) = values.next() else {
65            return Err(headers::Error::invalid());
66        };
67        if values.next().is_some() {
68            return Err(headers::Error::invalid());
69        }
70        let val = val.to_str().map_err(|_| headers::Error::invalid())?;
71        Self::from_str(val).map_err(|_| headers::Error::invalid())
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::Overwrite;
78    use axum::{body::Body, extract::FromRequest};
79    use axum_extra::TypedHeader;
80    use http::Request;
81
82    #[tokio::test]
83    #[rstest::rstest]
84    #[case("T", Overwrite(true))]
85    #[case("F", Overwrite(false))]
86    async fn test_overwrite_header(#[case] input: &str, #[case] header: Overwrite) {
87        let request = Request::builder()
88            .method("GET")
89            .header("Overwrite", input)
90            .body(Body::empty())
91            .unwrap();
92        let TypedHeader(depth) = TypedHeader::<Overwrite>::from_request(request, &())
93            .await
94            .unwrap();
95        assert_eq!(depth, header);
96    }
97
98    #[tokio::test]
99    async fn test_invalid_overwrite_header() {
100        let request = Request::builder()
101            .method("GET")
102            .header("Overwrite", "asldkj")
103            .body(Body::empty())
104            .unwrap();
105        assert!(
106            TypedHeader::<Overwrite>::from_request(request, &())
107                .await
108                .is_err()
109        );
110    }
111}