summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormolotkov-and <[email protected]>2023-09-12 17:58:46 +0300
committermolotkov-and <[email protected]>2023-09-12 19:25:55 +0300
commitf9df42fa4e9b160a1a3c72589846109cd086b03a (patch)
tree1ce09d87b89a9d26e8af20d4585667921475a46b
parent937ce8de21d92bc61e3acf2ec3f521dcd23440ac (diff)
KIKIMR-19166: Add refresh tokens autnethicated in LDAP
-rw-r--r--ydb/core/security/ticket_parser_impl.h366
-rw-r--r--ydb/core/security/ticket_parser_ut.cpp214
-rw-r--r--ydb/core/tx/schemeshard/schemeshard__login.cpp2
-rw-r--r--ydb/library/login/login.cpp6
-rw-r--r--ydb/library/login/login.h4
-rw-r--r--ydb/library/login/login_ut.cpp4
6 files changed, 384 insertions, 212 deletions
diff --git a/ydb/core/security/ticket_parser_impl.h b/ydb/core/security/ticket_parser_impl.h
index b848150b9c7..dfcd9867df9 100644
--- a/ydb/core/security/ticket_parser_impl.h
+++ b/ydb/core/security/ticket_parser_impl.h
@@ -36,45 +36,14 @@ inline bool IsRetryableGrpcError(const NGrpc::TGrpcStatus& status) {
template <typename TDerived>
class TTicketParserImpl : public TActorBootstrapped<TDerived> {
+private:
using TThis = TTicketParserImpl;
using TBase = TActorBootstrapped<TDerived>;
- TDerived* GetDerived() {
- return static_cast<TDerived*>(this);
- }
-
- static TString GetKey(TEvTicketParser::TEvAuthorizeTicket* request) {
- TStringStream key;
- if (request->Signature.AccessKeyId) {
- const auto& sign = request->Signature;
- key << sign.AccessKeyId << "-" << sign.Signature << ":" << sign.StringToSign << ":"
- << sign.Service << ":" << sign.Region << ":" << sign.SignedAt.NanoSeconds();
- } else {
- key << request->Ticket;
- }
- key << ':';
- if (request->Database) {
- key << request->Database;
- key << ':';
- }
- for (const auto& entry : request->Entries) {
- for (auto it = entry.Attributes.begin(); it != entry.Attributes.end(); ++it) {
- if (it != entry.Attributes.begin()) {
- key << '-';
- }
- key << it->second;
- }
- key << ':';
- for (auto it = entry.Permissions.begin(); it != entry.Permissions.end(); ++it) {
- if (it != entry.Permissions.begin()) {
- key << '-';
- }
- key << it->Permission << "(" << it->Required << ")";
- }
- }
- return key.Str();
- }
-
+ struct TExternalAuthInfo {
+ TString Type;
+ TString Login;
+ };
struct TPermissionRecord {
TString Subject;
@@ -110,6 +79,141 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
using TEvAccessServiceGetUserAccountRequest = TEvRequestWithKey<NCloud::TEvUserAccountService::TEvGetUserAccountRequest>;
using TEvAccessServiceGetServiceAccountRequest = TEvRequestWithKey<NCloud::TEvServiceAccountService::TEvGetServiceAccountRequest>;
+ struct TTokenRefreshRecord {
+ TString Key;
+ TInstant RefreshTime;
+
+ bool operator <(const TTokenRefreshRecord& o) const {
+ return RefreshTime > o.RefreshTime;
+ }
+ };
+
+protected:
+ class TTokenRecordBase {
+ private:
+ TIntrusiveConstPtr<NACLib::TUserToken> Token;
+ public:
+ TTokenRecordBase(const TTokenRecordBase&) = delete;
+ TTokenRecordBase& operator =(const TTokenRecordBase&) = delete;
+
+ TString Ticket;
+ typename TDerived::ETokenType TokenType = TDerived::ETokenType::Unknown;
+ NKikimr::TEvTicketParser::TEvAuthorizeTicket::TAccessKeySignature Signature;
+ THashMap<TString, TPermissionRecord> Permissions;
+ TString Subject; // login
+ TEvTicketParser::TError Error;
+ TDeque<THolder<TEventHandle<TEvTicketParser::TEvAuthorizeTicket>>> AuthorizeRequests;
+ ui64 ResponsesLeft = 0;
+ TInstant InitTime;
+ TInstant RefreshTime;
+ TInstant ExpireTime;
+ TInstant AccessTime;
+ TDuration CurrentDelay = TDuration::Seconds(1);
+ TString PeerName;
+ TString Database;
+ TStackVec<TString> AdditionalSIDs;
+ bool RefreshRetryableErrorImmediately = false;
+ TExternalAuthInfo ExternalAuthInfo;
+
+ TTokenRecordBase(const TStringBuf ticket)
+ : Ticket(ticket)
+ {}
+
+ void SetToken(const TIntrusivePtr<NACLib::TUserToken>& token) {
+ // saving serialization info into the token instance.
+ token->SaveSerializationInfo();
+ Token = token;
+ }
+
+ const TIntrusiveConstPtr<NACLib::TUserToken> GetToken() const {
+ return Token;
+ }
+
+ void UnsetToken() {
+ Token = nullptr;
+ }
+
+ TString GetAttributeValue(const TString& permission, const TString& key) const {
+ if (auto it = Permissions.find(permission); it != Permissions.end()) {
+ for (const auto& pr : it->second.Attributes) {
+ if (pr.first == key) {
+ return pr.second;
+ }
+ }
+ }
+ return TString();
+ }
+
+ template <typename T>
+ void SetErrorRefreshTime(TTicketParserImpl<T>* ticketParser, TInstant now) {
+ if (Error.Retryable) {
+ SetRefreshTime(now, CurrentDelay);
+ if (CurrentDelay < ticketParser->MaxErrorRefreshTime - ticketParser->MinErrorRefreshTime) {
+ static const double scaleFactor = 2.0;
+ CurrentDelay = Min(CurrentDelay * scaleFactor, ticketParser->MaxErrorRefreshTime - ticketParser->MinErrorRefreshTime);
+ }
+ } else {
+ SetRefreshTime(now, ticketParser->RefreshTime - ticketParser->RefreshTime / 2);
+ }
+ }
+
+ template <typename T>
+ void SetOkRefreshTime(TTicketParserImpl<T>* ticketParser, TInstant now) {
+ SetRefreshTime(now, ticketParser->RefreshTime - ticketParser->RefreshTime / 2);
+ }
+
+ void SetRefreshTime(TInstant now, TDuration delay) {
+ const TDuration::TValue half = delay.GetValue() / 2;
+ RefreshTime = now + TDuration::FromValue(half + RandomNumber<TDuration::TValue>(half));
+ }
+
+ TString GetSubject() const {
+ return Subject;
+ }
+
+ TString GetAuthType() const {
+ switch (TokenType) {
+ case TDerived::ETokenType::Unknown:
+ return "Unknown";
+ case TDerived::ETokenType::Unsupported:
+ return "Unsupported";
+ case TDerived::ETokenType::Builtin:
+ return "Builtin";
+ case TDerived::ETokenType::Login:
+ return "Login";
+ case TDerived::ETokenType::AccessService:
+ return "AccessService";
+ case TDerived::ETokenType::ApiKey:
+ return "ApiKey";
+ }
+ }
+
+ bool NeedsRefresh() const {
+ switch (TokenType) {
+ case TDerived::ETokenType::Builtin:
+ return false;
+ case TDerived::ETokenType::Login:
+ return true;
+ default:
+ return Signature.AccessKeyId.empty();
+ }
+ }
+
+ bool IsTokenReady() const {
+ return Token != nullptr;
+ }
+
+ bool IsExternalAuthEnabled() const {
+ return !ExternalAuthInfo.Type.empty();
+ }
+
+ void EnableExternalAuth(const NLogin::TLoginProvider::TValidateTokenResponse& response) {
+ ExternalAuthInfo.Login = response.User;
+ ExternalAuthInfo.Type = response.ExternalAuth;
+ }
+ };
+
+private:
TString DomainName;
::NMonitoring::TDynamicCounters::TCounterPtr CounterTicketsReceived;
::NMonitoring::TDynamicCounters::TCounterPtr CounterTicketsSuccess;
@@ -136,19 +240,46 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
TString AccessServiceDomain;
TString ServiceDomain;
- struct TTokenRefreshRecord {
- TString Key;
- TInstant RefreshTime;
-
- bool operator <(const TTokenRefreshRecord& o) const {
- return RefreshTime > o.RefreshTime;
- }
- };
-
TPriorityQueue<TTokenRefreshRecord> RefreshQueue;
std::unordered_map<TString, NLogin::TLoginProvider> LoginProviders;
bool UseLoginProvider = false;
+ TDerived* GetDerived() {
+ return static_cast<TDerived*>(this);
+ }
+
+ static TString GetKey(TEvTicketParser::TEvAuthorizeTicket* request) {
+ TStringStream key;
+ if (request->Signature.AccessKeyId) {
+ const auto& sign = request->Signature;
+ key << sign.AccessKeyId << "-" << sign.Signature << ":" << sign.StringToSign << ":"
+ << sign.Service << ":" << sign.Region << ":" << sign.SignedAt.NanoSeconds();
+ } else {
+ key << request->Ticket;
+ }
+ key << ':';
+ if (request->Database) {
+ key << request->Database;
+ key << ':';
+ }
+ for (const auto& entry : request->Entries) {
+ for (auto it = entry.Attributes.begin(); it != entry.Attributes.end(); ++it) {
+ if (it != entry.Attributes.begin()) {
+ key << '-';
+ }
+ key << it->second;
+ }
+ key << ':';
+ for (auto it = entry.Permissions.begin(); it != entry.Permissions.end(); ++it) {
+ if (it != entry.Permissions.begin()) {
+ key << '-';
+ }
+ key << it->Permission << "(" << it->Required << ")";
+ }
+ }
+ return key.Str();
+ }
+
TInstant GetExpireTime(TInstant now) const {
return now + ExpireTime;
}
@@ -303,7 +434,8 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
record.TokenType = TDerived::ETokenType::Login;
record.ExpireTime = ToInstant(response.ExpiresAt);
CounterTicketsLogin->Inc();
- if (response.ExternalAuth.has_value()) {
+ if (response.ExternalAuth) {
+ record.EnableExternalAuth(response);
HandleExternalAuthentication(key, record, response);
return true;
}
@@ -336,7 +468,7 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
template <typename TTokenRecord>
void HandleExternalAuthentication(const TString& key, TTokenRecord& record, const NLogin::TLoginProvider::TValidateTokenResponse& loginProviderResponse) {
- if (loginProviderResponse.ExternalAuth.value() == Config.GetLdapAuthenticationDomain()) {
+ if (loginProviderResponse.ExternalAuth == Config.GetLdapAuthenticationDomain()) {
SendRequestToLdap(key, record, loginProviderResponse.User);
} else {
SetError(key, record, {.Message = "Do not have suitable external auth provider"});
@@ -911,120 +1043,6 @@ protected:
return TDerived::ETokenType::Unknown;
}
- class TTokenRecordBase {
- private:
- TIntrusiveConstPtr<NACLib::TUserToken> Token;
- public:
- TTokenRecordBase(const TTokenRecordBase&) = delete;
- TTokenRecordBase& operator =(const TTokenRecordBase&) = delete;
-
- TString Ticket;
- typename TDerived::ETokenType TokenType = TDerived::ETokenType::Unknown;
- NKikimr::TEvTicketParser::TEvAuthorizeTicket::TAccessKeySignature Signature;
- THashMap<TString, TPermissionRecord> Permissions;
- TString Subject; // login
- TEvTicketParser::TError Error;
- TDeque<THolder<TEventHandle<TEvTicketParser::TEvAuthorizeTicket>>> AuthorizeRequests;
- ui64 ResponsesLeft = 0;
- TInstant InitTime;
- TInstant RefreshTime;
- TInstant ExpireTime;
- TInstant AccessTime;
- TDuration CurrentDelay = TDuration::Seconds(1);
- TString PeerName;
- TString Database;
- TStackVec<TString> AdditionalSIDs;
- bool RefreshRetryableErrorImmediately = false;
-
- TTokenRecordBase(const TStringBuf ticket)
- : Ticket(ticket)
- {}
-
- void SetToken(const TIntrusivePtr<NACLib::TUserToken>& token) {
- // saving serialization info into the token instance.
- token->SaveSerializationInfo();
- Token = token;
- }
-
- const TIntrusiveConstPtr<NACLib::TUserToken> GetToken() const {
- return Token;
- }
-
- void UnsetToken() {
- Token = nullptr;
- }
-
- TString GetAttributeValue(const TString& permission, const TString& key) const {
- if (auto it = Permissions.find(permission); it != Permissions.end()) {
- for (const auto& pr : it->second.Attributes) {
- if (pr.first == key) {
- return pr.second;
- }
- }
- }
- return TString();
- }
-
- template <typename T>
- void SetErrorRefreshTime(TTicketParserImpl<T>* ticketParser, TInstant now) {
- if (Error.Retryable) {
- SetRefreshTime(now, CurrentDelay);
- if (CurrentDelay < ticketParser->MaxErrorRefreshTime - ticketParser->MinErrorRefreshTime) {
- static const double scaleFactor = 2.0;
- CurrentDelay = Min(CurrentDelay * scaleFactor, ticketParser->MaxErrorRefreshTime - ticketParser->MinErrorRefreshTime);
- }
- } else {
- SetRefreshTime(now, ticketParser->RefreshTime - ticketParser->RefreshTime / 2);
- }
- }
-
- template <typename T>
- void SetOkRefreshTime(TTicketParserImpl<T>* ticketParser, TInstant now) {
- SetRefreshTime(now, ticketParser->RefreshTime - ticketParser->RefreshTime / 2);
- }
-
- void SetRefreshTime(TInstant now, TDuration delay) {
- const TDuration::TValue half = delay.GetValue() / 2;
- RefreshTime = now + TDuration::FromValue(half + RandomNumber<TDuration::TValue>(half));
- }
-
- TString GetSubject() const {
- return Subject;
- }
-
- TString GetAuthType() const {
- switch (TokenType) {
- case TDerived::ETokenType::Unknown:
- return "Unknown";
- case TDerived::ETokenType::Unsupported:
- return "Unsupported";
- case TDerived::ETokenType::Builtin:
- return "Builtin";
- case TDerived::ETokenType::Login:
- return "Login";
- case TDerived::ETokenType::AccessService:
- return "AccessService";
- case TDerived::ETokenType::ApiKey:
- return "ApiKey";
- }
- }
-
- bool NeedsRefresh() const {
- switch (TokenType) {
- case TDerived::ETokenType::Builtin:
- return false;
- case TDerived::ETokenType::Login:
- return true;
- default:
- return Signature.AccessKeyId.empty();
- }
- }
-
- bool IsTokenReady() const {
- return Token != nullptr;
- }
- };
-
static TStringBuf GetTicketFromKey(const TStringBuf key) {
return key.Before(':');
}
@@ -1246,15 +1264,34 @@ protected:
}
template <typename TTokenRecord>
+ bool RefreshTicketViaExternalAuthProvider(const TString& key, TTokenRecord& record) {
+ const TExternalAuthInfo& externalAuthInfo = record.ExternalAuthInfo;
+ if (externalAuthInfo.Type == Config.GetLdapAuthenticationDomain()) {
+ if (Config.HasLdapAuthentication()) {
+ Send(MakeLdapAuthProviderID(), new TEvLdapAuthProvider::TEvEnrichGroupsRequest(key, externalAuthInfo.Login));
+ return true;
+ }
+ SetError(key, record, {.Message = "LdapAuthProvider is not initialized", .Retryable = false});
+ return false;
+ } else {
+ SetError(key, record, {.Message = "Do not have suitable external auth provider"});
+ return false;
+ }
+ }
+
+ template <typename TTokenRecord>
bool RefreshLoginTicket(const TString& key, TTokenRecord& record) {
GetDerived()->ResetTokenRecord(record);
+ const TString userSID = record.GetToken()->GetUserSID();
+ if (record.IsExternalAuthEnabled()) {
+ return RefreshTicketViaExternalAuthProvider(key, record);
+ }
const TString& database = Config.GetDomainLoginOnly() ? DomainName : record.Database;
auto itLoginProvider = LoginProviders.find(database);
if (itLoginProvider == LoginProviders.end()) {
return false;
}
NLogin::TLoginProvider& loginProvider(itLoginProvider->second);
- const TString userSID = record.GetToken()->GetUserSID();
if (loginProvider.CheckUserExists(userSID)) {
const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(userSID);
const TVector<NACLib::TSID> groups(providerGroups.begin(), providerGroups.end());
@@ -1265,10 +1302,7 @@ protected:
.AuthType = record.GetAuthType()
}));
} else {
- TEvTicketParser::TError error;
- error.Message = "User not found";
- error.Retryable = false;
- SetError(key, record, error);
+ SetError(key, record, {.Message = "User not found", .Retryable = false});
}
return true;
}
diff --git a/ydb/core/security/ticket_parser_ut.cpp b/ydb/core/security/ticket_parser_ut.cpp
index 25b841583b8..7892e018ac8 100644
--- a/ydb/core/security/ticket_parser_ut.cpp
+++ b/ydb/core/security/ticket_parser_ut.cpp
@@ -91,14 +91,20 @@ private:
ui16 GrpcPort;
};
-TAutoPtr<IEventHandle> LdapAuthenticate(TLdapKikimrServer& server, const TString& login, const TString& password) {
+NLogin::TLoginProvider::TLoginUserResponse GetLoginResponse(TLdapKikimrServer& server, const TString& login, const TString& password) {
TTestActorRuntime* runtime = server.GetRuntime();
NLogin::TLoginProvider provider;
provider.Audience = "/Root";
provider.RotateKeys();
TActorId sender = runtime->AllocateEdgeActor();
runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvUpdateLoginSecurityState(provider.GetSecurityState())), 0);
- auto loginResponse = provider.LoginUser({.User = login, .Password = password, .ExternalAuth = "ldap"});
+ return provider.LoginUser({.User = login, .Password = password, .ExternalAuth = "ldap"});
+}
+
+TAutoPtr<IEventHandle> LdapAuthenticate(TLdapKikimrServer& server, const TString& login, const TString& password) {
+ auto loginResponse = GetLoginResponse(server, login, password);
+ TTestActorRuntime* runtime = server.GetRuntime();
+ TActorId sender = runtime->AllocateEdgeActor();
runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
TAutoPtr<IEventHandle> handle;
@@ -106,7 +112,51 @@ TAutoPtr<IEventHandle> LdapAuthenticate(TLdapKikimrServer& server, const TString
return handle;
}
+class TCorrectLdapResponse {
+public:
+ static std::vector<TString> Groups;
+ static LdapMock::TLdapMockResponses GetResponses(const TString& login);
+};
+
+std::vector<TString> TCorrectLdapResponse::Groups {
+ "ou=groups,dc=search,dc=yandex,dc=net",
+ "cn=people,ou=groups,dc=search,dc=yandex,dc=net",
+ "cn=developers,ou=groups,dc=search,dc=yandex,dc=net"
+};
+
+LdapMock::TLdapMockResponses TCorrectLdapResponse::GetResponses(const TString& login) {
+ LdapMock::TLdapMockResponses responses;
+ responses.BindResponses.push_back({{{.Login = "cn=robouser,dc=search,dc=yandex,dc=net", .Password = "robouserPassword"}}, {.Status = LdapMock::EStatus::SUCCESS}});
+
+ LdapMock::TSearchRequestInfo fetchGroupsSearchRequestInfo {
+ {
+ .BaseDn = "dc=search,dc=yandex,dc=net",
+ .Scope = 2,
+ .DerefAliases = 0,
+ .Filter = {.Type = LdapMock::EFilterType::LDAP_FILTER_EQUALITY, .Attribute = "uid", .Value = login},
+ .Attributes = {"memberOf"}
+ }
+ };
+
+ std::vector<LdapMock::TSearchEntry> fetchGroupsSearchResponseEntries {
+ {
+ .Dn = "uid=" + login + ",dc=search,dc=yandex,dc=net",
+ .AttributeList = {
+ {"memberOf", TCorrectLdapResponse::Groups}
+ }
+ }
+ };
+
+ LdapMock::TSearchResponseInfo fetchGroupsSearchResponseInfo {
+ .ResponseEntries = fetchGroupsSearchResponseEntries,
+ .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
+ };
+ responses.SearchResponses.push_back({fetchGroupsSearchRequestInfo, fetchGroupsSearchResponseInfo});
+ return responses;
+}
+
Y_UNIT_TEST_SUITE(TTicketParserTest) {
+
Y_UNIT_TEST(LoginGood) {
using namespace Tests;
TPortManager tp;
@@ -429,42 +479,8 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
TString login = "ldapuser";
TString password = "ldapUserPassword";
- LdapMock::TLdapMockResponses responses;
- responses.BindResponses.push_back({{{.Login = "cn=robouser,dc=search,dc=yandex,dc=net", .Password = "robouserPassword"}}, {.Status = LdapMock::EStatus::SUCCESS}});
-
- LdapMock::TSearchRequestInfo fetchGroupsSearchRequestInfo {
- {
- .BaseDn = "dc=search,dc=yandex,dc=net",
- .Scope = 2,
- .DerefAliases = 0,
- .Filter = {.Type = LdapMock::EFilterType::LDAP_FILTER_EQUALITY, .Attribute = "uid", .Value = login},
- .Attributes = {"memberOf"}
- }
- };
-
- THashSet<TString> expectedGroups {
- "ou=groups,dc=search,dc=yandex,dc=net",
- "cn=people,ou=groups,dc=search,dc=yandex,dc=net",
- "cn=developers,ou=groups,dc=search,dc=yandex,dc=net"
- };
- std::vector<LdapMock::TSearchEntry> fetchGroupsSearchResponseEntries {
- {
- .Dn = "uid=" + login + ",dc=search,dc=yandex,dc=net",
- .AttributeList = {
- {"memberOf", std::vector(expectedGroups.begin(), expectedGroups.end())}
- }
- }
- };
- expectedGroups.insert("all-users@well-known");
-
- LdapMock::TSearchResponseInfo fetchGroupsSearchResponseInfo {
- .ResponseEntries = fetchGroupsSearchResponseEntries,
- .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
- };
- responses.SearchResponses.push_back({fetchGroupsSearchRequestInfo, fetchGroupsSearchResponseInfo});
-
TLdapKikimrServer server(InitLdapSettings);
- LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
+ LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), TCorrectLdapResponse::GetResponses(login));
TAutoPtr<IEventHandle> handle = LdapAuthenticate(server, login, password);
TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = handle->Get<TEvTicketParser::TEvAuthorizeTicketResult>();
@@ -473,6 +489,9 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + "@ldap");
const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());
+
+ THashSet<TString> expectedGroups(TCorrectLdapResponse::Groups.begin(), TCorrectLdapResponse::Groups.end());
+ expectedGroups.insert("all-users@well-known");
UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
for (const auto& expectedGroup : expectedGroups) {
UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
@@ -559,7 +578,6 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
TString password = "ldapUserPassword";
LdapMock::TLdapMockResponses responses;
- responses.BindResponses.push_back({{{.Login = "uid=" + login + ",dc=search,dc=yandex,dc=net", .Password = password}}, {LdapMock::EStatus::SUCCESS}});
responses.BindResponses.push_back({{{.Login = "cn=robouser,dc=search,dc=yandex,dc=net", .Password = "robouserPassword"}}, {.Status = LdapMock::EStatus::SUCCESS}});
TLdapKikimrServer server(InitLdapSettingsWithInvalidFilter);
@@ -590,6 +608,126 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
ldapServer.Stop();
}
+ Y_UNIT_TEST(LdapRefreshGroupsInfoGood) {
+ TString login = "ldapuser";
+ TString password = "ldapUserPassword";
+
+ TLdapKikimrServer server(InitLdapSettings);
+ auto responses = TCorrectLdapResponse::GetResponses(login);
+ LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
+
+ auto loginResponse = GetLoginResponse(server, login, password);
+ TTestActorRuntime* runtime = server.GetRuntime();
+ TActorId sender = runtime->AllocateEdgeActor();
+ runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
+ TAutoPtr<IEventHandle> handle;
+ TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
+
+ UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
+ UNIT_ASSERT(ticketParserResult->Token != nullptr);
+ UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + "@ldap");
+ const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
+ THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());
+
+ THashSet<TString> expectedGroups(TCorrectLdapResponse::Groups.begin(), TCorrectLdapResponse::Groups.end());
+ expectedGroups.insert("all-users@well-known");
+ UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
+ for (const auto& expectedGroup : expectedGroups) {
+ UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
+ }
+
+ THashSet<TString> newExpectedGroups {
+ "ou=groups,dc=search,dc=yandex,dc=net",
+ "cn=people,ou=groups,dc=search,dc=yandex,dc=net",
+ "cn=desiners,ou=groups,dc=search,dc=yandex,dc=net"
+ };
+ std::vector<LdapMock::TSearchEntry> newFetchGroupsSearchResponseEntries {
+ {
+ .Dn = "uid=" + login + ",dc=search,dc=yandex,dc=net",
+ .AttributeList = {
+ {"memberOf", std::vector(newExpectedGroups.begin(), newExpectedGroups.end())}
+ }
+ }
+ };
+ newExpectedGroups.insert("all-users@well-known");
+
+ LdapMock::TSearchResponseInfo newFetchGroupsSearchResponseInfo {
+ .ResponseEntries = newFetchGroupsSearchResponseEntries,
+ .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
+ };
+
+ auto& searchresponse = responses.SearchResponses.front();
+ searchresponse.second = newFetchGroupsSearchResponseInfo;
+ ldapServer.SetSearchReasponse(searchresponse);
+ Sleep(TDuration::Seconds(10));
+
+ runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
+ ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
+
+ UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
+ UNIT_ASSERT(ticketParserResult->Token != nullptr);
+ UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + "@ldap");
+ const auto& newFetchedGroups = ticketParserResult->Token->GetGroupSIDs();
+ THashSet<TString> newGroups(newFetchedGroups.begin(), newFetchedGroups.end());
+ UNIT_ASSERT_VALUES_EQUAL(newFetchedGroups.size(), newExpectedGroups.size());
+ for (const auto& expectedGroup : newExpectedGroups) {
+ UNIT_ASSERT_C(newGroups.contains(expectedGroup), "Can not find " + expectedGroup);
+ }
+
+ ldapServer.Stop();
+ }
+
+ Y_UNIT_TEST(LdapRefreshRemoveUserBad) {
+ TString login = "ldapuser";
+ TString password = "ldapUserPassword";
+
+ TLdapKikimrServer server(InitLdapSettings);
+ auto responses = TCorrectLdapResponse::GetResponses(login);
+ LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), responses);
+
+
+ auto loginResponse = GetLoginResponse(server, login, password);
+ TTestActorRuntime* runtime = server.GetRuntime();
+ TActorId sender = runtime->AllocateEdgeActor();
+ runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
+ TAutoPtr<IEventHandle> handle;
+ TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
+
+ UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
+ UNIT_ASSERT(ticketParserResult->Token != nullptr);
+ UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + "@ldap");
+ const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
+ THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());
+
+ THashSet<TString> expectedGroups(TCorrectLdapResponse::Groups.begin(), TCorrectLdapResponse::Groups.end());
+ expectedGroups.insert("all-users@well-known");
+ UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
+ for (const auto& expectedGroup : expectedGroups) {
+ UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
+ }
+
+ LdapMock::TSearchResponseInfo newFetchGroupsSearchResponseInfo {
+ .ResponseEntries = {}, // User has been removed. Return empty entries list
+ .ResponseDone = {.Status = LdapMock::EStatus::SUCCESS}
+ };
+
+ auto& searchresponse = responses.SearchResponses.front();
+ searchresponse.second = newFetchGroupsSearchResponseInfo;
+ ldapServer.SetSearchReasponse(searchresponse);
+ Sleep(TDuration::Seconds(10));
+
+ runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
+ ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
+
+ UNIT_ASSERT_C(!ticketParserResult->Error.empty(), "Expected return error message");
+ UNIT_ASSERT(ticketParserResult->Token == nullptr);
+ UNIT_ASSERT_STRINGS_EQUAL(ticketParserResult->Error.Message, "LDAP user " + login + " does not exist. "
+ "LDAP search for filter uid=" + login + " on server localhost return no entries");
+ UNIT_ASSERT_EQUAL(ticketParserResult->Error.Retryable, false);
+
+ ldapServer.Stop();
+ }
+
Y_UNIT_TEST(AccessServiceAuthenticationOk) {
using namespace Tests;
diff --git a/ydb/core/tx/schemeshard/schemeshard__login.cpp b/ydb/core/tx/schemeshard/schemeshard__login.cpp
index 4915b5d1c8e..85d2f54a489 100644
--- a/ydb/core/tx/schemeshard/schemeshard__login.cpp
+++ b/ydb/core/tx/schemeshard/schemeshard__login.cpp
@@ -22,7 +22,7 @@ struct TSchemeShard::TTxLogin : TSchemeShard::TRwTxBase {
return {
.User = Request->Get()->Record.GetUser(),
.Password = Request->Get()->Record.GetPassword(),
- .ExternalAuth = (Request->Get()->Record.HasExternalAuth() ? std::make_optional(Request->Get()->Record.GetExternalAuth()) : std::nullopt)
+ .ExternalAuth = Request->Get()->Record.GetExternalAuth()
};
}
diff --git a/ydb/library/login/login.cpp b/ydb/library/login/login.cpp
index 29ca5b9e469..f65d753c96a 100644
--- a/ydb/library/login/login.cpp
+++ b/ydb/library/login/login.cpp
@@ -235,7 +235,7 @@ std::vector<TString> TLoginProvider::GetGroupsMembership(const TString& member)
TLoginProvider::TLoginUserResponse TLoginProvider::LoginUser(const TLoginUserRequest& request) {
TLoginUserResponse response;
- if (!request.ExternalAuth.has_value()) {
+ if (!request.ExternalAuth) {
auto itUser = Sids.find(request.User);
if (itUser == Sids.end() || itUser->second.Type != ESidType::USER) {
response.Error = "Invalid user";
@@ -276,8 +276,8 @@ TLoginProvider::TLoginUserResponse TLoginProvider::LoginUser(const TLoginUserReq
token.set_audience(Audience);
}
- if (request.ExternalAuth.has_value()) {
- token.set_payload_claim(EXTERNAL_AUTH_CLAIM_NAME, jwt::claim(request.ExternalAuth.value()));
+ if (request.ExternalAuth) {
+ token.set_payload_claim(EXTERNAL_AUTH_CLAIM_NAME, jwt::claim(request.ExternalAuth));
} else {
if (request.Options.WithUserGroups) {
auto groups = GetGroupsMembership(request.User);
diff --git a/ydb/library/login/login.h b/ydb/library/login/login.h
index b8b0b34c2b9..3cd8d3856b2 100644
--- a/ydb/library/login/login.h
+++ b/ydb/library/login/login.h
@@ -41,7 +41,7 @@ public:
TString User;
TString Password;
TOptions Options;
- std::optional<TString> ExternalAuth;
+ TString ExternalAuth;
};
struct TLoginUserResponse : TBasicResponse {
@@ -58,7 +58,7 @@ public:
TString User;
std::optional<std::vector<TString>> Groups;
std::chrono::system_clock::time_point ExpiresAt;
- std::optional<TString> ExternalAuth;
+ TString ExternalAuth;
};
struct TCreateUserRequest : TBasicRequest {
diff --git a/ydb/library/login/login_ut.cpp b/ydb/library/login/login_ut.cpp
index 03e07aa7ea5..4ace83b1ebe 100644
--- a/ydb/library/login/login_ut.cpp
+++ b/ydb/library/login/login_ut.cpp
@@ -220,7 +220,7 @@ Y_UNIT_TEST_SUITE(Login) {
auto response3 = provider.ValidateToken(request3);
UNIT_ASSERT_VALUES_EQUAL(response3.Error, "");
UNIT_ASSERT(response3.User == request2.User);
- UNIT_ASSERT(response3.ExternalAuth.has_value());
+ UNIT_ASSERT(!response3.ExternalAuth.empty());
UNIT_ASSERT(response3.ExternalAuth == request2.ExternalAuth);
}
{
@@ -234,7 +234,7 @@ Y_UNIT_TEST_SUITE(Login) {
auto response3 = provider.ValidateToken(request3);
UNIT_ASSERT_VALUES_EQUAL(response3.Error, "");
UNIT_ASSERT(response3.User == request1.User);
- UNIT_ASSERT(!response3.ExternalAuth.has_value());
+ UNIT_ASSERT(response3.ExternalAuth.empty());
}
}
}