1use actix_session::Session;
2use actix_web::{
3 HttpRequest, HttpResponse, Responder, ResponseError,
4 http::StatusCode,
5 web::{self, Data, Form, Query, Redirect, ServiceConfig},
6};
7pub use config::OidcConfig;
8use config::UserIdClaim;
9use error::OidcError;
10use openidconnect::{
11 AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
12 EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
13 RedirectUrl, TokenResponse, UserInfoClaims,
14 core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
15};
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18pub use user_store::UserStore;
19
20mod config;
21mod error;
22mod user_store;
23
24pub const ROUTE_NAME_OIDC_LOGIN: &str = "oidc_login";
25const ROUTE_NAME_OIDC_CALLBACK: &str = "oidc_callback";
26const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
27
28#[derive(Debug, Clone)]
29pub struct OidcServiceConfig {
30 pub default_redirect_route_name: &'static str,
31 pub session_key_user_id: &'static str,
32}
33
34#[derive(Debug, Deserialize, Serialize)]
35struct OidcState {
36 state: CsrfToken,
37 nonce: Nonce,
38 pkce_verifier: PkceCodeVerifier,
39 redirect_uri: Option<String>,
40}
41
42#[derive(Debug, Deserialize, Serialize)]
43struct GroupAdditionalClaims {
44 #[serde(default)]
45 pub groups: Vec<String>,
46}
47
48impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
49
50fn get_http_client() -> reqwest::Client {
51 reqwest::ClientBuilder::new()
52 .redirect(reqwest::redirect::Policy::none())
54 .build()
55 .expect("Something went wrong :(")
56}
57
58async fn get_oidc_client(
59 OidcConfig {
60 issuer,
61 client_id,
62 client_secret,
63 ..
64 }: OidcConfig,
65 http_client: &reqwest::Client,
66 redirect_uri: RedirectUrl,
67) -> Result<
68 CoreClient<
69 EndpointSet,
70 EndpointNotSet,
71 EndpointNotSet,
72 EndpointNotSet,
73 EndpointMaybeSet,
74 EndpointMaybeSet,
75 >,
76 OidcError,
77> {
78 let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
79 .await
80 .map_err(|_| OidcError::Other("Failed to discover OpenID provider"))?;
81
82 Ok(CoreClient::from_provider_metadata(
83 provider_metadata.clone(),
84 client_id.clone(),
85 client_secret.clone(),
86 )
87 .set_redirect_uri(redirect_uri))
88}
89
90#[derive(Debug, Deserialize)]
91pub struct GetOidcForm {
92 redirect_uri: Option<String>,
93}
94
95pub async fn route_post_oidc(
97 req: HttpRequest,
98 Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
99 oidc_config: Data<OidcConfig>,
100 session: Session,
101) -> Result<HttpResponse, OidcError> {
102 let http_client = get_http_client();
103 let oidc_client = get_oidc_client(
104 oidc_config.as_ref().clone(),
105 &http_client,
106 RedirectUrl::new(req.url_for_static(ROUTE_NAME_OIDC_CALLBACK)?.to_string())?,
107 )
108 .await?;
109
110 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
111
112 let (auth_url, csrf_token, nonce) = oidc_client
113 .authorize_url(
114 AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
115 CsrfToken::new_random,
116 Nonce::new_random,
117 )
118 .add_scopes(oidc_config.scopes.clone())
119 .set_pkce_challenge(pkce_challenge)
120 .url();
121
122 session.insert(
123 SESSION_KEY_OIDC_STATE,
124 OidcState {
125 state: csrf_token,
126 nonce,
127 pkce_verifier,
128 redirect_uri,
129 },
130 )?;
131
132 Ok(Redirect::to(auth_url.to_string())
133 .see_other()
134 .respond_to(&req)
135 .map_into_boxed_body())
136}
137
138#[derive(Debug, Clone, Deserialize)]
139pub struct AuthCallbackQuery {
140 code: AuthorizationCode,
141 iss: IssuerUrl,
142 state: String,
143}
144
145pub async fn route_get_oidc_callback<US: UserStore>(
147 req: HttpRequest,
148 oidc_config: Data<OidcConfig>,
149 session: Session,
150 user_store: Data<US>,
151 Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
152 service_config: Data<OidcServiceConfig>,
153) -> Result<HttpResponse, OidcError> {
154 assert_eq!(iss, oidc_config.issuer);
155 let oidc_state = session
156 .remove_as::<OidcState>(SESSION_KEY_OIDC_STATE)
157 .ok_or(OidcError::Other("No local OIDC state"))?
158 .map_err(|_| OidcError::Other("Error parsing OIDC state"))?;
159
160 assert_eq!(oidc_state.state.secret(), &state);
161
162 let http_client = get_http_client();
163 let oidc_client = get_oidc_client(
164 oidc_config.get_ref().clone(),
165 &http_client,
166 RedirectUrl::new(req.url_for_static(ROUTE_NAME_OIDC_CALLBACK)?.to_string())?,
167 )
168 .await?;
169
170 let token_response = oidc_client
171 .exchange_code(code)?
172 .set_pkce_verifier(oidc_state.pkce_verifier)
173 .request_async(&http_client)
174 .await
175 .map_err(|_| OidcError::Other("Error requesting token"))?;
176 let id_claims = token_response
177 .id_token()
178 .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
179 .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
180
181 let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
182 .user_info(
183 token_response.access_token().clone(),
184 Some(id_claims.subject().clone()),
185 )?
186 .request_async(&http_client)
187 .await
188 .map_err(|_| OidcError::Other("Error fetching user info"))?;
189
190 if let Some(require_group) = &oidc_config.require_group {
191 if !user_info_claims
192 .additional_claims()
193 .groups
194 .contains(require_group)
195 {
196 return Ok(HttpResponse::build(StatusCode::UNAUTHORIZED)
197 .body("User is not in an authorized group to use RustiCal"));
198 }
199 }
200
201 let user_id = match oidc_config.claim_userid {
202 UserIdClaim::Sub => user_info_claims.subject().to_string(),
203 UserIdClaim::PreferredUsername => user_info_claims
204 .preferred_username()
205 .ok_or(OidcError::Other("Missing preferred_username claim"))?
206 .to_string(),
207 };
208
209 match user_store.user_exists(&user_id).await {
210 Ok(false) => {
211 if !oidc_config.allow_sign_up {
213 return Ok(HttpResponse::Unauthorized().body("User sign up disabled"));
214 }
215 if let Err(err) = user_store.insert_user(&user_id).await {
217 return Ok(err.error_response());
218 }
219 }
220 Ok(true) => {}
221 Err(err) => {
222 return Ok(err.error_response());
223 }
224 }
225
226 let default_redirect = req
227 .url_for_static(service_config.default_redirect_route_name)?
228 .to_string();
229 let redirect_uri = oidc_state.redirect_uri.unwrap_or(default_redirect.clone());
230 let redirect_uri = req
231 .full_url()
232 .join(&redirect_uri)
233 .ok()
234 .and_then(|uri| req.full_url().make_relative(&uri))
235 .unwrap_or(default_redirect);
236
237 session.insert(service_config.session_key_user_id, user_id.clone())?;
239
240 Ok(Redirect::to(redirect_uri)
241 .temporary()
242 .respond_to(&req)
243 .map_into_boxed_body())
244}
245
246pub fn configure_oidc<US: UserStore>(
247 cfg: &mut ServiceConfig,
248 oidc_config: OidcConfig,
249 service_config: OidcServiceConfig,
250 user_store: Arc<US>,
251) {
252 cfg.app_data(Data::new(oidc_config))
253 .app_data(Data::new(service_config))
254 .app_data(Data::from(user_store))
255 .service(
256 web::resource("")
257 .name(ROUTE_NAME_OIDC_LOGIN)
258 .post(route_post_oidc),
259 )
260 .service(
261 web::resource("/callback")
262 .name(ROUTE_NAME_OIDC_CALLBACK)
263 .get(route_get_oidc_callback::<US>),
264 );
265}