diff options
author | molotkov-and <molotkov-and@yandex-team.com> | 2022-09-28 14:37:25 +0300 |
---|---|---|
committer | molotkov-and <molotkov-and@yandex-team.com> | 2022-09-28 14:37:25 +0300 |
commit | e49a6d731192528f21228bf8bbbdeb78b35a5dca (patch) | |
tree | b0520a175ebdfca0726111a2787cee1239c66834 | |
parent | b4a5e309b31f8c90e165d2388586544a3926ca93 (diff) | |
download | ydb-e49a6d731192528f21228bf8bbbdeb78b35a5dca.tar.gz |
Extract InitTokenRecord to base class
-rw-r--r-- | ydb/core/security/ticket_parser.cpp | 114 | ||||
-rw-r--r-- | ydb/core/security/ticket_parser_impl.h | 129 |
2 files changed, 125 insertions, 118 deletions
diff --git a/ydb/core/security/ticket_parser.cpp b/ydb/core/security/ticket_parser.cpp index b4439d1e124..937036e1cea 100644 --- a/ydb/core/security/ticket_parser.cpp +++ b/ydb/core/security/ticket_parser.cpp @@ -38,8 +38,6 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> { struct TTokenRecord : TBase::TTokenRecordBase { using TBase::TTokenRecordBase::TTokenRecordBase; - ETokenType TokenType = ETokenType::Unknown; - TString GetSubject() const { return Subject; } @@ -60,104 +58,14 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> { THashMap<TString, TTokenRecord> UserTokens; - TTokenRecord* GetUserToken(const TString& key) { - auto it = UserTokens.find(key); - return it != UserTokens.end() ? &it->second : nullptr; - } - - TTokenRecord* InsertUserToken(const TString& key, const TStringBuf ticket) { - auto it = UserTokens.emplace(key, ticket).first; - return &it->second; + THashMap<TString, TTokenRecord>& GetUserTokens() { + return UserTokens; } static TStringStream GetKey(TEvTicketParser::TEvAuthorizeTicket* request) { return request->Ticket; } - void InitTokenRecord(const TString& key, TTokenRecord& record, const TActorContext& ctx) { - TInstant now = ctx.Now(); - record.InitTime = now; - record.AccessTime = now; - record.ExpireTime = GetExpireTime(now); - - if (record.Error) { - return; - } - - if (record.TokenType == ETokenType::Unknown || record.TokenType == ETokenType::Builtin) { - if(record.Ticket.EndsWith("@" BUILTIN_ACL_DOMAIN)) { - record.TokenType = ETokenType::Builtin; - SetToken(key, record, new NACLib::TUserToken({ - .OriginalUserToken = record.Ticket, - .UserSID = record.Ticket, - .AuthType = record.GetAuthType() - }), ctx); - CounterTicketsBuiltin->Inc(); - return; - } - - if(record.Ticket.EndsWith("@" BUILTIN_ERROR_DOMAIN)) { - record.TokenType = ETokenType::Builtin; - SetError(key, record, {"Builtin error simulation"}, ctx); - CounterTicketsBuiltin->Inc(); - return; - } - } - - if (UseLoginProvider && (record.TokenType == ETokenType::Unknown || record.TokenType == ETokenType::Login)) { - TString database = Config.GetDomainLoginOnly() ? DomainName : record.Database; - auto itLoginProvider = LoginProviders.find(database); - if (itLoginProvider != LoginProviders.end()) { - NLogin::TLoginProvider& loginProvider(itLoginProvider->second); - auto response = loginProvider.ValidateToken({.Token = record.Ticket}); - if (response.Error) { - if (!response.TokenUnrecognized || record.TokenType != ETokenType::Unknown) { - record.TokenType = ETokenType::Login; - TEvTicketParser::TError error; - error.Message = response.Error; - error.Retryable = response.ErrorRetryable; - SetError(key, record, error, ctx); - CounterTicketsLogin->Inc(); - return; - } - } else { - record.TokenType = ETokenType::Login; - TVector<NACLib::TSID> groups; - if (response.Groups.has_value()) { - const std::vector<TString>& tokenGroups = response.Groups.value(); - groups.assign(tokenGroups.begin(), tokenGroups.end()); - } else { - const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(response.User); - groups.assign(providerGroups.begin(), providerGroups.end()); - } - record.ExpireTime = ToInstant(response.ExpiresAt); - SetToken(key, record, new NACLib::TUserToken({ - .OriginalUserToken = record.Ticket, - .UserSID = response.User, - .GroupSIDs = groups, - .AuthType = record.GetAuthType() - }), ctx); - CounterTicketsLogin->Inc(); - return; - } - } else { - if (record.TokenType == ETokenType::Login) { - TEvTicketParser::TError error; - error.Message = "Login state is not available yet"; - error.Retryable = false; - SetError(key, record, error, ctx); - CounterTicketsLogin->Inc(); - return; - } - } - } - - if (record.TokenType == ETokenType::Unknown && record.ResponsesLeft == 0) { - record.Error.Message = "Could not find correct token validator"; - record.Error.Retryable = false; - } - } - void SetToken(const TString& key, TTokenRecord& record, TIntrusivePtr<NACLib::TUserToken> token, const TActorContext& ctx) { TInstant now = ctx.Now(); record.Error.clear(); @@ -199,24 +107,6 @@ class TTicketParser : public TTicketParserImpl<TTicketParser> { return ticket.empty(); } - void SetTokenType(TTokenRecord& record, TStringBuf&, const TStringBuf ticketType) { - if (ticketType) { - record.TokenType = ParseTokenType(ticketType); - switch (record.TokenType) { - case ETokenType::Unsupported: - record.Error.Message = "Token is not supported"; - record.Error.Retryable = false; - break; - case ETokenType::Unknown: - record.Error.Message = "Unknown token"; - record.Error.Retryable = false; - break; - default: - break; - } - } - } - void Handle(TEvTicketParser::TEvRefreshTicket::TPtr& ev, const TActorContext&) { UserTokens.erase(ev->Get()->Ticket); } diff --git a/ydb/core/security/ticket_parser_impl.h b/ydb/core/security/ticket_parser_impl.h index c9e0f384a1b..576a40021d6 100644 --- a/ydb/core/security/ticket_parser_impl.h +++ b/ydb/core/security/ticket_parser_impl.h @@ -40,7 +40,9 @@ protected: TTokenRecordBase(const TTokenRecordBase&) = delete; TTokenRecordBase& operator =(const TTokenRecordBase&) = delete; + TString Ticket; + typename TDerived::ETokenType TokenType = TDerived::ETokenType::Unknown; TString Subject; // login TEvTicketParser::TError Error; TIntrusivePtr<NACLib::TUserToken> Token; @@ -123,6 +125,101 @@ protected: } } + template <typename TTokenRecord> + void SetTime(TTokenRecord& record, TInstant now) { + record.InitTime = now; + record.AccessTime = now; + record.ExpireTime = GetExpireTime(now); + } + + template <typename TTokenRecord> + void InitTokenRecord(const TString&, TTokenRecord& record, const TActorContext&, TInstant) { + if (record.TokenType == TDerived::ETokenType::Unknown && record.ResponsesLeft == 0) { + record.Error.Message = "Could not find correct token validator"; + record.Error.Retryable = false; + } + } + + template <typename TTokenRecord> + void InitTokenRecord(const TString& key, TTokenRecord& record, const TActorContext& ctx) { + TInstant now = ctx.Now(); + GetDerived()->SetTime(record, now); + + if (record.Error) { + return; + } + + if (record.TokenType == TDerived::ETokenType::Unknown || record.TokenType == TDerived::ETokenType::Builtin) { + if(record.Ticket.EndsWith("@" BUILTIN_ACL_DOMAIN)) { + record.TokenType = TDerived::ETokenType::Builtin; + GetDerived()->SetToken(key, record, new NACLib::TUserToken({ + .OriginalUserToken = record.Ticket, + .UserSID = record.Ticket, + .AuthType = record.GetAuthType() + }), ctx); + CounterTicketsBuiltin->Inc(); + return; + } + + if(record.Ticket.EndsWith("@" BUILTIN_ERROR_DOMAIN)) { + record.TokenType = TDerived::ETokenType::Builtin; + GetDerived()->SetError(key, record, {"Builtin error simulation"}, ctx); + CounterTicketsBuiltin->Inc(); + return; + } + } + + if (UseLoginProvider && (record.TokenType == TDerived::ETokenType::Unknown || record.TokenType == TDerived::ETokenType::Login)) { + TString database = Config.GetDomainLoginOnly() ? DomainName : record.Database; + auto itLoginProvider = LoginProviders.find(database); + if (itLoginProvider != LoginProviders.end()) { + NLogin::TLoginProvider& loginProvider(itLoginProvider->second); + auto response = loginProvider.ValidateToken({.Token = record.Ticket}); + if (response.Error) { + if (!response.TokenUnrecognized || record.TokenType != TDerived::ETokenType::Unknown) { + record.TokenType = TDerived::ETokenType::Login; + TEvTicketParser::TError error; + error.Message = response.Error; + error.Retryable = response.ErrorRetryable; + GetDerived()->SetError(key, record, error, ctx); + CounterTicketsLogin->Inc(); + return; + } + } else { + record.TokenType = TDerived::ETokenType::Login; + TVector<NACLib::TSID> groups; + if (response.Groups.has_value()) { + const std::vector<TString>& tokenGroups = response.Groups.value(); + groups.assign(tokenGroups.begin(), tokenGroups.end()); + } else { + const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(response.User); + groups.assign(providerGroups.begin(), providerGroups.end()); + } + record.ExpireTime = ToInstant(response.ExpiresAt); + GetDerived()->SetToken(key, record, new NACLib::TUserToken({ + .OriginalUserToken = record.Ticket, + .UserSID = response.User, + .GroupSIDs = groups, + .AuthType = record.GetAuthType() + }), ctx); + CounterTicketsLogin->Inc(); + return; + } + } else { + if (record.TokenType == TDerived::ETokenType::Login) { + TEvTicketParser::TError error; + error.Message = "Login state is not available yet"; + error.Retryable = false; + GetDerived()->SetError(key, record, error, ctx); + CounterTicketsLogin->Inc(); + return; + } + } + } + + GetDerived()->InitTokenRecord(key, record, ctx, now); + } + void Respond(TTokenRecordBase& record, const TActorContext& ctx) { if (record.IsTokenReady()) { for (const auto& request : record.AuthorizeRequests) { @@ -151,6 +248,25 @@ protected: record.Database = std::move(ev->Get()->Database); } + template <typename TTokenRecord> + void SetTokenType(TTokenRecord& record, TStringBuf&, const TStringBuf ticketType) { + if (ticketType) { + record.TokenType = GetDerived()->ParseTokenType(ticketType); + switch (record.TokenType) { + case TDerived::ETokenType::Unsupported: + record.Error.Message = "Token is not supported"; + record.Error.Retryable = false; + break; + case TDerived::ETokenType::Unknown: + record.Error.Message = "Unknown token"; + record.Error.Retryable = false; + break; + default: + break; + } + } + } + void Handle(TEvTicketParser::TEvAuthorizeTicket::TPtr& ev, const TActorContext& ctx) { TStringBuf ticket; TStringBuf ticketType; @@ -169,9 +285,10 @@ protected: ctx.Send(sender, new TEvTicketParser::TEvAuthorizeTicketResult(ev->Get()->Ticket, error), 0, cookie); return; } - auto recordPtr = GetDerived()->GetUserToken(key); - if (recordPtr) { - auto& record = *recordPtr; + auto& UserTokens = GetDerived()->GetUserTokens(); + auto it = UserTokens.find(key); + if (it != UserTokens.end()) { + auto& record = it->second; // we know about token if (record.IsTokenReady()) { // token already have built @@ -188,16 +305,16 @@ protected: CounterTicketsCacheHit->Inc(); return; } else { - recordPtr = GetDerived()->InsertUserToken(key, ticket); + it = UserTokens.emplace(key, ticket).first; CounterTicketsCacheMiss->Inc(); } - auto& record = *recordPtr; + auto& record = it->second; GetDerived()->TokenRecordSetup(record, ev); GetDerived()->SetTokenType(record, ticket, ticketType); - GetDerived()->InitTokenRecord(key, record, ctx); + InitTokenRecord(key, record, ctx); if (record.Error) { LOG_ERROR_S(ctx, NKikimrServices::TICKET_PARSER, "Ticket " << MaskTicket(ticket) << ": " << record.Error); ctx.Send(sender, new TEvTicketParser::TEvAuthorizeTicketResult(ev->Get()->Ticket, record.Error), 0, cookie); |