rustical_dav/header/
overwrite.rs

1use axum::{body::Body, extract::FromRequestParts, response::IntoResponse};
2use thiserror::Error;
3
4#[derive(Error, Debug)]
5#[error("Invalid Overwrite header")]
6pub struct InvalidOverwriteHeader;
7
8impl IntoResponse for InvalidOverwriteHeader {
9    fn into_response(self) -> axum::response::Response {
10        axum::response::Response::builder()
11            .status(axum::http::StatusCode::BAD_REQUEST)
12            .body(Body::new("Invalid Overwrite header".to_string()))
13            .expect("this always works")
14    }
15}
16
17#[derive(Debug, PartialEq, Eq)]
18pub struct Overwrite(pub bool);
19
20impl Default for Overwrite {
21    fn default() -> Self {
22        Self(true)
23    }
24}
25
26impl<S: Send + Sync> FromRequestParts<S> for Overwrite {
27    type Rejection = InvalidOverwriteHeader;
28
29    async fn from_request_parts(
30        parts: &mut axum::http::request::Parts,
31        _state: &S,
32    ) -> Result<Self, Self::Rejection> {
33        parts.headers.get("Overwrite").map_or_else(
34            || Ok(Self::default()),
35            |overwrite_header| overwrite_header.as_bytes().try_into(),
36        )
37    }
38}
39
40impl TryFrom<&[u8]> for Overwrite {
41    type Error = InvalidOverwriteHeader;
42
43    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
44        match value {
45            b"T" => Ok(Self(true)),
46            b"F" => Ok(Self(false)),
47            _ => Err(InvalidOverwriteHeader),
48        }
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use axum::{extract::FromRequestParts, response::IntoResponse};
55    use http::Request;
56
57    use crate::header::Overwrite;
58
59    #[tokio::test]
60    async fn test_overwrite_default() {
61        let request = Request::put("asd").body(()).unwrap();
62        let (mut parts, ()) = request.into_parts();
63        let overwrite = Overwrite::from_request_parts(&mut parts, &())
64            .await
65            .unwrap();
66        assert_eq!(
67            Overwrite(true),
68            overwrite,
69            "By default we want to overwrite!"
70        );
71    }
72
73    #[test]
74    fn test_overwrite() {
75        assert_eq!(
76            Overwrite(true),
77            Overwrite::try_from(b"T".as_slice()).unwrap()
78        );
79        assert_eq!(
80            Overwrite(false),
81            Overwrite::try_from(b"F".as_slice()).unwrap()
82        );
83        if let Err(err) = Overwrite::try_from(b"aslkdjlad".as_slice()) {
84            let _ = err.into_response();
85        } else {
86            unreachable!("should return error")
87        }
88    }
89}