rustical_dav_push/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3mod extension;
4mod prop;
5pub mod register;
6use base64::Engine;
7use derive_more::Constructor;
8pub use extension::*;
9use http::{HeaderValue, Method, header};
10pub use prop::*;
11use reqwest::{Body, Url};
12use rustical_store::{
13    CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
14};
15use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
16use std::{collections::HashMap, sync::Arc, time::Duration};
17use tokio::sync::mpsc::Receiver;
18use tracing::{error, warn};
19
20mod endpoints;
21pub use endpoints::subscription_service;
22
23#[derive(XmlSerialize, Debug)]
24pub struct ContentUpdate {
25    #[xml(ns = "rustical_dav::namespace::NS_DAV")]
26    sync_token: Option<String>,
27}
28
29#[derive(XmlSerialize, XmlRootTag, Debug)]
30#[xml(root = "push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
31#[xml(ns_prefix(
32    rustical_dav::namespace::NS_DAVPUSH = "",
33    rustical_dav::namespace::NS_DAV = "D",
34))]
35struct PushMessage {
36    #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
37    topic: String,
38    #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
39    content_update: Option<ContentUpdate>,
40}
41
42#[derive(Debug, Constructor)]
43pub struct DavPushController<S: SubscriptionStore> {
44    allowed_push_servers: Option<Vec<String>>,
45    sub_store: Arc<S>,
46}
47
48impl<S: SubscriptionStore> DavPushController<S> {
49    pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
50        loop {
51            // Make sure we don't flood the subscribers
52            tokio::time::sleep(Duration::from_secs(10)).await;
53            let mut messages = vec![];
54            recv.recv_many(&mut messages, 100).await;
55
56            // Right now we just have to show the latest content update by topic
57            // This might become more complicated in the future depending on what kind of updates
58            // we add
59            let mut latest_messages = HashMap::new();
60            for message in messages {
61                if matches!(message.data, CollectionOperationInfo::Content { .. }) {
62                    latest_messages.insert(message.topic.clone(), message);
63                }
64            }
65            let messages = latest_messages.into_values();
66
67            for message in messages {
68                self.send_message(message).await;
69            }
70        }
71    }
72
73    #[allow(clippy::cognitive_complexity)]
74    async fn send_message(&self, message: CollectionOperation) {
75        let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
76            Ok(subs) => subs,
77            Err(err) => {
78                error!("{err}");
79                return;
80            }
81        };
82
83        if subscriptions.is_empty() {
84            return;
85        }
86
87        if matches!(message.data, CollectionOperationInfo::Delete) {
88            // Collection has been deleted, but we cannot handle that
89            return;
90        }
91
92        let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
93            Some(ContentUpdate {
94                sync_token: Some(sync_token),
95            })
96        } else {
97            None
98        };
99
100        let push_message = PushMessage {
101            topic: message.topic,
102            content_update,
103        };
104
105        let payload = match push_message.serialize_to_string() {
106            Ok(payload) => payload,
107            Err(err) => {
108                error!("Could not serialize push message: {}", err);
109                return;
110            }
111        };
112
113        for subsciption in subscriptions {
114            if let Some(allowed_push_servers) = &self.allowed_push_servers {
115                if let Ok(url) = Url::parse(&subsciption.push_resource) {
116                    let origin = url.origin().unicode_serialization();
117                    if !allowed_push_servers.contains(&origin) {
118                        warn!(
119                            "Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
120                            subsciption.id, subsciption.topic
121                        );
122                        self.try_delete_subscription(&subsciption.id).await;
123                    }
124                } else {
125                    warn!(
126                        "Deleting subscription {} on topic {} because of invalid URL",
127                        subsciption.id, subsciption.topic
128                    );
129                    self.try_delete_subscription(&subsciption.id).await;
130                }
131            }
132
133            if let Err(err) = self.send_payload(&payload, &subsciption).await {
134                error!("An error occured sending out a push notification: {err}");
135                if err.is_permament_error() {
136                    warn!(
137                        "Deleting subscription {} on topic {}",
138                        subsciption.id, subsciption.topic
139                    );
140                    self.try_delete_subscription(&subsciption.id).await;
141                }
142            }
143        }
144    }
145
146    async fn try_delete_subscription(&self, sub_id: &str) {
147        if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
148            error!("Error deleting subsciption: {err}");
149        }
150    }
151
152    async fn send_payload(
153        &self,
154        payload: &str,
155        subsciption: &Subscription,
156    ) -> Result<(), NotifierError> {
157        if subsciption.public_key_type != "p256dh" {
158            return Err(NotifierError::InvalidPublicKeyType(
159                subsciption.public_key_type.clone(),
160            ));
161        }
162        let endpoint = subsciption
163            .push_resource
164            .parse()
165            .map_err(|_| NotifierError::InvalidEndpointUrl(subsciption.push_resource.clone()))?;
166        let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
167            .decode(&subsciption.public_key)
168            .map_err(|_| NotifierError::InvalidKeyEncoding)?;
169        let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
170            .decode(&subsciption.auth_secret)
171            .map_err(|_| NotifierError::InvalidKeyEncoding)?;
172
173        let client = reqwest::ClientBuilder::new()
174            .build()
175            .map_err(NotifierError::from)?;
176
177        let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
178
179        let mut request = reqwest::Request::new(Method::POST, endpoint);
180        *request.body_mut() = Some(Body::from(payload));
181        let hdrs = request.headers_mut();
182        hdrs.insert(
183            header::CONTENT_ENCODING,
184            HeaderValue::from_static("aes128gcm"),
185        );
186        hdrs.insert(
187            header::CONTENT_TYPE,
188            HeaderValue::from_static("application/octet-stream"),
189        );
190        hdrs.insert("TTL", HeaderValue::from(60));
191        client.execute(request).await?;
192
193        Ok(())
194    }
195}
196
197#[derive(Debug, thiserror::Error)]
198enum NotifierError {
199    #[error("Invalid public key type: {0}")]
200    InvalidPublicKeyType(String),
201    #[error("Invalid endpoint URL: {0}")]
202    InvalidEndpointUrl(String),
203    #[error("Invalid key encoding")]
204    InvalidKeyEncoding,
205    #[error(transparent)]
206    EceError(#[from] ece::Error),
207    #[error(transparent)]
208    ReqwestError(#[from] reqwest::Error),
209}
210
211impl NotifierError {
212    // Decide whether the error should cause the subscription to be removed
213    pub const fn is_permament_error(&self) -> bool {
214        match self {
215            Self::InvalidPublicKeyType(_)
216            | Self::InvalidEndpointUrl(_)
217            | Self::InvalidKeyEncoding => true,
218            Self::EceError(err) => matches!(
219                err,
220                ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
221            ),
222            Self::ReqwestError(_) => false,
223        }
224    }
225}