rustical_store/auth/
middleware.rs

1use super::AuthenticationProvider;
2use axum::{extract::Request, response::Response};
3use futures_core::future::BoxFuture;
4use headers::{Authorization, HeaderMapExt, authorization::Basic};
5use std::{
6    sync::Arc,
7    task::{Context, Poll},
8};
9use tower::{Layer, Service};
10use tower_sessions::Session;
11use tracing::{Instrument, info_span};
12
13pub struct AuthenticationLayer<AP: AuthenticationProvider> {
14    auth_provider: Arc<AP>,
15}
16
17impl<AP: AuthenticationProvider> Clone for AuthenticationLayer<AP> {
18    fn clone(&self) -> Self {
19        Self {
20            auth_provider: self.auth_provider.clone(),
21        }
22    }
23}
24
25impl<AP: AuthenticationProvider> AuthenticationLayer<AP> {
26    pub const fn new(auth_provider: Arc<AP>) -> Self {
27        Self { auth_provider }
28    }
29}
30
31impl<S, AP: AuthenticationProvider> Layer<S> for AuthenticationLayer<AP> {
32    type Service = AuthenticationMiddleware<S, AP>;
33
34    fn layer(&self, inner: S) -> Self::Service {
35        Self::Service {
36            inner,
37            auth_provider: self.auth_provider.clone(),
38        }
39    }
40}
41
42pub struct AuthenticationMiddleware<S, AP: AuthenticationProvider> {
43    inner: S,
44    auth_provider: Arc<AP>,
45}
46
47impl<S: Clone, AP: AuthenticationProvider> Clone for AuthenticationMiddleware<S, AP> {
48    fn clone(&self) -> Self {
49        Self {
50            inner: self.inner.clone(),
51            auth_provider: self.auth_provider.clone(),
52        }
53    }
54}
55
56impl<S, AP: AuthenticationProvider> Service<Request> for AuthenticationMiddleware<S, AP>
57where
58    S: Service<Request, Response = Response> + Send + Clone + 'static,
59    S::Future: Send + 'static,
60{
61    type Response = S::Response;
62    type Error = S::Error;
63    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
64
65    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66        self.inner.poll_ready(cx)
67    }
68
69    fn call(&mut self, mut request: Request) -> Self::Future {
70        let auth_header: Option<Authorization<Basic>> = request.headers().typed_get();
71        let ap = self.auth_provider.clone();
72        let mut inner = self.inner.clone();
73
74        Box::pin(async move {
75            if let Some(session) = request.extensions().get::<Session>()
76                && let Ok(Some(user_id)) = session.get::<String>("user").await
77                && let Ok(Some(user)) = ap.get_principal(&user_id).await
78            {
79                request.extensions_mut().insert(user);
80            }
81
82            if let Some(auth) = auth_header {
83                let user_id = auth.username();
84                let password = auth.password();
85                if let Ok(Some(user)) = ap
86                    .validate_app_token(user_id, password)
87                    .instrument(info_span!("validate_user_token"))
88                    .await
89                {
90                    request.extensions_mut().insert(user);
91                }
92            }
93
94            let response = inner.call(request).await?;
95            Ok(response)
96        })
97    }
98}