rustical_dav/header/
overwrite.rs1use axum::{body::Body, response::IntoResponse};
2use derive_more::{From, Into};
3use headers::Header;
4use http::{HeaderName, HeaderValue};
5use std::str::FromStr;
6use thiserror::Error;
7
8static OVERWRITE: HeaderName = HeaderName::from_static("overwrite");
9
10#[derive(Error, Debug)]
11#[error("Invalid Overwrite header")]
12pub struct InvalidOverwriteHeader;
13
14impl IntoResponse for InvalidOverwriteHeader {
15 fn into_response(self) -> axum::response::Response {
16 axum::response::Response::builder()
17 .status(axum::http::StatusCode::BAD_REQUEST)
18 .body(Body::new("Invalid Overwrite header".to_string()))
19 .expect("this always works")
20 }
21}
22
23#[derive(Debug, PartialEq, Eq, From, Into)]
24pub struct Overwrite(pub bool);
25
26impl Default for Overwrite {
27 fn default() -> Self {
28 Self(true)
29 }
30}
31
32impl From<&Overwrite> for &'static str {
33 fn from(value: &Overwrite) -> Self {
34 if value.0 { "T" } else { "F" }
35 }
36}
37
38impl FromStr for Overwrite {
39 type Err = InvalidOverwriteHeader;
40
41 fn from_str(s: &str) -> Result<Self, Self::Err> {
42 match s {
43 "T" => Ok(Self(true)),
44 "F" => Ok(Self(false)),
45 _ => Err(InvalidOverwriteHeader),
46 }
47 }
48}
49
50impl Header for Overwrite {
51 fn name() -> &'static HeaderName {
52 &OVERWRITE
53 }
54
55 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
56 values.extend(std::iter::once(HeaderValue::from_static(self.into())));
57 }
58
59 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
60 where
61 Self: Sized,
62 I: Iterator<Item = &'i HeaderValue>,
63 {
64 let Some(val) = values.next() else {
65 return Err(headers::Error::invalid());
66 };
67 if values.next().is_some() {
68 return Err(headers::Error::invalid());
69 }
70 let val = val.to_str().map_err(|_| headers::Error::invalid())?;
71 Self::from_str(val).map_err(|_| headers::Error::invalid())
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::Overwrite;
78 use axum::{body::Body, extract::FromRequest};
79 use axum_extra::TypedHeader;
80 use http::Request;
81
82 #[tokio::test]
83 #[rstest::rstest]
84 #[case("T", Overwrite(true))]
85 #[case("F", Overwrite(false))]
86 async fn test_overwrite_header(#[case] input: &str, #[case] header: Overwrite) {
87 let request = Request::builder()
88 .method("GET")
89 .header("Overwrite", input)
90 .body(Body::empty())
91 .unwrap();
92 let TypedHeader(depth) = TypedHeader::<Overwrite>::from_request(request, &())
93 .await
94 .unwrap();
95 assert_eq!(depth, header);
96 }
97
98 #[tokio::test]
99 async fn test_invalid_overwrite_header() {
100 let request = Request::builder()
101 .method("GET")
102 .header("Overwrite", "asldkj")
103 .body(Body::empty())
104 .unwrap();
105 assert!(
106 TypedHeader::<Overwrite>::from_request(request, &())
107 .await
108 .is_err()
109 );
110 }
111}