rustical_oidc/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
3use axum::{
4    Extension, Form,
5    extract::Query,
6    response::{IntoResponse, Redirect, Response},
7};
8use axum_extra::extract::Host;
9pub use config::OidcConfig;
10use config::UserIdClaim;
11use error::OidcError;
12use openidconnect::{
13    AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
14    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
15    RedirectUrl, TokenResponse, UserInfoClaims,
16    core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
17};
18use reqwest::{StatusCode, Url};
19use serde::{Deserialize, Serialize};
20use tower_sessions::Session;
21pub use user_store::UserStore;
22
23mod config;
24mod error;
25mod user_store;
26
27const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
28
29#[derive(Debug, Clone)]
30pub struct OidcServiceConfig {
31    pub default_redirect_path: &'static str,
32    pub session_key_user_id: &'static str,
33}
34
35#[derive(Debug, Deserialize, Serialize)]
36struct OidcState {
37    state: CsrfToken,
38    nonce: Nonce,
39    pkce_verifier: PkceCodeVerifier,
40    redirect_uri: Option<String>,
41}
42
43#[derive(Debug, Deserialize, Serialize)]
44struct GroupAdditionalClaims {
45    #[serde(default)]
46    groups: Option<Vec<String>>,
47}
48
49impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
50
51fn get_http_client() -> reqwest::Client {
52    reqwest::ClientBuilder::new()
53        // Following redirects opens the client up to SSRF vulnerabilities.
54        .redirect(reqwest::redirect::Policy::none())
55        .build()
56        .expect("Something went wrong :(")
57}
58
59async fn get_oidc_client(
60    OidcConfig {
61        issuer,
62        client_id,
63        client_secret,
64        ..
65    }: OidcConfig,
66    http_client: &reqwest::Client,
67    redirect_uri: RedirectUrl,
68) -> Result<
69    CoreClient<
70        EndpointSet,
71        EndpointNotSet,
72        EndpointNotSet,
73        EndpointNotSet,
74        EndpointMaybeSet,
75        EndpointMaybeSet,
76    >,
77    OidcError,
78> {
79    let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
80        .await
81        .map_err(|err| {
82            tracing::error!("An error occured trying to discover OpenID provider: {err}");
83            OidcError::Other("Failed to discover OpenID provider")
84        })?;
85
86    Ok(CoreClient::from_provider_metadata(
87        provider_metadata,
88        client_id.clone(),
89        client_secret.clone(),
90    )
91    .set_redirect_uri(redirect_uri))
92}
93
94#[derive(Debug, Deserialize)]
95pub struct GetOidcForm {
96    redirect_uri: Option<String>,
97}
98
99/// Endpoint that redirects to the authorize endpoint of the OIDC service
100pub async fn route_post_oidc(
101    Extension(oidc_config): Extension<OidcConfig>,
102    session: Session,
103    Host(host): Host,
104    Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
105) -> Result<Response, OidcError> {
106    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
107
108    let http_client = get_http_client();
109    let oidc_client = get_oidc_client(
110        oidc_config.clone(),
111        &http_client,
112        RedirectUrl::new(callback_uri)?,
113    )
114    .await?;
115
116    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
117
118    let (auth_url, csrf_token, nonce) = oidc_client
119        .authorize_url(
120            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
121            CsrfToken::new_random,
122            Nonce::new_random,
123        )
124        .add_scopes(oidc_config.scopes.clone())
125        .set_pkce_challenge(pkce_challenge)
126        .url();
127
128    session
129        .insert(
130            SESSION_KEY_OIDC_STATE,
131            OidcState {
132                state: csrf_token,
133                nonce,
134                pkce_verifier,
135                redirect_uri,
136            },
137        )
138        .await?;
139
140    Ok(Redirect::to(auth_url.as_str()).into_response())
141}
142
143#[derive(Debug, Clone, Deserialize)]
144pub struct AuthCallbackQuery {
145    code: AuthorizationCode,
146    // RFC 9207
147    iss: Option<IssuerUrl>,
148    state: String,
149}
150
151// Handle callback from IdP page
152pub async fn route_get_oidc_callback<US: UserStore + Clone>(
153    Extension(oidc_config): Extension<OidcConfig>,
154    Extension(user_store): Extension<US>,
155    Extension(service_config): Extension<OidcServiceConfig>,
156    session: Session,
157    Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
158    Host(host): Host,
159) -> Result<Response, OidcError> {
160    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
161
162    if let Some(iss) = iss {
163        assert_eq!(iss, oidc_config.issuer);
164    }
165    let oidc_state = session
166        .remove::<OidcState>(SESSION_KEY_OIDC_STATE)
167        .await?
168        .ok_or(OidcError::Other("No local OIDC state"))?;
169
170    assert_eq!(oidc_state.state.secret(), &state);
171
172    let http_client = get_http_client();
173    let oidc_client = get_oidc_client(
174        oidc_config.clone(),
175        &http_client,
176        RedirectUrl::new(callback_uri)?,
177    )
178    .await?;
179
180    let token_response = oidc_client
181        .exchange_code(code)?
182        .set_pkce_verifier(oidc_state.pkce_verifier)
183        .request_async(&http_client)
184        .await
185        .map_err(|_| OidcError::Other("Error requesting token"))?;
186    let id_claims = token_response
187        .id_token()
188        .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
189        .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
190
191    let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
192        .user_info(
193            token_response.access_token().clone(),
194            Some(id_claims.subject().clone()),
195        )?
196        .request_async(&http_client)
197        .await
198        .map_err(|e| OidcError::UserInfo(e.to_string()))?;
199
200    if let Some(require_group) = &oidc_config.require_group
201        && !user_info_claims
202            .additional_claims()
203            .groups
204            .clone()
205            .unwrap_or_default()
206            .contains(require_group)
207    {
208        return Ok((
209            StatusCode::UNAUTHORIZED,
210            "User is not in an authorized group to use RustiCal",
211        )
212            .into_response());
213    }
214
215    let user_id = match oidc_config.claim_userid {
216        UserIdClaim::Sub => user_info_claims.subject().to_string(),
217        UserIdClaim::PreferredUsername => user_info_claims
218            .preferred_username()
219            .ok_or(OidcError::Other("Missing preferred_username claim"))?
220            .to_string(),
221    };
222
223    match user_store.user_exists(&user_id).await {
224        Ok(false) => {
225            // User does not exist
226            if !oidc_config.allow_sign_up {
227                return Ok((StatusCode::UNAUTHORIZED, "User signup is disabled").into_response());
228            }
229            // Create new user
230            if let Err(err) = user_store.insert_user(&user_id).await {
231                return Ok(err.into_response());
232            }
233        }
234        Ok(true) => {}
235        Err(err) => {
236            return Ok(err.into_response());
237        }
238    }
239
240    let default_redirect = service_config.default_redirect_path.to_owned();
241    let base_url: Url = format!("https://{host}").parse().unwrap();
242    let redirect_uri = if let Some(redirect_uri) = oidc_state.redirect_uri {
243        if let Ok(redirect_url) = base_url.join(&redirect_uri) {
244            if redirect_url.origin() == base_url.origin() {
245                redirect_url.path().to_owned()
246            } else {
247                default_redirect
248            }
249        } else {
250            default_redirect
251        }
252    } else {
253        default_redirect
254    };
255
256    // Complete login flow
257    session
258        .insert(service_config.session_key_user_id, user_id.clone())
259        .await?;
260
261    Ok(Redirect::to(&redirect_uri).into_response())
262}