1use axum::{
2 Extension, Form,
3 extract::Query,
4 response::{IntoResponse, Redirect, Response},
5};
6use axum_extra::extract::Host;
7pub use config::OidcConfig;
8use config::UserIdClaim;
9use error::OidcError;
10use openidconnect::{
11 AuthenticationFlow, AuthorizationCode, CsrfToken, EndpointMaybeSet, EndpointNotSet,
12 EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
13 RedirectUrl, TokenResponse, UserInfoClaims,
14 core::{CoreClient, CoreGenderClaim, CoreProviderMetadata, CoreResponseType},
15};
16use reqwest::{StatusCode, Url};
17use serde::{Deserialize, Serialize};
18use tower_sessions::Session;
19pub use user_store::UserStore;
20
21mod config;
22mod error;
23mod user_store;
24
25const SESSION_KEY_OIDC_STATE: &str = "oidc_state";
26
27#[derive(Debug, Clone)]
28pub struct OidcServiceConfig {
29 pub default_redirect_path: &'static str,
30 pub session_key_user_id: &'static str,
31}
32
33#[derive(Debug, Deserialize, Serialize)]
34struct OidcState {
35 state: CsrfToken,
36 nonce: Nonce,
37 pkce_verifier: PkceCodeVerifier,
38 redirect_uri: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Serialize)]
42struct GroupAdditionalClaims {
43 #[serde(default)]
44 groups: Option<Vec<String>>,
45}
46
47impl openidconnect::AdditionalClaims for GroupAdditionalClaims {}
48
49fn get_http_client() -> reqwest::Client {
50 reqwest::ClientBuilder::new()
51 .redirect(reqwest::redirect::Policy::none())
53 .build()
54 .expect("Something went wrong :(")
55}
56
57async fn get_oidc_client(
58 OidcConfig {
59 issuer,
60 client_id,
61 client_secret,
62 ..
63 }: OidcConfig,
64 http_client: &reqwest::Client,
65 redirect_uri: RedirectUrl,
66) -> Result<
67 CoreClient<
68 EndpointSet,
69 EndpointNotSet,
70 EndpointNotSet,
71 EndpointNotSet,
72 EndpointMaybeSet,
73 EndpointMaybeSet,
74 >,
75 OidcError,
76> {
77 let provider_metadata = CoreProviderMetadata::discover_async(issuer, http_client)
78 .await
79 .map_err(|err| {
80 tracing::error!("An error occured trying to discover OpenID provider: {err}");
81 OidcError::Other("Failed to discover OpenID provider")
82 })?;
83
84 Ok(CoreClient::from_provider_metadata(
85 provider_metadata.clone(),
86 client_id.clone(),
87 client_secret.clone(),
88 )
89 .set_redirect_uri(redirect_uri))
90}
91
92#[derive(Debug, Deserialize)]
93pub struct GetOidcForm {
94 redirect_uri: Option<String>,
95}
96
97pub async fn route_post_oidc(
99 Extension(oidc_config): Extension<OidcConfig>,
100 session: Session,
101 Host(host): Host,
102 Form(GetOidcForm { redirect_uri }): Form<GetOidcForm>,
103) -> Result<Response, OidcError> {
104 let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
105
106 let http_client = get_http_client();
107 let oidc_client = get_oidc_client(
108 oidc_config.clone(),
109 &http_client,
110 RedirectUrl::new(callback_uri)?,
111 )
112 .await?;
113
114 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
115
116 let (auth_url, csrf_token, nonce) = oidc_client
117 .authorize_url(
118 AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
119 CsrfToken::new_random,
120 Nonce::new_random,
121 )
122 .add_scopes(oidc_config.scopes.clone())
123 .set_pkce_challenge(pkce_challenge)
124 .url();
125
126 session
127 .insert(
128 SESSION_KEY_OIDC_STATE,
129 OidcState {
130 state: csrf_token,
131 nonce,
132 pkce_verifier,
133 redirect_uri,
134 },
135 )
136 .await?;
137
138 Ok(Redirect::to(auth_url.as_str()).into_response())
139}
140
141#[derive(Debug, Clone, Deserialize)]
142pub struct AuthCallbackQuery {
143 code: AuthorizationCode,
144 iss: Option<IssuerUrl>,
146 state: String,
147}
148
149pub async fn route_get_oidc_callback<US: UserStore + Clone>(
151 Extension(oidc_config): Extension<OidcConfig>,
152 Extension(user_store): Extension<US>,
153 Extension(service_config): Extension<OidcServiceConfig>,
154 session: Session,
155 Query(AuthCallbackQuery { code, iss, state }): Query<AuthCallbackQuery>,
156 Host(host): Host,
157) -> Result<Response, OidcError> {
158 let callback_uri = format!("https://{host}/frontend/login/oidc/callback");
159
160 if let Some(iss) = iss {
161 assert_eq!(iss, oidc_config.issuer);
162 }
163 let oidc_state = session
164 .remove::<OidcState>(SESSION_KEY_OIDC_STATE)
165 .await?
166 .ok_or(OidcError::Other("No local OIDC state"))?;
167
168 assert_eq!(oidc_state.state.secret(), &state);
169
170 let http_client = get_http_client();
171 let oidc_client = get_oidc_client(
172 oidc_config.clone(),
173 &http_client,
174 RedirectUrl::new(callback_uri)?,
175 )
176 .await?;
177
178 let token_response = oidc_client
179 .exchange_code(code)?
180 .set_pkce_verifier(oidc_state.pkce_verifier)
181 .request_async(&http_client)
182 .await
183 .map_err(|_| OidcError::Other("Error requesting token"))?;
184 let id_claims = token_response
185 .id_token()
186 .ok_or(OidcError::Other("OIDC provider did not return an ID token"))?
187 .claims(&oidc_client.id_token_verifier(), &oidc_state.nonce)?;
188
189 let user_info_claims: UserInfoClaims<GroupAdditionalClaims, CoreGenderClaim> = oidc_client
190 .user_info(
191 token_response.access_token().clone(),
192 Some(id_claims.subject().clone()),
193 )?
194 .request_async(&http_client)
195 .await
196 .map_err(|e| OidcError::UserInfo(e.to_string()))?;
197
198 if let Some(require_group) = &oidc_config.require_group
199 && !user_info_claims
200 .additional_claims()
201 .groups
202 .clone()
203 .unwrap_or_default()
204 .contains(require_group)
205 {
206 return Ok((
207 StatusCode::UNAUTHORIZED,
208 "User is not in an authorized group to use RustiCal",
209 )
210 .into_response());
211 }
212
213 let user_id = match oidc_config.claim_userid {
214 UserIdClaim::Sub => user_info_claims.subject().to_string(),
215 UserIdClaim::PreferredUsername => user_info_claims
216 .preferred_username()
217 .ok_or(OidcError::Other("Missing preferred_username claim"))?
218 .to_string(),
219 };
220
221 match user_store.user_exists(&user_id).await {
222 Ok(false) => {
223 if !oidc_config.allow_sign_up {
225 return Ok((StatusCode::UNAUTHORIZED, "User signup is disabled").into_response());
226 }
227 if let Err(err) = user_store.insert_user(&user_id).await {
229 return Ok(err.into_response());
230 }
231 }
232 Ok(true) => {}
233 Err(err) => {
234 return Ok(err.into_response());
235 }
236 }
237
238 let default_redirect = service_config.default_redirect_path.to_owned();
239 let base_url: Url = format!("https://{host}").parse().unwrap();
240 let redirect_uri = if let Some(redirect_uri) = oidc_state.redirect_uri {
241 if let Ok(redirect_url) = base_url.join(&redirect_uri) {
242 if redirect_url.origin() == base_url.origin() {
243 redirect_url.path().to_owned()
244 } else {
245 default_redirect
246 }
247 } else {
248 default_redirect
249 }
250 } else {
251 default_redirect
252 };
253
254 session
256 .insert(service_config.session_key_user_id, user_id.clone())
257 .await?;
258
259 Ok(Redirect::to(&redirect_uri).into_response())
260}