Skip to main content

rustical_dav_push/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc)]
3mod extension;
4mod prop;
5pub mod register;
6use base64::Engine;
7use chrono::Utc;
8use derive_more::Constructor;
9pub use extension::*;
10use http::{HeaderValue, Method, header};
11pub use prop::*;
12use reqwest::{Body, Url};
13use rustical_store::{
14    CollectionOperation, CollectionOperationInfo, Subscription, SubscriptionStore,
15};
16use rustical_xml::{XmlRootTag, XmlSerialize, XmlSerializeRoot};
17use std::{collections::HashMap, sync::Arc, time::Duration};
18use tokio::sync::mpsc::Receiver;
19use tracing::{error, info, warn};
20
21mod endpoints;
22pub use endpoints::subscription_service;
23
24#[derive(XmlSerialize, Debug)]
25pub struct ContentUpdate {
26    #[xml(ns = "rustical_dav::namespace::NS_DAV")]
27    sync_token: Option<String>,
28}
29
30#[derive(XmlSerialize, XmlRootTag, Debug)]
31#[xml(root = "push-message", ns = "rustical_dav::namespace::NS_DAVPUSH")]
32#[xml(ns_prefix(
33    rustical_dav::namespace::NS_DAVPUSH = "",
34    rustical_dav::namespace::NS_DAV = "D",
35))]
36struct PushMessage {
37    #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
38    topic: String,
39    #[xml(ns = "rustical_dav::namespace::NS_DAVPUSH")]
40    content_update: Option<ContentUpdate>,
41}
42
43#[derive(Debug, Constructor)]
44pub struct DavPushController<S: SubscriptionStore> {
45    allowed_push_servers: Option<Vec<String>>,
46    sub_store: Arc<S>,
47}
48
49impl<S: SubscriptionStore> DavPushController<S> {
50    pub async fn notifier(&self, mut recv: Receiver<CollectionOperation>) {
51        loop {
52            // Make sure we don't flood the subscribers
53            tokio::time::sleep(Duration::from_secs(10)).await;
54            let mut messages = vec![];
55            recv.recv_many(&mut messages, 100).await;
56
57            // Right now we just have to show the latest content update by topic
58            // This might become more complicated in the future depending on what kind of updates
59            // we add
60            let mut latest_messages = HashMap::new();
61            for message in messages {
62                if matches!(message.data, CollectionOperationInfo::Content { .. }) {
63                    latest_messages.insert(message.topic.clone(), message);
64                }
65            }
66            let messages = latest_messages.into_values();
67
68            for message in messages {
69                self.send_message(message).await;
70            }
71        }
72    }
73
74    #[allow(clippy::cognitive_complexity)]
75    async fn send_message(&self, message: CollectionOperation) {
76        let subscriptions = match self.sub_store.get_subscriptions(&message.topic).await {
77            Ok(subs) => subs,
78            Err(err) => {
79                error!("{err}");
80                return;
81            }
82        };
83
84        if subscriptions.is_empty() {
85            return;
86        }
87
88        if matches!(message.data, CollectionOperationInfo::Delete) {
89            // Collection has been deleted, but we cannot handle that
90            return;
91        }
92
93        let content_update = if let CollectionOperationInfo::Content { sync_token } = message.data {
94            Some(ContentUpdate {
95                sync_token: Some(sync_token),
96            })
97        } else {
98            None
99        };
100
101        let push_message = PushMessage {
102            topic: message.topic,
103            content_update,
104        };
105
106        let payload = match push_message.serialize_to_string() {
107            Ok(payload) => payload,
108            Err(err) => {
109                error!("Could not serialize push message: {}", err);
110                return;
111            }
112        };
113
114        for subsciption in subscriptions {
115            if subsciption.is_expired(&Utc::now()) {
116                info!(
117                    "Deleting subscription {} on topic {} because it is expired",
118                    subsciption.id, subsciption.topic
119                );
120                self.try_delete_subscription(&subsciption.id).await;
121                continue;
122            }
123
124            if let Some(allowed_push_servers) = &self.allowed_push_servers {
125                if let Ok(url) = Url::parse(&subsciption.push_resource) {
126                    let origin = url.origin().unicode_serialization();
127                    if !allowed_push_servers.contains(&origin) {
128                        warn!(
129                            "Deleting subscription {} on topic {} because the endpoint is not in the list of allowed push servers",
130                            subsciption.id, subsciption.topic
131                        );
132                        self.try_delete_subscription(&subsciption.id).await;
133                        continue;
134                    }
135                } else {
136                    warn!(
137                        "Deleting subscription {} on topic {} because of invalid URL",
138                        subsciption.id, subsciption.topic
139                    );
140                    self.try_delete_subscription(&subsciption.id).await;
141                    continue;
142                }
143            }
144
145            if let Err(err) = self.send_payload(&payload, &subsciption).await {
146                error!("An error occured sending out a push notification: {err}");
147                if err.is_permament_error() {
148                    warn!(
149                        "Deleting subscription {} on topic {}",
150                        subsciption.id, subsciption.topic
151                    );
152                    self.try_delete_subscription(&subsciption.id).await;
153                }
154            }
155        }
156    }
157
158    async fn try_delete_subscription(&self, sub_id: &str) {
159        if let Err(err) = self.sub_store.delete_subscription(sub_id).await {
160            error!("Error deleting subsciption: {err}");
161        }
162    }
163
164    async fn send_payload(
165        &self,
166        payload: &str,
167        subsciption: &Subscription,
168    ) -> Result<(), NotifierError> {
169        if subsciption.public_key_type != "p256dh" {
170            return Err(NotifierError::InvalidPublicKeyType(
171                subsciption.public_key_type.clone(),
172            ));
173        }
174        let endpoint = subsciption
175            .push_resource
176            .parse()
177            .map_err(|_| NotifierError::InvalidEndpointUrl(subsciption.push_resource.clone()))?;
178        let ua_public = base64::engine::general_purpose::URL_SAFE_NO_PAD
179            .decode(&subsciption.public_key)
180            .map_err(|_| NotifierError::InvalidKeyEncoding)?;
181        let auth_secret = base64::engine::general_purpose::URL_SAFE_NO_PAD
182            .decode(&subsciption.auth_secret)
183            .map_err(|_| NotifierError::InvalidKeyEncoding)?;
184
185        let client = reqwest::ClientBuilder::new()
186            .build()
187            .map_err(NotifierError::from)?;
188
189        let payload = ece::encrypt(&ua_public, &auth_secret, payload.as_bytes())?;
190
191        let mut request = reqwest::Request::new(Method::POST, endpoint);
192        *request.body_mut() = Some(Body::from(payload));
193        let hdrs = request.headers_mut();
194        hdrs.insert(
195            header::CONTENT_ENCODING,
196            HeaderValue::from_static("aes128gcm"),
197        );
198        hdrs.insert(
199            header::CONTENT_TYPE,
200            HeaderValue::from_static("application/octet-stream"),
201        );
202        hdrs.insert("TTL", HeaderValue::from(60));
203        client.execute(request).await?;
204
205        Ok(())
206    }
207}
208
209#[derive(Debug, thiserror::Error)]
210enum NotifierError {
211    #[error("Invalid public key type: {0}")]
212    InvalidPublicKeyType(String),
213    #[error("Invalid endpoint URL: {0}")]
214    InvalidEndpointUrl(String),
215    #[error("Invalid key encoding")]
216    InvalidKeyEncoding,
217    #[error(transparent)]
218    EceError(#[from] ece::Error),
219    #[error(transparent)]
220    ReqwestError(#[from] reqwest::Error),
221}
222
223impl NotifierError {
224    // Decide whether the error should cause the subscription to be removed
225    pub const fn is_permament_error(&self) -> bool {
226        match self {
227            Self::InvalidPublicKeyType(_)
228            | Self::InvalidEndpointUrl(_)
229            | Self::InvalidKeyEncoding => true,
230            Self::EceError(err) => matches!(
231                err,
232                ece::Error::InvalidAuthSecret | ece::Error::InvalidKeyLength
233            ),
234            Self::ReqwestError(_) => false,
235        }
236    }
237}