diff --git a/include/ZeroTierOne.h b/include/ZeroTierOne.h index 23f97b388..3f5e9506f 100644 --- a/include/ZeroTierOne.h +++ b/include/ZeroTierOne.h @@ -1246,6 +1246,11 @@ typedef struct * oidc client id */ char ssoClientID[256]; + + /** + * sso provider + **/ + char ssoProvider[64]; } ZT_VirtualNetworkConfig; /** diff --git a/node/Network.cpp b/node/Network.cpp index 41e5186b1..b03f4b3d0 100644 --- a/node/Network.cpp +++ b/node/Network.cpp @@ -1450,6 +1450,7 @@ void Network::_externalConfig(ZT_VirtualNetworkConfig *ec) const Utils::scopy(ec->ssoNonce, sizeof(ec->ssoNonce), _config.ssoNonce); Utils::scopy(ec->ssoState, sizeof(ec->ssoState), _config.ssoState); Utils::scopy(ec->ssoClientID, sizeof(ec->ssoClientID), _config.ssoClientID); + Utils::scopy(ec->ssoProvider, sizeof(ec->ssoProvider), _config.ssoProvider); } void Network::_sendUpdatesToMembers(void *tPtr,const MulticastGroup *const newMulticastGroup) diff --git a/service/OneService.cpp b/service/OneService.cpp index 5984b8b86..32e648e85 100644 --- a/service/OneService.cpp +++ b/service/OneService.cpp @@ -302,10 +302,12 @@ public: assert(_config.issuerURL != nullptr); assert(_config.ssoClientID != nullptr); assert(_config.centralAuthURL != nullptr); + assert(_config.ssoProvider != nullptr); _idc = zeroidc::zeroidc_new( _config.issuerURL, _config.ssoClientID, + _config.ssoProvider, _config.centralAuthURL, _webPort ); diff --git a/zeroidc/src/ext.rs b/zeroidc/src/ext.rs index dfb25bd1a..d87724a78 100644 --- a/zeroidc/src/ext.rs +++ b/zeroidc/src/ext.rs @@ -28,6 +28,7 @@ pub extern "C" fn zeroidc_new( issuer: *const c_char, client_id: *const c_char, auth_endpoint: *const c_char, + provider: *const c_char, web_listen_port: u16, ) -> *mut ZeroIDC { if issuer.is_null() { @@ -40,6 +41,11 @@ pub extern "C" fn zeroidc_new( return std::ptr::null_mut(); } + if provider.is_null() { + println!("provider is null"); + return std::ptr::null_mut(); + } + if auth_endpoint.is_null() { println!("auth_endpoint is null"); return std::ptr::null_mut(); @@ -47,10 +53,12 @@ pub extern "C" fn zeroidc_new( let issuer = unsafe { CStr::from_ptr(issuer) }; let client_id = unsafe { CStr::from_ptr(client_id) }; + let provider = unsafe { CStr::from_ptr(provider) }; let auth_endpoint = unsafe { CStr::from_ptr(auth_endpoint) }; match ZeroIDC::new( issuer.to_str().unwrap(), client_id.to_str().unwrap(), + provider.to_str().unwrap(), auth_endpoint.to_str().unwrap(), web_listen_port, ) { diff --git a/zeroidc/src/lib.rs b/zeroidc/src/lib.rs index 0da3cd71e..74c371345 100644 --- a/zeroidc/src/lib.rs +++ b/zeroidc/src/lib.rs @@ -61,6 +61,7 @@ struct Inner { running: bool, issuer: String, auth_endpoint: String, + provider: String, oidc_thread: Option>, oidc_client: Option, access_token: Option, @@ -115,6 +116,7 @@ impl ZeroIDC { pub fn new( issuer: &str, client_id: &str, + provider: &str, auth_ep: &str, local_web_port: u16, ) -> Result { @@ -122,6 +124,7 @@ impl ZeroIDC { inner: Arc::new(Mutex::new(Inner { running: false, issuer: issuer.to_string(), + provider: provider.to_string(), auth_endpoint: auth_ep.to_string(), oidc_thread: None, oidc_client: None, @@ -444,36 +447,53 @@ impl ZeroIDC { if need_verifier || csrf_diff || nonce_diff { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let r = i.oidc_client.as_ref().map(|c| { - if i.issuer.contains("okta") { - let (auth_url, csrf_token, nonce) = c - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - csrf_func(csrf_token), - nonce_func(nonce), - ) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())) - .add_scope(Scope::new("groups".to_string())) - .set_pkce_challenge(pkce_challenge) - .url(); - - (auth_url, csrf_token, nonce) - } else { - let (auth_url, csrf_token, nonce) = c - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - csrf_func(csrf_token), - nonce_func(nonce), - ) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())) - .set_pkce_challenge(pkce_challenge) - .url(); - - (auth_url, csrf_token, nonce) + let mut auth_builder = c + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + csrf_func(csrf_token), + nonce_func(nonce), + ) + .set_pkce_challenge(pkce_challenge); + match i.provider.as_str() { + "auth0" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "okta" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "keycloak" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())); + } + "onelogin" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + } + "default" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + _ => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } } + + auth_builder.url() }); if let Some(r) = r {