1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3mod extension;
4mod prop;
5pub mod register;
6use base64::Engine;
7use derive_more::Constructor;
8pub use extension::*;
9use http::{HeaderValue, Method, header};
10pub use prop::*;
11use reqwest::{Body, Url};
12use rustical_store::{
13 CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
14};
15use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
16use std::{collections::HashMap, sync::Arc, time::Duration};
17use tokio::sync::mpsc::Receiver;
18use tracing::{error, warn};
19
20mod endpoints;
21pub use endpoints::subscription_service;
22
23#[derive(XmlSerialize, Debug)]
24pub struct ContentUpdate {
25 #[xml(ns = "rustical_dav::namespace::NS_DAV")]
26 sync_token: Option<String>,
27}
28
29#[derive(XmlSerialize, XmlRootTag, Debug)]
30#[xml(root = "push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
31#[xml(ns_prefix(
32 rustical_dav::namespace::NS_DAVPUSH = "",
33 rustical_dav::namespace::NS_DAV = "D",
34))]
35struct PushMessage {
36 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
37 topic: String,
38 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
39 content_update: Option<ContentUpdate>,
40}
41
42#[derive(Debug, Constructor)]
43pub struct DavPushController<S: SubscriptionStore> {
44 allowed_push_servers: Option<Vec<String>>,
45 sub_store: Arc<S>,
46}
47
48impl<S: SubscriptionStore> DavPushController<S> {
49 pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
50 loop {
51 tokio::time::sleep(Duration::from_secs(10)).await;
53 let mut messages = vec![];
54 recv.recv_many(&mut messages, 100).await;
55
56 let mut latest_messages = HashMap::new();
60 for message in messages {
61 if matches!(message.data, CollectionOperationInfo::Content { .. }) {
62 latest_messages.insert(message.topic.clone(), message);
63 }
64 }
65 let messages = latest_messages.into_values();
66
67 for message in messages {
68 self.send_message(message).await;
69 }
70 }
71 }
72
73 #[allow(clippy::cognitive_complexity)]
74 async fn send_message(&self, message: CollectionOperation) {
75 let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
76 Ok(subs) => subs,
77 Err(err) => {
78 error!("{err}");
79 return;
80 }
81 };
82
83 if subscriptions.is_empty() {
84 return;
85 }
86
87 if matches!(message.data, CollectionOperationInfo::Delete) {
88 return;
90 }
91
92 let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
93 Some(ContentUpdate {
94 sync_token: Some(sync_token),
95 })
96 } else {
97 None
98 };
99
100 let push_message = PushMessage {
101 topic: message.topic,
102 content_update,
103 };
104
105 let payload = match push_message.serialize_to_string() {
106 Ok(payload) => payload,
107 Err(err) => {
108 error!("Could not serialize push message: {}", err);
109 return;
110 }
111 };
112
113 for subsciption in subscriptions {
114 if let Some(allowed_push_servers) = &self.allowed_push_servers {
115 if let Ok(url) = Url::parse(&subsciption.push_resource) {
116 let origin = url.origin().unicode_serialization();
117 if !allowed_push_servers.contains(&origin) {
118 warn!(
119 "Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
120 subsciption.id, subsciption.topic
121 );
122 self.try_delete_subscription(&subsciption.id).await;
123 }
124 } else {
125 warn!(
126 "Deleting subscription {} on topic {} because of invalid URL",
127 subsciption.id, subsciption.topic
128 );
129 self.try_delete_subscription(&subsciption.id).await;
130 }
131 }
132
133 if let Err(err) = self.send_payload(&payload, &subsciption).await {
134 error!("An error occured sending out a push notification: {err}");
135 if err.is_permament_error() {
136 warn!(
137 "Deleting subscription {} on topic {}",
138 subsciption.id, subsciption.topic
139 );
140 self.try_delete_subscription(&subsciption.id).await;
141 }
142 }
143 }
144 }
145
146 async fn try_delete_subscription(&self, sub_id: &str) {
147 if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
148 error!("Error deleting subsciption: {err}");
149 }
150 }
151
152 async fn send_payload(
153 &self,
154 payload: &str,
155 subsciption: &Subscription,
156 ) -> Result<(), NotifierError> {
157 if subsciption.public_key_type != "p256dh" {
158 return Err(NotifierError::InvalidPublicKeyType(
159 subsciption.public_key_type.clone(),
160 ));
161 }
162 let endpoint = subsciption
163 .push_resource
164 .parse()
165 .map_err(|_| NotifierError::InvalidEndpointUrl(subsciption.push_resource.clone()))?;
166 let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
167 .decode(&subsciption.public_key)
168 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
169 let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
170 .decode(&subsciption.auth_secret)
171 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
172
173 let client = reqwest::ClientBuilder::new()
174 .build()
175 .map_err(NotifierError::from)?;
176
177 let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
178
179 let mut request = reqwest::Request::new(Method::POST, endpoint);
180 *request.body_mut() = Some(Body::from(payload));
181 let hdrs = request.headers_mut();
182 hdrs.insert(
183 header::CONTENT_ENCODING,
184 HeaderValue::from_static("aes128gcm"),
185 );
186 hdrs.insert(
187 header::CONTENT_TYPE,
188 HeaderValue::from_static("application/octet-stream"),
189 );
190 hdrs.insert("TTL", HeaderValue::from(60));
191 client.execute(request).await?;
192
193 Ok(())
194 }
195}
196
197#[derive(Debug, thiserror::Error)]
198enum NotifierError {
199 #[error("Invalid public key type: {0}")]
200 InvalidPublicKeyType(String),
201 #[error("Invalid endpoint URL: {0}")]
202 InvalidEndpointUrl(String),
203 #[error("Invalid key encoding")]
204 InvalidKeyEncoding,
205 #[error(transparent)]
206 EceError(#[from] ece::Error),
207 #[error(transparent)]
208 ReqwestError(#[from] reqwest::Error),
209}
210
211impl NotifierError {
212 pub const fn is_permament_error(&self) -> bool {
214 match self {
215 Self::InvalidPublicKeyType(_)
216 | Self::InvalidEndpointUrl(_)
217 | Self::InvalidKeyEncoding => true,
218 Self::EceError(err) => matches!(
219 err,
220 ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
221 ),
222 Self::ReqwestError(_) => false,
223 }
224 }
225}