rustical_dav/header/
overwrite.rs1use axum::{body::Body, extract::FromRequestParts, response::IntoResponse};
2use thiserror::Error;
3
4#[derive(Error, Debug)]
5#[error("Invalid Overwrite header")]
6pub struct InvalidOverwriteHeader;
7
8impl IntoResponse for InvalidOverwriteHeader {
9 fn into_response(self) -> axum::response::Response {
10 axum::response::Response::builder()
11 .status(axum::http::StatusCode::BAD_REQUEST)
12 .body(Body::new("Invalid Overwrite header".to_string()))
13 .expect("this always works")
14 }
15}
16
17#[derive(Debug, PartialEq, Eq)]
18pub struct Overwrite(pub bool);
19
20impl Default for Overwrite {
21 fn default() -> Self {
22 Self(true)
23 }
24}
25
26impl<S: Send + Sync> FromRequestParts<S> for Overwrite {
27 type Rejection = InvalidOverwriteHeader;
28
29 async fn from_request_parts(
30 parts: &mut axum::http::request::Parts,
31 _state: &S,
32 ) -> Result<Self, Self::Rejection> {
33 parts.headers.get("Overwrite").map_or_else(
34 || Ok(Self::default()),
35 |overwrite_header| overwrite_header.as_bytes().try_into(),
36 )
37 }
38}
39
40impl TryFrom<&[u8]> for Overwrite {
41 type Error = InvalidOverwriteHeader;
42
43 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
44 match value {
45 b"T" => Ok(Self(true)),
46 b"F" => Ok(Self(false)),
47 _ => Err(InvalidOverwriteHeader),
48 }
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use axum::{extract::FromRequestParts, response::IntoResponse};
55 use http::Request;
56
57 use crate::header::Overwrite;
58
59 #[tokio::test]
60 async fn test_overwrite_default() {
61 let request = Request::put("asd").body(()).unwrap();
62 let (mut parts, ()) = request.into_parts();
63 let overwrite = Overwrite::from_request_parts(&mut parts, &())
64 .await
65 .unwrap();
66 assert_eq!(
67 Overwrite(true),
68 overwrite,
69 "By default we want to overwrite!"
70 );
71 }
72
73 #[test]
74 fn test_overwrite() {
75 assert_eq!(
76 Overwrite(true),
77 Overwrite::try_from(b"T".as_slice()).unwrap()
78 );
79 assert_eq!(
80 Overwrite(false),
81 Overwrite::try_from(b"F".as_slice()).unwrap()
82 );
83 if let Err(err) = Overwrite::try_from(b"aslkdjlad".as_slice()) {
84 let _ = err.into_response();
85 } else {
86 unreachable!("should return error")
87 }
88 }
89}