rustical_oidc/
lib.rs

1use actix_session::Session;
2use actix_web::{
3    HttpRequest, HttpResponse, Responder, ResponseError,
4    http::StatusCode,
5    web::{self, Data, Form, Query, Redirect, ServiceConfig},
6};
7pub use config::OidcConfig;
8use config::UserIdClaim;
9use error::OidcError;
10use openidconnect::{
11    AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
12    EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
13    RedirectUrl, TokenResponse, UserInfoClaims,
14    core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
15};
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18pub use user_store::UserStore;
19
20mod config;
21mod error;
22mod user_store;
23
24pub const ROUTE_NAME_OIDC_LOGIN: &str = "oidc_login";
25const ROUTE_NAME_OIDC_CALLBACK: &str = "oidc_callback";
26const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
27
28#[derive(Debug, Clone)]
29pub struct OidcServiceConfig {
30    pub default_redirect_route_name: &'static str,
31    pub session_key_user_id: &'static str,
32}
33
34#[derive(Debug, Deserialize, Serialize)]
35struct OidcState {
36    state: CsrfToken,
37    nonce: Nonce,
38    pkce_verifier: PkceCodeVerifier,
39    redirect_uri: Option<String>,
40}
41
42#[derive(Debug, Deserialize, Serialize)]
43struct GroupAdditionalClaims {
44    #[serde(default)]
45    pub groups: Vec<String>,
46}
47
48impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
49
50fn get_http_client() -> reqwest::Client {
51    reqwest::ClientBuilder::new()
52        // Following redirects opens the client up to SSRF vulnerabilities.
53        .redirect(reqwest::redirect::Policy::none())
54        .build()
55        .expect("Something went wrong :(")
56}
57
58async fn get_oidc_client(
59    OidcConfig {
60        issuer,
61        client_id,
62        client_secret,
63        ..
64    }: OidcConfig,
65    http_client: &reqwest::Client,
66    redirect_uri: RedirectUrl,
67) -> Result<
68    CoreClient<
69        EndpointSet,
70        EndpointNotSet,
71        EndpointNotSet,
72        EndpointNotSet,
73        EndpointMaybeSet,
74        EndpointMaybeSet,
75    >,
76    OidcError,
77> {
78    let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
79        .await
80        .map_err(|_| OidcError::Other("Failed to discover OpenID provider"))?;
81
82    Ok(CoreClient::from_provider_metadata(
83        provider_metadata.clone(),
84        client_id.clone(),
85        client_secret.clone(),
86    )
87    .set_redirect_uri(redirect_uri))
88}
89
90#[derive(Debug, Deserialize)]
91pub struct GetOidcForm {
92    redirect_uri: Option<String>,
93}
94
95/// Endpoint that redirects to the authorize endpoint of the OIDC service
96pub async fn route_post_oidc(
97    req: HttpRequest,
98    Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
99    oidc_config: Data<OidcConfig>,
100    session: Session,
101) -> Result<HttpResponse, OidcError> {
102    let http_client = get_http_client();
103    let oidc_client = get_oidc_client(
104        oidc_config.as_ref().clone(),
105        &http_client,
106        RedirectUrl::new(req.url_for_static(ROUTE_NAME_OIDC_CALLBACK)?.to_string())?,
107    )
108    .await?;
109
110    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
111
112    let (auth_url, csrf_token, nonce) = oidc_client
113        .authorize_url(
114            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
115            CsrfToken::new_random,
116            Nonce::new_random,
117        )
118        .add_scopes(oidc_config.scopes.clone())
119        .set_pkce_challenge(pkce_challenge)
120        .url();
121
122    session.insert(
123        SESSION_KEY_OIDC_STATE,
124        OidcState {
125            state: csrf_token,
126            nonce,
127            pkce_verifier,
128            redirect_uri,
129        },
130    )?;
131
132    Ok(Redirect::to(auth_url.to_string())
133        .see_other()
134        .respond_to(&req)
135        .map_into_boxed_body())
136}
137
138#[derive(Debug, Clone, Deserialize)]
139pub struct AuthCallbackQuery {
140    code: AuthorizationCode,
141    iss: IssuerUrl,
142    state: String,
143}
144
145/// Handle callback from IdP page
146pub async fn route_get_oidc_callback<US: UserStore>(
147    req: HttpRequest,
148    oidc_config: Data<OidcConfig>,
149    session: Session,
150    user_store: Data<US>,
151    Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
152    service_config: Data<OidcServiceConfig>,
153) -> Result<HttpResponse, OidcError> {
154    assert_eq!(iss, oidc_config.issuer);
155    let oidc_state = session
156        .remove_as::<OidcState>(SESSION_KEY_OIDC_STATE)
157        .ok_or(OidcError::Other("No local OIDC state"))?
158        .map_err(|_| OidcError::Other("Error parsing OIDC state"))?;
159
160    assert_eq!(oidc_state.state.secret(), &state);
161
162    let http_client = get_http_client();
163    let oidc_client = get_oidc_client(
164        oidc_config.get_ref().clone(),
165        &http_client,
166        RedirectUrl::new(req.url_for_static(ROUTE_NAME_OIDC_CALLBACK)?.to_string())?,
167    )
168    .await?;
169
170    let token_response = oidc_client
171        .exchange_code(code)?
172        .set_pkce_verifier(oidc_state.pkce_verifier)
173        .request_async(&http_client)
174        .await
175        .map_err(|_| OidcError::Other("Error requesting token"))?;
176    let id_claims = token_response
177        .id_token()
178        .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
179        .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
180
181    let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
182        .user_info(
183            token_response.access_token().clone(),
184            Some(id_claims.subject().clone()),
185        )?
186        .request_async(&http_client)
187        .await
188        .map_err(|_| OidcError::Other("Error fetching user info"))?;
189
190    if let Some(require_group) = &oidc_config.require_group {
191        if !user_info_claims
192            .additional_claims()
193            .groups
194            .contains(require_group)
195        {
196            return Ok(HttpResponse::build(StatusCode::UNAUTHORIZED)
197                .body("User is not in an authorized group to use RustiCal"));
198        }
199    }
200
201    let user_id = match oidc_config.claim_userid {
202        UserIdClaim::Sub => user_info_claims.subject().to_string(),
203        UserIdClaim::PreferredUsername => user_info_claims
204            .preferred_username()
205            .ok_or(OidcError::Other("Missing preferred_username claim"))?
206            .to_string(),
207    };
208
209    match user_store.user_exists(&user_id).await {
210        Ok(false) => {
211            // User does not exist
212            if !oidc_config.allow_sign_up {
213                return Ok(HttpResponse::Unauthorized().body("User sign up disabled"));
214            }
215            // Create new user
216            if let Err(err) = user_store.insert_user(&user_id).await {
217                return Ok(err.error_response());
218            }
219        }
220        Ok(true) => {}
221        Err(err) => {
222            return Ok(err.error_response());
223        }
224    }
225
226    let default_redirect = req
227        .url_for_static(service_config.default_redirect_route_name)?
228        .to_string();
229    let redirect_uri = oidc_state.redirect_uri.unwrap_or(default_redirect.clone());
230    let redirect_uri = req
231        .full_url()
232        .join(&redirect_uri)
233        .ok()
234        .and_then(|uri| req.full_url().make_relative(&uri))
235        .unwrap_or(default_redirect);
236
237    // Complete login flow
238    session.insert(service_config.session_key_user_id, user_id.clone())?;
239
240    Ok(Redirect::to(redirect_uri)
241        .temporary()
242        .respond_to(&req)
243        .map_into_boxed_body())
244}
245
246pub fn configure_oidc<US: UserStore>(
247    cfg: &mut ServiceConfig,
248    oidc_config: OidcConfig,
249    service_config: OidcServiceConfig,
250    user_store: Arc<US>,
251) {
252    cfg.app_data(Data::new(oidc_config))
253        .app_data(Data::new(service_config))
254        .app_data(Data::from(user_store))
255        .service(
256            web::resource("")
257                .name(ROUTE_NAME_OIDC_LOGIN)
258                .post(route_post_oidc),
259        )
260        .service(
261            web::resource("/callback")
262                .name(ROUTE_NAME_OIDC_CALLBACK)
263                .get(route_get_oidc_callback::<US>),
264        );
265}