1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3mod extension;
4mod prop;
5pub mod register;
6use base64::Engine;
7use chrono::Utc;
8use derive_more::Constructor;
9pub use extension::*;
10use http::{HeaderValue, Method, header};
11pub use prop::*;
12use reqwest::{Body, Url};
13use rustical_store::{
14 CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
15};
16use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
17use std::{collections::HashMap, sync::Arc, time::Duration};
18use tokio::sync::mpsc::Receiver;
19use tracing::{error, info, warn};
20
21mod endpoints;
22pub use endpoints::subscription_service;
23
24#[derive(XmlSerialize, Debug)]
25pub struct ContentUpdate {
26 #[xml(ns = "rustical_dav::namespace::NS_DAV")]
27 sync_token: Option<String>,
28}
29
30#[derive(XmlSerialize, XmlRootTag, Debug)]
31#[xml(root = "push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
32#[xml(ns_prefix(
33 rustical_dav::namespace::NS_DAVPUSH = "",
34 rustical_dav::namespace::NS_DAV = "D",
35))]
36struct PushMessage {
37 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
38 topic: String,
39 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
40 content_update: Option<ContentUpdate>,
41}
42
43#[derive(Debug, Constructor)]
44pub struct DavPushController<S: SubscriptionStore> {
45 allowed_push_servers: Option<Vec<String>>,
46 sub_store: Arc<S>,
47}
48
49impl<S: SubscriptionStore> DavPushController<S> {
50 pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
51 loop {
52 tokio::time::sleep(Duration::from_secs(10)).await;
54 let mut messages = vec![];
55 recv.recv_many(&mut messages, 100).await;
56
57 let mut latest_messages = HashMap::new();
61 for message in messages {
62 if matches!(message.data, CollectionOperationInfo::Content { .. }) {
63 latest_messages.insert(message.topic.clone(), message);
64 }
65 }
66 let messages = latest_messages.into_values();
67
68 for message in messages {
69 self.send_message(message).await;
70 }
71 }
72 }
73
74 #[allow(clippy::cognitive_complexity)]
75 async fn send_message(&self, message: CollectionOperation) {
76 let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
77 Ok(subs) => subs,
78 Err(err) => {
79 error!("{err}");
80 return;
81 }
82 };
83
84 if subscriptions.is_empty() {
85 return;
86 }
87
88 if matches!(message.data, CollectionOperationInfo::Delete) {
89 return;
91 }
92
93 let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
94 Some(ContentUpdate {
95 sync_token: Some(sync_token),
96 })
97 } else {
98 None
99 };
100
101 let push_message = PushMessage {
102 topic: message.topic,
103 content_update,
104 };
105
106 let payload = match push_message.serialize_to_string() {
107 Ok(payload) => payload,
108 Err(err) => {
109 error!("Could not serialize push message: {}", err);
110 return;
111 }
112 };
113
114 for subsciption in subscriptions {
115 if subsciption.is_expired(&Utc::now()) {
116 info!(
117 "Deleting subscription {} on topic {} because it is expired",
118 subsciption.id, subsciption.topic
119 );
120 self.try_delete_subscription(&subsciption.id).await;
121 continue;
122 }
123
124 if let Some(allowed_push_servers) = &self.allowed_push_servers {
125 if let Ok(url) = Url::parse(&subsciption.push_resource) {
126 let origin = url.origin().unicode_serialization();
127 if !allowed_push_servers.contains(&origin) {
128 warn!(
129 "Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
130 subsciption.id, subsciption.topic
131 );
132 self.try_delete_subscription(&subsciption.id).await;
133 continue;
134 }
135 } else {
136 warn!(
137 "Deleting subscription {} on topic {} because of invalid URL",
138 subsciption.id, subsciption.topic
139 );
140 self.try_delete_subscription(&subsciption.id).await;
141 continue;
142 }
143 }
144
145 if let Err(err) = self.send_payload(&payload, &subsciption).await {
146 error!("An error occured sending out a push notification: {err}");
147 if err.is_permament_error() {
148 warn!(
149 "Deleting subscription {} on topic {}",
150 subsciption.id, subsciption.topic
151 );
152 self.try_delete_subscription(&subsciption.id).await;
153 }
154 }
155 }
156 }
157
158 async fn try_delete_subscription(&self, sub_id: &str) {
159 if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
160 error!("Error deleting subsciption: {err}");
161 }
162 }
163
164 async fn send_payload(
165 &self,
166 payload: &str,
167 subsciption: &Subscription,
168 ) -> Result<(), NotifierError> {
169 if subsciption.public_key_type != "p256dh" {
170 return Err(NotifierError::InvalidPublicKeyType(
171 subsciption.public_key_type.clone(),
172 ));
173 }
174 let endpoint = subsciption
175 .push_resource
176 .parse()
177 .map_err(|_| NotifierError::InvalidEndpointUrl(subsciption.push_resource.clone()))?;
178 let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
179 .decode(&subsciption.public_key)
180 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
181 let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
182 .decode(&subsciption.auth_secret)
183 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
184
185 let client = reqwest::ClientBuilder::new()
186 .build()
187 .map_err(NotifierError::from)?;
188
189 let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
190
191 let mut request = reqwest::Request::new(Method::POST, endpoint);
192 *request.body_mut() = Some(Body::from(payload));
193 let hdrs = request.headers_mut();
194 hdrs.insert(
195 header::CONTENT_ENCODING,
196 HeaderValue::from_static("aes128gcm"),
197 );
198 hdrs.insert(
199 header::CONTENT_TYPE,
200 HeaderValue::from_static("application/octet-stream"),
201 );
202 hdrs.insert("TTL", HeaderValue::from(60));
203 client.execute(request).await?;
204
205 Ok(())
206 }
207}
208
209#[derive(Debug, thiserror::Error)]
210enum NotifierError {
211 #[error("Invalid public key type: {0}")]
212 InvalidPublicKeyType(String),
213 #[error("Invalid endpoint URL: {0}")]
214 InvalidEndpointUrl(String),
215 #[error("Invalid key encoding")]
216 InvalidKeyEncoding,
217 #[error(transparent)]
218 EceError(#[from] ece::Error),
219 #[error(transparent)]
220 ReqwestError(#[from] reqwest::Error),
221}
222
223impl NotifierError {
224 pub const fn is_permament_error(&self) -> bool {
226 match self {
227 Self::InvalidPublicKeyType(_)
228 | Self::InvalidEndpointUrl(_)
229 | Self::InvalidKeyEncoding => true,
230 Self::EceError(err) => matches!(
231 err,
232 ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
233 ),
234 Self::ReqwestError(_) => false,
235 }
236 }
237}