rustical_oidc/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
3use axum::{
4    Extension, Form,
5    extract::Query,
6    response::{IntoResponse, Redirect, Response},
7};
8use axum_extra::TypedHeader;
9pub use config::OidcConfig;
10use config::UserIdClaim;
11use error::OidcError;
12use headers::Host;
13use openidconnect::{
14    AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
15    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
16    RedirectUrl, TokenResponse, UserInfoClaims,
17    core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
18};
19use reqwest::{StatusCode, Url};
20use serde::{Deserialize, Serialize};
21use tower_sessions::Session;
22pub use user_store::UserStore;
23
24mod config;
25mod error;
26mod user_store;
27
28const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
29
30#[derive(Debug, Clone)]
31pub struct OidcServiceConfig {
32    pub default_redirect_path: &'static str,
33    pub session_key_user_id: &'static str,
34}
35
36#[derive(Debug, Deserialize, Serialize)]
37struct OidcState {
38    state: CsrfToken,
39    nonce: Nonce,
40    pkce_verifier: PkceCodeVerifier,
41    redirect_uri: Option<String>,
42}
43
44#[derive(Debug, Deserialize, Serialize)]
45struct GroupAdditionalClaims {
46    #[serde(default)]
47    groups: Option<Vec<String>>,
48}
49
50impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
51
52fn get_http_client() -> reqwest::Client {
53    reqwest::ClientBuilder::new()
54        // Following redirects opens the client up to SSRF vulnerabilities.
55        .redirect(reqwest::redirect::Policy::none())
56        .build()
57        .expect("Something went wrong :(")
58}
59
60async fn get_oidc_client(
61    OidcConfig {
62        issuer,
63        client_id,
64        client_secret,
65        ..
66    }: OidcConfig,
67    http_client: &reqwest::Client,
68    redirect_uri: RedirectUrl,
69) -> Result<
70    CoreClient<
71        EndpointSet,
72        EndpointNotSet,
73        EndpointNotSet,
74        EndpointNotSet,
75        EndpointMaybeSet,
76        EndpointMaybeSet,
77    >,
78    OidcError,
79> {
80    let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
81        .await
82        .map_err(|err| {
83            tracing::error!("An error occured trying to discover OpenID provider: {err}");
84            OidcError::Other("Failed to discover OpenID provider")
85        })?;
86
87    Ok(CoreClient::from_provider_metadata(
88        provider_metadata,
89        client_id.clone(),
90        client_secret.clone(),
91    )
92    .set_redirect_uri(redirect_uri))
93}
94
95#[derive(Debug, Deserialize)]
96pub struct GetOidcForm {
97    redirect_uri: Option<String>,
98}
99
100/// Endpoint that redirects to the authorize endpoint of the OIDC service
101pub async fn route_post_oidc(
102    Extension(oidc_config): Extension<OidcConfig>,
103    session: Session,
104    TypedHeader(host): TypedHeader<Host>,
105    Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
106) -> Result<Response, OidcError> {
107    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
108
109    let http_client = get_http_client();
110    let oidc_client = get_oidc_client(
111        oidc_config.clone(),
112        &http_client,
113        RedirectUrl::new(callback_uri)?,
114    )
115    .await?;
116
117    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
118
119    let (auth_url, csrf_token, nonce) = oidc_client
120        .authorize_url(
121            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
122            CsrfToken::new_random,
123            Nonce::new_random,
124        )
125        .add_scopes(oidc_config.scopes.clone())
126        .set_pkce_challenge(pkce_challenge)
127        .url();
128
129    session
130        .insert(
131            SESSION_KEY_OIDC_STATE,
132            OidcState {
133                state: csrf_token,
134                nonce,
135                pkce_verifier,
136                redirect_uri,
137            },
138        )
139        .await?;
140
141    Ok(Redirect::to(auth_url.as_str()).into_response())
142}
143
144#[derive(Debug, Clone, Deserialize)]
145pub struct AuthCallbackQuery {
146    code: AuthorizationCode,
147    // RFC 9207
148    iss: Option<IssuerUrl>,
149    state: String,
150}
151
152// Handle callback from IdP page
153pub async fn route_get_oidc_callback<US: UserStore + Clone>(
154    Extension(oidc_config): Extension<OidcConfig>,
155    Extension(user_store): Extension<US>,
156    Extension(service_config): Extension<OidcServiceConfig>,
157    session: Session,
158    Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
159    TypedHeader(host): TypedHeader<Host>,
160) -> Result<Response, OidcError> {
161    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
162
163    if let Some(iss) = iss {
164        assert_eq!(iss, oidc_config.issuer);
165    }
166    let oidc_state = session
167        .remove::<OidcState>(SESSION_KEY_OIDC_STATE)
168        .await?
169        .ok_or(OidcError::Other("No local OIDC state"))?;
170
171    assert_eq!(oidc_state.state.secret(), &state);
172
173    let http_client = get_http_client();
174    let oidc_client = get_oidc_client(
175        oidc_config.clone(),
176        &http_client,
177        RedirectUrl::new(callback_uri)?,
178    )
179    .await?;
180
181    let token_response = oidc_client
182        .exchange_code(code)?
183        .set_pkce_verifier(oidc_state.pkce_verifier)
184        .request_async(&http_client)
185        .await
186        .map_err(|_| OidcError::Other("Error requesting token"))?;
187    let id_claims = token_response
188        .id_token()
189        .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
190        .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
191
192    let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
193        .user_info(
194            token_response.access_token().clone(),
195            Some(id_claims.subject().clone()),
196        )?
197        .request_async(&http_client)
198        .await
199        .map_err(|e| OidcError::UserInfo(e.to_string()))?;
200
201    if let Some(require_group) = &oidc_config.require_group
202        && !user_info_claims
203            .additional_claims()
204            .groups
205            .clone()
206            .unwrap_or_default()
207            .contains(require_group)
208    {
209        return Ok((
210            StatusCode::UNAUTHORIZED,
211            "User is not in an authorized group to use RustiCal",
212        )
213            .into_response());
214    }
215
216    let user_id = match oidc_config.claim_userid {
217        UserIdClaim::Sub => user_info_claims.subject().to_string(),
218        UserIdClaim::PreferredUsername => user_info_claims
219            .preferred_username()
220            .ok_or(OidcError::Other("Missing preferred_username claim"))?
221            .to_string(),
222    };
223
224    match user_store.user_exists(&user_id).await {
225        Ok(false) => {
226            // User does not exist
227            if !oidc_config.allow_sign_up {
228                return Ok((StatusCode::UNAUTHORIZED, "User signup is disabled").into_response());
229            }
230            // Create new user
231            if let Err(err) = user_store.insert_user(&user_id).await {
232                return Ok(err.into_response());
233            }
234        }
235        Ok(true) => {}
236        Err(err) => {
237            return Ok(err.into_response());
238        }
239    }
240
241    let default_redirect = service_config.default_redirect_path.to_owned();
242    let base_url: Url = format!("https://{host}").parse().unwrap();
243    let redirect_uri = if let Some(redirect_uri) = oidc_state.redirect_uri {
244        if let Ok(redirect_url) = base_url.join(&redirect_uri) {
245            if redirect_url.origin() == base_url.origin() {
246                redirect_url.path().to_owned()
247            } else {
248                default_redirect
249            }
250        } else {
251            default_redirect
252        }
253    } else {
254        default_redirect
255    };
256
257    // Complete login flow
258    session
259        .insert(service_config.session_key_user_id, user_id.clone())
260        .await?;
261
262    Ok(Redirect::to(&redirect_uri).into_response())
263}