diff --git a/controller/DB.hpp b/controller/DB.hpp index 10adbec10..64bd83af0 100644 --- a/controller/DB.hpp +++ b/controller/DB.hpp @@ -53,6 +53,7 @@ public: , ssoNonce() , ssoState() , ssoClientID() + , ssoProvider("default") {} bool enabled; @@ -64,6 +65,7 @@ public: std::string ssoNonce; std::string ssoState; std::string ssoClientID; + std::string ssoProvider; }; /** diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp index 1c12980f1..bf3acfe4c 100644 --- a/controller/PostgreSQL.cpp +++ b/controller/PostgreSQL.cpp @@ -442,24 +442,29 @@ AuthInfo PostgreSQL::getSSOAuthInfo(const nlohmann::json &member, const std::str exit(7); } - r = w.exec_params("SELECT oc.client_id, oc.authorization_endpoint, oc.issuer, oc.sso_impl_version " - "FROM ztc_network n " - "INNER JOIN ztc_network_oidc_config noc " - " ON noc.network_id = n.id " - "INNER JOIN ztc_oidc_config oc " - " ON noc.client_id = oc.client_id " - "WHERE n.id = $1 AND n.sso_enabled = true", networkId); + r = w.exec_params( + "SELECT oc.client_id, oc.authorization_endpoint, oc.issuer, oc.provider, oc.sso_impl_version " + "FROM ztc_network AS n " + "INNER JOIN ztc_org o " + " ON o.owner_id = n.owner_id " + "LEFT OUTER JOIN ztc_network_oidc_config noc " + " ON noc.network_id = n.id " + "LEFT OUTER JOIN ztc_oidc_config oc " + " ON noc.client_id = oc.client_id AND noc.org_id = o.org_id " + "WHERE n.id = $1 AND n.sso_enabled = true", networkId); std::string client_id = ""; std::string authorization_endpoint = ""; std::string issuer = ""; + std::string provider = ""; uint64_t sso_version = 0; if (r.size() == 1) { client_id = r.at(0)[0].as(); authorization_endpoint = r.at(0)[1].as(); issuer = r.at(0)[2].as(); - sso_version = r.at(0)[3].as(); + provider = r.at(0)[3].as(); + sso_version = r.at(0)[4].as(); } else if (r.size() > 1) { fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str()); } else { @@ -489,18 +494,20 @@ AuthInfo PostgreSQL::getSSOAuthInfo(const nlohmann::json &member, const std::str } else if (info.version == 1) { info.ssoClientID = client_id; info.issuerURL = issuer; + info.ssoProvider = provider; info.ssoNonce = nonce; info.ssoState = std::string(state_hex) + "_" +networkId; info.centralAuthURL = redirectURL; #ifdef ZT_DEBUG fprintf( stderr, - "ssoClientID: %s\nissuerURL: %s\nssoNonce: %s\nssoState: %s\ncentralAuthURL: %s\n", + "ssoClientID: %s\nissuerURL: %s\nssoNonce: %s\nssoState: %s\ncentralAuthURL: %s\nprovider: %s\n", info.ssoClientID.c_str(), info.issuerURL.c_str(), info.ssoNonce.c_str(), info.ssoState.c_str(), - info.centralAuthURL.c_str()); + info.centralAuthURL.c_str(), + provider.c_str()); #endif } } else { @@ -539,10 +546,12 @@ void PostgreSQL::initializeNetworks() std::unordered_set networkSet; char qbuf[2048] = {0}; - sprintf(qbuf, "SELECT n.id, (EXTRACT(EPOCH FROM n.creation_time AT TIME ZONE 'UTC')*1000)::bigint as creation_time, n.capabilities, " + sprintf(qbuf, + "SELECT n.id, (EXTRACT(EPOCH FROM n.creation_time AT TIME ZONE 'UTC')*1000)::bigint as creation_time, n.capabilities, " "n.enable_broadcast, (EXTRACT(EPOCH FROM n.last_modified AT TIME ZONE 'UTC')*1000)::bigint AS last_modified, n.mtu, n.multicast_limit, n.name, n.private, n.remote_trace_level, " - "n.remote_trace_target, n.revision, n.rules, n.tags, n.v4_assign_mode, n.v6_assign_mode, n.sso_enabled, (CASE WHEN n.sso_enabled THEN o.client_id ELSE NULL END) as client_id, " - "(CASE WHEN n.sso_enabled THEN o.authorization_endpoint ELSE NULL END) as authorization_endpoint, d.domain, d.servers, " + "n.remote_trace_target, n.revision, n.rules, n.tags, n.v4_assign_mode, n.v6_assign_mode, n.sso_enabled, (CASE WHEN n.sso_enabled THEN noc.client_id ELSE NULL END) as client_id, " + "(CASE WHEN n.sso_enabled THEN oc.authorization_endpoint ELSE NULL END) as authorization_endpoint, " + "(CASE WHEN n.sso_enabled THEN oc.provider ELSE NULL END) as provider, d.domain, d.servers, " "ARRAY(SELECT CONCAT(host(ip_range_start),'|', host(ip_range_end)) FROM ztc_network_assignment_pool WHERE network_id = n.id) AS assignment_pool, " "ARRAY(SELECT CONCAT(host(address),'/',bits::text,'|',COALESCE(host(via), 'NULL'))FROM ztc_network_route WHERE network_id = n.id) AS routes " "FROM ztc_network n " @@ -551,7 +560,7 @@ void PostgreSQL::initializeNetworks() "LEFT OUTER JOIN ztc_network_oidc_config noc " " ON noc.network_id = n.id " "LEFT OUTER JOIN ztc_oidc_config oc " - " ON noc.client_id = oc.client_id AND o.org_id = oc.org_id " + " ON noc.client_id = oc.client_id AND oc.org_id = o.org_id " "LEFT OUTER JOIN ztc_network_dns d " " ON d.network_id = n.id " "WHERE deleted = false AND controller_id = '%s'", _myAddressStr.c_str()); @@ -582,6 +591,7 @@ void PostgreSQL::initializeNetworks() , std::optional // ssoEnabled , std::optional // clientId , std::optional // authorizationEndpoint + , std::optional // ssoProvider , std::optional // domain , std::optional // servers , std::string // assignmentPoolString @@ -618,10 +628,11 @@ void PostgreSQL::initializeNetworks() std::optional ssoEnabled = std::get<16>(row); std::optional clientId = std::get<17>(row); std::optional authorizationEndpoint = std::get<18>(row); - std::optional dnsDomain = std::get<19>(row); - std::optional dnsServers = std::get<20>(row); - std::string assignmentPoolString = std::get<21>(row); - std::string routesString = std::get<22>(row); + std::optional ssoProvider = std::get<19>(row); + std::optional dnsDomain = std::get<20>(row); + std::optional dnsServers = std::get<21>(row); + std::string assignmentPoolString = std::get<22>(row); + std::string routesString = std::get<23>(row); config["id"] = nwid; config["nwid"] = nwid; @@ -646,6 +657,7 @@ void PostgreSQL::initializeNetworks() config["routes"] = json::array(); config["clientId"] = clientId.value_or(""); config["authorizationEndpoint"] = authorizationEndpoint.value_or(""); + config["provider"] = ssoProvider.value_or(""); networkSet.insert(nwid); diff --git a/node/IncomingPacket.cpp b/node/IncomingPacket.cpp index 9080128b6..a120a208a 100644 --- a/node/IncomingPacket.cpp +++ b/node/IncomingPacket.cpp @@ -217,6 +217,7 @@ bool IncomingPacket::_doERROR(const RuntimeEnvironment *RR,void *tPtr,const Shar char ssoNonce[64] = { 0 }; char ssoState[128] = {0}; char ssoClientID[256] = { 0 }; + char ssoProvider[64] = { 0 }; if (authInfo.get(ZT_AUTHINFO_DICT_KEY_ISSUER_URL, issuerURL, sizeof(issuerURL)) > 0) { issuerURL[sizeof(issuerURL) - 1] = 0; @@ -233,8 +234,13 @@ bool IncomingPacket::_doERROR(const RuntimeEnvironment *RR,void *tPtr,const Shar if (authInfo.get(ZT_AUTHINFO_DICT_KEY_CLIENT_ID, ssoClientID, sizeof(ssoClientID)) > 0) { ssoClientID[sizeof(ssoClientID) - 1] = 0; } + if (authInfo.get(ZT_AUTHINFO_DICT_KEY_SSO_PROVIDER, ssoProvider, sizeof(ssoProvider)) > 0 ) { + ssoProvider[sizeof(ssoProvider) - 1] = 0; + } else { + strncpy(ssoProvider, "default", sizeof(ssoProvider)); + } - network->setAuthenticationRequired(tPtr, issuerURL, centralAuthURL, ssoClientID, ssoNonce, ssoState); + network->setAuthenticationRequired(tPtr, issuerURL, centralAuthURL, ssoClientID, ssoProvider, ssoNonce, ssoState); } } } else { diff --git a/node/Network.cpp b/node/Network.cpp index a3810162b..41e5186b1 100644 --- a/node/Network.cpp +++ b/node/Network.cpp @@ -1556,7 +1556,7 @@ Membership &Network::_membership(const Address &a) return _memberships[a]; } -void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char* nonce, const char* state) +void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char *ssoProvider, const char* nonce, const char* state) { Mutex::Lock _l(_lock); _netconfFailure = NETCONF_FAILURE_AUTHENTICATION_REQUIRED; @@ -1568,6 +1568,7 @@ void Network::setAuthenticationRequired(void *tPtr, const char* issuerURL, const Utils::scopy(_config.ssoClientID, sizeof(_config.ssoClientID), clientID); Utils::scopy(_config.ssoNonce, sizeof(_config.ssoNonce), nonce); Utils::scopy(_config.ssoState, sizeof(_config.ssoState), state); + Utils::scopy(_config.ssoProvider, sizeof(_config.ssoProvider), ssoProvider); _sendUpdateEvent(tPtr); } diff --git a/node/Network.hpp b/node/Network.hpp index b427a83d6..275e82f02 100644 --- a/node/Network.hpp +++ b/node/Network.hpp @@ -241,7 +241,7 @@ public: * set netconf failure to 'authentication required' along with info needed * for sso full flow authentication. */ - void setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char* nonce, const char* state); + void setAuthenticationRequired(void *tPtr, const char* issuerURL, const char* centralEndpoint, const char* clientID, const char *ssoProvider, const char* nonce, const char* state); /** * Causes this network to request an updated configuration from its master node now diff --git a/node/NetworkConfig.cpp b/node/NetworkConfig.cpp index 13a9313aa..a200ba8ca 100644 --- a/node/NetworkConfig.cpp +++ b/node/NetworkConfig.cpp @@ -201,6 +201,7 @@ bool NetworkConfig::toDictionary(Dictionary &d,b if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_NONCE, this->ssoNonce)) return false; if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_STATE, this->ssoState)) return false; if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_CLIENT_ID, this->ssoClientID)) return false; + if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER, this->ssoProvider)) return false; } delete tmp; @@ -424,6 +425,11 @@ bool NetworkConfig::fromDictionary(const DictionaryssoClientID, (unsigned int)sizeof(this->ssoClientID)) > 0) { this->ssoClientID[sizeof(this->ssoClientID) - 1] = 0; } + if (d.get(ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER, this->ssoProvider, (unsigned int)(sizeof(this->ssoProvider))) > 0) { + this->ssoProvider[sizeof(this->ssoProvider) - 1] = 0; + } else { + strncpy(this->ssoProvider, "default", sizeof(this->ssoProvider)); + } } else { this->authenticationURL[0] = 0; this->authenticationExpiryTime = 0; @@ -432,6 +438,7 @@ bool NetworkConfig::fromDictionary(const DictionaryssoState[0] = 0; this->ssoClientID[0] = 0; this->issuerURL[0] = 0; + this->ssoProvider[0] = 0; } } } diff --git a/node/NetworkConfig.hpp b/node/NetworkConfig.hpp index 0161b4fa9..cd713dde8 100644 --- a/node/NetworkConfig.hpp +++ b/node/NetworkConfig.hpp @@ -195,6 +195,8 @@ namespace ZeroTier { #define ZT_NETWORKCONFIG_DICT_KEY_STATE "ssos" // client ID #define ZT_NETWORKCONFIG_DICT_KEY_CLIENT_ID "ssocid" +// SSO Provider +#define ZT_NETWORKCONFIG_DICT_KEY_SSO_PROVIDER "ssop" // AuthInfo fields -- used by ncSendError for sso @@ -212,6 +214,8 @@ namespace ZeroTier { #define ZT_AUTHINFO_DICT_KEY_STATE "aS" // Client ID #define ZT_AUTHINFO_DICT_KEY_CLIENT_ID "aCID" +// SSO Provider +#define ZT_AUTHINFO_DICT_KEY_SSO_PROVIDER "aSSOp" // Legacy fields -- these are obsoleted but are included when older clients query @@ -289,6 +293,7 @@ public: memset(ssoNonce, 0, sizeof(ssoNonce)); memset(ssoState, 0, sizeof(ssoState)); memset(ssoClientID, 0, sizeof(ssoClientID)); + strncpy(ssoProvider, "default", sizeof(ssoProvider)); } /** @@ -699,6 +704,15 @@ public: * oidc client id */ char ssoClientID[256]; + + /** + * oidc provider + * + * because certain providers require specific scopes to be requested + * and others to be not requested in order to make everything work + * correctly + **/ + char ssoProvider[64]; }; } // namespace ZeroTier