rustical_store/auth/
middleware.rs1use super::AuthenticationProvider;
2use axum::{extract::Request, response::Response};
3use futures_core::future::BoxFuture;
4use headers::{Authorization, HeaderMapExt, authorization::Basic};
5use std::{
6 sync::Arc,
7 task::{Context, Poll},
8};
9use tower::{Layer, Service};
10use tower_sessions::Session;
11use tracing::{Instrument, info_span};
12
13pub struct AuthenticationLayer<AP: AuthenticationProvider> {
14 auth_provider: Arc<AP>,
15}
16
17impl<AP: AuthenticationProvider> Clone for AuthenticationLayer<AP> {
18 fn clone(&self) -> Self {
19 Self {
20 auth_provider: self.auth_provider.clone(),
21 }
22 }
23}
24
25impl<AP: AuthenticationProvider> AuthenticationLayer<AP> {
26 pub const fn new(auth_provider: Arc<AP>) -> Self {
27 Self { auth_provider }
28 }
29}
30
31impl<S, AP: AuthenticationProvider> Layer<S> for AuthenticationLayer<AP> {
32 type Service = AuthenticationMiddleware<S, AP>;
33
34 fn layer(&self, inner: S) -> Self::Service {
35 Self::Service {
36 inner,
37 auth_provider: self.auth_provider.clone(),
38 }
39 }
40}
41
42pub struct AuthenticationMiddleware<S, AP: AuthenticationProvider> {
43 inner: S,
44 auth_provider: Arc<AP>,
45}
46
47impl<S: Clone, AP: AuthenticationProvider> Clone for AuthenticationMiddleware<S, AP> {
48 fn clone(&self) -> Self {
49 Self {
50 inner: self.inner.clone(),
51 auth_provider: self.auth_provider.clone(),
52 }
53 }
54}
55
56impl<S, AP: AuthenticationProvider> Service<Request> for AuthenticationMiddleware<S, AP>
57where
58 S: Service<Request, Response = Response> + Send + Clone + 'static,
59 S::Future: Send + 'static,
60{
61 type Response = S::Response;
62 type Error = S::Error;
63 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
64
65 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66 self.inner.poll_ready(cx)
67 }
68
69 fn call(&mut self, mut request: Request) -> Self::Future {
70 let auth_header: Option<Authorization<Basic>> = request.headers().typed_get();
71 let ap = self.auth_provider.clone();
72 let mut inner = self.inner.clone();
73
74 Box::pin(async move {
75 if let Some(session) = request.extensions().get::<Session>()
76 && let Ok(Some(user_id)) = session.get::<String>("user").await
77 && let Ok(Some(user)) = ap.get_principal(&user_id).await
78 {
79 request.extensions_mut().insert(user);
80 }
81
82 if let Some(auth) = auth_header {
83 let user_id = auth.username();
84 let password = auth.password();
85 if let Ok(Some(user)) = ap
86 .validate_app_token(user_id, password)
87 .instrument(info_span!("validate_user_token"))
88 .await
89 {
90 request.extensions_mut().insert(user);
91 }
92 }
93
94 let response = inner.call(request).await?;
95 Ok(response)
96 })
97 }
98}