rustical_oidc/
lib.rs

1use axum::{
2    Extension, Form,
3    extract::Query,
4    response::{IntoResponse, Redirect, Response},
5};
6use axum_extra::extract::Host;
7pub use config::OidcConfig;
8use config::UserIdClaim;
9use error::OidcError;
10use openidconnect::{
11    AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
12    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
13    RedirectUrl, TokenResponse, UserInfoClaims,
14    core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
15};
16use reqwest::{StatusCode, Url};
17use serde::{Deserialize, Serialize};
18use tower_sessions::Session;
19pub use user_store::UserStore;
20
21mod config;
22mod error;
23mod user_store;
24
25const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
26
27#[derive(Debug, Clone)]
28pub struct OidcServiceConfig {
29    pub default_redirect_path: &'static str,
30    pub session_key_user_id: &'static str,
31}
32
33#[derive(Debug, Deserialize, Serialize)]
34struct OidcState {
35    state: CsrfToken,
36    nonce: Nonce,
37    pkce_verifier: PkceCodeVerifier,
38    redirect_uri: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Serialize)]
42struct GroupAdditionalClaims {
43    #[serde(default)]
44    groups: Option<Vec<String>>,
45}
46
47impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
48
49fn get_http_client() -> reqwest::Client {
50    reqwest::ClientBuilder::new()
51        // Following redirects opens the client up to SSRF vulnerabilities.
52        .redirect(reqwest::redirect::Policy::none())
53        .build()
54        .expect("Something went wrong :(")
55}
56
57async fn get_oidc_client(
58    OidcConfig {
59        issuer,
60        client_id,
61        client_secret,
62        ..
63    }: OidcConfig,
64    http_client: &reqwest::Client,
65    redirect_uri: RedirectUrl,
66) -> Result<
67    CoreClient<
68        EndpointSet,
69        EndpointNotSet,
70        EndpointNotSet,
71        EndpointNotSet,
72        EndpointMaybeSet,
73        EndpointMaybeSet,
74    >,
75    OidcError,
76> {
77    let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
78        .await
79        .map_err(|_| OidcError::Other("Failed to discover OpenID provider"))?;
80
81    Ok(CoreClient::from_provider_metadata(
82        provider_metadata.clone(),
83        client_id.clone(),
84        client_secret.clone(),
85    )
86    .set_redirect_uri(redirect_uri))
87}
88
89#[derive(Debug, Deserialize)]
90pub struct GetOidcForm {
91    redirect_uri: Option<String>,
92}
93
94/// Endpoint that redirects to the authorize endpoint of the OIDC service
95pub async fn route_post_oidc(
96    Extension(oidc_config): Extension<OidcConfig>,
97    session: Session,
98    Host(host): Host,
99    Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
100) -> Result<Response, OidcError> {
101    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
102
103    let http_client = get_http_client();
104    let oidc_client = get_oidc_client(
105        oidc_config.clone(),
106        &http_client,
107        RedirectUrl::new(callback_uri)?,
108    )
109    .await?;
110
111    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
112
113    let (auth_url, csrf_token, nonce) = oidc_client
114        .authorize_url(
115            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
116            CsrfToken::new_random,
117            Nonce::new_random,
118        )
119        .add_scopes(oidc_config.scopes.clone())
120        .set_pkce_challenge(pkce_challenge)
121        .url();
122
123    session
124        .insert(
125            SESSION_KEY_OIDC_STATE,
126            OidcState {
127                state: csrf_token,
128                nonce,
129                pkce_verifier,
130                redirect_uri,
131            },
132        )
133        .await?;
134
135    Ok(Redirect::to(auth_url.as_str()).into_response())
136}
137
138#[derive(Debug, Clone, Deserialize)]
139pub struct AuthCallbackQuery {
140    code: AuthorizationCode,
141    // RFC 9207
142    iss: Option<IssuerUrl>,
143    state: String,
144}
145
146// Handle callback from IdP page
147pub async fn route_get_oidc_callback<US: UserStore + Clone>(
148    Extension(oidc_config): Extension<OidcConfig>,
149    Extension(user_store): Extension<US>,
150    Extension(service_config): Extension<OidcServiceConfig>,
151    session: Session,
152    Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
153    Host(host): Host,
154) -> Result<Response, OidcError> {
155    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
156
157    if let Some(iss) = iss {
158        assert_eq!(iss, oidc_config.issuer);
159    }
160    let oidc_state = session
161        .remove::<OidcState>(SESSION_KEY_OIDC_STATE)
162        .await?
163        .ok_or(OidcError::Other("No local OIDC state"))?;
164
165    assert_eq!(oidc_state.state.secret(), &state);
166
167    let http_client = get_http_client();
168    let oidc_client = get_oidc_client(
169        oidc_config.clone(),
170        &http_client,
171        RedirectUrl::new(callback_uri)?,
172    )
173    .await?;
174
175    let token_response = oidc_client
176        .exchange_code(code)?
177        .set_pkce_verifier(oidc_state.pkce_verifier)
178        .request_async(&http_client)
179        .await
180        .map_err(|_| OidcError::Other("Error requesting token"))?;
181    let id_claims = token_response
182        .id_token()
183        .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
184        .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
185
186    let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
187        .user_info(
188            token_response.access_token().clone(),
189            Some(id_claims.subject().clone()),
190        )?
191        .request_async(&http_client)
192        .await
193        .map_err(|e| OidcError::UserInfo(e.to_string()))?;
194
195    if let Some(require_group) = &oidc_config.require_group {
196        if !user_info_claims
197            .additional_claims()
198            .groups
199            .clone()
200            .unwrap_or_default()
201            .contains(require_group)
202        {
203            return Ok((
204                StatusCode::UNAUTHORIZED,
205                "User is not in an authorized group to use RustiCal",
206            )
207                .into_response());
208        }
209    }
210
211    let user_id = match oidc_config.claim_userid {
212        UserIdClaim::Sub => user_info_claims.subject().to_string(),
213        UserIdClaim::PreferredUsername => user_info_claims
214            .preferred_username()
215            .ok_or(OidcError::Other("Missing preferred_username claim"))?
216            .to_string(),
217    };
218
219    match user_store.user_exists(&user_id).await {
220        Ok(false) => {
221            // User does not exist
222            if !oidc_config.allow_sign_up {
223                return Ok((StatusCode::UNAUTHORIZED, "User signup is disabled").into_response());
224            }
225            // Create new user
226            if let Err(err) = user_store.insert_user(&user_id).await {
227                return Ok(err.into_response());
228            }
229        }
230        Ok(true) => {}
231        Err(err) => {
232            return Ok(err.into_response());
233        }
234    }
235
236    let default_redirect = service_config.default_redirect_path.to_owned();
237    let base_url: Url = format!("https://{host}").parse().unwrap();
238    let redirect_uri = if let Some(redirect_uri) = oidc_state.redirect_uri {
239        if let Ok(redirect_url) = base_url.join(&redirect_uri) {
240            if redirect_url.origin() == base_url.origin() {
241                redirect_url.path().to_owned()
242            } else {
243                default_redirect
244            }
245        } else {
246            default_redirect
247        }
248    } else {
249        default_redirect
250    };
251
252    // Complete login flow
253    session
254        .insert(service_config.session_key_user_id, user_id.clone())
255        .await?;
256
257    Ok(Redirect::to(&redirect_uri).into_response())
258}