plumbing full flow from controller -> client network

This commit is contained in:
Grant Limberg 2021-11-04 15:40:08 -07:00
commit 8d39c9a861
No known key found for this signature in database
GPG key ID: 2BA62CCABBB4095A
14 changed files with 400 additions and 70 deletions

View file

@ -40,6 +40,30 @@
namespace ZeroTier
{
struct AuthInfo
{
public:
AuthInfo()
: enabled(false)
, version(0)
, authenticationURL()
, authenticationExpiryTime(0)
, centralAuthURL()
, ssoNonce()
, ssoState()
, ssoClientID()
{}
bool enabled;
uint64_t version;
std::string authenticationURL;
uint64_t authenticationExpiryTime;
std::string centralAuthURL;
std::string ssoNonce;
std::string ssoState;
std::string ssoClientID;
};
/**
* Base class with common infrastructure for all controller DB implementations
*/
@ -108,7 +132,7 @@ public:
virtual void eraseMember(const uint64_t networkId,const uint64_t memberId) = 0;
virtual void nodeIsOnline(const uint64_t networkId,const uint64_t memberId,const InetAddress &physicalAddress) = 0;
virtual std::string getSSOAuthURL(const nlohmann::json &member, const std::string &redirectURL) { return ""; }
virtual AuthInfo getSSOAuthInfo(const nlohmann::json &member, const std::string &redirectURL) { return AuthInfo(); }
virtual void networkMemberSSOHasExpired(uint64_t nwid, int64_t ts);
inline void addListener(DB::ChangeListener *const listener)

View file

@ -125,16 +125,16 @@ bool DBMirrorSet::get(const uint64_t networkId,nlohmann::json &network,std::vect
return false;
}
std::string DBMirrorSet::getSSOAuthURL(const nlohmann::json &member, const std::string &redirectURL)
AuthInfo DBMirrorSet::getSSOAuthInfo(const nlohmann::json &member, const std::string &redirectURL)
{
std::lock_guard<std::mutex> l(_dbs_l);
for(auto d=_dbs.begin();d!=_dbs.end();++d) {
std::string url = (*d)->getSSOAuthURL(member, redirectURL);
if (!url.empty()) {
return url;
AuthInfo info = (*d)->getSSOAuthInfo(member, redirectURL);
if (info.enabled) {
return info;
}
}
return "";
return AuthInfo();
}
void DBMirrorSet::networkMemberSSOHasExpired(uint64_t nwid, int64_t ts)

View file

@ -51,7 +51,7 @@ public:
virtual void onNetworkMemberUpdate(const void *db,uint64_t networkId,uint64_t memberId,const nlohmann::json &member);
virtual void onNetworkMemberDeauthorize(const void *db,uint64_t networkId,uint64_t memberId);
std::string getSSOAuthURL(const nlohmann::json &member, const std::string &redirectURL);
AuthInfo getSSOAuthInfo(const nlohmann::json &member, const std::string &redirectURL);
void networkMemberSSOHasExpired(uint64_t nwid, int64_t ts);
inline void addDB(const std::shared_ptr<DB> &db)

View file

@ -1360,27 +1360,53 @@ void EmbeddedNetworkController::_request(
// Otherwise no, we use standard auth logic.
bool networkSSOEnabled = OSUtils::jsonBool(network["ssoEnabled"], false);
bool memberSSOExempt = OSUtils::jsonBool(member["ssoExempt"], false);
std::string authenticationURL;
if (networkSSOEnabled && !memberSSOExempt) {
authenticationURL = _db.getSSOAuthURL(member, _ssoRedirectURL);
AuthInfo info;
if (networkSSOEnabled && ! memberSSOExempt) {
info = _db.getSSOAuthInfo(member, _ssoRedirectURL);
assert(info.enabled == networkSSOEnabled);
std::string memberId = member["id"];
//fprintf(stderr, "ssoEnabled && !ssoExempt %s-%s\n", nwids, memberId.c_str());
uint64_t authenticationExpiryTime = (int64_t)OSUtils::jsonInt(member["authenticationExpiryTime"], 0);
//fprintf(stderr, "authExpiryTime: %lld\n", authenticationExpiryTime);
if (authenticationExpiryTime < now) {
if (!authenticationURL.empty()) {
_db.networkMemberSSOHasExpired(nwid, now);
onNetworkMemberDeauthorize(&_db, nwid, identity.address().toInt());
if (info.version == 0) {
if (!info.authenticationURL.empty()) {
_db.networkMemberSSOHasExpired(nwid, now);
onNetworkMemberDeauthorize(&_db, nwid, identity.address().toInt());
Dictionary<3072> authInfo;
authInfo.add("aU", authenticationURL.c_str());
//fprintf(stderr, "sending auth URL: %s\n", authenticationURL.c_str());
Dictionary<4096> authInfo;
authInfo.add(ZT_AUTHINFO_DICT_KEY_VERSION, 0ULL);
authInfo.add(ZT_AUTHINFO_DICT_KEY_AUTHENTICATION_URL, info.authenticationURL.c_str());
//fprintf(stderr, "sending auth URL: %s\n", authenticationURL.c_str());
DB::cleanMember(member);
_db.save(member,true);
DB::cleanMember(member);
_db.save(member,true);
_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes());
return;
_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes());
return;
}
} else if (info.version == 1) {
if (!info.authenticationURL.empty()) {
_db.networkMemberSSOHasExpired(nwid, now);
onNetworkMemberDeauthorize(&_db, nwid, identity.address().toInt());
Dictionary<8192> authInfo;
authInfo.add(ZT_AUTHINFO_DICT_KEY_VERSION, info.version);
authInfo.add(ZT_AUTHINFO_DICT_KEY_AUTHENTICATION_URL, info.authenticationURL.c_str());
authInfo.add(ZT_AUTHINFO_DICT_KEY_CENTRAL_ENDPOINT_URL, info.centralAuthURL.c_str());
authInfo.add(ZT_AUTHINFO_DICT_KEY_NONCE, info.ssoNonce.c_str());
authInfo.add(ZT_AUTHINFO_DICT_KEY_STATE, info.ssoState.c_str());
authInfo.add(ZT_AUTHINFO_DICT_KEY_CLIENT_ID, info.ssoClientID.c_str());
DB::cleanMember(member);
_db.save(member, true);
_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes());
return;
}
} else {
fprintf(stderr, "invalid sso info.version %llu\n", info.version);
}
} else if (authorized) {
_db.memberWillExpire(authenticationExpiryTime, nwid, identity.address().toInt());
@ -1452,9 +1478,32 @@ void EmbeddedNetworkController::_request(
nc->multicastLimit = (unsigned int)OSUtils::jsonInt(network["multicastLimit"],32ULL);
nc->ssoEnabled = OSUtils::jsonBool(network["ssoEnabled"], false);
nc->authenticationExpiryTime = OSUtils::jsonInt(member["authenticationExpiryTime"], 0LL);
if (!authenticationURL.empty())
Utils::scopy(nc->authenticationURL, sizeof(nc->authenticationURL), authenticationURL.c_str());
nc->ssoVersion = info.version;
if (info.version == 0) {
nc->authenticationExpiryTime = OSUtils::jsonInt(member["authenticationExpiryTime"], 0LL);
if (!info.authenticationURL.empty()) {
Utils::scopy(nc->authenticationURL, sizeof(nc->authenticationURL), info.authenticationURL.c_str());
}
}
else if (info.version == 1) {
nc->authenticationExpiryTime = OSUtils::jsonInt(member["authenticationExpiryTime"], 0LL);
if (!info.authenticationURL.empty()) {
Utils::scopy(nc->authenticationURL, sizeof(nc->authenticationURL), info.authenticationURL.c_str());
}
if (!info.centralAuthURL.empty()) {
Utils::scopy(nc->centralAuthURL, sizeof(nc->centralAuthURL), info.centralAuthURL.c_str());
}
if (!info.ssoNonce.empty()) {
Utils::scopy(nc->ssoNonce, sizeof(nc->ssoNonce), info.ssoNonce.c_str());
}
if (!info.ssoState.empty()) {
Utils::scopy(nc->ssoState, sizeof(nc->ssoState), info.ssoState.c_str());
}
if (!info.ssoClientID.empty()) {
Utils::scopy(nc->ssoClientID, sizeof(nc->ssoClientID), info.ssoClientID.c_str());
}
}
std::string rtt(OSUtils::jsonString(member["remoteTraceTarget"],""));
if (rtt.length() == 10) {

View file

@ -336,7 +336,7 @@ void PostgreSQL::nodeIsOnline(const uint64_t networkId, const uint64_t memberId,
}
}
std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member, const std::string &redirectURL)
AuthInfo PostgreSQL::getSSOAuthInfo(const nlohmann::json &member, const std::string &redirectURL)
{
// NONCE is just a random character string. no semantic meaning
// state = HMAC SHA384 of Nonce based on shared sso key
@ -347,10 +347,12 @@ std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member, const std::s
// how do we tell when a nonce is used? if auth_expiration_time is set
std::string networkId = member["nwid"];
std::string memberId = member["id"];
char authenticationURL[4096] = {0};
//fprintf(stderr, "PostgreSQL::updateMemberOnLoad: %s-%s\n", networkId.c_str(), memberId.c_str());
bool have_auth = false;
char authenticationURL[4096] = {0};
AuthInfo info;
info.enabled = true;
// fprintf(stderr, "PostgreSQL::updateMemberOnLoad: %s-%s\n", networkId.c_str(), memberId.c_str());
try {
auto c = _pool->borrow();
pqxx::work w(*c->c);
@ -390,38 +392,51 @@ std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member, const std::s
exit(6);
}
r = w.exec_params("SELECT org.client_id, org.authorization_endpoint "
r = w.exec_params("SELECT org.client_id, org.authorization_endpoint, org.sso_version "
"FROM ztc_network AS nw, ztc_org AS org "
"WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", networkId);
std::string client_id = "";
std::string authorization_endpoint = "";
uint64_t sso_version = 0;
if (r.size() == 1) {
client_id = r.at(0)[0].as<std::string>();
authorization_endpoint = r.at(0)[1].as<std::string>();
sso_version = r.at(0)[2].as<uint64_t>();
} else if (r.size() > 1) {
fprintf(stderr, "ERROR: More than one auth endpoint for an organization?!?!? NetworkID: %s\n", networkId.c_str());
} else {
fprintf(stderr, "No client or auth endpoint?!?\n");
}
info.version = sso_version;
// no catch all else because we don't actually care if no records exist here. just continue as normal.
if ((!client_id.empty())&&(!authorization_endpoint.empty())) {
have_auth = true;
uint8_t state[48];
HMACSHA384(_ssoPsk, nonceBytes, sizeof(nonceBytes), state);
char state_hex[256];
Utils::hex(state, 48, state_hex);
OSUtils::ztsnprintf(authenticationURL, sizeof(authenticationURL),
"%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redirect_uri=%s&nonce=%s&state=%s&client_id=%s",
authorization_endpoint.c_str(),
redirectURL.c_str(),
nonce.c_str(),
state_hex,
client_id.c_str());
if (info.version == 0) {
char url[2048] = {0};
OSUtils::ztsnprintf(url, sizeof(authenticationURL),
"%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redirect_uri=%s&nonce=%s&state=%s&client_id=%s",
authorization_endpoint.c_str(),
redirectURL.c_str(),
nonce.c_str(),
state_hex,
client_id.c_str());
info.authenticationURL = std::string(url);
} else if (info.version == 1) {
info.ssoClientID = client_id;
info.authenticationURL = authorization_endpoint;
info.ssoNonce = nonce;
info.ssoState = std::string(state_hex);
info.centralAuthURL = redirectURL;
}
} else {
fprintf(stderr, "client_id: %s\nauthorization_endpoint: %s\n", client_id.c_str(), authorization_endpoint.c_str());
}
@ -432,7 +447,7 @@ std::string PostgreSQL::getSSOAuthURL(const nlohmann::json &member, const std::s
fprintf(stderr, "ERROR: Error updating member on load: %s\n", e.what());
}
return std::string(authenticationURL);
return info; //std::string(authenticationURL);
}
void PostgreSQL::initializeNetworks()

View file

@ -107,7 +107,7 @@ public:
virtual void eraseNetwork(const uint64_t networkId);
virtual void eraseMember(const uint64_t networkId, const uint64_t memberId);
virtual void nodeIsOnline(const uint64_t networkId, const uint64_t memberId, const InetAddress &physicalAddress);
virtual std::string getSSOAuthURL(const nlohmann::json &member, const std::string &redirectURL);
virtual AuthInfo getSSOAuthInfo(const nlohmann::json &member, const std::string &redirectURL);
protected:
struct _PairHasher