rustical_oidc/
lib.rs

1use axum::{
2    Extension, Form,
3    extract::Query,
4    response::{IntoResponse, Redirect, Response},
5};
6use axum_extra::extract::Host;
7pub use config::OidcConfig;
8use config::UserIdClaim;
9use error::OidcError;
10use openidconnect::{
11    AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
12    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
13    RedirectUrl, TokenResponse, UserInfoClaims,
14    core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
15};
16use reqwest::{StatusCode, Url};
17use serde::{Deserialize, Serialize};
18use tower_sessions::Session;
19pub use user_store::UserStore;
20
21mod config;
22mod error;
23mod user_store;
24
25const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
26
27#[derive(Debug, Clone)]
28pub struct OidcServiceConfig {
29    pub default_redirect_path: &'static str,
30    pub session_key_user_id: &'static str,
31}
32
33#[derive(Debug, Deserialize, Serialize)]
34struct OidcState {
35    state: CsrfToken,
36    nonce: Nonce,
37    pkce_verifier: PkceCodeVerifier,
38    redirect_uri: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Serialize)]
42struct GroupAdditionalClaims {
43    #[serde(default)]
44    groups: Option<Vec<String>>,
45}
46
47impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
48
49fn get_http_client() -> reqwest::Client {
50    reqwest::ClientBuilder::new()
51        // Following redirects opens the client up to SSRF vulnerabilities.
52        .redirect(reqwest::redirect::Policy::none())
53        .build()
54        .expect("Something went wrong :(")
55}
56
57async fn get_oidc_client(
58    OidcConfig {
59        issuer,
60        client_id,
61        client_secret,
62        ..
63    }: OidcConfig,
64    http_client: &reqwest::Client,
65    redirect_uri: RedirectUrl,
66) -> Result<
67    CoreClient<
68        EndpointSet,
69        EndpointNotSet,
70        EndpointNotSet,
71        EndpointNotSet,
72        EndpointMaybeSet,
73        EndpointMaybeSet,
74    >,
75    OidcError,
76> {
77    let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
78        .await
79        .map_err(|err| {
80            tracing::error!("An error occured trying to discover OpenID provider: {err}");
81            OidcError::Other("Failed to discover OpenID provider")
82        })?;
83
84    Ok(CoreClient::from_provider_metadata(
85        provider_metadata.clone(),
86        client_id.clone(),
87        client_secret.clone(),
88    )
89    .set_redirect_uri(redirect_uri))
90}
91
92#[derive(Debug, Deserialize)]
93pub struct GetOidcForm {
94    redirect_uri: Option<String>,
95}
96
97/// Endpoint that redirects to the authorize endpoint of the OIDC service
98pub async fn route_post_oidc(
99    Extension(oidc_config): Extension<OidcConfig>,
100    session: Session,
101    Host(host): Host,
102    Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
103) -> Result<Response, OidcError> {
104    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
105
106    let http_client = get_http_client();
107    let oidc_client = get_oidc_client(
108        oidc_config.clone(),
109        &http_client,
110        RedirectUrl::new(callback_uri)?,
111    )
112    .await?;
113
114    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
115
116    let (auth_url, csrf_token, nonce) = oidc_client
117        .authorize_url(
118            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
119            CsrfToken::new_random,
120            Nonce::new_random,
121        )
122        .add_scopes(oidc_config.scopes.clone())
123        .set_pkce_challenge(pkce_challenge)
124        .url();
125
126    session
127        .insert(
128            SESSION_KEY_OIDC_STATE,
129            OidcState {
130                state: csrf_token,
131                nonce,
132                pkce_verifier,
133                redirect_uri,
134            },
135        )
136        .await?;
137
138    Ok(Redirect::to(auth_url.as_str()).into_response())
139}
140
141#[derive(Debug, Clone, Deserialize)]
142pub struct AuthCallbackQuery {
143    code: AuthorizationCode,
144    // RFC 9207
145    iss: Option<IssuerUrl>,
146    state: String,
147}
148
149// Handle callback from IdP page
150pub async fn route_get_oidc_callback<US: UserStore + Clone>(
151    Extension(oidc_config): Extension<OidcConfig>,
152    Extension(user_store): Extension<US>,
153    Extension(service_config): Extension<OidcServiceConfig>,
154    session: Session,
155    Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
156    Host(host): Host,
157) -> Result<Response, OidcError> {
158    let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
159
160    if let Some(iss) = iss {
161        assert_eq!(iss, oidc_config.issuer);
162    }
163    let oidc_state = session
164        .remove::<OidcState>(SESSION_KEY_OIDC_STATE)
165        .await?
166        .ok_or(OidcError::Other("No local OIDC state"))?;
167
168    assert_eq!(oidc_state.state.secret(), &state);
169
170    let http_client = get_http_client();
171    let oidc_client = get_oidc_client(
172        oidc_config.clone(),
173        &http_client,
174        RedirectUrl::new(callback_uri)?,
175    )
176    .await?;
177
178    let token_response = oidc_client
179        .exchange_code(code)?
180        .set_pkce_verifier(oidc_state.pkce_verifier)
181        .request_async(&http_client)
182        .await
183        .map_err(|_| OidcError::Other("Error requesting token"))?;
184    let id_claims = token_response
185        .id_token()
186        .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
187        .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
188
189    let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
190        .user_info(
191            token_response.access_token().clone(),
192            Some(id_claims.subject().clone()),
193        )?
194        .request_async(&http_client)
195        .await
196        .map_err(|e| OidcError::UserInfo(e.to_string()))?;
197
198    if let Some(require_group) = &oidc_config.require_group
199        && !user_info_claims
200            .additional_claims()
201            .groups
202            .clone()
203            .unwrap_or_default()
204            .contains(require_group)
205    {
206        return Ok((
207            StatusCode::UNAUTHORIZED,
208            "User is not in an authorized group to use RustiCal",
209        )
210            .into_response());
211    }
212
213    let user_id = match oidc_config.claim_userid {
214        UserIdClaim::Sub => user_info_claims.subject().to_string(),
215        UserIdClaim::PreferredUsername => user_info_claims
216            .preferred_username()
217            .ok_or(OidcError::Other("Missing preferred_username claim"))?
218            .to_string(),
219    };
220
221    match user_store.user_exists(&user_id).await {
222        Ok(false) => {
223            // User does not exist
224            if !oidc_config.allow_sign_up {
225                return Ok((StatusCode::UNAUTHORIZED, "User signup is disabled").into_response());
226            }
227            // Create new user
228            if let Err(err) = user_store.insert_user(&user_id).await {
229                return Ok(err.into_response());
230            }
231        }
232        Ok(true) => {}
233        Err(err) => {
234            return Ok(err.into_response());
235        }
236    }
237
238    let default_redirect = service_config.default_redirect_path.to_owned();
239    let base_url: Url = format!("https://{host}").parse().unwrap();
240    let redirect_uri = if let Some(redirect_uri) = oidc_state.redirect_uri {
241        if let Ok(redirect_url) = base_url.join(&redirect_uri) {
242            if redirect_url.origin() == base_url.origin() {
243                redirect_url.path().to_owned()
244            } else {
245                default_redirect
246            }
247        } else {
248            default_redirect
249        }
250    } else {
251        default_redirect
252    };
253
254    // Complete login flow
255    session
256        .insert(service_config.session_key_user_id, user_id.clone())
257        .await?;
258
259    Ok(Redirect::to(&redirect_uri).into_response())
260}