aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authormolotkov-and <molotkov-and@yandex-team.com>2022-09-28 14:37:25 +0300
committermolotkov-and <molotkov-and@yandex-team.com>2022-09-28 14:37:25 +0300
commite49a6d731192528f21228bf8bbbdeb78b35a5dca (patch)
treeb0520a175ebdfca0726111a2787cee1239c66834
parentb4a5e309b31f8c90e165d2388586544a3926ca93 (diff)
downloadydb-e49a6d731192528f21228bf8bbbdeb78b35a5dca.tar.gz
Extract InitTokenRecord to base class
-rw-r--r--ydb/core/security/ticket_parser.cpp114
-rw-r--r--ydb/core/security/ticket_parser_impl.h129
2 files changed, 125 insertions, 118 deletions
diff --git a/ydb/core/security/ticket_parser.cpp b/ydb/core/security/ticket_parser.cpp
index b4439d1e124..937036e1cea 100644
--- a/ydb/core/security/ticket_parser.cpp
+++ b/ydb/core/security/ticket_parser.cpp
@@ -38,8 +38,6 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> {
struct TTokenRecord : TBase::TTokenRecordBase {
using TBase::TTokenRecordBase::TTokenRecordBase;
- ETokenType TokenType = ETokenType::Unknown;
-
TString GetSubject() const {
return Subject;
}
@@ -60,104 +58,14 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> {
THashMap<TString, TTokenRecord> UserTokens;
- TTokenRecord* GetUserToken(const TString& key) {
- auto it = UserTokens.find(key);
- return it != UserTokens.end() ? &it->second : nullptr;
- }
-
- TTokenRecord* InsertUserToken(const TString& key, const TStringBuf ticket) {
- auto it = UserTokens.emplace(key, ticket).first;
- return &it->second;
+ THashMap<TString, TTokenRecord>& GetUserTokens() {
+ return UserTokens;
}
static TStringStream GetKey(TEvTicketParser::TEvAuthorizeTicket* request) {
return request->Ticket;
}
- void InitTokenRecord(const TString& key, TTokenRecord& record, const TActorContext& ctx) {
- TInstant now = ctx.Now();
- record.InitTime = now;
- record.AccessTime = now;
- record.ExpireTime = GetExpireTime(now);
-
- if (record.Error) {
- return;
- }
-
- if (record.TokenType == ETokenType::Unknown || record.TokenType == ETokenType::Builtin) {
- if(record.Ticket.EndsWith("@" BUILTIN_ACL_DOMAIN)) {
- record.TokenType = ETokenType::Builtin;
- SetToken(key, record, new NACLib::TUserToken({
- .OriginalUserToken = record.Ticket,
- .UserSID = record.Ticket,
- .AuthType = record.GetAuthType()
- }), ctx);
- CounterTicketsBuiltin->Inc();
- return;
- }
-
- if(record.Ticket.EndsWith("@" BUILTIN_ERROR_DOMAIN)) {
- record.TokenType = ETokenType::Builtin;
- SetError(key, record, {"Builtin error simulation"}, ctx);
- CounterTicketsBuiltin->Inc();
- return;
- }
- }
-
- if (UseLoginProvider && (record.TokenType == ETokenType::Unknown || record.TokenType == ETokenType::Login)) {
- TString database = Config.GetDomainLoginOnly() ? DomainName : record.Database;
- auto itLoginProvider = LoginProviders.find(database);
- if (itLoginProvider != LoginProviders.end()) {
- NLogin::TLoginProvider& loginProvider(itLoginProvider->second);
- auto response = loginProvider.ValidateToken({.Token = record.Ticket});
- if (response.Error) {
- if (!response.TokenUnrecognized || record.TokenType != ETokenType::Unknown) {
- record.TokenType = ETokenType::Login;
- TEvTicketParser::TError error;
- error.Message = response.Error;
- error.Retryable = response.ErrorRetryable;
- SetError(key, record, error, ctx);
- CounterTicketsLogin->Inc();
- return;
- }
- } else {
- record.TokenType = ETokenType::Login;
- TVector<NACLib::TSID> groups;
- if (response.Groups.has_value()) {
- const std::vector<TString>& tokenGroups = response.Groups.value();
- groups.assign(tokenGroups.begin(), tokenGroups.end());
- } else {
- const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(response.User);
- groups.assign(providerGroups.begin(), providerGroups.end());
- }
- record.ExpireTime = ToInstant(response.ExpiresAt);
- SetToken(key, record, new NACLib::TUserToken({
- .OriginalUserToken = record.Ticket,
- .UserSID = response.User,
- .GroupSIDs = groups,
- .AuthType = record.GetAuthType()
- }), ctx);
- CounterTicketsLogin->Inc();
- return;
- }
- } else {
- if (record.TokenType == ETokenType::Login) {
- TEvTicketParser::TError error;
- error.Message = "Login state is not available yet";
- error.Retryable = false;
- SetError(key, record, error, ctx);
- CounterTicketsLogin->Inc();
- return;
- }
- }
- }
-
- if (record.TokenType == ETokenType::Unknown && record.ResponsesLeft == 0) {
- record.Error.Message = "Could not find correct token validator";
- record.Error.Retryable = false;
- }
- }
-
void SetToken(const TString& key, TTokenRecord& record, TIntrusivePtr<NACLib::TUserToken> token, const TActorContext& ctx) {
TInstant now = ctx.Now();
record.Error.clear();
@@ -199,24 +107,6 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> {
return ticket.empty();
}
- void SetTokenType(TTokenRecord& record, TStringBuf&, const TStringBuf ticketType) {
- if (ticketType) {
- record.TokenType = ParseTokenType(ticketType);
- switch (record.TokenType) {
- case ETokenType::Unsupported:
- record.Error.Message = "Token is not supported";
- record.Error.Retryable = false;
- break;
- case ETokenType::Unknown:
- record.Error.Message = "Unknown token";
- record.Error.Retryable = false;
- break;
- default:
- break;
- }
- }
- }
-
void Handle(TEvTicketParser::TEvRefreshTicket::TPtr& ev, const TActorContext&) {
UserTokens.erase(ev->Get()->Ticket);
}
diff --git a/ydb/core/security/ticket_parser_impl.h b/ydb/core/security/ticket_parser_impl.h
index c9e0f384a1b..576a40021d6 100644
--- a/ydb/core/security/ticket_parser_impl.h
+++ b/ydb/core/security/ticket_parser_impl.h
@@ -40,7 +40,9 @@ protected:
TTokenRecordBase(const TTokenRecordBase&) = delete;
TTokenRecordBase& operator =(const TTokenRecordBase&) = delete;
+
TString Ticket;
+ typename TDerived::ETokenType TokenType = TDerived::ETokenType::Unknown;
TString Subject; // login
TEvTicketParser::TError Error;
TIntrusivePtr<NACLib::TUserToken> Token;
@@ -123,6 +125,101 @@ protected:
}
}
+ template <typename TTokenRecord>
+ void SetTime(TTokenRecord& record, TInstant now) {
+ record.InitTime = now;
+ record.AccessTime = now;
+ record.ExpireTime = GetExpireTime(now);
+ }
+
+ template <typename TTokenRecord>
+ void InitTokenRecord(const TString&, TTokenRecord& record, const TActorContext&, TInstant) {
+ if (record.TokenType == TDerived::ETokenType::Unknown && record.ResponsesLeft == 0) {
+ record.Error.Message = "Could not find correct token validator";
+ record.Error.Retryable = false;
+ }
+ }
+
+ template <typename TTokenRecord>
+ void InitTokenRecord(const TString& key, TTokenRecord& record, const TActorContext& ctx) {
+ TInstant now = ctx.Now();
+ GetDerived()->SetTime(record, now);
+
+ if (record.Error) {
+ return;
+ }
+
+ if (record.TokenType == TDerived::ETokenType::Unknown || record.TokenType == TDerived::ETokenType::Builtin) {
+ if(record.Ticket.EndsWith("@" BUILTIN_ACL_DOMAIN)) {
+ record.TokenType = TDerived::ETokenType::Builtin;
+ GetDerived()->SetToken(key, record, new NACLib::TUserToken({
+ .OriginalUserToken = record.Ticket,
+ .UserSID = record.Ticket,
+ .AuthType = record.GetAuthType()
+ }), ctx);
+ CounterTicketsBuiltin->Inc();
+ return;
+ }
+
+ if(record.Ticket.EndsWith("@" BUILTIN_ERROR_DOMAIN)) {
+ record.TokenType = TDerived::ETokenType::Builtin;
+ GetDerived()->SetError(key, record, {"Builtin error simulation"}, ctx);
+ CounterTicketsBuiltin->Inc();
+ return;
+ }
+ }
+
+ if (UseLoginProvider && (record.TokenType == TDerived::ETokenType::Unknown || record.TokenType == TDerived::ETokenType::Login)) {
+ TString database = Config.GetDomainLoginOnly() ? DomainName : record.Database;
+ auto itLoginProvider = LoginProviders.find(database);
+ if (itLoginProvider != LoginProviders.end()) {
+ NLogin::TLoginProvider& loginProvider(itLoginProvider->second);
+ auto response = loginProvider.ValidateToken({.Token = record.Ticket});
+ if (response.Error) {
+ if (!response.TokenUnrecognized || record.TokenType != TDerived::ETokenType::Unknown) {
+ record.TokenType = TDerived::ETokenType::Login;
+ TEvTicketParser::TError error;
+ error.Message = response.Error;
+ error.Retryable = response.ErrorRetryable;
+ GetDerived()->SetError(key, record, error, ctx);
+ CounterTicketsLogin->Inc();
+ return;
+ }
+ } else {
+ record.TokenType = TDerived::ETokenType::Login;
+ TVector<NACLib::TSID> groups;
+ if (response.Groups.has_value()) {
+ const std::vector<TString>& tokenGroups = response.Groups.value();
+ groups.assign(tokenGroups.begin(), tokenGroups.end());
+ } else {
+ const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(response.User);
+ groups.assign(providerGroups.begin(), providerGroups.end());
+ }
+ record.ExpireTime = ToInstant(response.ExpiresAt);
+ GetDerived()->SetToken(key, record, new NACLib::TUserToken({
+ .OriginalUserToken = record.Ticket,
+ .UserSID = response.User,
+ .GroupSIDs = groups,
+ .AuthType = record.GetAuthType()
+ }), ctx);
+ CounterTicketsLogin->Inc();
+ return;
+ }
+ } else {
+ if (record.TokenType == TDerived::ETokenType::Login) {
+ TEvTicketParser::TError error;
+ error.Message = "Login state is not available yet";
+ error.Retryable = false;
+ GetDerived()->SetError(key, record, error, ctx);
+ CounterTicketsLogin->Inc();
+ return;
+ }
+ }
+ }
+
+ GetDerived()->InitTokenRecord(key, record, ctx, now);
+ }
+
void Respond(TTokenRecordBase& record, const TActorContext& ctx) {
if (record.IsTokenReady()) {
for (const auto& request : record.AuthorizeRequests) {
@@ -151,6 +248,25 @@ protected:
record.Database = std::move(ev->Get()->Database);
}
+ template <typename TTokenRecord>
+ void SetTokenType(TTokenRecord& record, TStringBuf&, const TStringBuf ticketType) {
+ if (ticketType) {
+ record.TokenType = GetDerived()->ParseTokenType(ticketType);
+ switch (record.TokenType) {
+ case TDerived::ETokenType::Unsupported:
+ record.Error.Message = "Token is not supported";
+ record.Error.Retryable = false;
+ break;
+ case TDerived::ETokenType::Unknown:
+ record.Error.Message = "Unknown token";
+ record.Error.Retryable = false;
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
void Handle(TEvTicketParser::TEvAuthorizeTicket::TPtr& ev, const TActorContext& ctx) {
TStringBuf ticket;
TStringBuf ticketType;
@@ -169,9 +285,10 @@ protected:
ctx.Send(sender, new TEvTicketParser::TEvAuthorizeTicketResult(ev->Get()->Ticket, error), 0, cookie);
return;
}
- auto recordPtr = GetDerived()->GetUserToken(key);
- if (recordPtr) {
- auto& record = *recordPtr;
+ auto& UserTokens = GetDerived()->GetUserTokens();
+ auto it = UserTokens.find(key);
+ if (it != UserTokens.end()) {
+ auto& record = it->second;
// we know about token
if (record.IsTokenReady()) {
// token already have built
@@ -188,16 +305,16 @@ protected:
CounterTicketsCacheHit->Inc();
return;
} else {
- recordPtr = GetDerived()->InsertUserToken(key, ticket);
+ it = UserTokens.emplace(key, ticket).first;
CounterTicketsCacheMiss->Inc();
}
- auto& record = *recordPtr;
+ auto& record = it->second;
GetDerived()->TokenRecordSetup(record, ev);
GetDerived()->SetTokenType(record, ticket, ticketType);
- GetDerived()->InitTokenRecord(key, record, ctx);
+ InitTokenRecord(key, record, ctx);
if (record.Error) {
LOG_ERROR_S(ctx, NKikimrServices::TICKET_PARSER, "Ticket " << MaskTicket(ticket) << ": " << record.Error);
ctx.Send(sender, new TEvTicketParser::TEvAuthorizeTicketResult(ev->Get()->Ticket, record.Error), 0, cookie);