1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3mod extension;
4mod prop;
5pub mod register;
6use base64::Engine;
7use chrono::Utc;
8use derive_more::Constructor;
9pub use extension::*;
10use http::{HeaderValue, Method, header};
11pub use prop::*;
12use reqwest::{Body, Url};
13use rustical_store::{
14 CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
15};
16use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
17use std::{collections::HashMap, sync::Arc, time::Duration};
18use tokio::sync::mpsc::Receiver;
19use tracing::{error, info, warn};
20
21mod endpoints;
22pub use endpoints::subscription_service;
23
24#[derive(XmlSerialize, Debug)]
25pub struct ContentUpdate {
26 #[xml(ns = "rustical_dav::namespace::NS_DAV")]
27 sync_token: Option<String>,
28}
29
30#[derive(XmlSerialize, XmlRootTag, Debug)]
31#[xml(root = "push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
32#[xml(ns_prefix(
33 rustical_dav::namespace::NS_DAVPUSH = "",
34 rustical_dav::namespace::NS_DAV = "D",
35))]
36struct PushMessage {
37 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
38 topic: String,
39 #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
40 content_update: Option<ContentUpdate>,
41}
42
43#[derive(Debug, Constructor)]
44pub struct DavPushController<S: SubscriptionStore> {
45 allowed_push_servers: Option<Vec<String>>,
46 sub_store: Arc<S>,
47}
48
49impl<S: SubscriptionStore> DavPushController<S> {
50 pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
51 loop {
52 tokio::time::sleep(Duration::from_secs(10)).await;
54 let mut messages = vec![];
55 recv.recv_many(&mut messages, 100).await;
56
57 let mut latest_messages = HashMap::new();
61 for message in messages {
62 if matches!(message.data, CollectionOperationInfo::Content { .. }) {
63 latest_messages.insert(message.topic.clone(), message);
64 }
65 }
66 let messages = latest_messages.into_values();
67
68 for message in messages {
69 self.send_message(message).await;
70 }
71 }
72 }
73
74 #[allow(clippy::cognitive_complexity)]
75 async fn send_message(&self, message: CollectionOperation) {
76 let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
77 Ok(subs) => subs,
78 Err(err) => {
79 error!("{err}");
80 return;
81 }
82 };
83
84 if subscriptions.is_empty() {
85 return;
86 }
87
88 if matches!(message.data, CollectionOperationInfo::Delete) {
89 return;
91 }
92
93 let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
94 Some(ContentUpdate {
95 sync_token: Some(sync_token),
96 })
97 } else {
98 None
99 };
100
101 let push_message = PushMessage {
102 topic: message.topic,
103 content_update,
104 };
105
106 let payload = match push_message.serialize_to_string() {
107 Ok(payload) => payload,
108 Err(err) => {
109 error!("Could not serialize push message: {}", err);
110 return;
111 }
112 };
113
114 for subsciption in subscriptions {
115 if subsciption.is_expired(&Utc::now()) {
116 info!(
117 "Deleting subscription {} on topic {} because it is expired",
118 subsciption.id, subsciption.topic
119 );
120 self.try_delete_subscription(&subsciption.id).await;
121 continue;
122 }
123
124 if let Some(allowed_push_servers) = &self.allowed_push_servers {
125 if let Ok(url) = Url::parse(&subsciption.push_resource) {
126 let origin = url.origin().unicode_serialization();
127 if !allowed_push_servers.contains(&origin) {
128 warn!(
129 "Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
130 subsciption.id, subsciption.topic
131 );
132 self.try_delete_subscription(&subsciption.id).await;
133 continue;
134 }
135 } else {
136 warn!(
137 "Deleting subscription {} on topic {} because of invalid URL",
138 subsciption.id, subsciption.topic
139 );
140 self.try_delete_subscription(&subsciption.id).await;
141 continue;
142 }
143 }
144
145 if let Err(err) = send_payload(&payload, &subsciption).await {
146 error!("An error occured sending out a push notification: {err}");
147 if err.is_permament_error() {
148 warn!(
149 "Deleting subscription {} on topic {}",
150 subsciption.id, subsciption.topic
151 );
152 self.try_delete_subscription(&subsciption.id).await;
153 }
154 }
155 }
156 }
157
158 async fn try_delete_subscription(&self, sub_id: &str) {
159 if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
160 error!("Error deleting subsciption: {err}");
161 }
162 }
163}
164
165async fn send_payload(payload: &str, subsciption: &Subscription) -> Result<(), NotifierError> {
166 if subsciption.public_key_type != "p256dh" {
167 return Err(NotifierError::InvalidPublicKeyType(
168 subsciption.public_key_type.clone(),
169 ));
170 }
171 let endpoint = subsciption
172 .push_resource
173 .parse()
174 .map_err(|_| NotifierError::InvalidEndpointUrl(subsciption.push_resource.clone()))?;
175 let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
176 .decode(&subsciption.public_key)
177 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
178 let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
179 .decode(&subsciption.auth_secret)
180 .map_err(|_| NotifierError::InvalidKeyEncoding)?;
181
182 let client = reqwest::ClientBuilder::new()
183 .build()
184 .map_err(NotifierError::from)?;
185
186 let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
187
188 let mut request = reqwest::Request::new(Method::POST, endpoint);
189 *request.body_mut() = Some(Body::from(payload));
190 let hdrs = request.headers_mut();
191 hdrs.insert(
192 header::CONTENT_ENCODING,
193 HeaderValue::from_static("aes128gcm"),
194 );
195 hdrs.insert(
196 header::CONTENT_TYPE,
197 HeaderValue::from_static("application/octet-stream"),
198 );
199 hdrs.insert("TTL", HeaderValue::from(60));
200 client.execute(request).await?;
201
202 Ok(())
203}
204
205#[derive(Debug, thiserror::Error)]
206enum NotifierError {
207 #[error("Invalid public key type: {0}")]
208 InvalidPublicKeyType(String),
209 #[error("Invalid endpoint URL: {0}")]
210 InvalidEndpointUrl(String),
211 #[error("Invalid key encoding")]
212 InvalidKeyEncoding,
213 #[error(transparent)]
214 EceError(#[from] ece::Error),
215 #[error(transparent)]
216 ReqwestError(#[from] reqwest::Error),
217}
218
219impl NotifierError {
220 pub const fn is_permament_error(&self) -> bool {
222 match self {
223 Self::InvalidPublicKeyType(_)
224 | Self::InvalidEndpointUrl(_)
225 | Self::InvalidKeyEncoding => true,
226 Self::EceError(err) => matches!(
227 err,
228 ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
229 ),
230 Self::ReqwestError(_) => false,
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use crate::send_payload;
238 use base64::Engine;
239 use chrono::NaiveDateTime;
240 use ece::generate_keypair_and_auth_secret;
241 use rustical_store::Subscription;
242
243 #[tokio::test]
244 async fn test_ntfy_request() {
245 let (keypair, auth_secret) = generate_keypair_and_auth_secret().unwrap();
246 let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(auth_secret);
247 let public_key =
248 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(keypair.pub_as_raw().unwrap());
249
250 send_payload(
251 "hello",
252 &Subscription {
253 id: "asd".to_string(),
254 topic: "asd".to_string(),
255 expiration: NaiveDateTime::MAX,
256 push_resource: "https://ntfy.sh/upL00-v4L3SGM2".to_string(),
257 public_key,
258 public_key_type: "p256dh".to_string(),
259 auth_secret,
260 },
261 )
262 .await
263 .unwrap();
264 }
265}