diff options
author | hcpp <hcpp@ydb.tech> | 2023-11-08 12:09:41 +0300 |
---|---|---|
committer | hcpp <hcpp@ydb.tech> | 2023-11-08 12:56:14 +0300 |
commit | a361f5b98b98b44ea510d274f6769164640dd5e1 (patch) | |
tree | c47c80962c6e2e7b06798238752fd3da0191a3f6 /library | |
parent | 9478806fde1f4d40bd5a45e7cbe77237dab613e9 (diff) | |
download | ydb-a361f5b98b98b44ea510d274f6769164640dd5e1.tar.gz |
metrics have been added
Diffstat (limited to 'library')
400 files changed, 31360 insertions, 0 deletions
diff --git a/library/cpp/http/cookies/cookies.cpp b/library/cpp/http/cookies/cookies.cpp new file mode 100644 index 0000000000..12b66c7f9d --- /dev/null +++ b/library/cpp/http/cookies/cookies.cpp @@ -0,0 +1,33 @@ +#include "cookies.h" + +#include <library/cpp/string_utils/scan/scan.h> +#include <util/string/strip.h> +#include <util/string/builder.h> + +namespace { + struct TCookiesScanner { + THttpCookies* C; + + inline void operator()(const TStringBuf& key, const TStringBuf& val) { + C->Add(StripString(key), StripString(val)); + } + }; +} + +void THttpCookies::Scan(const TStringBuf& s) { + Clear(); + TCookiesScanner scan = {this}; + ScanKeyValue<true, ';', '='>(s, scan); +} + +/*** https://datatracker.ietf.org/doc/html/rfc6265#section-5.4 ***/ +TString THttpCookies::ToString() const { + TStringBuilder result; + for (const auto& [key, value] : *this) { + if (!result.empty()) { + result << "; "; + } + result << key << "=" << value; + } + return result; +} diff --git a/library/cpp/http/cookies/cookies.h b/library/cpp/http/cookies/cookies.h new file mode 100644 index 0000000000..d7a0030c8b --- /dev/null +++ b/library/cpp/http/cookies/cookies.h @@ -0,0 +1,17 @@ +#pragma once + +#include "lctable.h" + +class THttpCookies: public TLowerCaseTable<TStringBuf> { +public: + inline THttpCookies(const TStringBuf& cookieString) { + Scan(cookieString); + } + + inline THttpCookies() noexcept { + } + + void Scan(const TStringBuf& cookieString); + + TString ToString() const; +}; diff --git a/library/cpp/http/cookies/lctable.h b/library/cpp/http/cookies/lctable.h new file mode 100644 index 0000000000..09c88eafb8 --- /dev/null +++ b/library/cpp/http/cookies/lctable.h @@ -0,0 +1,86 @@ +#pragma once + +#include <library/cpp/digest/lower_case/lchash.h> + +#include <util/generic/hash_multi_map.h> +#include <util/generic/strbuf.h> +#include <util/generic/algorithm.h> +#include <util/generic/singleton.h> + +struct TStrBufHash { + inline size_t operator()(const TStringBuf& s) const noexcept { + return FnvCaseLess<size_t>(s); + } +}; + +struct TStrBufEqualToCaseLess { + inline bool operator()(const TStringBuf& c1, const TStringBuf& c2) const noexcept { + typedef TLowerCaseIterator<const TStringBuf::TChar> TIter; + + return (c1.size() == c2.size()) && std::equal(TIter(c1.begin()), TIter(c1.end()), TIter(c2.begin())); + } +}; + +template <class T> +class TLowerCaseTable: private THashMultiMap<TStringBuf, T, TStrBufHash, TStrBufEqualToCaseLess> { + typedef THashMultiMap<TStringBuf, T, TStrBufHash, TStrBufEqualToCaseLess> TBase; + +public: + typedef typename TBase::const_iterator const_iterator; + typedef std::pair<const_iterator, const_iterator> TConstIteratorPair; + + using TBase::TBase; + using TBase::begin; + using TBase::end; + + inline TConstIteratorPair EqualRange(const TStringBuf& name) const { + return TBase::equal_range(name); + } + + inline const T& Get(const TStringBuf& name, size_t numOfValue = 0) const { + TConstIteratorPair range = EqualRange(name); + + if (range.first == TBase::end()) + return Default<T>(); + + if (numOfValue == 0) + return range.first->second; + + const_iterator next = range.first; + for (size_t c = 0; c < numOfValue; ++c) { + ++next; + if (next == range.second) + return Default<T>(); + } + + return next->second; + } + + inline bool Has(const TStringBuf& name) const { + return TBase::find(name) != TBase::end(); + } + + size_t NumOfValues(const TStringBuf& name) const { + return TBase::count(name); + } + + inline size_t Size() const noexcept { + return TBase::size(); + } + + inline bool Empty() const noexcept { + return TBase::empty(); + } + + inline void Add(const TStringBuf& key, const T& val) { + TBase::insert(typename TBase::value_type(key, val)); + } + + inline void Clear() noexcept { + TBase::clear(); + } + + inline size_t Erase(const TStringBuf& key) { + return TBase::erase(key); + } +}; diff --git a/library/cpp/http/cookies/ya.make b/library/cpp/http/cookies/ya.make new file mode 100644 index 0000000000..70c1e8f250 --- /dev/null +++ b/library/cpp/http/cookies/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +SRCS( + cookies.cpp +) + +PEERDIR( + library/cpp/digest/lower_case + library/cpp/string_utils/scan +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/string_utils/secret_string/secret_string.cpp b/library/cpp/string_utils/secret_string/secret_string.cpp new file mode 100644 index 0000000000..3b68d3cd27 --- /dev/null +++ b/library/cpp/string_utils/secret_string/secret_string.cpp @@ -0,0 +1,68 @@ +#include "secret_string.h" + +#include <util/system/madvise.h> + +namespace NSecretString { + TSecretString::TSecretString(TStringBuf value) { + Init(value); + } + + TSecretString::~TSecretString() { + try { + Clear(); + } catch (...) { + } + } + + TSecretString& TSecretString::operator=(const TSecretString& o) { + if (&o == this) { + return *this; + } + + Init(o.Value_); + + return *this; + } + + /** + * It is not honest "move". Actually it is copy-assignment with cleaning of other instance. + * This way allowes to avoid side effects of string optimizations: + * Copy-On-Write or Short-String-Optimization + */ + TSecretString& TSecretString::operator=(TSecretString&& o) { + if (&o == this) { + return *this; + } + + Init(o.Value_); + o.Clear(); + + return *this; + } + + TSecretString& TSecretString::operator=(const TStringBuf o) { + Init(o); + + return *this; + } + + void TSecretString::Init(TStringBuf value) { + Clear(); + if (value.empty()) { + return; + } + + Value_ = value; + MadviseExcludeFromCoreDump(Value_); + } + + void TSecretString::Clear() { + if (Value_.empty()) { + return; + } + + SecureZero((void*)Value_.data(), Value_.size()); + MadviseIncludeIntoCoreDump(Value_); + Value_.clear(); + } +} diff --git a/library/cpp/string_utils/secret_string/secret_string.h b/library/cpp/string_utils/secret_string/secret_string.h new file mode 100644 index 0000000000..fdb9f6a85c --- /dev/null +++ b/library/cpp/string_utils/secret_string/secret_string.h @@ -0,0 +1,74 @@ +#pragma once + +#include <library/cpp/string_utils/ztstrbuf/ztstrbuf.h> + +#include <util/generic/string.h> + +namespace NSecretString { + /** + * TSecretString allowes to store some long lived secrets in "secure" storage in memory. + * Common usage: + * 1) read secret value from disk/env/etc + * 2) put it into TSecretString + * 3) destory secret copy from 1) + * + * Useful scenerios for TSecretString: + * - in memory only tasks: using key to create crypto signature; + * - rare network cases: db password on connection or OAuth token in background tasks. + * These cases disclosure the secret + * because of sending it over network with some I/O frameworks. + * Usually such frameworks copy input params to provide network protocol: gRPC, for example. + * + * Supported features: + * 1. Exclude secret from core dump. + * madvise(MADV_DONTDUMP) in ctor excludes full memory page from core dump. + * madvise(MADV_DODUMP) in dtor reverts previous action. + * 2. Zero memory before free. + * + * Code dump looks like this: +(gdb) print s +$1 = (const TSecretString &) @0x7fff23c4c560: { + Value_ = {<TStringBase<TBasicString<char, std::__y1::char_traits<char> >, char, std::__y1::char_traits<char> >> = { + static npos = <optimized out>}, Data_ = 0x107c001d8 <error: Cannot access memory at address 0x107c001d8>}} + */ + + class TSecretString { + public: + TSecretString() = default; + TSecretString(TStringBuf value); + ~TSecretString(); + + TSecretString(const TSecretString& o) + : TSecretString(o.Value()) + { + } + + TSecretString(TSecretString&& o) + : TSecretString(o.Value()) + { + o.Clear(); + } + + TSecretString& operator=(const TSecretString& o); + TSecretString& operator=(TSecretString&& o); + + TSecretString& operator=(const TStringBuf o); + + operator TZtStringBuf() const { + return Value(); + } + + // Provides zero terminated string + TZtStringBuf Value() const { + return TZtStringBuf(Value_); + } + + private: + // TStringBuf breaks Copy-On-Write to provide correct copy-ctor and copy-assignment + void Init(TStringBuf value); + void Clear(); + + private: + TString Value_; + }; +} diff --git a/library/cpp/string_utils/secret_string/ya.make b/library/cpp/string_utils/secret_string/ya.make new file mode 100644 index 0000000000..c1d43f7a1d --- /dev/null +++ b/library/cpp/string_utils/secret_string/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +SRCS( + secret_string.cpp +) + +PEERDIR( + library/cpp/string_utils/ztstrbuf +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/string_utils/tskv_format/builder.cpp b/library/cpp/string_utils/tskv_format/builder.cpp new file mode 100644 index 0000000000..ede9074022 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/builder.cpp @@ -0,0 +1 @@ +#include "builder.h" diff --git a/library/cpp/string_utils/tskv_format/builder.h b/library/cpp/string_utils/tskv_format/builder.h new file mode 100644 index 0000000000..40689ddc85 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/builder.h @@ -0,0 +1,67 @@ +#pragma once + +#include "escape.h" + +#include <util/stream/str.h> + +namespace NTskvFormat { + class TLogBuilder { + private: + TStringStream Out; + + public: + TLogBuilder() = default; + + TLogBuilder(TStringBuf logType, ui32 unixtime) { + Begin(logType, unixtime); + } + + TLogBuilder(TStringBuf logType) { + Begin(logType); + } + + TLogBuilder& Add(TStringBuf fieldName, TStringBuf fieldValue) { + if (!Out.Empty()) { + Out << '\t'; + } + Escape(fieldName, Out.Str()); + Out << '='; + Escape(fieldValue, Out.Str()); + + return *this; + } + + TLogBuilder& AddUnescaped(TStringBuf fieldName, TStringBuf fieldValue) { + if (!Out.Empty()) { + Out << '\t'; + } + Out << fieldName << '=' << fieldValue; + return *this; + } + + TLogBuilder& Begin(TStringBuf logType, ui32 unixtime) { + Out << "tskv\ttskv_format=" << logType << "\tunixtime=" << unixtime; + return *this; + } + + TLogBuilder& Begin(TStringBuf logType) { + Out << "tskv\ttskv_format=" << logType; + return *this; + } + + TLogBuilder& End() { + Out << '\n'; + return *this; + } + + TLogBuilder& Clear() { + Out.Clear(); + return *this; + } + + TString& Str() { + return Out.Str(); + } + }; + +} diff --git a/library/cpp/string_utils/tskv_format/escape.cpp b/library/cpp/string_utils/tskv_format/escape.cpp new file mode 100644 index 0000000000..3dc78bec8c --- /dev/null +++ b/library/cpp/string_utils/tskv_format/escape.cpp @@ -0,0 +1,112 @@ +#include <util/generic/yexception.h> +#include "escape.h" + +namespace NTskvFormat { + namespace { + const TStringBuf ESCAPE_CHARS("\t\n\r\\\0=\"", 7); + + TString& EscapeImpl(const char* src, size_t len, TString& dst) { + TStringBuf srcStr(src, len); + size_t noEscapeStart = 0; + + while (noEscapeStart < len) { + size_t noEscapeEnd = srcStr.find_first_of(ESCAPE_CHARS, noEscapeStart); + + if (noEscapeEnd == TStringBuf::npos) { + dst.append(src + noEscapeStart, len - noEscapeStart); + break; + } + + dst.append(src + noEscapeStart, noEscapeEnd - noEscapeStart); + + switch (src[noEscapeEnd]) { + case '\t': + dst.append(TStringBuf("\\t")); + break; + case '\n': + dst.append(TStringBuf("\\n")); + break; + case '\r': + dst.append(TStringBuf("\\r")); + break; + case '\0': + dst.append(TStringBuf("\\0")); + break; + case '\\': + dst.append(TStringBuf("\\\\")); + break; + case '=': + dst.append(TStringBuf("\\=")); + break; + case '"': + dst.append(TStringBuf("\\\"")); + break; + } + + noEscapeStart = noEscapeEnd + 1; + } + + return dst; + } + + TString& UnescapeImpl(const char* src, const size_t len, TString& dst) { + TStringBuf srcStr(src, len); + size_t noEscapeStart = 0; + + while (noEscapeStart < len) { + size_t noEscapeEnd = srcStr.find('\\', noEscapeStart); + + if (noEscapeEnd == TStringBuf::npos) { + dst.append(src + noEscapeStart, len - noEscapeStart); + break; + } + + dst.append(src + noEscapeStart, noEscapeEnd - noEscapeStart); + + if (noEscapeEnd + 1 >= len) { + throw yexception() << "expected (t|n|r|0|\\|=|\"|) after \\. Got end of line."; + } + + switch (src[noEscapeEnd + 1]) { + case 't': + dst.append('\t'); + break; + case 'n': + dst.append('\n'); + break; + case 'r': + dst.append('\r'); + break; + case '0': + dst.append('\0'); + break; + case '\\': + dst.append('\\'); + break; + case '=': + dst.append('='); + break; + case '"': + dst.append('"'); + break; + default: + throw yexception() << "unexpected symbol '" << src[noEscapeEnd + 1] << "' after \\"; + } + + noEscapeStart = noEscapeEnd + 2; + } + + return dst; + } + + } + + TString& Escape(const TStringBuf& src, TString& dst) { + return EscapeImpl(src.data(), src.size(), dst); + } + + TString& Unescape(const TStringBuf& src, TString& dst) { + return UnescapeImpl(src.data(), src.size(), dst); + } + +} diff --git a/library/cpp/string_utils/tskv_format/escape.h b/library/cpp/string_utils/tskv_format/escape.h new file mode 100644 index 0000000000..2e3dd02c98 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/escape.h @@ -0,0 +1,10 @@ +#pragma once + +#include <util/generic/strbuf.h> +#include <util/generic/string.h> + +namespace NTskvFormat { + TString& Escape(const TStringBuf& src, TString& dst); + TString& Unescape(const TStringBuf& src, TString& dst); + +} diff --git a/library/cpp/string_utils/tskv_format/tskv_map.cpp b/library/cpp/string_utils/tskv_format/tskv_map.cpp new file mode 100644 index 0000000000..99e5f19731 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/tskv_map.cpp @@ -0,0 +1,60 @@ +#include "tskv_map.h" + +namespace { + void Split(const TStringBuf& kv, TStringBuf& key, TStringBuf& value, bool& keyHasEscapes) { + size_t delimiter = 0; + keyHasEscapes = false; + for (delimiter = 0; delimiter < kv.size() && kv[delimiter] != '='; ++delimiter) { + if (kv[delimiter] == '\\') { + ++delimiter; + keyHasEscapes = true; + } + } + + if (delimiter < kv.size()) { + key = kv.Head(delimiter); + value = kv.Tail(delimiter + 1); + } else { + throw yexception() << "Incorrect tskv format"; + } + } + + TStringBuf DeserializeTokenToBuffer(const TStringBuf& token, TString& buffer) { + size_t tokenStart = buffer.size(); + NTskvFormat::Unescape(token, buffer); + return TStringBuf(buffer).Tail(tokenStart); + } + + void DeserializeTokenToString(const TStringBuf& token, TString& result, bool unescape) { + if (unescape) { + result.clear(); + NTskvFormat::Unescape(token, result); + } else { + result = token; + } + + } +} + +void NTskvFormat::NDetail::DeserializeKvToStringBufs(const TStringBuf& kv, TStringBuf& key, TStringBuf& value, TString& buffer, bool unescape) { + bool keyHasEscapes = false; + Split(kv, key, value, keyHasEscapes); + if (unescape) { + if (keyHasEscapes) { + key = DeserializeTokenToBuffer(key, buffer); + } + if (value.Contains('\\')) { + value = DeserializeTokenToBuffer(value, buffer); + } + } +} + +void NTskvFormat::NDetail::DeserializeKvToStrings(const TStringBuf& kv, TString& key, TString& value, bool unescape) { + TStringBuf keyBuf, valueBuf; + bool keyHasEscapes = false; + Split(kv, keyBuf, valueBuf, keyHasEscapes); + + Y_UNUSED(keyHasEscapes); + DeserializeTokenToString(keyBuf, key, unescape); + DeserializeTokenToString(valueBuf, value, unescape); +} diff --git a/library/cpp/string_utils/tskv_format/tskv_map.h b/library/cpp/string_utils/tskv_format/tskv_map.h new file mode 100644 index 0000000000..4f4978fcf5 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/tskv_map.h @@ -0,0 +1,62 @@ +#pragma once + +#include "escape.h" +#include <util/string/cast.h> +#include <util/string/split.h> + +namespace NTskvFormat { + namespace NDetail { + void DeserializeKvToStringBufs(const TStringBuf& kv, TStringBuf& key, TStringBuf& value, TString& buffer, bool unescape); + void DeserializeKvToStrings(const TStringBuf& kv, TString& key, TString& value, bool unescape); + } + + template <typename T> + TString& SerializeMap(const T& data, TString& result) { + result.clear(); + for (const auto& kv : data) { + if (result.size() > 0) { + result.push_back('\t'); + } + Escape(ToString(kv.first), result); + result.push_back('='); + Escape(ToString(kv.second), result); + } + return result; + } + + /** + * Deserializing to TStringBuf is faster, just remember that `data' + * must not be invalidated while `result' is still in use. + */ + template <typename T> + void DeserializeMap(const TStringBuf& data, T& result, TString& buffer, bool unescape = true) { + result.clear(); + buffer.clear(); + buffer.reserve(data.size()); + TStringBuf key, value; + + StringSplitter(data.begin(), data.end()).Split('\t').Consume([&](const TStringBuf kv){ + NDetail::DeserializeKvToStringBufs(kv, key, value, buffer, unescape); + result[key] = value; + }); + + Y_ASSERT(buffer.size() <= data.size()); + } + + template <typename T> + void DeserializeMap(const TStringBuf& data, T& result, bool unescape = true) { + if constexpr(std::is_same<typename T::key_type, TStringBuf>::value || + std::is_same<typename T::mapped_type, TStringBuf>::value) + { + DeserializeMap(data, result, result.DeserializeBuffer, unescape); // we can't unescape values w/o buffer + return; + } + result.clear(); + TString key, value; + + StringSplitter(data.begin(), data.end()).Split('\t').Consume([&](const TStringBuf kv){ + NDetail::DeserializeKvToStrings(kv, key, value, unescape); + result[key] = value; + }); + } +} diff --git a/library/cpp/string_utils/tskv_format/ya.make b/library/cpp/string_utils/tskv_format/ya.make new file mode 100644 index 0000000000..1283d748b3 --- /dev/null +++ b/library/cpp/string_utils/tskv_format/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +SRCS( + builder.cpp + escape.cpp + tskv_map.cpp +) + +END() diff --git a/library/cpp/tvmauth/checked_service_ticket.h b/library/cpp/tvmauth/checked_service_ticket.h new file mode 100644 index 0000000000..71dc48b7cb --- /dev/null +++ b/library/cpp/tvmauth/checked_service_ticket.h @@ -0,0 +1,76 @@ +#pragma once + +#include "ticket_status.h" +#include "type.h" +#include "utils.h" + +#include <util/generic/ptr.h> + +namespace NTvmAuth::NInternal { + class TCanningKnife; +} + +namespace NTvmAuth { + class TCheckedServiceTicket { + public: + class TImpl; + + TCheckedServiceTicket(THolder<TImpl> impl); + TCheckedServiceTicket(TCheckedServiceTicket&& o); + ~TCheckedServiceTicket(); + + TCheckedServiceTicket& operator=(TCheckedServiceTicket&&); + + /*! + * @return True value if ticket parsed and checked successfully + */ + explicit operator bool() const; + + /*! + * @return TTvmId of request destination + */ + TTvmId GetDst() const; + + /*! + * You should check src with your ACL + * @return TvmId of request source + */ + TTvmId GetSrc() const; + + /*! + * @return Ticket check status + */ + ETicketStatus GetStatus() const; + + /*! + * DebugInfo is human readable data for debug purposes + * @return Serialized ticket + */ + TString DebugInfo() const; + + /*! + * IssuerUID is UID of developer who is debuging something, + * so he(she) issued ServiceTicket with his(her) ssh-sign: + * it is grant_type=sshkey in tvm-api. + * https://wiki.yandex-team.ru/passport/tvm2/debug/#sxoditvapizakrytoeserviceticketami + * @return uid + */ + TMaybe<TUid> GetIssuerUid() const; + + public: // for python binding + TCheckedServiceTicket() = default; + + private: + THolder<TImpl> Impl_; + friend class NInternal::TCanningKnife; + }; + + namespace NBlackboxTvmId { + const TStringBuf Prod = "222"; + const TStringBuf Test = "224"; + const TStringBuf ProdYateam = "223"; + const TStringBuf TestYateam = "225"; + const TStringBuf Stress = "226"; + const TStringBuf Mimino = "239"; + } +} diff --git a/library/cpp/tvmauth/checked_user_ticket.h b/library/cpp/tvmauth/checked_user_ticket.h new file mode 100644 index 0000000000..32256de6a7 --- /dev/null +++ b/library/cpp/tvmauth/checked_user_ticket.h @@ -0,0 +1,111 @@ +#pragma once + +#include "ticket_status.h" +#include "type.h" +#include "utils.h" + +#include <util/generic/ptr.h> + +#include <optional> + +namespace NTvmAuth::NInternal { + class TCanningKnife; +} + +namespace NTvmAuth { + /*! + * BlackboxEnv describes environment of Passport: + * https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#0-opredeljaemsjasokruzhenijami + */ + enum class EBlackboxEnv: ui8 { + Prod, + Test, + ProdYateam, + TestYateam, + Stress + }; + + /*! + * UserTicket contains only valid users. + * Details: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#chtoestvusertickete + */ + class TCheckedUserTicket { + public: + class TImpl; + + TCheckedUserTicket(THolder<TImpl> impl); + TCheckedUserTicket(TCheckedUserTicket&&); + ~TCheckedUserTicket(); + + TCheckedUserTicket& operator=(TCheckedUserTicket&&); + + /*! + * @return True value if ticket parsed and checked successfully + */ + explicit operator bool() const; + + /*! + * Never empty + * @return UIDs of users listed in ticket + */ + const TUids& GetUids() const; + + /*! + * Maybe 0 + * @return Default user in ticket + */ + TUid GetDefaultUid() const; + + /*! + * Never empty + * @return UIDs of users listed in ticket with user extended fields + */ + TUidsExtFieldsMap GetUidsExtFields() const; + + /*! + * Empty if there is no default uid in ticket + * @return Default user in ticket with extended fields + */ + std::optional<TUserExtFields> GetDefaultUidExtFields() const; + + /*! + * Scopes inherited from credential - never empty + * @return Newly constructed vector of scopes + */ + const TScopes& GetScopes() const; + + /*! + * Check if scope presented in ticket + */ + bool HasScope(TStringBuf scopeName) const; + + /*! + * @return Ticket check status + */ + ETicketStatus GetStatus() const; + + /*! + * DebugInfo is human readable data for debug purposes + * @return Serialized ticket + */ + TString DebugInfo() const; + + /*! + * Env of user + */ + EBlackboxEnv GetEnv() const; + + /*! + * @return login_id of user + * empty if ticket does not contain login_id + */ + const TString& GetLoginId() const; + + public: // for python binding + TCheckedUserTicket() = default; + + private: + THolder<TImpl> Impl_; + friend class NInternal::TCanningKnife; + }; +} diff --git a/library/cpp/tvmauth/client/README.md b/library/cpp/tvmauth/client/README.md new file mode 100644 index 0000000000..cda6a22d3c --- /dev/null +++ b/library/cpp/tvmauth/client/README.md @@ -0,0 +1,84 @@ +Overview +=== +This library provides ability to operate with TVM. Library is fast enough to get or check tickets for every request without burning CPU. + +[Home page of project](https://wiki.yandex-team.ru/passport/tvm2/) +You can find some examples in [here](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/examples). + +You can ask questions: [PASSPORTDUTY](https://st.yandex-team.ru/createTicket?queue=PASSPORTDUTY&_form=77618) + +TvmClient +=== +Don't forget to collect logs from client. +___ +`TvmClient` allowes: +1. `GetServiceTicketFor()` - to fetch ServiceTicket for outgoing request +2. `CheckServiceTicket()` - to check ServiceTicket from incoming request +3. `CheckUserTicket()` - to check UserTicket from incoming request +4. `GetRoles()` - to get roles from IDM + +All methods are thread-safe. + +You should check status of `CheckedServiceTicket` or `CheckedUserTicket` for equality 'Ok'. You can get ticket fields (src/uids/scopes) only for correct ticket. Otherwise exception will be thrown. +___ +You should check status of client with `GetStatus()`: +* `OK` - nothing to do here +* `Warning` - **you should trigger your monitoring alert** + + Normal operation of TvmClient is still possible but there are problems with refreshing cache, so it is expiring. + Is tvm-api.yandex.net accessible? + Have you changed your TVM-secret or your backend (dst) deleted its TVM-client? + +* `Error` - **you should trigger your monitoring alert and close this instance for user-traffic** + + TvmClient's cache is already invalid (expired) or soon will be: you can't check valid ServiceTicket or be authenticated by your backends (dsts) + +___ +Constructor creates system thread for refreshing cache - so do not fork your proccess after creating `TTvmClient` instance. Constructor leads to network I/O. Other methods always use memory. + +Exceptions maybe thrown from constructor: +* `TRetriableException` - maybe some network trouble: you can try to create client one more time. +* `TNonRetriableException` - settings are bad: fix them. +___ +You can choose way for fetching data for your service operation: +* http://localhost:{port}/tvm - recomended way +* https://tvm-api.yandex.net + +TvmTool +------------ +`TTvmClient` uses local http-interface to get state. This interface can be provided with tvmtool (local daemon) or Qloud/YP (local http api in container). +See more: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. + +`TTvmClient` fetches configuration from tvmtool, so you need only to tell client how to connect to it and tell which alias of tvm id should be used for this `TvmClient` instance. + +TvmApi +------------ +First of all: please use `DiskCacheDir` - it provides reliability for your service and for tvm-api. +Please check restrictions of this field. + +Roles +=== +[Example](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/examples/create_with_tvmapi/create.cpp?rev=r8888584#L84) + +You need to configure roles fetching +------------ +1. Enable disk cache: [DiskCacheDir](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/api/settings.h?rev=r9001419#L54) + +2. Enable ServiceTicket fetching: + [SelfTvmId](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/api/settings.h?rev=r9001419#L57) + [Secret](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/api/settings.h?rev=r9001419#L60) +3. Enable roles fetching from tirole: + [FetchRolesForIdmSystemSlug](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/api/settings.h?rev=r9001419#L78) + +You need to use roles for request check +------------ +1. Check ServiceTicket and/or UserTicket - as usual: + [CheckServiceTicket()](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/facade.h?rev=r7890770#L91)/[CheckUserTicket()](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/facade.h?rev=r7890770#L99) + +2. Get actual roles from `TvmClient`: [GetRoles()](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/facade.h?rev=r7890770#L105) + +3. Use roles + - case#1: [get](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/roles/roles.h?rev=r7890770#L37-46) role list for service or user and check for the exact role you need. + - case#2: use [shortcuts](https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/misc/roles/roles.h?rev=r7890770#L50) - they are wrappers for case#1 + +4. If consumer (service or user) has required role, you can perform request. + If consumer doesn't have required role, you should show error message with useful message. diff --git a/library/cpp/tvmauth/client/client_status.cpp b/library/cpp/tvmauth/client/client_status.cpp new file mode 100644 index 0000000000..eca35ba22b --- /dev/null +++ b/library/cpp/tvmauth/client/client_status.cpp @@ -0,0 +1,6 @@ +#include "client_status.h" + +template <> +void Out<NTvmAuth::TClientStatus>(IOutputStream& out, const NTvmAuth::TClientStatus& s) { + out << s.GetCode() << ": " << s.GetLastError(); +} diff --git a/library/cpp/tvmauth/client/client_status.h b/library/cpp/tvmauth/client/client_status.h new file mode 100644 index 0000000000..831a66e299 --- /dev/null +++ b/library/cpp/tvmauth/client/client_status.h @@ -0,0 +1,84 @@ +#pragma once + +#include <util/generic/string.h> +#include <util/string/builder.h> + +namespace NTvmAuth { + class TClientStatus { + public: + enum ECode { + Ok, + Warning, + Error, + IncompleteTicketsSet, + NotInitialized, + }; + + TClientStatus(ECode state, TString&& lastError) + : Code_(state) + , LastError_(std::move(lastError)) + { + } + + TClientStatus() = default; + TClientStatus(const TClientStatus&) = default; + TClientStatus(TClientStatus&&) = default; + + TClientStatus& operator=(const TClientStatus&) = default; + TClientStatus& operator=(TClientStatus&&) = default; + + ECode GetCode() const { + return Code_; + } + + const TString& GetLastError() const { + return LastError_; + } + + TString CreateJugglerMessage() const { + return TStringBuilder() << GetJugglerCode() << ";TvmClient: " << LastError_ << "\n"; + } + + private: + int32_t GetJugglerCode() const { + switch (Code_) { + case ECode::Ok: + return 0; // OK juggler check state + case ECode::Warning: + case ECode::IncompleteTicketsSet: + case ECode::NotInitialized: + return 1; // WARN juggler check state + case ECode::Error: + return 2; // CRIT juggler check state + } + return 2; // This should not happen, so set check state as CRIT. + } + + ECode Code_ = Ok; + TString LastError_; + }; + + static inline bool operator==(const TClientStatus& l, const TClientStatus& r) noexcept { + return l.GetCode() == r.GetCode() && l.GetLastError() == r.GetLastError(); + } + + static inline bool operator==(const TClientStatus& l, const TClientStatus::ECode r) noexcept { + return l.GetCode() == r; + } + + static inline bool operator==(const TClientStatus::ECode l, const TClientStatus& r) noexcept { + return r.GetCode() == l; + } + + static inline bool operator!=(const TClientStatus& l, const TClientStatus& r) noexcept { + return !(l == r); + } + + static inline bool operator!=(const TClientStatus& l, const TClientStatus::ECode r) noexcept { + return !(l == r); + } + + static inline bool operator!=(const TClientStatus::ECode l, const TClientStatus& r) noexcept { + return !(l == r); + } +} diff --git a/library/cpp/tvmauth/client/exception.h b/library/cpp/tvmauth/client/exception.h new file mode 100644 index 0000000000..43de506f4e --- /dev/null +++ b/library/cpp/tvmauth/client/exception.h @@ -0,0 +1,29 @@ +#pragma once + +#include <library/cpp/tvmauth/exception.h> + +namespace NTvmAuth { + class TClientException: public TTvmException { + }; + + class TInternalException: public TTvmException { + }; + + class TRetriableException: public TClientException { + }; + class TNonRetriableException: public TClientException { + }; + + class TNotInitializedException: public TClientException { + }; + + class TIllegalUsage: public TNonRetriableException { + }; + + class TBrokenTvmClientSettings: public TIllegalUsage { + }; + class TMissingServiceTicket: public TNonRetriableException { + }; + class TPermissionDenied: public TNonRetriableException { + }; +} diff --git a/library/cpp/tvmauth/client/facade.cpp b/library/cpp/tvmauth/client/facade.cpp new file mode 100644 index 0000000000..2647c276fc --- /dev/null +++ b/library/cpp/tvmauth/client/facade.cpp @@ -0,0 +1,88 @@ +#include "facade.h" + +#include "misc/api/threaded_updater.h" +#include "misc/tool/threaded_updater.h" + +namespace NTvmAuth { + TTvmClient::TTvmClient(const NTvmTool::TClientSettings& settings, TLoggerPtr logger) + : Updater_(NTvmTool::TThreadedUpdater::Create(settings, std::move(logger))) + { + ServiceTicketCheckFlags_.NeedDstCheck = settings.ShouldCheckDst; + } + + TTvmClient::TTvmClient(const NTvmApi::TClientSettings& settings, TLoggerPtr logger) + : Updater_(NTvmApi::TThreadedUpdater::Create(settings, std::move(logger))) + { + ServiceTicketCheckFlags_.NeedDstCheck = settings.ShouldCheckDst; + } + + TTvmClient::TTvmClient( + TAsyncUpdaterPtr updater, + const TServiceContext::TCheckFlags& serviceTicketCheckFlags) + : Updater_(std::move(updater)) + , ServiceTicketCheckFlags_(serviceTicketCheckFlags) + { + try { + if (Updater_->GetRoles()) { + } + } catch (const TIllegalUsage&) { + // it is a test probably + } catch (const TNotInitializedException&) { + // it is a test probably + } + } + + TTvmClient::TTvmClient(TTvmClient&& o) = default; + TTvmClient::~TTvmClient() = default; + TTvmClient& TTvmClient::operator=(TTvmClient&& o) = default; + + TClientStatus TTvmClient::GetStatus() const { + Y_ENSURE(Updater_); + return Updater_->GetStatus(); + } + + TInstant TTvmClient::GetUpdateTimeOfPublicKeys() const { + Y_ENSURE(Updater_); + return Updater_->GetUpdateTimeOfPublicKeys(); + } + + TInstant TTvmClient::GetUpdateTimeOfServiceTickets() const { + Y_ENSURE(Updater_); + return Updater_->GetUpdateTimeOfServiceTickets(); + } + + TInstant TTvmClient::GetInvalidationTimeOfPublicKeys() const { + Y_ENSURE(Updater_); + return Updater_->GetInvalidationTimeOfPublicKeys(); + } + + TInstant TTvmClient::GetInvalidationTimeOfServiceTickets() const { + Y_ENSURE(Updater_); + return Updater_->GetInvalidationTimeOfServiceTickets(); + } + + TString TTvmClient::GetServiceTicketFor(const TClientSettings::TAlias& dst) const { + Y_ENSURE(Updater_); + return Updater_->GetServiceTicketFor(dst); + } + + TString TTvmClient::GetServiceTicketFor(const TTvmId dst) const { + Y_ENSURE(Updater_); + return Updater_->GetServiceTicketFor(dst); + } + + TCheckedServiceTicket TTvmClient::CheckServiceTicket(TStringBuf ticket) const { + Y_ENSURE(Updater_); + return Updater_->CheckServiceTicket(ticket, ServiceTicketCheckFlags_); + } + + TCheckedUserTicket TTvmClient::CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overridenEnv) const { + Y_ENSURE(Updater_); + return Updater_->CheckUserTicket(ticket, overridenEnv); + } + + NRoles::TRolesPtr TTvmClient::GetRoles() const { + Y_ENSURE(Updater_); + return Updater_->GetRoles(); + } +} diff --git a/library/cpp/tvmauth/client/facade.h b/library/cpp/tvmauth/client/facade.h new file mode 100644 index 0000000000..8e8b635a2a --- /dev/null +++ b/library/cpp/tvmauth/client/facade.h @@ -0,0 +1,118 @@ +#pragma once + +#include "misc/async_updater.h" +#include "misc/api/settings.h" +#include "misc/tool/settings.h" + +#include <library/cpp/tvmauth/checked_service_ticket.h> +#include <library/cpp/tvmauth/checked_user_ticket.h> + +namespace NTvmAuth::NInternal { + class TClientCaningKnife; +} + +namespace NTvmAuth { + class TDefaultUidChecker; + class TServiceTicketGetter; + class TServiceTicketChecker; + class TSrcChecker; + class TUserTicketChecker; + + /*! + * Long lived thread-safe object for interacting with TVM. + * In 99% cases TvmClient shoud be created at service startup and live for the whole process lifetime. + */ + class TTvmClient { + public: + /*! + * Uses local http-interface to get state: http://localhost/tvm/. + * This interface can be provided with tvmtool (local daemon) or Qloud/YP (local http api in container). + * See more: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. + * + * Starts thread for updating of in-memory cache in background + * @param settings + * @param logger is usefull for monitoring and debuging + */ + TTvmClient(const NTvmTool::TClientSettings& settings, TLoggerPtr logger); + + /*! + * Uses general way to get state: https://tvm-api.yandex.net. + * It is not recomended for Qloud/YP. + * + * Starts thread for updating of in-memory cache in background + * Reads cache from disk if specified + * @param settings + * @param logger is usefull for monitoring and debuging + */ + TTvmClient(const NTvmApi::TClientSettings& settings, TLoggerPtr logger); + + /*! + * Feel free to use custom updating logic in tests + */ + TTvmClient( + TAsyncUpdaterPtr updater, + const TServiceContext::TCheckFlags& serviceTicketCheckFlags = {}); + + TTvmClient(TTvmClient&&); + ~TTvmClient(); + TTvmClient& operator=(TTvmClient&&); + + /*! + * You should trigger your monitoring if status is not Ok. + * It will be unable to operate if status is Error. + * Description: https://a.yandex-team.ru/arc/trunk/arcadia/library/cpp/tvmauth/client/README.md#high-level-interface + * @return Current status of client. + */ + TClientStatus GetStatus() const; + + /*! + * Some tools for monitoring + */ + + TInstant GetUpdateTimeOfPublicKeys() const; + TInstant GetUpdateTimeOfServiceTickets() const; + TInstant GetInvalidationTimeOfPublicKeys() const; + TInstant GetInvalidationTimeOfServiceTickets() const; + + /*! + * Requires fetchinig options (from TClientSettings or Qloud/YP/tvmtool settings) + * Can throw exception if cache is invalid or wrong config + * + * Alias is local label for TvmID + * which can be used to avoid this number in every checking case in code. + * @param dst + */ + TString GetServiceTicketFor(const TClientSettings::TAlias& dst) const; + TString GetServiceTicketFor(const TTvmId dst) const; + + /*! + * For TTvmApi::TClientSettings: checking must be enabled in TClientSettings + * Can throw exception if checking was not enabled in settings + * + * ServiceTicket contains src: you should check it by yourself with ACL + * @param ticket + */ + TCheckedServiceTicket CheckServiceTicket(TStringBuf ticket) const; + + /*! + * Requires blackbox enviroment (from TClientSettings or Qloud/YP/tvmtool settings) + * Can throw exception if checking was not enabled in settings + * @param ticket + * @param overrideEnv allowes you to override env from settings + */ + TCheckedUserTicket CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overrideEnv = {}) const; + + /*! + * Requires idm system slug (from TClientSettings or Qloud/YP/tvmtool settings) + * Can throw exception if slug was not specified in settings + */ + NRoles::TRolesPtr GetRoles() const; + + private: + TAsyncUpdaterPtr Updater_; + + TServiceContext::TCheckFlags ServiceTicketCheckFlags_; + + friend class NInternal::TClientCaningKnife; + }; +} diff --git a/library/cpp/tvmauth/client/logger.cpp b/library/cpp/tvmauth/client/logger.cpp new file mode 100644 index 0000000000..bd63773cdf --- /dev/null +++ b/library/cpp/tvmauth/client/logger.cpp @@ -0,0 +1,12 @@ +#include "logger.h" + +#include <util/datetime/base.h> +#include <util/generic/string.h> + +namespace NTvmAuth { + void TCerrLogger::Log(int lvl, const TString& msg) { + if (lvl > Level_) + return; + Cerr << TInstant::Now().ToStringLocal() << " lvl=" << lvl << " msg: " << msg << "\n"; + } +} diff --git a/library/cpp/tvmauth/client/logger.h b/library/cpp/tvmauth/client/logger.h new file mode 100644 index 0000000000..6f3718a2aa --- /dev/null +++ b/library/cpp/tvmauth/client/logger.h @@ -0,0 +1,59 @@ +#pragma once + +#include <util/generic/ptr.h> + +namespace NTvmAuth { + class ILogger: public TAtomicRefCount<ILogger> { + public: + virtual ~ILogger() = default; + + void Debug(const TString& msg) { + Log(7, msg); + } + + void Info(const TString& msg) { + Log(6, msg); + } + + void Warning(const TString& msg) { + Log(4, msg); + } + + void Error(const TString& msg) { + Log(3, msg); + } + + protected: + /*! + * Log event + * @param lvl is syslog level: 0(Emergency) ... 7(Debug) + * @param msg + */ + virtual void Log(int lvl, const TString& msg) = 0; + }; + + class TCerrLogger: public ILogger { + public: + TCerrLogger(int level) + : Level_(level) + { + } + + void Log(int lvl, const TString& msg) override; + + private: + const int Level_; + }; + + using TLoggerPtr = TIntrusivePtr<ILogger>; + + class TDevNullLogger: public ILogger { + public: + static TLoggerPtr IAmBrave() { + return MakeIntrusive<TDevNullLogger>(); + } + + void Log(int, const TString&) override { + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.cpp b/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.cpp new file mode 100644 index 0000000000..cd6ec45406 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.cpp @@ -0,0 +1,166 @@ +#include "tvm_client.h" + +#include <util/string/builder.h> + +namespace NTvmAuth::NDynamicClient { + TIntrusivePtr<TTvmClient> TTvmClient::Create(const NTvmApi::TClientSettings& settings, TLoggerPtr logger) { + Y_ENSURE_EX(logger, TNonRetriableException() << "Logger is required"); + THolder<TTvmClient> p(new TTvmClient(settings, std::move(logger))); + p->Init(); + p->StartWorker(); + return p.Release(); + } + + NThreading::TFuture<TAddResponse> TTvmClient::Add(TDsts&& dsts) { + if (dsts.empty()) { + LogDebug("Adding dst: got empty task"); + return NThreading::MakeFuture<TAddResponse>(TAddResponse{}); + } + + NThreading::TPromise<TAddResponse> promise = NThreading::NewPromise<TAddResponse>(); + + TServiceTickets::TMapIdStr requestedTicketsFromStartUpCache = GetRequestedTicketsFromStartUpCache(dsts); + + if (requestedTicketsFromStartUpCache.size() == dsts.size() && + !IsInvalid(TServiceTickets::GetInvalidationTime(requestedTicketsFromStartUpCache), TInstant::Now())) { + std::unique_lock lock(*ServiceTicketBatchUpdateMutex_); + + TPairTicketsErrors newCache; + TServiceTicketsPtr cache = GetCachedServiceTickets(); + + NTvmApi::TDstSetPtr oldDsts = GetDsts(); + std::shared_ptr<TDsts> newDsts = std::make_shared<TDsts>(oldDsts->begin(), oldDsts->end()); + + for (const auto& ticket : cache->TicketsById) { + newCache.Tickets.insert(ticket); + } + for (const auto& error : cache->ErrorsById) { + newCache.Errors.insert(error); + } + for (const auto& ticket : requestedTicketsFromStartUpCache) { + newCache.Tickets.insert(ticket); + newDsts->insert(ticket.first); + } + + UpdateServiceTicketsCache(std::move(newCache), GetStartUpCacheBornDate()); + SetDsts(std::move(newDsts)); + + lock.unlock(); + + TAddResponse response; + + for (const auto& dst : dsts) { + response.emplace(dst, TDstResponse{EDstStatus::Success, TString()}); + LogDebug(TStringBuilder() << "Got ticket from disk cache" + << ": dst=" << dst.Id << " got ticket"); + } + + promise.SetValue(std::move(response)); + return promise.GetFuture(); + } + + const size_t size = dsts.size(); + const ui64 id = ++TaskIds_; + + TaskQueue_.Enqueue(TTask{id, promise, std::move(dsts)}); + + LogDebug(TStringBuilder() << "Adding dst: got task #" << id << " with " << size << " dsts"); + return promise.GetFuture(); + } + + std::optional<TString> TTvmClient::GetOptionalServiceTicketFor(const TTvmId dst) { + TServiceTicketsPtr tickets = GetCachedServiceTickets(); + + Y_ENSURE_EX(tickets, + TBrokenTvmClientSettings() + << "Need to enable fetching of service tickets in settings"); + + auto it = tickets->TicketsById.find(dst); + if (it != tickets->TicketsById.end()) { + return it->second; + } + + it = tickets->ErrorsById.find(dst); + if (it != tickets->ErrorsById.end()) { + ythrow TMissingServiceTicket() + << "Failed to get ticket for '" << dst << "': " + << it->second; + } + + return {}; + } + + TTvmClient::TTvmClient(const NTvmApi::TClientSettings& settings, TLoggerPtr logger) + : TBase(settings, logger) + { + } + + TTvmClient::~TTvmClient() { + TBase::StopWorker(); + } + + void TTvmClient::Worker() { + TBase::Worker(); + ProcessTasks(); + } + + void TTvmClient::ProcessTasks() { + TaskQueue_.DequeueAll(&Tasks_); + if (Tasks_.empty()) { + return; + } + + TDsts required; + for (const TTask& task : Tasks_) { + for (const auto& dst : task.Dsts) { + required.insert(dst); + } + } + + TServiceTicketsPtr cache = UpdateMissingServiceTickets(required); + for (TTask& task : Tasks_) { + try { + SetResponseForTask(task, *cache); + } catch (const std::exception& e) { + LogError(TStringBuilder() + << "Adding dst: task #" << task.Id << ": exception: " << e.what()); + } catch (...) { + LogError(TStringBuilder() + << "Adding dst: task #" << task.Id << ": exception: " << CurrentExceptionMessage()); + } + } + + Tasks_.clear(); + } + + static const TString UNKNOWN = "Unknown reason"; + void TTvmClient::SetResponseForTask(TTvmClient::TTask& task, const TServiceTickets& cache) { + if (task.Promise.HasValue()) { + LogWarning(TStringBuilder() << "Adding dst: task #" << task.Id << " already has value"); + return; + } + + TAddResponse response; + + for (const auto& dst : task.Dsts) { + if (cache.TicketsById.contains(dst.Id)) { + response.emplace(dst, TDstResponse{EDstStatus::Success, TString()}); + + LogDebug(TStringBuilder() << "Adding dst: task #" << task.Id + << ": dst=" << dst.Id << " got ticket"); + continue; + } + + auto it = cache.ErrorsById.find(dst.Id); + const TString& error = it == cache.ErrorsById.end() ? UNKNOWN : it->second; + response.emplace(dst, TDstResponse{EDstStatus::Fail, error}); + + LogWarning(TStringBuilder() << "Adding dst: task #" << task.Id + << ": dst=" << dst.Id + << " failed to get ticket: " << error); + } + + LogDebug(TStringBuilder() << "Adding dst: task #" << task.Id << ": set value"); + task.Promise.SetValue(std::move(response)); + } +} diff --git a/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.h b/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.h new file mode 100644 index 0000000000..67eeb2618a --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.h @@ -0,0 +1,60 @@ +#pragma once + +#include <library/cpp/tvmauth/client/misc/api/threaded_updater.h> + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/map.h> +#include <util/thread/lfqueue.h> + +#include <optional> + +namespace NTvmAuth::NDynamicClient { + enum class EDstStatus { + Success, + Fail, + }; + + struct TDstResponse { + EDstStatus Status = EDstStatus::Fail; + TString Error; + + bool operator==(const TDstResponse& o) const { + return Status == o.Status && Error == o.Error; + } + }; + + using TDsts = NTvmApi::TDstSet; + using TAddResponse = TMap<NTvmApi::TClientSettings::TDst, TDstResponse>; + + class TTvmClient: public NTvmApi::TThreadedUpdater { + public: + static TIntrusivePtr<TTvmClient> Create(const NTvmApi::TClientSettings& settings, TLoggerPtr logger); + virtual ~TTvmClient(); + + NThreading::TFuture<TAddResponse> Add(TDsts&& dsts); + std::optional<TString> GetOptionalServiceTicketFor(const TTvmId dst); + + protected: // for tests + struct TTask { + ui64 Id = 0; + NThreading::TPromise<TAddResponse> Promise; + TDsts Dsts; + }; + + using TBase = NTvmApi::TThreadedUpdater; + + protected: // for tests + TTvmClient(const NTvmApi::TClientSettings& settings, TLoggerPtr logger); + + void Worker() override; + void ProcessTasks(); + + void SetResponseForTask(TTask& task, const TServiceTickets& cache); + + private: + std::atomic<ui64> TaskIds_ = {0}; + TLockFreeQueue<TTask> TaskQueue_; + TVector<TTask> Tasks_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/api/dynamic_dst/ya.make b/library/cpp/tvmauth/client/misc/api/dynamic_dst/ya.make new file mode 100644 index 0000000000..27908b39fe --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/dynamic_dst/ya.make @@ -0,0 +1,18 @@ +LIBRARY() + +PEERDIR( + library/cpp/threading/future + library/cpp/tvmauth/client +) + +SRCS( + tvm_client.cpp +) + +GENERATE_ENUM_SERIALIZATION(tvm_client.h) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/library/cpp/tvmauth/client/misc/api/retry_settings.h b/library/cpp/tvmauth/client/misc/api/retry_settings.h new file mode 100644 index 0000000000..607b230811 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/retry_settings.h @@ -0,0 +1,33 @@ +#pragma once + +#include <library/cpp/tvmauth/client/misc/exponential_backoff.h> + +namespace NTvmAuth::NTvmApi { + struct TRetrySettings { + TExponentialBackoff::TSettings BackoffSettings = { + TDuration::Seconds(0), + TDuration::Minutes(1), + 2, + 0.5, + }; + TDuration MaxRandomSleepDefault = TDuration::Seconds(5); + TDuration MaxRandomSleepWhenOk = TDuration::Minutes(1); + ui32 RetriesOnStart = 3; + ui32 RetriesInBackground = 2; + TDuration WorkerAwakingPeriod = TDuration::Seconds(10); + ui32 DstsLimit = 300; + TDuration RolesUpdatePeriod = TDuration::Minutes(10); + TDuration RolesWarnPeriod = TDuration::Minutes(20); + + bool operator==(const TRetrySettings& o) const { + return BackoffSettings == o.BackoffSettings && + MaxRandomSleepDefault == o.MaxRandomSleepDefault && + MaxRandomSleepWhenOk == o.MaxRandomSleepWhenOk && + RetriesOnStart == o.RetriesOnStart && + WorkerAwakingPeriod == o.WorkerAwakingPeriod && + DstsLimit == o.DstsLimit && + RolesUpdatePeriod == o.RolesUpdatePeriod && + RolesWarnPeriod == o.RolesWarnPeriod; + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/api/roles_fetcher.cpp b/library/cpp/tvmauth/client/misc/api/roles_fetcher.cpp new file mode 100644 index 0000000000..8f4b359e8c --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/roles_fetcher.cpp @@ -0,0 +1,164 @@ +#include "roles_fetcher.h" + +#include <library/cpp/tvmauth/client/misc/disk_cache.h> +#include <library/cpp/tvmauth/client/misc/roles/decoder.h> +#include <library/cpp/tvmauth/client/misc/roles/parser.h> + +#include <library/cpp/http/misc/httpcodes.h> +#include <library/cpp/string_utils/quote/quote.h> + +#include <util/string/builder.h> +#include <util/string/join.h> + +namespace NTvmAuth::NTvmApi { + static TString CreatePath(const TString& dir, const TString& file) { + return dir.EndsWith("/") + ? dir + file + : dir + "/" + file; + } + + TRolesFetcher::TRolesFetcher(const TRolesFetcherSettings& settings, TLoggerPtr logger) + : Settings_(settings) + , Logger_(logger) + , CacheFilePath_(CreatePath(Settings_.CacheDir, "roles")) + { + Client_ = std::make_unique<TKeepAliveHttpClient>( + Settings_.TiroleHost, + Settings_.TirolePort, + Settings_.Timeout, + Settings_.Timeout); + } + + TInstant TRolesFetcher::ReadFromDisk() { + TDiskReader dr(CacheFilePath_, Logger_.Get()); + if (!dr.Read()) { + return {}; + } + + std::pair<TString, TString> data = ParseDiskFormat(dr.Data()); + if (data.second != Settings_.IdmSystemSlug) { + Logger_->Warning( + TStringBuilder() << "Roles in disk cache are for another slug (" << data.second + << "). Self=" << Settings_.IdmSystemSlug); + return {}; + } + + CurrentRoles_.Set(NRoles::TParser::Parse(std::make_shared<TString>(std::move(data.first)))); + Logger_->Debug( + TStringBuilder() << "Succeed to read roles with revision " + << CurrentRoles_.Get()->GetMeta().Revision + << " from " << CacheFilePath_); + + return dr.Time(); + } + + bool TRolesFetcher::AreRolesOk() const { + return bool(GetCurrentRoles()); + } + + bool TRolesFetcher::IsTimeToUpdate(const TRetrySettings& settings, TDuration sinceUpdate) { + return settings.RolesUpdatePeriod < sinceUpdate; + } + + bool TRolesFetcher::ShouldWarn(const TRetrySettings& settings, TDuration sinceUpdate) { + return settings.RolesWarnPeriod < sinceUpdate; + } + + NUtils::TFetchResult TRolesFetcher::FetchActualRoles(const TString& serviceTicket) { + TStringStream out; + THttpHeaders outHeaders; + + TRequest req = CreateTiroleRequest(serviceTicket); + TKeepAliveHttpClient::THttpCode code = Client_->DoGet( + req.Url, + &out, + req.Headers, + &outHeaders); + + const THttpInputHeader* reqId = outHeaders.FindHeader("X-Request-Id"); + + Logger_->Debug( + TStringBuilder() << "Succeed to perform request for roles to " << Settings_.TiroleHost + << " (request_id=" << (reqId ? reqId->Value() : "") + << "). code=" << code); + + return {code, std::move(outHeaders), "/v1/get_actual_roles", out.Str(), {}}; + } + + void TRolesFetcher::Update(NUtils::TFetchResult&& fetchResult, TInstant now) { + if (fetchResult.Code == HTTP_NOT_MODIFIED) { + Y_ENSURE(CurrentRoles_.Get(), + "tirole did not return any roles because current roles are actual," + " but there are no roles in memory - this should never happen"); + return; + } + + Y_ENSURE(fetchResult.Code == HTTP_OK, + "Unexpected code from tirole: " << fetchResult.Code << ". " << fetchResult.Response); + + const THttpInputHeader* codec = fetchResult.Headers.FindHeader("X-Tirole-Compression"); + const TStringBuf codecBuf = codec ? codec->Value() : ""; + + NRoles::TRawPtr blob; + try { + blob = std::make_shared<TString>(NRoles::TDecoder::Decode( + codecBuf, + std::move(fetchResult.Response))); + } catch (const std::exception& e) { + throw yexception() << "Failed to decode blob with codec '" << codecBuf + << "': " << e.what(); + } + + CurrentRoles_.Set(NRoles::TParser::Parse(blob)); + + Logger_->Debug( + TStringBuilder() << "Succeed to update roles with revision " + << CurrentRoles_.Get()->GetMeta().Revision); + + TDiskWriter dw(CacheFilePath_, Logger_.Get()); + dw.Write(PrepareDiskFormat(*blob, Settings_.IdmSystemSlug), now); + } + + NTvmAuth::NRoles::TRolesPtr TRolesFetcher::GetCurrentRoles() const { + return CurrentRoles_.Get(); + } + + void TRolesFetcher::ResetConnection() { + Client_->ResetConnection(); + } + + static const char DELIMETER = '\t'; + + std::pair<TString, TString> TRolesFetcher::ParseDiskFormat(TStringBuf filebody) { + TStringBuf slug = filebody.RNextTok(DELIMETER); + return {TString(filebody), CGIUnescapeRet(slug)}; + } + + TString TRolesFetcher::PrepareDiskFormat(TStringBuf roles, TStringBuf slug) { + TStringStream res; + res.Reserve(roles.size() + 1 + slug.size()); + res << roles << DELIMETER << CGIEscapeRet(slug); + return res.Str(); + } + + TRolesFetcher::TRequest TRolesFetcher::CreateTiroleRequest(const TString& serviceTicket) const { + TRolesFetcher::TRequest res; + + TStringStream url; + url.Reserve(512); + url << "/v1/get_actual_roles?"; + url << "system_slug=" << CGIEscapeRet(Settings_.IdmSystemSlug) << "&"; + Settings_.ProcInfo.AddToRequest(url); + res.Url = std::move(url.Str()); + + res.Headers.reserve(2); + res.Headers.emplace(XYaServiceTicket_, serviceTicket); + + NRoles::TRolesPtr roles = CurrentRoles_.Get(); + if (roles) { + res.Headers.emplace(IfNoneMatch_, Join("", "\"", roles->GetMeta().Revision, "\"")); + } + + return res; + } +} diff --git a/library/cpp/tvmauth/client/misc/api/roles_fetcher.h b/library/cpp/tvmauth/client/misc/api/roles_fetcher.h new file mode 100644 index 0000000000..63691223b5 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/roles_fetcher.h @@ -0,0 +1,63 @@ +#pragma once + +#include "retry_settings.h" + +#include <library/cpp/tvmauth/client/misc/fetch_result.h> +#include <library/cpp/tvmauth/client/misc/proc_info.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/roles/roles.h> + +#include <library/cpp/tvmauth/client/logger.h> + +#include <library/cpp/http/simple/http_client.h> + +namespace NTvmAuth::NTvmApi { + struct TRolesFetcherSettings { + TString TiroleHost; + ui16 TirolePort = 0; + TString CacheDir; + NUtils::TProcInfo ProcInfo; + TTvmId SelfTvmId = 0; + TString IdmSystemSlug; + TDuration Timeout = TDuration::Seconds(30); + }; + + class TRolesFetcher { + public: + TRolesFetcher(const TRolesFetcherSettings& settings, TLoggerPtr logger); + + TInstant ReadFromDisk(); + + bool AreRolesOk() const; + static bool IsTimeToUpdate(const TRetrySettings& settings, TDuration sinceUpdate); + static bool ShouldWarn(const TRetrySettings& settings, TDuration sinceUpdate); + + NUtils::TFetchResult FetchActualRoles(const TString& serviceTicket); + void Update(NUtils::TFetchResult&& fetchResult, TInstant now = TInstant::Now()); + + NTvmAuth::NRoles::TRolesPtr GetCurrentRoles() const; + + void ResetConnection(); + + public: + static std::pair<TString, TString> ParseDiskFormat(TStringBuf filebody); + static TString PrepareDiskFormat(TStringBuf roles, TStringBuf slug); + + struct TRequest { + TString Url; + TKeepAliveHttpClient::THeaders Headers; + }; + TRequest CreateTiroleRequest(const TString& serviceTicket) const; + + private: + const TRolesFetcherSettings Settings_; + const TLoggerPtr Logger_; + const TString CacheFilePath_; + const TString XYaServiceTicket_ = "X-Ya-Service-Ticket"; + const TString IfNoneMatch_ = "If-None-Match"; + + NUtils::TProtectedValue<NTvmAuth::NRoles::TRolesPtr> CurrentRoles_; + + std::unique_ptr<TKeepAliveHttpClient> Client_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/api/settings.cpp b/library/cpp/tvmauth/client/misc/api/settings.cpp new file mode 100644 index 0000000000..082f089f75 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/settings.cpp @@ -0,0 +1,89 @@ +#include "settings.h" + +#include <util/datetime/base.h> +#include <util/stream/file.h> +#include <util/system/fs.h> + +#include <set> + +namespace NTvmAuth::NTvmApi { + void TClientSettings::CheckPermissions(const TString& dir) { + const TString name = dir + "/check.tmp"; + + try { + NFs::EnsureExists(dir); + + TFile file(name, CreateAlways | RdWr); + + NFs::Remove(name); + } catch (const std::exception& e) { + NFs::Remove(name); + ythrow TPermissionDenied() << "Permission denied to disk cache directory: " << e.what(); + } + } + + void TClientSettings::CheckValid() const { + if (DiskCacheDir) { + CheckPermissions(DiskCacheDir); + } + + if (TStringBuf(Secret)) { + Y_ENSURE_EX(NeedServiceTicketsFetching(), + TBrokenTvmClientSettings() << "Secret is present but destinations list is empty. It makes no sense"); + } + if (NeedServiceTicketsFetching()) { + Y_ENSURE_EX(SelfTvmId != 0, + TBrokenTvmClientSettings() << "SelfTvmId cannot be 0 if fetching of Service Tickets required"); + Y_ENSURE_EX((TStringBuf)Secret, + TBrokenTvmClientSettings() << "Secret is required for fetching of Service Tickets"); + } + + if (CheckServiceTickets) { + Y_ENSURE_EX(SelfTvmId != 0 || !ShouldCheckDst, + TBrokenTvmClientSettings() << "SelfTvmId cannot be 0 if checking of dst in Service Tickets required"); + } + + if (FetchRolesForIdmSystemSlug) { + Y_ENSURE_EX(DiskCacheDir, + TBrokenTvmClientSettings() << "Disk cache must be enabled to use roles: " + "they can be heavy"); + } + + bool needSmth = NeedServiceTicketsFetching() || + CheckServiceTickets || + CheckUserTicketsWithBbEnv; + Y_ENSURE_EX(needSmth, TBrokenTvmClientSettings() << "Invalid settings: nothing to do"); + + // Useless now: keep it here to avoid forgetting check from TDst. TODO: PASSP-35377 + for (const auto& dst : FetchServiceTicketsForDsts) { + Y_ENSURE_EX(dst.Id != 0, TBrokenTvmClientSettings() << "TvmId cannot be 0"); + } + // TODO: check only FetchServiceTicketsForDsts_ + // Python binding checks settings before normalization + for (const auto& [alias, dst] : FetchServiceTicketsForDstsWithAliases) { + Y_ENSURE_EX(dst.Id != 0, TBrokenTvmClientSettings() << "TvmId cannot be 0"); + } + Y_ENSURE_EX(TiroleTvmId != 0, TBrokenTvmClientSettings() << "TiroleTvmId cannot be 0"); + } + + TClientSettings TClientSettings::CloneNormalized() const { + TClientSettings res = *this; + + std::set<TTvmId> allDsts; + for (const auto& tvmid : res.FetchServiceTicketsForDsts) { + allDsts.insert(tvmid.Id); + } + for (const auto& [alias, tvmid] : res.FetchServiceTicketsForDstsWithAliases) { + allDsts.insert(tvmid.Id); + } + if (FetchRolesForIdmSystemSlug) { + allDsts.insert(res.TiroleTvmId); + } + + res.FetchServiceTicketsForDsts = {allDsts.begin(), allDsts.end()}; + + res.CheckValid(); + + return res; + } +} diff --git a/library/cpp/tvmauth/client/misc/api/settings.h b/library/cpp/tvmauth/client/misc/api/settings.h new file mode 100644 index 0000000000..ce5890cb87 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/settings.h @@ -0,0 +1,239 @@ +#pragma once + +#include <library/cpp/tvmauth/client/misc/settings.h> + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/checked_user_ticket.h> +#include <library/cpp/tvmauth/type.h> + +#include <library/cpp/string_utils/secret_string/secret_string.h> + +#include <util/datetime/base.h> +#include <util/generic/hash.h> +#include <util/generic/maybe.h> + +namespace NTvmAuth::NTvmApi { + /** + * Settings for TVM client. Uses https://tvm-api.yandex.net to get state. + * At least one of them is required: + * FetchServiceTicketsForDsts_/FetchServiceTicketsForDstsWithAliases_ + * CheckServiceTickets_ + * CheckUserTicketsWithBbEnv_ + */ + class TClientSettings: public NTvmAuth::TClientSettings { + public: + class TDst; + + /** + * Alias is an internal name for destinations within your code. + * You can associate a name with an tvm_id once in your code and use the name as an alias for + * tvm_id to each calling point. Useful for several environments: prod/test/etc. + * @example: + * // init + * static const TString MY_BACKEND = "my backend"; + * TDstMap map = {{MY_BACKEND, TDst(config.get("my_back_tvm_id"))}}; + * ... + * // per request + * TString t = tvmClient.GetServiceTicket(MY_BACKEND); + */ + using TDstMap = THashMap<TAlias, TDst>; + using TDstVector = TVector<TDst>; + + public: + /*! + * NOTE: Please use this option: it provides the best reliability + * NOTE: Client requires read/write permissions + * WARNING: The same directory can be used only: + * - for TVM clients with the same settings + * OR + * - for new client replacing previous - with another config. + * System user must be the same for processes with these clients inside. + * Implementation doesn't provide other scenarios. + */ + TString DiskCacheDir; + + // Required for Service Ticket fetching or checking + TTvmId SelfTvmId = 0; + + // Options for Service Tickets fetching + NSecretString::TSecretString Secret; + /*! + * Client will process both attrs: + * FetchServiceTicketsForDsts_, FetchServiceTicketsForDstsWithAliases_ + * WARNING: It is not way to provide authorization for incoming ServiceTickets! + * It is way only to send your ServiceTickets to your backend! + */ + TDstVector FetchServiceTicketsForDsts; + TDstMap FetchServiceTicketsForDstsWithAliases; + bool IsIncompleteTicketsSetAnError = true; + + // Options for Service Tickets checking + bool CheckServiceTickets = false; + + // Options for User Tickets checking + TMaybe<EBlackboxEnv> CheckUserTicketsWithBbEnv; + + // Options for roles fetching + TString FetchRolesForIdmSystemSlug; + /*! + * By default client checks src from ServiceTicket or default uid from UserTicket - + * to prevent you from forgetting to check it yourself. + * It does binary checks only: + * ticket gets status NoRoles, if there is no role for src or default uid. + * You need to check roles on your own if you have a non-binary role system or + * you have disabled ShouldCheckSrc/ShouldCheckDefaultUid + * + * You may need to disable this check in the following cases: + * - You use GetRoles() to provide verbose message (with revision). + * Double check may be inconsistent: + * binary check inside client uses revision of roles X - i.e. src 100500 has no role, + * exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + */ + bool ShouldCheckSrc = true; + bool ShouldCheckDefaultUid = true; + /*! + * By default client checks dst from ServiceTicket. If this check is switched off + * incorrect dst does not result in error of checked ticket status + * DANGEROUS: This case you must check dst manualy using @link TCheckedServiceTicket::GetDst() + */ + bool ShouldCheckDst = true; + + // In case of unsuccessful initialization at startup the client will be initialized in the background + bool EnableLazyInitialization = false; + + // Options for tests + TString TvmHost = "https://tvm-api.yandex.net"; + ui16 TvmPort = 443; + TString TiroleHost = "https://tirole-api.yandex.net"; + TDuration TvmSocketTimeout = TDuration::Seconds(5); + TDuration TvmConnectTimeout = TDuration::Seconds(30); + ui16 TirolePort = 443; + TTvmId TiroleTvmId = TIROLE_TVMID; + + // for debug purposes + TString LibVersionPrefix; + + void CheckValid() const; + TClientSettings CloneNormalized() const; + + static inline const TTvmId TIROLE_TVMID = 2028120; + static inline const TTvmId TIROLE_TVMID_TEST = 2026536; + + // DEPRECATED API + // TODO: get rid of it: PASSP-35377 + public: + // Deprecated: set attributes directly + void SetSelfTvmId(TTvmId selfTvmId) { + SelfTvmId = selfTvmId; + } + + // Deprecated: set attributes directly + void EnableServiceTicketChecking() { + CheckServiceTickets = true; + } + + // Deprecated: set attributes directly + void EnableUserTicketChecking(EBlackboxEnv env) { + CheckUserTicketsWithBbEnv = env; + } + + // Deprecated: set attributes directly + void SetTvmHostPort(const TString& host, ui16 port) { + TvmHost = host; + TvmPort = port; + } + + // Deprecated: set attributes directly + void SetTiroleHostPort(const TString& host, ui16 port) { + TiroleHost = host; + TirolePort = port; + } + + // Deprecated: set attributes directly + void EnableRolesFetching(const TString& systemSlug, TTvmId tiroleTvmId = TIROLE_TVMID) { + TiroleTvmId = tiroleTvmId; + FetchRolesForIdmSystemSlug = systemSlug; + } + + // Deprecated: set attributes directly + void DoNotCheckSrcByDefault() { + ShouldCheckSrc = false; + } + + // Deprecated: set attributes directly + void DoNotCheckDefaultUidByDefault() { + ShouldCheckDefaultUid = false; + } + + // Deprecated: set attributes directly + void SetDiskCacheDir(const TString& dir) { + DiskCacheDir = dir; + } + + // Deprecated: set attributes directly + void EnableServiceTicketsFetchOptions(const TStringBuf selfSecret, + TDstMap&& dsts, + const bool considerIncompleteTicketsSetAsError = true) { + IsIncompleteTicketsSetAnError = considerIncompleteTicketsSetAsError; + Secret = selfSecret; + + FetchServiceTicketsForDsts = TDstVector{}; + FetchServiceTicketsForDsts.reserve(dsts.size()); + for (const auto& pair : dsts) { + FetchServiceTicketsForDsts.push_back(pair.second); + } + + FetchServiceTicketsForDstsWithAliases = std::move(dsts); + } + + // Deprecated: set attributes directly + void EnableServiceTicketsFetchOptions(const TStringBuf selfSecret, + TDstVector&& dsts, + const bool considerIncompleteTicketsSetAsError = true) { + IsIncompleteTicketsSetAnError = considerIncompleteTicketsSetAsError; + Secret = selfSecret; + FetchServiceTicketsForDsts = std::move(dsts); + } + + public: + bool IsServiceTicketFetchingRequired() const { + return bool(Secret.Value()); + } + + bool NeedServiceTicketsFetching() const { + return !FetchServiceTicketsForDsts.empty() || + !FetchServiceTicketsForDstsWithAliases.empty() || + FetchRolesForIdmSystemSlug; + } + + // TODO: get rid of TDst: PASSP-35377 + class TDst { + public: + TDst(TTvmId id) + : Id(id) + { + Y_ENSURE_EX(id != 0, TBrokenTvmClientSettings() << "TvmId cannot be 0"); + } + + TTvmId Id; + + bool operator==(const TDst& o) const { + return Id == o.Id; + } + + bool operator<(const TDst& o) const { + return Id < o.Id; + } + + public: // for python binding + TDst() + : Id(0) + { + } + }; + + public: + static void CheckPermissions(const TString& dir); + }; +} diff --git a/library/cpp/tvmauth/client/misc/api/threaded_updater.cpp b/library/cpp/tvmauth/client/misc/api/threaded_updater.cpp new file mode 100644 index 0000000000..51df498ead --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/threaded_updater.cpp @@ -0,0 +1,1128 @@ +#include "threaded_updater.h" + +#include <library/cpp/tvmauth/client/misc/checker.h> +#include <library/cpp/tvmauth/client/misc/default_uid_checker.h> +#include <library/cpp/tvmauth/client/misc/disk_cache.h> +#include <library/cpp/tvmauth/client/misc/getter.h> +#include <library/cpp/tvmauth/client/misc/src_checker.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/retry_settings/v1/settings.pb.h> + +#include <library/cpp/tvmauth/client/logger.h> + +#include <library/cpp/json/json_reader.h> + +#include <util/stream/str.h> +#include <util/string/builder.h> +#include <util/string/cast.h> +#include <util/system/thread.h> + +namespace NTvmAuth::NTvmApi { + static TString CreatePublicKeysUrl(const TClientSettings& settings, + const NUtils::TProcInfo& procInfo) { + TStringStream s; + s << "/2/keys"; + s << "?"; + procInfo.AddToRequest(s); + + s << "&get_retry_settings=yes"; + + if (settings.SelfTvmId != 0) { + s << "&src=" << settings.SelfTvmId; + } + + if (settings.CheckUserTicketsWithBbEnv) { + s << "&env=" << static_cast<int>(*settings.CheckUserTicketsWithBbEnv); + } + + return s.Str(); + } + + TAsyncUpdaterPtr TThreadedUpdater::Create(const TClientSettings& settings, TLoggerPtr logger) { + Y_ENSURE_EX(logger, TNonRetriableException() << "Logger is required"); + THolder<TThreadedUpdater> p(new TThreadedUpdater(settings, std::move(logger))); + try { + p->Init(); + } catch (const TRetriableException& e) { + if (!settings.EnableLazyInitialization) { + throw e; + } + } + + p->StartWorker(); + return p.Release(); + } + + TThreadedUpdater::~TThreadedUpdater() { + ExpBackoff_.SetEnabled(false); + ExpBackoff_.Interrupt(); + StopWorker(); // Required here to avoid using of deleted members + } + + TClientStatus TThreadedUpdater::GetStatus() const { + const TClientStatus::ECode state = GetState(); + return TClientStatus(state, GetLastError(state == TClientStatus::Ok || state == TClientStatus::IncompleteTicketsSet)); + } + + TString TThreadedUpdater::GetServiceTicketFor(const TClientSettings::TAlias& dst) const { + Y_ENSURE_EX(Settings_.NeedServiceTicketsFetching(), TBrokenTvmClientSettings() + << "Need to enable ServiceTickets fetching"); + + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TString TThreadedUpdater::GetServiceTicketFor(const TTvmId dst) const { + Y_ENSURE_EX(Settings_.NeedServiceTicketsFetching(), TBrokenTvmClientSettings() + << "Need to enable ServiceTickets fetching"); + + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TCheckedServiceTicket TThreadedUpdater::CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags) const { + Y_ENSURE_EX(Settings_.CheckServiceTickets, TBrokenTvmClientSettings() + << "Need to TClientSettings::EnableServiceTicketChecking()"); + + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + TServiceContextPtr c = GetCachedServiceContext(); + TCheckedServiceTicket res = TServiceTicketChecker::Check(ticket, c, flags); + if (Settings_.ShouldCheckSrc && Settings_.FetchRolesForIdmSystemSlug && res) { + NRoles::TRolesPtr roles = GetRoles(); + return TSrcChecker::Check(std::move(res), roles); + } + return res; + } + + TCheckedUserTicket TThreadedUpdater::CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overridenEnv) const { + Y_ENSURE_EX(Settings_.CheckUserTicketsWithBbEnv, TBrokenTvmClientSettings() + << "Need to use TClientSettings::EnableUserTicketChecking()"); + + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + auto c = GetCachedUserContext(overridenEnv); + TCheckedUserTicket res = TUserTicketChecker::Check(ticket, c); + if (Settings_.ShouldCheckDefaultUid && Settings_.FetchRolesForIdmSystemSlug && res && res.GetEnv() == EBlackboxEnv::ProdYateam) { + NRoles::TRolesPtr roles = GetRoles(); + return TDefaultUidChecker::Check(std::move(res), roles); + } + return res; + } + + NRoles::TRolesPtr TThreadedUpdater::GetRoles() const { + Y_ENSURE_EX(RolesFetcher_, + TBrokenTvmClientSettings() << "Roles were not configured in settings"); + + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + return RolesFetcher_->GetCurrentRoles(); + } + + TClientStatus::ECode TThreadedUpdater::GetState() const { + const TInstant now = TInstant::Now(); + + if (!IsInited()) { + return TClientStatus::NotInitialized; + } + + if (Settings_.IsServiceTicketFetchingRequired()) { + if (AreServiceTicketsInvalid(now)) { + return TClientStatus::Error; + } + auto tickets = GetCachedServiceTickets(); + if (!tickets) { + return TClientStatus::Error; + } + if (tickets->TicketsById.size() < GetDsts()->size()) { + if (Settings_.IsIncompleteTicketsSetAnError) { + return TClientStatus::Error; + } else { + return TClientStatus::IncompleteTicketsSet; + } + } + } + if ((Settings_.CheckServiceTickets || Settings_.CheckUserTicketsWithBbEnv) && ArePublicKeysInvalid(now)) { + return TClientStatus::Error; + } + + const TDuration sincePublicKeysUpdate = now - GetUpdateTimeOfPublicKeys(); + const TDuration sinceServiceTicketsUpdate = now - GetUpdateTimeOfServiceTickets(); + const TDuration sinceRolesUpdate = now - GetUpdateTimeOfRoles(); + + if (Settings_.IsServiceTicketFetchingRequired() && sinceServiceTicketsUpdate > ServiceTicketsDurations_.Expiring) { + return TClientStatus::Warning; + } + if ((Settings_.CheckServiceTickets || Settings_.CheckUserTicketsWithBbEnv) && + sincePublicKeysUpdate > PublicKeysDurations_.Expiring) + { + return TClientStatus::Warning; + } + if (RolesFetcher_ && TRolesFetcher::ShouldWarn(RetrySettings_, sinceRolesUpdate)) { + return TClientStatus::Warning; + } + + return TClientStatus::Ok; + } + + TThreadedUpdater::TThreadedUpdater(const TClientSettings& settings, TLoggerPtr logger) + : TThreadedUpdaterBase( + TRetrySettings{}.WorkerAwakingPeriod, + std::move(logger), + settings.TvmHost, + settings.TvmPort, + settings.TvmSocketTimeout, + settings.TvmConnectTimeout) + , ExpBackoff_(RetrySettings_.BackoffSettings) + , ServiceTicketBatchUpdateMutex_(std::make_unique<std::mutex>()) + , Settings_(settings.CloneNormalized()) + , ProcInfo_(NUtils::TProcInfo::Create(Settings_.LibVersionPrefix)) + , PublicKeysUrl_(CreatePublicKeysUrl(Settings_, ProcInfo_)) + , DstAliases_(MakeAliasMap(Settings_)) + , Headers_({{"Content-Type", "application/x-www-form-urlencoded"}}) + , Random_(TInstant::Now().MicroSeconds()) + { + if (Settings_.IsServiceTicketFetchingRequired()) { + SigningContext_ = TServiceContext::SigningFactory(Settings_.Secret); + } + + if (Settings_.IsServiceTicketFetchingRequired()) { + Destinations_ = std::make_shared<const TDstSet>(Settings_.FetchServiceTicketsForDsts.begin(), Settings_.FetchServiceTicketsForDsts.end()); + } + + PublicKeysDurations_.RefreshPeriod = TDuration::Days(1); + ServiceTicketsDurations_.RefreshPeriod = TDuration::Hours(1); + + if (Settings_.CheckUserTicketsWithBbEnv) { + SetBbEnv(*Settings_.CheckUserTicketsWithBbEnv); + } + + if (Settings_.FetchRolesForIdmSystemSlug) { + RolesFetcher_ = std::make_unique<TRolesFetcher>( + TRolesFetcherSettings{ + Settings_.TiroleHost, + Settings_.TirolePort, + Settings_.DiskCacheDir, + ProcInfo_, + Settings_.SelfTvmId, + Settings_.FetchRolesForIdmSystemSlug, + }, + Logger_); + } + + if (Settings_.DiskCacheDir) { + TString path = Settings_.DiskCacheDir; + if (path.back() != '/') { + path.push_back('/'); + } + + if (Settings_.IsServiceTicketFetchingRequired()) { + ServiceTicketsFilepath_ = path; + ServiceTicketsFilepath_.append("service_tickets"); + } + + if (Settings_.CheckServiceTickets || Settings_.CheckUserTicketsWithBbEnv) { + PublicKeysFilepath_ = path; + PublicKeysFilepath_.append("public_keys"); + } + + RetrySettingsFilepath_ = path + "retry_settings"; + } else { + LogInfo("Disk cache disabled. Please set disk cache directory in settings for best reliability"); + } + } + + void TThreadedUpdater::Init() { + ReadStateFromDisk(); + ClearErrors(); + ExpBackoff_.SetEnabled(false); + + // First of all try to get tickets: there are a lot of reasons to fail this request. + // As far as disk cache usually disabled, client will not fetch keys before fail on every ctor call. + UpdateServiceTickets(); + if (!AreServicesTicketsOk()) { + ThrowLastError(); + } + + UpdatePublicKeys(); + if (!IsServiceContextOk() || !IsUserContextOk()) { + ThrowLastError(); + } + + UpdateRoles(); + if (RolesFetcher_ && !RolesFetcher_->AreRolesOk()) { + ThrowLastError(); + } + + SetInited(true); + ExpBackoff_.SetEnabled(true); + } + + void TThreadedUpdater::UpdateServiceTickets() { + if (!Settings_.IsServiceTicketFetchingRequired()) { + return; + } + + TInstant stut = GetUpdateTimeOfServiceTickets(); + try { + if (IsTimeToUpdateServiceTickets(stut)) { + UpdateAllServiceTickets(); + NeedFetchMissingServiceTickets_ = false; + } else if (NeedFetchMissingServiceTickets_) { + TDstSetPtr dsts = GetDsts(); + if (GetCachedServiceTickets()->TicketsById.size() < dsts->size()) { + UpdateMissingServiceTickets(*dsts); + NeedFetchMissingServiceTickets_ = false; + } + } + if (AreServicesTicketsOk()) { + ClearError(EScope::ServiceTickets); + } + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::ServiceTickets, e.what()); + LogWarning(TStringBuilder() << "Failed to update service tickets: " << e.what()); + if (TInstant::Now() - stut > ServiceTicketsDurations_.Expiring) { + LogError("Service tickets have not been refreshed for too long period"); + } + } + } + + void TThreadedUpdater::UpdateAllServiceTickets() { + TDstSetPtr dsts = GetDsts(); + THttpResult st = GetServiceTicketsFromHttp(*dsts, RetrySettings_.DstsLimit); + + std::unique_lock lock(*ServiceTicketBatchUpdateMutex_); + + auto oldCache = GetCachedServiceTickets(); + if (oldCache) { + if (dsts->size() < GetDsts()->size()) { + for (const auto& pair : oldCache->TicketsById) { + st.TicketsWithErrors.Tickets.insert(pair); + } + } + + for (const auto& pair : oldCache->ErrorsById) { + st.TicketsWithErrors.Errors.insert(pair); + } + } + + UpdateServiceTicketsCache(std::move(st.TicketsWithErrors), TInstant::Now()); + + lock.unlock(); + + if (ServiceTicketsFilepath_) { + DiskCacheServiceTickets_ = CreateJsonArray(st.Responses); + TDiskWriter w(ServiceTicketsFilepath_, Logger_.Get()); + w.Write(PrepareTicketsForDisk(DiskCacheServiceTickets_, Settings_.SelfTvmId)); + } + } + + TServiceTicketsPtr TThreadedUpdater::UpdateMissingServiceTickets(const TDstSet& required) { + TServiceTicketsPtr cache = GetCachedServiceTickets(); + TClientSettings::TDstVector missingDsts = FindMissingDsts(cache, required); + + if (missingDsts.empty()) { + return cache; + } + + THttpResult st = GetServiceTicketsFromHttp(missingDsts, RetrySettings_.DstsLimit); + + std::unique_lock lock(*ServiceTicketBatchUpdateMutex_); + + cache = GetCachedServiceTickets(); + size_t gotTickets = st.TicketsWithErrors.Tickets.size(); + + TDstSetPtr oldDsts = GetDsts(); + std::shared_ptr<TDstSet> newDsts = std::make_shared<TDstSet>(oldDsts->begin(), oldDsts->end()); + + for (const auto& pair : cache->TicketsById) { + st.TicketsWithErrors.Tickets.insert(pair); + } + for (const auto& pair : cache->ErrorsById) { + st.TicketsWithErrors.Errors.insert(pair); + } + for (const auto& pair : st.TicketsWithErrors.Tickets) { + st.TicketsWithErrors.Errors.erase(pair.first); + newDsts->insert(pair.first); + } + + TServiceTicketsPtr c = UpdateServiceTicketsCachePartly( + std::move(st.TicketsWithErrors), + gotTickets); + + SetDsts(std::move(newDsts)); + + lock.unlock(); + + if (!c) { + LogWarning("UpdateMissingServiceTickets: new cache is NULL. BUG?"); + c = cache; + } + + if (!ServiceTicketsFilepath_) { + return c; + } + + DiskCacheServiceTickets_ = AppendToJsonArray(DiskCacheServiceTickets_, st.Responses); + + TDiskWriter w(ServiceTicketsFilepath_, Logger_.Get()); + w.Write(PrepareTicketsForDisk(DiskCacheServiceTickets_, Settings_.SelfTvmId)); + + return c; + } + + void TThreadedUpdater::UpdatePublicKeys() { + if (!Settings_.CheckServiceTickets && !Settings_.CheckUserTicketsWithBbEnv) { + return; + } + + TInstant pkut = GetUpdateTimeOfPublicKeys(); + if (!IsTimeToUpdatePublicKeys(pkut)) { + return; + } + + try { + TString publicKeys = GetPublicKeysFromHttp(); + + UpdatePublicKeysCache(publicKeys, TInstant::Now()); + if (PublicKeysFilepath_) { + TDiskWriter w(PublicKeysFilepath_, Logger_.Get()); + w.Write(publicKeys); + } + if (IsServiceContextOk() && IsUserContextOk()) { + ClearError(EScope::PublicKeys); + } + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::PublicKeys, e.what()); + LogWarning(TStringBuilder() << "Failed to update public keys: " << e.what()); + if (TInstant::Now() - pkut > PublicKeysDurations_.Expiring) { + LogError("Public keys have not been refreshed for too long period"); + } + } + } + + void TThreadedUpdater::UpdateRoles() { + if (!RolesFetcher_) { + return; + } + + TInstant rut = GetUpdateTimeOfRoles(); + if (!TRolesFetcher::IsTimeToUpdate(RetrySettings_, TInstant::Now() - rut)) { + return; + } + + struct TCloser { + TRolesFetcher* Fetcher; + ~TCloser() { + Fetcher->ResetConnection(); + } + } closer{RolesFetcher_.get()}; + + try { + TServiceTicketsPtr st = GetCachedServiceTickets(); + Y_ENSURE(st, "No one service ticket in memory: how it possible?"); + auto it = st->TicketsById.find(Settings_.TiroleTvmId); + Y_ENSURE(it != st->TicketsById.end(), + "Missing tvmid for tirole in cache: " << Settings_.TiroleTvmId); + + RolesFetcher_->Update( + FetchWithRetries( + [&]() { return RolesFetcher_->FetchActualRoles(it->second); }, + EScope::Roles)); + SetUpdateTimeOfRoles(TInstant::Now()); + + if (RolesFetcher_->AreRolesOk()) { + ClearError(EScope::Roles); + } + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::Roles, e.what()); + LogWarning(TStringBuilder() << "Failed to update roles: " << e.what()); + if (TRolesFetcher::ShouldWarn(RetrySettings_, TInstant::Now() - rut)) { + LogError("Roles have not been refreshed for too long period"); + } + } + } + + TServiceTicketsPtr TThreadedUpdater::UpdateServiceTicketsCachePartly( + TAsyncUpdaterBase::TPairTicketsErrors&& tickets, + size_t got) { + size_t count = tickets.Tickets.size(); + TServiceTicketsPtr c = MakeIntrusiveConst<TServiceTickets>(std::move(tickets.Tickets), + std::move(tickets.Errors), + DstAliases_); + SetServiceTickets(c); + LogInfo(TStringBuilder() + << "Cache was partly updated with " << got + << " service ticket(s). total: " << count); + + return c; + } + + void TThreadedUpdater::UpdateServiceTicketsCache(TPairTicketsErrors&& tickets, TInstant time) { + size_t count = tickets.Tickets.size(); + SetServiceTickets(MakeIntrusiveConst<TServiceTickets>(std::move(tickets.Tickets), + std::move(tickets.Errors), + DstAliases_)); + + SetUpdateTimeOfServiceTickets(time); + + if (count > 0) { + LogInfo(TStringBuilder() << "Cache was updated with " << count << " service ticket(s): " << time); + } + } + + void TThreadedUpdater::UpdatePublicKeysCache(const TString& publicKeys, TInstant time) { + if (publicKeys.empty()) { + return; + } + + if (Settings_.CheckServiceTickets) { + SetServiceContext(MakeIntrusiveConst<TServiceContext>( + TServiceContext::CheckingFactory(Settings_.SelfTvmId, + publicKeys))); + } + + if (Settings_.CheckUserTicketsWithBbEnv) { + SetUserContext(publicKeys); + } + + SetUpdateTimeOfPublicKeys(time); + + LogInfo(TStringBuilder() << "Cache was updated with public keys: " << time); + } + + void TThreadedUpdater::ReadStateFromDisk() { + try { + TServiceTicketsFromDisk st; + std::tie(st, StartUpCache_) = ReadServiceTicketsFromDisk(); + UpdateServiceTicketsCache(std::move(st.TicketsWithErrors), st.BornDate); + DiskCacheServiceTickets_ = st.FileBody; + } catch (const std::exception& e) { + LogWarning(TStringBuilder() << "Failed to read service tickets from disk: " << e.what()); + } + + try { + std::pair<TString, TInstant> pk = ReadPublicKeysFromDisk(); + UpdatePublicKeysCache(pk.first, pk.second); + } catch (const std::exception& e) { + LogWarning(TStringBuilder() << "Failed to read public keys from disk: " << e.what()); + } + + try { + TString rs = ReadRetrySettingsFromDisk(); + UpdateRetrySettings(rs); + } catch (const std::exception& e) { + LogWarning(TStringBuilder() << "Failed to read retry settings from disk: " << e.what()); + } + + try { + if (RolesFetcher_) { + SetUpdateTimeOfRoles(RolesFetcher_->ReadFromDisk()); + } + } catch (const std::exception& e) { + LogWarning(TStringBuilder() << "Failed to read roles from disk: " << e.what()); + } + } + + std::pair<TThreadedUpdater::TServiceTicketsFromDisk, TThreadedUpdater::TServiceTicketsFromDisk> TThreadedUpdater::ReadServiceTicketsFromDisk() const { + if (!ServiceTicketsFilepath_) { + return {}; + } + + TDiskReader r(ServiceTicketsFilepath_, Logger_.Get()); + if (!r.Read()) { + return {}; + } + + std::pair<TStringBuf, TTvmId> data = ParseTicketsFromDisk(r.Data()); + if (data.second != Settings_.SelfTvmId) { + TStringStream s; + s << "Disk cache is for another tvmId (" << data.second << "). "; + s << "Self=" << Settings_.SelfTvmId; + LogWarning(s.Str()); + return {}; + } + + TThreadedUpdater::TServiceTicketsFromDisk resDst{ + .BornDate = r.Time(), + .FileBody = TString(data.first), + }; + + TThreadedUpdater::TServiceTicketsFromDisk resAll{ + .BornDate = r.Time(), + .FileBody = TString(data.first), + }; + + TDstSetPtr dsts = GetDsts(); + ParseTicketsFromResponse(data.first, *dsts, resDst.TicketsWithErrors); + + if (IsInvalid(TServiceTickets::GetInvalidationTime(resDst.TicketsWithErrors.Tickets), TInstant::Now())) { + LogWarning("Disk cache (service tickets) is too old"); + return {}; + } + + try { + ParseTicketsFromDiskCache(data.first, resAll.TicketsWithErrors); + } catch (std::exception& e) { + LogWarning(TStringBuilder() << "Failed to parse all service tickets from disk cache: " << e.what()); + LogInfo(TStringBuilder() << "Got " << resDst.TicketsWithErrors.Tickets.size() << " service ticket(s) from disk"); + return {std::move(resDst), {}}; + } + + if (resAll.TicketsWithErrors.Tickets.empty()) { + LogInfo(TStringBuilder() << "Got " << resDst.TicketsWithErrors.Tickets.size() << " service ticket(s) from disk"); + return {std::move(resDst), std::move(resAll)}; + } + + if (IsInvalid(TServiceTickets::GetInvalidationTime(resAll.TicketsWithErrors.Tickets), TInstant::Now())) { + LogWarning("Disk cache (service tickets) is too old"); + LogInfo(TStringBuilder() << "Got " << resDst.TicketsWithErrors.Tickets.size() << " service ticket(s) from disk"); + return {std::move(resDst), {}}; + } + + LogInfo(TStringBuilder() << "Got " << resAll.TicketsWithErrors.Tickets.size() << " service ticket(s) from disk"); + return {std::move(resDst), std::move(resAll)}; + } + + std::pair<TString, TInstant> TThreadedUpdater::ReadPublicKeysFromDisk() const { + if (!PublicKeysFilepath_) { + return {}; + } + + TDiskReader r(PublicKeysFilepath_, Logger_.Get()); + if (!r.Read()) { + return {}; + } + + if (TInstant::Now() - r.Time() > PublicKeysDurations_.Invalid) { + LogWarning("Disk cache (public keys) is too old"); + return {}; + } + + return {r.Data(), r.Time()}; + } + + TString TThreadedUpdater::ReadRetrySettingsFromDisk() const { + if (!RetrySettingsFilepath_) { + return {}; + } + + TDiskReader r(RetrySettingsFilepath_, Logger_.Get()); + if (!r.Read()) { + return {}; + } + + return r.Data(); + } + + template <class Dsts> + TThreadedUpdater::THttpResult TThreadedUpdater::GetServiceTicketsFromHttp(const Dsts& dsts, const size_t dstLimit) const { + Y_ENSURE(SigningContext_, "Internal error"); + + TClientSettings::TDstVector part; + part.reserve(dstLimit); + THttpResult res; + res.TicketsWithErrors.Tickets.reserve(dsts.size()); + res.Responses.reserve(dsts.size() / dstLimit + 1); + + for (auto it = dsts.begin(); it != dsts.end();) { + part.clear(); + for (size_t count = 0; it != dsts.end() && count < dstLimit; ++count, ++it) { + part.push_back(*it); + } + + TString response = + FetchWithRetries( + [this, &part]() { + // create request here to keep 'ts' actual + return FetchServiceTicketsFromHttp(PrepareRequestForServiceTickets( + Settings_.SelfTvmId, + *SigningContext_, + part, + ProcInfo_)); + }, + EScope::ServiceTickets) + .Response; + ParseTicketsFromResponse(response, part, res.TicketsWithErrors); + LogDebug(TStringBuilder() + << "Response with service tickets for " << part.size() + << " destination(s) was successfully fetched from " << TvmUrl_); + + res.Responses.push_back(response); + } + + LogDebug(TStringBuilder() + << "Got responses with service tickets with " << res.Responses.size() << " pages for " + << dsts.size() << " destination(s)"); + for (const auto& p : res.TicketsWithErrors.Errors) { + LogError(TStringBuilder() + << "Failed to get service ticket for dst=" << p.first << ": " << p.second); + } + + return res; + } + + TString TThreadedUpdater::GetPublicKeysFromHttp() const { + TString publicKeys = + FetchWithRetries( + [this]() { return FetchPublicKeysFromHttp(); }, + EScope::PublicKeys) + .Response; + + LogDebug("Public keys were successfully fetched from " + TvmUrl_); + + return publicKeys; + } + + TServiceTickets::TMapIdStr TThreadedUpdater::GetRequestedTicketsFromStartUpCache(const TDstSet& dsts) const { + TServiceTickets::TMapIdStr res; + for (const TClientSettings::TDst& dst : dsts) { + auto it = StartUpCache_.TicketsWithErrors.Tickets.find(dst.Id); + if (it != StartUpCache_.TicketsWithErrors.Tickets.end()) { + res[dst.Id] = it->second; + } + } + return res; + } + + TInstant TThreadedUpdater::GetStartUpCacheBornDate() const { + return StartUpCache_.BornDate; + } + + NUtils::TFetchResult TThreadedUpdater::FetchServiceTicketsFromHttp(const TString& body) const { + TStringStream s; + + THttpHeaders outHeaders; + TKeepAliveHttpClient::THttpCode code = GetClient().DoPost("/2/ticket", body, &s, Headers_, &outHeaders); + + const THttpInputHeader* settings = outHeaders.FindHeader("X-Ya-Retry-Settings"); + + return {code, {}, "/2/ticket", s.Str(), settings ? settings->Value() : ""}; + } + + NUtils::TFetchResult TThreadedUpdater::FetchPublicKeysFromHttp() const { + TStringStream s; + + THttpHeaders outHeaders; + TKeepAliveHttpClient::THttpCode code = GetClient().DoGet(PublicKeysUrl_, &s, {}, &outHeaders); + + const THttpInputHeader* settings = outHeaders.FindHeader("X-Ya-Retry-Settings"); + + return {code, {}, "/2/keys", s.Str(), settings ? settings->Value() : ""}; + } + + bool TThreadedUpdater::UpdateRetrySettings(const TString& header) const { + if (header.empty()) { + // Probably it is some kind of test? + return false; + } + + try { + TString raw = NUtils::Base64url2bin(header); + Y_ENSURE(raw, "Invalid base64url in settings"); + + retry_settings::v1::Settings proto; + Y_ENSURE(proto.ParseFromString(raw), "Invalid proto"); + + // This ugly hack helps to process these settings in any case + TThreadedUpdater& this_ = *const_cast<TThreadedUpdater*>(this); + TRetrySettings& res = this_.RetrySettings_; + + TStringStream diff; + auto update = [&diff](auto& l, const auto& r, TStringBuf desc) { + if (l != r) { + diff << desc << ":" << l << "->" << r << ";"; + l = r; + } + }; + + if (proto.has_exponential_backoff_min_sec()) { + update(res.BackoffSettings.Min, + TDuration::Seconds(proto.exponential_backoff_min_sec()), + "exponential_backoff_min"); + } + if (proto.has_exponential_backoff_max_sec()) { + update(res.BackoffSettings.Max, + TDuration::Seconds(proto.exponential_backoff_max_sec()), + "exponential_backoff_max"); + } + if (proto.has_exponential_backoff_factor()) { + update(res.BackoffSettings.Factor, + proto.exponential_backoff_factor(), + "exponential_backoff_factor"); + } + if (proto.has_exponential_backoff_jitter()) { + update(res.BackoffSettings.Jitter, + proto.exponential_backoff_jitter(), + "exponential_backoff_jitter"); + } + this_.ExpBackoff_.UpdateSettings(res.BackoffSettings); + + if (proto.has_max_random_sleep_default()) { + update(res.MaxRandomSleepDefault, + TDuration::MilliSeconds(proto.max_random_sleep_default()), + "max_random_sleep_default"); + } + if (proto.has_max_random_sleep_when_ok()) { + update(res.MaxRandomSleepWhenOk, + TDuration::MilliSeconds(proto.max_random_sleep_when_ok()), + "max_random_sleep_when_ok"); + } + if (proto.has_retries_on_start()) { + Y_ENSURE(proto.retries_on_start(), "retries_on_start==0"); + update(res.RetriesOnStart, + proto.retries_on_start(), + "retries_on_start"); + } + if (proto.has_retries_in_background()) { + Y_ENSURE(proto.retries_in_background(), "retries_in_background==0"); + update(res.RetriesInBackground, + proto.retries_in_background(), + "retries_in_background"); + } + if (proto.has_worker_awaking_period_sec()) { + update(res.WorkerAwakingPeriod, + TDuration::Seconds(proto.worker_awaking_period_sec()), + "worker_awaking_period"); + this_.WorkerAwakingPeriod_ = res.WorkerAwakingPeriod; + } + if (proto.has_dsts_limit()) { + Y_ENSURE(proto.dsts_limit(), "dsts_limit==0"); + update(res.DstsLimit, + proto.dsts_limit(), + "dsts_limit"); + } + + if (proto.has_roles_update_period_sec()) { + Y_ENSURE(proto.roles_update_period_sec(), "roles_update_period==0"); + update(res.RolesUpdatePeriod, + TDuration::Seconds(proto.roles_update_period_sec()), + "roles_update_period_sec"); + } + if (proto.has_roles_warn_period_sec()) { + Y_ENSURE(proto.roles_warn_period_sec(), "roles_warn_period_sec==0"); + update(res.RolesWarnPeriod, + TDuration::Seconds(proto.roles_warn_period_sec()), + "roles_warn_period_sec"); + } + + if (diff.empty()) { + return false; + } + + LogDebug("Retry settings were updated: " + diff.Str()); + return true; + } catch (const std::exception& e) { + LogWarning(TStringBuilder() + << "Failed to update retry settings from server, header '" + << header << "': " + << e.what()); + } + + return false; + } + + template <typename Func> + NUtils::TFetchResult TThreadedUpdater::FetchWithRetries(Func func, EScope scope) const { + const ui32 tries = IsInited() ? RetrySettings_.RetriesInBackground + : RetrySettings_.RetriesOnStart; + + for (size_t idx = 1;; ++idx) { + RandomSleep(); + + try { + NUtils::TFetchResult result = func(); + + if (UpdateRetrySettings(result.RetrySettings) && RetrySettingsFilepath_) { + TDiskWriter w(RetrySettingsFilepath_, Logger_.Get()); + w.Write(result.RetrySettings); + } + + if (400 <= result.Code && result.Code <= 499) { + throw TNonRetriableException() << ProcessHttpError(scope, result.Path, result.Code, result.Response); + } + if (result.Code < 200 || result.Code >= 399) { + throw yexception() << ProcessHttpError(scope, result.Path, result.Code, result.Response); + } + + ExpBackoff_.Decrease(); + return result; + } catch (const TNonRetriableException& e) { + LogWarning(TStringBuilder() << "Failed to get " << scope << ": " << e.what()); + ExpBackoff_.Increase(); + throw; + } catch (const std::exception& e) { + LogWarning(TStringBuilder() << "Failed to get " << scope << ": " << e.what()); + ExpBackoff_.Increase(); + if (idx >= tries) { + throw; + } + } + } + + throw yexception() << "unreachable"; + } + + void TThreadedUpdater::RandomSleep() const { + const TDuration maxSleep = TClientStatus::ECode::Ok == GetState() + ? RetrySettings_.MaxRandomSleepWhenOk + : RetrySettings_.MaxRandomSleepDefault; + + if (maxSleep) { + ui32 toSleep = Random_.GenRand() % maxSleep.MilliSeconds(); + ExpBackoff_.Sleep(TDuration::MilliSeconds(toSleep)); + } + } + + TString TThreadedUpdater::PrepareRequestForServiceTickets(TTvmId src, + const TServiceContext& ctx, + const TClientSettings::TDstVector& dsts, + const NUtils::TProcInfo& procInfo, + time_t now) { + TStringStream s; + + const TString ts = IntToString<10>(now); + TStringStream dst; + dst.Reserve(10 * dsts.size()); + for (const TClientSettings::TDst& d : dsts) { + if (dst.Str()) { + dst << ','; + } + dst << d.Id; + } + + s << "grant_type=client_credentials"; + s << "&src=" << src; + s << "&dst=" << dst.Str(); + s << "&ts=" << ts; + s << "&sign=" << ctx.SignCgiParamsForTvm(ts, dst.Str()); + s << "&get_retry_settings=yes"; + + s << "&"; + procInfo.AddToRequest(s); + + return s.Str(); + } + + template <class Dsts> + void TThreadedUpdater::ParseTicketsFromResponse(TStringBuf resp, + const Dsts& dsts, + TPairTicketsErrors& out) const { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(resp, &doc), "Invalid json from tvm-api: " << resp); + const NJson::TJsonValue* currentResp = doc.IsMap() ? &doc : nullptr; + auto find = [¤tResp, &doc](TTvmId id, NJson::TJsonValue& obj) -> bool { + const TString idStr = IntToString<10>(id); + if (currentResp && currentResp->GetValue(idStr, &obj)) { + return true; + } + + for (const NJson::TJsonValue& val : doc.GetArray()) { + currentResp = &val; + if (currentResp->GetValue(idStr, &obj)) { + return true; + } + } + + return false; + }; + + for (const TClientSettings::TDst& d : dsts) { + NJson::TJsonValue obj; + NJson::TJsonValue val; + + if (!find(d.Id, obj) || !obj.GetValue("ticket", &val)) { + TString err; + if (obj.GetValue("error", &val)) { + err = val.GetString(); + } else { + err = "Missing tvm_id in response, should never happend: " + IntToString<10>(d.Id); + } + + TStringStream s; + s << "Failed to get ServiceTicket for " << d.Id << ": " << err; + ProcessError(EType::NonRetriable, EScope::ServiceTickets, s.Str()); + + out.Errors.insert({d.Id, std::move(err)}); + continue; + } + + out.Tickets.insert({d.Id, val.GetString()}); + } + } + + void TThreadedUpdater::ParseTicketsFromDiskCache(TStringBuf cache, + TPairTicketsErrors& out) const { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(cache, &doc), "Invalid json from disk: " << cache); + + for (const NJson::TJsonValue& cacheItem : doc.GetArray()) { + for (const auto& [idStr, resp] : cacheItem.GetMap()) { + NJson::TJsonValue val; + TTvmId id; + if (!TryIntFromString<10, TTvmId, TString>(idStr, id)) { + LogWarning(TStringBuilder() << "tvm_id in cache is not integer: " << idStr); + continue; + } + + if (resp.GetValue("ticket", &val)) { + out.Tickets[id] = val.GetString(); + } else if (resp.GetValue("error", &val)) { + out.Errors[id] = val.GetString(); + } else { + out.Errors[id] = "tvm_id found, but response has no error or ticket, should never happend: " + idStr; + } + } + } + } + + static const char DELIMETER = '\t'; + TString TThreadedUpdater::PrepareTicketsForDisk(TStringBuf tvmResponse, TTvmId selfId) { + TStringStream s; + s << tvmResponse << DELIMETER << selfId; + return s.Str(); + } + + std::pair<TStringBuf, TTvmId> TThreadedUpdater::ParseTicketsFromDisk(TStringBuf data) { + TStringBuf tvmId = data.RNextTok(DELIMETER); + return {data, IntFromString<TTvmId, 10>(tvmId)}; + } + + TDstSetPtr TThreadedUpdater::GetDsts() const { + return Destinations_.Get(); + } + + void TThreadedUpdater::SetDsts(TDstSetPtr dsts) { + Destinations_.Set(std::move(dsts)); + } + + bool TThreadedUpdater::IsTimeToUpdateServiceTickets(TInstant lastUpdate) const { + return TInstant::Now() - lastUpdate > ServiceTicketsDurations_.RefreshPeriod; + } + + bool TThreadedUpdater::IsTimeToUpdatePublicKeys(TInstant lastUpdate) const { + return TInstant::Now() - lastUpdate > PublicKeysDurations_.RefreshPeriod; + } + + bool TThreadedUpdater::AreServicesTicketsOk() const { + if (!Settings_.IsServiceTicketFetchingRequired()) { + return true; + } + auto c = GetCachedServiceTickets(); + return c && (!Settings_.IsIncompleteTicketsSetAnError || c->TicketsById.size() == GetDsts()->size()); + } + + bool TThreadedUpdater::IsServiceContextOk() const { + if (!Settings_.CheckServiceTickets) { + return true; + } + + return bool(GetCachedServiceContext()); + } + + bool TThreadedUpdater::IsUserContextOk() const { + if (!Settings_.CheckUserTicketsWithBbEnv) { + return true; + } + return bool(GetCachedUserContext()); + } + + void TThreadedUpdater::Worker() { + if (IsInited()) { + UpdateServiceTickets(); + UpdatePublicKeys(); + UpdateRoles(); + } else { + try { + Init(); + } catch (const TRetriableException& e) { + // Still not initialized + } catch (const std::exception& e) { + // Can't retry, so we mark client as initialized and now GetStatus() will return TClientStatus::Error + SetInited(true); + } + } + } + + TServiceTickets::TMapAliasId TThreadedUpdater::MakeAliasMap(const TClientSettings& settings) { + TServiceTickets::TMapAliasId res; + + for (const auto& pair : settings.FetchServiceTicketsForDstsWithAliases) { + res.insert({pair.first, pair.second.Id}); + } + + return res; + } + + TClientSettings::TDstVector TThreadedUpdater::FindMissingDsts(TServiceTicketsPtr available, const TDstSet& required) { + Y_ENSURE(available); + TDstSet set; + // available->TicketsById is not sorted + for (const auto& pair : available->TicketsById) { + set.insert(pair.first); + } + return FindMissingDsts(set, required); + } + + TClientSettings::TDstVector TThreadedUpdater::FindMissingDsts(const TDstSet& available, const TDstSet& required) { + TClientSettings::TDstVector res; + std::set_difference(required.begin(), required.end(), + available.begin(), available.end(), + std::inserter(res, res.begin())); + return res; + } + + TString TThreadedUpdater::CreateJsonArray(const TSmallVec<TString>& responses) { + if (responses.empty()) { + return "[]"; + } + + size_t size = 0; + for (const TString& r : responses) { + size += r.size() + 1; + } + + TString res; + res.reserve(size + 2); + + res.push_back('['); + for (const TString& r : responses) { + res.append(r).push_back(','); + } + res.back() = ']'; + + return res; + } + + TString TThreadedUpdater::AppendToJsonArray(const TString& json, const TSmallVec<TString>& responses) { + Y_ENSURE(json, "previous body required"); + + size_t size = 0; + for (const TString& r : responses) { + size += r.size() + 1; + } + + TString res; + res.reserve(size + 2 + json.size()); + + res.push_back('['); + if (json.StartsWith('[')) { + Y_ENSURE(json.EndsWith(']'), "array is broken:" << json); + res.append(TStringBuf(json).Chop(1).Skip(1)); + } else { + res.append(json); + } + + res.push_back(','); + for (const TString& r : responses) { + res.append(r).push_back(','); + } + res.back() = ']'; + + return res; + } +} diff --git a/library/cpp/tvmauth/client/misc/api/threaded_updater.h b/library/cpp/tvmauth/client/misc/api/threaded_updater.h new file mode 100644 index 0000000000..8fd68ee678 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/api/threaded_updater.h @@ -0,0 +1,155 @@ +#pragma once + +#include "retry_settings.h" +#include "roles_fetcher.h" +#include "settings.h" + +#include <library/cpp/tvmauth/client/misc/async_updater.h> +#include <library/cpp/tvmauth/client/misc/threaded_updater.h> + +#include <util/generic/set.h> +#include <util/random/fast.h> + +#include <mutex> + +namespace NTvmAuth::NTvmApi { + using TDstSet = TSet<TClientSettings::TDst>; + using TDstSetPtr = std::shared_ptr<const TDstSet>; + + class TThreadedUpdater: public TThreadedUpdaterBase { + public: + /*! + * Starts thread for updating of in-memory cache in background + * Reads cache from disk if specified + * @param settings + * @param logger is usefull for monitoring and debuging + */ + static TAsyncUpdaterPtr Create(const TClientSettings& settings, TLoggerPtr logger); + ~TThreadedUpdater(); + + TClientStatus GetStatus() const override; + TString GetServiceTicketFor(const TClientSettings::TAlias& dst) const override; + TString GetServiceTicketFor(const TTvmId dst) const override; + TCheckedServiceTicket CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags = TServiceContext::TCheckFlags{}) const override; + TCheckedUserTicket CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overrideEnv = {}) const override; + NRoles::TRolesPtr GetRoles() const override; + + protected: // for tests + TClientStatus::ECode GetState() const; + + TThreadedUpdater(const TClientSettings& settings, TLoggerPtr logger); + void Init(); + + void UpdateServiceTickets(); + void UpdateAllServiceTickets(); + TServiceTicketsPtr UpdateMissingServiceTickets(const TDstSet& required); + void UpdatePublicKeys(); + void UpdateRoles(); + + TServiceTicketsPtr UpdateServiceTicketsCachePartly(TPairTicketsErrors&& tickets, size_t got); + void UpdateServiceTicketsCache(TPairTicketsErrors&& tickets, TInstant time); + void UpdatePublicKeysCache(const TString& publicKeys, TInstant time); + + void ReadStateFromDisk(); + + struct TServiceTicketsFromDisk { + TPairTicketsErrors TicketsWithErrors; + TInstant BornDate; + TString FileBody; + }; + + std::pair<TServiceTicketsFromDisk, TServiceTicketsFromDisk> ReadServiceTicketsFromDisk() const; + + std::pair<TString, TInstant> ReadPublicKeysFromDisk() const; + TString ReadRetrySettingsFromDisk() const; + + struct THttpResult { + TPairTicketsErrors TicketsWithErrors; + TSmallVec<TString> Responses; + }; + + template <class Dsts> + THttpResult GetServiceTicketsFromHttp(const Dsts& dsts, const size_t dstLimit) const; + TString GetPublicKeysFromHttp() const; + TServiceTickets::TMapIdStr GetRequestedTicketsFromStartUpCache(const TDstSet& dsts) const; + + virtual NUtils::TFetchResult FetchServiceTicketsFromHttp(const TString& body) const; + virtual NUtils::TFetchResult FetchPublicKeysFromHttp() const; + + bool UpdateRetrySettings(const TString& header) const; + + template <typename Func> + NUtils::TFetchResult FetchWithRetries(Func func, EScope scope) const; + void RandomSleep() const; + + static TString PrepareRequestForServiceTickets(TTvmId src, + const TServiceContext& ctx, + const TClientSettings::TDstVector& dsts, + const NUtils::TProcInfo& procInfo, + time_t now = time(nullptr)); + template <class Dsts> + void ParseTicketsFromResponse(TStringBuf resp, + const Dsts& dsts, + TPairTicketsErrors& out) const; + + void ParseTicketsFromDiskCache(TStringBuf cache, + TPairTicketsErrors& out) const; + + static TString PrepareTicketsForDisk(TStringBuf tvmResponse, TTvmId selfId); + static std::pair<TStringBuf, TTvmId> ParseTicketsFromDisk(TStringBuf data); + + TDstSetPtr GetDsts() const; + void SetDsts(TDstSetPtr dsts); + + TInstant GetStartUpCacheBornDate() const; + + bool IsTimeToUpdateServiceTickets(TInstant lastUpdate) const; + bool IsTimeToUpdatePublicKeys(TInstant lastUpdate) const; + + bool AreServicesTicketsOk() const; + bool IsServiceContextOk() const; + bool IsUserContextOk() const; + + void Worker() override; + + static TServiceTickets::TMapAliasId MakeAliasMap(const TClientSettings& settings); + static TClientSettings::TDstVector FindMissingDsts(TServiceTicketsPtr available, const TDstSet& required); + static TClientSettings::TDstVector FindMissingDsts(const TDstSet& available, const TDstSet& required); + + static TString CreateJsonArray(const TSmallVec<TString>& responses); + static TString AppendToJsonArray(const TString& json, const TSmallVec<TString>& responses); + + private: + TRetrySettings RetrySettings_; + + protected: + mutable TExponentialBackoff ExpBackoff_; + std::unique_ptr<std::mutex> ServiceTicketBatchUpdateMutex_; + + private: + const TClientSettings Settings_; + + const NUtils::TProcInfo ProcInfo_; + + const TString PublicKeysUrl_; + + const TServiceTickets::TMapAliasId DstAliases_; + + const TKeepAliveHttpClient::THeaders Headers_; + TMaybe<TServiceContext> SigningContext_; + + NUtils::TProtectedValue<TDstSetPtr> Destinations_; + + TString DiskCacheServiceTickets_; + TServiceTicketsFromDisk StartUpCache_; + bool NeedFetchMissingServiceTickets_ = true; + + TString PublicKeysFilepath_; + TString ServiceTicketsFilepath_; + TString RetrySettingsFilepath_; + + std::unique_ptr<TRolesFetcher> RolesFetcher_; + + mutable TReallyFastRng32 Random_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/async_updater.cpp b/library/cpp/tvmauth/client/misc/async_updater.cpp new file mode 100644 index 0000000000..670033c684 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/async_updater.cpp @@ -0,0 +1,180 @@ +#include "async_updater.h" + +#include "utils.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <util/string/builder.h> +#include <util/system/spin_wait.h> + +namespace NTvmAuth { + TAsyncUpdaterBase::TAsyncUpdaterBase() { + ServiceTicketsDurations_.RefreshPeriod = TDuration::Hours(1); + ServiceTicketsDurations_.Expiring = TDuration::Hours(2); + ServiceTicketsDurations_.Invalid = TDuration::Hours(11); + + PublicKeysDurations_.RefreshPeriod = TDuration::Days(1); + PublicKeysDurations_.Expiring = TDuration::Days(2); + PublicKeysDurations_.Invalid = TDuration::Days(6); + } + + NRoles::TRolesPtr TAsyncUpdaterBase::GetRoles() const { + ythrow TIllegalUsage() << "not implemented"; + } + + TString TAsyncUpdaterBase::GetServiceTicketFor(const TClientSettings::TAlias& dst) const { + Y_UNUSED(dst); + ythrow TIllegalUsage() << "not implemented"; + } + + TString TAsyncUpdaterBase::GetServiceTicketFor(const TTvmId dst) const { + Y_UNUSED(dst); + ythrow TIllegalUsage() << "not implemented"; + } + + TCheckedServiceTicket TAsyncUpdaterBase::CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags) const { + Y_UNUSED(ticket, flags); + ythrow TIllegalUsage() << "not implemented"; + } + + TCheckedUserTicket TAsyncUpdaterBase::CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overrideEnv) const { + Y_UNUSED(ticket, overrideEnv); + ythrow TIllegalUsage() << "not implemented"; + } + + TInstant TAsyncUpdaterBase::GetUpdateTimeOfPublicKeys() const { + return PublicKeysTime_.Get(); + } + + TInstant TAsyncUpdaterBase::GetUpdateTimeOfServiceTickets() const { + return ServiceTicketsTime_.Get(); + } + + TInstant TAsyncUpdaterBase::GetUpdateTimeOfRoles() const { + return RolesTime_.Get(); + } + + TInstant TAsyncUpdaterBase::GetInvalidationTimeOfPublicKeys() const { + TInstant ins = GetUpdateTimeOfPublicKeys(); + return ins == TInstant() ? TInstant() : ins + PublicKeysDurations_.Invalid; + } + + TInstant TAsyncUpdaterBase::GetInvalidationTimeOfServiceTickets() const { + TServiceTicketsPtr c = GetCachedServiceTickets(); + return c ? c->InvalidationTime : TInstant(); + } + + bool TAsyncUpdaterBase::ArePublicKeysInvalid(TInstant now) const { + return IsInvalid(GetInvalidationTimeOfPublicKeys(), now); + } + + bool TAsyncUpdaterBase::AreServiceTicketsInvalid(TInstant now) const { + TServiceTicketsPtr c = GetCachedServiceTickets(); + // Empty set of tickets is allways valid. + return c && !c->TicketsById.empty() && IsInvalid(GetInvalidationTimeOfServiceTickets(), now); + } + + bool TAsyncUpdaterBase::IsInvalid(TInstant invTime, TInstant now) { + return invTime - + TDuration::Minutes(1) // lag for closing from balancer + < now; + } + + void TAsyncUpdaterBase::SetBbEnv(EBlackboxEnv original, TMaybe<EBlackboxEnv> overrided) { + if (overrided) { + Y_ENSURE_EX(NUtils::CheckBbEnvOverriding(original, *overrided), + TBrokenTvmClientSettings() << "Overriding of BlackboxEnv is illegal: " + << original << " -> " << *overrided); + } + + Envs_.store({original, overrided}, std::memory_order_relaxed); + } + + TServiceTicketsPtr TAsyncUpdaterBase::GetCachedServiceTickets() const { + return ServiceTickets_.Get(); + } + + TServiceContextPtr TAsyncUpdaterBase::GetCachedServiceContext() const { + return ServiceContext_.Get(); + } + + TUserContextPtr TAsyncUpdaterBase::GetCachedUserContext(TMaybe<EBlackboxEnv> overridenEnv) const { + TAllUserContextsPtr ctx = AllUserContexts_.Get(); + if (!ctx) { + return nullptr; + } + + const TEnvs envs = Envs_.load(std::memory_order_relaxed); + if (!envs.Original) { + return nullptr; + } + + EBlackboxEnv env = *envs.Original; + + if (overridenEnv) { + Y_ENSURE_EX(NUtils::CheckBbEnvOverriding(*envs.Original, *overridenEnv), + TBrokenTvmClientSettings() << "Overriding of BlackboxEnv is illegal: " + << *envs.Original << " -> " << *overridenEnv); + env = *overridenEnv; + } else if (envs.Overrided) { + env = *envs.Overrided; + } + + return ctx->Get(env); + } + + void TAsyncUpdaterBase::SetServiceTickets(TServiceTicketsPtr c) { + ServiceTickets_.Set(std::move(c)); + } + + void TAsyncUpdaterBase::SetServiceContext(TServiceContextPtr c) { + ServiceContext_.Set(std::move(c)); + } + + void TAsyncUpdaterBase::SetUserContext(TStringBuf publicKeys) { + AllUserContexts_.Set(MakeIntrusiveConst<TAllUserContexts>(publicKeys)); + } + + void TAsyncUpdaterBase::SetUpdateTimeOfPublicKeys(TInstant ins) { + PublicKeysTime_.Set(ins); + } + + void TAsyncUpdaterBase::SetUpdateTimeOfServiceTickets(TInstant ins) { + ServiceTicketsTime_.Set(ins); + } + + void TAsyncUpdaterBase::SetUpdateTimeOfRoles(TInstant ins) { + RolesTime_.Set(ins); + } + + void TAsyncUpdaterBase::SetInited(bool value) { + Inited_.store(value, std::memory_order_relaxed); + } + + bool TAsyncUpdaterBase::IsInited() const { + return Inited_.load(std::memory_order_relaxed); + } + + bool TAsyncUpdaterBase::IsServiceTicketMapOk(TServiceTicketsPtr c, size_t expectedTicketCount, bool strict) { + return c && + (strict + ? c->TicketsById.size() == expectedTicketCount + : !c->TicketsById.empty()); + } + + TAllUserContexts::TAllUserContexts(TStringBuf publicKeys) { + auto add = [&, this](EBlackboxEnv env) { + Ctx_[(size_t)env] = MakeIntrusiveConst<TUserContext>(env, publicKeys); + }; + + add(EBlackboxEnv::Prod); + add(EBlackboxEnv::Test); + add(EBlackboxEnv::ProdYateam); + add(EBlackboxEnv::TestYateam); + add(EBlackboxEnv::Stress); + } + + TUserContextPtr TAllUserContexts::Get(EBlackboxEnv env) const { + return Ctx_[(size_t)env]; + } +} diff --git a/library/cpp/tvmauth/client/misc/async_updater.h b/library/cpp/tvmauth/client/misc/async_updater.h new file mode 100644 index 0000000000..0c0e81ccac --- /dev/null +++ b/library/cpp/tvmauth/client/misc/async_updater.h @@ -0,0 +1,123 @@ +#pragma once + +#include "last_error.h" +#include "service_tickets.h" +#include "settings.h" +#include "roles/roles.h" + +#include <library/cpp/tvmauth/client/client_status.h> +#include <library/cpp/tvmauth/client/logger.h> + +#include <library/cpp/tvmauth/deprecated/service_context.h> +#include <library/cpp/tvmauth/deprecated/user_context.h> +#include <library/cpp/tvmauth/src/utils.h> + +#include <util/datetime/base.h> +#include <util/generic/hash.h> +#include <util/generic/maybe.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> + +#include <array> +#include <atomic> + +namespace NTvmAuth { + + class TAllUserContexts: public TAtomicRefCount<TAllUserContexts> { + public: + TAllUserContexts(TStringBuf publicKeys); + + TUserContextPtr Get(EBlackboxEnv env) const; + + private: + std::array<TUserContextPtr, 5> Ctx_; + }; + using TAllUserContextsPtr = TIntrusiveConstPtr<TAllUserContexts>; + + class TAsyncUpdaterBase: public TAtomicRefCount<TAsyncUpdaterBase>, protected TLastError, TNonCopyable { + public: + TAsyncUpdaterBase(); + virtual ~TAsyncUpdaterBase() = default; + + virtual TClientStatus GetStatus() const = 0; + virtual TString GetServiceTicketFor(const TClientSettings::TAlias& dst) const; + virtual TString GetServiceTicketFor(const TTvmId dst) const; + virtual TCheckedServiceTicket CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags = TServiceContext::TCheckFlags{}) const; + virtual TCheckedUserTicket CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overrideEnv = {}) const; + virtual NRoles::TRolesPtr GetRoles() const; + + TServiceTicketsPtr GetCachedServiceTickets() const; + TServiceContextPtr GetCachedServiceContext() const; + TUserContextPtr GetCachedUserContext(TMaybe<EBlackboxEnv> overridenEnv = {}) const; + + TInstant GetUpdateTimeOfPublicKeys() const; + TInstant GetUpdateTimeOfServiceTickets() const; + TInstant GetUpdateTimeOfRoles() const; + TInstant GetInvalidationTimeOfPublicKeys() const; + TInstant GetInvalidationTimeOfServiceTickets() const; + + bool ArePublicKeysInvalid(TInstant now) const; + bool AreServiceTicketsInvalid(TInstant now) const; + static bool IsInvalid(TInstant invTime, TInstant now); + + protected: + void SetBbEnv(EBlackboxEnv original, TMaybe<EBlackboxEnv> overrided = {}); + + void SetServiceTickets(TServiceTicketsPtr c); + void SetServiceContext(TServiceContextPtr c); + void SetUserContext(TStringBuf publicKeys); + void SetUpdateTimeOfPublicKeys(TInstant ins); + void SetUpdateTimeOfServiceTickets(TInstant ins); + void SetUpdateTimeOfRoles(TInstant ins); + + void SetInited(bool value); + bool IsInited() const; + + static bool IsServiceTicketMapOk(TServiceTicketsPtr c, size_t expectedTicketCount, bool strict); + + protected: + struct TPairTicketsErrors { + TServiceTickets::TMapIdStr Tickets; + TServiceTickets::TMapIdStr Errors; + + bool operator==(const TPairTicketsErrors& o) const { + return Tickets == o.Tickets && Errors == o.Errors; + } + }; + + struct TStateDurations { + TDuration RefreshPeriod; + TDuration Expiring; + TDuration Invalid; + }; + + TStateDurations ServiceTicketsDurations_; + TStateDurations PublicKeysDurations_; + + protected: + virtual void StartTvmClientStopping() const { + } + virtual bool IsTvmClientStopped() const { + return true; + } + friend class NTvmAuth::NInternal::TClientCaningKnife; + + private: + struct TEnvs { + TMaybe<EBlackboxEnv> Original; + TMaybe<EBlackboxEnv> Overrided; + }; + static_assert(sizeof(TEnvs) <= 8, "Small struct is easy to store as atomic"); + std::atomic<TEnvs> Envs_ = {{}}; + + NUtils::TProtectedValue<TServiceTicketsPtr> ServiceTickets_; + NUtils::TProtectedValue<TServiceContextPtr> ServiceContext_; + NUtils::TProtectedValue<TAllUserContextsPtr> AllUserContexts_; + NUtils::TProtectedValue<TInstant> PublicKeysTime_; + NUtils::TProtectedValue<TInstant> ServiceTicketsTime_; + NUtils::TProtectedValue<TInstant> RolesTime_; + + std::atomic_bool Inited_{false}; + }; + using TAsyncUpdaterPtr = TIntrusiveConstPtr<TAsyncUpdaterBase>; +} diff --git a/library/cpp/tvmauth/client/misc/checker.h b/library/cpp/tvmauth/client/misc/checker.h new file mode 100644 index 0000000000..16f1a95200 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/checker.h @@ -0,0 +1,38 @@ +#pragma once + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/checked_service_ticket.h> +#include <library/cpp/tvmauth/checked_user_ticket.h> +#include <library/cpp/tvmauth/deprecated/service_context.h> +#include <library/cpp/tvmauth/deprecated/user_context.h> + +namespace NTvmAuth { + class TServiceTicketChecker { + public: + /*! + * Checking must be enabled in TClientSettings + * Can throw exception if cache is out of date or wrong config + * @param ticket + */ + static TCheckedServiceTicket Check( + TStringBuf ticket, + TServiceContextPtr c, + const TServiceContext::TCheckFlags& flags = {}) { + Y_ENSURE_EX(c, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableServiceTicketChecking()"); + return c->Check(ticket, flags); + } + }; + + class TUserTicketChecker { + public: + /*! + * Blackbox enviroment must be cofingured in TClientSettings + * Can throw exception if cache is out of date or wrong config + */ + static TCheckedUserTicket Check(TStringBuf ticket, TUserContextPtr c) { + Y_ENSURE_EX(c, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableUserTicketChecking()"); + return c->Check(ticket); + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/default_uid_checker.h b/library/cpp/tvmauth/client/misc/default_uid_checker.h new file mode 100644 index 0000000000..b723d6e918 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/default_uid_checker.h @@ -0,0 +1,31 @@ +#pragma once + +#include "roles/roles.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/checked_user_ticket.h> +#include <library/cpp/tvmauth/src/user_impl.h> +#include <library/cpp/tvmauth/src/utils.h> + +namespace NTvmAuth { + class TDefaultUidChecker { + public: + /*! + * Checking must be enabled in TClientSettings + * Can throw exception if cache is out of date or wrong config + * @param ticket + */ + static TCheckedUserTicket Check(TCheckedUserTicket ticket, NRoles::TRolesPtr r) { + Y_ENSURE_EX(r, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableRolesFetching()"); + NRoles::TConsumerRolesPtr roles = r->GetRolesForUser(ticket); + if (roles) { + return ticket; + } + + TUserTicketImplPtr impl = THolder(NInternal::TCanningKnife::GetU(ticket)); + impl->SetStatus(ETicketStatus::NoRoles); + return TCheckedUserTicket(std::move(impl)); + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/disk_cache.cpp b/library/cpp/tvmauth/client/misc/disk_cache.cpp new file mode 100644 index 0000000000..8f3ab7770f --- /dev/null +++ b/library/cpp/tvmauth/client/misc/disk_cache.cpp @@ -0,0 +1,162 @@ +#include "disk_cache.h" + +#include <library/cpp/tvmauth/client/logger.h> + +#include <openssl/evp.h> +#include <openssl/hmac.h> + +#include <util/stream/file.h> +#include <util/stream/str.h> +#include <util/system/fs.h> +#include <util/system/sysstat.h> + +#include <exception> + +namespace NTvmAuth { + static const size_t HASH_SIZE = 32; + static const size_t TIMESTAMP_SIZE = sizeof(time_t); + + TDiskReader::TDiskReader(const TString& filename, ILogger* logger) + : Filename_(filename) + , Logger_(logger) + { + } + + bool TDiskReader::Read() { + TStringStream s; + + try { + if (!NFs::Exists(Filename_)) { + if (Logger_) { + s << "File '" << Filename_ << "' does not exist"; + Logger_->Debug(s.Str()); + } + return false; + } + + TFile file(Filename_, OpenExisting | RdOnly | Seq); + file.Flock(LOCK_SH | LOCK_NB); + + TFileInput input(file); + return ParseData(input.ReadAll()); + } catch (const std::exception& e) { + if (Logger_) { + s << "Failed to read '" << Filename_ << "': " << e.what(); + Logger_->Error(s.Str()); + } + } + + return false; + } + + bool TDiskReader::ParseData(TStringBuf buf) { + TStringStream s; + + if (buf.size() <= HASH_SIZE + TIMESTAMP_SIZE) { + if (Logger_) { + s << "File '" << Filename_ << "' is too small"; + Logger_->Warning(s.Str()); + } + return false; + } + + TStringBuf hash = buf.SubStr(0, HASH_SIZE); + if (hash != GetHash(buf.Skip(HASH_SIZE))) { + if (Logger_) { + s << "Content of '" << Filename_ << "' was incorrectly changed"; + Logger_->Warning(s.Str()); + } + return false; + } + + Time_ = TInstant::Seconds(GetTimestamp(buf.substr(0, TIMESTAMP_SIZE))); + Data_ = buf.Skip(TIMESTAMP_SIZE); + + if (Logger_) { + s << "File '" << Filename_ << "' was successfully read"; + Logger_->Info(s.Str()); + } + return true; + } + + TString TDiskReader::GetHash(TStringBuf data) { + TString value(EVP_MAX_MD_SIZE, 0); + unsigned macLen = 0; + if (!::HMAC(EVP_sha256(), + "", + 0, + (unsigned char*)data.data(), + data.size(), + (unsigned char*)value.data(), + &macLen)) { + return {}; + } + + if (macLen != EVP_MAX_MD_SIZE) { + value.resize(macLen); + } + + return value; + } + + time_t TDiskReader::GetTimestamp(TStringBuf data) { + time_t time = 0; + for (int idx = TIMESTAMP_SIZE - 1; idx >= 0; --idx) { + time <<= 8; + time |= static_cast<unsigned char>(data.at(idx)); + } + return time; + } + + TDiskWriter::TDiskWriter(const TString& filename, ILogger* logger) + : Filename_(filename) + , Logger_(logger) + { + } + + bool TDiskWriter::Write(TStringBuf data, TInstant now) { + TStringStream s; + + try { + { + if (NFs::Exists(Filename_)) { + Chmod(Filename_.c_str(), + S_IRUSR | S_IWUSR); // 600 + } + + TFile file(Filename_, CreateAlways | WrOnly | Seq | AWUser | ARUser); + file.Flock(LOCK_EX | LOCK_NB); + + TFileOutput output(file); + output << PrepareData(now, data); + } + + if (Logger_) { + s << "File '" << Filename_ << "' was successfully written"; + Logger_->Info(s.Str()); + } + return true; + } catch (const std::exception& e) { + if (Logger_) { + s << "Failed to write '" << Filename_ << "': " << e.what(); + Logger_->Error(s.Str()); + } + } + + return false; + } + + TString TDiskWriter::PrepareData(TInstant time, TStringBuf data) { + TString toHash = WriteTimestamp(time.TimeT()) + data; + return TDiskReader::GetHash(toHash) + toHash; + } + + TString TDiskWriter::WriteTimestamp(time_t time) { + TString res(TIMESTAMP_SIZE, 0); + for (size_t idx = 0; idx < TIMESTAMP_SIZE; ++idx) { + res[idx] = time & 0xFF; + time >>= 8; + } + return res; + } +} diff --git a/library/cpp/tvmauth/client/misc/disk_cache.h b/library/cpp/tvmauth/client/misc/disk_cache.h new file mode 100644 index 0000000000..9e77556f86 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/disk_cache.h @@ -0,0 +1,50 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/generic/string.h> + +namespace NTvmAuth { + class ILogger; + + class TDiskReader { + public: + TDiskReader(const TString& filename, ILogger* logger = nullptr); + + bool Read(); + + const TString& Data() const { + return Data_; + } + + TInstant Time() const { + return Time_; + } + + public: // for tests + bool ParseData(TStringBuf buf); + + static TString GetHash(TStringBuf data); + static time_t GetTimestamp(TStringBuf data); + + private: + TString Filename_; + ILogger* Logger_; + TInstant Time_; + TString Data_; + }; + + class TDiskWriter { + public: + TDiskWriter(const TString& filename, ILogger* logger = nullptr); + + bool Write(TStringBuf data, TInstant now = TInstant::Now()); + + public: // for tests + static TString PrepareData(TInstant time, TStringBuf data); + static TString WriteTimestamp(time_t time); + + private: + TString Filename_; + ILogger* Logger_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/exponential_backoff.h b/library/cpp/tvmauth/client/misc/exponential_backoff.h new file mode 100644 index 0000000000..89a7a3c8ad --- /dev/null +++ b/library/cpp/tvmauth/client/misc/exponential_backoff.h @@ -0,0 +1,94 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/random/normal.h> +#include <util/system/event.h> + +#include <atomic> + +namespace NTvmAuth { + // https://habr.com/ru/post/227225/ + class TExponentialBackoff { + public: + struct TSettings { + TDuration Min; + TDuration Max; + double Factor = 1.001; + double Jitter = 0; + + bool operator==(const TSettings& o) const { + return Min == o.Min && + Max == o.Max && + Factor == o.Factor && + Jitter == o.Jitter; + } + }; + + TExponentialBackoff(const TSettings& settings, bool isEnabled = true) + : CurrentValue_(settings.Min) + , IsEnabled_(isEnabled) + { + UpdateSettings(settings); + } + + void UpdateSettings(const TSettings& settings) { + Y_ENSURE(settings.Factor > 1, "factor=" << settings.Factor << ". Should be > 1"); + Y_ENSURE(settings.Jitter >= 0 && settings.Jitter < 1, "jitter should be in range [0, 1)"); + + Min_ = settings.Min; + Max_ = settings.Max; + Factor_ = settings.Factor; + Jitter_ = settings.Jitter; + } + + TDuration Increase() { + CurrentValue_ = std::min(CurrentValue_ * Factor_, Max_); + + double rnd = StdNormalRandom<double>(); + const bool isNegative = rnd < 0; + rnd = std::abs(rnd); + + const TDuration diff = rnd * Jitter_ * CurrentValue_; + if (isNegative) { + CurrentValue_ -= diff; + } else { + CurrentValue_ += diff; + } + + return CurrentValue_; + } + + TDuration Decrease() { + CurrentValue_ = std::max(CurrentValue_ / Factor_, Min_); + return CurrentValue_; + } + + void Sleep(TDuration add = TDuration()) { + if (IsEnabled_.load(std::memory_order_relaxed)) { + Ev_.WaitT(CurrentValue_ + add); + } + } + + void Interrupt() { + Ev_.Signal(); + } + + TDuration GetCurrentValue() const { + return CurrentValue_; + } + + void SetEnabled(bool val) { + IsEnabled_.store(val); + } + + private: + TDuration Min_; + TDuration Max_; + double Factor_; + double Jitter_; + TDuration CurrentValue_; + std::atomic_bool IsEnabled_; + + TAutoEvent Ev_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/fetch_result.h b/library/cpp/tvmauth/client/misc/fetch_result.h new file mode 100644 index 0000000000..4b0774e92f --- /dev/null +++ b/library/cpp/tvmauth/client/misc/fetch_result.h @@ -0,0 +1,13 @@ +#pragma once + +#include <library/cpp/http/simple/http_client.h> + +namespace NTvmAuth::NUtils { + struct TFetchResult { + TKeepAliveHttpClient::THttpCode Code; + THttpHeaders Headers; + TStringBuf Path; + TString Response; + TString RetrySettings; + }; +} diff --git a/library/cpp/tvmauth/client/misc/getter.h b/library/cpp/tvmauth/client/misc/getter.h new file mode 100644 index 0000000000..6c7617b418 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/getter.h @@ -0,0 +1,49 @@ +#pragma once + +#include "checker.h" +#include "service_tickets.h" + +namespace NTvmAuth { + class TServiceTicketGetter { + public: + /*! + * Fetching must enabled in TClientSettings + * Can throw exception if cache is invalid or wrong config + * @param dst + */ + static TString GetTicket(const TClientSettings::TAlias& dst, TServiceTicketsPtr c) { + Y_ENSURE_EX(c, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableServiceTicketsFetchOptions()"); + return GetTicketImpl(dst, c->TicketsByAlias, c->ErrorsByAlias, c->UnfetchedAliases); + } + + static TString GetTicket(const TTvmId dst, TServiceTicketsPtr c) { + Y_ENSURE_EX(c, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableServiceTicketsFetchOptions()"); + return GetTicketImpl(dst, c->TicketsById, c->ErrorsById, c->UnfetchedIds); + } + + private: + template <class Key, class Cont, class UnfetchedCont> + static TString GetTicketImpl(const Key& dst, const Cont& tickets, const Cont& errors, const UnfetchedCont& unfetched) { + auto it = tickets.find(dst); + if (it != tickets.end()) { + return it->second; + } + + it = errors.find(dst); + if (it != errors.end()) { + ythrow TMissingServiceTicket() + << "Failed to get ticket for '" << dst << "': " + << it->second; + } + + if (unfetched.contains(dst)) { + ythrow TMissingServiceTicket() + << "Failed to get ticket for '" << dst << "': this dst was not fetched yet."; + } + + ythrow TBrokenTvmClientSettings() + << "Destination '" << dst << "' was not specified in settings. " + << "Check your settings (if you use Qloud/YP/tvmtool - check it's settings)"; + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/last_error.cpp b/library/cpp/tvmauth/client/misc/last_error.cpp new file mode 100644 index 0000000000..a5054b0342 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/last_error.cpp @@ -0,0 +1,115 @@ +#include "last_error.h" + +#include <util/string/builder.h> + +namespace NTvmAuth { + TLastError::TLastError() + : LastErrors_(MakeIntrusiveConst<TLastErrors>()) + { + } + + TString TLastError::GetLastError(bool isOk, EType* type) const { + if (isOk) { + return OK_; + } + + const TLastErrorsPtr ptr = LastErrors_.Get(); + + for (const TLastErr& err : ptr->Errors) { + if (err && err->first == EType::NonRetriable) { + if (type) { + *type = EType::NonRetriable; + } + return err->second; + } + } + + for (const TLastErr& err : ptr->Errors) { + if (err) { + if (type) { + *type = EType::Retriable; + } + return err->second; + } + } + + if (type) { + *type = EType::NonRetriable; + } + return "Internal client error: failed to collect last useful error message, please report this message to tvm-dev@yandex-team.ru"; + } + + TString TLastError::ProcessHttpError(TLastError::EScope scope, + TStringBuf path, + int code, + const TString& msg) const { + TString err = TStringBuilder() << "Path:" << path << ".Code=" << code << ": " << msg; + + ProcessError(code >= 400 && code < 500 ? EType::NonRetriable + : EType::Retriable, + scope, + err); + + return err; + } + + void TLastError::ProcessError(TLastError::EType type, TLastError::EScope scope, const TStringBuf msg) const { + Update(scope, [&](TLastErr& lastError) { + if (lastError && lastError->first == EType::NonRetriable && type == EType::Retriable) { + return false; + } + + TString err = TStringBuilder() << scope << ": " << msg; + err.erase(std::remove(err.begin(), err.vend(), '\r'), err.vend()); + std::replace(err.begin(), err.vend(), '\n', ' '); + + lastError = {type, std::move(err)}; + return true; + }); + } + + void TLastError::ClearError(TLastError::EScope scope) { + Update(scope, [&](TLastErr& lastError) { + if (!lastError) { + return false; + } + + lastError.Clear(); + return true; + }); + } + + void TLastError::ClearErrors() { + for (size_t idx = 0; idx < (size_t)EScope::COUNT; ++idx) { + ClearError((EScope)idx); + } + } + + void TLastError::ThrowLastError() { + EType type; + TString err = GetLastError(false, &type); + + switch (type) { + case EType::NonRetriable: + ythrow TNonRetriableException() + << "Failed to start TvmClient. Do not retry: " + << err; + case EType::Retriable: + ythrow TRetriableException() + << "Failed to start TvmClient. You can retry: " + << err; + } + } + + template <typename Func> + void TLastError::Update(TLastError::EScope scope, Func func) const { + Y_ABORT_UNLESS(scope != EScope::COUNT); + + TLastErrors errs = *LastErrors_.Get(); + TLastErr& lastError = errs.Errors[(size_t)scope]; + + if (func(lastError)) { + LastErrors_.Set(MakeIntrusiveConst<TLastErrors>(std::move(errs))); + } + } +} diff --git a/library/cpp/tvmauth/client/misc/last_error.h b/library/cpp/tvmauth/client/misc/last_error.h new file mode 100644 index 0000000000..b0ad33611f --- /dev/null +++ b/library/cpp/tvmauth/client/misc/last_error.h @@ -0,0 +1,51 @@ +#pragma once + +#include "utils.h" + +#include <array> + +namespace NTvmAuth { + class TLastError { + public: + enum class EType { + NonRetriable, + Retriable, + }; + + enum class EScope { + ServiceTickets, + PublicKeys, + Roles, + TvmtoolConfig, + + COUNT, + }; + + using TLastErr = TMaybe<std::pair<EType, TString>>; + + struct TLastErrors: public TAtomicRefCount<TLastErrors> { + std::array<TLastErr, (int)EScope::COUNT> Errors; + }; + using TLastErrorsPtr = TIntrusiveConstPtr<TLastErrors>; + + public: + TLastError(); + + TString GetLastError(bool isOk, EType* type = nullptr) const; + + TString ProcessHttpError(EScope scope, TStringBuf path, int code, const TString& msg) const; + void ProcessError(EType type, EScope scope, const TStringBuf msg) const; + void ClearError(EScope scope); + void ClearErrors(); + void ThrowLastError(); + + private: + template <typename Func> + void Update(EScope scope, Func func) const; + + private: + const TString OK_ = "OK"; + + mutable NUtils::TProtectedValue<TLastErrorsPtr> LastErrors_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/proc_info.cpp b/library/cpp/tvmauth/client/misc/proc_info.cpp new file mode 100644 index 0000000000..e2e5ec15b9 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/proc_info.cpp @@ -0,0 +1,53 @@ +#include "proc_info.h" + +#include <library/cpp/tvmauth/version.h> + +#include <library/cpp/string_utils/quote/quote.h> + +#include <util/stream/file.h> +#include <util/string/cast.h> +#include <util/system/getpid.h> + +namespace NTvmAuth::NUtils { + void TProcInfo::AddToRequest(IOutputStream& out) const { + out << "_pid=" << Pid; + if (ProcessName) { + out << "&_procces_name=" << *ProcessName; + } + out << "&lib_version=client_" << VersionPrefix << LibVersion(); + } + + TProcInfo TProcInfo::Create(const TString& versionPrefix) { + TProcInfo res; + res.Pid = IntToString<10>(GetPID()); + res.ProcessName = GetProcessName(); + res.VersionPrefix = versionPrefix; + return res; + } + + std::optional<TString> TProcInfo::GetProcessName() { + try { + // works only for linux + TFileInput proc("/proc/self/status"); + + TString line; + while (proc.ReadLine(line)) { + TStringBuf buf(line); + if (!buf.SkipPrefix("Name:")) { + continue; + } + + while (buf && isspace(buf.front())) { + buf.Skip(1); + } + + TString res(buf); + CGIEscape(res); + return res; + } + } catch (...) { + } + + return {}; + } +} diff --git a/library/cpp/tvmauth/client/misc/proc_info.h b/library/cpp/tvmauth/client/misc/proc_info.h new file mode 100644 index 0000000000..b1526e5c47 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/proc_info.h @@ -0,0 +1,18 @@ +#pragma once + +#include <util/generic/string.h> + +#include <optional> + +namespace NTvmAuth::NUtils { + struct TProcInfo { + TString Pid; + std::optional<TString> ProcessName; + TString VersionPrefix; + + void AddToRequest(IOutputStream& out) const; + + static TProcInfo Create(const TString& versionPrefix); + static std::optional<TString> GetProcessName(); + }; +} diff --git a/library/cpp/tvmauth/client/misc/retry_settings/v1/settings.proto b/library/cpp/tvmauth/client/misc/retry_settings/v1/settings.proto new file mode 100644 index 0000000000..ddf1648777 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/retry_settings/v1/settings.proto @@ -0,0 +1,21 @@ +syntax = "proto2"; + +package retry_settings.v1; + +option cc_enable_arenas = true; +option go_package = "github.com/ydb-platform/ydb/library/cpp/tvmauth/client/misc/retry_settings/v1"; + +message Settings { + optional uint32 exponential_backoff_min_sec = 1; + optional uint32 exponential_backoff_max_sec = 2; + optional double exponential_backoff_factor = 3; + optional double exponential_backoff_jitter = 4; + optional uint32 max_random_sleep_default = 5; + optional uint32 max_random_sleep_when_ok = 12; + optional uint32 retries_on_start = 6; + optional uint32 worker_awaking_period_sec = 7; + optional uint32 dsts_limit = 8; + optional uint32 retries_in_background = 9; + optional uint32 roles_update_period_sec = 10; + optional uint32 roles_warn_period_sec = 11; +} diff --git a/library/cpp/tvmauth/client/misc/retry_settings/v1/ya.make b/library/cpp/tvmauth/client/misc/retry_settings/v1/ya.make new file mode 100644 index 0000000000..7ade82ac74 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/retry_settings/v1/ya.make @@ -0,0 +1,13 @@ +PROTO_LIBRARY() + +EXCLUDE_TAGS( + JAVA_PROTO + PY_PROTO + PY3_PROTO +) + +SRCS( + settings.proto +) + +END() diff --git a/library/cpp/tvmauth/client/misc/roles/decoder.cpp b/library/cpp/tvmauth/client/misc/roles/decoder.cpp new file mode 100644 index 0000000000..6337fb91c2 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/decoder.cpp @@ -0,0 +1,93 @@ +#include "decoder.h" + +#include <library/cpp/tvmauth/client/misc/utils.h> + +#include <library/cpp/openssl/crypto/sha.h> +#include <library/cpp/streams/brotli/brotli.h> +#include <library/cpp/streams/zstd/zstd.h> + +#include <util/generic/yexception.h> +#include <util/stream/zlib.h> +#include <util/string/ascii.h> + +namespace NTvmAuth::NRoles { + TString TDecoder::Decode(const TStringBuf codec, TString&& blob) { + if (codec.empty()) { + return std::move(blob); + } + + const TCodecInfo info = ParseCodec(codec); + TString decoded = DecodeImpl(info.Type, blob); + + VerifySize(decoded, info.Size); + VerifyChecksum(decoded, info.Sha256); + + return decoded; + } + + TDecoder::TCodecInfo TDecoder::ParseCodec(TStringBuf codec) { + const char delim = ':'; + + const TStringBuf version = codec.NextTok(delim); + Y_ENSURE(version == "1", + "unknown codec format version; known: 1; got: " << version); + + TCodecInfo res; + res.Type = codec.NextTok(delim); + Y_ENSURE(res.Type, "codec type is empty"); + + const TStringBuf size = codec.NextTok(delim); + Y_ENSURE(TryIntFromString<10>(size, res.Size), + "decoded blob size is not number"); + + res.Sha256 = codec; + const size_t expectedSha256Size = 2 * NOpenSsl::NSha256::DIGEST_LENGTH; + Y_ENSURE(res.Sha256.size() == expectedSha256Size, + "sha256 of decoded blob has invalid length: expected " + << expectedSha256Size << ", got " << res.Sha256.size()); + + return res; + } + + TString TDecoder::DecodeImpl(TStringBuf codec, const TString& blob) { + if (AsciiEqualsIgnoreCase(codec, "brotli")) { + return DecodeBrolti(blob); + } else if (AsciiEqualsIgnoreCase(codec, "gzip")) { + return DecodeGzip(blob); + } else if (AsciiEqualsIgnoreCase(codec, "zstd")) { + return DecodeZstd(blob); + } + + ythrow yexception() << "unknown codec: '" << codec << "'"; + } + + TString TDecoder::DecodeBrolti(const TString& blob) { + TStringInput in(blob); + return TBrotliDecompress(&in).ReadAll(); + } + + TString TDecoder::DecodeGzip(const TString& blob) { + TStringInput in(blob); + return TZLibDecompress(&in).ReadAll(); + } + + TString TDecoder::DecodeZstd(const TString& blob) { + TStringInput in(blob); + return TZstdDecompress(&in).ReadAll(); + } + + void TDecoder::VerifySize(const TStringBuf decoded, size_t expected) { + Y_ENSURE(expected == decoded.size(), + "Decoded blob has bad size: expected " << expected << ", actual " << decoded.size()); + } + + void TDecoder::VerifyChecksum(const TStringBuf decoded, const TStringBuf expected) { + using namespace NOpenSsl::NSha256; + + const TDigest dig = Calc(decoded); + const TString actual = NUtils::ToHex(TStringBuf((char*)dig.data(), dig.size())); + + Y_ENSURE(AsciiEqualsIgnoreCase(actual, expected), + "Decoded blob has bad sha256: expected=" << expected << ", actual=" << actual); + } +} diff --git a/library/cpp/tvmauth/client/misc/roles/decoder.h b/library/cpp/tvmauth/client/misc/roles/decoder.h new file mode 100644 index 0000000000..de5cdb37e0 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/decoder.h @@ -0,0 +1,32 @@ +#pragma once + +#include <util/generic/string.h> + +namespace NTvmAuth::NRoles { + class TDecoder { + public: + static TString Decode(const TStringBuf codec, TString&& blob); + + public: + struct TCodecInfo { + TStringBuf Type; + size_t Size = 0; + TStringBuf Sha256; + + bool operator==(const TCodecInfo& o) const { + return Type == o.Type && + Size == o.Size && + Sha256 == o.Sha256; + } + }; + + static TCodecInfo ParseCodec(TStringBuf codec); + static TString DecodeImpl(TStringBuf codec, const TString& blob); + static TString DecodeBrolti(const TString& blob); + static TString DecodeGzip(const TString& blob); + static TString DecodeZstd(const TString& blob); + + static void VerifySize(const TStringBuf decoded, size_t expected); + static void VerifyChecksum(const TStringBuf decoded, const TStringBuf expected); + }; +} diff --git a/library/cpp/tvmauth/client/misc/roles/entities_index.cpp b/library/cpp/tvmauth/client/misc/roles/entities_index.cpp new file mode 100644 index 0000000000..6c3cd5c192 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/entities_index.cpp @@ -0,0 +1,111 @@ +#include "entities_index.h" + +#include <util/stream/str.h> + +#include <set> + +namespace NTvmAuth::NRoles { + TEntitiesIndex::TStage::TStage(const std::set<TString>& k) + : Keys_(k.begin(), k.end()) + { + } + + // TODO TStringBuf + bool TEntitiesIndex::TStage::GetNextKeySet(std::vector<TString>& out) { + out.clear(); + out.reserve(Keys_.size()); + + ++Id_; + for (size_t idx = 0; idx < Keys_.size(); ++idx) { + bool need = (Id_ >> idx) & 0x01; + + if (need) { + out.push_back(Keys_[idx]); + } + } + + return !out.empty(); + } + + TEntitiesIndex::TEntitiesIndex(const std::vector<TEntityPtr>& entities) { + const std::set<TString> uniqueKeys = GetUniqueSortedKeys(entities); + Idx_.Entities = entities; + Idx_.SubTree.reserve(uniqueKeys.size() * entities.size()); + + TStage stage(uniqueKeys); + std::vector<TString> keyset; + while (stage.GetNextKeySet(keyset)) { + for (const TEntityPtr& e : entities) { + TSubTree* currentBranch = &Idx_; + + for (const TString& key : keyset) { + auto it = e->find(key); + if (it == e->end()) { + continue; + } + + auto [i, ok] = currentBranch->SubTree.emplace( + TKeyValue{it->first, it->second}, + TSubTree()); + + currentBranch = &i->second; + currentBranch->Entities.push_back(e); + } + } + } + + MakeUnique(Idx_); + } + + std::set<TString> TEntitiesIndex::GetUniqueSortedKeys(const std::vector<TEntityPtr>& entities) { + std::set<TString> res; + + for (const TEntityPtr& e : entities) { + for (const auto& [key, value] : *e) { + res.insert(key); + } + } + + return res; + } + + void TEntitiesIndex::MakeUnique(TSubTree& branch) { + auto& vec = branch.Entities; + std::sort(vec.begin(), vec.end()); + vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); + + for (auto& [_, restPart] : branch.SubTree) { + MakeUnique(restPart); + } + } + + static void Print(const TEntitiesIndex::TSubTree& part, IOutputStream& out, size_t offset = 0) { + std::vector<std::pair<TKeyValue, const TEntitiesIndex::TSubTree*>> vec; + vec.reserve(part.SubTree.size()); + + for (const auto& [key, value] : part.SubTree) { + vec.push_back({key, &value}); + } + + std::sort(vec.begin(), vec.end(), [](const auto& l, const auto& r) { + if (l.first.Key != r.first.Key) { + return l.first.Key < r.first.Key; + } + return l.first.Value < r.first.Value; + }); + + for (const auto& [key, value] : vec) { + out << TString(offset, ' ') << "\"" << key.Key << "/" << key.Value << "\"" << Endl; + Print(*value, out, offset + 4); + } + } + + TString TEntitiesIndex::PrintDebugString() const { + TStringStream res; + res << Endl; + + Print(Idx_, res); + + return res.Str(); + } +} diff --git a/library/cpp/tvmauth/client/misc/roles/entities_index.h b/library/cpp/tvmauth/client/misc/roles/entities_index.h new file mode 100644 index 0000000000..bf42750d52 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/entities_index.h @@ -0,0 +1,107 @@ +#pragma once + +#include "types.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <set> +#include <vector> + +namespace NTvmAuth::NRoles { + class TEntitiesIndex: TMoveOnly { + public: + struct TSubTree; + using TIdxByAttrs = THashMap<TKeyValue, TSubTree>; + + struct TSubTree { + std::vector<TEntityPtr> Entities; + TIdxByAttrs SubTree; + }; + + class TStage { + public: + TStage(const std::set<TString>& k); + + bool GetNextKeySet(std::vector<TString>& out); + + private: + std::vector<TString> Keys_; + size_t Id_ = 0; + }; + + public: + TEntitiesIndex(const std::vector<TEntityPtr>& entities); + + /** + * Iterators must be to sorted unique key/value + */ + template <typename Iterator> + bool ContainsExactEntity(Iterator begin, Iterator end) const; + + /** + * Iterators must be to sorted unique key/value + */ + template <typename Iterator> + const std::vector<TEntityPtr>& GetEntitiesWithAttrs(Iterator begin, Iterator end) const; + + public: // for tests + static std::set<TString> GetUniqueSortedKeys(const std::vector<TEntityPtr>& entities); + static void MakeUnique(TEntitiesIndex::TSubTree& branch); + + TString PrintDebugString() const; + + private: + template <typename Iterator> + const TSubTree* FindSubtree(Iterator begin, Iterator end, size_t& size) const; + + private: + TSubTree Idx_; + std::vector<TEntityPtr> EmptyResult_; + }; + + template <typename Iterator> + bool TEntitiesIndex::ContainsExactEntity(Iterator begin, Iterator end) const { + size_t size = 0; + const TSubTree* subtree = FindSubtree(begin, end, size); + if (!subtree) { + return false; + } + + auto res = std::find_if( + subtree->Entities.begin(), + subtree->Entities.end(), + [size](const auto& e) { return size == e->size(); }); + return res != subtree->Entities.end(); + } + + template <typename Iterator> + const std::vector<TEntityPtr>& TEntitiesIndex::GetEntitiesWithAttrs(Iterator begin, Iterator end) const { + size_t size = 0; + const TSubTree* subtree = FindSubtree(begin, end, size); + if (!subtree) { + return EmptyResult_; + } + + return subtree->Entities; + } + + template <typename Iterator> + const TEntitiesIndex::TSubTree* TEntitiesIndex::FindSubtree(Iterator begin, + Iterator end, + size_t& size) const { + const TSubTree* subtree = &Idx_; + size = 0; + + for (auto attr = begin; attr != end; ++attr) { + auto it = subtree->SubTree.find(TKeyValueView{attr->first, attr->second}); + if (it == subtree->SubTree.end()) { + return nullptr; + } + + ++size; + subtree = &it->second; + } + + return subtree; + } +} diff --git a/library/cpp/tvmauth/client/misc/roles/parser.cpp b/library/cpp/tvmauth/client/misc/roles/parser.cpp new file mode 100644 index 0000000000..28faf4c057 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/parser.cpp @@ -0,0 +1,149 @@ +#include "parser.h" + +#include <library/cpp/json/json_reader.h> + +#include <util/string/cast.h> + +namespace NTvmAuth::NRoles { + static void GetRequiredValue(const NJson::TJsonValue& doc, + TStringBuf key, + NJson::TJsonValue& obj) { + Y_ENSURE(doc.GetValue(key, &obj), "Missing '" << key << "'"); + } + + static ui64 GetRequiredUInt(const NJson::TJsonValue& doc, + TStringBuf key) { + NJson::TJsonValue obj; + GetRequiredValue(doc, key, obj); + Y_ENSURE(obj.IsUInteger(), "key '" << key << "' must be uint"); + return obj.GetUInteger(); + } + + static bool GetOptionalMap(const NJson::TJsonValue& doc, + TStringBuf key, + NJson::TJsonValue& obj) { + if (!doc.GetValue(key, &obj)) { + return false; + } + + Y_ENSURE(obj.IsMap(), "'" << key << "' must be object"); + return true; + } + + TRolesPtr TParser::Parse(TRawPtr decodedBlob) { + try { + return ParseImpl(decodedBlob); + } catch (const std::exception& e) { + throw yexception() << "Failed to parse roles from tirole: " << e.what() + << ". '" << *decodedBlob << "'"; + } + } + + TRolesPtr TParser::ParseImpl(TRawPtr decodedBlob) { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(*decodedBlob, &doc), "Invalid json"); + Y_ENSURE(doc.IsMap(), "Json must be object"); + + TRoles::TTvmConsumers tvm = GetConsumers<TTvmId>(doc, "tvm"); + TRoles::TUserConsumers user = GetConsumers<TUid>(doc, "user"); + + // fetch it last to provide more correct apply instant + TRoles::TMeta meta = GetMeta(doc); + + return std::make_shared<TRoles>( + std::move(meta), + std::move(tvm), + std::move(user), + std::move(decodedBlob)); + } + + TRoles::TMeta TParser::GetMeta(const NJson::TJsonValue& doc) { + TRoles::TMeta res; + + NJson::TJsonValue obj; + GetRequiredValue(doc, "revision", obj); + if (obj.IsString()) { + res.Revision = obj.GetString(); + } else if (obj.IsUInteger()) { + res.Revision = ToString(obj.GetUInteger()); + } else { + ythrow yexception() << "'revision' has unexpected type: " << obj.GetType(); + } + + res.BornTime = TInstant::Seconds(GetRequiredUInt(doc, "born_date")); + + return res; + } + + template <typename Id> + THashMap<Id, TConsumerRolesPtr> TParser::GetConsumers(const NJson::TJsonValue& doc, + TStringBuf type) { + THashMap<Id, TConsumerRolesPtr> res; + + NJson::TJsonValue obj; + if (!GetOptionalMap(doc, type, obj)) { + return res; + } + + for (const auto& [key, value] : obj.GetMap()) { + Y_ENSURE(value.IsMap(), + "roles for consumer must be map: '" << key << "' is " << value.GetType()); + + Id id = 0; + Y_ENSURE(TryIntFromString<10>(key, id), + "id must be valid positive number of proper size for " + << type << ". got '" + << key << "'"); + + Y_ENSURE(res.emplace(id, GetConsumer(value, key)).second, + "consumer duplicate detected: '" << key << "' for " << type); + } + + return res; + } + + TConsumerRolesPtr TParser::GetConsumer(const NJson::TJsonValue& obj, TStringBuf consumer) { + TEntitiesByRoles entities; + + for (const auto& [key, value] : obj.GetMap()) { + Y_ENSURE(value.IsArray(), + "entities for roles must be array: '" << key << "' is " << value.GetType()); + + entities.emplace(key, GetEntities(value, consumer, key)); + } + + return std::make_shared<TConsumerRoles>(std::move(entities)); + } + + TEntitiesPtr TParser::GetEntities(const NJson::TJsonValue& obj, + TStringBuf consumer, + TStringBuf role) { + std::vector<TEntityPtr> entities; + entities.reserve(obj.GetArray().size()); + + for (const NJson::TJsonValue& e : obj.GetArray()) { + Y_ENSURE(e.IsMap(), + "role entity for role must be map: consumer '" + << consumer << "' with role '" << role << "' has " << e.GetType()); + + entities.push_back(GetEntity(e, consumer, role)); + } + + return std::make_shared<TEntities>(TEntities(entities)); + } + + TEntityPtr TParser::GetEntity(const NJson::TJsonValue& obj, TStringBuf consumer, TStringBuf role) { + TEntityPtr res = std::make_shared<TEntity>(); + + for (const auto& [key, value] : obj.GetMap()) { + Y_ENSURE(value.IsString(), + "entity is map (str->str), got value " + << value.GetType() << ". consumer '" + << consumer << "' with role '" << role << "'"); + + res->emplace(key, value.GetString()); + } + + return res; + } +} diff --git a/library/cpp/tvmauth/client/misc/roles/parser.h b/library/cpp/tvmauth/client/misc/roles/parser.h new file mode 100644 index 0000000000..0982ba78c6 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/parser.h @@ -0,0 +1,36 @@ +#pragma once + +#include "roles.h" +#include "types.h" + +namespace NJson { + class TJsonValue; +} + +namespace NTvmAuth::NRoles { + class TParser { + public: + static TRolesPtr Parse(TRawPtr decodedBlob); + + public: + static TRolesPtr ParseImpl(TRawPtr decodedBlob); + static TRoles::TMeta GetMeta(const NJson::TJsonValue& doc); + + template <typename Id> + static THashMap<Id, TConsumerRolesPtr> GetConsumers( + const NJson::TJsonValue& doc, + TStringBuf key); + + static TConsumerRolesPtr GetConsumer( + const NJson::TJsonValue& obj, + TStringBuf consumer); + static TEntitiesPtr GetEntities( + const NJson::TJsonValue& obj, + TStringBuf consumer, + TStringBuf role); + static TEntityPtr GetEntity( + const NJson::TJsonValue& obj, + TStringBuf consumer, + TStringBuf role); + }; +} diff --git a/library/cpp/tvmauth/client/misc/roles/roles.cpp b/library/cpp/tvmauth/client/misc/roles/roles.cpp new file mode 100644 index 0000000000..0761033104 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/roles.cpp @@ -0,0 +1,101 @@ +#include "roles.h" + +#include <library/cpp/tvmauth/checked_service_ticket.h> +#include <library/cpp/tvmauth/checked_user_ticket.h> + +namespace NTvmAuth::NRoles { + TRoles::TRoles(TMeta&& meta, + TTvmConsumers tvm, + TUserConsumers user, + TRawPtr raw) + : Meta_(std::move(meta)) + , TvmIds_(std::move(tvm)) + , Users_(std::move(user)) + , Raw_(std::move(raw)) + { + Y_ENSURE(Raw_); + } + + TConsumerRolesPtr TRoles::GetRolesForService(const TCheckedServiceTicket& t) const { + Y_ENSURE_EX(t, + TIllegalUsage() << "Service ticket must be valid, got: " << t.GetStatus()); + auto it = TvmIds_.find(t.GetSrc()); + return it == TvmIds_.end() ? TConsumerRolesPtr() : it->second; + } + + TConsumerRolesPtr TRoles::GetRolesForUser(const TCheckedUserTicket& t, + std::optional<TUid> selectedUid) const { + Y_ENSURE_EX(t, + TIllegalUsage() << "User ticket must be valid, got: " << t.GetStatus()); + Y_ENSURE_EX(t.GetEnv() == EBlackboxEnv::ProdYateam, + TIllegalUsage() << "User ticket must be from ProdYateam, got from " << t.GetEnv()); + + TUid uid = t.GetDefaultUid(); + if (selectedUid) { + auto it = std::find(t.GetUids().begin(), t.GetUids().end(), *selectedUid); + Y_ENSURE_EX(it != t.GetUids().end(), + TIllegalUsage() << "selectedUid must be in user ticket but it's not: " + << *selectedUid); + uid = *selectedUid; + } + + auto it = Users_.find(uid); + return it == Users_.end() ? TConsumerRolesPtr() : it->second; + } + + const TRoles::TMeta& TRoles::GetMeta() const { + return Meta_; + } + + const TString& TRoles::GetRaw() const { + return *Raw_; + } + + bool TRoles::CheckServiceRole(const TCheckedServiceTicket& t, + const TStringBuf roleName) const { + TConsumerRolesPtr c = GetRolesForService(t); + return c ? c->HasRole(roleName) : false; + } + + bool TRoles::CheckUserRole(const TCheckedUserTicket& t, + const TStringBuf roleName, + std::optional<TUid> selectedUid) const { + TConsumerRolesPtr c = GetRolesForUser(t, selectedUid); + return c ? c->HasRole(roleName) : false; + } + + bool TRoles::CheckServiceRoleForExactEntity(const TCheckedServiceTicket& t, + const TStringBuf roleName, + const TEntity& exactEntity) const { + TConsumerRolesPtr c = GetRolesForService(t); + return c ? c->CheckRoleForExactEntity(roleName, exactEntity) : false; + } + + bool TRoles::CheckUserRoleForExactEntity(const TCheckedUserTicket& t, + const TStringBuf roleName, + const TEntity& exactEntity, + std::optional<TUid> selectedUid) const { + TConsumerRolesPtr c = GetRolesForUser(t, selectedUid); + return c ? c->CheckRoleForExactEntity(roleName, exactEntity) : false; + } + + TConsumerRoles::TConsumerRoles(TEntitiesByRoles roles) + : Roles_(std::move(roles)) + { + } + + bool TConsumerRoles::CheckRoleForExactEntity(const TStringBuf roleName, + const TEntity& exactEntity) const { + auto it = Roles_.find(roleName); + if (it == Roles_.end()) { + return false; + } + + return it->second->Contains(exactEntity); + } + + TEntities::TEntities(TEntitiesIndex idx) + : Idx_(std::move(idx)) + { + } +} diff --git a/library/cpp/tvmauth/client/misc/roles/roles.h b/library/cpp/tvmauth/client/misc/roles/roles.h new file mode 100644 index 0000000000..6d510ee8a1 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/roles.h @@ -0,0 +1,186 @@ +#pragma once + +#include "entities_index.h" +#include "types.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/type.h> + +#include <util/datetime/base.h> +#include <util/generic/array_ref.h> +#include <util/generic/hash.h> + +#include <vector> + +namespace NTvmAuth { + class TCheckedServiceTicket; + class TCheckedUserTicket; +} + +namespace NTvmAuth::NRoles { + class TRoles { + public: + struct TMeta { + TString Revision; + TInstant BornTime; + TInstant Applied = TInstant::Now(); + }; + + using TTvmConsumers = THashMap<TTvmId, TConsumerRolesPtr>; + using TUserConsumers = THashMap<TUid, TConsumerRolesPtr>; + + TRoles(TMeta&& meta, + TTvmConsumers tvm, + TUserConsumers user, + TRawPtr raw); + + /** + * @return ptr to roles. It will be nullptr if there are no roles + */ + TConsumerRolesPtr GetRolesForService(const TCheckedServiceTicket& t) const; + + /** + * @return ptr to roles. It will be nullptr if there are no roles + */ + TConsumerRolesPtr GetRolesForUser(const TCheckedUserTicket& t, + std::optional<TUid> selectedUid = {}) const; + + const TMeta& GetMeta() const; + const TString& GetRaw() const; + + public: // shortcuts + /** + * @brief CheckServiceRole() is shortcut for simple role checking - for any possible entity + */ + bool CheckServiceRole( + const TCheckedServiceTicket& t, + const TStringBuf roleName) const; + + /** + * @brief CheckUserRole() is shortcut for simple role checking - for any possible entity + */ + bool CheckUserRole( + const TCheckedUserTicket& t, + const TStringBuf roleName, + std::optional<TUid> selectedUid = {}) const; + + /** + * @brief CheckServiceRoleForExactEntity() is shortcut for simple role checking for exact entity + */ + bool CheckServiceRoleForExactEntity( + const TCheckedServiceTicket& t, + const TStringBuf roleName, + const TEntity& exactEntity) const; + + /** + * @brief CheckUserRoleForExactEntity() is shortcut for simple role checking for exact entity + */ + bool CheckUserRoleForExactEntity( + const TCheckedUserTicket& t, + const TStringBuf roleName, + const TEntity& exactEntity, + std::optional<TUid> selectedUid = {}) const; + + private: + TMeta Meta_; + TTvmConsumers TvmIds_; + TUserConsumers Users_; + TRawPtr Raw_; + }; + + class TConsumerRoles { + public: + TConsumerRoles(TEntitiesByRoles roles); + + bool HasRole(const TStringBuf roleName) const { + return Roles_.contains(roleName); + } + + const TEntitiesByRoles& GetRoles() const { + return Roles_; + } + + /** + * @return ptr to entries. It will be nullptr if there is no role + */ + TEntitiesPtr GetEntitiesForRole(const TStringBuf roleName) const { + auto it = Roles_.find(roleName); + return it == Roles_.end() ? TEntitiesPtr() : it->second; + } + + /** + * @brief CheckRoleForExactEntity() is shortcut for simple role checking for exact entity + */ + bool CheckRoleForExactEntity(const TStringBuf roleName, + const TEntity& exactEntity) const; + + private: + TEntitiesByRoles Roles_; + }; + + class TEntities { + public: + TEntities(TEntitiesIndex idx); + + /** + * @brief Contains() provides info about entity presence + */ + bool Contains(const TEntity& exactEntity) const { + return Idx_.ContainsExactEntity(exactEntity.begin(), exactEntity.end()); + } + + /** + * @brief The same as Contains() + * It checks span for sorted and unique properties. + */ + template <class StrKey = TString, class StrValue = TString> + bool ContainsSortedUnique( + const TArrayRef<const std::pair<StrKey, StrValue>>& exactEntity) const { + CheckSpan(exactEntity); + return Idx_.ContainsExactEntity(exactEntity.begin(), exactEntity.end()); + } + + /** + * @brief GetEntitiesWithAttrs() collects entities with ALL attributes from `attrs` + */ + template <class StrKey = TString, class StrValue = TString> + const std::vector<TEntityPtr>& GetEntitiesWithAttrs( + const std::map<StrKey, StrValue>& attrs) const { + return Idx_.GetEntitiesWithAttrs(attrs.begin(), attrs.end()); + } + + /** + * @brief The same as GetEntitiesWithAttrs() + * It checks span for sorted and unique properties. + */ + template <class StrKey = TString, class StrValue = TString> + const std::vector<TEntityPtr>& GetEntitiesWithSortedUniqueAttrs( + const TArrayRef<const std::pair<StrKey, StrValue>>& attrs) const { + CheckSpan(attrs); + return Idx_.GetEntitiesWithAttrs(attrs.begin(), attrs.end()); + } + + private: + template <class StrKey, class StrValue> + static void CheckSpan(const TArrayRef<const std::pair<StrKey, StrValue>>& attrs) { + if (attrs.empty()) { + return; + } + + auto prev = attrs.begin(); + for (auto it = prev + 1; it != attrs.end(); ++it) { + Y_ENSURE_EX(prev->first != it->first, + TIllegalUsage() << "attrs are not unique: '" << it->first << "'"); + Y_ENSURE_EX(prev->first < it->first, + TIllegalUsage() << "attrs are not sorted: '" << prev->first + << "' before '" << it->first << "'"); + + prev = it; + } + } + + private: + TEntitiesIndex Idx_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/roles/types.h b/library/cpp/tvmauth/client/misc/roles/types.h new file mode 100644 index 0000000000..de0745e72e --- /dev/null +++ b/library/cpp/tvmauth/client/misc/roles/types.h @@ -0,0 +1,70 @@ +#pragma once + +#include <util/generic/hash_set.h> + +#include <map> +#include <memory> + +namespace NTvmAuth::NRoles { + using TEntity = std::map<TString, TString>; + using TEntityPtr = std::shared_ptr<TEntity>; + + class TEntities; + using TEntitiesPtr = std::shared_ptr<TEntities>; + + using TEntitiesByRoles = THashMap<TString, TEntitiesPtr>; + + class TConsumerRoles; + using TConsumerRolesPtr = std::shared_ptr<TConsumerRoles>; + + class TRoles; + using TRolesPtr = std::shared_ptr<TRoles>; + + using TRawPtr = std::shared_ptr<TString>; + + template <class T> + struct TKeyValueBase { + T Key; + T Value; + + template <typename U> + bool operator==(const TKeyValueBase<U>& o) const { + return Key == o.Key && Value == o.Value; + } + }; + + using TKeyValue = TKeyValueBase<TString>; + using TKeyValueView = TKeyValueBase<TStringBuf>; +} + +// Traits + +template <> +struct THash<NTvmAuth::NRoles::TKeyValue> { + std::size_t operator()(const NTvmAuth::NRoles::TKeyValue& e) const { + return std::hash<std::string_view>()(e.Key) + std::hash<std::string_view>()(e.Value); + } + + std::size_t operator()(const NTvmAuth::NRoles::TKeyValueView& e) const { + return std::hash<std::string_view>()(e.Key) + std::hash<std::string_view>()(e.Value); + } +}; + +template <> +struct TEqualTo<NTvmAuth::NRoles::TKeyValue> { + using is_transparent = std::true_type; + + template <typename T, typename U> + bool operator()(const NTvmAuth::NRoles::TKeyValueBase<T>& l, + const NTvmAuth::NRoles::TKeyValueBase<U>& r) { + return l == r; + } +}; + +inline bool operator<(const NTvmAuth::NRoles::TEntityPtr& l, const NTvmAuth::NRoles::TEntityPtr& r) { + return *l < *r; +} + +inline bool operator==(const NTvmAuth::NRoles::TEntityPtr& l, const NTvmAuth::NRoles::TEntityPtr& r) { + return *l == *r; +} diff --git a/library/cpp/tvmauth/client/misc/service_tickets.h b/library/cpp/tvmauth/client/misc/service_tickets.h new file mode 100644 index 0000000000..6a24bd5689 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/service_tickets.h @@ -0,0 +1,86 @@ +#pragma once + +#include "settings.h" +#include "roles/roles.h" + +#include <library/cpp/tvmauth/src/utils.h> + +#include <util/datetime/base.h> +#include <util/generic/hash.h> +#include <util/generic/maybe.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> + +namespace NTvmAuth::NInternal { + class TClientCaningKnife; +} + +namespace NTvmAuth { + class TServiceTickets: public TAtomicRefCount<TServiceTickets> { + public: + using TMapAliasStr = THashMap<TClientSettings::TAlias, TString>; + using TMapIdStr = THashMap<TTvmId, TString>; + using TIdSet = THashSet<TTvmId>; + using TAliasSet = THashSet<TClientSettings::TAlias>; + using TMapAliasId = THashMap<TClientSettings::TAlias, TTvmId>; + + TServiceTickets(TMapIdStr&& tickets, TMapIdStr&& errors, const TMapAliasId& dstMap) + : TicketsById(std::move(tickets)) + , ErrorsById(std::move(errors)) + { + InitAliasesAndUnfetchedIds(dstMap); + InitInvalidationTime(); + } + + static TInstant GetInvalidationTime(const TMapIdStr& ticketsById) { + TInstant res; + + for (const auto& pair : ticketsById) { + TMaybe<TInstant> t = NTvmAuth::NInternal::TCanningKnife::GetExpirationTime(pair.second); + if (!t) { + continue; + } + + res = res == TInstant() ? *t : std::min(res, *t); + } + + return res; + } + + public: + TMapIdStr TicketsById; + TMapIdStr ErrorsById; + TMapAliasStr TicketsByAlias; + TMapAliasStr ErrorsByAlias; + TInstant InvalidationTime; + TIdSet UnfetchedIds; + TAliasSet UnfetchedAliases; + + private: + void InitAliasesAndUnfetchedIds(const TMapAliasId& dstMap) { + for (const auto& pair : dstMap) { + auto it = TicketsById.find(pair.second); + auto errIt = ErrorsById.find(pair.second); + + if (it == TicketsById.end()) { + if (errIt != ErrorsById.end()) { + Y_ENSURE(ErrorsByAlias.insert({pair.first, errIt->second}).second, + "failed to add: " << pair.first); + } else { + UnfetchedAliases.insert(pair.first); + UnfetchedIds.insert(pair.second); + } + } else { + Y_ENSURE(TicketsByAlias.insert({pair.first, it->second}).second, + "failed to add: " << pair.first); + } + } + } + + void InitInvalidationTime() { + InvalidationTime = GetInvalidationTime(TicketsById); + } + }; + + using TServiceTicketsPtr = TIntrusiveConstPtr<TServiceTickets>; +} diff --git a/library/cpp/tvmauth/client/misc/settings.h b/library/cpp/tvmauth/client/misc/settings.h new file mode 100644 index 0000000000..8fae6c34d3 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/settings.h @@ -0,0 +1,13 @@ +#pragma once + +#include <util/generic/fwd.h> + +namespace NTvmAuth { + class TClientSettings { + public: + /*! + * Look at description in relevant settings: NTvmApi::TClientSettings or NTvmTool::TClientSettings + */ + using TAlias = TString; + }; +} diff --git a/library/cpp/tvmauth/client/misc/src_checker.h b/library/cpp/tvmauth/client/misc/src_checker.h new file mode 100644 index 0000000000..bb99fe8884 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/src_checker.h @@ -0,0 +1,31 @@ +#pragma once + +#include "roles/roles.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/checked_service_ticket.h> +#include <library/cpp/tvmauth/src/service_impl.h> +#include <library/cpp/tvmauth/src/utils.h> + +namespace NTvmAuth { + class TSrcChecker { + public: + /*! + * Checking must be enabled in TClientSettings + * Can throw exception if cache is out of date or wrong config + * @param ticket + */ + static TCheckedServiceTicket Check(TCheckedServiceTicket ticket, NRoles::TRolesPtr r) { + Y_ENSURE_EX(r, TBrokenTvmClientSettings() << "Need to use TClientSettings::EnableRolesFetching()"); + NRoles::TConsumerRolesPtr roles = r->GetRolesForService(ticket); + if (roles) { + return ticket; + } + + TServiceTicketImplPtr impl = THolder(NInternal::TCanningKnife::GetS(ticket)); + impl->SetStatus(ETicketStatus::NoRoles); + return TCheckedServiceTicket(std::move(impl)); + } + }; +} diff --git a/library/cpp/tvmauth/client/misc/threaded_updater.cpp b/library/cpp/tvmauth/client/misc/threaded_updater.cpp new file mode 100644 index 0000000000..5d21ce67a7 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/threaded_updater.cpp @@ -0,0 +1,111 @@ +#include "threaded_updater.h" + +#include <library/cpp/tvmauth/client/exception.h> + +#include <util/string/builder.h> +#include <util/system/spin_wait.h> +#include <util/system/thread.h> + +namespace NTvmAuth { + TThreadedUpdaterBase::TThreadedUpdaterBase(TDuration workerAwakingPeriod, + TLoggerPtr logger, + const TString& url, + ui16 port, + TDuration socketTimeout, + TDuration connectTimeout) + : WorkerAwakingPeriod_(workerAwakingPeriod) + , Logger_(std::move(logger)) + , TvmUrl_(url) + , TvmPort_(port) + , TvmSocketTimeout_(socketTimeout) + , TvmConnectTimeout_(connectTimeout) + , IsStopped_(true) + { + Y_ENSURE_EX(Logger_, TNonRetriableException() << "Logger is required"); + + ServiceTicketsDurations_.RefreshPeriod = TDuration::Hours(1); + ServiceTicketsDurations_.Expiring = TDuration::Hours(2); + ServiceTicketsDurations_.Invalid = TDuration::Hours(11); + + PublicKeysDurations_.RefreshPeriod = TDuration::Days(1); + PublicKeysDurations_.Expiring = TDuration::Days(2); + PublicKeysDurations_.Invalid = TDuration::Days(6); + } + + TThreadedUpdaterBase::~TThreadedUpdaterBase() { + StopWorker(); + } + + void TThreadedUpdaterBase::StartWorker() { + if (HttpClient_) { + HttpClient_->ResetConnection(); + } + Thread_ = MakeHolder<TThread>(WorkerWrap, this); + Thread_->Start(); + Started_.Wait(); + IsStopped_ = false; + } + + void TThreadedUpdaterBase::StopWorker() { + Event_.Signal(); + if (Thread_) { + Thread_.Reset(); + } + } + + TKeepAliveHttpClient& TThreadedUpdaterBase::GetClient() const { + if (!HttpClient_) { + HttpClient_ = MakeHolder<TKeepAliveHttpClient>(TvmUrl_, TvmPort_, TvmSocketTimeout_, TvmConnectTimeout_); + } + + return *HttpClient_; + } + + void TThreadedUpdaterBase::LogDebug(const TString& msg) const { + if (Logger_) { + Logger_->Debug(msg); + } + } + + void TThreadedUpdaterBase::LogInfo(const TString& msg) const { + if (Logger_) { + Logger_->Info(msg); + } + } + + void TThreadedUpdaterBase::LogWarning(const TString& msg) const { + if (Logger_) { + Logger_->Warning(msg); + } + } + + void TThreadedUpdaterBase::LogError(const TString& msg) const { + if (Logger_) { + Logger_->Error(msg); + } + } + + void* TThreadedUpdaterBase::WorkerWrap(void* arg) { + TThread::SetCurrentThreadName("TicketParserUpd"); + TThreadedUpdaterBase& this_ = *reinterpret_cast<TThreadedUpdaterBase*>(arg); + this_.Started_.Signal(); + this_.LogDebug("Thread-worker started"); + + while (true) { + if (this_.Event_.WaitT(this_.WorkerAwakingPeriod_)) { + break; + } + + try { + this_.Worker(); + this_.GetClient().ResetConnection(); + } catch (const std::exception& e) { // impossible now + this_.LogError(TStringBuilder() << "Failed to generate new cache: " << e.what()); + } + } + + this_.LogDebug("Thread-worker stopped"); + this_.IsStopped_ = true; + return nullptr; + } +} diff --git a/library/cpp/tvmauth/client/misc/threaded_updater.h b/library/cpp/tvmauth/client/misc/threaded_updater.h new file mode 100644 index 0000000000..783684ba3b --- /dev/null +++ b/library/cpp/tvmauth/client/misc/threaded_updater.h @@ -0,0 +1,76 @@ +#pragma once + +#include "async_updater.h" +#include "settings.h" + +#include <library/cpp/tvmauth/client/logger.h> + +#include <library/cpp/http/simple/http_client.h> + +#include <util/datetime/base.h> +#include <util/generic/ptr.h> +#include <util/system/event.h> +#include <util/system/thread.h> + +class TKeepAliveHttpClient; + +namespace NTvmAuth::NInternal { + class TClientCaningKnife; +} +namespace NTvmAuth { + class TThreadedUpdaterBase: public TAsyncUpdaterBase { + public: + TThreadedUpdaterBase(TDuration workerAwakingPeriod, + TLoggerPtr logger, + const TString& url, + ui16 port, + TDuration socketTimeout, + TDuration connectTimeout); + virtual ~TThreadedUpdaterBase(); + + protected: + void StartWorker(); + void StopWorker(); + + virtual void Worker() { + } + + TKeepAliveHttpClient& GetClient() const; + + void LogDebug(const TString& msg) const; + void LogInfo(const TString& msg) const; + void LogWarning(const TString& msg) const; + void LogError(const TString& msg) const; + + protected: + TDuration WorkerAwakingPeriod_; + + const TLoggerPtr Logger_; + + protected: + const TString TvmUrl_; + + private: + static void* WorkerWrap(void* arg); + + void StartTvmClientStopping() const override { + Event_.Signal(); + } + + bool IsTvmClientStopped() const override { + return IsStopped_; + } + + private: + mutable THolder<TKeepAliveHttpClient> HttpClient_; + + const ui32 TvmPort_; + const TDuration TvmSocketTimeout_; + const TDuration TvmConnectTimeout_; + + mutable TAutoEvent Event_; + mutable TAutoEvent Started_; + std::atomic_bool IsStopped_; + THolder<TThread> Thread_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/tool/meta_info.cpp b/library/cpp/tvmauth/client/misc/tool/meta_info.cpp new file mode 100644 index 0000000000..9a0ae228fe --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/meta_info.cpp @@ -0,0 +1,208 @@ +#include "meta_info.h" + +#include <library/cpp/json/json_reader.h> + +#include <util/string/builder.h> + +namespace NTvmAuth::NTvmTool { + TString TMetaInfo::TConfig::ToString() const { + TStringStream s; + s << "self_tvm_id=" << SelfTvmId << ", " + << "bb_env=" << BbEnv << ", " + << "idm_slug=" << (IdmSlug ? IdmSlug : "<NULL>") << ", " + << "dsts=["; + + for (const auto& pair : DstAliases) { + s << "(" << pair.first << ":" << pair.second << ")"; + } + + s << "]"; + + return std::move(s.Str()); + } + + TMetaInfo::TMetaInfo(TLoggerPtr logger) + : Logger_(std::move(logger)) + { + } + + TMetaInfo::TConfigPtr TMetaInfo::Init(TKeepAliveHttpClient& client, + const TClientSettings& settings) { + ApplySettings(settings); + + TryPing(client); + const TString metaString = Fetch(client); + if (Logger_) { + TStringStream s; + s << "Meta info fetched from " << settings.GetHostname() << ":" << settings.GetPort(); + Logger_->Debug(s.Str()); + } + + try { + Config_.Set(ParseMetaString(metaString, SelfAlias_)); + } catch (const yexception& e) { + ythrow TNonRetriableException() << "Malformed json from tvmtool: " << e.what(); + } + TConfigPtr cfg = Config_.Get(); + Y_ENSURE_EX(cfg, TNonRetriableException() << "Alias '" << SelfAlias_ << "' not found in meta info"); + + if (Logger_) { + Logger_->Info("Meta: " + cfg->ToString()); + } + + return cfg; + } + + TString TMetaInfo::GetRequestForTickets(const TConfig& config) { + Y_ENSURE(!config.DstAliases.empty()); + + TStringStream s; + s << "/tvm/tickets" + << "?src=" << config.SelfTvmId + << "&dsts="; + + for (const auto& pair : config.DstAliases) { + s << pair.second << ","; // avoid aliases - url-encoding required + } + s.Str().pop_back(); + + return s.Str(); + } + + bool TMetaInfo::TryUpdateConfig(TKeepAliveHttpClient& client) { + const TString metaString = Fetch(client); + + TConfigPtr config; + try { + config = ParseMetaString(metaString, SelfAlias_); + } catch (const yexception& e) { + ythrow TNonRetriableException() << "Malformed json from tvmtool: " << e.what(); + } + Y_ENSURE_EX(config, TNonRetriableException() << "Alias '" << SelfAlias_ << "' not found in meta info"); + + TConfigPtr oldConfig = Config_.Get(); + if (*config == *oldConfig) { + return false; + } + + if (Logger_) { + Logger_->Info(TStringBuilder() + << "Meta was updated. Old: (" << oldConfig->ToString() + << "). New: (" << config->ToString() << ")"); + } + + Config_ = config; + return true; + } + + void TMetaInfo::TryPing(TKeepAliveHttpClient& client) { + try { + TStringStream s; + TKeepAliveHttpClient::THttpCode code = client.DoGet("/tvm/ping", &s); + if (code < 200 || 300 <= code) { + throw yexception() << "(" << code << ") " << s.Str(); + } + } catch (const std::exception& e) { + ythrow TNonRetriableException() << "Failed to connect to tvmtool: " << e.what(); + } + } + + TString TMetaInfo::Fetch(TKeepAliveHttpClient& client) const { + TStringStream res; + TKeepAliveHttpClient::THttpCode code; + try { + code = client.DoGet("/tvm/private_api/__meta__", &res, AuthHeader_); + } catch (const std::exception& e) { + ythrow TRetriableException() << "Failed to fetch meta data from tvmtool: " << e.what(); + } + + if (code != 200) { + Y_ENSURE_EX(code != 404, + TNonRetriableException() << "Library does not support so old tvmtool. You need tvmtool>=1.1.0"); + + TStringStream err; + err << "Failed to fetch meta from tvmtool: " << client.GetHost() << ":" << client.GetPort() + << " (" << code << "): " << res.Str(); + Y_ENSURE_EX(!(500 <= code && code < 600), TRetriableException() << err.Str()); + ythrow TNonRetriableException() << err.Str(); + } + + return res.Str(); + } + + static TMetaInfo::TDstAliases::value_type ParsePair(const NJson::TJsonValue& val, const TString& meta) { + NJson::TJsonValue jAlias; + Y_ENSURE(val.GetValue("alias", &jAlias), meta); + Y_ENSURE(jAlias.IsString(), meta); + + NJson::TJsonValue jClientId; + Y_ENSURE(val.GetValue("client_id", &jClientId), meta); + Y_ENSURE(jClientId.IsInteger(), meta); + + return {jAlias.GetString(), jClientId.GetInteger()}; + } + + TMetaInfo::TConfigPtr TMetaInfo::ParseMetaString(const TString& meta, const TString& self) { + NJson::TJsonValue jDoc; + Y_ENSURE(NJson::ReadJsonTree(meta, &jDoc), meta); + + NJson::TJsonValue jEnv; + Y_ENSURE(jDoc.GetValue("bb_env", &jEnv), meta); + + NJson::TJsonValue jTenants; + Y_ENSURE(jDoc.GetValue("tenants", &jTenants), meta); + Y_ENSURE(jTenants.IsArray(), meta); + + for (const NJson::TJsonValue& jTen : jTenants.GetArray()) { + NJson::TJsonValue jSelf; + Y_ENSURE(jTen.GetValue("self", &jSelf), meta); + auto selfPair = ParsePair(jSelf, meta); + if (selfPair.first != self) { + continue; + } + + TConfigPtr config = std::make_shared<TConfig>(); + config->SelfTvmId = selfPair.second; + config->BbEnv = BbEnvFromString(jEnv.GetString(), meta); + + { + NJson::TJsonValue jSlug; + if (jTen.GetValue("idm_slug", &jSlug)) { + config->IdmSlug = jSlug.GetString(); + } + } + + NJson::TJsonValue jDsts; + Y_ENSURE(jTen.GetValue("dsts", &jDsts), meta); + Y_ENSURE(jDsts.IsArray(), meta); + for (const NJson::TJsonValue& jDst : jDsts.GetArray()) { + config->DstAliases.insert(ParsePair(jDst, meta)); + } + + return config; + } + + return {}; + } + + void TMetaInfo::ApplySettings(const TClientSettings& settings) { + AuthHeader_ = {{"Authorization", settings.GetAuthToken()}}; + SelfAlias_ = settings.GetSelfAlias(); + } + + EBlackboxEnv TMetaInfo::BbEnvFromString(const TString& env, const TString& meta) { + if (env == "Prod") { + return EBlackboxEnv::Prod; + } else if (env == "Test") { + return EBlackboxEnv::Test; + } else if (env == "ProdYaTeam") { + return EBlackboxEnv::ProdYateam; + } else if (env == "TestYaTeam") { + return EBlackboxEnv::TestYateam; + } else if (env == "Stress") { + return EBlackboxEnv::Stress; + } + + ythrow yexception() << "'bb_env'=='" << env << "'. " << meta; + } +} diff --git a/library/cpp/tvmauth/client/misc/tool/meta_info.h b/library/cpp/tvmauth/client/misc/tool/meta_info.h new file mode 100644 index 0000000000..9dd4f0dbf8 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/meta_info.h @@ -0,0 +1,69 @@ +#pragma once + +#include "settings.h" + +#include <library/cpp/tvmauth/client/misc/utils.h> + +#include <library/cpp/tvmauth/client/logger.h> + +#include <library/cpp/http/simple/http_client.h> + +namespace NTvmAuth::NTvmTool { + class TMetaInfo { + public: + using TDstAliases = THashMap<TClientSettings::TAlias, TTvmId>; + + struct TConfig { + TTvmId SelfTvmId = 0; + EBlackboxEnv BbEnv = EBlackboxEnv::Prod; + TString IdmSlug; + TDstAliases DstAliases; + + bool AreTicketsRequired() const { + return !DstAliases.empty(); + } + + TString ToString() const; + + bool operator==(const TConfig& c) const { + return SelfTvmId == c.SelfTvmId && + BbEnv == c.BbEnv && + IdmSlug == c.IdmSlug && + DstAliases == c.DstAliases; + } + }; + using TConfigPtr = std::shared_ptr<TConfig>; + + public: + TMetaInfo(TLoggerPtr logger); + + TConfigPtr Init(TKeepAliveHttpClient& client, + const TClientSettings& settings); + + static TString GetRequestForTickets(const TMetaInfo::TConfig& config); + + const TKeepAliveHttpClient::THeaders& GetAuthHeader() const { + return AuthHeader_; + } + + TConfigPtr GetConfig() const { + return Config_.Get(); + } + + bool TryUpdateConfig(TKeepAliveHttpClient& client); + + protected: + void TryPing(TKeepAliveHttpClient& client); + TString Fetch(TKeepAliveHttpClient& client) const; + static TConfigPtr ParseMetaString(const TString& meta, const TString& self); + void ApplySettings(const TClientSettings& settings); + static EBlackboxEnv BbEnvFromString(const TString& env, const TString& meta); + + protected: + NUtils::TProtectedValue<TConfigPtr> Config_; + TKeepAliveHttpClient::THeaders AuthHeader_; + + TLoggerPtr Logger_; + TString SelfAlias_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/tool/roles_fetcher.cpp b/library/cpp/tvmauth/client/misc/tool/roles_fetcher.cpp new file mode 100644 index 0000000000..05b0856edc --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/roles_fetcher.cpp @@ -0,0 +1,81 @@ +#include "roles_fetcher.h" + +#include <library/cpp/tvmauth/client/misc/roles/parser.h> + +#include <library/cpp/http/misc/httpcodes.h> +#include <library/cpp/string_utils/quote/quote.h> + +#include <util/string/builder.h> +#include <util/string/join.h> + +namespace NTvmAuth::NTvmTool { + TRolesFetcher::TRolesFetcher(const TRolesFetcherSettings& settings, TLoggerPtr logger) + : Settings_(settings) + , Logger_(std::move(logger)) + { + } + + bool TRolesFetcher::IsTimeToUpdate(TDuration sinceUpdate) const { + return Settings_.UpdatePeriod < sinceUpdate; + } + + bool TRolesFetcher::ShouldWarn(TDuration sinceUpdate) const { + return Settings_.WarnPeriod < sinceUpdate; + } + + bool TRolesFetcher::AreRolesOk() const { + return bool(GetCurrentRoles()); + } + + NUtils::TFetchResult TRolesFetcher::FetchActualRoles(const TKeepAliveHttpClient::THeaders& authHeader, + TKeepAliveHttpClient& client) const { + const TRequest req = CreateRequest(authHeader); + + TStringStream out; + THttpHeaders outHeaders; + + TKeepAliveHttpClient::THttpCode code = client.DoGet( + req.Url, + &out, + req.Headers, + &outHeaders); + + return {code, std::move(outHeaders), "/v2/roles", out.Str(), {}}; + } + + void TRolesFetcher::Update(NUtils::TFetchResult&& fetchResult) { + if (fetchResult.Code == HTTP_NOT_MODIFIED) { + Y_ENSURE(CurrentRoles_.Get(), + "tvmtool did not return any roles because current roles are actual," + " but there are no roles in memory - this should never happen"); + return; + } + + Y_ENSURE(fetchResult.Code == HTTP_OK, + "Unexpected code from tvmtool: " << fetchResult.Code << ". " << fetchResult.Response); + + CurrentRoles_.Set(NRoles::TParser::Parse(std::make_shared<TString>(std::move(fetchResult.Response)))); + + Logger_->Debug( + TStringBuilder() << "Succeed to update roles with revision " + << CurrentRoles_.Get()->GetMeta().Revision); + } + + NTvmAuth::NRoles::TRolesPtr TRolesFetcher::GetCurrentRoles() const { + return CurrentRoles_.Get(); + } + + TRolesFetcher::TRequest TRolesFetcher::CreateRequest(const TKeepAliveHttpClient::THeaders& authHeader) const { + TRequest request{ + .Url = "/v2/roles?self=" + CGIEscapeRet(Settings_.SelfAlias), + .Headers = authHeader, + }; + + NRoles::TRolesPtr roles = CurrentRoles_.Get(); + if (roles) { + request.Headers.emplace(IfNoneMatch_, Join("", "\"", roles->GetMeta().Revision, "\"")); + } + + return request; + } +} diff --git a/library/cpp/tvmauth/client/misc/tool/roles_fetcher.h b/library/cpp/tvmauth/client/misc/tool/roles_fetcher.h new file mode 100644 index 0000000000..8c60b59610 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/roles_fetcher.h @@ -0,0 +1,49 @@ +#pragma once + +#include <library/cpp/tvmauth/client/misc/fetch_result.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/roles/roles.h> + +#include <library/cpp/tvmauth/client/logger.h> + +#include <util/datetime/base.h> +#include <util/generic/string.h> + +namespace NTvmAuth::NTvmTool { + struct TRolesFetcherSettings { + TString SelfAlias; + TDuration UpdatePeriod = TDuration::Minutes(1); + TDuration WarnPeriod = TDuration::Minutes(20); + }; + + class TRolesFetcher { + public: + TRolesFetcher(const TRolesFetcherSettings& settings, TLoggerPtr logger); + + bool IsTimeToUpdate(TDuration sinceUpdate) const; + bool ShouldWarn(TDuration sinceUpdate) const; + bool AreRolesOk() const; + + NUtils::TFetchResult FetchActualRoles(const TKeepAliveHttpClient::THeaders& authHeader, + TKeepAliveHttpClient& client) const; + void Update(NUtils::TFetchResult&& fetchResult); + + NTvmAuth::NRoles::TRolesPtr GetCurrentRoles() const; + + protected: + struct TRequest { + TString Url; + TKeepAliveHttpClient::THeaders Headers; + }; + + protected: + TRequest CreateRequest(const TKeepAliveHttpClient::THeaders& authHeader) const; + + private: + const TRolesFetcherSettings Settings_; + const TLoggerPtr Logger_; + const TString IfNoneMatch_ = "If-None-Match"; + + NUtils::TProtectedValue<NTvmAuth::NRoles::TRolesPtr> CurrentRoles_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/tool/settings.cpp b/library/cpp/tvmauth/client/misc/tool/settings.cpp new file mode 100644 index 0000000000..894501f19d --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/settings.cpp @@ -0,0 +1,37 @@ +#include "settings.h" + +#include <library/cpp/string_utils/url/url.h> + +#include <util/system/env.h> + +namespace NTvmAuth::NTvmTool { + TClientSettings::TClientSettings(const TAlias& selfAias) + : SelfAias_(selfAias) + , Hostname_("localhost") + , Port_(1) + , SocketTimeout_(TDuration::Seconds(5)) + , ConnectTimeout_(TDuration::Seconds(30)) + { + AuthToken_ = GetEnv("TVMTOOL_LOCAL_AUTHTOKEN"); + if (!AuthToken_) { + AuthToken_ = GetEnv("QLOUD_TVM_TOKEN"); + } + TStringBuf auth(AuthToken_); + FixSpaces(auth); + AuthToken_ = auth; + + const TString url = GetEnv("DEPLOY_TVM_TOOL_URL"); + if (url) { + TStringBuf scheme, host; + TryGetSchemeHostAndPort(url, scheme, host, Port_); + } + + Y_ENSURE_EX(SelfAias_, TBrokenTvmClientSettings() << "Alias for your TVM client cannot be empty"); + } + + void TClientSettings::FixSpaces(TStringBuf& str) { + while (str && isspace(str.back())) { + str.Chop(1); + } + } +} diff --git a/library/cpp/tvmauth/client/misc/tool/settings.h b/library/cpp/tvmauth/client/misc/tool/settings.h new file mode 100644 index 0000000000..1267ca1527 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/settings.h @@ -0,0 +1,178 @@ +#pragma once + +#include <library/cpp/tvmauth/client/misc/settings.h> + +#include <library/cpp/tvmauth/client/exception.h> + +#include <library/cpp/tvmauth/checked_user_ticket.h> + +#include <util/datetime/base.h> +#include <util/generic/maybe.h> + +namespace NTvmAuth::NTvmTool { + /** + * Uses local http-interface to get state: http://localhost/tvm/. + * This interface can be provided with tvmtool (local daemon) or Qloud/YP (local http api in container). + * See more: https://wiki.yandex-team.ru/passport/tvm2/qloud/. + * + * Most part of settings will be fetched from tvmtool on start of client. + * You need to use aliases for TVM-clients (src and dst) which you specified in tvmtool or Qloud/YP interface + */ + class TClientSettings: public NTvmAuth::TClientSettings { + public: + /*! + * Sets default values: + * - hostname == "localhost" + * - port detected with env["DEPLOY_TVM_TOOL_URL"] (provided with Yandex.Deploy), + * otherwise port == 1 (it is ok for Qloud) + * - authToken: env["TVMTOOL_LOCAL_AUTHTOKEN"] (provided with Yandex.Deploy), + * otherwise env["QLOUD_TVM_TOKEN"] (provided with Qloud) + * + * AuthToken is protection from SSRF. + * + * @param selfAias - alias for your TVM client, which you specified in tvmtool or YD interface + */ + TClientSettings(const TAlias& selfAias); + + /*! + * Look at comment for ctor + * @param port + */ + TClientSettings& SetPort(ui16 port) { + Port_ = port; + return *this; + } + + /*! + * Default value: hostname == "localhost" + * @param hostname + */ + TClientSettings& SetHostname(const TString& hostname) { + Y_ENSURE_EX(hostname, TBrokenTvmClientSettings() << "Hostname cannot be empty"); + Hostname_ = hostname; + return *this; + } + + TClientSettings& SetSocketTimeout(TDuration socketTimeout) { + SocketTimeout_ = socketTimeout; + return *this; + } + + TClientSettings& SetConnectTimeout(TDuration connectTimeout) { + ConnectTimeout_ = connectTimeout; + return *this; + } + + /*! + * Look at comment for ctor + * @param token + */ + TClientSettings& SetAuthToken(TStringBuf token) { + FixSpaces(token); + Y_ENSURE_EX(token, TBrokenTvmClientSettings() << "Auth token cannot be empty"); + AuthToken_ = token; + return *this; + } + + /*! + * Blackbox environmet is provided by tvmtool for client. + * You can override it for your purpose with limitations: + * (env from tvmtool) -> (override) + * - Prod/ProdYateam -> Prod/ProdYateam + * - Test/TestYateam -> Test/TestYateam + * - Stress -> Stress + * + * You can contact tvm-dev@yandex-team.ru if limitations are too strict + * @param env + */ + TClientSettings& OverrideBlackboxEnv(EBlackboxEnv env) { + BbEnv_ = env; + return *this; + } + + /*! + * By default client checks src from ServiceTicket or default uid from UserTicket - + * to prevent you from forgetting to check it yourself. + * It does binary checks only: + * ticket gets status NoRoles, if there is no role for src or default uid. + * You need to check roles on your own if you have a non-binary role system or + * you have disabled ShouldCheckSrc/ShouldCheckDefaultUid + * + * You may need to disable this check in the following cases: + * - You use GetRoles() to provide verbose message (with revision). + * Double check may be inconsistent: + * binary check inside client uses revision of roles X - i.e. src 100500 has no role, + * exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + */ + bool ShouldCheckSrc = true; + bool ShouldCheckDefaultUid = true; + /*! + * By default client checks dst from ServiceTicket. If this check is switched off + * incorrect dst does not result in error of checked ticket status + * DANGEROUS: This case you must check dst manualy using @link TCheckedServiceTicket::GetDst() + */ + bool ShouldCheckDst = true; + + // In case of unsuccessful initialization at startup the client will be initialized in the background + bool EnableLazyInitialization = false; + + // DEPRECATED API + // TODO: get rid of it: PASSP-35377 + public: + // Deprecated: set attributes directly + TClientSettings& SetShouldCheckSrc(bool val = true) { + ShouldCheckSrc = val; + return *this; + } + + // Deprecated: set attributes directly + TClientSettings& SetSShouldCheckDefaultUid(bool val = true) { + ShouldCheckDefaultUid = val; + return *this; + } + + public: // for TAsyncUpdaterBase + const TAlias& GetSelfAlias() const { + return SelfAias_; + } + + const TString& GetHostname() const { + return Hostname_; + } + + ui16 GetPort() const { + return Port_; + } + + TDuration GetSocketTimeout() const { + return SocketTimeout_; + } + + TDuration GetConnectTimeout() const { + return ConnectTimeout_; + } + + const TString& GetAuthToken() const { + Y_ENSURE_EX(AuthToken_, TBrokenTvmClientSettings() + << "Auth token cannot be empty. " + << "Env 'TVMTOOL_LOCAL_AUTHTOKEN' and 'QLOUD_TVM_TOKEN' are empty."); + return AuthToken_; + } + + TMaybe<EBlackboxEnv> GetOverridedBlackboxEnv() const { + return BbEnv_; + } + + private: + void FixSpaces(TStringBuf& str); + + private: + TAlias SelfAias_; + TString Hostname_; + ui16 Port_; + TDuration SocketTimeout_; + TDuration ConnectTimeout_; + TString AuthToken_; + TMaybe<EBlackboxEnv> BbEnv_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/tool/threaded_updater.cpp b/library/cpp/tvmauth/client/misc/tool/threaded_updater.cpp new file mode 100644 index 0000000000..82dd57a77a --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/threaded_updater.cpp @@ -0,0 +1,442 @@ +#include "threaded_updater.h" + +#include <library/cpp/tvmauth/client/misc/checker.h> +#include <library/cpp/tvmauth/client/misc/default_uid_checker.h> +#include <library/cpp/tvmauth/client/misc/getter.h> +#include <library/cpp/tvmauth/client/misc/src_checker.h> +#include <library/cpp/tvmauth/client/misc/utils.h> + +#include <library/cpp/json/json_reader.h> + +#include <util/generic/hash_set.h> +#include <util/stream/str.h> +#include <util/string/ascii.h> +#include <util/string/builder.h> +#include <util/string/cast.h> + +namespace NTvmAuth::NTvmTool { + TAsyncUpdaterPtr TThreadedUpdater::Create(const TClientSettings& settings, TLoggerPtr logger) { + Y_ENSURE_EX(logger, TNonRetriableException() << "Logger is required"); + THolder<TThreadedUpdater> p(new TThreadedUpdater(settings, std::move(logger))); + + try { + p->Init(settings); + } catch (const TRetriableException& e) { + if (!settings.EnableLazyInitialization) { + throw e; + } + } + + p->StartWorker(); + return p.Release(); + } + + TThreadedUpdater::~TThreadedUpdater() { + StopWorker(); // Required here to avoid using of deleted members + } + + TClientStatus TThreadedUpdater::GetStatus() const { + const TClientStatus::ECode state = GetState(); + return TClientStatus(state, GetLastError(state == TClientStatus::Ok)); + } + + TString TThreadedUpdater::GetServiceTicketFor(const TClientSettings::TAlias& dst) const { + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + if (!MetaInfo_.GetConfig()->AreTicketsRequired()) { + throw TBrokenTvmClientSettings() << "Need to enable ServiceTickets fetching"; + } + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TString TThreadedUpdater::GetServiceTicketFor(const TTvmId dst) const { + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + if (!MetaInfo_.GetConfig()->AreTicketsRequired()) { + throw TBrokenTvmClientSettings() << "Need to enable ServiceTickets fetching"; + } + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TCheckedServiceTicket TThreadedUpdater::CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags) const { + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + TServiceContextPtr c = GetCachedServiceContext(); + TCheckedServiceTicket res = TServiceTicketChecker::Check(ticket, c, flags); + if (Settings_.ShouldCheckSrc && RolesFetcher_ && res) { + NRoles::TRolesPtr roles = GetRoles(); + return TSrcChecker::Check(std::move(res), roles); + } + return res; + } + + TCheckedUserTicket TThreadedUpdater::CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overridenEnv) const { + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + auto c = GetCachedUserContext(overridenEnv); + TCheckedUserTicket res = TUserTicketChecker::Check(ticket, c); + if (Settings_.ShouldCheckDefaultUid && RolesFetcher_ && res && res.GetEnv() == EBlackboxEnv::ProdYateam) { + NRoles::TRolesPtr roles = GetRoles(); + return TDefaultUidChecker::Check(std::move(res), roles); + } + return res; + } + + NRoles::TRolesPtr TThreadedUpdater::GetRoles() const { + Y_ENSURE_EX(IsInited(), TNotInitializedException() << "Client is not initialized"); + + Y_ENSURE_EX(RolesFetcher_, + TBrokenTvmClientSettings() << "Roles were not configured in settings"); + + return RolesFetcher_->GetCurrentRoles(); + } + + TClientStatus::ECode TThreadedUpdater::GetState() const { + const TInstant now = TInstant::Now(); + + if (!IsInited()) { + return TClientStatus::NotInitialized; + } + + const TMetaInfo::TConfigPtr config = MetaInfo_.GetConfig(); + + if ((config->AreTicketsRequired() && AreServiceTicketsInvalid(now)) || ArePublicKeysInvalid(now)) { + return TClientStatus::Error; + } + + if (config->AreTicketsRequired()) { + if (!GetCachedServiceTickets() || config->DstAliases.size() > GetCachedServiceTickets()->TicketsByAlias.size()) { + return TClientStatus::Error; + } + } + + const TDuration st = now - GetUpdateTimeOfServiceTickets(); + const TDuration pk = now - GetUpdateTimeOfPublicKeys(); + + if ((config->AreTicketsRequired() && st > ServiceTicketsDurations_.Expiring) || pk > PublicKeysDurations_.Expiring) { + return TClientStatus::Warning; + } + + if (RolesFetcher_ && RolesFetcher_->ShouldWarn(now - GetUpdateTimeOfRoles())) { + return TClientStatus::Warning; + } + + if (IsConfigWarnTime()) { + return TClientStatus::Warning; + } + + return TClientStatus::Ok; + } + + TThreadedUpdater::TThreadedUpdater(const TClientSettings& settings, TLoggerPtr logger) + : TThreadedUpdaterBase(TDuration::Seconds(5), logger, settings.GetHostname(), settings.GetPort(), settings.GetSocketTimeout(), settings.GetConnectTimeout()) + , MetaInfo_(logger) + , ConfigWarnDelay_(TDuration::Seconds(30)) + , Settings_(settings) + { + ServiceTicketsDurations_.RefreshPeriod = TDuration::Minutes(10); + PublicKeysDurations_.RefreshPeriod = TDuration::Minutes(10); + } + + void TThreadedUpdater::Init(const TClientSettings& settings) { + const TMetaInfo::TConfigPtr config = MetaInfo_.Init(GetClient(), settings); + LastVisitForConfig_ = TInstant::Now(); + + SetBbEnv(config->BbEnv, settings.GetOverridedBlackboxEnv()); + if (settings.GetOverridedBlackboxEnv()) { + LogInfo(TStringBuilder() + << "Meta: override blackbox env: " << config->BbEnv + << "->" << *settings.GetOverridedBlackboxEnv()); + } + + if (config->IdmSlug) { + RolesFetcher_ = std::make_unique<TRolesFetcher>( + TRolesFetcherSettings{ + .SelfAlias = settings.GetSelfAlias(), + }, + Logger_); + } + + ui8 tries = 3; + do { + UpdateState(); + } while (!IsEverythingOk(*config) && --tries > 0); + + if (!IsEverythingOk(*config)) { + ThrowLastError(); + } + SetInited(true); + } + + void TThreadedUpdater::UpdateState() { + bool wasUpdated = false; + try { + wasUpdated = MetaInfo_.TryUpdateConfig(GetClient()); + LastVisitForConfig_ = TInstant::Now(); + ClearError(EScope::TvmtoolConfig); + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::TvmtoolConfig, e.what()); + LogWarning(TStringBuilder() << "Error while fetching of tvmtool config: " << e.what()); + } + if (IsConfigWarnTime()) { + LogError(TStringBuilder() << "Tvmtool config have not been refreshed for too long period"); + } + + TMetaInfo::TConfigPtr config = MetaInfo_.GetConfig(); + + if (wasUpdated || IsTimeToUpdateServiceTickets(*config, LastVisitForServiceTickets_)) { + try { + const TInstant updateTime = UpdateServiceTickets(*config); + SetUpdateTimeOfServiceTickets(updateTime); + LastVisitForServiceTickets_ = TInstant::Now(); + + if (AreServiceTicketsOk(*config)) { + ClearError(EScope::ServiceTickets); + } + LogDebug(TStringBuilder() << "Tickets fetched from tvmtool: " << updateTime); + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::ServiceTickets, e.what()); + LogWarning(TStringBuilder() << "Error while fetching of tickets: " << e.what()); + } + + if (TInstant::Now() - GetUpdateTimeOfServiceTickets() > ServiceTicketsDurations_.Expiring) { + LogError("Service tickets have not been refreshed for too long period"); + } + } + + if (wasUpdated || IsTimeToUpdatePublicKeys(LastVisitForPublicKeys_)) { + try { + const TInstant updateTime = UpdateKeys(*config); + SetUpdateTimeOfPublicKeys(updateTime); + LastVisitForPublicKeys_ = TInstant::Now(); + + if (ArePublicKeysOk()) { + ClearError(EScope::PublicKeys); + } + LogDebug(TStringBuilder() << "Public keys fetched from tvmtool: " << updateTime); + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::PublicKeys, e.what()); + LogWarning(TStringBuilder() << "Error while fetching of public keys: " << e.what()); + } + + if (TInstant::Now() - GetUpdateTimeOfPublicKeys() > PublicKeysDurations_.Expiring) { + LogError("Public keys have not been refreshed for too long period"); + } + } + + if (RolesFetcher_ && (wasUpdated || RolesFetcher_->IsTimeToUpdate(TInstant::Now() - GetUpdateTimeOfRoles()))) { + try { + RolesFetcher_->Update(RolesFetcher_->FetchActualRoles(MetaInfo_.GetAuthHeader(), GetClient())); + SetUpdateTimeOfRoles(TInstant::Now()); + + if (RolesFetcher_->AreRolesOk()) { + ClearError(EScope::Roles); + } + } catch (const std::exception& e) { + ProcessError(EType::Retriable, EScope::Roles, e.what()); + LogWarning(TStringBuilder() << "Failed to update roles: " << e.what()); + } + + if (RolesFetcher_->ShouldWarn(TInstant::Now() - GetUpdateTimeOfRoles())) { + LogError("Roles have not been refreshed for too long period"); + } + } + } + + TInstant TThreadedUpdater::UpdateServiceTickets(const TMetaInfo::TConfig& config) { + const std::pair<TString, TInstant> tickets = FetchServiceTickets(config); + + if (TInstant::Now() - tickets.second >= ServiceTicketsDurations_.Invalid) { + throw yexception() << "Service tickets are too old: " << tickets.second; + } + + TPairTicketsErrors p = ParseFetchTicketsResponse(tickets.first, config.DstAliases); + SetServiceTickets(MakeIntrusiveConst<TServiceTickets>(std::move(p.Tickets), + std::move(p.Errors), + config.DstAliases)); + return tickets.second; + } + + std::pair<TString, TInstant> TThreadedUpdater::FetchServiceTickets(const TMetaInfo::TConfig& config) const { + TStringStream s; + THttpHeaders headers; + + const TString request = TMetaInfo::GetRequestForTickets(config); + auto code = GetClient().DoGet(request, &s, MetaInfo_.GetAuthHeader(), &headers); + Y_ENSURE(code == 200, ProcessHttpError(EScope::ServiceTickets, request, code, s.Str())); + + return {s.Str(), GetBirthTimeFromResponse(headers, "tickets")}; + } + + static THashSet<TTvmId> GetAllTvmIds(const TMetaInfo::TDstAliases& dsts) { + THashSet<TTvmId> res; + res.reserve(dsts.size()); + + for (const auto& pair : dsts) { + res.insert(pair.second); + } + + return res; + } + + TAsyncUpdaterBase::TPairTicketsErrors TThreadedUpdater::ParseFetchTicketsResponse(const TString& resp, + const TMetaInfo::TDstAliases& dsts) const { + const THashSet<TTvmId> allTvmIds = GetAllTvmIds(dsts); + + TServiceTickets::TMapIdStr tickets; + TServiceTickets::TMapIdStr errors; + + auto procErr = [this](const TString& msg) { + ProcessError(EType::NonRetriable, EScope::ServiceTickets, msg); + LogError(msg); + }; + + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(resp, &doc), "Invalid json from tvmtool: " << resp); + + for (const auto& pair : doc.GetMap()) { + NJson::TJsonValue tvmId; + unsigned long long tvmIdNum = 0; + + if (!pair.second.GetValue("tvm_id", &tvmId) || + !tvmId.GetUInteger(&tvmIdNum)) { + procErr(TStringBuilder() + << "Failed to get 'tvm_id' from key, should never happend '" + << pair.first << "': " << resp); + continue; + } + + if (!allTvmIds.contains(tvmIdNum)) { + continue; + } + + NJson::TJsonValue val; + if (!pair.second.GetValue("ticket", &val)) { + TString err; + if (pair.second.GetValue("error", &val)) { + err = val.GetString(); + } else { + err = "Failed to get 'ticket' and 'error', should never happend: " + pair.first; + } + + procErr(TStringBuilder() + << "Failed to get ServiceTicket for " << pair.first + << " (" << tvmIdNum << "): " << err); + + errors.insert({tvmIdNum, std::move(err)}); + continue; + } + + tickets.insert({tvmIdNum, val.GetString()}); + } + + // This work-around is required because of bug in old verions of tvmtool: PASSP-24829 + for (const auto& pair : dsts) { + if (!tickets.contains(pair.second) && !errors.contains(pair.second)) { + TString err = "Missing tvm_id in response, should never happend: " + pair.first; + + procErr(TStringBuilder() + << "Failed to get ServiceTicket for " << pair.first + << " (" << pair.second << "): " << err); + + errors.emplace(pair.second, std::move(err)); + } + } + + return {std::move(tickets), std::move(errors)}; + } + + TInstant TThreadedUpdater::UpdateKeys(const TMetaInfo::TConfig& config) { + const std::pair<TString, TInstant> keys = FetchPublicKeys(); + + if (TInstant::Now() - keys.second >= PublicKeysDurations_.Invalid) { + throw yexception() << "Public keys are too old: " << keys.second; + } + + SetServiceContext(MakeIntrusiveConst<TServiceContext>( + TServiceContext::CheckingFactory(config.SelfTvmId, keys.first))); + SetUserContext(keys.first); + + return keys.second; + } + + std::pair<TString, TInstant> TThreadedUpdater::FetchPublicKeys() const { + TStringStream s; + THttpHeaders headers; + + auto code = GetClient().DoGet("/tvm/keys", &s, MetaInfo_.GetAuthHeader(), &headers); + Y_ENSURE(code == 200, ProcessHttpError(EScope::PublicKeys, "/tvm/keys", code, s.Str())); + + return {s.Str(), GetBirthTimeFromResponse(headers, "public keys")}; + } + + TInstant TThreadedUpdater::GetBirthTimeFromResponse(const THttpHeaders& headers, TStringBuf errMsg) { + auto it = std::find_if(headers.begin(), + headers.end(), + [](const THttpInputHeader& h) { + return AsciiEqualsIgnoreCase(h.Name(), "X-Ya-Tvmtool-Data-Birthtime"); + }); + Y_ENSURE(it != headers.end(), "Failed to fetch bithtime of " << errMsg << " from tvmtool"); + + ui64 time = 0; + Y_ENSURE(TryIntFromString<10>(it->Value(), time), + "Bithtime of " << errMsg << " from tvmtool must be unixtime. Got: " << it->Value()); + + return TInstant::Seconds(time); + } + + bool TThreadedUpdater::IsTimeToUpdateServiceTickets(const TMetaInfo::TConfig& config, + TInstant lastUpdate) const { + return config.AreTicketsRequired() && + TInstant::Now() - lastUpdate > ServiceTicketsDurations_.RefreshPeriod; + } + + bool TThreadedUpdater::IsTimeToUpdatePublicKeys(TInstant lastUpdate) const { + return TInstant::Now() - lastUpdate > PublicKeysDurations_.RefreshPeriod; + } + + bool TThreadedUpdater::IsEverythingOk(const TMetaInfo::TConfig& config) const { + if (RolesFetcher_ && !RolesFetcher_->AreRolesOk()) { + return false; + } + return AreServiceTicketsOk(config) && ArePublicKeysOk(); + } + + bool TThreadedUpdater::AreServiceTicketsOk(const TMetaInfo::TConfig& config) const { + return AreServiceTicketsOk(config.DstAliases.size()); + } + + bool TThreadedUpdater::AreServiceTicketsOk(size_t requiredCount) const { + if (requiredCount == 0) { + return true; + } + + auto c = GetCachedServiceTickets(); + return c && c->TicketsByAlias.size() == requiredCount; + } + + bool TThreadedUpdater::ArePublicKeysOk() const { + return GetCachedServiceContext() && GetCachedUserContext(); + } + + bool TThreadedUpdater::IsConfigWarnTime() const { + return LastVisitForConfig_ + ConfigWarnDelay_ < TInstant::Now(); + } + + void TThreadedUpdater::Worker() { + if (IsInited()) { + UpdateState(); + } else { + try { + Init(Settings_); + } catch (const TRetriableException& e) { + // Still not initialized + } catch (const std::exception& e) { + // Can't retry, so we mark client as initialized and now GetStatus() will return TClientStatus::Error + SetInited(true); + } + } + } +} diff --git a/library/cpp/tvmauth/client/misc/tool/threaded_updater.h b/library/cpp/tvmauth/client/misc/tool/threaded_updater.h new file mode 100644 index 0000000000..e007553f81 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/tool/threaded_updater.h @@ -0,0 +1,65 @@ +#pragma once + +#include "meta_info.h" +#include "roles_fetcher.h" + +#include <library/cpp/tvmauth/client/misc/async_updater.h> +#include <library/cpp/tvmauth/client/misc/threaded_updater.h> + +#include <atomic> + +namespace NTvmAuth::NTvmTool { + class TThreadedUpdater: public TThreadedUpdaterBase { + public: + static TAsyncUpdaterPtr Create(const TClientSettings& settings, TLoggerPtr logger); + ~TThreadedUpdater(); + + TClientStatus GetStatus() const override; + TString GetServiceTicketFor(const TClientSettings::TAlias& dst) const override; + TString GetServiceTicketFor(const TTvmId dst) const override; + TCheckedServiceTicket CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags = TServiceContext::TCheckFlags{}) const override; + TCheckedUserTicket CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overrideEnv = {}) const override; + NRoles::TRolesPtr GetRoles() const override; + + protected: // for tests + TClientStatus::ECode GetState() const; + + TThreadedUpdater(const TClientSettings& settings, TLoggerPtr logger); + + void Init(const TClientSettings& settings); + void UpdateState(); + + TInstant UpdateServiceTickets(const TMetaInfo::TConfig& config); + std::pair<TString, TInstant> FetchServiceTickets(const TMetaInfo::TConfig& config) const; + TPairTicketsErrors ParseFetchTicketsResponse(const TString& resp, + const TMetaInfo::TDstAliases& dsts) const; + + TInstant UpdateKeys(const TMetaInfo::TConfig& config); + std::pair<TString, TInstant> FetchPublicKeys() const; + + static TInstant GetBirthTimeFromResponse(const THttpHeaders& headers, TStringBuf errMsg); + + bool IsTimeToUpdateServiceTickets(const TMetaInfo::TConfig& config, TInstant lastUpdate) const; + bool IsTimeToUpdatePublicKeys(TInstant lastUpdate) const; + + bool IsEverythingOk(const TMetaInfo::TConfig& config) const; + bool AreServiceTicketsOk(const TMetaInfo::TConfig& config) const; + bool AreServiceTicketsOk(size_t requiredCount) const; + bool ArePublicKeysOk() const; + bool IsConfigWarnTime() const; + + private: + void Worker() override; + + protected: + TMetaInfo MetaInfo_; + TInstant LastVisitForServiceTickets_; + TInstant LastVisitForPublicKeys_; + TInstant LastVisitForConfig_; + TDuration ConfigWarnDelay_; + std::unique_ptr<TRolesFetcher> RolesFetcher_; + + private: + const TClientSettings Settings_; + }; +} diff --git a/library/cpp/tvmauth/client/misc/utils.cpp b/library/cpp/tvmauth/client/misc/utils.cpp new file mode 100644 index 0000000000..a124c7b11c --- /dev/null +++ b/library/cpp/tvmauth/client/misc/utils.cpp @@ -0,0 +1,46 @@ +#include "utils.h" + +#include <library/cpp/tvmauth/client/facade.h> + +#include <util/stream/format.h> + +namespace NTvmAuth::NInternal { + void TClientCaningKnife::StartTvmClientStopping(TTvmClient* c) { + if (c && c->Updater_) { + c->Updater_->StartTvmClientStopping(); + } + } + + bool TClientCaningKnife::IsTvmClientStopped(TTvmClient* c) { + return c && c->Updater_ ? c->Updater_->IsTvmClientStopped() : true; + } +} + +namespace NTvmAuth::NUtils { + TString ToHex(const TStringBuf s) { + TStringStream res; + res.Reserve(2 * s.size()); + + for (char c : s) { + res << Hex(c, HF_FULL); + } + + return std::move(res.Str()); + } + + bool CheckBbEnvOverriding(EBlackboxEnv original, EBlackboxEnv override) noexcept { + switch (original) { + case EBlackboxEnv::Prod: + case EBlackboxEnv::ProdYateam: + return override == EBlackboxEnv::Prod || override == EBlackboxEnv::ProdYateam; + case EBlackboxEnv::Test: + return true; + case EBlackboxEnv::TestYateam: + return override == EBlackboxEnv::Test || override == EBlackboxEnv::TestYateam; + case EBlackboxEnv::Stress: + return override == EBlackboxEnv::Stress; + } + + return false; + } +} diff --git a/library/cpp/tvmauth/client/misc/utils.h b/library/cpp/tvmauth/client/misc/utils.h new file mode 100644 index 0000000000..1aa5e61bf1 --- /dev/null +++ b/library/cpp/tvmauth/client/misc/utils.h @@ -0,0 +1,95 @@ +#pragma once + +#include "api/settings.h" +#include "tool/settings.h" + +#include <util/string/cast.h> +#include <util/system/spinlock.h> + +#include <optional> + +namespace NTvmAuth { + class TTvmClient; +} + +namespace NTvmAuth::NInternal { + class TClientCaningKnife { + public: + static void StartTvmClientStopping(TTvmClient* c); + static bool IsTvmClientStopped(TTvmClient* c); + }; +} + +namespace NTvmAuth::NUtils { + TString ToHex(const TStringBuf s); + + inline NTvmAuth::NTvmApi::TClientSettings::TDstMap ParseDstMap(TStringBuf dsts) { + NTvmAuth::NTvmApi::TClientSettings::TDstMap res; + + while (dsts) { + TStringBuf pair = dsts.NextTok(';'); + TStringBuf alias = pair.NextTok(':'); + res.insert(decltype(res)::value_type( + alias, + IntFromString<TTvmId, 10>(pair))); + } + + return res; + } + + inline NTvmAuth::NTvmApi::TClientSettings::TDstVector ParseDstVector(TStringBuf dsts) { + NTvmAuth::NTvmApi::TClientSettings::TDstVector res; + + while (dsts) { + res.push_back(IntFromString<TTvmId, 10>(dsts.NextTok(';'))); + } + + return res; + } + + bool CheckBbEnvOverriding(EBlackboxEnv original, EBlackboxEnv override) noexcept; + + template <class T> + class TProtectedValue { + class TAssignOp { + public: + static void Assign(T& l, const T& r) { + l = r; + } + + template <typename U> + static void Assign(std::shared_ptr<U>& l, std::shared_ptr<U>& r) { + l.swap(r); + } + + template <typename U> + static void Assign(TIntrusiveConstPtr<U>& l, TIntrusiveConstPtr<U>& r) { + l.Swap(r); + } + }; + + public: + TProtectedValue() = default; + + TProtectedValue(T value) + : Value_(value) + { + } + + T Get() const { + with_lock (Lock_) { + return Value_; + } + } + + void Set(T o) { + with_lock (Lock_) { + TAssignOp::Assign(Value_, o); + } + } + + private: + T Value_; + mutable TAdaptiveLock Lock_; + }; +} diff --git a/library/cpp/tvmauth/client/mocked_updater.cpp b/library/cpp/tvmauth/client/mocked_updater.cpp new file mode 100644 index 0000000000..54f94bc92a --- /dev/null +++ b/library/cpp/tvmauth/client/mocked_updater.cpp @@ -0,0 +1,60 @@ +#include "mocked_updater.h" + +#include <library/cpp/tvmauth/unittest.h> + +namespace NTvmAuth { + TMockedUpdater::TSettings TMockedUpdater::TSettings::CreateDeafult() { + TMockedUpdater::TSettings res; + + res.SelfTvmId = 100500; + + res.Backends = { + { + /*.Alias_ = */ "my_dest", + /*.Id_ = */ 42, + /*.Value_ = */ "3:serv:CBAQ__________9_IgYIlJEGECo:O9-vbod_8czkKrpwJAZCI8UgOIhNr2xKPcS-LWALrVC224jga2nIT6vLiw6q3d6pAT60g9K7NB39LEmh7vMuePtUMjzuZuL-uJg17BsH2iTLCZSxDjWxbU9piA2T6u607jiSyiy-FI74pEPqkz7KKJ28aPsefuC1VUweGkYFzNY", + }, + }; + + res.BadBackends = { + { + /*.Alias_ = */ "my_bad_dest", + /*.Id_ = */ 43, + /*.Value_ = */ "Dst is not found", + }, + }; + + return res; + } + + TMockedUpdater::TMockedUpdater(const TSettings& settings) + : Roles_(settings.Roles) + { + SetServiceContext(MakeIntrusiveConst<TServiceContext>(TServiceContext::CheckingFactory( + settings.SelfTvmId, + NUnittest::TVMKNIFE_PUBLIC_KEYS))); + + SetBbEnv(settings.UserTicketEnv); + SetUserContext(NUnittest::TVMKNIFE_PUBLIC_KEYS); + + TServiceTickets::TMapIdStr tickets, errors; + TServiceTickets::TMapAliasId aliases; + + for (const TSettings::TTuple& t : settings.Backends) { + tickets[t.Id] = t.Value; + aliases[t.Alias] = t.Id; + } + for (const TSettings::TTuple& t : settings.BadBackends) { + errors[t.Id] = t.Value; + aliases[t.Alias] = t.Id; + } + + SetServiceTickets(MakeIntrusiveConst<TServiceTickets>( + std::move(tickets), + std::move(errors), + std::move(aliases))); + + SetUpdateTimeOfPublicKeys(TInstant::Now()); + SetUpdateTimeOfServiceTickets(TInstant::Now()); + } +} diff --git a/library/cpp/tvmauth/client/mocked_updater.h b/library/cpp/tvmauth/client/mocked_updater.h new file mode 100644 index 0000000000..2b6daedb03 --- /dev/null +++ b/library/cpp/tvmauth/client/mocked_updater.h @@ -0,0 +1,81 @@ +#pragma once + +#include "misc/async_updater.h" +#include "misc/checker.h" +#include "misc/default_uid_checker.h" +#include "misc/getter.h" +#include "misc/src_checker.h" + +namespace NTvmAuth { + class TMockedUpdater: public TAsyncUpdaterBase { + public: + struct TSettings { + struct TTuple { + TClientSettings::TAlias Alias; + TTvmId Id = 0; + TString Value; // ticket or error + }; + + TTvmId SelfTvmId = 0; + TVector<TTuple> Backends; + TVector<TTuple> BadBackends; + EBlackboxEnv UserTicketEnv = EBlackboxEnv::Test; + NRoles::TRolesPtr Roles; + + static TSettings CreateDeafult(); + }; + + TMockedUpdater(const TSettings& settings = TSettings::CreateDeafult()); + + TClientStatus GetStatus() const override { + return TClientStatus(); + } + + NRoles::TRolesPtr GetRoles() const override { + Y_ENSURE_EX(Roles_, TIllegalUsage() << "Roles are not provided"); + return Roles_; + } + + TString GetServiceTicketFor(const TClientSettings::TAlias& dst) const override { + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TString GetServiceTicketFor(const TTvmId dst) const override { + auto c = GetCachedServiceTickets(); + return TServiceTicketGetter::GetTicket(dst, c); + } + + TCheckedServiceTicket CheckServiceTicket(TStringBuf ticket, const TServiceContext::TCheckFlags& flags) const override { + TServiceContextPtr c = GetCachedServiceContext(); + TCheckedServiceTicket res = TServiceTicketChecker::Check(ticket, c, flags); + + if (Roles_ && res) { + NRoles::TRolesPtr roles = GetRoles(); + return TSrcChecker::Check(std::move(res), roles); + } + + return res; + } + + TCheckedUserTicket CheckUserTicket(TStringBuf ticket, TMaybe<EBlackboxEnv> overridenEnv) const override { + auto c = GetCachedUserContext(overridenEnv); + TCheckedUserTicket res = TUserTicketChecker::Check(ticket, c); + + if (Roles_ && res && res.GetEnv() == EBlackboxEnv::ProdYateam) { + NRoles::TRolesPtr roles = GetRoles(); + return TDefaultUidChecker::Check(std::move(res), roles); + } + return res; + } + + using TAsyncUpdaterBase::SetServiceContext; + using TAsyncUpdaterBase::SetServiceTickets; + using TAsyncUpdaterBase::SetUpdateTimeOfPublicKeys; + using TAsyncUpdaterBase::SetUpdateTimeOfServiceTickets; + using TAsyncUpdaterBase::SetUserContext; + + protected: + NRoles::TRolesPtr Roles_; + }; +} diff --git a/library/cpp/tvmauth/client/ya.make b/library/cpp/tvmauth/client/ya.make new file mode 100644 index 0000000000..2c958f7b16 --- /dev/null +++ b/library/cpp/tvmauth/client/ya.make @@ -0,0 +1,48 @@ +LIBRARY() + +PEERDIR( + library/cpp/http/simple + library/cpp/json + library/cpp/openssl/crypto + library/cpp/streams/brotli + library/cpp/streams/zstd + library/cpp/string_utils/quote + library/cpp/tvmauth + library/cpp/tvmauth/client/misc/retry_settings/v1 +) + +SRCS( + client_status.cpp + facade.cpp + logger.cpp + misc/api/roles_fetcher.cpp + misc/api/settings.cpp + misc/api/threaded_updater.cpp + misc/async_updater.cpp + misc/disk_cache.cpp + misc/last_error.cpp + misc/proc_info.cpp + misc/roles/decoder.cpp + misc/roles/entities_index.cpp + misc/roles/parser.cpp + misc/roles/roles.cpp + misc/threaded_updater.cpp + misc/tool/meta_info.cpp + misc/tool/roles_fetcher.cpp + misc/tool/settings.cpp + misc/tool/threaded_updater.cpp + misc/utils.cpp + mocked_updater.cpp +) + +GENERATE_ENUM_SERIALIZATION(client_status.h) +GENERATE_ENUM_SERIALIZATION(misc/async_updater.h) +GENERATE_ENUM_SERIALIZATION(misc/last_error.h) + +END() + +RECURSE_FOR_TESTS( + examples + misc/api/dynamic_dst + ut +) diff --git a/library/cpp/tvmauth/deprecated/service_context.h b/library/cpp/tvmauth/deprecated/service_context.h new file mode 100644 index 0000000000..bdf1bb5224 --- /dev/null +++ b/library/cpp/tvmauth/deprecated/service_context.h @@ -0,0 +1,72 @@ +#pragma once + +#include <library/cpp/tvmauth/checked_service_ticket.h> + +#include <util/generic/ptr.h> + +namespace NTvmAuth { + class TServiceContext: public TAtomicRefCount<TServiceContext> { + public: + /*! + * @struct TCheckFlags holds flags that control checking + */ + struct TCheckFlags { + TCheckFlags() { + } + bool NeedDstCheck = true; + }; + + /*! + * Create service context. Serivce contexts are used to store TVM keys and parse service tickets. + * @param selfTvmId + * @param secretBase64 + * @param tvmKeysResponse + */ + TServiceContext(TStringBuf secretBase64, TTvmId selfTvmId, TStringBuf tvmKeysResponse); + TServiceContext(TServiceContext&&); + ~TServiceContext(); + + /*! + * Create service context only for checking service tickets + * \param[in] selfTvmId + * \param[in] tvmKeysResponse + * \return + */ + static TServiceContext CheckingFactory(TTvmId selfTvmId, TStringBuf tvmKeysResponse); + + /*! + * Create service context only for signing HTTP request to TVM-API + * \param[in] secretBase64 + * \return + */ + static TServiceContext SigningFactory(TStringBuf secretBase64); + + TServiceContext& operator=(TServiceContext&&); + + /*! + * Parse and validate service ticket body then create TCheckedServiceTicket object. + * @param ticketBody + * @return TCheckedServiceTicket object + */ + TCheckedServiceTicket Check(TStringBuf ticketBody, const TCheckFlags& flags = {}) const; + + /*! + * Sign params for TVM API + * @param ts Param 'ts' of request to TVM + * @param dst Param 'dst' of request to TVM + * @param scopes Param 'scopes' of request to TVM + * @return Signed string + */ + TString SignCgiParamsForTvm(TStringBuf ts, TStringBuf dst, TStringBuf scopes = TStringBuf()) const; + + class TImpl; + + private: + TServiceContext() = default; + + private: + THolder<TImpl> Impl_; + }; + + using TServiceContextPtr = TIntrusiveConstPtr<TServiceContext>; +} diff --git a/library/cpp/tvmauth/deprecated/user_context.h b/library/cpp/tvmauth/deprecated/user_context.h new file mode 100644 index 0000000000..f7fe67d02e --- /dev/null +++ b/library/cpp/tvmauth/deprecated/user_context.h @@ -0,0 +1,30 @@ +#pragma once + +#include <library/cpp/tvmauth/checked_user_ticket.h> + +#include <util/generic/ptr.h> + +namespace NTvmAuth { + class TUserContext: public TAtomicRefCount<TUserContext> { + public: + TUserContext(EBlackboxEnv env, TStringBuf tvmKeysResponse); + TUserContext(TUserContext&&); + ~TUserContext(); + + TUserContext& operator=(TUserContext&&); + + /*! + * Parse and validate user ticket body then create TCheckedUserTicket object. + * @param ticketBody + * @return TCheckedUserTicket object + */ + TCheckedUserTicket Check(TStringBuf ticketBody) const; + + class TImpl; + + private: + THolder<TImpl> Impl_; + }; + + using TUserContextPtr = TIntrusiveConstPtr<TUserContext>; +} diff --git a/library/cpp/tvmauth/exception.h b/library/cpp/tvmauth/exception.h new file mode 100644 index 0000000000..f528886b95 --- /dev/null +++ b/library/cpp/tvmauth/exception.h @@ -0,0 +1,20 @@ +#pragma once + +#include <util/generic/yexception.h> + +#include <exception> + +namespace NTvmAuth { + class TTvmException: public yexception { + }; + class TContextException: public TTvmException { + }; + class TMalformedTvmSecretException: public TContextException { + }; + class TMalformedTvmKeysException: public TContextException { + }; + class TEmptyTvmKeysException: public TContextException { + }; + class TNotAllowedException: public TTvmException { + }; +} diff --git a/library/cpp/tvmauth/src/parser.h b/library/cpp/tvmauth/src/parser.h new file mode 100644 index 0000000000..678e709444 --- /dev/null +++ b/library/cpp/tvmauth/src/parser.h @@ -0,0 +1,51 @@ +#pragma once + +#include <library/cpp/tvmauth/src/protos/ticket2.pb.h> +#include <library/cpp/tvmauth/src/rw/keys.h> + +#include <library/cpp/tvmauth/ticket_status.h> + +#include <util/generic/fwd.h> + +#include <string> + +namespace NTvmAuth { + struct TParserTvmKeys { + static inline const char DELIM = ':'; + static TString ParseStrV1(TStringBuf str); + }; + + struct TParserTickets { + static const char DELIM = ':'; + + static TStringBuf UserFlag(); + static TStringBuf ServiceFlag(); + + struct TRes { + TRes(ETicketStatus status) + : Status(status) + { + } + + ETicketStatus Status; + + ticket2::Ticket Ticket; + }; + static TRes ParseV3(TStringBuf body, const NRw::TPublicKeys& keys, TStringBuf type); + + // private: + struct TStrRes { + const ETicketStatus Status; + + TString Proto; + TString Sign; + + TStringBuf ForCheck; + + bool operator==(const TStrRes& o) const { // for tests + return Status == o.Status && Proto == o.Proto && Sign == o.Sign && ForCheck == o.ForCheck; + } + }; + static TStrRes ParseStrV3(TStringBuf body, TStringBuf type); + }; +} diff --git a/library/cpp/tvmauth/src/protos/ticket2.proto b/library/cpp/tvmauth/src/protos/ticket2.proto new file mode 100644 index 0000000000..47950a8861 --- /dev/null +++ b/library/cpp/tvmauth/src/protos/ticket2.proto @@ -0,0 +1,33 @@ +package ticket2; + +option go_package = "github.com/ydb-platform/ydb/library/cpp/tvmauth/src/protos"; + +import "library/cpp/tvmauth/src/protos/tvm_keys.proto"; + +message User { + required uint64 uid = 1; + optional uint64 porgId = 2; +} + +message UserTicket { + repeated User users = 1; + required uint64 defaultUid = 2; + repeated string scopes = 3; + required uint32 entryPoint = 4; + required tvm_keys.BbEnvType env = 5; + optional string loginId = 6; +} + +message ServiceTicket { + required uint32 srcClientId = 1; + required uint32 dstClientId = 2; + repeated string scopes = 3; + optional uint64 issuerUid = 4; +} + +message Ticket { + required uint32 keyId = 1; + required int64 expirationTime = 2; + optional UserTicket user = 3; + optional ServiceTicket service = 4; +} diff --git a/library/cpp/tvmauth/src/protos/tvm_keys.proto b/library/cpp/tvmauth/src/protos/tvm_keys.proto new file mode 100644 index 0000000000..d931e52071 --- /dev/null +++ b/library/cpp/tvmauth/src/protos/tvm_keys.proto @@ -0,0 +1,36 @@ +package tvm_keys; + +option go_package = "github.com/ydb-platform/ydb/library/cpp/tvmauth/src/protos"; + +enum KeyType { + RabinWilliams = 0; +} + +enum BbEnvType { + Prod = 0; + Test = 1; + ProdYateam = 2; + TestYateam = 3; + Stress = 4; +} + +message General { + required uint32 id = 1; + required KeyType type = 2; + required bytes body = 3; + optional int64 createdTime = 4; +} + +message BbKey { + required General gen = 1; + required BbEnvType env = 2; +} + +message TvmKey { + required General gen = 1; +} + +message Keys { + repeated BbKey bb = 1; + repeated TvmKey tvm = 2; +} diff --git a/library/cpp/tvmauth/src/protos/ya.make b/library/cpp/tvmauth/src/protos/ya.make new file mode 100644 index 0000000000..6a7fab902a --- /dev/null +++ b/library/cpp/tvmauth/src/protos/ya.make @@ -0,0 +1,10 @@ +PROTO_LIBRARY() + +INCLUDE_TAGS(GO_PROTO) + +SRCS( + ticket2.proto + tvm_keys.proto +) + +END() diff --git a/library/cpp/tvmauth/src/rw/keys.h b/library/cpp/tvmauth/src/rw/keys.h new file mode 100644 index 0000000000..e02b7e72a1 --- /dev/null +++ b/library/cpp/tvmauth/src/rw/keys.h @@ -0,0 +1,65 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/string.h> + +#include <unordered_map> + +struct TRwInternal; + +namespace NTvmAuth { + namespace NRw { + namespace NPrivate { + class TRwDestroyer { + public: + static void Destroy(TRwInternal* o); + }; + } + + using TRw = THolder<TRwInternal, NPrivate::TRwDestroyer>; + using TKeyId = ui32; + + struct TKeyPair { + TString Private; + TString Public; + }; + TKeyPair GenKeyPair(size_t size); + + class TRwPrivateKey { + public: + TRwPrivateKey(TStringBuf body, TKeyId id); + + TKeyId GetId() const; + TString SignTicket(TStringBuf ticket) const; + + private: + static TRw Deserialize(TStringBuf key); + + TKeyId Id_; + TRw Rw_; + int SignLen_; + }; + + class TRwPublicKey { + public: + TRwPublicKey(TStringBuf body); + + bool CheckSign(TStringBuf ticket, TStringBuf sign) const; + + private: + static TRw Deserialize(TStringBuf key); + + TRw Rw_; + }; + + using TPublicKeys = std::unordered_map<TKeyId, TRwPublicKey>; + + class TSecureHeap { + public: + TSecureHeap(size_t totalSize, int minChunkSize); + ~TSecureHeap(); + + static void Init(size_t totalSize = 16 * 1024 * 1024, int minChunkSize = 16); + }; + } +} diff --git a/library/cpp/tvmauth/src/service_impl.h b/library/cpp/tvmauth/src/service_impl.h new file mode 100644 index 0000000000..76400cffea --- /dev/null +++ b/library/cpp/tvmauth/src/service_impl.h @@ -0,0 +1,78 @@ +#pragma once + +#include <library/cpp/tvmauth/src/protos/ticket2.pb.h> +#include <library/cpp/tvmauth/src/protos/tvm_keys.pb.h> +#include <library/cpp/tvmauth/src/rw/keys.h> + +#include <library/cpp/tvmauth/type.h> +#include <library/cpp/tvmauth/deprecated/service_context.h> + +#include <library/cpp/charset/ci_string.h> +#include <library/cpp/string_utils/secret_string/secret_string.h> + +#include <util/generic/maybe.h> + +#include <string> + +namespace NTvmAuth { + using TServiceTicketImplPtr = THolder<TCheckedServiceTicket::TImpl>; + class TCheckedServiceTicket::TImpl { + public: + explicit operator bool() const; + + TTvmId GetDst() const; + TTvmId GetSrc() const; + const TScopes& GetScopes() const; + bool HasScope(TStringBuf scopeName) const; + ETicketStatus GetStatus() const; + time_t GetExpirationTime() const; + + TString DebugInfo() const; + TMaybe<TUid> GetIssuerUid() const; + + void SetStatus(ETicketStatus status); + + /*! + * Constructor for creation invalid ticket storing error status in TServiceContext + * @param status + * @param protobufTicket + */ + TImpl(ETicketStatus status, ticket2::Ticket&& protobufTicket); + + static TServiceTicketImplPtr CreateTicketForTests(ETicketStatus status, + TTvmId src, + TMaybe<TUid> issuerUid, + TTvmId dst = 100500); + + private: + ETicketStatus Status_; + ticket2::Ticket ProtobufTicket_; + mutable TScopes CachedScopes_; + mutable TString CachedDebugInfo_; + }; + + class TServiceContext::TImpl { + public: + TImpl(TStringBuf secretBase64, TTvmId selfTvmId, TStringBuf tvmKeysResponse); + TImpl(TTvmId selfTvmId, TStringBuf tvmKeysResponse); + TImpl(TStringBuf secretBase64); + + void ResetKeys(TStringBuf tvmKeysResponse); + + TServiceTicketImplPtr Check(TStringBuf ticketBody, const TServiceContext::TCheckFlags& flags = {}) const; + TString SignCgiParamsForTvm(TStringBuf ts, TStringBuf dst, TStringBuf scopes = TStringBuf()) const; + + const NRw::TPublicKeys& GetKeys() const { // for tests + return Keys_; + } + + private: + static TString ParseSecret(TStringBuf secretBase64); + + NRw::TPublicKeys Keys_; + const NSecretString::TSecretString Secret_; + const TTvmId SelfTvmId_ = 0; + + ::google::protobuf::LogSilencer LogSilencer_; + }; +} diff --git a/library/cpp/tvmauth/src/user_impl.h b/library/cpp/tvmauth/src/user_impl.h new file mode 100644 index 0000000000..b1190bd626 --- /dev/null +++ b/library/cpp/tvmauth/src/user_impl.h @@ -0,0 +1,76 @@ +#pragma once + +#include <library/cpp/tvmauth/src/protos/ticket2.pb.h> +#include <library/cpp/tvmauth/src/protos/tvm_keys.pb.h> +#include <library/cpp/tvmauth/src/rw/keys.h> + +#include <library/cpp/tvmauth/deprecated/user_context.h> + +#include <library/cpp/charset/ci_string.h> + +#include <optional> +#include <unordered_map> + +namespace NTvmAuth { + using TUserTicketImplPtr = THolder<TCheckedUserTicket::TImpl>; + class TCheckedUserTicket::TImpl { + public: + explicit operator bool() const; + + TUid GetDefaultUid() const; + time_t GetExpirationTime() const; + const TScopes& GetScopes() const; + bool HasScope(TStringBuf scopeName) const; + ETicketStatus GetStatus() const; + const TUids& GetUids() const; + TUidsExtFieldsMap GetUidsExtFields() const; + std::optional<TUserExtFields> GetDefaultUidExtFields() const; + const TString& GetLoginId() const; + + TString DebugInfo() const; + + EBlackboxEnv GetEnv() const; + + void SetStatus(ETicketStatus status); + + /*! + * Constructor for creation invalid ticket storing error status in TServiceContext + * @param status + * @param protobufTicket + */ + TImpl(ETicketStatus status, ticket2::Ticket&& protobufTicket); + + static TUserTicketImplPtr CreateTicketForTests(ETicketStatus status, + TUid defaultUid, + TScopes scopes, + TUids uids, + EBlackboxEnv env = EBlackboxEnv::Test); + + private: + static const int MaxUserCount = 15; + + ETicketStatus Status_; + ticket2::Ticket ProtobufTicket_; + mutable TScopes CachedScopes_; + mutable TUids CachedUids_; + mutable TString CachedDebugInfo_; + }; + + class TUserContext::TImpl { + public: + TImpl(EBlackboxEnv env, TStringBuf tvmKeysResponse); + void ResetKeys(TStringBuf tvmKeysResponse); + + TUserTicketImplPtr Check(TStringBuf ticketBody) const; + const NRw::TPublicKeys& GetKeys() const; + + bool IsAllowed(tvm_keys::BbEnvType env) const; + + private: + ETicketStatus CheckProtobufUserTicket(const ticket2::Ticket& ticket) const; + + NRw::TPublicKeys Keys_; + EBlackboxEnv Env_; + ::google::protobuf::LogSilencer LogSilencer_; + }; +} diff --git a/library/cpp/tvmauth/src/utils.h b/library/cpp/tvmauth/src/utils.h new file mode 100644 index 0000000000..e5847ac89f --- /dev/null +++ b/library/cpp/tvmauth/src/utils.h @@ -0,0 +1,30 @@ +#pragma once + +#include <library/cpp/tvmauth/checked_service_ticket.h> +#include <library/cpp/tvmauth/checked_user_ticket.h> +#include <library/cpp/tvmauth/ticket_status.h> + +#include <util/datetime/base.h> +#include <util/generic/fwd.h> + +namespace NTvmAuth::NUtils { + TString Bin2base64url(TStringBuf buf); + TString Base64url2bin(TStringBuf buf); + + TString SignCgiParamsForTvm(TStringBuf secret, TStringBuf ts, TStringBuf dstTvmId, TStringBuf scopes); +} + +namespace NTvmAuth::NInternal { + class TCanningKnife { + public: + static TCheckedServiceTicket::TImpl* GetS(TCheckedServiceTicket& t) { + return t.Impl_.Release(); + } + + static TCheckedUserTicket::TImpl* GetU(TCheckedUserTicket& t) { + return t.Impl_.Release(); + } + + static TMaybe<TInstant> GetExpirationTime(TStringBuf ticket); + }; +} diff --git a/library/cpp/tvmauth/ticket_status.h b/library/cpp/tvmauth/ticket_status.h new file mode 100644 index 0000000000..532d4de56e --- /dev/null +++ b/library/cpp/tvmauth/ticket_status.h @@ -0,0 +1,23 @@ +#pragma once + +#include <util/generic/strbuf.h> + +namespace NTvmAuth { + /*! + * Status mean result of ticket check + */ + enum class ETicketStatus { + Ok, + Expired, + InvalidBlackboxEnv, + InvalidDst, + InvalidTicketType, + Malformed, + MissingKey, + SignBroken, + UnsupportedVersion, + NoRoles, + }; + + TStringBuf StatusToString(ETicketStatus st); +} diff --git a/library/cpp/tvmauth/type.h b/library/cpp/tvmauth/type.h new file mode 100644 index 0000000000..acf0f9e1aa --- /dev/null +++ b/library/cpp/tvmauth/type.h @@ -0,0 +1,27 @@ +#pragma once + +#include <library/cpp/containers/stack_vector/stack_vec.h> + +#include <util/generic/hash.h> + +namespace NTvmAuth { + struct TUserExtFields; + + using TScopes = TSmallVec<TStringBuf>; + using TTvmId = ui32; + using TUid = ui64; + using TUids = TSmallVec<TUid>; + using TUidsExtFieldsMap = THashMap<TUid, TUserExtFields>; + using TAlias = TString; + using TPorgId = ui64; + + struct TUserExtFields { + bool operator==(const TUserExtFields& o) const { + return Uid == o.Uid && + CurrentPorgId == o.CurrentPorgId; + } + + TUid Uid = 0; + TPorgId CurrentPorgId = 0; + }; +} diff --git a/library/cpp/tvmauth/unittest.h b/library/cpp/tvmauth/unittest.h new file mode 100644 index 0000000000..79c9c6bf18 --- /dev/null +++ b/library/cpp/tvmauth/unittest.h @@ -0,0 +1,21 @@ +#pragma once + +#include "checked_service_ticket.h" +#include "checked_user_ticket.h" + +#include <util/generic/maybe.h> + +namespace NTvmAuth::NUnittest { + static const TString TVMKNIFE_PUBLIC_KEYS = "1:CpgCCpMCCAEQABqIAjCCAQQCggEAcLEXeH67FQESFUn4_7wnX7wN0PUrBoUsm3QQ4W5vC-qz6sXaEjSwnTV8w1o-z6X9KPLlhzMQvuS38NCNfK4uvJ4Zvfp3YsXJ25-rYtbnrYJHNvHohD-kPCCw_yZpMp21JdWigzQGuV7CtrxUhF-NNrsnUaJrE5-OpEWNt4X6nCItKIYeVcSK6XJUbEWbrNCRbvkSc4ak2ymFeMuHYJVjxh4eQbk7_ZPzodP0WvF6eUYrYeb42imVEOR8ofVLQWE5DVnb1z_TqZm4i1XkS7jMwZuBxBRw8DGdYei0lT_sAf7KST2jC0590NySB3vsBgWEVs1OdUUWA6r-Dvx9dsOQtSCVkQYQAAqZAgqUAggCEAAaiQIwggEFAoIBAQDhEBM5-6YsPWfogKtbluJoCX1WV2KdzOaQ0-OlRbBzeCzw-eQKu12c8WakHBbeCMd1I1TU64SDkDorWjXGIa_2xT6N3zzNAE50roTbPCcmeQrps26woTYfYIuqDdoxYKZNr0lvNLLW47vBr7EKqo1S4KSj7aXK_XYeEvUgIgf3nVIcNrio7VTnFmGGVQCepaL1Hi1gN4yIXjVZ06PBPZ-DxSRu6xOGbFrfKMJeMPs7KOyE-26Q3xOXdTIa1X-zYIucTd_bxUCL4BVbwW2AvbbFsaG7ISmVdGu0XUTmhXs1KrEfUVLRJhE4Dx99hAZXm1_HlYMUeJcMQ_oHOhV94ENFIJaRBhACCpYBCpEBCAMQABqGATCBgwKBgF9t2YJGAJkRRFq6fWhi3m1TFW1UOE0f6ZrfYhHAkpqGlKlh0QVfeTNPpeJhi75xXzCe6oReRUm-0DbqDNhTShC7uGUv1INYnRBQWH6E-5Fc5XrbDFSuGQw2EYjNfHy_HefHJXxQKAqPvxBDKMKkHgV58WtM6rC8jRi9sdX_ig2NIJeRBhABCpYBCpEBCAQQABqGATCBgwKBgGB4d6eLGUBv-Q6EPLehC4S-yuE2HB-_rJ7WkeYwyp-xIPolPrd-PQme2utHB4ZgpXHIu_OFksDe_0bPgZniNRSVRbl7W49DgS5Ya3kMfrYB4DnF5Fta5tn1oV6EwxYD4JONpFTenOJALPGTPawxXEfon_peiHOSBuQMu3_Vn-l1IJiRBhADCpcBCpIBCAUQABqHATCBhAKBgQCTJMKIfmfeZpaI7Q9rnsc29gdWawK7TnpVKRHws1iY7EUlYROeVcMdAwEqVM6f8BVCKLGgzQ7Gar_uuxfUGKwqEQzoppDraw4F75J464-7D5f6_oJQuGIBHZxqbMONtLjBCXRUhQW5szBLmTQ_R3qaJb5vf-h0APZfkYhq1cTttSCZkQYQBAqWAQqRAQgLEAAahgEwgYMCgYBvvGVH_M2H8qxxv94yaDYUTWbRnJ1uiIYc59KIQlfFimMPhSS7x2tqUa2-hI55JiII0Xym6GNkwLhyc1xtWChpVuIdSnbvttbrt4weDMLHqTwNOF6qAsVKGKT1Yh8yf-qb-DSmicgvFc74mBQm_6gAY1iQsf33YX8578ClhKBWHSCVkQYQAAqXAQqSAQgMEAAahwEwgYQCgYEAkuzFcd5TJu7lYWYe2hQLFfUWIIj91BvQQLa_Thln4YtGCO8gG1KJqJm-YlmJOWQG0B7H_5RVhxUxV9KpmFnsDVkzUFKOsCBaYGXc12xPVioawUlAwp5qp3QQtZyx_se97YIoLzuLr46UkLcLnkIrp-Jo46QzYi_QHq45WTm8MQ0glpEGEAIKlwEKkgEIDRAAGocBMIGEAoGBAIUzbxOknXf_rNt17_ir8JlWvrtnCWsQd1MAnl5mgArvavDtKeBYHzi5_Ak7DHlLzuA6YE8W175FxLFKpN2hkz-l-M7ltUSd8N1BvJRhK4t6WffWfC_1wPyoAbeSN2Yb1jygtZJQ8wGoXHcJQUXiMit3eFNyylwsJFj1gzAR4JCdIJeRBhABCpYBCpEBCA4QABqGATCBgwKBgFMcbEpl9ukVR6AO_R6sMyiU11I8b8MBSUCEC15iKsrVO8v_m47_TRRjWPYtQ9eZ7o1ocNJHaGUU7qqInFqtFaVnIceP6NmCsXhjs3MLrWPS8IRAy4Zf4FKmGOx3N9O2vemjUygZ9vUiSkULdVrecinRaT8JQ5RG4bUMY04XGIwFIJiRBhADCpYBCpEBCA8QABqGATCBgwKBgGpCkW-NR3li8GlRvqpq2YZGSIgm_PTyDI2Zwfw69grsBmPpVFW48Vw7xoMN35zcrojEpialB_uQzlpLYOvsMl634CRIuj-n1QE3-gaZTTTE8mg-AR4mcxnTKThPnRQpbuOlYAnriwiasWiQEMbGjq_HmWioYYxFo9USlklQn4-9IJmRBhAEEpUBCpIBCAYQABqHATCBhAKBgQCoZkFGm9oLTqjeXZAq6j5S6i7K20V0lNdBBLqfmFBIRuTkYxhs4vUYnWjZrKRAd5bp6_py0csmFmpl_5Yh0b-2pdo_E5PNP7LGRzKyKSiFddyykKKzVOazH8YYldDAfE8Z5HoS9e48an5JsPg0jr-TPu34DnJq3yv2a6dqiKL9zSCakQYSlQEKkgEIEBAAGocBMIGEAoGBALhrihbf3EpjDQS2sCQHazoFgN0nBbE9eesnnFTfzQELXb2gnJU9enmV_aDqaHKjgtLIPpCgn40lHrn5k6mvH5OdedyI6cCzE-N-GFp3nAq0NDJyMe0fhtIRD__CbT0ulcvkeow65ubXWfw6dBC2gR_34rdMe_L_TGRLMWjDULbNIJqRBg"; + + TCheckedServiceTicket CreateServiceTicket(ETicketStatus status, + TTvmId src, + TMaybe<TUid> issuerUid = TMaybe<TUid>(), + TTvmId dst = 100500); + + TCheckedUserTicket CreateUserTicket(ETicketStatus status, + TUid defaultUid, + const TScopes& scopes, + const TUids& uids = TUids(), + EBlackboxEnv env = EBlackboxEnv::Test); +} diff --git a/library/cpp/tvmauth/utils.h b/library/cpp/tvmauth/utils.h new file mode 100644 index 0000000000..ad8950cab5 --- /dev/null +++ b/library/cpp/tvmauth/utils.h @@ -0,0 +1,12 @@ +#pragma once + +#include <util/generic/strbuf.h> + +namespace NTvmAuth::NUtils { + /*! + * Remove signature from ticket string - rest part can be parsed later with `tvmknife parse_ticket ...` + * @param ticketBody Raw ticket body + * @return safe for logging part of ticket + */ + TStringBuf RemoveTicketSignature(TStringBuf ticketBody); +} diff --git a/library/cpp/tvmauth/version.h b/library/cpp/tvmauth/version.h new file mode 100644 index 0000000000..48ec279829 --- /dev/null +++ b/library/cpp/tvmauth/version.h @@ -0,0 +1,7 @@ +#pragma once + +#include <util/generic/strbuf.h> + +namespace NTvmAuth { + TStringBuf LibVersion(); +} diff --git a/library/cpp/tvmauth/ya.make b/library/cpp/tvmauth/ya.make new file mode 100644 index 0000000000..94be4cabd1 --- /dev/null +++ b/library/cpp/tvmauth/ya.make @@ -0,0 +1,40 @@ +LIBRARY() + +PEERDIR( + library/cpp/string_utils/secret_string + library/cpp/tvmauth/src/protos + library/cpp/tvmauth/src/rw +) + +SRCS( + deprecated/service_context.cpp + deprecated/user_context.cpp + src/parser.cpp + src/service_impl.cpp + src/service_ticket.cpp + src/status.cpp + src/unittest.cpp + src/user_impl.cpp + src/user_ticket.cpp + src/utils.cpp + src/version.cpp + utils.cpp +) + +GENERATE_ENUM_SERIALIZATION(checked_user_ticket.h) +GENERATE_ENUM_SERIALIZATION(ticket_status.h) + +RESOURCE( + src/version /builtin/version +) + +NO_BUILD_IF(OPENSOURCE) + +END() + +RECURSE( + client + src/rw + src/ut + test_all +) diff --git a/library/go/blockcodecs/all/all.go b/library/go/blockcodecs/all/all.go new file mode 100644 index 0000000000..362e85d663 --- /dev/null +++ b/library/go/blockcodecs/all/all.go @@ -0,0 +1,8 @@ +package all + +import ( + _ "github.com/ydb-platform/ydb/library/go/blockcodecs/blockbrotli" + _ "github.com/ydb-platform/ydb/library/go/blockcodecs/blocklz4" + _ "github.com/ydb-platform/ydb/library/go/blockcodecs/blocksnappy" + _ "github.com/ydb-platform/ydb/library/go/blockcodecs/blockzstd" +) diff --git a/library/go/blockcodecs/all/ya.make b/library/go/blockcodecs/all/ya.make new file mode 100644 index 0000000000..32dae74d26 --- /dev/null +++ b/library/go/blockcodecs/all/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(all.go) + +END() diff --git a/library/go/blockcodecs/blockbrotli/brotli.go b/library/go/blockcodecs/blockbrotli/brotli.go new file mode 100644 index 0000000000..8e806c47ed --- /dev/null +++ b/library/go/blockcodecs/blockbrotli/brotli.go @@ -0,0 +1,94 @@ +package blockbrotli + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/andybalholm/brotli" + "github.com/ydb-platform/ydb/library/go/blockcodecs" +) + +type brotliCodec int + +func (b brotliCodec) ID() blockcodecs.CodecID { + switch b { + case 1: + return 48947 + case 10: + return 43475 + case 11: + return 7241 + case 2: + return 63895 + case 3: + return 11408 + case 4: + return 47136 + case 5: + return 45284 + case 6: + return 63219 + case 7: + return 59675 + case 8: + return 40233 + case 9: + return 10380 + default: + panic("unsupported level") + } +} + +func (b brotliCodec) Name() string { + return fmt.Sprintf("brotli_%d", b) +} + +func (b brotliCodec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (b brotliCodec) Encode(dst, src []byte) ([]byte, error) { + if cap(dst) < 8 { + dst = make([]byte, 8) + } + + dst = dst[:8] + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + wb := bytes.NewBuffer(dst) + w := brotli.NewWriterLevel(wb, int(b)) + + if _, err := w.Write(src); err != nil { + return nil, err + } + + if err := w.Close(); err != nil { + return nil, err + } + + return wb.Bytes(), nil +} + +func (b brotliCodec) Decode(dst, src []byte) ([]byte, error) { + if len(src) < 8 { + return nil, fmt.Errorf("short block: %d < 8", len(src)) + } + + rb := bytes.NewBuffer(src[8:]) + r := brotli.NewReader(rb) + + _, err := io.ReadFull(r, dst) + if err != nil { + return nil, err + } + + return dst, nil +} + +func init() { + for i := 1; i <= 11; i++ { + blockcodecs.Register(brotliCodec(i)) + } +} diff --git a/library/go/blockcodecs/blockbrotli/ya.make b/library/go/blockcodecs/blockbrotli/ya.make new file mode 100644 index 0000000000..3e14d0cb30 --- /dev/null +++ b/library/go/blockcodecs/blockbrotli/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(brotli.go) + +END() diff --git a/library/go/blockcodecs/blocklz4/lz4.go b/library/go/blockcodecs/blocklz4/lz4.go new file mode 100644 index 0000000000..058ad6d2bf --- /dev/null +++ b/library/go/blockcodecs/blocklz4/lz4.go @@ -0,0 +1,81 @@ +package blocklz4 + +import ( + "encoding/binary" + + "github.com/pierrec/lz4" + "github.com/ydb-platform/ydb/library/go/blockcodecs" +) + +type lz4Codec struct{} + +func (l lz4Codec) ID() blockcodecs.CodecID { + return 6051 +} + +func (l lz4Codec) Name() string { + return "lz4-fast14-safe" +} + +func (l lz4Codec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (l lz4Codec) Encode(dst, src []byte) ([]byte, error) { + dst = dst[:cap(dst)] + + n := lz4.CompressBlockBound(len(src)) + 8 + if len(dst) < n { + dst = append(dst, make([]byte, n-len(dst))...) + } + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + m, err := lz4.CompressBlock(src, dst[8:], nil) + if err != nil { + return nil, err + } + + return dst[:8+m], nil +} + +func (l lz4Codec) Decode(dst, src []byte) ([]byte, error) { + n, err := lz4.UncompressBlock(src[8:], dst) + if err != nil { + return nil, err + } + return dst[:n], nil +} + +type lz4HCCodec struct { + lz4Codec +} + +func (l lz4HCCodec) ID() blockcodecs.CodecID { + return 62852 +} + +func (l lz4HCCodec) Name() string { + return "lz4-hc-safe" +} + +func (l lz4HCCodec) Encode(dst, src []byte) ([]byte, error) { + dst = dst[:cap(dst)] + + n := lz4.CompressBlockBound(len(src)) + 8 + if len(dst) < n { + dst = append(dst, make([]byte, n-len(dst))...) + } + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + m, err := lz4.CompressBlockHC(src, dst[8:], 0) + if err != nil { + return nil, err + } + + return dst[:8+m], nil +} + +func init() { + blockcodecs.Register(lz4Codec{}) + blockcodecs.Register(lz4HCCodec{}) +} diff --git a/library/go/blockcodecs/blocklz4/ya.make b/library/go/blockcodecs/blocklz4/ya.make new file mode 100644 index 0000000000..78c18bca64 --- /dev/null +++ b/library/go/blockcodecs/blocklz4/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(lz4.go) + +END() diff --git a/library/go/blockcodecs/blocksnappy/snappy.go b/library/go/blockcodecs/blocksnappy/snappy.go new file mode 100644 index 0000000000..eb9e888fcb --- /dev/null +++ b/library/go/blockcodecs/blocksnappy/snappy.go @@ -0,0 +1,32 @@ +package blocksnappy + +import ( + "github.com/golang/snappy" + "github.com/ydb-platform/ydb/library/go/blockcodecs" +) + +type snappyCodec struct{} + +func (s snappyCodec) ID() blockcodecs.CodecID { + return 50986 +} + +func (s snappyCodec) Name() string { + return "snappy" +} + +func (s snappyCodec) DecodedLen(in []byte) (int, error) { + return snappy.DecodedLen(in) +} + +func (s snappyCodec) Encode(dst, src []byte) ([]byte, error) { + return snappy.Encode(dst, src), nil +} + +func (s snappyCodec) Decode(dst, src []byte) ([]byte, error) { + return snappy.Decode(dst, src) +} + +func init() { + blockcodecs.Register(snappyCodec{}) +} diff --git a/library/go/blockcodecs/blocksnappy/ya.make b/library/go/blockcodecs/blocksnappy/ya.make new file mode 100644 index 0000000000..594e2c5443 --- /dev/null +++ b/library/go/blockcodecs/blocksnappy/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(snappy.go) + +END() diff --git a/library/go/blockcodecs/blockzstd/ya.make b/library/go/blockcodecs/blockzstd/ya.make new file mode 100644 index 0000000000..61a919b995 --- /dev/null +++ b/library/go/blockcodecs/blockzstd/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(zstd.go) + +END() diff --git a/library/go/blockcodecs/blockzstd/zstd.go b/library/go/blockcodecs/blockzstd/zstd.go new file mode 100644 index 0000000000..55aa79f174 --- /dev/null +++ b/library/go/blockcodecs/blockzstd/zstd.go @@ -0,0 +1,72 @@ +package blockzstd + +import ( + "encoding/binary" + "fmt" + + "github.com/klauspost/compress/zstd" + "github.com/ydb-platform/ydb/library/go/blockcodecs" +) + +type zstdCodec int + +func (z zstdCodec) ID() blockcodecs.CodecID { + switch z { + case 1: + return 55019 + case 3: + return 23308 + case 7: + return 33533 + default: + panic("unsupported level") + } +} + +func (z zstdCodec) Name() string { + return fmt.Sprintf("zstd08_%d", z) +} + +func (z zstdCodec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (z zstdCodec) Encode(dst, src []byte) ([]byte, error) { + if cap(dst) < 8 { + dst = make([]byte, 8) + } + + dst = dst[:8] + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + w, err := zstd.NewWriter(nil, + zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(int(z))), + zstd.WithEncoderConcurrency(1)) + if err != nil { + return nil, err + } + + defer w.Close() + return w.EncodeAll(src, dst), nil +} + +func (z zstdCodec) Decode(dst, src []byte) ([]byte, error) { + if len(src) < 8 { + return nil, fmt.Errorf("short block: %d < 8", len(src)) + } + + r, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1)) + if err != nil { + return nil, err + } + + defer r.Close() + return r.DecodeAll(src[8:], dst[:0]) +} + +func init() { + for _, i := range []int{1, 3, 7} { + blockcodecs.Register(zstdCodec(i)) + blockcodecs.RegisterAlias(fmt.Sprintf("zstd_%d", i), zstdCodec(i)) + } +} diff --git a/library/go/blockcodecs/codecs.go b/library/go/blockcodecs/codecs.go new file mode 100644 index 0000000000..b45bda6d61 --- /dev/null +++ b/library/go/blockcodecs/codecs.go @@ -0,0 +1,89 @@ +package blockcodecs + +import ( + "encoding/binary" + "fmt" + "sync" + + "go.uber.org/atomic" +) + +type CodecID uint16 + +type Codec interface { + ID() CodecID + Name() string + + DecodedLen(in []byte) (int, error) + Encode(dst, src []byte) ([]byte, error) + Decode(dst, src []byte) ([]byte, error) +} + +var ( + codecsByID sync.Map + codecsByName sync.Map +) + +// Register new codec. +// +// NOTE: update FindCodecByName description, after adding new codecs. +func Register(c Codec) { + if _, duplicate := codecsByID.LoadOrStore(c.ID(), c); duplicate { + panic(fmt.Sprintf("codec with id %d is already registered", c.ID())) + } + + RegisterAlias(c.Name(), c) +} + +func RegisterAlias(name string, c Codec) { + if _, duplicate := codecsByName.LoadOrStore(name, c); duplicate { + panic(fmt.Sprintf("codec with name %s is already registered", c.Name())) + } +} + +func ListCodecs() []Codec { + var c []Codec + codecsByID.Range(func(key, value interface{}) bool { + c = append(c, value.(Codec)) + return true + }) + return c +} + +func FindCodec(id CodecID) Codec { + c, ok := codecsByID.Load(id) + if ok { + return c.(Codec) + } else { + return nil + } +} + +// FindCodecByName returns codec by name. +// +// Possible names: +// +// null +// snappy +// zstd08_{level} - level is integer 1, 3 or 7. +// zstd_{level} - level is integer 1, 3 or 7. +func FindCodecByName(name string) Codec { + c, ok := codecsByName.Load(name) + if ok { + return c.(Codec) + } else { + return nil + } +} + +var ( + maxDecompressedBlockSize = atomic.NewInt32(16 << 20) // 16 MB +) + +func DecodedLen(in []byte) (int, error) { + if len(in) < 8 { + return 0, fmt.Errorf("short block: %d < 8", len(in)) + } + + return int(binary.LittleEndian.Uint64(in[:8])), nil +} diff --git a/library/go/blockcodecs/decoder.go b/library/go/blockcodecs/decoder.go new file mode 100644 index 0000000000..bb38dcf844 --- /dev/null +++ b/library/go/blockcodecs/decoder.go @@ -0,0 +1,155 @@ +package blockcodecs + +import ( + "encoding/binary" + "fmt" + "io" +) + +type Decoder struct { + // optional + codec Codec + + r io.Reader + header [10]byte + eof bool + checkEOF bool + + pos int + buffer []byte + + scratch []byte +} + +func (d *Decoder) getCodec(id CodecID) (Codec, error) { + if d.codec != nil { + if id != d.codec.ID() { + return nil, fmt.Errorf("blockcodecs: received block codec differs from provided: %d != %d", id, d.codec.ID()) + } + + return d.codec, nil + } + + if codec := FindCodec(id); codec != nil { + return codec, nil + } + + return nil, fmt.Errorf("blockcodecs: received block with unsupported codec %d", id) +} + +// SetCheckUnderlyingEOF changes EOF handling. +// +// Blockcodecs format contains end of stream separator. By default Decoder will stop right after +// that separator, without trying to read following bytes from underlying reader. +// +// That allows reading sequence of blockcodecs streams from one underlying stream of bytes, +// but messes up HTTP keep-alive, when using blockcodecs together with net/http connection pool. +// +// Setting CheckUnderlyingEOF to true, changes that. After encoutering end of stream block, +// Decoder will perform one more Read from underlying reader and check for io.EOF. +func (d *Decoder) SetCheckUnderlyingEOF(checkEOF bool) { + d.checkEOF = checkEOF +} + +func (d *Decoder) Read(p []byte) (int, error) { + if d.eof { + return 0, io.EOF + } + + if d.pos == len(d.buffer) { + if _, err := io.ReadFull(d.r, d.header[:]); err != nil { + return 0, fmt.Errorf("blockcodecs: invalid header: %w", err) + } + + codecID := CodecID(binary.LittleEndian.Uint16(d.header[:2])) + size := int(binary.LittleEndian.Uint64(d.header[2:])) + + codec, err := d.getCodec(codecID) + if err != nil { + return 0, err + } + + if limit := int(maxDecompressedBlockSize.Load()); size > limit { + return 0, fmt.Errorf("blockcodecs: block size exceeds limit: %d > %d", size, limit) + } + + if len(d.scratch) < size { + d.scratch = append(d.scratch, make([]byte, size-len(d.scratch))...) + } + d.scratch = d.scratch[:size] + + if _, err := io.ReadFull(d.r, d.scratch[:]); err != nil { + return 0, fmt.Errorf("blockcodecs: truncated block: %w", err) + } + + decodedSize, err := codec.DecodedLen(d.scratch[:]) + if err != nil { + return 0, fmt.Errorf("blockcodecs: corrupted block: %w", err) + } + + if decodedSize == 0 { + if d.checkEOF { + var scratch [1]byte + n, err := d.r.Read(scratch[:]) + if n != 0 { + return 0, fmt.Errorf("blockcodecs: data after EOF block") + } + if err != nil && err != io.EOF { + return 0, fmt.Errorf("blockcodecs: error after EOF block: %v", err) + } + } + + d.eof = true + return 0, io.EOF + } + + if limit := int(maxDecompressedBlockSize.Load()); decodedSize > limit { + return 0, fmt.Errorf("blockcodecs: decoded block size exceeds limit: %d > %d", decodedSize, limit) + } + + decodeInto := func(buf []byte) error { + out, err := codec.Decode(buf, d.scratch) + if err != nil { + return fmt.Errorf("blockcodecs: corrupted block: %w", err) + } else if len(out) != decodedSize { + return fmt.Errorf("blockcodecs: incorrect block size: %d != %d", len(out), decodedSize) + } + + return nil + } + + if len(p) >= decodedSize { + if err := decodeInto(p[:decodedSize]); err != nil { + return 0, err + } + + return decodedSize, nil + } + + if len(d.buffer) < decodedSize { + d.buffer = append(d.buffer, make([]byte, decodedSize-len(d.buffer))...) + } + d.buffer = d.buffer[:decodedSize] + d.pos = decodedSize + + if err := decodeInto(d.buffer); err != nil { + return 0, err + } + + d.pos = 0 + } + + n := copy(p, d.buffer[d.pos:]) + d.pos += n + return n, nil +} + +// NewDecoder creates decoder that supports input in any of registered codecs. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// NewDecoderCodec creates decode that tries to decode input using provided codec. +func NewDecoderCodec(r io.Reader, codec Codec) *Decoder { + return &Decoder{r: r, codec: codec} +} diff --git a/library/go/blockcodecs/encoder.go b/library/go/blockcodecs/encoder.go new file mode 100644 index 0000000000..b7bb154f79 --- /dev/null +++ b/library/go/blockcodecs/encoder.go @@ -0,0 +1,139 @@ +package blockcodecs + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type encoder struct { + w io.Writer + codec Codec + + closed bool + header [10]byte + + buf []byte + pos int + + scratch []byte +} + +const ( + // defaultBufferSize is 32KB, same as size of buffer used in io.Copy. + defaultBufferSize = 32 << 10 +) + +var ( + _ io.WriteCloser = (*encoder)(nil) +) + +func (e *encoder) Write(p []byte) (int, error) { + if e.closed { + return 0, errors.New("blockcodecs: encoder is closed") + } + + n := len(p) + + // Complete current block + if e.pos != 0 { + m := copy(e.buf[e.pos:], p) + p = p[m:] + e.pos += m + + if e.pos == len(e.buf) { + e.pos = 0 + + if err := e.doFlush(e.buf); err != nil { + return 0, err + } + } + } + + // Copy huge input directly to output + for len(p) >= len(e.buf) { + if e.pos != 0 { + panic("broken invariant") + } + + var chunk []byte + if len(p) > len(e.buf) { + chunk = p[:len(e.buf)] + p = p[len(e.buf):] + } else { + chunk = p + p = nil + } + + if err := e.doFlush(chunk); err != nil { + return 0, err + } + } + + // Store suffix in buffer + m := copy(e.buf, p) + e.pos += m + if m != len(p) { + panic("broken invariant") + } + + return n, nil +} + +func (e *encoder) Close() error { + if e.closed { + return nil + } + + if err := e.Flush(); err != nil { + return err + } + + e.closed = true + + return e.doFlush(nil) +} + +func (e *encoder) doFlush(block []byte) error { + var err error + e.scratch, err = e.codec.Encode(e.scratch, block) + if err != nil { + return fmt.Errorf("blockcodecs: block compression error: %w", err) + } + + binary.LittleEndian.PutUint16(e.header[:2], uint16(e.codec.ID())) + binary.LittleEndian.PutUint64(e.header[2:], uint64(len(e.scratch))) + + if _, err := e.w.Write(e.header[:]); err != nil { + return err + } + + if _, err := e.w.Write(e.scratch); err != nil { + return err + } + + return nil +} + +func (e *encoder) Flush() error { + if e.closed { + return errors.New("blockcodecs: flushing closed encoder") + } + + if e.pos == 0 { + return nil + } + + err := e.doFlush(e.buf[:e.pos]) + e.pos = 0 + return err +} + +func NewEncoder(w io.Writer, codec Codec) io.WriteCloser { + return NewEncoderBuffer(w, codec, defaultBufferSize) +} + +func NewEncoderBuffer(w io.Writer, codec Codec, bufferSize int) io.WriteCloser { + return &encoder{w: w, codec: codec, buf: make([]byte, bufferSize)} +} diff --git a/library/go/blockcodecs/nop_codec.go b/library/go/blockcodecs/nop_codec.go new file mode 100644 index 0000000000..c15e65a29e --- /dev/null +++ b/library/go/blockcodecs/nop_codec.go @@ -0,0 +1,27 @@ +package blockcodecs + +type nopCodec struct{} + +func (n nopCodec) ID() CodecID { + return 54476 +} + +func (n nopCodec) Name() string { + return "null" +} + +func (n nopCodec) DecodedLen(in []byte) (int, error) { + return len(in), nil +} + +func (n nopCodec) Encode(dst, src []byte) ([]byte, error) { + return append(dst[:0], src...), nil +} + +func (n nopCodec) Decode(dst, src []byte) ([]byte, error) { + return append(dst[:0], src...), nil +} + +func init() { + Register(nopCodec{}) +} diff --git a/library/go/blockcodecs/ya.make b/library/go/blockcodecs/ya.make new file mode 100644 index 0000000000..a4544c28d2 --- /dev/null +++ b/library/go/blockcodecs/ya.make @@ -0,0 +1,19 @@ +GO_LIBRARY() + +SRCS( + codecs.go + decoder.go + encoder.go + nop_codec.go +) + +END() + +RECURSE( + all + blockbrotli + blocklz4 + blocksnappy + blockzstd + integration +) diff --git a/library/go/certifi/cas.go b/library/go/certifi/cas.go new file mode 100644 index 0000000000..093ce0b23b --- /dev/null +++ b/library/go/certifi/cas.go @@ -0,0 +1,35 @@ +package certifi + +import ( + "crypto/x509" + "sync" + + "github.com/ydb-platform/ydb/library/go/certifi/internal/certs" +) + +var ( + internalOnce sync.Once + commonOnce sync.Once + internalCAs []*x509.Certificate + commonCAs []*x509.Certificate +) + +// InternalCAs returns list of Yandex Internal certificates +func InternalCAs() []*x509.Certificate { + internalOnce.Do(initInternalCAs) + return internalCAs +} + +// CommonCAs returns list of common certificates +func CommonCAs() []*x509.Certificate { + commonOnce.Do(initCommonCAs) + return commonCAs +} + +func initInternalCAs() { + internalCAs = certsFromPEM(certs.InternalCAs()) +} + +func initCommonCAs() { + commonCAs = certsFromPEM(certs.CommonCAs()) +} diff --git a/library/go/certifi/certifi.go b/library/go/certifi/certifi.go new file mode 100644 index 0000000000..e969263883 --- /dev/null +++ b/library/go/certifi/certifi.go @@ -0,0 +1,80 @@ +package certifi + +import ( + "crypto/x509" + "os" +) + +var underYaMake = true + +// NewCertPool returns a copy of the system or bundled cert pool. +// +// Default behavior can be modified with env variable, e.g. use system pool: +// +// CERTIFI_USE_SYSTEM_CA=yes ./my-cool-program +func NewCertPool() (caCertPool *x509.CertPool, err error) { + if forceSystem() { + return NewCertPoolSystem() + } + + return NewCertPoolBundled() +} + +// NewCertPoolSystem returns a copy of the system cert pool + common CAs + internal CAs +// +// WARNING: system cert pool is not available on Windows +func NewCertPoolSystem() (caCertPool *x509.CertPool, err error) { + caCertPool, err = x509.SystemCertPool() + + if err != nil || caCertPool == nil { + caCertPool = x509.NewCertPool() + } + + for _, cert := range CommonCAs() { + caCertPool.AddCert(cert) + } + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +// NewCertPoolBundled returns a new cert pool with common CAs + internal CAs +func NewCertPoolBundled() (caCertPool *x509.CertPool, err error) { + caCertPool = x509.NewCertPool() + + for _, cert := range CommonCAs() { + caCertPool.AddCert(cert) + } + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +// NewCertPoolInternal returns a new cert pool with internal CAs +func NewCertPoolInternal() (caCertPool *x509.CertPool, err error) { + caCertPool = x509.NewCertPool() + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +func forceSystem() bool { + if os.Getenv("CERTIFI_USE_SYSTEM_CA") == "yes" { + return true + } + + if !underYaMake && len(InternalCAs()) == 0 { + return true + } + + return false +} diff --git a/library/go/certifi/doc.go b/library/go/certifi/doc.go new file mode 100644 index 0000000000..d988ba0563 --- /dev/null +++ b/library/go/certifi/doc.go @@ -0,0 +1,4 @@ +// Certifi is a collection of public and internal Root Certificates for validating the trustworthiness of SSL certificates while verifying the identity of TLS hosts. +// +// Certifi use Arcadia Root Certificates for that: https://github.com/ydb-platform/ydb/arc/trunk/arcadia/certs +package certifi diff --git a/library/go/certifi/internal/certs/certs.go b/library/go/certifi/internal/certs/certs.go new file mode 100644 index 0000000000..1e64fe7157 --- /dev/null +++ b/library/go/certifi/internal/certs/certs.go @@ -0,0 +1,13 @@ +package certs + +import ( + "github.com/ydb-platform/ydb/library/go/core/resource" +) + +func InternalCAs() []byte { + return resource.Get("/certifi/internal.pem") +} + +func CommonCAs() []byte { + return resource.Get("/certifi/common.pem") +} diff --git a/library/go/certifi/internal/certs/ya.make b/library/go/certifi/internal/certs/ya.make new file mode 100644 index 0000000000..d16d7ab5ad --- /dev/null +++ b/library/go/certifi/internal/certs/ya.make @@ -0,0 +1,10 @@ +GO_LIBRARY() + +RESOURCE( + certs/cacert.pem /certifi/common.pem + certs/yandex_internal.pem /certifi/internal.pem +) + +SRCS(certs.go) + +END() diff --git a/library/go/certifi/utils.go b/library/go/certifi/utils.go new file mode 100644 index 0000000000..76d90e3f1f --- /dev/null +++ b/library/go/certifi/utils.go @@ -0,0 +1,29 @@ +package certifi + +import ( + "crypto/x509" + "encoding/pem" +) + +func certsFromPEM(pemCerts []byte) []*x509.Certificate { + var result []*x509.Certificate + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + continue + } + + result = append(result, cert) + } + + return result +} diff --git a/library/go/certifi/ya.make b/library/go/certifi/ya.make new file mode 100644 index 0000000000..d8181f1d68 --- /dev/null +++ b/library/go/certifi/ya.make @@ -0,0 +1,21 @@ +GO_LIBRARY() + +SRCS( + cas.go + certifi.go + doc.go + utils.go +) + +GO_XTEST_SRCS( + certifi_example_test.go + certifi_test.go + utils_test.go +) + +END() + +RECURSE( + gotest + internal +) diff --git a/library/go/cgosem/sem.go b/library/go/cgosem/sem.go new file mode 100644 index 0000000000..357e0529db --- /dev/null +++ b/library/go/cgosem/sem.go @@ -0,0 +1,67 @@ +// Package cgosem implements fast and imprecise semaphore used to globally limit concurrency of _fast_ cgo calls. +// +// In the future, when go runtime scheduler gets smarter and stop suffering from uncontrolled growth the number of +// system threads, this package should be removed. +// +// See "Cgoroutines != Goroutines" section of https://www.cockroachlabs.com/blog/the-cost-and-complexity-of-cgo/ +// for explanation of the thread leak problem. +// +// To use this semaphore, put the following line at the beginning of the function doing Cgo calls. +// +// defer cgosem.S.Acquire().Release() +// +// This will globally limit number of concurrent Cgo calls to GOMAXPROCS, limiting number of additional threads created by the +// go runtime to the same number. +// +// Overhead of this semaphore is about 1us, which should be negligible compared to the work you are trying to do in the C function. +// +// To see code in action, run: +// +// ya make -r library/go/cgosem/gotest +// env GODEBUG=schedtrace=1000,scheddetail=1 library/go/cgosem/gotest/gotest --test.run TestLeak +// env GODEBUG=schedtrace=1000,scheddetail=1 library/go/cgosem/gotest/gotest --test.run TestLeakFix +// +// And look for the number of created M's. +package cgosem + +import "runtime" + +type Sem chan struct{} + +// new creates new semaphore with max concurrency of n. +func newSem(n int) (s Sem) { + s = make(chan struct{}, n) + for i := 0; i < n; i++ { + s <- struct{}{} + } + return +} + +func (s Sem) Acquire() Sem { + if s == nil { + return nil + } + + <-s + return s +} + +func (s Sem) Release() { + if s == nil { + return + } + + s <- struct{}{} +} + +// S is global semaphore with good enough settings for most cgo libraries. +var S Sem + +// Disable global cgo semaphore. Must be called from init() function. +func Disable() { + S = nil +} + +func init() { + S = newSem(runtime.GOMAXPROCS(0)) +} diff --git a/library/go/cgosem/ya.make b/library/go/cgosem/ya.make new file mode 100644 index 0000000000..8b383384c5 --- /dev/null +++ b/library/go/cgosem/ya.make @@ -0,0 +1,12 @@ +GO_LIBRARY() + +SRCS(sem.go) + +GO_TEST_SRCS(leak_test.go) + +END() + +RECURSE( + dummy + gotest +) diff --git a/library/go/core/metrics/buckets.go b/library/go/core/metrics/buckets.go new file mode 100644 index 0000000000..063c0c4418 --- /dev/null +++ b/library/go/core/metrics/buckets.go @@ -0,0 +1,147 @@ +package metrics + +import ( + "sort" + "time" +) + +var ( + _ DurationBuckets = (*durationBuckets)(nil) + _ Buckets = (*buckets)(nil) +) + +const ( + errBucketsCountNeedsGreaterThanZero = "n needs to be > 0" + errBucketsStartNeedsGreaterThanZero = "start needs to be > 0" + errBucketsFactorNeedsGreaterThanOne = "factor needs to be > 1" +) + +type durationBuckets struct { + buckets []time.Duration +} + +// NewDurationBuckets returns new DurationBuckets implementation. +func NewDurationBuckets(bk ...time.Duration) DurationBuckets { + sort.Slice(bk, func(i, j int) bool { + return bk[i] < bk[j] + }) + return durationBuckets{buckets: bk} +} + +func (d durationBuckets) Size() int { + return len(d.buckets) +} + +func (d durationBuckets) MapDuration(dv time.Duration) (idx int) { + for _, bound := range d.buckets { + if dv < bound { + break + } + idx++ + } + return +} + +func (d durationBuckets) UpperBound(idx int) time.Duration { + if idx > d.Size()-1 { + panic("idx is out of bounds") + } + return d.buckets[idx] +} + +type buckets struct { + buckets []float64 +} + +// NewBuckets returns new Buckets implementation. +func NewBuckets(bk ...float64) Buckets { + sort.Slice(bk, func(i, j int) bool { + return bk[i] < bk[j] + }) + return buckets{buckets: bk} +} + +func (d buckets) Size() int { + return len(d.buckets) +} + +func (d buckets) MapValue(v float64) (idx int) { + for _, bound := range d.buckets { + if v < bound { + break + } + idx++ + } + return +} + +func (d buckets) UpperBound(idx int) float64 { + if idx > d.Size()-1 { + panic("idx is out of bounds") + } + return d.buckets[idx] +} + +// MakeLinearBuckets creates a set of linear value buckets. +func MakeLinearBuckets(start, width float64, n int) Buckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + bounds := make([]float64, n) + for i := range bounds { + bounds[i] = start + (float64(i) * width) + } + return NewBuckets(bounds...) +} + +// MakeLinearDurationBuckets creates a set of linear duration buckets. +func MakeLinearDurationBuckets(start, width time.Duration, n int) DurationBuckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + buckets := make([]time.Duration, n) + for i := range buckets { + buckets[i] = start + (time.Duration(i) * width) + } + return NewDurationBuckets(buckets...) +} + +// MakeExponentialBuckets creates a set of exponential value buckets. +func MakeExponentialBuckets(start, factor float64, n int) Buckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + if start <= 0 { + panic(errBucketsStartNeedsGreaterThanZero) + } + if factor <= 1 { + panic(errBucketsFactorNeedsGreaterThanOne) + } + buckets := make([]float64, n) + curr := start + for i := range buckets { + buckets[i] = curr + curr *= factor + } + return NewBuckets(buckets...) +} + +// MakeExponentialDurationBuckets creates a set of exponential duration buckets. +func MakeExponentialDurationBuckets(start time.Duration, factor float64, n int) DurationBuckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + if start <= 0 { + panic(errBucketsStartNeedsGreaterThanZero) + } + if factor <= 1 { + panic(errBucketsFactorNeedsGreaterThanOne) + } + buckets := make([]time.Duration, n) + curr := start + for i := range buckets { + buckets[i] = curr + curr = time.Duration(float64(curr) * factor) + } + return NewDurationBuckets(buckets...) +} diff --git a/library/go/core/metrics/buckets_test.go b/library/go/core/metrics/buckets_test.go new file mode 100644 index 0000000000..70cb6398c2 --- /dev/null +++ b/library/go/core/metrics/buckets_test.go @@ -0,0 +1,183 @@ +package metrics + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewDurationBuckets(t *testing.T) { + buckets := []time.Duration{ + 1 * time.Second, + 3 * time.Second, + 5 * time.Second, + } + bk := NewDurationBuckets(buckets...) + + expect := durationBuckets{ + buckets: []time.Duration{ + 1 * time.Second, + 3 * time.Second, + 5 * time.Second, + }, + } + assert.Equal(t, expect, bk) +} + +func Test_durationBuckets_MapDuration(t *testing.T) { + bk := NewDurationBuckets([]time.Duration{ + 1 * time.Second, + 3 * time.Second, + 5 * time.Second, + }...) + + for i := 0; i <= bk.Size(); i++ { + assert.Equal(t, i, bk.MapDuration(time.Duration(i*2)*time.Second)) + } +} + +func Test_durationBuckets_Size(t *testing.T) { + var buckets []time.Duration + for i := 1; i < 3; i++ { + buckets = append(buckets, time.Duration(i)*time.Second) + bk := NewDurationBuckets(buckets...) + assert.Equal(t, i, bk.Size()) + } +} + +func Test_durationBuckets_UpperBound(t *testing.T) { + bk := NewDurationBuckets([]time.Duration{ + 1 * time.Second, + 2 * time.Second, + 3 * time.Second, + }...) + + assert.Panics(t, func() { bk.UpperBound(999) }) + + for i := 0; i < bk.Size()-1; i++ { + assert.Equal(t, time.Duration(i+1)*time.Second, bk.UpperBound(i)) + } +} + +func TestNewBuckets(t *testing.T) { + bk := NewBuckets(1, 3, 5) + + expect := buckets{ + buckets: []float64{1, 3, 5}, + } + assert.Equal(t, expect, bk) +} + +func Test_buckets_MapValue(t *testing.T) { + bk := NewBuckets(1, 3, 5) + + for i := 0; i <= bk.Size(); i++ { + assert.Equal(t, i, bk.MapValue(float64(i*2))) + } +} + +func Test_buckets_Size(t *testing.T) { + var buckets []float64 + for i := 1; i < 3; i++ { + buckets = append(buckets, float64(i)) + bk := NewBuckets(buckets...) + assert.Equal(t, i, bk.Size()) + } +} + +func Test_buckets_UpperBound(t *testing.T) { + bk := NewBuckets(1, 2, 3) + + assert.Panics(t, func() { bk.UpperBound(999) }) + + for i := 0; i < bk.Size()-1; i++ { + assert.Equal(t, float64(i+1), bk.UpperBound(i)) + } +} + +func TestMakeLinearBuckets_CorrectParameters_NotPanics(t *testing.T) { + assert.NotPanics(t, func() { + assert.Equal(t, + NewBuckets(0.0, 1.0, 2.0), + MakeLinearBuckets(0, 1, 3), + ) + }) +} + +func TestMakeLinearBucketsPanicsOnBadCount(t *testing.T) { + assert.Panics(t, func() { + MakeLinearBuckets(0, 1, 0) + }) +} + +func TestMakeLinearDurationBuckets(t *testing.T) { + assert.NotPanics(t, func() { + assert.Equal(t, + NewDurationBuckets(0, time.Second, 2*time.Second), + MakeLinearDurationBuckets(0*time.Second, 1*time.Second, 3), + ) + }) +} + +func TestMakeLinearDurationBucketsPanicsOnBadCount(t *testing.T) { + assert.Panics(t, func() { + MakeLinearDurationBuckets(0*time.Second, 1*time.Second, 0) + }) +} + +func TestMakeExponentialBuckets(t *testing.T) { + assert.NotPanics(t, func() { + assert.Equal( + t, + NewBuckets(2, 4, 8), + MakeExponentialBuckets(2, 2, 3), + ) + }) +} + +func TestMakeExponentialBucketsPanicsOnBadCount(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialBuckets(2, 2, 0) + }) +} + +func TestMakeExponentialBucketsPanicsOnBadStart(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialBuckets(0, 2, 2) + }) +} + +func TestMakeExponentialBucketsPanicsOnBadFactor(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialBuckets(2, 1, 2) + }) +} + +func TestMakeExponentialDurationBuckets(t *testing.T) { + assert.NotPanics(t, func() { + assert.Equal( + t, + NewDurationBuckets(2*time.Second, 4*time.Second, 8*time.Second), + MakeExponentialDurationBuckets(2*time.Second, 2, 3), + ) + }) +} + +func TestMakeExponentialDurationBucketsPanicsOnBadCount(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialDurationBuckets(2*time.Second, 2, 0) + }) +} + +func TestMakeExponentialDurationBucketsPanicsOnBadStart(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialDurationBuckets(0, 2, 2) + }) +} + +func TestMakeExponentialDurationBucketsPanicsOnBadFactor(t *testing.T) { + assert.Panics(t, func() { + MakeExponentialDurationBuckets(2*time.Second, 1, 2) + }) +} diff --git a/library/go/core/metrics/collect/collect.go b/library/go/core/metrics/collect/collect.go new file mode 100644 index 0000000000..492a2f74a5 --- /dev/null +++ b/library/go/core/metrics/collect/collect.go @@ -0,0 +1,9 @@ +package collect + +import ( + "context" + + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +type Func func(ctx context.Context, r metrics.Registry, c metrics.CollectPolicy) diff --git a/library/go/core/metrics/collect/policy/inflight/inflight.go b/library/go/core/metrics/collect/policy/inflight/inflight.go new file mode 100644 index 0000000000..bc045fe188 --- /dev/null +++ b/library/go/core/metrics/collect/policy/inflight/inflight.go @@ -0,0 +1,78 @@ +package inflight + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/x/xsync" +) + +var _ metrics.CollectPolicy = (*inflightPolicy)(nil) + +type inflightPolicy struct { + addCollectLock sync.Mutex + collect atomic.Value // func(ctx context.Context) + + minUpdateInterval time.Duration + lastUpdate time.Time + + inflight xsync.SingleInflight +} + +func NewCollectorPolicy(opts ...Option) metrics.CollectPolicy { + c := &inflightPolicy{ + minUpdateInterval: time.Second, + inflight: xsync.NewSingleInflight(), + } + c.collect.Store(func(context.Context) {}) + + for _, opt := range opts { + opt(c) + } + + return c +} + +func (i *inflightPolicy) RegisteredCounter(counterFunc func() int64) func() int64 { + return func() int64 { + i.tryInflightUpdate() + return counterFunc() + } +} + +func (i *inflightPolicy) RegisteredGauge(gaugeFunc func() float64) func() float64 { + return func() float64 { + i.tryInflightUpdate() + return gaugeFunc() + } +} + +func (i *inflightPolicy) AddCollect(collect func(context.Context)) { + oldCollect := i.getCollect() + i.setCollect(func(ctx context.Context) { + oldCollect(ctx) + collect(ctx) + }) +} + +func (i *inflightPolicy) tryInflightUpdate() { + i.inflight.Do(func() { + if time.Since(i.lastUpdate) < i.minUpdateInterval { + return + } + + i.getCollect()(context.Background()) + i.lastUpdate = time.Now() + }) +} + +func (i *inflightPolicy) getCollect() func(context.Context) { + return i.collect.Load().(func(context.Context)) +} + +func (i *inflightPolicy) setCollect(collect func(context.Context)) { + i.collect.Store(collect) +} diff --git a/library/go/core/metrics/collect/policy/inflight/inflight_opts.go b/library/go/core/metrics/collect/policy/inflight/inflight_opts.go new file mode 100644 index 0000000000..cc277b0c71 --- /dev/null +++ b/library/go/core/metrics/collect/policy/inflight/inflight_opts.go @@ -0,0 +1,11 @@ +package inflight + +import "time" + +type Option func(*inflightPolicy) + +func WithMinCollectInterval(interval time.Duration) Option { + return func(c *inflightPolicy) { + c.minUpdateInterval = interval + } +} diff --git a/library/go/core/metrics/collect/policy/inflight/ya.make b/library/go/core/metrics/collect/policy/inflight/ya.make new file mode 100644 index 0000000000..6101e04049 --- /dev/null +++ b/library/go/core/metrics/collect/policy/inflight/ya.make @@ -0,0 +1,8 @@ +GO_LIBRARY() + +SRCS( + inflight.go + inflight_opts.go +) + +END() diff --git a/library/go/core/metrics/collect/policy/ya.make b/library/go/core/metrics/collect/policy/ya.make new file mode 100644 index 0000000000..2717ef9863 --- /dev/null +++ b/library/go/core/metrics/collect/policy/ya.make @@ -0,0 +1 @@ +RECURSE(inflight) diff --git a/library/go/core/metrics/collect/system.go b/library/go/core/metrics/collect/system.go new file mode 100644 index 0000000000..a21e91d632 --- /dev/null +++ b/library/go/core/metrics/collect/system.go @@ -0,0 +1,229 @@ +// dashboard generator for these metrics can be found at: github.com/ydb-platform/ydb/arcadia/library/go/yandex/monitoring-dashboards +package collect + +import ( + "context" + "os" + "runtime" + "runtime/debug" + "time" + + "github.com/prometheus/procfs" + "github.com/ydb-platform/ydb/library/go/core/buildinfo" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ Func = GoMetrics + +func GoMetrics(_ context.Context, r metrics.Registry, c metrics.CollectPolicy) { + if r == nil { + return + } + r = r.WithPrefix("go") + + var stats debug.GCStats + stats.PauseQuantiles = make([]time.Duration, 5) // Minimum, 25%, 50%, 75%, and maximum pause times. + var numGoroutine, numThread int + var ms runtime.MemStats + + c.AddCollect(func(context.Context) { + debug.ReadGCStats(&stats) + runtime.ReadMemStats(&ms) + + numThread, _ = runtime.ThreadCreateProfile(nil) + numGoroutine = runtime.NumGoroutine() + }) + + gcRegistry := r.WithPrefix("gc") + gcRegistry.FuncCounter("num", c.RegisteredCounter(func() int64 { + return stats.NumGC + })) + gcRegistry.FuncCounter(r.ComposeName("pause", "total", "ns"), c.RegisteredCounter(func() int64 { + return stats.PauseTotal.Nanoseconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "min"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[0].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "25"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[1].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "50"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[2].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "75"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[3].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "max"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[4].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("last", "ts"), c.RegisteredGauge(func() float64 { + return float64(ms.LastGC) + })) + gcRegistry.FuncCounter(r.ComposeName("forced", "num"), c.RegisteredCounter(func() int64 { + return int64(ms.NumForcedGC) + })) + + r.FuncGauge(r.ComposeName("goroutine", "num"), c.RegisteredGauge(func() float64 { + return float64(numGoroutine) + })) + r.FuncGauge(r.ComposeName("thread", "num"), c.RegisteredGauge(func() float64 { + return float64(numThread) + })) + + memRegistry := r.WithPrefix("mem") + memRegistry.FuncCounter(r.ComposeName("alloc", "total"), c.RegisteredCounter(func() int64 { + return int64(ms.TotalAlloc) + })) + memRegistry.FuncGauge("sys", c.RegisteredGauge(func() float64 { + return float64(ms.Sys) + })) + memRegistry.FuncCounter("lookups", c.RegisteredCounter(func() int64 { + return int64(ms.Lookups) + })) + memRegistry.FuncCounter("mallocs", c.RegisteredCounter(func() int64 { + return int64(ms.Mallocs) + })) + memRegistry.FuncCounter("frees", c.RegisteredCounter(func() int64 { + return int64(ms.Frees) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "alloc"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapAlloc) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapSys) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "idle"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapIdle) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapInuse) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "released"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapReleased) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "objects"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapObjects) + })) + + memRegistry.FuncGauge(r.ComposeName("stack", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.StackInuse) + })) + memRegistry.FuncGauge(r.ComposeName("stack", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.StackSys) + })) + + memRegistry.FuncGauge(r.ComposeName("span", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.MSpanInuse) + })) + memRegistry.FuncGauge(r.ComposeName("span", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.MSpanSys) + })) + + memRegistry.FuncGauge(r.ComposeName("cache", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.MCacheInuse) + })) + memRegistry.FuncGauge(r.ComposeName("cache", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.MCacheSys) + })) + + memRegistry.FuncGauge(r.ComposeName("buck", "hash", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.BuckHashSys) + })) + memRegistry.FuncGauge(r.ComposeName("gc", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.GCSys) + })) + memRegistry.FuncGauge(r.ComposeName("other", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.OtherSys) + })) + memRegistry.FuncGauge(r.ComposeName("gc", "next"), c.RegisteredGauge(func() float64 { + return float64(ms.NextGC) + })) + + memRegistry.FuncGauge(r.ComposeName("gc", "cpu", "fraction"), c.RegisteredGauge(func() float64 { + return ms.GCCPUFraction + })) +} + +var _ Func = ProcessMetrics + +func ProcessMetrics(_ context.Context, r metrics.Registry, c metrics.CollectPolicy) { + if r == nil { + return + } + buildVersion := buildinfo.Info.ArcadiaSourceRevision + r.WithTags(map[string]string{"revision": buildVersion}).Gauge("build").Set(1.0) + + pid := os.Getpid() + proc, err := procfs.NewProc(pid) + if err != nil { + return + } + + procRegistry := r.WithPrefix("proc") + + var ioStat procfs.ProcIO + var procStat procfs.ProcStat + var fd int + var cpuWait uint64 + + const clocksPerSec = 100 + + c.AddCollect(func(ctx context.Context) { + if gatheredFD, err := proc.FileDescriptorsLen(); err == nil { + fd = gatheredFD + } + + if gatheredIOStat, err := proc.IO(); err == nil { + ioStat.SyscW = gatheredIOStat.SyscW + ioStat.WriteBytes = gatheredIOStat.WriteBytes + ioStat.SyscR = gatheredIOStat.SyscR + ioStat.ReadBytes = gatheredIOStat.ReadBytes + } + + if gatheredStat, err := proc.Stat(); err == nil { + procStat.UTime = gatheredStat.UTime + procStat.STime = gatheredStat.STime + procStat.RSS = gatheredStat.RSS + } + + if gatheredSched, err := proc.Schedstat(); err == nil { + cpuWait = gatheredSched.WaitingNanoseconds + } + }) + + procRegistry.FuncGauge("fd", c.RegisteredGauge(func() float64 { + return float64(fd) + })) + + ioRegistry := procRegistry.WithPrefix("io") + ioRegistry.FuncCounter(r.ComposeName("read", "count"), c.RegisteredCounter(func() int64 { + return int64(ioStat.SyscR) + })) + ioRegistry.FuncCounter(r.ComposeName("read", "bytes"), c.RegisteredCounter(func() int64 { + return int64(ioStat.ReadBytes) + })) + ioRegistry.FuncCounter(r.ComposeName("write", "count"), c.RegisteredCounter(func() int64 { + return int64(ioStat.SyscW) + })) + ioRegistry.FuncCounter(r.ComposeName("write", "bytes"), c.RegisteredCounter(func() int64 { + return int64(ioStat.WriteBytes) + })) + + cpuRegistry := procRegistry.WithPrefix("cpu") + cpuRegistry.FuncCounter(r.ComposeName("total", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.UTime+procStat.STime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("user", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.UTime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("system", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.STime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("wait", "ns"), c.RegisteredCounter(func() int64 { + return int64(cpuWait) + })) + + procRegistry.FuncGauge(r.ComposeName("mem", "rss"), c.RegisteredGauge(func() float64 { + return float64(procStat.RSS) + })) +} diff --git a/library/go/core/metrics/collect/ya.make b/library/go/core/metrics/collect/ya.make new file mode 100644 index 0000000000..be81763221 --- /dev/null +++ b/library/go/core/metrics/collect/ya.make @@ -0,0 +1,10 @@ +GO_LIBRARY() + +SRCS( + collect.go + system.go +) + +END() + +RECURSE(policy) diff --git a/library/go/core/metrics/gotest/ya.make b/library/go/core/metrics/gotest/ya.make new file mode 100644 index 0000000000..d0bdf91982 --- /dev/null +++ b/library/go/core/metrics/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/core/metrics) + +END() diff --git a/library/go/core/metrics/internal/pkg/metricsutil/buckets.go b/library/go/core/metrics/internal/pkg/metricsutil/buckets.go new file mode 100644 index 0000000000..e9501fcceb --- /dev/null +++ b/library/go/core/metrics/internal/pkg/metricsutil/buckets.go @@ -0,0 +1,27 @@ +package metricsutil + +import ( + "sort" + + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +// BucketsBounds unwraps Buckets bounds to slice of float64. +func BucketsBounds(b metrics.Buckets) []float64 { + bkts := make([]float64, b.Size()) + for i := range bkts { + bkts[i] = b.UpperBound(i) + } + sort.Float64s(bkts) + return bkts +} + +// DurationBucketsBounds unwraps DurationBuckets bounds to slice of float64. +func DurationBucketsBounds(b metrics.DurationBuckets) []float64 { + bkts := make([]float64, b.Size()) + for i := range bkts { + bkts[i] = b.UpperBound(i).Seconds() + } + sort.Float64s(bkts) + return bkts +} diff --git a/library/go/core/metrics/internal/pkg/metricsutil/ya.make b/library/go/core/metrics/internal/pkg/metricsutil/ya.make new file mode 100644 index 0000000000..3058637089 --- /dev/null +++ b/library/go/core/metrics/internal/pkg/metricsutil/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(buckets.go) + +END() diff --git a/library/go/core/metrics/internal/pkg/registryutil/gotest/ya.make b/library/go/core/metrics/internal/pkg/registryutil/gotest/ya.make new file mode 100644 index 0000000000..55c204d140 --- /dev/null +++ b/library/go/core/metrics/internal/pkg/registryutil/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/core/metrics/internal/pkg/registryutil) + +END() diff --git a/library/go/core/metrics/internal/pkg/registryutil/registryutil.go b/library/go/core/metrics/internal/pkg/registryutil/registryutil.go new file mode 100644 index 0000000000..ebce50d8cb --- /dev/null +++ b/library/go/core/metrics/internal/pkg/registryutil/registryutil.go @@ -0,0 +1,104 @@ +package registryutil + +import ( + "errors" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/OneOfOne/xxhash" +) + +// BuildRegistryKey creates registry name based on given prefix and tags +func BuildRegistryKey(prefix string, tags map[string]string) string { + var builder strings.Builder + + builder.WriteString(strconv.Quote(prefix)) + builder.WriteRune('{') + builder.WriteString(StringifyTags(tags)) + builder.WriteByte('}') + + return builder.String() +} + +// BuildFQName returns name parts joined by given separator. +// Mainly used to append prefix to registry +func BuildFQName(separator string, parts ...string) (name string) { + var b strings.Builder + for _, p := range parts { + if p == "" { + continue + } + if b.Len() > 0 { + b.WriteString(separator) + } + b.WriteString(strings.Trim(p, separator)) + } + return b.String() +} + +// MergeTags merges 2 sets of tags with the tags from tagsRight overriding values from tagsLeft +func MergeTags(leftTags map[string]string, rightTags map[string]string) map[string]string { + if leftTags == nil && rightTags == nil { + return nil + } + + if len(leftTags) == 0 { + return rightTags + } + + if len(rightTags) == 0 { + return leftTags + } + + newTags := make(map[string]string) + for key, value := range leftTags { + newTags[key] = value + } + for key, value := range rightTags { + newTags[key] = value + } + return newTags +} + +// StringifyTags returns string representation of given tags map. +// It is guaranteed that equal sets of tags will produce equal strings. +func StringifyTags(tags map[string]string) string { + keys := make([]string, 0, len(tags)) + for key := range tags { + keys = append(keys, key) + } + sort.Strings(keys) + + var builder strings.Builder + for i, key := range keys { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(key + "=" + tags[key]) + } + + return builder.String() +} + +// VectorHash computes hash of metrics vector element +func VectorHash(tags map[string]string, labels []string) (uint64, error) { + if len(tags) != len(labels) { + return 0, errors.New("inconsistent tags and labels sets") + } + + h := xxhash.New64() + + for _, label := range labels { + v, ok := tags[label] + if !ok { + return 0, fmt.Errorf("label '%s' not found in tags", label) + } + _, _ = h.WriteString(label) + _, _ = h.WriteString(v) + _, _ = h.WriteString(",") + } + + return h.Sum64(), nil +} diff --git a/library/go/core/metrics/internal/pkg/registryutil/registryutil_test.go b/library/go/core/metrics/internal/pkg/registryutil/registryutil_test.go new file mode 100644 index 0000000000..5463f04755 --- /dev/null +++ b/library/go/core/metrics/internal/pkg/registryutil/registryutil_test.go @@ -0,0 +1,48 @@ +package registryutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildFQName(t *testing.T) { + testCases := []struct { + name string + parts []string + sep string + expected string + }{ + { + name: "empty", + parts: nil, + sep: "_", + expected: "", + }, + { + name: "one part", + parts: []string{"part"}, + sep: "_", + expected: "part", + }, + { + name: "two parts", + parts: []string{"part", "another"}, + sep: "_", + expected: "part_another", + }, + { + name: "parts with sep", + parts: []string{"abcde", "deabc"}, + sep: "abc", + expected: "deabcde", + }, + } + + for _, testCase := range testCases { + c := testCase + t.Run(c.name, func(t *testing.T) { + assert.Equal(t, c.expected, BuildFQName(c.sep, c.parts...)) + }) + } +} diff --git a/library/go/core/metrics/internal/pkg/registryutil/ya.make b/library/go/core/metrics/internal/pkg/registryutil/ya.make new file mode 100644 index 0000000000..4a1f976d40 --- /dev/null +++ b/library/go/core/metrics/internal/pkg/registryutil/ya.make @@ -0,0 +1,9 @@ +GO_LIBRARY() + +SRCS(registryutil.go) + +GO_TEST_SRCS(registryutil_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/core/metrics/internal/pkg/ya.make b/library/go/core/metrics/internal/pkg/ya.make new file mode 100644 index 0000000000..416d1b3e5d --- /dev/null +++ b/library/go/core/metrics/internal/pkg/ya.make @@ -0,0 +1,4 @@ +RECURSE( + metricsutil + registryutil +) diff --git a/library/go/core/metrics/internal/ya.make b/library/go/core/metrics/internal/ya.make new file mode 100644 index 0000000000..b2a587f35d --- /dev/null +++ b/library/go/core/metrics/internal/ya.make @@ -0,0 +1 @@ +RECURSE(pkg) diff --git a/library/go/core/metrics/metrics.go b/library/go/core/metrics/metrics.go new file mode 100644 index 0000000000..097fca9a55 --- /dev/null +++ b/library/go/core/metrics/metrics.go @@ -0,0 +1,163 @@ +// Package metrics provides interface collecting performance metrics. +package metrics + +import ( + "context" + "time" +) + +// Gauge tracks single float64 value. +type Gauge interface { + Set(value float64) + Add(value float64) +} + +// FuncGauge is Gauge with value provided by callback function. +type FuncGauge interface { + Function() func() float64 +} + +// IntGauge tracks single int64 value. +type IntGauge interface { + Set(value int64) + Add(value int64) +} + +// FuncIntGauge is IntGauge with value provided by callback function. +type FuncIntGauge interface { + Function() func() int64 +} + +// Counter tracks monotonically increasing value. +type Counter interface { + // Inc increments counter by 1. + Inc() + + // Add adds delta to the counter. Delta must be >=0. + Add(delta int64) +} + +// FuncCounter is Counter with value provided by callback function. +type FuncCounter interface { + Function() func() int64 +} + +// Histogram tracks distribution of value. +type Histogram interface { + RecordValue(value float64) +} + +// Timer measures durations. +type Timer interface { + RecordDuration(value time.Duration) +} + +// DurationBuckets defines buckets of the duration histogram. +type DurationBuckets interface { + // Size returns number of buckets. + Size() int + + // MapDuration returns index of the bucket. + // + // index is integer in range [0, Size()). + MapDuration(d time.Duration) int + + // UpperBound of the last bucket is always +Inf. + // + // bucketIndex is integer in range [0, Size()-1). + UpperBound(bucketIndex int) time.Duration +} + +// Buckets defines intervals of the regular histogram. +type Buckets interface { + // Size returns number of buckets. + Size() int + + // MapValue returns index of the bucket. + // + // Index is integer in range [0, Size()). + MapValue(v float64) int + + // UpperBound of the last bucket is always +Inf. + // + // bucketIndex is integer in range [0, Size()-1). + UpperBound(bucketIndex int) float64 +} + +// GaugeVec stores multiple dynamically created gauges. +type GaugeVec interface { + With(map[string]string) Gauge + + // Reset deletes all metrics in vector. + Reset() +} + +// IntGaugeVec stores multiple dynamically created gauges. +type IntGaugeVec interface { + With(map[string]string) IntGauge + + // Reset deletes all metrics in vector. + Reset() +} + +// CounterVec stores multiple dynamically created counters. +type CounterVec interface { + With(map[string]string) Counter + + // Reset deletes all metrics in vector. + Reset() +} + +// TimerVec stores multiple dynamically created timers. +type TimerVec interface { + With(map[string]string) Timer + + // Reset deletes all metrics in vector. + Reset() +} + +// HistogramVec stores multiple dynamically created histograms. +type HistogramVec interface { + With(map[string]string) Histogram + + // Reset deletes all metrics in vector. + Reset() +} + +// Registry creates profiling metrics. +type Registry interface { + // WithTags creates new sub-scope, where each metric has tags attached to it. + WithTags(tags map[string]string) Registry + // WithPrefix creates new sub-scope, where each metric has prefix added to it name. + WithPrefix(prefix string) Registry + + ComposeName(parts ...string) string + + Counter(name string) Counter + CounterVec(name string, labels []string) CounterVec + FuncCounter(name string, function func() int64) FuncCounter + + Gauge(name string) Gauge + GaugeVec(name string, labels []string) GaugeVec + FuncGauge(name string, function func() float64) FuncGauge + + IntGauge(name string) IntGauge + IntGaugeVec(name string, labels []string) IntGaugeVec + FuncIntGauge(name string, function func() int64) FuncIntGauge + + Timer(name string) Timer + TimerVec(name string, labels []string) TimerVec + + Histogram(name string, buckets Buckets) Histogram + HistogramVec(name string, buckets Buckets, labels []string) HistogramVec + + DurationHistogram(name string, buckets DurationBuckets) Timer + DurationHistogramVec(name string, buckets DurationBuckets, labels []string) TimerVec +} + +// CollectPolicy defines how registered gauge metrics are updated via collect func. +type CollectPolicy interface { + RegisteredCounter(counterFunc func() int64) func() int64 + RegisteredGauge(gaugeFunc func() float64) func() float64 + AddCollect(collect func(ctx context.Context)) +} diff --git a/library/go/core/metrics/mock/counter.go b/library/go/core/metrics/mock/counter.go new file mode 100644 index 0000000000..c3016ea1a9 --- /dev/null +++ b/library/go/core/metrics/mock/counter.go @@ -0,0 +1,35 @@ +package mock + +import ( + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var _ metrics.Counter = (*Counter)(nil) + +// Counter tracks monotonically increasing value. +type Counter struct { + Name string + Tags map[string]string + Value *atomic.Int64 +} + +// Inc increments counter by 1. +func (c *Counter) Inc() { + c.Add(1) +} + +// Add adds delta to the counter. Delta must be >=0. +func (c *Counter) Add(delta int64) { + c.Value.Add(delta) +} + +var _ metrics.FuncCounter = (*FuncCounter)(nil) + +type FuncCounter struct { + function func() int64 +} + +func (c FuncCounter) Function() func() int64 { + return c.function +} diff --git a/library/go/core/metrics/mock/gauge.go b/library/go/core/metrics/mock/gauge.go new file mode 100644 index 0000000000..58d2d29beb --- /dev/null +++ b/library/go/core/metrics/mock/gauge.go @@ -0,0 +1,33 @@ +package mock + +import ( + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var _ metrics.Gauge = (*Gauge)(nil) + +// Gauge tracks single float64 value. +type Gauge struct { + Name string + Tags map[string]string + Value *atomic.Float64 +} + +func (g *Gauge) Set(value float64) { + g.Value.Store(value) +} + +func (g *Gauge) Add(value float64) { + g.Value.Add(value) +} + +var _ metrics.FuncGauge = (*FuncGauge)(nil) + +type FuncGauge struct { + function func() float64 +} + +func (g FuncGauge) Function() func() float64 { + return g.function +} diff --git a/library/go/core/metrics/mock/histogram.go b/library/go/core/metrics/mock/histogram.go new file mode 100644 index 0000000000..734d7b5f88 --- /dev/null +++ b/library/go/core/metrics/mock/histogram.go @@ -0,0 +1,40 @@ +package mock + +import ( + "sort" + "sync" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var ( + _ metrics.Histogram = (*Histogram)(nil) + _ metrics.Timer = (*Histogram)(nil) +) + +type Histogram struct { + Name string + Tags map[string]string + BucketBounds []float64 + BucketValues []int64 + InfValue *atomic.Int64 + mutex sync.Mutex +} + +func (h *Histogram) RecordValue(value float64) { + boundIndex := sort.SearchFloat64s(h.BucketBounds, value) + + if boundIndex < len(h.BucketValues) { + h.mutex.Lock() + h.BucketValues[boundIndex] += 1 + h.mutex.Unlock() + } else { + h.InfValue.Inc() + } +} + +func (h *Histogram) RecordDuration(value time.Duration) { + h.RecordValue(value.Seconds()) +} diff --git a/library/go/core/metrics/mock/int_gauge.go b/library/go/core/metrics/mock/int_gauge.go new file mode 100644 index 0000000000..8955107da9 --- /dev/null +++ b/library/go/core/metrics/mock/int_gauge.go @@ -0,0 +1,33 @@ +package mock + +import ( + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var _ metrics.IntGauge = (*IntGauge)(nil) + +// IntGauge tracks single int64 value. +type IntGauge struct { + Name string + Tags map[string]string + Value *atomic.Int64 +} + +func (g *IntGauge) Set(value int64) { + g.Value.Store(value) +} + +func (g *IntGauge) Add(value int64) { + g.Value.Add(value) +} + +var _ metrics.FuncIntGauge = (*FuncIntGauge)(nil) + +type FuncIntGauge struct { + function func() int64 +} + +func (g FuncIntGauge) Function() func() int64 { + return g.function +} diff --git a/library/go/core/metrics/mock/registry.go b/library/go/core/metrics/mock/registry.go new file mode 100644 index 0000000000..77f465f8ea --- /dev/null +++ b/library/go/core/metrics/mock/registry.go @@ -0,0 +1,224 @@ +package mock + +import ( + "sync" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/metricsutil" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" + "go.uber.org/atomic" +) + +var _ metrics.Registry = (*Registry)(nil) + +type Registry struct { + separator string + prefix string + tags map[string]string + allowLoadRegisteredMetrics bool + + subregistries map[string]*Registry + m *sync.Mutex + + metrics *sync.Map +} + +func NewRegistry(opts *RegistryOpts) *Registry { + r := &Registry{ + separator: ".", + + subregistries: make(map[string]*Registry), + m: new(sync.Mutex), + + metrics: new(sync.Map), + } + + if opts != nil { + r.separator = string(opts.Separator) + r.prefix = opts.Prefix + r.tags = opts.Tags + r.allowLoadRegisteredMetrics = opts.AllowLoadRegisteredMetrics + } + + return r +} + +// WithTags creates new sub-scope, where each metric has tags attached to it. +func (r Registry) WithTags(tags map[string]string) metrics.Registry { + return r.newSubregistry(r.prefix, registryutil.MergeTags(r.tags, tags)) +} + +// WithPrefix creates new sub-scope, where each metric has prefix added to it name. +func (r Registry) WithPrefix(prefix string) metrics.Registry { + return r.newSubregistry(registryutil.BuildFQName(r.separator, r.prefix, prefix), r.tags) +} + +func (r Registry) ComposeName(parts ...string) string { + return registryutil.BuildFQName(r.separator, parts...) +} + +func (r Registry) Counter(name string) metrics.Counter { + s := &Counter{ + Name: r.newMetricName(name), + Tags: r.tags, + Value: new(atomic.Int64), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*Counter) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) FuncCounter(name string, function func() int64) metrics.FuncCounter { + metricName := r.newMetricName(name) + key := registryutil.BuildRegistryKey(metricName, r.tags) + s := FuncCounter{function: function} + if _, loaded := r.metrics.LoadOrStore(key, s); loaded { + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) Gauge(name string) metrics.Gauge { + s := &Gauge{ + Name: r.newMetricName(name), + Tags: r.tags, + Value: new(atomic.Float64), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*Gauge) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) FuncGauge(name string, function func() float64) metrics.FuncGauge { + metricName := r.newMetricName(name) + key := registryutil.BuildRegistryKey(metricName, r.tags) + s := FuncGauge{function: function} + if _, loaded := r.metrics.LoadOrStore(key, s); loaded { + panic("metric with key " + key + " already registered") + } + return s +} + +func (r *Registry) IntGauge(name string) metrics.IntGauge { + s := &IntGauge{ + Name: r.newMetricName(name), + Tags: r.tags, + Value: new(atomic.Int64), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*IntGauge) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r *Registry) FuncIntGauge(name string, function func() int64) metrics.FuncIntGauge { + metricName := r.newMetricName(name) + key := registryutil.BuildRegistryKey(metricName, r.tags) + s := FuncIntGauge{function: function} + if _, loaded := r.metrics.LoadOrStore(key, s); loaded { + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) Timer(name string) metrics.Timer { + s := &Timer{ + Name: r.newMetricName(name), + Tags: r.tags, + Value: new(atomic.Duration), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*Timer) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) Histogram(name string, buckets metrics.Buckets) metrics.Histogram { + s := &Histogram{ + Name: r.newMetricName(name), + Tags: r.tags, + BucketBounds: metricsutil.BucketsBounds(buckets), + BucketValues: make([]int64, buckets.Size()), + InfValue: new(atomic.Int64), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*Histogram) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r Registry) DurationHistogram(name string, buckets metrics.DurationBuckets) metrics.Timer { + s := &Histogram{ + Name: r.newMetricName(name), + Tags: r.tags, + BucketBounds: metricsutil.DurationBucketsBounds(buckets), + BucketValues: make([]int64, buckets.Size()), + InfValue: new(atomic.Int64), + } + + key := registryutil.BuildRegistryKey(s.Name, r.tags) + if val, loaded := r.metrics.LoadOrStore(key, s); loaded { + if r.allowLoadRegisteredMetrics { + return val.(*Histogram) + } + panic("metric with key " + key + " already registered") + } + return s +} + +func (r *Registry) newSubregistry(prefix string, tags map[string]string) *Registry { + registryKey := registryutil.BuildRegistryKey(prefix, tags) + + r.m.Lock() + defer r.m.Unlock() + + if existing, ok := r.subregistries[registryKey]; ok { + return existing + } + + subregistry := &Registry{ + separator: r.separator, + prefix: prefix, + tags: tags, + allowLoadRegisteredMetrics: r.allowLoadRegisteredMetrics, + + subregistries: r.subregistries, + m: r.m, + + metrics: r.metrics, + } + + r.subregistries[registryKey] = subregistry + return subregistry +} + +func (r *Registry) newMetricName(name string) string { + return registryutil.BuildFQName(r.separator, r.prefix, name) +} diff --git a/library/go/core/metrics/mock/registry_opts.go b/library/go/core/metrics/mock/registry_opts.go new file mode 100644 index 0000000000..1cc1c3970d --- /dev/null +++ b/library/go/core/metrics/mock/registry_opts.go @@ -0,0 +1,52 @@ +package mock + +import ( + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +type RegistryOpts struct { + Separator rune + Prefix string + Tags map[string]string + AllowLoadRegisteredMetrics bool +} + +// NewRegistryOpts returns new initialized instance of RegistryOpts +func NewRegistryOpts() *RegistryOpts { + return &RegistryOpts{ + Separator: '.', + Tags: make(map[string]string), + } +} + +// SetTags overrides existing tags +func (o *RegistryOpts) SetTags(tags map[string]string) *RegistryOpts { + o.Tags = tags + return o +} + +// AddTags merges given tags with existing +func (o *RegistryOpts) AddTags(tags map[string]string) *RegistryOpts { + for k, v := range tags { + o.Tags[k] = v + } + return o +} + +// SetPrefix overrides existing prefix +func (o *RegistryOpts) SetPrefix(prefix string) *RegistryOpts { + o.Prefix = prefix + return o +} + +// AppendPrefix adds given prefix as postfix to existing using separator +func (o *RegistryOpts) AppendPrefix(prefix string) *RegistryOpts { + o.Prefix = registryutil.BuildFQName(string(o.Separator), o.Prefix, prefix) + return o +} + +// SetSeparator overrides existing separator +func (o *RegistryOpts) SetSeparator(separator rune) *RegistryOpts { + o.Separator = separator + return o +} diff --git a/library/go/core/metrics/mock/timer.go b/library/go/core/metrics/mock/timer.go new file mode 100644 index 0000000000..3ea3629ca9 --- /dev/null +++ b/library/go/core/metrics/mock/timer.go @@ -0,0 +1,21 @@ +package mock + +import ( + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var _ metrics.Timer = (*Timer)(nil) + +// Timer measures gauge duration. +type Timer struct { + Name string + Tags map[string]string + Value *atomic.Duration +} + +func (t *Timer) RecordDuration(value time.Duration) { + t.Value.Store(value) +} diff --git a/library/go/core/metrics/mock/vec.go b/library/go/core/metrics/mock/vec.go new file mode 100644 index 0000000000..f1cde3d47c --- /dev/null +++ b/library/go/core/metrics/mock/vec.go @@ -0,0 +1,256 @@ +package mock + +import ( + "sync" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +type MetricsVector interface { + With(map[string]string) interface{} + + // Reset deletes all metrics in vector. + Reset() +} + +// Vector is base implementation of vector of metrics of any supported type +type Vector struct { + Labels []string + Mtx sync.RWMutex // Protects metrics. + Metrics map[uint64]interface{} + NewMetric func(map[string]string) interface{} +} + +func (v *Vector) With(tags map[string]string) interface{} { + hv, err := registryutil.VectorHash(tags, v.Labels) + if err != nil { + panic(err) + } + + v.Mtx.RLock() + metric, ok := v.Metrics[hv] + v.Mtx.RUnlock() + if ok { + return metric + } + + v.Mtx.Lock() + defer v.Mtx.Unlock() + + metric, ok = v.Metrics[hv] + if !ok { + metric = v.NewMetric(tags) + v.Metrics[hv] = metric + } + + return metric +} + +// Reset deletes all metrics in this vector. +func (v *Vector) Reset() { + v.Mtx.Lock() + defer v.Mtx.Unlock() + + for h := range v.Metrics { + delete(v.Metrics, h) + } +} + +var _ metrics.CounterVec = (*CounterVec)(nil) + +// CounterVec stores counters and +// implements metrics.CounterVec interface +type CounterVec struct { + Vec MetricsVector +} + +// CounterVec creates a new counters vector with given metric name and +// partitioned by the given label names. +func (r *Registry) CounterVec(name string, labels []string) metrics.CounterVec { + return &CounterVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).Counter(name) + }, + }, + } +} + +// With creates new or returns existing counter with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *CounterVec) With(tags map[string]string) metrics.Counter { + return v.Vec.With(tags).(*Counter) +} + +// Reset deletes all metrics in this vector. +func (v *CounterVec) Reset() { + v.Vec.Reset() +} + +var _ metrics.GaugeVec = new(GaugeVec) + +// GaugeVec stores gauges and +// implements metrics.GaugeVec interface +type GaugeVec struct { + Vec MetricsVector +} + +// GaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) GaugeVec(name string, labels []string) metrics.GaugeVec { + return &GaugeVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).Gauge(name) + }, + }, + } +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *GaugeVec) With(tags map[string]string) metrics.Gauge { + return v.Vec.With(tags).(*Gauge) +} + +// Reset deletes all metrics in this vector. +func (v *GaugeVec) Reset() { + v.Vec.Reset() +} + +var _ metrics.IntGaugeVec = new(IntGaugeVec) + +// IntGaugeVec stores gauges and +// implements metrics.IntGaugeVec interface +type IntGaugeVec struct { + Vec MetricsVector +} + +// IntGaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) IntGaugeVec(name string, labels []string) metrics.IntGaugeVec { + return &IntGaugeVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).IntGauge(name) + }, + }, + } +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *IntGaugeVec) With(tags map[string]string) metrics.IntGauge { + return v.Vec.With(tags).(*IntGauge) +} + +// Reset deletes all metrics in this vector. +func (v *IntGaugeVec) Reset() { + v.Vec.Reset() +} + +var _ metrics.TimerVec = new(TimerVec) + +// TimerVec stores timers and +// implements metrics.TimerVec interface +type TimerVec struct { + Vec MetricsVector +} + +// TimerVec creates a new timers vector with given metric name and +// partitioned by the given label names. +func (r *Registry) TimerVec(name string, labels []string) metrics.TimerVec { + return &TimerVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).Timer(name) + }, + }, + } +} + +// With creates new or returns existing timer with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *TimerVec) With(tags map[string]string) metrics.Timer { + return v.Vec.With(tags).(*Timer) +} + +// Reset deletes all metrics in this vector. +func (v *TimerVec) Reset() { + v.Vec.Reset() +} + +var _ metrics.HistogramVec = (*HistogramVec)(nil) + +// HistogramVec stores histograms and +// implements metrics.HistogramVec interface +type HistogramVec struct { + Vec MetricsVector +} + +// HistogramVec creates a new histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) HistogramVec(name string, buckets metrics.Buckets, labels []string) metrics.HistogramVec { + return &HistogramVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).Histogram(name, buckets) + }, + }, + } +} + +// With creates new or returns existing histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *HistogramVec) With(tags map[string]string) metrics.Histogram { + return v.Vec.With(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *HistogramVec) Reset() { + v.Vec.Reset() +} + +var _ metrics.TimerVec = (*DurationHistogramVec)(nil) + +// DurationHistogramVec stores duration histograms and +// implements metrics.TimerVec interface +type DurationHistogramVec struct { + Vec MetricsVector +} + +// DurationHistogramVec creates a new duration histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) DurationHistogramVec(name string, buckets metrics.DurationBuckets, labels []string) metrics.TimerVec { + return &DurationHistogramVec{ + Vec: &Vector{ + Labels: append([]string(nil), labels...), + Metrics: make(map[uint64]interface{}), + NewMetric: func(tags map[string]string) interface{} { + return r.WithTags(tags).DurationHistogram(name, buckets) + }, + }, + } +} + +// With creates new or returns existing duration histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *DurationHistogramVec) With(tags map[string]string) metrics.Timer { + return v.Vec.With(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *DurationHistogramVec) Reset() { + v.Vec.Reset() +} diff --git a/library/go/core/metrics/mock/ya.make b/library/go/core/metrics/mock/ya.make new file mode 100644 index 0000000000..0ddaf2285b --- /dev/null +++ b/library/go/core/metrics/mock/ya.make @@ -0,0 +1,14 @@ +GO_LIBRARY() + +SRCS( + counter.go + gauge.go + int_gauge.go + histogram.go + registry.go + registry_opts.go + timer.go + vec.go +) + +END() diff --git a/library/go/core/metrics/nop/counter.go b/library/go/core/metrics/nop/counter.go new file mode 100644 index 0000000000..65a36910da --- /dev/null +++ b/library/go/core/metrics/nop/counter.go @@ -0,0 +1,31 @@ +package nop + +import "github.com/ydb-platform/ydb/library/go/core/metrics" + +var _ metrics.Counter = (*Counter)(nil) + +type Counter struct{} + +func (Counter) Inc() {} + +func (Counter) Add(_ int64) {} + +var _ metrics.CounterVec = (*CounterVec)(nil) + +type CounterVec struct{} + +func (t CounterVec) With(_ map[string]string) metrics.Counter { + return Counter{} +} + +func (t CounterVec) Reset() {} + +var _ metrics.FuncCounter = (*FuncCounter)(nil) + +type FuncCounter struct { + function func() int64 +} + +func (c FuncCounter) Function() func() int64 { + return c.function +} diff --git a/library/go/core/metrics/nop/gauge.go b/library/go/core/metrics/nop/gauge.go new file mode 100644 index 0000000000..9ab9ff6d77 --- /dev/null +++ b/library/go/core/metrics/nop/gauge.go @@ -0,0 +1,31 @@ +package nop + +import "github.com/ydb-platform/ydb/library/go/core/metrics" + +var _ metrics.Gauge = (*Gauge)(nil) + +type Gauge struct{} + +func (Gauge) Set(_ float64) {} + +func (Gauge) Add(_ float64) {} + +var _ metrics.GaugeVec = (*GaugeVec)(nil) + +type GaugeVec struct{} + +func (t GaugeVec) With(_ map[string]string) metrics.Gauge { + return Gauge{} +} + +func (t GaugeVec) Reset() {} + +var _ metrics.FuncGauge = (*FuncGauge)(nil) + +type FuncGauge struct { + function func() float64 +} + +func (g FuncGauge) Function() func() float64 { + return g.function +} diff --git a/library/go/core/metrics/nop/histogram.go b/library/go/core/metrics/nop/histogram.go new file mode 100644 index 0000000000..bde571323c --- /dev/null +++ b/library/go/core/metrics/nop/histogram.go @@ -0,0 +1,38 @@ +package nop + +import ( + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var ( + _ metrics.Histogram = (*Histogram)(nil) + _ metrics.Timer = (*Histogram)(nil) +) + +type Histogram struct{} + +func (Histogram) RecordValue(_ float64) {} + +func (Histogram) RecordDuration(_ time.Duration) {} + +var _ metrics.HistogramVec = (*HistogramVec)(nil) + +type HistogramVec struct{} + +func (t HistogramVec) With(_ map[string]string) metrics.Histogram { + return Histogram{} +} + +func (t HistogramVec) Reset() {} + +var _ metrics.TimerVec = (*DurationHistogramVec)(nil) + +type DurationHistogramVec struct{} + +func (t DurationHistogramVec) With(_ map[string]string) metrics.Timer { + return Histogram{} +} + +func (t DurationHistogramVec) Reset() {} diff --git a/library/go/core/metrics/nop/int_gauge.go b/library/go/core/metrics/nop/int_gauge.go new file mode 100644 index 0000000000..226059a79d --- /dev/null +++ b/library/go/core/metrics/nop/int_gauge.go @@ -0,0 +1,31 @@ +package nop + +import "github.com/ydb-platform/ydb/library/go/core/metrics" + +var _ metrics.IntGauge = (*IntGauge)(nil) + +type IntGauge struct{} + +func (IntGauge) Set(_ int64) {} + +func (IntGauge) Add(_ int64) {} + +var _ metrics.IntGaugeVec = (*IntGaugeVec)(nil) + +type IntGaugeVec struct{} + +func (t IntGaugeVec) With(_ map[string]string) metrics.IntGauge { + return IntGauge{} +} + +func (t IntGaugeVec) Reset() {} + +var _ metrics.FuncIntGauge = (*FuncIntGauge)(nil) + +type FuncIntGauge struct { + function func() int64 +} + +func (g FuncIntGauge) Function() func() int64 { + return g.function +} diff --git a/library/go/core/metrics/nop/registry.go b/library/go/core/metrics/nop/registry.go new file mode 100644 index 0000000000..97ed977ed7 --- /dev/null +++ b/library/go/core/metrics/nop/registry.go @@ -0,0 +1,79 @@ +package nop + +import "github.com/ydb-platform/ydb/library/go/core/metrics" + +var _ metrics.Registry = (*Registry)(nil) + +type Registry struct{} + +func (r Registry) ComposeName(parts ...string) string { + return "" +} + +func (r Registry) WithTags(_ map[string]string) metrics.Registry { + return Registry{} +} + +func (r Registry) WithPrefix(_ string) metrics.Registry { + return Registry{} +} + +func (r Registry) Counter(_ string) metrics.Counter { + return Counter{} +} + +func (r Registry) FuncCounter(_ string, function func() int64) metrics.FuncCounter { + return FuncCounter{function: function} +} + +func (r Registry) Gauge(_ string) metrics.Gauge { + return Gauge{} +} + +func (r Registry) FuncGauge(_ string, function func() float64) metrics.FuncGauge { + return FuncGauge{function: function} +} + +func (r Registry) IntGauge(_ string) metrics.IntGauge { + return IntGauge{} +} + +func (r Registry) FuncIntGauge(_ string, function func() int64) metrics.FuncIntGauge { + return FuncIntGauge{function: function} +} + +func (r Registry) Timer(_ string) metrics.Timer { + return Timer{} +} + +func (r Registry) Histogram(_ string, _ metrics.Buckets) metrics.Histogram { + return Histogram{} +} + +func (r Registry) DurationHistogram(_ string, _ metrics.DurationBuckets) metrics.Timer { + return Histogram{} +} + +func (r Registry) CounterVec(_ string, _ []string) metrics.CounterVec { + return CounterVec{} +} + +func (r Registry) GaugeVec(_ string, _ []string) metrics.GaugeVec { + return GaugeVec{} +} + +func (r Registry) IntGaugeVec(_ string, _ []string) metrics.IntGaugeVec { + return IntGaugeVec{} +} + +func (r Registry) TimerVec(_ string, _ []string) metrics.TimerVec { + return TimerVec{} +} + +func (r Registry) HistogramVec(_ string, _ metrics.Buckets, _ []string) metrics.HistogramVec { + return HistogramVec{} +} + +func (r Registry) DurationHistogramVec(_ string, _ metrics.DurationBuckets, _ []string) metrics.TimerVec { + return DurationHistogramVec{} +} diff --git a/library/go/core/metrics/nop/timer.go b/library/go/core/metrics/nop/timer.go new file mode 100644 index 0000000000..61906032a2 --- /dev/null +++ b/library/go/core/metrics/nop/timer.go @@ -0,0 +1,23 @@ +package nop + +import ( + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.Timer = (*Timer)(nil) + +type Timer struct{} + +func (Timer) RecordDuration(_ time.Duration) {} + +var _ metrics.TimerVec = (*TimerVec)(nil) + +type TimerVec struct{} + +func (t TimerVec) With(_ map[string]string) metrics.Timer { + return Timer{} +} + +func (t TimerVec) Reset() {} diff --git a/library/go/core/metrics/nop/ya.make b/library/go/core/metrics/nop/ya.make new file mode 100644 index 0000000000..279bc22ef4 --- /dev/null +++ b/library/go/core/metrics/nop/ya.make @@ -0,0 +1,12 @@ +GO_LIBRARY() + +SRCS( + counter.go + gauge.go + int_gauge.go + histogram.go + registry.go + timer.go +) + +END() diff --git a/library/go/core/metrics/prometheus/counter.go b/library/go/core/metrics/prometheus/counter.go new file mode 100644 index 0000000000..1a07063f30 --- /dev/null +++ b/library/go/core/metrics/prometheus/counter.go @@ -0,0 +1,34 @@ +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.Counter = (*Counter)(nil) + +// Counter tracks monotonically increasing value. +type Counter struct { + cnt prometheus.Counter +} + +// Inc increments counter by 1. +func (c Counter) Inc() { + c.cnt.Inc() +} + +// Add adds delta to the counter. Delta must be >=0. +func (c Counter) Add(delta int64) { + c.cnt.Add(float64(delta)) +} + +var _ metrics.FuncCounter = (*FuncCounter)(nil) + +type FuncCounter struct { + cnt prometheus.CounterFunc + function func() int64 +} + +func (c FuncCounter) Function() func() int64 { + return c.function +} diff --git a/library/go/core/metrics/prometheus/counter_test.go b/library/go/core/metrics/prometheus/counter_test.go new file mode 100644 index 0000000000..04f0c894f8 --- /dev/null +++ b/library/go/core/metrics/prometheus/counter_test.go @@ -0,0 +1,38 @@ +package prometheus + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" +) + +func TestCounter_Add(t *testing.T) { + c := &Counter{cnt: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "test_counter_add", + })} + + var expectValue int64 = 42 + c.Add(expectValue) + + var res dto.Metric + err := c.cnt.Write(&res) + + assert.NoError(t, err) + assert.Equal(t, expectValue, int64(res.GetCounter().GetValue())) +} + +func TestCounter_Inc(t *testing.T) { + c := &Counter{cnt: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "test_counter_inc", + })} + + var res dto.Metric + for i := 1; i <= 10; i++ { + c.Inc() + err := c.cnt.Write(&res) + assert.NoError(t, err) + assert.Equal(t, int64(i), int64(res.GetCounter().GetValue())) + } +} diff --git a/library/go/core/metrics/prometheus/gauge.go b/library/go/core/metrics/prometheus/gauge.go new file mode 100644 index 0000000000..8683755561 --- /dev/null +++ b/library/go/core/metrics/prometheus/gauge.go @@ -0,0 +1,32 @@ +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.Gauge = (*Gauge)(nil) + +// Gauge tracks single float64 value. +type Gauge struct { + gg prometheus.Gauge +} + +func (g Gauge) Set(value float64) { + g.gg.Set(value) +} + +func (g Gauge) Add(value float64) { + g.gg.Add(value) +} + +var _ metrics.FuncGauge = (*FuncGauge)(nil) + +type FuncGauge struct { + ff prometheus.GaugeFunc + function func() float64 +} + +func (g FuncGauge) Function() func() float64 { + return g.function +} diff --git a/library/go/core/metrics/prometheus/gauge_test.go b/library/go/core/metrics/prometheus/gauge_test.go new file mode 100644 index 0000000000..aebb7586c1 --- /dev/null +++ b/library/go/core/metrics/prometheus/gauge_test.go @@ -0,0 +1,39 @@ +package prometheus + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" +) + +func TestGauge_Add(t *testing.T) { + g := &Gauge{gg: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "test_gauge_add", + })} + + var expectValue float64 = 42 + g.Add(expectValue) + + var res dto.Metric + err := g.gg.Write(&res) + + assert.NoError(t, err) + assert.Equal(t, expectValue, res.GetGauge().GetValue()) +} + +func TestGauge_Set(t *testing.T) { + g := &Gauge{gg: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "test_gauge_set", + })} + + var expectValue float64 = 42 + g.Set(expectValue) + + var res dto.Metric + err := g.gg.Write(&res) + + assert.NoError(t, err) + assert.Equal(t, expectValue, res.GetGauge().GetValue()) +} diff --git a/library/go/core/metrics/prometheus/gotest/ya.make b/library/go/core/metrics/prometheus/gotest/ya.make new file mode 100644 index 0000000000..466256dcaa --- /dev/null +++ b/library/go/core/metrics/prometheus/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/core/metrics/prometheus) + +END() diff --git a/library/go/core/metrics/prometheus/histogram.go b/library/go/core/metrics/prometheus/histogram.go new file mode 100644 index 0000000000..bd5e0dca66 --- /dev/null +++ b/library/go/core/metrics/prometheus/histogram.go @@ -0,0 +1,22 @@ +package prometheus + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.Histogram = (*Histogram)(nil) + +type Histogram struct { + hm prometheus.Observer +} + +func (h Histogram) RecordValue(value float64) { + h.hm.Observe(value) +} + +func (h Histogram) RecordDuration(value time.Duration) { + h.hm.Observe(value.Seconds()) +} diff --git a/library/go/core/metrics/prometheus/histogram_test.go b/library/go/core/metrics/prometheus/histogram_test.go new file mode 100644 index 0000000000..0dec46589c --- /dev/null +++ b/library/go/core/metrics/prometheus/histogram_test.go @@ -0,0 +1,91 @@ +package prometheus + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/ptr" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestHistogram_RecordValue(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + h := rg.Histogram("test_histogram_record_value", + metrics.NewBuckets(0.1, 1.0, 15.47, 42.0, 128.256), + ) + + for _, value := range []float64{0.5, 0.7, 34.1234, 127} { + h.RecordValue(value) + } + + expectBuckets := []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(0), UpperBound: ptr.Float64(0.1)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(1.0)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(15.47)}, + {CumulativeCount: ptr.Uint64(3), UpperBound: ptr.Float64(42.0)}, + {CumulativeCount: ptr.Uint64(4), UpperBound: ptr.Float64(128.256)}, + } + + gathered, err := rg.Gather() + require.NoError(t, err) + + resBuckets := gathered[0].Metric[0].GetHistogram().GetBucket() + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(), + protocmp.Transform(), + } + assert.True(t, cmp.Equal(expectBuckets, resBuckets, cmpOpts...), cmp.Diff(expectBuckets, resBuckets, cmpOpts...)) +} + +func TestDurationHistogram_RecordDuration(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + ht := rg.DurationHistogram("test_histogram_record_value", + metrics.NewDurationBuckets( + 1*time.Millisecond, // 0.1 + 1*time.Second, // 1.0 + 15*time.Second+470*time.Millisecond, // 15.47 + 42*time.Second, // 42.0 + 128*time.Second+256*time.Millisecond, // 128.256 + ), + ) + + values := []time.Duration{ + 500 * time.Millisecond, + 700 * time.Millisecond, + 34*time.Second + 1234*time.Millisecond, + 127 * time.Second, + } + + for _, value := range values { + ht.RecordDuration(value) + } + + expectBuckets := []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(0), UpperBound: ptr.Float64(0.001)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(1)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(15.47)}, + {CumulativeCount: ptr.Uint64(3), UpperBound: ptr.Float64(42)}, + {CumulativeCount: ptr.Uint64(4), UpperBound: ptr.Float64(128.256)}, + } + + gathered, err := rg.Gather() + require.NoError(t, err) + + resBuckets := gathered[0].Metric[0].GetHistogram().GetBucket() + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(), + protocmp.Transform(), + } + + assert.True(t, cmp.Equal(expectBuckets, resBuckets, cmpOpts...), cmp.Diff(expectBuckets, resBuckets, cmpOpts...)) +} diff --git a/library/go/core/metrics/prometheus/int_gauge.go b/library/go/core/metrics/prometheus/int_gauge.go new file mode 100644 index 0000000000..813b87828c --- /dev/null +++ b/library/go/core/metrics/prometheus/int_gauge.go @@ -0,0 +1,32 @@ +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.IntGauge = (*IntGauge)(nil) + +// IntGauge tracks single int64 value. +type IntGauge struct { + metrics.Gauge +} + +func (i IntGauge) Set(value int64) { + i.Gauge.Set(float64(value)) +} + +func (i IntGauge) Add(value int64) { + i.Gauge.Add(float64(value)) +} + +var _ metrics.FuncIntGauge = (*FuncIntGauge)(nil) + +type FuncIntGauge struct { + ff prometheus.GaugeFunc + function func() int64 +} + +func (g FuncIntGauge) Function() func() int64 { + return g.function +} diff --git a/library/go/core/metrics/prometheus/registry.go b/library/go/core/metrics/prometheus/registry.go new file mode 100644 index 0000000000..bad45fe617 --- /dev/null +++ b/library/go/core/metrics/prometheus/registry.go @@ -0,0 +1,254 @@ +package prometheus + +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/metricsutil" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +var _ metrics.Registry = (*Registry)(nil) + +type Registry struct { + rg *prometheus.Registry + + m *sync.Mutex + subregistries map[string]*Registry + + tags map[string]string + prefix string + nameSanitizer func(string) string +} + +// NewRegistry creates new Prometheus backed registry. +func NewRegistry(opts *RegistryOpts) *Registry { + r := &Registry{ + rg: prometheus.NewRegistry(), + m: new(sync.Mutex), + subregistries: make(map[string]*Registry), + tags: make(map[string]string), + } + + if opts != nil { + r.prefix = opts.Prefix + r.tags = opts.Tags + if opts.rg != nil { + r.rg = opts.rg + } + for _, collector := range opts.Collectors { + collector(r) + } + if opts.NameSanitizer != nil { + r.nameSanitizer = opts.NameSanitizer + } + } + + return r +} + +// WithTags creates new sub-scope, where each metric has tags attached to it. +func (r Registry) WithTags(tags map[string]string) metrics.Registry { + return r.newSubregistry(r.prefix, registryutil.MergeTags(r.tags, tags)) +} + +// WithPrefix creates new sub-scope, where each metric has prefix added to it name. +func (r Registry) WithPrefix(prefix string) metrics.Registry { + return r.newSubregistry(registryutil.BuildFQName("_", r.prefix, prefix), r.tags) +} + +// ComposeName builds FQ name with appropriate separator. +func (r Registry) ComposeName(parts ...string) string { + return registryutil.BuildFQName("_", parts...) +} + +func (r Registry) Counter(name string) metrics.Counter { + name = r.sanitizeName(name) + cnt := prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }) + + if err := r.rg.Register(cnt); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &Counter{cnt: existErr.ExistingCollector.(prometheus.Counter)} + } + panic(err) + } + + return &Counter{cnt: cnt} +} + +func (r Registry) FuncCounter(name string, function func() int64) metrics.FuncCounter { + name = r.sanitizeName(name) + cnt := prometheus.NewCounterFunc(prometheus.CounterOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, func() float64 { + return float64(function()) + }) + + if err := r.rg.Register(cnt); err != nil { + panic(err) + } + + return &FuncCounter{ + cnt: cnt, + function: function, + } +} + +func (r Registry) Gauge(name string) metrics.Gauge { + name = r.sanitizeName(name) + gg := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }) + + if err := r.rg.Register(gg); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &Gauge{gg: existErr.ExistingCollector.(prometheus.Gauge)} + } + panic(err) + } + + return &Gauge{gg: gg} +} + +func (r Registry) FuncGauge(name string, function func() float64) metrics.FuncGauge { + name = r.sanitizeName(name) + ff := prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, function) + if err := r.rg.Register(ff); err != nil { + panic(err) + } + return &FuncGauge{ + ff: ff, + function: function, + } +} + +func (r Registry) IntGauge(name string) metrics.IntGauge { + return &IntGauge{Gauge: r.Gauge(name)} +} + +func (r Registry) FuncIntGauge(name string, function func() int64) metrics.FuncIntGauge { + name = r.sanitizeName(name) + ff := prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, func() float64 { return float64(function()) }) + if err := r.rg.Register(ff); err != nil { + panic(err) + } + return &FuncIntGauge{ + ff: ff, + function: function, + } +} + +func (r Registry) Timer(name string) metrics.Timer { + name = r.sanitizeName(name) + gg := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }) + + if err := r.rg.Register(gg); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &Timer{gg: existErr.ExistingCollector.(prometheus.Gauge)} + } + panic(err) + } + + return &Timer{gg: gg} +} + +func (r Registry) Histogram(name string, buckets metrics.Buckets) metrics.Histogram { + name = r.sanitizeName(name) + hm := prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + Buckets: metricsutil.BucketsBounds(buckets), + }) + + if err := r.rg.Register(hm); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &Histogram{hm: existErr.ExistingCollector.(prometheus.Observer)} + } + panic(err) + } + + return &Histogram{hm: hm} +} + +func (r Registry) DurationHistogram(name string, buckets metrics.DurationBuckets) metrics.Timer { + name = r.sanitizeName(name) + hm := prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + Buckets: metricsutil.DurationBucketsBounds(buckets), + }) + + if err := r.rg.Register(hm); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &Histogram{hm: existErr.ExistingCollector.(prometheus.Histogram)} + } + panic(err) + } + + return &Histogram{hm: hm} +} + +// Gather returns raw collected Prometheus metrics. +func (r Registry) Gather() ([]*dto.MetricFamily, error) { + return r.rg.Gather() +} + +func (r *Registry) newSubregistry(prefix string, tags map[string]string) *Registry { + registryKey := registryutil.BuildRegistryKey(prefix, tags) + + r.m.Lock() + defer r.m.Unlock() + + if old, ok := r.subregistries[registryKey]; ok { + return old + } + + subregistry := &Registry{ + rg: r.rg, + m: r.m, + subregistries: r.subregistries, + tags: tags, + prefix: prefix, + nameSanitizer: r.nameSanitizer, + } + + r.subregistries[registryKey] = subregistry + return subregistry +} + +func (r *Registry) sanitizeName(name string) string { + if r.nameSanitizer == nil { + return name + } + return r.nameSanitizer(name) +} diff --git a/library/go/core/metrics/prometheus/registry_opts.go b/library/go/core/metrics/prometheus/registry_opts.go new file mode 100644 index 0000000000..fedb019d85 --- /dev/null +++ b/library/go/core/metrics/prometheus/registry_opts.go @@ -0,0 +1,84 @@ +package prometheus + +import ( + "context" + + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/collect" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +type RegistryOpts struct { + Prefix string + Tags map[string]string + rg *prometheus.Registry + Collectors []func(metrics.Registry) + NameSanitizer func(string) string +} + +// NewRegistryOpts returns new initialized instance of RegistryOpts. +func NewRegistryOpts() *RegistryOpts { + return &RegistryOpts{ + Tags: make(map[string]string), + } +} + +// SetTags overrides existing tags. +func (o *RegistryOpts) SetTags(tags map[string]string) *RegistryOpts { + o.Tags = tags + return o +} + +// AddTags merges given tags with existing. +func (o *RegistryOpts) AddTags(tags map[string]string) *RegistryOpts { + for k, v := range tags { + o.Tags[k] = v + } + return o +} + +// SetPrefix overrides existing prefix. +func (o *RegistryOpts) SetPrefix(prefix string) *RegistryOpts { + o.Prefix = prefix + return o +} + +// AppendPrefix adds given prefix as postfix to existing using separator. +func (o *RegistryOpts) AppendPrefix(prefix string) *RegistryOpts { + o.Prefix = registryutil.BuildFQName("_", o.Prefix, prefix) + return o +} + +// SetRegistry sets the given prometheus registry for further usage instead +// of creating a new one. +// +// This is primarily used to unite externally defined metrics with metrics kept +// in the core registry. +func (o *RegistryOpts) SetRegistry(rg *prometheus.Registry) *RegistryOpts { + o.rg = rg + return o +} + +// AddCollectors adds collectors that handle their metrics automatically (e.g. system metrics). +func (o *RegistryOpts) AddCollectors( + ctx context.Context, c metrics.CollectPolicy, collectors ...collect.Func, +) *RegistryOpts { + if len(collectors) == 0 { + return o + } + + o.Collectors = append(o.Collectors, func(r metrics.Registry) { + for _, collector := range collectors { + collector(ctx, r, c) + } + }) + return o +} + +// SetNameSanitizer sets a functions which will be called for each metric's name. +// It allows to alter names, for example to replace invalid characters +func (o *RegistryOpts) SetNameSanitizer(v func(string) string) *RegistryOpts { + o.NameSanitizer = v + return o +} diff --git a/library/go/core/metrics/prometheus/registry_test.go b/library/go/core/metrics/prometheus/registry_test.go new file mode 100644 index 0000000000..73d071a8de --- /dev/null +++ b/library/go/core/metrics/prometheus/registry_test.go @@ -0,0 +1,481 @@ +package prometheus + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/ptr" + "github.com/ydb-platform/ydb/library/go/test/assertpb" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestNewRegistry(t *testing.T) { + expected := &Registry{ + rg: prometheus.NewRegistry(), + m: new(sync.Mutex), + subregistries: make(map[string]*Registry), + tags: map[string]string{}, + prefix: "", + } + + r := NewRegistry(nil) + assert.IsType(t, expected, r) + assert.Equal(t, expected, r) +} + +func TestRegistry_Subregisters(t *testing.T) { + r := NewRegistry(nil) + sr1 := r.WithPrefix("subregister1"). + WithTags(map[string]string{"ololo": "trololo"}) + sr2 := sr1.WithPrefix("subregister2"). + WithTags(map[string]string{"shimba": "boomba"}) + + // check global subregistries map + expectedMap := map[string]*Registry{ + "\"subregister1\"{}": { + rg: r.rg, + m: r.m, + subregistries: r.subregistries, + prefix: "subregister1", + tags: make(map[string]string), + }, + "\"subregister1\"{ololo=trololo}": { + rg: r.rg, + m: r.m, + subregistries: r.subregistries, + tags: map[string]string{"ololo": "trololo"}, + prefix: "subregister1", + }, + "\"subregister1_subregister2\"{ololo=trololo}": { + rg: r.rg, + m: r.m, + subregistries: r.subregistries, + tags: map[string]string{"ololo": "trololo"}, + prefix: "subregister1_subregister2", + }, + "\"subregister1_subregister2\"{ololo=trololo,shimba=boomba}": { + rg: r.rg, + m: r.m, + subregistries: r.subregistries, + tags: map[string]string{"ololo": "trololo", "shimba": "boomba"}, + prefix: "subregister1_subregister2", + }, + } + + assert.EqualValues(t, expectedMap, r.subregistries) + + // top-register write + rCnt := r.Counter("subregisters_count") + rCnt.Add(2) + + // sub-register write + srTm := sr1.Timer("mytimer") + srTm.RecordDuration(42 * time.Second) + + // sub-sub-register write + srHm := sr2.Histogram("myhistogram", metrics.NewBuckets(1, 2, 3)) + srHm.RecordValue(1.5) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("subregister1_mytimer"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_GAUGE), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + }, + Gauge: &dto.Gauge{Value: ptr.Float64(42)}, + }, + }, + }, + { + Name: ptr.String("subregister1_subregister2_myhistogram"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_HISTOGRAM), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + {Name: ptr.String("shimba"), Value: ptr.String("boomba")}, + }, + Histogram: &dto.Histogram{ + SampleCount: ptr.Uint64(1), + SampleSum: ptr.Float64(1.5), + Bucket: []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(0), UpperBound: ptr.Float64(1)}, + {CumulativeCount: ptr.Uint64(1), UpperBound: ptr.Float64(2)}, + {CumulativeCount: ptr.Uint64(1), UpperBound: ptr.Float64(3)}, + }, + }, + }, + }, + }, + { + Name: ptr.String("subregisters_count"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_COUNTER), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{}, + Counter: &dto.Counter{Value: ptr.Float64(2)}, + }, + }, + }, + } + + cmpOpts := []cmp.Option{ + protocmp.Transform(), + } + assert.True(t, cmp.Equal(expected, mr, cmpOpts...), cmp.Diff(expected, mr, cmpOpts...)) +} + +func TestRegistry_Counter(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + + // must panic on empty name + assert.Panics(t, func() { r.Counter("") }) + + srCnt := sr.Counter("mycounter") + srCnt.Add(42) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_mycounter"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_COUNTER), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + }, + Counter: &dto.Counter{Value: ptr.Float64(42)}, + }, + }, + }, + } + cmpOpts := []cmp.Option{ + protocmp.Transform(), + } + assert.True(t, cmp.Equal(expected, mr, cmpOpts...), cmp.Diff(expected, mr, cmpOpts...)) +} + +func TestRegistry_DurationHistogram(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + + // must panic on empty name + assert.Panics(t, func() { r.DurationHistogram("", nil) }) + + cnt := sr.DurationHistogram("myhistogram", metrics.NewDurationBuckets( + 1*time.Second, 3*time.Second, 5*time.Second, + )) + + cnt.RecordDuration(2 * time.Second) + cnt.RecordDuration(4 * time.Second) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_myhistogram"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_HISTOGRAM), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{{Name: ptr.String("ololo"), Value: ptr.String("trololo")}}, + Histogram: &dto.Histogram{ + SampleCount: ptr.Uint64(2), + SampleSum: ptr.Float64(6), + Bucket: []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(0), UpperBound: ptr.Float64(1)}, + {CumulativeCount: ptr.Uint64(1), UpperBound: ptr.Float64(3)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(5)}, + }, + }, + }, + }, + }, + } + assertpb.Equal(t, expected, mr) +} + +func TestRegistry_Gauge(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + + // must panic on empty name + assert.Panics(t, func() { r.Gauge("") }) + + cnt := sr.Gauge("mygauge") + cnt.Add(42) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_mygauge"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_GAUGE), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{{Name: ptr.String("ololo"), Value: ptr.String("trololo")}}, + Gauge: &dto.Gauge{Value: ptr.Float64(42)}, + }, + }, + }, + } + assertpb.Equal(t, expected, mr) +} + +func TestRegistry_Histogram(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + + // must panic on empty name + assert.Panics(t, func() { r.Histogram("", nil) }) + + cnt := sr.Histogram("myhistogram", metrics.NewBuckets(1, 3, 5)) + + cnt.RecordValue(2) + cnt.RecordValue(4) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_myhistogram"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_HISTOGRAM), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{{Name: ptr.String("ololo"), Value: ptr.String("trololo")}}, + Histogram: &dto.Histogram{ + SampleCount: ptr.Uint64(2), + SampleSum: ptr.Float64(6), + Bucket: []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(0), UpperBound: ptr.Float64(1)}, + {CumulativeCount: ptr.Uint64(1), UpperBound: ptr.Float64(3)}, + {CumulativeCount: ptr.Uint64(2), UpperBound: ptr.Float64(5)}, + }, + }, + }, + }, + }, + } + assertpb.Equal(t, expected, mr) +} + +func TestRegistry_Timer(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + + // must panic on empty name + assert.Panics(t, func() { r.Timer("") }) + + cnt := sr.Timer("mytimer") + cnt.RecordDuration(42 * time.Second) + + mr, err := r.Gather() + assert.NoError(t, err) + + assert.IsType(t, mr, []*dto.MetricFamily{}) + + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_mytimer"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_GAUGE), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{{Name: ptr.String("ololo"), Value: ptr.String("trololo")}}, + Gauge: &dto.Gauge{Value: ptr.Float64(42)}, + }, + }, + }, + } + assertpb.Equal(t, expected, mr) +} + +func TestRegistry_WithPrefix(t *testing.T) { + testCases := []struct { + r metrics.Registry + expected string + }{ + { + r: func() metrics.Registry { + return NewRegistry(nil) + }(), + expected: "", + }, + { + r: func() metrics.Registry { + return NewRegistry(nil).WithPrefix("myprefix") + }(), + expected: "myprefix", + }, + { + r: func() metrics.Registry { + return NewRegistry(nil).WithPrefix("__myprefix_") + }(), + expected: "myprefix", + }, + { + r: func() metrics.Registry { + return NewRegistry(nil).WithPrefix("__myprefix_").WithPrefix("mysubprefix______") + }(), + expected: "myprefix_mysubprefix", + }, + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + assert.Equal(t, tc.expected, tc.r.(*Registry).prefix) + }) + } +} + +func TestRegistry_WithTags(t *testing.T) { + testCases := []struct { + r metrics.Registry + expected map[string]string + }{ + { + r: func() metrics.Registry { + return NewRegistry(nil) + }(), + expected: map[string]string{}, + }, + { + r: func() metrics.Registry { + return NewRegistry(nil).WithTags(map[string]string{"shimba": "boomba"}) + }(), + expected: map[string]string{"shimba": "boomba"}, + }, + { + r: func() metrics.Registry { + return NewRegistry(nil). + WithTags(map[string]string{"shimba": "boomba"}). + WithTags(map[string]string{"looken": "tooken"}) + }(), + expected: map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }, + }, + { + r: func() metrics.Registry { + return NewRegistry(nil). + WithTags(map[string]string{"shimba": "boomba"}). + WithTags(map[string]string{"looken": "tooken"}). + WithTags(map[string]string{"shimba": "cooken"}) + }(), + expected: map[string]string{ + "shimba": "cooken", + "looken": "tooken", + }, + }, + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + assert.Equal(t, tc.expected, tc.r.(*Registry).tags) + }) + } +} + +func TestRegistry_WithTags_NoPanic(t *testing.T) { + _ = NewRegistry(nil).WithTags(map[string]string{"foo": "bar"}) + _ = NewRegistry(nil).WithTags(map[string]string{"foo": "bar"}) +} + +func TestRegistry_Counter_NoPanic(t *testing.T) { + r := NewRegistry(nil) + sr := r.WithPrefix("myprefix"). + WithTags(map[string]string{"ololo": "trololo"}) + cntrRaz := sr.Counter("mycounter").(*Counter) + cntrDvaz := sr.Counter("mycounter").(*Counter) + assert.Equal(t, cntrRaz.cnt, cntrDvaz.cnt) + cntrRaz.Add(100) + cntrDvaz.Add(100) + mr, err := r.Gather() + assert.NoError(t, err) + expected := []*dto.MetricFamily{ + { + Name: ptr.String("myprefix_mycounter"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_COUNTER), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{{Name: ptr.String("ololo"), Value: ptr.String("trololo")}}, + Counter: &dto.Counter{Value: ptr.Float64(200)}, + }, + }, + }, + } + assertpb.Equal(t, expected, mr) +} + +func TestRegistry_NameSanitizer(t *testing.T) { + testCases := []struct { + opts *RegistryOpts + name string + want string + }{ + { + opts: nil, + name: "some_name", + want: "some_name", + }, + { + opts: NewRegistryOpts().SetNameSanitizer(func(s string) string { + return strings.ReplaceAll(s, "/", "_") + }), + name: "other/name", + want: "other_name", + }, + } + + for _, tc := range testCases { + r := NewRegistry(tc.opts) + _ = r.Counter(tc.name) + mfs, err := r.Gather() + assert.NoError(t, err) + assert.NotEmpty(t, mfs) + + assert.Equal(t, tc.want, *mfs[0].Name) + } +} diff --git a/library/go/core/metrics/prometheus/timer.go b/library/go/core/metrics/prometheus/timer.go new file mode 100644 index 0000000000..3350e5a61d --- /dev/null +++ b/library/go/core/metrics/prometheus/timer.go @@ -0,0 +1,19 @@ +package prometheus + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +var _ metrics.Timer = (*Timer)(nil) + +// Timer measures gauge duration. +type Timer struct { + gg prometheus.Gauge +} + +func (t Timer) RecordDuration(value time.Duration) { + t.gg.Set(value.Seconds()) +} diff --git a/library/go/core/metrics/prometheus/timer_test.go b/library/go/core/metrics/prometheus/timer_test.go new file mode 100644 index 0000000000..a520b6f477 --- /dev/null +++ b/library/go/core/metrics/prometheus/timer_test.go @@ -0,0 +1,24 @@ +package prometheus + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" +) + +func TestTimer_RecordDuration(t *testing.T) { + g := &Timer{gg: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "test_timer_record_duration", + })} + + g.RecordDuration(42 * time.Second) + + var res dto.Metric + err := g.gg.Write(&res) + + assert.NoError(t, err) + assert.Equal(t, float64(42), res.GetGauge().GetValue()) +} diff --git a/library/go/core/metrics/prometheus/vec.go b/library/go/core/metrics/prometheus/vec.go new file mode 100644 index 0000000000..731c7b752a --- /dev/null +++ b/library/go/core/metrics/prometheus/vec.go @@ -0,0 +1,248 @@ +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/metricsutil" + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +var _ metrics.CounterVec = (*CounterVec)(nil) + +// CounterVec wraps prometheus.CounterVec +// and implements metrics.CounterVec interface. +type CounterVec struct { + vec *prometheus.CounterVec +} + +// CounterVec creates a new counters vector with given metric name and +// partitioned by the given label names. +func (r *Registry) CounterVec(name string, labels []string) metrics.CounterVec { + name = r.sanitizeName(name) + vec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &CounterVec{vec: existErr.ExistingCollector.(*prometheus.CounterVec)} + } + panic(err) + } + + return &CounterVec{vec: vec} +} + +// With creates new or returns existing counter with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *CounterVec) With(tags map[string]string) metrics.Counter { + return &Counter{cnt: v.vec.With(tags)} +} + +// Reset deletes all metrics in this vector. +func (v *CounterVec) Reset() { + v.vec.Reset() +} + +var _ metrics.GaugeVec = (*GaugeVec)(nil) + +// GaugeVec wraps prometheus.GaugeVec +// and implements metrics.GaugeVec interface. +type GaugeVec struct { + vec *prometheus.GaugeVec +} + +// GaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) GaugeVec(name string, labels []string) metrics.GaugeVec { + name = r.sanitizeName(name) + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &GaugeVec{vec: existErr.ExistingCollector.(*prometheus.GaugeVec)} + } + panic(err) + } + + return &GaugeVec{vec: vec} +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *GaugeVec) With(tags map[string]string) metrics.Gauge { + return &Gauge{gg: v.vec.With(tags)} +} + +// Reset deletes all metrics in this vector. +func (v *GaugeVec) Reset() { + v.vec.Reset() +} + +// IntGaugeVec wraps prometheus.GaugeVec +// and implements metrics.IntGaugeVec interface. +type IntGaugeVec struct { + vec *prometheus.GaugeVec +} + +// IntGaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) IntGaugeVec(name string, labels []string) metrics.IntGaugeVec { + name = r.sanitizeName(name) + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &IntGaugeVec{vec: existErr.ExistingCollector.(*prometheus.GaugeVec)} + } + panic(err) + } + + return &IntGaugeVec{vec: vec} +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *IntGaugeVec) With(tags map[string]string) metrics.IntGauge { + return &IntGauge{Gauge{gg: v.vec.With(tags)}} +} + +// Reset deletes all metrics in this vector. +func (v *IntGaugeVec) Reset() { + v.vec.Reset() +} + +var _ metrics.TimerVec = (*TimerVec)(nil) + +// TimerVec wraps prometheus.GaugeVec +// and implements metrics.TimerVec interface. +type TimerVec struct { + vec *prometheus.GaugeVec +} + +// TimerVec creates a new timers vector with given metric name and +// partitioned by the given label names. +func (r *Registry) TimerVec(name string, labels []string) metrics.TimerVec { + name = r.sanitizeName(name) + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &TimerVec{vec: existErr.ExistingCollector.(*prometheus.GaugeVec)} + } + panic(err) + } + + return &TimerVec{vec: vec} +} + +// With creates new or returns existing timer with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *TimerVec) With(tags map[string]string) metrics.Timer { + return &Timer{gg: v.vec.With(tags)} +} + +// Reset deletes all metrics in this vector. +func (v *TimerVec) Reset() { + v.vec.Reset() +} + +var _ metrics.HistogramVec = (*HistogramVec)(nil) + +// HistogramVec wraps prometheus.HistogramVec +// and implements metrics.HistogramVec interface. +type HistogramVec struct { + vec *prometheus.HistogramVec +} + +// HistogramVec creates a new histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) HistogramVec(name string, buckets metrics.Buckets, labels []string) metrics.HistogramVec { + name = r.sanitizeName(name) + vec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + Buckets: metricsutil.BucketsBounds(buckets), + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &HistogramVec{vec: existErr.ExistingCollector.(*prometheus.HistogramVec)} + } + panic(err) + } + + return &HistogramVec{vec: vec} +} + +// With creates new or returns existing histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *HistogramVec) With(tags map[string]string) metrics.Histogram { + return &Histogram{hm: v.vec.With(tags)} +} + +// Reset deletes all metrics in this vector. +func (v *HistogramVec) Reset() { + v.vec.Reset() +} + +var _ metrics.TimerVec = (*DurationHistogramVec)(nil) + +// DurationHistogramVec wraps prometheus.HistogramVec +// and implements metrics.TimerVec interface. +type DurationHistogramVec struct { + vec *prometheus.HistogramVec +} + +// DurationHistogramVec creates a new duration histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) DurationHistogramVec(name string, buckets metrics.DurationBuckets, labels []string) metrics.TimerVec { + name = r.sanitizeName(name) + vec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: r.prefix, + Name: name, + ConstLabels: r.tags, + Buckets: metricsutil.DurationBucketsBounds(buckets), + }, labels) + + if err := r.rg.Register(vec); err != nil { + var existErr prometheus.AlreadyRegisteredError + if xerrors.As(err, &existErr) { + return &DurationHistogramVec{vec: existErr.ExistingCollector.(*prometheus.HistogramVec)} + } + panic(err) + } + + return &DurationHistogramVec{vec: vec} +} + +// With creates new or returns existing duration histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *DurationHistogramVec) With(tags map[string]string) metrics.Timer { + return &Histogram{hm: v.vec.With(tags)} +} + +// Reset deletes all metrics in this vector. +func (v *DurationHistogramVec) Reset() { + v.vec.Reset() +} diff --git a/library/go/core/metrics/prometheus/vec_test.go b/library/go/core/metrics/prometheus/vec_test.go new file mode 100644 index 0000000000..ccf088c17a --- /dev/null +++ b/library/go/core/metrics/prometheus/vec_test.go @@ -0,0 +1,137 @@ +package prometheus + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +func TestCounterVec(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec := rg.CounterVec("ololo", []string{"shimba", "looken"}) + mt := vec.With(map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }) + + assert.IsType(t, &CounterVec{}, vec) + assert.IsType(t, &Counter{}, mt) + + vec.Reset() + + metrics, err := rg.Gather() + assert.NoError(t, err) + assert.Empty(t, metrics) +} + +func TestCounterVec_RegisterAgain(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec1 := rg.CounterVec("ololo", []string{"shimba", "looken"}).(*CounterVec) + vec2 := rg.CounterVec("ololo", []string{"shimba", "looken"}).(*CounterVec) + assert.Same(t, vec1.vec, vec2.vec) +} + +func TestGaugeVec(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec := rg.GaugeVec("ololo", []string{"shimba", "looken"}) + mt := vec.With(map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }) + + assert.IsType(t, &GaugeVec{}, vec) + assert.IsType(t, &Gauge{}, mt) + + vec.Reset() + + metrics, err := rg.Gather() + assert.NoError(t, err) + assert.Empty(t, metrics) +} + +func TestGaugeVec_RegisterAgain(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec1 := rg.GaugeVec("ololo", []string{"shimba", "looken"}).(*GaugeVec) + vec2 := rg.GaugeVec("ololo", []string{"shimba", "looken"}).(*GaugeVec) + assert.Same(t, vec1.vec, vec2.vec) +} + +func TestTimerVec(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec := rg.TimerVec("ololo", []string{"shimba", "looken"}) + mt := vec.With(map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }) + + assert.IsType(t, &TimerVec{}, vec) + assert.IsType(t, &Timer{}, mt) + + vec.Reset() + + metrics, err := rg.Gather() + assert.NoError(t, err) + assert.Empty(t, metrics) +} + +func TestTimerVec_RegisterAgain(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec1 := rg.TimerVec("ololo", []string{"shimba", "looken"}).(*TimerVec) + vec2 := rg.TimerVec("ololo", []string{"shimba", "looken"}).(*TimerVec) + assert.Same(t, vec1.vec, vec2.vec) +} + +func TestHistogramVec(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + buckets := metrics.NewBuckets(1, 2, 3) + vec := rg.HistogramVec("ololo", buckets, []string{"shimba", "looken"}) + mt := vec.With(map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }) + + assert.IsType(t, &HistogramVec{}, vec) + assert.IsType(t, &Histogram{}, mt) + + vec.Reset() + + metrics, err := rg.Gather() + assert.NoError(t, err) + assert.Empty(t, metrics) +} + +func TestHistogramVec_RegisterAgain(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + buckets := metrics.NewBuckets(1, 2, 3) + vec1 := rg.HistogramVec("ololo", buckets, []string{"shimba", "looken"}).(*HistogramVec) + vec2 := rg.HistogramVec("ololo", buckets, []string{"shimba", "looken"}).(*HistogramVec) + assert.Same(t, vec1.vec, vec2.vec) +} + +func TestDurationHistogramVec(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + buckets := metrics.NewDurationBuckets(1, 2, 3) + vec := rg.DurationHistogramVec("ololo", buckets, []string{"shimba", "looken"}) + mt := vec.With(map[string]string{ + "shimba": "boomba", + "looken": "tooken", + }) + + assert.IsType(t, &DurationHistogramVec{}, vec) + assert.IsType(t, &Histogram{}, mt) + + vec.Reset() + + metrics, err := rg.Gather() + assert.NoError(t, err) + assert.Empty(t, metrics) +} + +func TestDurationHistogramVec_RegisterAgain(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + buckets := metrics.NewDurationBuckets(1, 2, 3) + vec1 := rg.DurationHistogramVec("ololo", buckets, []string{"shimba", "looken"}).(*DurationHistogramVec) + vec2 := rg.DurationHistogramVec("ololo", buckets, []string{"shimba", "looken"}).(*DurationHistogramVec) + assert.Same(t, vec1.vec, vec2.vec) +} diff --git a/library/go/core/metrics/prometheus/ya.make b/library/go/core/metrics/prometheus/ya.make new file mode 100644 index 0000000000..b012835f4b --- /dev/null +++ b/library/go/core/metrics/prometheus/ya.make @@ -0,0 +1,25 @@ +GO_LIBRARY() + +SRCS( + counter.go + gauge.go + int_gauge.go + histogram.go + registry.go + registry_opts.go + timer.go + vec.go +) + +GO_TEST_SRCS( + counter_test.go + gauge_test.go + histogram_test.go + registry_test.go + timer_test.go + vec_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/core/metrics/solomon/converter.go b/library/go/core/metrics/solomon/converter.go new file mode 100644 index 0000000000..6976b223ba --- /dev/null +++ b/library/go/core/metrics/solomon/converter.go @@ -0,0 +1,114 @@ +package solomon + +import ( + "fmt" + + dto "github.com/prometheus/client_model/go" + "go.uber.org/atomic" +) + +// PrometheusMetrics converts Prometheus metrics to Solomon metrics. +func PrometheusMetrics(metrics []*dto.MetricFamily) (*Metrics, error) { + s := &Metrics{ + metrics: make([]Metric, 0, len(metrics)), + } + + if len(metrics) == 0 { + return s, nil + } + + for _, mf := range metrics { + if len(mf.Metric) == 0 { + continue + } + + for _, metric := range mf.Metric { + + tags := make(map[string]string, len(metric.Label)) + for _, label := range metric.Label { + tags[label.GetName()] = label.GetValue() + } + + switch *mf.Type { + case dto.MetricType_COUNTER: + s.metrics = append(s.metrics, &Counter{ + name: mf.GetName(), + metricType: typeCounter, + tags: tags, + value: *atomic.NewInt64(int64(metric.Counter.GetValue())), + }) + case dto.MetricType_GAUGE: + s.metrics = append(s.metrics, &Gauge{ + name: mf.GetName(), + metricType: typeGauge, + tags: tags, + value: *atomic.NewFloat64(metric.Gauge.GetValue()), + }) + case dto.MetricType_HISTOGRAM: + bounds := make([]float64, 0, len(metric.Histogram.Bucket)) + values := make([]int64, 0, len(metric.Histogram.Bucket)) + + var prevValuesSum int64 + + for _, bucket := range metric.Histogram.Bucket { + // prometheus uses cumulative buckets where solomon uses instant + bucketValue := int64(bucket.GetCumulativeCount()) + bucketValue -= prevValuesSum + prevValuesSum += bucketValue + + bounds = append(bounds, bucket.GetUpperBound()) + values = append(values, bucketValue) + } + + s.metrics = append(s.metrics, &Histogram{ + name: mf.GetName(), + metricType: typeHistogram, + tags: tags, + bucketBounds: bounds, + bucketValues: values, + infValue: *atomic.NewInt64(int64(metric.Histogram.GetSampleCount()) - prevValuesSum), + }) + case dto.MetricType_SUMMARY: + bounds := make([]float64, 0, len(metric.Summary.Quantile)) + values := make([]int64, 0, len(metric.Summary.Quantile)) + + var prevValuesSum int64 + + for _, bucket := range metric.Summary.GetQuantile() { + // prometheus uses cumulative buckets where solomon uses instant + bucketValue := int64(bucket.GetValue()) + bucketValue -= prevValuesSum + prevValuesSum += bucketValue + + bounds = append(bounds, bucket.GetQuantile()) + values = append(values, bucketValue) + } + + mName := mf.GetName() + + s.metrics = append(s.metrics, &Histogram{ + name: mName, + metricType: typeHistogram, + tags: tags, + bucketBounds: bounds, + bucketValues: values, + infValue: *atomic.NewInt64(int64(*metric.Summary.SampleCount) - prevValuesSum), + }, &Counter{ + name: mName + "_count", + metricType: typeCounter, + tags: tags, + value: *atomic.NewInt64(int64(*metric.Summary.SampleCount)), + }, &Gauge{ + name: mName + "_sum", + metricType: typeGauge, + tags: tags, + value: *atomic.NewFloat64(*metric.Summary.SampleSum), + }) + default: + return nil, fmt.Errorf("unsupported type: %s", mf.Type.String()) + } + } + } + + return s, nil +} diff --git a/library/go/core/metrics/solomon/converter_test.go b/library/go/core/metrics/solomon/converter_test.go new file mode 100644 index 0000000000..5368029038 --- /dev/null +++ b/library/go/core/metrics/solomon/converter_test.go @@ -0,0 +1,200 @@ +package solomon + +import ( + "testing" + + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/ptr" + "go.uber.org/atomic" +) + +func TestPrometheusMetrics(t *testing.T) { + testCases := []struct { + name string + metrics []*dto.MetricFamily + expect *Metrics + expectErr error + }{ + { + name: "success", + metrics: []*dto.MetricFamily{ + { + Name: ptr.String("subregister1_mygauge"), + Help: ptr.String(""), + Type: ptr.T(dto.MetricType_GAUGE), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + }, + Gauge: &dto.Gauge{Value: ptr.Float64(42)}, + }, + }, + }, + { + Name: ptr.String("subregisters_count"), + Help: ptr.String(""), + Type: ptr.T(dto.MetricType_COUNTER), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{}, + Counter: &dto.Counter{Value: ptr.Float64(2)}, + }, + }, + }, + { + Name: ptr.String("subregister1_subregister2_myhistogram"), + Help: ptr.String(""), + Type: ptr.T(dto.MetricType_HISTOGRAM), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + {Name: ptr.String("shimba"), Value: ptr.String("boomba")}, + }, + Histogram: &dto.Histogram{ + SampleCount: ptr.Uint64(6), + SampleSum: ptr.Float64(4.2), + Bucket: []*dto.Bucket{ + {CumulativeCount: ptr.Uint64(1), UpperBound: ptr.Float64(1)}, // 0.5 written + {CumulativeCount: ptr.Uint64(3), UpperBound: ptr.Float64(2)}, // 1.5 & 1.7 written + {CumulativeCount: ptr.Uint64(4), UpperBound: ptr.Float64(3)}, // 2.2 written + }, + }, + }, + }, + }, + { + Name: ptr.String("metrics_group"), + Help: ptr.String(""), + Type: ptr.T(dto.MetricType_COUNTER), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{}, + Counter: &dto.Counter{Value: ptr.Float64(2)}, + }, + { + Label: []*dto.LabelPair{}, + Counter: &dto.Counter{Value: ptr.Float64(3)}, + }, + }, + }, + }, + expect: &Metrics{ + metrics: []Metric{ + &Gauge{ + name: "subregister1_mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewFloat64(42), + }, + &Counter{ + name: "subregisters_count", + metricType: typeCounter, + tags: map[string]string{}, + value: *atomic.NewInt64(2), + }, + &Histogram{ + name: "subregister1_subregister2_myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo", "shimba": "boomba"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(2), + }, + // group of metrics + &Counter{ + name: "metrics_group", + metricType: typeCounter, + tags: map[string]string{}, + value: *atomic.NewInt64(2), + }, + &Counter{ + name: "metrics_group", + metricType: typeCounter, + tags: map[string]string{}, + value: *atomic.NewInt64(3), + }, + }, + }, + expectErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s, err := PrometheusMetrics(tc.metrics) + + if tc.expectErr == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.expectErr.Error()) + } + + assert.Equal(t, tc.expect, s) + }) + } +} + +func TestPrometheusSummaryMetric(t *testing.T) { + src := []*dto.MetricFamily{ + { + Name: ptr.String("subregister1_subregister2_mysummary"), + Help: ptr.String(""), + Type: func(mt dto.MetricType) *dto.MetricType { return &mt }(dto.MetricType_SUMMARY), + Metric: []*dto.Metric{ + { + Label: []*dto.LabelPair{ + {Name: ptr.String("ololo"), Value: ptr.String("trololo")}, + {Name: ptr.String("shimba"), Value: ptr.String("boomba")}, + }, + Summary: &dto.Summary{ + SampleCount: ptr.Uint64(8), + SampleSum: ptr.Float64(4.2), + Quantile: []*dto.Quantile{ + {Value: ptr.Float64(1), Quantile: ptr.Float64(1)}, // 0.5 written + {Value: ptr.Float64(3), Quantile: ptr.Float64(2)}, // 1.5 & 1.7 written + {Value: ptr.Float64(4), Quantile: ptr.Float64(3)}, // 2.2 written + }, + }, + }, + }, + }, + } + + mName := "subregister1_subregister2_mysummary" + mTags := map[string]string{"ololo": "trololo", "shimba": "boomba"} + bBounds := []float64{1, 2, 3} + bValues := []int64{1, 2, 1} + + expect := &Metrics{ + metrics: []Metric{ + &Histogram{ + name: mName, + metricType: typeHistogram, + tags: mTags, + bucketBounds: bBounds, + bucketValues: bValues, + infValue: *atomic.NewInt64(4), + }, + &Counter{ + name: mName + "_count", + metricType: typeCounter, + tags: mTags, + value: *atomic.NewInt64(8), + }, + &Gauge{ + name: mName + "_sum", + metricType: typeGauge, + tags: mTags, + value: *atomic.NewFloat64(4.2), + }, + }, + } + + s, err := PrometheusMetrics(src) + assert.NoError(t, err) + + assert.Equal(t, expect, s) +} diff --git a/library/go/core/metrics/solomon/counter.go b/library/go/core/metrics/solomon/counter.go new file mode 100644 index 0000000000..e37933760c --- /dev/null +++ b/library/go/core/metrics/solomon/counter.go @@ -0,0 +1,97 @@ +package solomon + +import ( + "encoding/json" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var ( + _ metrics.Counter = (*Counter)(nil) + _ Metric = (*Counter)(nil) +) + +// Counter tracks monotonically increasing value. +type Counter struct { + name string + metricType metricType + tags map[string]string + value atomic.Int64 + timestamp *time.Time + + useNameTag bool +} + +// Inc increments counter by 1. +func (c *Counter) Inc() { + c.Add(1) +} + +// Add adds delta to the counter. Delta must be >=0. +func (c *Counter) Add(delta int64) { + c.value.Add(delta) +} + +func (c *Counter) Name() string { + return c.name +} + +func (c *Counter) getType() metricType { + return c.metricType +} + +func (c *Counter) getLabels() map[string]string { + return c.tags +} + +func (c *Counter) getValue() interface{} { + return c.value.Load() +} + +func (c *Counter) getTimestamp() *time.Time { + return c.timestamp +} + +func (c *Counter) getNameTag() string { + if c.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (c *Counter) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: c.metricType.String(), + Value: c.value.Load(), + Labels: func() map[string]string { + labels := make(map[string]string, len(c.tags)+1) + labels[c.getNameTag()] = c.Name() + for k, v := range c.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(c.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (c *Counter) Snapshot() Metric { + return &Counter{ + name: c.name, + metricType: c.metricType, + tags: c.tags, + value: *atomic.NewInt64(c.value.Load()), + + useNameTag: c.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/counter_test.go b/library/go/core/metrics/solomon/counter_test.go new file mode 100644 index 0000000000..09284125d2 --- /dev/null +++ b/library/go/core/metrics/solomon/counter_test.go @@ -0,0 +1,90 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestCounter_Add(t *testing.T) { + c := &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + } + + c.Add(1) + assert.Equal(t, int64(1), c.value.Load()) + + c.Add(42) + assert.Equal(t, int64(43), c.value.Load()) + + c.Add(1489) + assert.Equal(t, int64(1532), c.value.Load()) +} + +func TestCounter_Inc(t *testing.T) { + c := &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + } + + for i := 0; i < 10; i++ { + c.Inc() + } + assert.Equal(t, int64(10), c.value.Load()) + + c.Inc() + c.Inc() + assert.Equal(t, int64(12), c.value.Load()) +} + +func TestCounter_MarshalJSON(t *testing.T) { + c := &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"COUNTER","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestRatedCounter_MarshalJSON(t *testing.T) { + c := &Counter{ + name: "mycounter", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"RATE","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestNameTagCounter_MarshalJSON(t *testing.T) { + c := &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"COUNTER","labels":{"name":"mycounter","ololo":"trololo"},"value":42}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/func_counter.go b/library/go/core/metrics/solomon/func_counter.go new file mode 100644 index 0000000000..db862869e4 --- /dev/null +++ b/library/go/core/metrics/solomon/func_counter.go @@ -0,0 +1,86 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" +) + +var _ Metric = (*FuncCounter)(nil) + +// FuncCounter tracks int64 value returned by function. +type FuncCounter struct { + name string + metricType metricType + tags map[string]string + function func() int64 + timestamp *time.Time + useNameTag bool +} + +func (c *FuncCounter) Name() string { + return c.name +} + +func (c *FuncCounter) Function() func() int64 { + return c.function +} + +func (c *FuncCounter) getType() metricType { + return c.metricType +} + +func (c *FuncCounter) getLabels() map[string]string { + return c.tags +} + +func (c *FuncCounter) getValue() interface{} { + return c.function() +} + +func (c *FuncCounter) getTimestamp() *time.Time { + return c.timestamp +} + +func (c *FuncCounter) getNameTag() string { + if c.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (c *FuncCounter) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: c.metricType.String(), + Value: c.function(), + Labels: func() map[string]string { + labels := make(map[string]string, len(c.tags)+1) + labels[c.getNameTag()] = c.Name() + for k, v := range c.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(c.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (c *FuncCounter) Snapshot() Metric { + return &Counter{ + name: c.name, + metricType: c.metricType, + tags: c.tags, + value: *atomic.NewInt64(c.function()), + + useNameTag: c.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/func_counter_test.go b/library/go/core/metrics/solomon/func_counter_test.go new file mode 100644 index 0000000000..7849769d12 --- /dev/null +++ b/library/go/core/metrics/solomon/func_counter_test.go @@ -0,0 +1,82 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestFuncCounter_Inc(t *testing.T) { + val := new(atomic.Int64) + c := &FuncCounter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return val.Load() + }, + } + + val.Store(1) + assert.Equal(t, int64(1), c.Snapshot().(*Counter).value.Load()) + + val.Store(42) + assert.Equal(t, int64(42), c.Snapshot().(*Counter).value.Load()) + +} + +func TestFuncCounter_MarshalJSON(t *testing.T) { + c := &FuncCounter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return 42 + }, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"COUNTER","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestRatedFuncCounter_MarshalJSON(t *testing.T) { + c := &FuncCounter{ + name: "mycounter", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return 42 + }, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"RATE","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestNameTagFuncCounter_MarshalJSON(t *testing.T) { + c := &FuncCounter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + + function: func() int64 { + return 42 + }, + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"COUNTER","labels":{"name":"mycounter","ololo":"trololo"},"value":42}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/func_gauge.go b/library/go/core/metrics/solomon/func_gauge.go new file mode 100644 index 0000000000..ce824c6fa8 --- /dev/null +++ b/library/go/core/metrics/solomon/func_gauge.go @@ -0,0 +1,87 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" +) + +var _ Metric = (*FuncGauge)(nil) + +// FuncGauge tracks float64 value returned by function. +type FuncGauge struct { + name string + metricType metricType + tags map[string]string + function func() float64 + timestamp *time.Time + + useNameTag bool +} + +func (g *FuncGauge) Name() string { + return g.name +} + +func (g *FuncGauge) Function() func() float64 { + return g.function +} + +func (g *FuncGauge) getType() metricType { + return g.metricType +} + +func (g *FuncGauge) getLabels() map[string]string { + return g.tags +} + +func (g *FuncGauge) getValue() interface{} { + return g.function() +} + +func (g *FuncGauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *FuncGauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *FuncGauge) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: g.metricType.String(), + Value: g.function(), + Labels: func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (g *FuncGauge) Snapshot() Metric { + return &Gauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewFloat64(g.function()), + + useNameTag: g.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/func_gauge_test.go b/library/go/core/metrics/solomon/func_gauge_test.go new file mode 100644 index 0000000000..f4317a0cab --- /dev/null +++ b/library/go/core/metrics/solomon/func_gauge_test.go @@ -0,0 +1,64 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestFuncGauge_Value(t *testing.T) { + val := new(atomic.Float64) + c := &FuncGauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() float64 { + return val.Load() + }, + } + + val.Store(1) + assert.Equal(t, float64(1), c.Snapshot().(*Gauge).value.Load()) + + val.Store(42) + assert.Equal(t, float64(42), c.Snapshot().(*Gauge).value.Load()) + +} + +func TestFunGauge_MarshalJSON(t *testing.T) { + c := &FuncGauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() float64 { + return 42.18 + }, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"DGAUGE","labels":{"ololo":"trololo","sensor":"mygauge"},"value":42.18}`) + assert.Equal(t, expected, b) +} + +func TestNameTagFunGauge_MarshalJSON(t *testing.T) { + c := &FuncGauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() float64 { + return 42.18 + }, + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"DGAUGE","labels":{"name":"mygauge","ololo":"trololo"},"value":42.18}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/func_int_gauge.go b/library/go/core/metrics/solomon/func_int_gauge.go new file mode 100644 index 0000000000..4e7f22949a --- /dev/null +++ b/library/go/core/metrics/solomon/func_int_gauge.go @@ -0,0 +1,87 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" +) + +var _ Metric = (*FuncIntGauge)(nil) + +// FuncIntGauge tracks int64 value returned by function. +type FuncIntGauge struct { + name string + metricType metricType + tags map[string]string + function func() int64 + timestamp *time.Time + + useNameTag bool +} + +func (g *FuncIntGauge) Name() string { + return g.name +} + +func (g *FuncIntGauge) Function() func() int64 { + return g.function +} + +func (g *FuncIntGauge) getType() metricType { + return g.metricType +} + +func (g *FuncIntGauge) getLabels() map[string]string { + return g.tags +} + +func (g *FuncIntGauge) getValue() interface{} { + return g.function() +} + +func (g *FuncIntGauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *FuncIntGauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *FuncIntGauge) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: g.metricType.String(), + Value: g.function(), + Labels: func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (g *FuncIntGauge) Snapshot() Metric { + return &IntGauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewInt64(g.function()), + + useNameTag: g.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/func_int_gauge_test.go b/library/go/core/metrics/solomon/func_int_gauge_test.go new file mode 100644 index 0000000000..4a576461e3 --- /dev/null +++ b/library/go/core/metrics/solomon/func_int_gauge_test.go @@ -0,0 +1,64 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestFuncIntGauge_Value(t *testing.T) { + val := new(atomic.Int64) + c := &FuncIntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return val.Load() + }, + } + + val.Store(1) + assert.Equal(t, int64(1), c.Snapshot().(*IntGauge).value.Load()) + + val.Store(42) + assert.Equal(t, int64(42), c.Snapshot().(*IntGauge).value.Load()) + +} + +func TestFunIntGauge_MarshalJSON(t *testing.T) { + c := &FuncIntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return 42 + }, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"IGAUGE","labels":{"ololo":"trololo","sensor":"myintgauge"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestNameTagFunIntGauge_MarshalJSON(t *testing.T) { + c := &FuncIntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + function: func() int64 { + return 42 + }, + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"IGAUGE","labels":{"name":"myintgauge","ololo":"trololo"},"value":42}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/gauge.go b/library/go/core/metrics/solomon/gauge.go new file mode 100644 index 0000000000..4660d33c11 --- /dev/null +++ b/library/go/core/metrics/solomon/gauge.go @@ -0,0 +1,115 @@ +package solomon + +import ( + "encoding/json" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var ( + _ metrics.Gauge = (*Gauge)(nil) + _ Metric = (*Gauge)(nil) +) + +// Gauge tracks single float64 value. +type Gauge struct { + name string + metricType metricType + tags map[string]string + value atomic.Float64 + timestamp *time.Time + + useNameTag bool +} + +func NewGauge(name string, value float64, opts ...metricOpts) Gauge { + mOpts := MetricsOpts{} + for _, op := range opts { + op(&mOpts) + } + return Gauge{ + name: name, + metricType: typeGauge, + tags: mOpts.tags, + value: *atomic.NewFloat64(value), + useNameTag: mOpts.useNameTag, + timestamp: mOpts.timestamp, + } +} + +func (g *Gauge) Set(value float64) { + g.value.Store(value) +} + +func (g *Gauge) Add(value float64) { + g.value.Add(value) +} + +func (g *Gauge) Name() string { + return g.name +} + +func (g *Gauge) getType() metricType { + return g.metricType +} + +func (g *Gauge) getLabels() map[string]string { + return g.tags +} + +func (g *Gauge) getValue() interface{} { + return g.value.Load() +} + +func (g *Gauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *Gauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *Gauge) MarshalJSON() ([]byte, error) { + metricType := g.metricType.String() + value := g.value.Load() + labels := func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }() + + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: metricType, + Value: value, + Labels: labels, + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy of metric. +func (g *Gauge) Snapshot() Metric { + return &Gauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewFloat64(g.value.Load()), + + useNameTag: g.useNameTag, + timestamp: g.timestamp, + } +} diff --git a/library/go/core/metrics/solomon/gauge_test.go b/library/go/core/metrics/solomon/gauge_test.go new file mode 100644 index 0000000000..82659a49c4 --- /dev/null +++ b/library/go/core/metrics/solomon/gauge_test.go @@ -0,0 +1,75 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestGauge_Add(t *testing.T) { + c := &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + } + + c.Add(1) + assert.Equal(t, float64(1), c.value.Load()) + + c.Add(42) + assert.Equal(t, float64(43), c.value.Load()) + + c.Add(14.89) + assert.Equal(t, float64(57.89), c.value.Load()) +} + +func TestGauge_Set(t *testing.T) { + c := &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + } + + c.Set(1) + assert.Equal(t, float64(1), c.value.Load()) + + c.Set(42) + assert.Equal(t, float64(42), c.value.Load()) + + c.Set(14.89) + assert.Equal(t, float64(14.89), c.value.Load()) +} + +func TestGauge_MarshalJSON(t *testing.T) { + c := &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewFloat64(42.18), + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"DGAUGE","labels":{"ololo":"trololo","sensor":"mygauge"},"value":42.18}`) + assert.Equal(t, expected, b) +} + +func TestNameTagGauge_MarshalJSON(t *testing.T) { + c := &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewFloat64(42.18), + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"DGAUGE","labels":{"name":"mygauge","ololo":"trololo"},"value":42.18}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/gotest/ya.make b/library/go/core/metrics/solomon/gotest/ya.make new file mode 100644 index 0000000000..0c386167a4 --- /dev/null +++ b/library/go/core/metrics/solomon/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/core/metrics/solomon) + +END() diff --git a/library/go/core/metrics/solomon/histogram.go b/library/go/core/metrics/solomon/histogram.go new file mode 100644 index 0000000000..6f4d3629e0 --- /dev/null +++ b/library/go/core/metrics/solomon/histogram.go @@ -0,0 +1,182 @@ +package solomon + +import ( + "encoding/binary" + "encoding/json" + "io" + "sort" + "sync" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/xerrors" + "go.uber.org/atomic" +) + +var ( + _ metrics.Histogram = (*Histogram)(nil) + _ metrics.Timer = (*Histogram)(nil) + _ Metric = (*Histogram)(nil) +) + +type Histogram struct { + name string + metricType metricType + tags map[string]string + bucketBounds []float64 + bucketValues []int64 + infValue atomic.Int64 + mutex sync.Mutex + timestamp *time.Time + useNameTag bool +} + +type histogram struct { + Bounds []float64 `json:"bounds"` + Buckets []int64 `json:"buckets"` + Inf int64 `json:"inf,omitempty"` +} + +func (h *histogram) writeHistogram(w io.Writer) error { + err := writeULEB128(w, uint32(len(h.Buckets))) + if err != nil { + return xerrors.Errorf("writeULEB128 size histogram buckets failed: %w", err) + } + + for _, upperBound := range h.Bounds { + err = binary.Write(w, binary.LittleEndian, float64(upperBound)) + if err != nil { + return xerrors.Errorf("binary.Write upper bound failed: %w", err) + } + } + + for _, bucketValue := range h.Buckets { + err = binary.Write(w, binary.LittleEndian, uint64(bucketValue)) + if err != nil { + return xerrors.Errorf("binary.Write histogram buckets failed: %w", err) + } + } + return nil +} + +func (h *Histogram) RecordValue(value float64) { + boundIndex := sort.SearchFloat64s(h.bucketBounds, value) + + if boundIndex < len(h.bucketValues) { + h.mutex.Lock() + h.bucketValues[boundIndex] += 1 + h.mutex.Unlock() + } else { + h.infValue.Inc() + } +} + +func (h *Histogram) RecordDuration(value time.Duration) { + h.RecordValue(value.Seconds()) +} + +func (h *Histogram) Reset() { + h.mutex.Lock() + defer h.mutex.Unlock() + + h.bucketValues = make([]int64, len(h.bucketValues)) + h.infValue.Store(0) +} + +func (h *Histogram) Name() string { + return h.name +} + +func (h *Histogram) getType() metricType { + return h.metricType +} + +func (h *Histogram) getLabels() map[string]string { + return h.tags +} + +func (h *Histogram) getValue() interface{} { + return histogram{ + Bounds: h.bucketBounds, + Buckets: h.bucketValues, + } +} + +func (h *Histogram) getTimestamp() *time.Time { + return h.timestamp +} + +func (h *Histogram) getNameTag() string { + if h.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (h *Histogram) MarshalJSON() ([]byte, error) { + valuesCopy := make([]int64, len(h.bucketValues)) + h.mutex.Lock() + copy(valuesCopy, h.bucketValues) + h.mutex.Unlock() + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Histogram histogram `json:"hist"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: h.metricType.String(), + Histogram: histogram{ + Bounds: h.bucketBounds, + Buckets: valuesCopy, + Inf: h.infValue.Load(), + }, + Labels: func() map[string]string { + labels := make(map[string]string, len(h.tags)+1) + labels[h.getNameTag()] = h.Name() + for k, v := range h.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(h.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (h *Histogram) Snapshot() Metric { + bucketBounds := make([]float64, len(h.bucketBounds)) + bucketValues := make([]int64, len(h.bucketValues)) + + copy(bucketBounds, h.bucketBounds) + h.mutex.Lock() + copy(bucketValues, h.bucketValues) + h.mutex.Unlock() + + return &Histogram{ + name: h.name, + metricType: h.metricType, + tags: h.tags, + bucketBounds: bucketBounds, + bucketValues: bucketValues, + infValue: *atomic.NewInt64(h.infValue.Load()), + useNameTag: h.useNameTag, + } +} + +// InitBucketValues cleans internal bucketValues and saves new values in order. +// Length of internal bucketValues stays unchanged. +// If length of slice in argument bucketValues more than length of internal one, +// the first extra element of bucketValues is stored in infValue. +func (h *Histogram) InitBucketValues(bucketValues []int64) { + h.mutex.Lock() + defer h.mutex.Unlock() + + h.bucketValues = make([]int64, len(h.bucketValues)) + h.infValue.Store(0) + copy(h.bucketValues, bucketValues) + if len(bucketValues) > len(h.bucketValues) { + h.infValue.Store(bucketValues[len(h.bucketValues)]) + } +} diff --git a/library/go/core/metrics/solomon/histogram_test.go b/library/go/core/metrics/solomon/histogram_test.go new file mode 100644 index 0000000000..be7042397c --- /dev/null +++ b/library/go/core/metrics/solomon/histogram_test.go @@ -0,0 +1,153 @@ +package solomon + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestHistogram_MarshalJSON(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(2), + } + + b, err := json.Marshal(h) + assert.NoError(t, err) + + expected := []byte(`{"type":"HIST","labels":{"ololo":"trololo","sensor":"myhistogram"},"hist":{"bounds":[1,2,3],"buckets":[1,2,1],"inf":2}}`) + assert.Equal(t, expected, b) +} + +func TestRatedHistogram_MarshalJSON(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeRatedHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(2), + } + + b, err := json.Marshal(h) + assert.NoError(t, err) + + expected := []byte(`{"type":"HIST_RATE","labels":{"ololo":"trololo","sensor":"myhistogram"},"hist":{"bounds":[1,2,3],"buckets":[1,2,1],"inf":2}}`) + assert.Equal(t, expected, b) +} + +func TestNameTagHistogram_MarshalJSON(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeRatedHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(2), + useNameTag: true, + } + + b, err := json.Marshal(h) + assert.NoError(t, err) + + expected := []byte(`{"type":"HIST_RATE","labels":{"name":"myhistogram","ololo":"trololo"},"hist":{"bounds":[1,2,3],"buckets":[1,2,1],"inf":2}}`) + assert.Equal(t, expected, b) +} + +func TestHistogram_RecordDuration(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: make([]int64, 3), + } + + h.RecordDuration(500 * time.Millisecond) + h.RecordDuration(1 * time.Second) + h.RecordDuration(1800 * time.Millisecond) + h.RecordDuration(3 * time.Second) + h.RecordDuration(1 * time.Hour) + + expectedValues := []int64{2, 1, 1} + assert.Equal(t, expectedValues, h.bucketValues) + + var expectedInfValue int64 = 1 + assert.Equal(t, expectedInfValue, h.infValue.Load()) +} + +func TestHistogram_RecordValue(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: make([]int64, 3), + } + + h.RecordValue(0.5) + h.RecordValue(1) + h.RecordValue(1.8) + h.RecordValue(3) + h.RecordValue(60) + + expectedValues := []int64{2, 1, 1} + assert.Equal(t, expectedValues, h.bucketValues) + + var expectedInfValue int64 = 1 + assert.Equal(t, expectedInfValue, h.infValue.Load()) +} + +func TestHistogram_Reset(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: make([]int64, 3), + } + + h.RecordValue(0.5) + h.RecordValue(1) + h.RecordValue(1.8) + h.RecordValue(3) + h.RecordValue(60) + + assert.Equal(t, []int64{2, 1, 1}, h.bucketValues) + assert.Equal(t, int64(1), h.infValue.Load()) + + h.Reset() + + assert.Equal(t, []int64{0, 0, 0}, h.bucketValues) + assert.Equal(t, int64(0), h.infValue.Load()) +} + +func TestHistogram_InitBucketValues(t *testing.T) { + h := &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: make([]int64, 3), + } + + valsToInit := []int64{1, 2, 3, 4} + h.InitBucketValues(valsToInit[:2]) + assert.Equal(t, append(valsToInit[:2], 0), h.bucketValues) + assert.Equal(t, *atomic.NewInt64(0), h.infValue) + + h.InitBucketValues(valsToInit[:3]) + assert.Equal(t, valsToInit[:3], h.bucketValues) + assert.Equal(t, *atomic.NewInt64(0), h.infValue) + + h.InitBucketValues(valsToInit) + assert.Equal(t, valsToInit[:3], h.bucketValues) + assert.Equal(t, *atomic.NewInt64(valsToInit[3]), h.infValue) +} diff --git a/library/go/core/metrics/solomon/int_gauge.go b/library/go/core/metrics/solomon/int_gauge.go new file mode 100644 index 0000000000..8733bf11fe --- /dev/null +++ b/library/go/core/metrics/solomon/int_gauge.go @@ -0,0 +1,115 @@ +package solomon + +import ( + "encoding/json" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var ( + _ metrics.IntGauge = (*IntGauge)(nil) + _ Metric = (*IntGauge)(nil) +) + +// IntGauge tracks single float64 value. +type IntGauge struct { + name string + metricType metricType + tags map[string]string + value atomic.Int64 + timestamp *time.Time + + useNameTag bool +} + +func NewIntGauge(name string, value int64, opts ...metricOpts) IntGauge { + mOpts := MetricsOpts{} + for _, op := range opts { + op(&mOpts) + } + return IntGauge{ + name: name, + metricType: typeIGauge, + tags: mOpts.tags, + value: *atomic.NewInt64(value), + useNameTag: mOpts.useNameTag, + timestamp: mOpts.timestamp, + } +} + +func (g *IntGauge) Set(value int64) { + g.value.Store(value) +} + +func (g *IntGauge) Add(value int64) { + g.value.Add(value) +} + +func (g *IntGauge) Name() string { + return g.name +} + +func (g *IntGauge) getType() metricType { + return g.metricType +} + +func (g *IntGauge) getLabels() map[string]string { + return g.tags +} + +func (g *IntGauge) getValue() interface{} { + return g.value.Load() +} + +func (g *IntGauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *IntGauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *IntGauge) MarshalJSON() ([]byte, error) { + metricType := g.metricType.String() + value := g.value.Load() + labels := func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }() + + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: metricType, + Value: value, + Labels: labels, + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy of metric. +func (g *IntGauge) Snapshot() Metric { + return &IntGauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewInt64(g.value.Load()), + + useNameTag: g.useNameTag, + timestamp: g.timestamp, + } +} diff --git a/library/go/core/metrics/solomon/int_gauge_test.go b/library/go/core/metrics/solomon/int_gauge_test.go new file mode 100644 index 0000000000..5918ef9ac3 --- /dev/null +++ b/library/go/core/metrics/solomon/int_gauge_test.go @@ -0,0 +1,75 @@ +package solomon + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestIntGauge_Add(t *testing.T) { + c := &IntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + } + + c.Add(1) + assert.Equal(t, int64(1), c.value.Load()) + + c.Add(42) + assert.Equal(t, int64(43), c.value.Load()) + + c.Add(-45) + assert.Equal(t, int64(-2), c.value.Load()) +} + +func TestIntGauge_Set(t *testing.T) { + c := &IntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + } + + c.Set(1) + assert.Equal(t, int64(1), c.value.Load()) + + c.Set(42) + assert.Equal(t, int64(42), c.value.Load()) + + c.Set(-45) + assert.Equal(t, int64(-45), c.value.Load()) +} + +func TestIntGauge_MarshalJSON(t *testing.T) { + c := &IntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"IGAUGE","labels":{"ololo":"trololo","sensor":"myintgauge"},"value":42}`) + assert.Equal(t, expected, b) +} + +func TestNameTagIntGauge_MarshalJSON(t *testing.T) { + c := &IntGauge{ + name: "myintgauge", + metricType: typeIGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"IGAUGE","labels":{"name":"myintgauge","ololo":"trololo"},"value":42}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/metrics.go b/library/go/core/metrics/solomon/metrics.go new file mode 100644 index 0000000000..6b73fd10a6 --- /dev/null +++ b/library/go/core/metrics/solomon/metrics.go @@ -0,0 +1,187 @@ +package solomon + +import ( + "bytes" + "context" + "encoding" + "encoding/json" + "fmt" + "time" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" + "golang.org/x/exp/slices" +) + +// Gather collects all metrics data via snapshots. +func (r Registry) Gather() (*Metrics, error) { + metrics := make([]Metric, 0) + + var err error + r.metrics.Range(func(_, v interface{}) bool { + if s, ok := v.(Metric); ok { + metrics = append(metrics, s.Snapshot()) + return true + } + err = fmt.Errorf("unexpected value type: %T", v) + return false + }) + + if err != nil { + return nil, err + } + + return &Metrics{metrics: metrics}, nil +} + +func NewMetrics(metrics []Metric) Metrics { + return Metrics{metrics: metrics} +} + +func NewMetricsWithTimestamp(metrics []Metric, ts time.Time) Metrics { + return Metrics{metrics: metrics, timestamp: &ts} +} + +type valueType uint8 + +const ( + valueTypeNone valueType = iota + valueTypeOneWithoutTS valueType = 0x01 + valueTypeOneWithTS valueType = 0x02 + valueTypeManyWithTS valueType = 0x03 +) + +type metricType uint8 + +const ( + typeUnspecified metricType = iota + typeGauge metricType = 0x01 + typeCounter metricType = 0x02 + typeRated metricType = 0x03 + typeIGauge metricType = 0x04 + typeHistogram metricType = 0x05 + typeRatedHistogram metricType = 0x06 +) + +func (k metricType) String() string { + switch k { + case typeCounter: + return "COUNTER" + case typeGauge: + return "DGAUGE" + case typeIGauge: + return "IGAUGE" + case typeHistogram: + return "HIST" + case typeRated: + return "RATE" + case typeRatedHistogram: + return "HIST_RATE" + default: + panic("unknown metric type") + } +} + +// Metric is an any abstract solomon Metric. +type Metric interface { + json.Marshaler + + Name() string + getType() metricType + getLabels() map[string]string + getValue() interface{} + getNameTag() string + getTimestamp() *time.Time + + Snapshot() Metric +} + +// Rated marks given Solomon metric or vector as rated. +// Example: +// +// cnt := r.Counter("mycounter") +// Rated(cnt) +// +// cntvec := r.CounterVec("mycounter", []string{"mytag"}) +// Rated(cntvec) +// +// For additional info: https://docs.yandex-team.ru/solomon/data-collection/dataformat/json +func Rated(s interface{}) { + switch st := s.(type) { + case *Counter: + st.metricType = typeRated + case *FuncCounter: + st.metricType = typeRated + case *Histogram: + st.metricType = typeRatedHistogram + + case *CounterVec: + st.vec.rated = true + case *HistogramVec: + st.vec.rated = true + case *DurationHistogramVec: + st.vec.rated = true + } + // any other metrics types are unrateable +} + +var ( + _ json.Marshaler = (*Metrics)(nil) + _ encoding.BinaryMarshaler = (*Metrics)(nil) +) + +type Metrics struct { + metrics []Metric + timestamp *time.Time +} + +// MarshalJSON implements json.Marshaler. +func (s Metrics) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Metrics []Metric `json:"metrics"` + Timestamp *int64 `json:"ts,omitempty"` + }{s.metrics, tsAsRef(s.timestamp)}) +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (s Metrics) MarshalBinary() ([]byte, error) { + var buf bytes.Buffer + se := NewSpackEncoder(context.Background(), CompressionNone, &s) + n, err := se.Encode(&buf) + if err != nil { + return nil, xerrors.Errorf("encode only %d bytes: %w", n, err) + } + return buf.Bytes(), nil +} + +// SplitToChunks splits Metrics into a slice of chunks, each at most maxChunkSize long. +// The length of returned slice is always at least one. +// Zero maxChunkSize denotes unlimited chunk length. +func (s Metrics) SplitToChunks(maxChunkSize int) []Metrics { + if maxChunkSize == 0 || len(s.metrics) == 0 { + return []Metrics{s} + } + chunks := make([]Metrics, 0, len(s.metrics)/maxChunkSize+1) + + for leftBound := 0; leftBound < len(s.metrics); leftBound += maxChunkSize { + rightBound := leftBound + maxChunkSize + if rightBound > len(s.metrics) { + rightBound = len(s.metrics) + } + chunk := s.metrics[leftBound:rightBound] + chunks = append(chunks, Metrics{metrics: chunk}) + } + return chunks +} + +// List return list of metrics +func (s Metrics) List() []Metric { + return slices.Clone(s.metrics) +} + +func tsAsRef(t *time.Time) *int64 { + if t == nil { + return nil + } + ts := t.Unix() + return &ts +} diff --git a/library/go/core/metrics/solomon/metrics_opts.go b/library/go/core/metrics/solomon/metrics_opts.go new file mode 100644 index 0000000000..d9ade67966 --- /dev/null +++ b/library/go/core/metrics/solomon/metrics_opts.go @@ -0,0 +1,29 @@ +package solomon + +import "time" + +type MetricsOpts struct { + useNameTag bool + tags map[string]string + timestamp *time.Time +} + +type metricOpts func(*MetricsOpts) + +func WithTags(tags map[string]string) func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.tags = tags + } +} + +func WithUseNameTag() func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.useNameTag = true + } +} + +func WithTimestamp(t time.Time) func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.timestamp = &t + } +} diff --git a/library/go/core/metrics/solomon/metrics_test.go b/library/go/core/metrics/solomon/metrics_test.go new file mode 100644 index 0000000000..610fa061a1 --- /dev/null +++ b/library/go/core/metrics/solomon/metrics_test.go @@ -0,0 +1,296 @@ +package solomon + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +func TestMetrics_MarshalJSON(t *testing.T) { + s := &Metrics{ + metrics: []Metric{ + &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + &Counter{ + name: "myratedcounter", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"shimba": "boomba"}, + value: *atomic.NewFloat64(14.89), + }, + &Timer{ + name: "mytimer", + metricType: typeGauge, + tags: map[string]string{"looken": "tooken"}, + value: *atomic.NewDuration(1456 * time.Millisecond), + }, + &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"chicken": "cooken"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(1), + }, + &Histogram{ + name: "myratedhistogram", + metricType: typeRatedHistogram, + tags: map[string]string{"chicken": "cooken"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(1), + }, + &Gauge{ + name: "mytimedgauge", + metricType: typeGauge, + tags: map[string]string{"oki": "toki"}, + value: *atomic.NewFloat64(42.24), + timestamp: timeAsRef(time.Unix(1500000000, 0)), + }, + }, + } + + b, err := json.Marshal(s) + assert.NoError(t, err) + + expected := []byte(`{"metrics":[` + + `{"type":"COUNTER","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42},` + + `{"type":"RATE","labels":{"ololo":"trololo","sensor":"myratedcounter"},"value":42},` + + `{"type":"DGAUGE","labels":{"sensor":"mygauge","shimba":"boomba"},"value":14.89},` + + `{"type":"DGAUGE","labels":{"looken":"tooken","sensor":"mytimer"},"value":1.456},` + + `{"type":"HIST","labels":{"chicken":"cooken","sensor":"myhistogram"},"hist":{"bounds":[1,2,3],"buckets":[1,2,1],"inf":1}},` + + `{"type":"HIST_RATE","labels":{"chicken":"cooken","sensor":"myratedhistogram"},"hist":{"bounds":[1,2,3],"buckets":[1,2,1],"inf":1}},` + + `{"type":"DGAUGE","labels":{"oki":"toki","sensor":"mytimedgauge"},"value":42.24,"ts":1500000000}` + + `]}`) + assert.Equal(t, expected, b) +} + +func timeAsRef(t time.Time) *time.Time { + return &t +} + +func TestMetrics_with_timestamp_MarshalJSON(t *testing.T) { + s := &Metrics{ + metrics: []Metric{ + &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + &Gauge{ + name: "mytimedgauge", + metricType: typeGauge, + tags: map[string]string{"oki": "toki"}, + value: *atomic.NewFloat64(42.24), + timestamp: timeAsRef(time.Unix(1500000000, 0)), + }, + }, + timestamp: timeAsRef(time.Unix(1657710477, 0)), + } + + b, err := json.Marshal(s) + assert.NoError(t, err) + + expected := []byte(`{"metrics":[` + + `{"type":"COUNTER","labels":{"ololo":"trololo","sensor":"mycounter"},"value":42},` + + `{"type":"DGAUGE","labels":{"oki":"toki","sensor":"mytimedgauge"},"value":42.24,"ts":1500000000}` + + `],"ts":1657710477}`) + assert.Equal(t, expected, b) +} + +func TestRated(t *testing.T) { + testCases := []struct { + name string + s interface{} + expected Metric + }{ + { + "counter", + &Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + &Counter{ + name: "mycounter", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + }, + { + "gauge", + &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewFloat64(42), + }, + &Gauge{ + name: "mygauge", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewFloat64(42), + }, + }, + { + "timer", + &Timer{ + name: "mytimer", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewDuration(1 * time.Second), + }, + &Timer{ + name: "mytimer", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewDuration(1 * time.Second), + }, + }, + { + "histogram", + &Histogram{ + name: "myhistogram", + metricType: typeHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + infValue: *atomic.NewInt64(0), + }, + &Histogram{ + name: "myhistogram", + metricType: typeRatedHistogram, + tags: map[string]string{"ololo": "trololo"}, + bucketBounds: []float64{1, 2, 3}, + infValue: *atomic.NewInt64(0), + }, + }, + { + "metric_interface", + metrics.Counter(&Counter{ + name: "mycounter", + metricType: typeCounter, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }), + &Counter{ + name: "mycounter", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + Rated(tc.s) + assert.Equal(t, tc.expected, tc.s) + }) + } +} + +func TestSplitToChunks(t *testing.T) { + zeroMetrics := Metrics{ + metrics: []Metric{}, + } + oneMetric := Metrics{ + metrics: []Metric{ + &Counter{name: "a"}, + }, + } + twoMetrics := Metrics{ + metrics: []Metric{ + &Counter{name: "a"}, + &Counter{name: "b"}, + }, + } + fourMetrics := Metrics{ + metrics: []Metric{ + &Counter{name: "a"}, + &Counter{name: "b"}, + &Counter{name: "c"}, + &Counter{name: "d"}, + }, + } + fiveMetrics := Metrics{ + metrics: []Metric{ + &Counter{name: "a"}, + &Counter{name: "b"}, + &Counter{name: "c"}, + &Counter{name: "d"}, + &Counter{name: "e"}, + }, + } + + chunks := zeroMetrics.SplitToChunks(2) + assert.Equal(t, 1, len(chunks)) + assert.Equal(t, 0, len(chunks[0].metrics)) + + chunks = oneMetric.SplitToChunks(1) + assert.Equal(t, 1, len(chunks)) + assert.Equal(t, 1, len(chunks[0].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + + chunks = oneMetric.SplitToChunks(2) + assert.Equal(t, 1, len(chunks)) + assert.Equal(t, 1, len(chunks[0].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + + chunks = twoMetrics.SplitToChunks(1) + assert.Equal(t, 2, len(chunks)) + assert.Equal(t, 1, len(chunks[0].metrics)) + assert.Equal(t, 1, len(chunks[1].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + assert.Equal(t, "b", chunks[1].metrics[0].Name()) + + chunks = twoMetrics.SplitToChunks(2) + assert.Equal(t, 1, len(chunks)) + assert.Equal(t, 2, len(chunks[0].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + assert.Equal(t, "b", chunks[0].metrics[1].Name()) + + chunks = fourMetrics.SplitToChunks(2) + assert.Equal(t, 2, len(chunks)) + assert.Equal(t, 2, len(chunks[0].metrics)) + assert.Equal(t, 2, len(chunks[1].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + assert.Equal(t, "b", chunks[0].metrics[1].Name()) + assert.Equal(t, "c", chunks[1].metrics[0].Name()) + assert.Equal(t, "d", chunks[1].metrics[1].Name()) + + chunks = fiveMetrics.SplitToChunks(2) + assert.Equal(t, 3, len(chunks)) + assert.Equal(t, 2, len(chunks[0].metrics)) + assert.Equal(t, 2, len(chunks[1].metrics)) + assert.Equal(t, 1, len(chunks[2].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + assert.Equal(t, "b", chunks[0].metrics[1].Name()) + assert.Equal(t, "c", chunks[1].metrics[0].Name()) + assert.Equal(t, "d", chunks[1].metrics[1].Name()) + assert.Equal(t, "e", chunks[2].metrics[0].Name()) + + chunks = fiveMetrics.SplitToChunks(0) + assert.Equal(t, 1, len(chunks)) + assert.Equal(t, 5, len(chunks[0].metrics)) + assert.Equal(t, "a", chunks[0].metrics[0].Name()) + assert.Equal(t, "b", chunks[0].metrics[1].Name()) + assert.Equal(t, "c", chunks[0].metrics[2].Name()) + assert.Equal(t, "d", chunks[0].metrics[3].Name()) + assert.Equal(t, "e", chunks[0].metrics[4].Name()) +} diff --git a/library/go/core/metrics/solomon/race_test.go b/library/go/core/metrics/solomon/race_test.go new file mode 100644 index 0000000000..32be6f34fb --- /dev/null +++ b/library/go/core/metrics/solomon/race_test.go @@ -0,0 +1,150 @@ +package solomon + +import ( + "bytes" + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +type spinBarrier struct { + count int64 + waiting atomic.Int64 + step atomic.Int64 +} + +func newSpinBarrier(size int) *spinBarrier { + return &spinBarrier{count: int64(size)} +} + +func (b *spinBarrier) wait() { + s := b.step.Load() + w := b.waiting.Add(1) + if w == b.count { + b.waiting.Store(0) + b.step.Add(1) + } else { + for s == b.step.Load() { + // noop + } + } +} + +func TestRaceDurationHistogramVecVersusStreamJson(t *testing.T) { + // Regression test: https://github.com/ydb-platform/ydb/review/2690822/details + registry := NewRegistry(NewRegistryOpts()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const stepCount = 200 + + barrier := newSpinBarrier(2) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + // Consumer + defer wg.Done() + out := bytes.NewBuffer(nil) + for i := 0; i < stepCount; i++ { + out.Reset() + barrier.wait() + _, err := registry.StreamJSON(ctx, out) + if err != nil { + require.ErrorIs(t, err, context.Canceled) + break + } + } + }() + + wg.Add(1) + go func() { + // Producer + defer wg.Done() + + const success = "success" + const version = "version" + vecs := make([]metrics.TimerVec, 0) + buckets := metrics.NewDurationBuckets(1, 2, 3) + ProducerLoop: + for i := 0; i < stepCount; i++ { + barrier.wait() + vec := registry.DurationHistogramVec( + fmt.Sprintf("latency-%v", i), + buckets, + []string{success, version}, + ) + Rated(vec) + vecs = append(vecs, vec) + for _, v := range vecs { + v.With(map[string]string{success: "ok", version: "123"}).RecordDuration(time.Second) + v.With(map[string]string{success: "false", version: "123"}).RecordDuration(time.Millisecond) + } + select { + case <-ctx.Done(): + break ProducerLoop + default: + // noop + } + } + }() + wg.Wait() +} + +func TestRaceDurationHistogramRecordDurationVersusStreamJson(t *testing.T) { + // Regression test: https://github.com/ydb-platform/ydb/review/2690822/details + + registry := NewRegistry(NewRegistryOpts()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const stepCount = 200 + barrier := newSpinBarrier(2) + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + // Consumer + defer wg.Done() + out := bytes.NewBuffer(nil) + for i := 0; i < stepCount; i++ { + out.Reset() + barrier.wait() + _, err := registry.StreamJSON(ctx, out) + if err != nil { + require.ErrorIs(t, err, context.Canceled) + break + } + } + }() + + wg.Add(1) + go func() { + // Producer + defer wg.Done() + + buckets := metrics.NewDurationBuckets(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + hist := registry.DurationHistogram("latency", buckets) + // Rated(hist) + + ProducerLoop: + for i := 0; i < stepCount; i++ { + barrier.wait() + hist.RecordDuration(time.Duration(i % 10)) + select { + case <-ctx.Done(): + break ProducerLoop + default: + // noop + } + } + }() + wg.Wait() +} diff --git a/library/go/core/metrics/solomon/registry.go b/library/go/core/metrics/solomon/registry.go new file mode 100644 index 0000000000..0ad4d9378a --- /dev/null +++ b/library/go/core/metrics/solomon/registry.go @@ -0,0 +1,256 @@ +package solomon + +import ( + "reflect" + "strconv" + "sync" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/metricsutil" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +var _ metrics.Registry = (*Registry)(nil) + +type Registry struct { + separator string + prefix string + tags map[string]string + rated bool + useNameTag bool + + subregistries map[string]*Registry + m *sync.Mutex + + metrics *sync.Map +} + +func NewRegistry(opts *RegistryOpts) *Registry { + r := &Registry{ + separator: ".", + useNameTag: false, + + subregistries: make(map[string]*Registry), + m: new(sync.Mutex), + + metrics: new(sync.Map), + } + + if opts != nil { + r.separator = string(opts.Separator) + r.prefix = opts.Prefix + r.tags = opts.Tags + r.rated = opts.Rated + r.useNameTag = opts.UseNameTag + for _, collector := range opts.Collectors { + collector(r) + } + } + + return r +} + +// Rated returns copy of registry with rated set to desired value. +func (r Registry) Rated(rated bool) metrics.Registry { + return &Registry{ + separator: r.separator, + prefix: r.prefix, + tags: r.tags, + rated: rated, + useNameTag: r.useNameTag, + + subregistries: r.subregistries, + m: r.m, + + metrics: r.metrics, + } +} + +// WithTags creates new sub-scope, where each metric has tags attached to it. +func (r Registry) WithTags(tags map[string]string) metrics.Registry { + return r.newSubregistry(r.prefix, registryutil.MergeTags(r.tags, tags)) +} + +// WithPrefix creates new sub-scope, where each metric has prefix added to it name. +func (r Registry) WithPrefix(prefix string) metrics.Registry { + return r.newSubregistry(registryutil.BuildFQName(r.separator, r.prefix, prefix), r.tags) +} + +// ComposeName builds FQ name with appropriate separator. +func (r Registry) ComposeName(parts ...string) string { + return registryutil.BuildFQName(r.separator, parts...) +} + +func (r Registry) Counter(name string) metrics.Counter { + s := &Counter{ + name: r.newMetricName(name), + metricType: typeCounter, + tags: r.tags, + + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Counter) +} + +func (r Registry) FuncCounter(name string, function func() int64) metrics.FuncCounter { + s := &FuncCounter{ + name: r.newMetricName(name), + metricType: typeCounter, + tags: r.tags, + function: function, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.FuncCounter) +} + +func (r Registry) Gauge(name string) metrics.Gauge { + s := &Gauge{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Gauge) +} + +func (r Registry) FuncGauge(name string, function func() float64) metrics.FuncGauge { + s := &FuncGauge{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + function: function, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.FuncGauge) +} + +func (r Registry) IntGauge(name string) metrics.IntGauge { + s := &IntGauge{ + name: r.newMetricName(name), + metricType: typeIGauge, + tags: r.tags, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.IntGauge) +} + +func (r Registry) FuncIntGauge(name string, function func() int64) metrics.FuncIntGauge { + s := &FuncIntGauge{ + name: r.newMetricName(name), + metricType: typeIGauge, + tags: r.tags, + function: function, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.FuncIntGauge) +} + +func (r Registry) Timer(name string) metrics.Timer { + s := &Timer{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Timer) +} + +func (r Registry) Histogram(name string, buckets metrics.Buckets) metrics.Histogram { + s := &Histogram{ + name: r.newMetricName(name), + metricType: typeHistogram, + tags: r.tags, + bucketBounds: metricsutil.BucketsBounds(buckets), + bucketValues: make([]int64, buckets.Size()), + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Histogram) +} + +func (r Registry) DurationHistogram(name string, buckets metrics.DurationBuckets) metrics.Timer { + s := &Histogram{ + name: r.newMetricName(name), + metricType: typeHistogram, + tags: r.tags, + bucketBounds: metricsutil.DurationBucketsBounds(buckets), + bucketValues: make([]int64, buckets.Size()), + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Timer) +} + +func (r *Registry) newSubregistry(prefix string, tags map[string]string) *Registry { + // differ simple and rated registries + keyTags := registryutil.MergeTags(tags, map[string]string{"rated": strconv.FormatBool(r.rated)}) + registryKey := registryutil.BuildRegistryKey(prefix, keyTags) + + r.m.Lock() + defer r.m.Unlock() + + if existing, ok := r.subregistries[registryKey]; ok { + return existing + } + + subregistry := &Registry{ + separator: r.separator, + prefix: prefix, + tags: tags, + rated: r.rated, + useNameTag: r.useNameTag, + + subregistries: r.subregistries, + m: r.m, + + metrics: r.metrics, + } + + r.subregistries[registryKey] = subregistry + return subregistry +} + +func (r *Registry) newMetricName(name string) string { + return registryutil.BuildFQName(r.separator, r.prefix, name) +} + +func (r *Registry) registerMetric(s Metric) Metric { + if r.rated { + Rated(s) + } + + key := r.metricKey(s) + + oldMetric, loaded := r.metrics.LoadOrStore(key, s) + if !loaded { + return s + } + + if reflect.TypeOf(oldMetric) == reflect.TypeOf(s) { + return oldMetric.(Metric) + } else { + r.metrics.Store(key, s) + return s + } +} + +func (r *Registry) unregisterMetric(s Metric) { + if r.rated { + Rated(s) + } + + r.metrics.Delete(r.metricKey(s)) +} + +func (r *Registry) metricKey(s Metric) string { + // differ simple and rated registries + keyTags := registryutil.MergeTags(r.tags, map[string]string{"rated": strconv.FormatBool(r.rated)}) + return registryutil.BuildRegistryKey(s.Name(), keyTags) +} diff --git a/library/go/core/metrics/solomon/registry_opts.go b/library/go/core/metrics/solomon/registry_opts.go new file mode 100644 index 0000000000..c3df17940a --- /dev/null +++ b/library/go/core/metrics/solomon/registry_opts.go @@ -0,0 +1,87 @@ +package solomon + +import ( + "context" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/collect" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +type RegistryOpts struct { + Separator rune + Prefix string + Tags map[string]string + Rated bool + UseNameTag bool + Collectors []func(metrics.Registry) +} + +// NewRegistryOpts returns new initialized instance of RegistryOpts +func NewRegistryOpts() *RegistryOpts { + return &RegistryOpts{ + Separator: '.', + Tags: make(map[string]string), + UseNameTag: false, + } +} + +// SetUseNameTag overrides current UseNameTag opt +func (o *RegistryOpts) SetUseNameTag(useNameTag bool) *RegistryOpts { + o.UseNameTag = useNameTag + return o +} + +// SetTags overrides existing tags +func (o *RegistryOpts) SetTags(tags map[string]string) *RegistryOpts { + o.Tags = tags + return o +} + +// AddTags merges given tags with existing +func (o *RegistryOpts) AddTags(tags map[string]string) *RegistryOpts { + for k, v := range tags { + o.Tags[k] = v + } + return o +} + +// SetPrefix overrides existing prefix +func (o *RegistryOpts) SetPrefix(prefix string) *RegistryOpts { + o.Prefix = prefix + return o +} + +// AppendPrefix adds given prefix as postfix to existing using separator +func (o *RegistryOpts) AppendPrefix(prefix string) *RegistryOpts { + o.Prefix = registryutil.BuildFQName(string(o.Separator), o.Prefix, prefix) + return o +} + +// SetSeparator overrides existing separator +func (o *RegistryOpts) SetSeparator(separator rune) *RegistryOpts { + o.Separator = separator + return o +} + +// SetRated overrides existing rated flag +func (o *RegistryOpts) SetRated(rated bool) *RegistryOpts { + o.Rated = rated + return o +} + +// AddCollectors adds collectors that handle their metrics automatically (e.g. system metrics). +func (o *RegistryOpts) AddCollectors( + ctx context.Context, c metrics.CollectPolicy, collectors ...collect.Func, +) *RegistryOpts { + if len(collectors) == 0 { + return o + } + + o.Collectors = append(o.Collectors, func(r metrics.Registry) { + for _, collector := range collectors { + collector(ctx, r, c) + } + }) + return o +} diff --git a/library/go/core/metrics/solomon/registry_test.go b/library/go/core/metrics/solomon/registry_test.go new file mode 100644 index 0000000000..a870203b31 --- /dev/null +++ b/library/go/core/metrics/solomon/registry_test.go @@ -0,0 +1,168 @@ +package solomon + +import ( + "encoding/json" + "fmt" + "reflect" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func TestRegistry_Gather(t *testing.T) { + r := &Registry{ + separator: ".", + prefix: "myprefix", + tags: make(map[string]string), + subregistries: make(map[string]*Registry), + metrics: func() *sync.Map { + metrics := map[string]Metric{ + "myprefix.mycounter": &Counter{ + name: "myprefix.mycounter", + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewInt64(42), + }, + "myprefix.mygauge": &Gauge{ + name: "myprefix.mygauge", + tags: map[string]string{"shimba": "boomba"}, + value: *atomic.NewFloat64(14.89), + }, + "myprefix.mytimer": &Timer{ + name: "myprefix.mytimer", + tags: map[string]string{"looken": "tooken"}, + value: *atomic.NewDuration(1456 * time.Millisecond), + }, + "myprefix.myhistogram": &Histogram{ + name: "myprefix.myhistogram", + tags: map[string]string{"chicken": "cooken"}, + bucketBounds: []float64{1, 2, 3}, + bucketValues: []int64{1, 2, 1}, + infValue: *atomic.NewInt64(1), + }, + } + + sm := new(sync.Map) + for k, v := range metrics { + sm.Store(k, v) + } + + return sm + }(), + } + + s, err := r.Gather() + assert.NoError(t, err) + + expected := &Metrics{} + r.metrics.Range(func(_, s interface{}) bool { + expected.metrics = append(expected.metrics, s.(Metric)) + return true + }) + + opts := cmp.Options{ + cmp.AllowUnexported(Metrics{}, Counter{}, Gauge{}, Timer{}, Histogram{}), + cmpopts.IgnoreUnexported(sync.Mutex{}, atomic.Duration{}, atomic.Int64{}, atomic.Float64{}), + // this will sort both slices for latest tests as well + cmpopts.SortSlices(func(x, y Metric) bool { + return x.Name() < y.Name() + }), + } + + assert.True(t, cmp.Equal(expected, s, opts...), cmp.Diff(expected, s, opts...)) + + for _, sen := range s.metrics { + var expectedMetric Metric + for _, expSen := range expected.metrics { + if expSen.Name() == sen.Name() { + expectedMetric = expSen + break + } + } + require.NotNil(t, expectedMetric) + + assert.NotEqual(t, fmt.Sprintf("%p", expectedMetric), fmt.Sprintf("%p", sen)) + assert.IsType(t, expectedMetric, sen) + + switch st := sen.(type) { + case *Counter: + assert.NotEqual(t, fmt.Sprintf("%p", expectedMetric.(*Counter)), fmt.Sprintf("%p", st)) + case *Gauge: + assert.NotEqual(t, fmt.Sprintf("%p", expectedMetric.(*Gauge)), fmt.Sprintf("%p", st)) + case *Timer: + assert.NotEqual(t, fmt.Sprintf("%p", expectedMetric.(*Timer)), fmt.Sprintf("%p", st)) + case *Histogram: + assert.NotEqual(t, fmt.Sprintf("%p", expectedMetric.(*Histogram)), fmt.Sprintf("%p", st)) + default: + t.Fatalf("unexpected metric type: %T", sen) + } + } +} + +func TestDoubleRegistration(t *testing.T) { + r := NewRegistry(NewRegistryOpts()) + + c0 := r.Counter("counter") + c1 := r.Counter("counter") + require.Equal(t, c0, c1) + + g0 := r.Gauge("counter") + g1 := r.Gauge("counter") + require.Equal(t, g0, g1) + + c2 := r.Counter("counter") + require.NotEqual(t, reflect.ValueOf(c0).Elem().UnsafeAddr(), reflect.ValueOf(c2).Elem().UnsafeAddr()) +} + +func TestSubregistry(t *testing.T) { + r := NewRegistry(NewRegistryOpts()) + + r0 := r.WithPrefix("one") + r1 := r0.WithPrefix("two") + r2 := r0.WithTags(map[string]string{"foo": "bar"}) + + _ = r0.Counter("counter") + _ = r1.Counter("counter") + _ = r2.Counter("counter") +} + +func TestSubregistry_TagAndPrefixReorder(t *testing.T) { + r := NewRegistry(NewRegistryOpts()) + + r0 := r.WithPrefix("one") + r1 := r.WithTags(map[string]string{"foo": "bar"}) + + r3 := r0.WithTags(map[string]string{"foo": "bar"}) + r4 := r1.WithPrefix("one") + + require.True(t, r3 == r4) +} + +func TestRatedRegistry(t *testing.T) { + r := NewRegistry(NewRegistryOpts().SetRated(true)) + s := r.Counter("counter") + b, _ := json.Marshal(s) + expected := []byte(`{"type":"RATE","labels":{"sensor":"counter"},"value":0}`) + assert.Equal(t, expected, b) +} + +func TestNameTagRegistry(t *testing.T) { + r := NewRegistry(NewRegistryOpts().SetUseNameTag(true)) + s := r.Counter("counter") + + b, _ := json.Marshal(s) + expected := []byte(`{"type":"COUNTER","labels":{"name":"counter"},"value":0}`) + assert.Equal(t, expected, b) + + sr := r.WithTags(map[string]string{"foo": "bar"}) + ssr := sr.Counter("sub_counter") + + b1, _ := json.Marshal(ssr) + expected1 := []byte(`{"type":"COUNTER","labels":{"foo":"bar","name":"sub_counter"},"value":0}`) + assert.Equal(t, expected1, b1) +} diff --git a/library/go/core/metrics/solomon/spack.go b/library/go/core/metrics/solomon/spack.go new file mode 100644 index 0000000000..48938d19b6 --- /dev/null +++ b/library/go/core/metrics/solomon/spack.go @@ -0,0 +1,387 @@ +package solomon + +import ( + "bytes" + "context" + "encoding/binary" + "io" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +type spackVersion uint16 + +const ( + version11 spackVersion = 0x0101 + version12 spackVersion = 0x0102 +) + +type errWriter struct { + w io.Writer + err error +} + +func (ew *errWriter) binaryWrite(data interface{}) { + if ew.err != nil { + return + } + switch t := data.(type) { + case uint8: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint8)) + case uint16: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint16)) + case uint32: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint32)) + default: + ew.err = xerrors.Errorf("binaryWrite not supported type %v", t) + } +} + +func writeULEB128(w io.Writer, value uint32) error { + remaining := value >> 7 + for remaining != 0 { + err := binary.Write(w, binary.LittleEndian, uint8(value&0x7f|0x80)) + if err != nil { + return xerrors.Errorf("binary.Write failed: %w", err) + } + value = remaining + remaining >>= 7 + } + err := binary.Write(w, binary.LittleEndian, uint8(value&0x7f)) + if err != nil { + return xerrors.Errorf("binary.Write failed: %w", err) + } + return err +} + +type spackMetric struct { + flags uint8 + + nameValueIndex uint32 + labelsCount uint32 + labels bytes.Buffer + + metric Metric +} + +func (s *spackMetric) writeLabelPool(se *spackEncoder, namesIdx map[string]uint32, valuesIdx map[string]uint32, name string, value string) error { + _, ok := namesIdx[name] + if !ok { + namesIdx[name] = se.nameCounter + se.nameCounter++ + _, err := se.labelNamePool.WriteString(name) + if err != nil { + return err + } + err = se.labelNamePool.WriteByte(0) + if err != nil { + return err + } + } + + _, ok = valuesIdx[value] + if !ok { + valuesIdx[value] = se.valueCounter + se.valueCounter++ + _, err := se.labelValuePool.WriteString(value) + if err != nil { + return err + } + err = se.labelValuePool.WriteByte(0) + if err != nil { + return err + } + } + + return nil +} + +func (s *spackMetric) writeLabel(se *spackEncoder, namesIdx map[string]uint32, valuesIdx map[string]uint32, name string, value string) error { + s.labelsCount++ + + err := s.writeLabelPool(se, namesIdx, valuesIdx, name, value) + if err != nil { + return err + } + + err = writeULEB128(&s.labels, uint32(namesIdx[name])) + if err != nil { + return err + } + err = writeULEB128(&s.labels, uint32(valuesIdx[value])) + if err != nil { + return err + } + + return nil +} + +func (s *spackMetric) writeMetric(w io.Writer, version spackVersion) error { + metricValueType := valueTypeOneWithoutTS + if s.metric.getTimestamp() != nil { + metricValueType = valueTypeOneWithTS + } + // library/cpp/monlib/encode/spack/spack_v1_encoder.cpp?rev=r9098142#L190 + types := uint8(s.metric.getType()<<2) | uint8(metricValueType) + err := binary.Write(w, binary.LittleEndian, types) + if err != nil { + return xerrors.Errorf("binary.Write types failed: %w", err) + } + + err = binary.Write(w, binary.LittleEndian, uint8(s.flags)) + if err != nil { + return xerrors.Errorf("binary.Write flags failed: %w", err) + } + if version >= version12 { + err = writeULEB128(w, uint32(s.nameValueIndex)) + if err != nil { + return xerrors.Errorf("writeULEB128 name value index: %w", err) + } + } + err = writeULEB128(w, uint32(s.labelsCount)) + if err != nil { + return xerrors.Errorf("writeULEB128 labels count failed: %w", err) + } + + _, err = w.Write(s.labels.Bytes()) // s.writeLabels(w) + if err != nil { + return xerrors.Errorf("write labels failed: %w", err) + } + if s.metric.getTimestamp() != nil { + err = binary.Write(w, binary.LittleEndian, uint32(s.metric.getTimestamp().Unix())) + if err != nil { + return xerrors.Errorf("write timestamp failed: %w", err) + } + } + + switch s.metric.getType() { + case typeGauge: + err = binary.Write(w, binary.LittleEndian, s.metric.getValue().(float64)) + if err != nil { + return xerrors.Errorf("binary.Write gauge value failed: %w", err) + } + case typeIGauge: + err = binary.Write(w, binary.LittleEndian, s.metric.getValue().(int64)) + if err != nil { + return xerrors.Errorf("binary.Write igauge value failed: %w", err) + } + case typeCounter, typeRated: + err = binary.Write(w, binary.LittleEndian, uint64(s.metric.getValue().(int64))) + if err != nil { + return xerrors.Errorf("binary.Write counter value failed: %w", err) + } + case typeHistogram, typeRatedHistogram: + h := s.metric.getValue().(histogram) + err = h.writeHistogram(w) + if err != nil { + return xerrors.Errorf("writeHistogram failed: %w", err) + } + default: + return xerrors.Errorf("unknown metric type: %v", s.metric.getType()) + } + return nil +} + +type SpackOpts func(*spackEncoder) + +func WithVersion12() func(*spackEncoder) { + return func(se *spackEncoder) { + se.version = version12 + } +} + +type spackEncoder struct { + context context.Context + compression uint8 + version spackVersion + + nameCounter uint32 + valueCounter uint32 + + labelNamePool bytes.Buffer + labelValuePool bytes.Buffer + + metrics Metrics +} + +func NewSpackEncoder(ctx context.Context, compression CompressionType, metrics *Metrics, opts ...SpackOpts) *spackEncoder { + if metrics == nil { + metrics = &Metrics{} + } + se := &spackEncoder{ + context: ctx, + compression: uint8(compression), + version: version11, + metrics: *metrics, + } + for _, op := range opts { + op(se) + } + return se +} + +func (se *spackEncoder) writeLabels() ([]spackMetric, error) { + namesIdx := make(map[string]uint32) + valuesIdx := make(map[string]uint32) + spackMetrics := make([]spackMetric, len(se.metrics.metrics)) + + for idx, metric := range se.metrics.metrics { + m := spackMetric{metric: metric} + + var err error + if se.version >= version12 { + err = m.writeLabelPool(se, namesIdx, valuesIdx, metric.getNameTag(), metric.Name()) + m.nameValueIndex = valuesIdx[metric.getNameTag()] + } else { + err = m.writeLabel(se, namesIdx, valuesIdx, metric.getNameTag(), metric.Name()) + } + if err != nil { + return nil, err + } + + for name, value := range metric.getLabels() { + if err := m.writeLabel(se, namesIdx, valuesIdx, name, value); err != nil { + return nil, err + } + + } + spackMetrics[idx] = m + } + + return spackMetrics, nil +} + +func (se *spackEncoder) Encode(w io.Writer) (written int, err error) { + spackMetrics, err := se.writeLabels() + if err != nil { + return written, xerrors.Errorf("writeLabels failed: %w", err) + } + + err = se.writeHeader(w) + if err != nil { + return written, xerrors.Errorf("writeHeader failed: %w", err) + } + written += HeaderSize + compression := CompressionType(se.compression) + + cw := newCompressedWriter(w, compression) + + err = se.writeLabelNamesPool(cw) + if err != nil { + return written, xerrors.Errorf("writeLabelNamesPool failed: %w", err) + } + + err = se.writeLabelValuesPool(cw) + if err != nil { + return written, xerrors.Errorf("writeLabelValuesPool failed: %w", err) + } + + err = se.writeCommonTime(cw) + if err != nil { + return written, xerrors.Errorf("writeCommonTime failed: %w", err) + } + + err = se.writeCommonLabels(cw) + if err != nil { + return written, xerrors.Errorf("writeCommonLabels failed: %w", err) + } + + err = se.writeMetricsData(cw, spackMetrics) + if err != nil { + return written, xerrors.Errorf("writeMetricsData failed: %w", err) + } + + err = cw.Close() + if err != nil { + return written, xerrors.Errorf("close failed: %w", err) + } + + switch compression { + case CompressionNone: + written += cw.(*noCompressionWriteCloser).written + case CompressionLz4: + written += cw.(*lz4CompressionWriteCloser).written + } + + return written, nil +} + +func (se *spackEncoder) writeHeader(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + ew := &errWriter{w: w} + ew.binaryWrite(uint16(0x5053)) // Magic + ew.binaryWrite(uint16(se.version)) // Version + ew.binaryWrite(uint16(24)) // HeaderSize + ew.binaryWrite(uint8(0)) // TimePrecision(SECONDS) + ew.binaryWrite(uint8(se.compression)) // CompressionAlg + ew.binaryWrite(uint32(se.labelNamePool.Len())) // LabelNamesSize + ew.binaryWrite(uint32(se.labelValuePool.Len())) // LabelValuesSize + ew.binaryWrite(uint32(len(se.metrics.metrics))) // MetricsCount + ew.binaryWrite(uint32(len(se.metrics.metrics))) // PointsCount + if ew.err != nil { + return xerrors.Errorf("binaryWrite failed: %w", ew.err) + } + return nil +} + +func (se *spackEncoder) writeLabelNamesPool(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + _, err := w.Write(se.labelNamePool.Bytes()) + if err != nil { + return xerrors.Errorf("write labelNamePool failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeLabelValuesPool(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + _, err := w.Write(se.labelValuePool.Bytes()) + if err != nil { + return xerrors.Errorf("write labelValuePool failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeCommonTime(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + if se.metrics.timestamp == nil { + return binary.Write(w, binary.LittleEndian, uint32(0)) + } + return binary.Write(w, binary.LittleEndian, uint32(se.metrics.timestamp.Unix())) +} + +func (se *spackEncoder) writeCommonLabels(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + _, err := w.Write([]byte{0}) + if err != nil { + return xerrors.Errorf("write commonLabels failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeMetricsData(w io.Writer, metrics []spackMetric) error { + for _, s := range metrics { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + err := s.writeMetric(w, se.version) + if err != nil { + return xerrors.Errorf("write metric failed: %w", err) + } + } + return nil +} diff --git a/library/go/core/metrics/solomon/spack_compression.go b/library/go/core/metrics/solomon/spack_compression.go new file mode 100644 index 0000000000..004fe0150d --- /dev/null +++ b/library/go/core/metrics/solomon/spack_compression.go @@ -0,0 +1,162 @@ +package solomon + +import ( + "encoding/binary" + "io" + + "github.com/OneOfOne/xxhash" + "github.com/pierrec/lz4" +) + +type CompressionType uint8 + +const ( + CompressionNone CompressionType = 0x0 + CompressionZlib CompressionType = 0x1 + CompressionZstd CompressionType = 0x2 + CompressionLz4 CompressionType = 0x3 +) + +const ( + compressionFrameLength = 512 * 1024 + hashTableSize = 64 * 1024 +) + +type noCompressionWriteCloser struct { + underlying io.Writer + written int +} + +func (w *noCompressionWriteCloser) Write(p []byte) (int, error) { + n, err := w.underlying.Write(p) + w.written += n + return n, err +} + +func (w *noCompressionWriteCloser) Close() error { + return nil +} + +type lz4CompressionWriteCloser struct { + underlying io.Writer + buffer []byte + table []int + written int +} + +func (w *lz4CompressionWriteCloser) flushFrame() (written int, err error) { + src := w.buffer + dst := make([]byte, lz4.CompressBlockBound(len(src))) + + sz, err := lz4.CompressBlock(src, dst, w.table) + if err != nil { + return written, err + } + + if sz == 0 { + dst = src + } else { + dst = dst[:sz] + } + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(len(dst))) + if err != nil { + return written, err + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(len(src))) + if err != nil { + return written, err + } + w.written += 4 + + n, err := w.underlying.Write(dst) + if err != nil { + return written, err + } + w.written += n + + checksum := xxhash.Checksum32S(dst, 0x1337c0de) + err = binary.Write(w.underlying, binary.LittleEndian, checksum) + if err != nil { + return written, err + } + w.written += 4 + + w.buffer = w.buffer[:0] + + return written, nil +} + +func (w *lz4CompressionWriteCloser) Write(p []byte) (written int, err error) { + q := p[:] + for len(q) > 0 { + space := compressionFrameLength - len(w.buffer) + if space == 0 { + n, err := w.flushFrame() + if err != nil { + return written, err + } + w.written += n + space = compressionFrameLength + } + length := len(q) + if length > space { + length = space + } + w.buffer = append(w.buffer, q[:length]...) + q = q[length:] + } + return written, nil +} + +func (w *lz4CompressionWriteCloser) Close() error { + var err error + if len(w.buffer) > 0 { + n, err := w.flushFrame() + if err != nil { + return err + } + w.written += n + } + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + return nil +} + +func newCompressedWriter(w io.Writer, compression CompressionType) io.WriteCloser { + switch compression { + case CompressionNone: + return &noCompressionWriteCloser{w, 0} + case CompressionZlib: + panic("zlib compression not supported") + case CompressionZstd: + panic("zstd compression not supported") + case CompressionLz4: + return &lz4CompressionWriteCloser{ + w, + make([]byte, 0, compressionFrameLength), + make([]int, hashTableSize), + 0, + } + default: + panic("unsupported compression algorithm") + } +} diff --git a/library/go/core/metrics/solomon/spack_compression_test.go b/library/go/core/metrics/solomon/spack_compression_test.go new file mode 100644 index 0000000000..baa8a8d1e9 --- /dev/null +++ b/library/go/core/metrics/solomon/spack_compression_test.go @@ -0,0 +1,26 @@ +package solomon + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func compress(t *testing.T, c uint8, s string) []byte { + buf := bytes.Buffer{} + w := newCompressedWriter(&buf, CompressionType(c)) + _, err := w.Write([]byte(s)) + assert.Equal(t, nil, err) + assert.Equal(t, nil, w.Close()) + return buf.Bytes() +} + +func TestCompression_None(t *testing.T) { + assert.Equal(t, []byte(nil), compress(t, uint8(CompressionNone), "")) + assert.Equal(t, []byte{'a'}, compress(t, uint8(CompressionNone), "a")) +} + +func TestCompression_Lz4(t *testing.T) { + assert.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, compress(t, uint8(CompressionLz4), "")) +} diff --git a/library/go/core/metrics/solomon/spack_test.go b/library/go/core/metrics/solomon/spack_test.go new file mode 100644 index 0000000000..64b504bf42 --- /dev/null +++ b/library/go/core/metrics/solomon/spack_test.go @@ -0,0 +1,184 @@ +package solomon + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_metrics_encode(t *testing.T) { + expectHeader := []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x8, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + // label names pool + 0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x0, // "sensor" + // label values pool + 0x6d, 0x79, 0x67, 0x61, 0x75, 0x67, 0x65, 0x0, // "gauge" + } + + testCases := []struct { + name string + metrics *Metrics + expectCommonTime []byte + expectCommonLabels []byte + expectMetrics [][]byte + expectWritten int + }{ + { + "common-ts+gauge", + &Metrics{ + metrics: []Metric{ + func() Metric { + g := NewGauge("mygauge", 43) + return &g + }(), + }, + timestamp: timeAsRef(time.Unix(1500000000, 0)), + }, + []byte{0x0, 0x2f, 0x68, 0x59}, // common time /1500000000 + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x5, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x45, 0x40, // 43 // metrics value + + }, + }, + 57, + }, + { + "gauge+ts", + &Metrics{ + metrics: []Metric{ + func() Metric { + g := NewGauge("mygauge", 43, WithTimestamp(time.Unix(1657710476, 0))) + return &g + }(), + }, + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x6, // uint8(typeGauge << 2) | uint8(valueTypeOneWithTS) + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x8c, 0xa7, 0xce, 0x62, //metric ts + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x45, 0x40, // 43 // metrics value + + }, + }, + 61, + }, + { + "common-ts+gauge+ts", + &Metrics{ + metrics: []Metric{ + func() Metric { + g := NewGauge("mygauge", 43, WithTimestamp(time.Unix(1657710476, 0))) + return &g + }(), + func() Metric { + g := NewGauge("mygauge", 42, WithTimestamp(time.Unix(1500000000, 0))) + return &g + }(), + }, + timestamp: timeAsRef(time.Unix(1500000000, 0)), + }, + []byte{0x0, 0x2f, 0x68, 0x59}, // common time /1500000000 + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x6, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x8c, 0xa7, 0xce, 0x62, //metric ts + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x45, 0x40, // 43 // metrics value + + }, + { + 0x6, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x0, 0x2f, 0x68, 0x59, // metric ts + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x45, 0x40, //42 // metrics value + + }, + }, + 78, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + ctx := context.Background() + + written, err := NewSpackEncoder(ctx, CompressionNone, tc.metrics).Encode(&buf) + + assert.NoError(t, err) + assert.Equal(t, tc.expectWritten, written) + + body := buf.Bytes() + setMetricsCount(expectHeader, len(tc.metrics.metrics)) + + require.True(t, bytes.HasPrefix(body, expectHeader)) + body = body[len(expectHeader):] + + require.True(t, bytes.HasPrefix(body, tc.expectCommonTime)) + body = body[len(tc.expectCommonTime):] + + require.True(t, bytes.HasPrefix(body, tc.expectCommonLabels)) + body = body[len(tc.expectCommonLabels):] + + expectButMissing := [][]byte{} + for range tc.expectMetrics { + var seen bool + var val []byte + for _, v := range tc.expectMetrics { + val = v + if bytes.HasPrefix(body, v) { + body = bytes.Replace(body, v, []byte{}, 1) + seen = true + break + } + } + if !seen { + expectButMissing = append(expectButMissing, val) + } + } + assert.Empty(t, body, "unexpected bytes seen") + assert.Empty(t, expectButMissing, "missing metrics bytes") + }) + } +} + +func setMetricsCount(header []byte, count int) { + header[16] = uint8(count) + header[20] = uint8(count) +} diff --git a/library/go/core/metrics/solomon/stream.go b/library/go/core/metrics/solomon/stream.go new file mode 100644 index 0000000000..7cf6d70064 --- /dev/null +++ b/library/go/core/metrics/solomon/stream.go @@ -0,0 +1,89 @@ +package solomon + +import ( + "context" + "encoding/json" + "io" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +const HeaderSize = 24 + +type StreamFormat string + +func (r *Registry) StreamJSON(ctx context.Context, w io.Writer) (written int, err error) { + cw := newCompressedWriter(w, CompressionNone) + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + _, err = cw.Write([]byte("{\"metrics\":[")) + if err != nil { + return written, xerrors.Errorf("write metrics failed: %w", err) + } + + first := true + r.metrics.Range(func(_, s interface{}) bool { + if ctx.Err() != nil { + err = xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + return false + } + + // write trailing comma + if !first { + _, err = cw.Write([]byte(",")) + if err != nil { + err = xerrors.Errorf("write metrics failed: %w", err) + return false + } + } + + var b []byte + + b, err = json.Marshal(s) + if err != nil { + err = xerrors.Errorf("marshal metric failed: %w", err) + return false + } + + // write metric json + _, err = cw.Write(b) + if err != nil { + err = xerrors.Errorf("write metric failed: %w", err) + return false + } + + first = false + return true + }) + if err != nil { + return written, err + } + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + _, err = cw.Write([]byte("]}")) + if err != nil { + return written, xerrors.Errorf("write metrics failed: %w", err) + } + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + err = cw.Close() + if err != nil { + return written, xerrors.Errorf("close failed: %w", err) + } + + return cw.(*noCompressionWriteCloser).written, nil +} + +func (r *Registry) StreamSpack(ctx context.Context, w io.Writer, compression CompressionType) (int, error) { + metrics, err := r.Gather() + if err != nil { + return 0, err + } + return NewSpackEncoder(ctx, compression, metrics).Encode(w) +} diff --git a/library/go/core/metrics/solomon/stream_test.go b/library/go/core/metrics/solomon/stream_test.go new file mode 100644 index 0000000000..7548f77dbb --- /dev/null +++ b/library/go/core/metrics/solomon/stream_test.go @@ -0,0 +1,595 @@ +package solomon + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +func Test_streamJson(t *testing.T) { + testCases := []struct { + name string + registry *Registry + expect string + expectWritten int + expectErr error + }{ + { + "success", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + cnt := r.Counter("mycounter") + cnt.Add(42) + + gg := r.Gauge("mygauge") + gg.Set(2) + + return r + }(), + `{"metrics":[{"type":"COUNTER","labels":{"sensor":"mycounter"},"value":42},{"type":"DGAUGE","labels":{"sensor":"mygauge"},"value":2}]}`, + 133, + nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + ctx := context.Background() + + written, err := tc.registry.StreamJSON(ctx, w) + + if tc.expectErr == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.expectErr.Error()) + } + + assert.Equal(t, tc.expectWritten, written) + assert.Equal(t, len(tc.expect), w.Body.Len()) + + if tc.expect != "" { + var expectedObj, givenObj map[string]interface{} + err = json.Unmarshal([]byte(tc.expect), &expectedObj) + assert.NoError(t, err) + err = json.Unmarshal(w.Body.Bytes(), &givenObj) + assert.NoError(t, err) + + sameMap(t, expectedObj, givenObj) + } + }) + } +} + +func Test_streamSpack(t *testing.T) { + testCases := []struct { + name string + registry *Registry + compression CompressionType + expectHeader []byte + expectLabelNamesPool [][]byte + expectValueNamesPool [][]byte + expectCommonTime []byte + expectCommonLabels []byte + expectMetrics [][]byte + expectWritten int + }{ + { + "counter", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + cnt := r.Counter("counter") + cnt.Add(42) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x8, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72}, // "counter" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x9, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x2a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // 42 // metrics value + }, + }, + 57, + }, + { + "counter_lz4", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + cnt := r.Counter("counter") + cnt.Add(0) + + return r + }(), + CompressionLz4, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x3, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x8, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + 0x23, 0x00, 0x00, 0x00, // compressed length + 0x21, 0x00, 0x00, 0x00, // uncompressed length + 0xf0, 0x12, + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72}, // "counter" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x9, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0 //metrics value + 0x10, 0x11, 0xa4, 0x22, // checksum + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // end stream + }, + }, + 83, + }, + { + "rate", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + cnt := r.Counter("counter") + Rated(cnt) + cnt.Add(0) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x8, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72}, // "counter" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0xd, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, //42 // metrics value + }, + }, + 57, + }, + { + "timer", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + t := r.Timer("timer") + t.RecordDuration(2 * time.Second) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x6, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x74, 0x69, 0x6d, 0x65, 0x72}, // "timer" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x5, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, //2.0 // metrics value + }, + }, + 55, + }, + { + "gauge", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + g := r.Gauge("gauge") + g.Set(42) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x6, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x67, 0x61, 0x75, 0x67, 0x65}, // "gauge" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + 0x5, // types + 0x0, // flags + 0x1, // labels index size + 0x0, // indexes of name labels + 0x0, // indexes of value labels + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x45, 0x40, //42 // metrics value + + }, + }, + 55, + }, + { + "histogram", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + _ = r.Histogram("histogram", metrics.NewBuckets(0, 0.1, 0.11)) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0xa, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x68, 0x69, 0x73, 0x74, 0x6F, 0x67, 0x72, 0x61, 0x6D}, // "histogram" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + /*types*/ 0x15, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x0, + /*buckets count*/ 0x3, + /*upper bound 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*upper bound 1*/ 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f, + /*upper bound 2*/ 0x29, 0x5c, 0x8f, 0xc2, 0xf5, 0x28, 0xbc, 0x3f, + /*counter 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 1*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 2*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + }, + 100, + }, + { + "rate_histogram", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + h := r.Histogram("histogram", metrics.NewBuckets(0, 0.1, 0.11)) + Rated(h) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0xa, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x68, 0x69, 0x73, 0x74, 0x6F, 0x67, 0x72, 0x61, 0x6D}, // "histogram" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + /*types*/ 0x19, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x0, + /*buckets count*/ 0x3, + /*upper bound 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*upper bound 1*/ 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f, + /*upper bound 2*/ 0x29, 0x5c, 0x8f, 0xc2, 0xf5, 0x28, 0xbc, 0x3f, + /*counter 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 1*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 2*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + }, + 100, + }, + { + "counter+timer", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + cnt := r.Counter("counter") + cnt.Add(42) + + t := r.Timer("timer") + t.RecordDuration(2 * time.Second) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0xe, 0x0, 0x0, 0x0, // label values size + 0x2, 0x0, 0x0, 0x0, // metric count + 0x2, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72}, // "counter" + {0x74, 0x69, 0x6d, 0x65, 0x72}, // "timer" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + /*types*/ 0x9, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x0, + /*metrics value*/ 0x2a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, //42 + }, + { + /*types*/ 0x5, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x1, + /*metrics value*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, //2.0 + + }, + }, + 76, + }, + { + "gauge+histogram", + func() *Registry { + r := NewRegistry(NewRegistryOpts()) + + g := r.Gauge("gauge") + g.Set(42) + + _ = r.Histogram("histogram", metrics.NewBuckets(0, 0.1, 0.11)) + + return r + }(), + CompressionNone, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0x10, 0x0, 0x0, 0x0, // label values size + 0x2, 0x0, 0x0, 0x0, // metric count + 0x2, 0x0, 0x0, 0x0, // point count + }, + [][]byte{ // label names pool + {0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72}, // "sensor" + }, + [][]byte{ // label values pool + {0x67, 0x61, 0x75, 0x67, 0x65}, // "gauge" + {0x68, 0x69, 0x73, 0x74, 0x6F, 0x67, 0x72, 0x61, 0x6D}, // "histogram" + }, + []byte{0x0, 0x0, 0x0, 0x0}, // common time + []byte{0x0}, // common labels count and indexes + [][]byte{ + { + + /*types*/ 0x5, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x0, + /*metrics value*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x45, 0x40, //42 + }, + { + /*types*/ 0x15, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x1, + /*buckets count*/ 0x3, + /*upper bound 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*upper bound 1*/ 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f, + /*upper bound 2*/ 0x29, 0x5c, 0x8f, 0xc2, 0xf5, 0x28, 0xbc, 0x3f, + /*counter 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 1*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 2*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + }, + 119, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + ctx := context.Background() + + written, err := tc.registry.StreamSpack(ctx, w, tc.compression) + + assert.NoError(t, err) + assert.Equal(t, tc.expectWritten, written) + body := w.Body.Bytes() + require.True(t, bytes.HasPrefix(body, tc.expectHeader)) + body = body[len(tc.expectHeader):] + + t.Logf("expectLabelNamesPool: %v", tc.expectLabelNamesPool) + labelNamesPoolBytes := body[:len(bytes.Join(tc.expectLabelNamesPool, []byte{0x0}))+1] + labelNamesPool := bytes.Split(bytes.Trim(labelNamesPoolBytes, "\x00"), []byte{0x0}) + require.ElementsMatch(t, tc.expectLabelNamesPool, labelNamesPool) + body = body[len(labelNamesPoolBytes):] + + t.Logf("expectValueNamesPool: %v", tc.expectValueNamesPool) + valueNamesPoolBytes := body[:len(bytes.Join(tc.expectValueNamesPool, []byte{0x0}))+1] + valueNamesPool := bytes.Split(bytes.Trim(valueNamesPoolBytes, "\x00"), []byte{0x0}) + require.ElementsMatch(t, tc.expectValueNamesPool, valueNamesPool) + body = body[len(valueNamesPoolBytes):] + + require.True(t, bytes.HasPrefix(body, tc.expectCommonTime)) + body = body[len(tc.expectCommonTime):] + + require.True(t, bytes.HasPrefix(body, tc.expectCommonLabels)) + body = body[len(tc.expectCommonLabels):] + + expectButMissing := [][]byte{} + for idx := range tc.expectMetrics { + var seen bool + var val []byte + for _, v := range tc.expectMetrics { + val = v[:] + fixValueNameIndex(idx, val) + if bytes.HasPrefix(body, val) { + body = bytes.Replace(body, val, []byte{}, 1) + seen = true + break + } + } + if !seen { + expectButMissing = append(expectButMissing, val) + } + } + assert.Empty(t, body, "unexpected bytes seen") + assert.Empty(t, expectButMissing, "missing metrics bytes") + }) + } +} + +func fixValueNameIndex(idx int, metric []byte) { + // ASSUMPTION_FOR_TESTS: the size of the index is always equal to one + // That is, the number of points in the metric is always one + metric[4] = uint8(idx) // fix value name index +} + +func sameMap(t *testing.T, expected, actual map[string]interface{}) bool { + if !assert.Len(t, actual, len(expected)) { + return false + } + + for k := range expected { + actualMetric := actual[k] + if !assert.NotNil(t, actualMetric, "expected key %q not found", k) { + return false + } + + if !assert.ElementsMatch(t, expected[k], actualMetric, "%q must have same elements", k) { + return false + } + } + return true +} diff --git a/library/go/core/metrics/solomon/timer.go b/library/go/core/metrics/solomon/timer.go new file mode 100644 index 0000000000..d36940a9f7 --- /dev/null +++ b/library/go/core/metrics/solomon/timer.go @@ -0,0 +1,91 @@ +package solomon + +import ( + "encoding/json" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "go.uber.org/atomic" +) + +var ( + _ metrics.Timer = (*Timer)(nil) + _ Metric = (*Timer)(nil) +) + +// Timer measures gauge duration. +type Timer struct { + name string + metricType metricType + tags map[string]string + value atomic.Duration + timestamp *time.Time + + useNameTag bool +} + +func (t *Timer) RecordDuration(value time.Duration) { + t.value.Store(value) +} + +func (t *Timer) Name() string { + return t.name +} + +func (t *Timer) getType() metricType { + return t.metricType +} + +func (t *Timer) getLabels() map[string]string { + return t.tags +} + +func (t *Timer) getValue() interface{} { + return t.value.Load().Seconds() +} + +func (t *Timer) getTimestamp() *time.Time { + return t.timestamp +} + +func (t *Timer) getNameTag() string { + if t.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (t *Timer) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: t.metricType.String(), + Value: t.value.Load().Seconds(), + Labels: func() map[string]string { + labels := make(map[string]string, len(t.tags)+1) + labels[t.getNameTag()] = t.Name() + for k, v := range t.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(t.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (t *Timer) Snapshot() Metric { + return &Timer{ + name: t.name, + metricType: t.metricType, + tags: t.tags, + value: *atomic.NewDuration(t.value.Load()), + + useNameTag: t.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/timer_test.go b/library/go/core/metrics/solomon/timer_test.go new file mode 100644 index 0000000000..4904815701 --- /dev/null +++ b/library/go/core/metrics/solomon/timer_test.go @@ -0,0 +1,56 @@ +package solomon + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestTimer_RecordDuration(t *testing.T) { + c := &Timer{ + name: "mytimer", + metricType: typeGauge, + tags: map[string]string{"ololo": "trololo"}, + } + + c.RecordDuration(1 * time.Second) + assert.Equal(t, 1*time.Second, c.value.Load()) + + c.RecordDuration(42 * time.Millisecond) + assert.Equal(t, 42*time.Millisecond, c.value.Load()) +} + +func TestTimerRated_MarshalJSON(t *testing.T) { + c := &Timer{ + name: "mytimer", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewDuration(42 * time.Millisecond), + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"RATE","labels":{"ololo":"trololo","sensor":"mytimer"},"value":0.042}`) + assert.Equal(t, expected, b) +} + +func TestNameTagTimer_MarshalJSON(t *testing.T) { + c := &Timer{ + name: "mytimer", + metricType: typeRated, + tags: map[string]string{"ololo": "trololo"}, + value: *atomic.NewDuration(42 * time.Millisecond), + + useNameTag: true, + } + + b, err := json.Marshal(c) + assert.NoError(t, err) + + expected := []byte(`{"type":"RATE","labels":{"name":"mytimer","ololo":"trololo"},"value":0.042}`) + assert.Equal(t, expected, b) +} diff --git a/library/go/core/metrics/solomon/vec.go b/library/go/core/metrics/solomon/vec.go new file mode 100644 index 0000000000..323919e9f8 --- /dev/null +++ b/library/go/core/metrics/solomon/vec.go @@ -0,0 +1,279 @@ +package solomon + +import ( + "sync" + + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/internal/pkg/registryutil" +) + +// metricsVector is a base implementation of vector of metrics of any supported type. +type metricsVector struct { + labels []string + mtx sync.RWMutex // Protects metrics. + metrics map[uint64]Metric + rated bool + newMetric func(map[string]string) Metric + removeMetric func(m Metric) +} + +func (v *metricsVector) with(tags map[string]string) Metric { + hv, err := registryutil.VectorHash(tags, v.labels) + if err != nil { + panic(err) + } + + v.mtx.RLock() + metric, ok := v.metrics[hv] + v.mtx.RUnlock() + if ok { + return metric + } + + v.mtx.Lock() + defer v.mtx.Unlock() + + metric, ok = v.metrics[hv] + if !ok { + metric = v.newMetric(tags) + v.metrics[hv] = metric + } + + return metric +} + +// reset deletes all metrics in this vector. +func (v *metricsVector) reset() { + v.mtx.Lock() + defer v.mtx.Unlock() + + for h, m := range v.metrics { + delete(v.metrics, h) + v.removeMetric(m) + } +} + +var _ metrics.CounterVec = (*CounterVec)(nil) + +// CounterVec stores counters and +// implements metrics.CounterVec interface. +type CounterVec struct { + vec *metricsVector +} + +// CounterVec creates a new counters vector with given metric name and +// partitioned by the given label names. +func (r *Registry) CounterVec(name string, labels []string) metrics.CounterVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + Counter(name).(*Counter) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + } + return &CounterVec{vec: vec} +} + +// With creates new or returns existing counter with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *CounterVec) With(tags map[string]string) metrics.Counter { + return v.vec.with(tags).(*Counter) +} + +// Reset deletes all metrics in this vector. +func (v *CounterVec) Reset() { + v.vec.reset() +} + +var _ metrics.GaugeVec = (*GaugeVec)(nil) + +// GaugeVec stores gauges and +// implements metrics.GaugeVec interface. +type GaugeVec struct { + vec *metricsVector +} + +// GaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) GaugeVec(name string, labels []string) metrics.GaugeVec { + return &GaugeVec{ + vec: &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return r.WithTags(tags).Gauge(name).(*Gauge) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + }, + } +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *GaugeVec) With(tags map[string]string) metrics.Gauge { + return v.vec.with(tags).(*Gauge) +} + +// Reset deletes all metrics in this vector. +func (v *GaugeVec) Reset() { + v.vec.reset() +} + +var _ metrics.IntGaugeVec = (*IntGaugeVec)(nil) + +// IntGaugeVec stores gauges and +// implements metrics.IntGaugeVec interface. +type IntGaugeVec struct { + vec *metricsVector +} + +// IntGaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) IntGaugeVec(name string, labels []string) metrics.IntGaugeVec { + return &IntGaugeVec{ + vec: &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return r.WithTags(tags).IntGauge(name).(*IntGauge) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + }, + } +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *IntGaugeVec) With(tags map[string]string) metrics.IntGauge { + return v.vec.with(tags).(*IntGauge) +} + +// Reset deletes all metrics in this vector. +func (v *IntGaugeVec) Reset() { + v.vec.reset() +} + +var _ metrics.TimerVec = (*TimerVec)(nil) + +// TimerVec stores timers and +// implements metrics.TimerVec interface. +type TimerVec struct { + vec *metricsVector +} + +// TimerVec creates a new timers vector with given metric name and +// partitioned by the given label names. +func (r *Registry) TimerVec(name string, labels []string) metrics.TimerVec { + return &TimerVec{ + vec: &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return r.WithTags(tags).Timer(name).(*Timer) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + }, + } +} + +// With creates new or returns existing timer with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *TimerVec) With(tags map[string]string) metrics.Timer { + return v.vec.with(tags).(*Timer) +} + +// Reset deletes all metrics in this vector. +func (v *TimerVec) Reset() { + v.vec.reset() +} + +var _ metrics.HistogramVec = (*HistogramVec)(nil) + +// HistogramVec stores histograms and +// implements metrics.HistogramVec interface. +type HistogramVec struct { + vec *metricsVector +} + +// HistogramVec creates a new histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) HistogramVec(name string, buckets metrics.Buckets, labels []string) metrics.HistogramVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + Histogram(name, buckets).(*Histogram) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + } + return &HistogramVec{vec: vec} +} + +// With creates new or returns existing histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *HistogramVec) With(tags map[string]string) metrics.Histogram { + return v.vec.with(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *HistogramVec) Reset() { + v.vec.reset() +} + +var _ metrics.TimerVec = (*DurationHistogramVec)(nil) + +// DurationHistogramVec stores duration histograms and +// implements metrics.TimerVec interface. +type DurationHistogramVec struct { + vec *metricsVector +} + +// DurationHistogramVec creates a new duration histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) DurationHistogramVec(name string, buckets metrics.DurationBuckets, labels []string) metrics.TimerVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + DurationHistogram(name, buckets).(*Histogram) + }, + removeMetric: func(m Metric) { + r.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + } + return &DurationHistogramVec{vec: vec} +} + +// With creates new or returns existing duration histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *DurationHistogramVec) With(tags map[string]string) metrics.Timer { + return v.vec.with(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *DurationHistogramVec) Reset() { + v.vec.reset() +} diff --git a/library/go/core/metrics/solomon/vec_test.go b/library/go/core/metrics/solomon/vec_test.go new file mode 100644 index 0000000000..cac437f434 --- /dev/null +++ b/library/go/core/metrics/solomon/vec_test.go @@ -0,0 +1,339 @@ +package solomon + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/metrics" +) + +func TestVec(t *testing.T) { + type args struct { + name string + labels []string + buckets metrics.Buckets + dbuckets metrics.DurationBuckets + } + + testCases := []struct { + name string + args args + expectedType interface{} + expectLabels []string + }{ + { + name: "CounterVec", + args: args{ + name: "cntvec", + labels: []string{"shimba", "looken"}, + }, + expectedType: &CounterVec{}, + expectLabels: []string{"shimba", "looken"}, + }, + { + name: "GaugeVec", + args: args{ + name: "ggvec", + labels: []string{"shimba", "looken"}, + }, + expectedType: &GaugeVec{}, + expectLabels: []string{"shimba", "looken"}, + }, + { + name: "TimerVec", + args: args{ + name: "tvec", + labels: []string{"shimba", "looken"}, + }, + expectedType: &TimerVec{}, + expectLabels: []string{"shimba", "looken"}, + }, + { + name: "HistogramVec", + args: args{ + name: "hvec", + labels: []string{"shimba", "looken"}, + buckets: metrics.NewBuckets(1, 2, 3, 4), + }, + expectedType: &HistogramVec{}, + expectLabels: []string{"shimba", "looken"}, + }, + { + name: "DurationHistogramVec", + args: args{ + name: "dhvec", + labels: []string{"shimba", "looken"}, + dbuckets: metrics.NewDurationBuckets(1, 2, 3, 4), + }, + expectedType: &DurationHistogramVec{}, + expectLabels: []string{"shimba", "looken"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + switch vect := tc.expectedType.(type) { + case *CounterVec: + vec := rg.CounterVec(tc.args.name, tc.args.labels) + assert.IsType(t, vect, vec) + assert.Equal(t, tc.expectLabels, vec.(*CounterVec).vec.labels) + case *GaugeVec: + vec := rg.GaugeVec(tc.args.name, tc.args.labels) + assert.IsType(t, vect, vec) + assert.Equal(t, tc.expectLabels, vec.(*GaugeVec).vec.labels) + case *TimerVec: + vec := rg.TimerVec(tc.args.name, tc.args.labels) + assert.IsType(t, vect, vec) + assert.Equal(t, tc.expectLabels, vec.(*TimerVec).vec.labels) + case *HistogramVec: + vec := rg.HistogramVec(tc.args.name, tc.args.buckets, tc.args.labels) + assert.IsType(t, vect, vec) + assert.Equal(t, tc.expectLabels, vec.(*HistogramVec).vec.labels) + case *DurationHistogramVec: + vec := rg.DurationHistogramVec(tc.args.name, tc.args.dbuckets, tc.args.labels) + assert.IsType(t, vect, vec) + assert.Equal(t, tc.expectLabels, vec.(*DurationHistogramVec).vec.labels) + default: + t.Errorf("unknown type: %T", vect) + } + }) + } +} + +func TestCounterVecWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + t.Run("plain", func(t *testing.T) { + vec := rg.CounterVec("ololo", []string{"shimba", "looken"}) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &CounterVec{}, vec) + assert.IsType(t, &Counter{}, metric) + assert.Equal(t, typeCounter, metric.(*Counter).metricType) + + assert.NotEmpty(t, vec.(*CounterVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*CounterVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Counter)) + }) + + t.Run("rated", func(t *testing.T) { + vec := rg.CounterVec("ololo", []string{"shimba", "looken"}) + Rated(vec) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &CounterVec{}, vec) + assert.IsType(t, &Counter{}, metric) + assert.Equal(t, typeRated, metric.(*Counter).metricType) + + assert.NotEmpty(t, vec.(*CounterVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*CounterVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Counter)) + }) +} + +func TestGaugeVecWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + vec := rg.GaugeVec("ololo", []string{"shimba", "looken"}) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &GaugeVec{}, vec) + assert.IsType(t, &Gauge{}, metric) + assert.Equal(t, typeGauge, metric.(*Gauge).metricType) + + assert.NotEmpty(t, vec.(*GaugeVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*GaugeVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Gauge)) +} + +func TestTimerVecWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + vec := rg.TimerVec("ololo", []string{"shimba", "looken"}) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &TimerVec{}, vec) + assert.IsType(t, &Timer{}, metric) + assert.Equal(t, typeGauge, metric.(*Timer).metricType) + + assert.NotEmpty(t, vec.(*TimerVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*TimerVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Timer)) +} + +func TestHistogramVecWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + t.Run("plain", func(t *testing.T) { + buckets := metrics.NewBuckets(1, 2, 3) + vec := rg.HistogramVec("ololo", buckets, []string{"shimba", "looken"}) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &HistogramVec{}, vec) + assert.IsType(t, &Histogram{}, metric) + assert.Equal(t, typeHistogram, metric.(*Histogram).metricType) + + assert.NotEmpty(t, vec.(*HistogramVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*HistogramVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Histogram)) + }) + + t.Run("rated", func(t *testing.T) { + buckets := metrics.NewBuckets(1, 2, 3) + vec := rg.HistogramVec("ololo", buckets, []string{"shimba", "looken"}) + Rated(vec) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &HistogramVec{}, vec) + assert.IsType(t, &Histogram{}, metric) + assert.Equal(t, typeRatedHistogram, metric.(*Histogram).metricType) + + assert.NotEmpty(t, vec.(*HistogramVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*HistogramVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Histogram)) + }) +} + +func TestDurationHistogramVecWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + t.Run("plain", func(t *testing.T) { + buckets := metrics.NewDurationBuckets(1, 2, 3) + vec := rg.DurationHistogramVec("ololo", buckets, []string{"shimba", "looken"}) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &DurationHistogramVec{}, vec) + assert.IsType(t, &Histogram{}, metric) + assert.Equal(t, typeHistogram, metric.(*Histogram).metricType) + + assert.NotEmpty(t, vec.(*DurationHistogramVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*DurationHistogramVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Histogram)) + }) + + t.Run("rated", func(t *testing.T) { + buckets := metrics.NewDurationBuckets(1, 2, 3) + vec := rg.DurationHistogramVec("ololo", buckets, []string{"shimba", "looken"}) + Rated(vec) + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + metric := vec.With(tags) + + assert.IsType(t, &DurationHistogramVec{}, vec) + assert.IsType(t, &Histogram{}, metric) + assert.Equal(t, typeRatedHistogram, metric.(*Histogram).metricType) + + assert.NotEmpty(t, vec.(*DurationHistogramVec).vec.metrics) + vec.Reset() + assert.Empty(t, vec.(*DurationHistogramVec).vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), metric.(*Histogram)) + }) +} + +func TestMetricsVectorWith(t *testing.T) { + rg := NewRegistry(NewRegistryOpts()) + + name := "ololo" + tags := map[string]string{ + "shimba": "boomba", + "looken": "tooken", + } + + vec := &metricsVector{ + labels: []string{"shimba", "looken"}, + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return rg.WithTags(tags).Counter(name).(*Counter) + }, + removeMetric: func(m Metric) { + rg.WithTags(m.getLabels()).(*Registry).unregisterMetric(m) + }, + } + + // check first counter + metric := vec.with(tags) + require.IsType(t, &Counter{}, metric) + cnt := metric.(*Counter) + assert.Equal(t, name, cnt.name) + assert.Equal(t, tags, cnt.tags) + + // check vector length + assert.Equal(t, 1, len(vec.metrics)) + + // check same counter returned for same tags set + cnt2 := vec.with(tags) + assert.Same(t, cnt, cnt2) + + // check vector length + assert.Equal(t, 1, len(vec.metrics)) + + // return new counter + cnt3 := vec.with(map[string]string{ + "shimba": "boomba", + "looken": "cooken", + }) + assert.NotSame(t, cnt, cnt3) + + // check vector length + assert.Equal(t, 2, len(vec.metrics)) + + // check for panic + assert.Panics(t, func() { + vec.with(map[string]string{"chicken": "cooken"}) + }) + assert.Panics(t, func() { + vec.with(map[string]string{"shimba": "boomba", "chicken": "cooken"}) + }) + + // check reset + vec.reset() + assert.Empty(t, vec.metrics) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), cnt2) + assertMetricRemoved(t, rg.WithTags(tags).(*Registry), cnt3) +} + +func assertMetricRemoved(t *testing.T, rg *Registry, m Metric) { + t.Helper() + + v, ok := rg.metrics.Load(rg.metricKey(m)) + assert.False(t, ok, v) +} diff --git a/library/go/core/metrics/solomon/ya.make b/library/go/core/metrics/solomon/ya.make new file mode 100644 index 0000000000..a4de14cadf --- /dev/null +++ b/library/go/core/metrics/solomon/ya.make @@ -0,0 +1,44 @@ +GO_LIBRARY() + +SRCS( + converter.go + counter.go + func_counter.go + func_gauge.go + func_int_gauge.go + gauge.go + int_gauge.go + histogram.go + metrics.go + metrics_opts.go + registry.go + registry_opts.go + spack.go + spack_compression.go + stream.go + timer.go + vec.go +) + +GO_TEST_SRCS( + converter_test.go + counter_test.go + func_counter_test.go + func_gauge_test.go + func_int_gauge_test.go + gauge_test.go + int_gauge_test.go + histogram_test.go + metrics_test.go + registry_test.go + spack_compression_test.go + spack_test.go + stream_test.go + timer_test.go + vec_test.go + race_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/core/metrics/ya.make b/library/go/core/metrics/ya.make new file mode 100644 index 0000000000..0a42f422af --- /dev/null +++ b/library/go/core/metrics/ya.make @@ -0,0 +1,20 @@ +GO_LIBRARY() + +SRCS( + buckets.go + metrics.go +) + +GO_TEST_SRCS(buckets_test.go) + +END() + +RECURSE( + collect + gotest + internal + mock + nop + prometheus + solomon +) diff --git a/library/go/core/resource/cc/main.go b/library/go/core/resource/cc/main.go new file mode 100644 index 0000000000..50887343d6 --- /dev/null +++ b/library/go/core/resource/cc/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "os" + "strings" +) + +func fatalf(msg string, args ...interface{}) { + _, _ = fmt.Fprintf(os.Stderr, msg+"\n", args...) + os.Exit(1) +} + +func generate(w io.Writer, pkg string, blobs [][]byte, keys []string) { + _, _ = fmt.Fprint(w, "// Code generated by github.com/ydb-platform/ydb/library/go/core/resource/cc DO NOT EDIT.\n") + _, _ = fmt.Fprintf(w, "package %s\n\n", pkg) + _, _ = fmt.Fprint(w, "import \"github.com/ydb-platform/ydb/library/go/core/resource\"\n") + + for i := 0; i < len(blobs); i++ { + blob := blobs[i] + + _, _ = fmt.Fprint(w, "\nfunc init() {\n") + + _, _ = fmt.Fprint(w, "\tblob := []byte(") + _, _ = fmt.Fprintf(w, "%+q", blob) + _, _ = fmt.Fprint(w, ")\n") + _, _ = fmt.Fprintf(w, "\tresource.InternalRegister(%q, blob)\n", keys[i]) + _, _ = fmt.Fprint(w, "}\n") + } +} + +func main() { + var pkg, output string + + flag.StringVar(&pkg, "package", "", "package name") + flag.StringVar(&output, "o", "", "output filename") + flag.Parse() + + if flag.NArg()%2 != 0 { + fatalf("cc: must provide even number of arguments") + } + + var keys []string + var blobs [][]byte + for i := 0; 2*i < flag.NArg(); i++ { + file := flag.Arg(2 * i) + key := flag.Arg(2*i + 1) + + if !strings.HasPrefix(key, "notafile") { + fatalf("cc: key argument must start with \"notafile\" string") + } + key = key[8:] + + if file == "-" { + parts := strings.SplitN(key, "=", 2) + if len(parts) != 2 { + fatalf("cc: invalid key syntax: %q", key) + } + + keys = append(keys, parts[0]) + blobs = append(blobs, []byte(parts[1])) + } else { + blob, err := os.ReadFile(file) + if err != nil { + fatalf("cc: %v", err) + } + + keys = append(keys, key) + blobs = append(blobs, blob) + } + } + + f, err := os.Create(output) + if err != nil { + fatalf("cc: %v", err) + } + + b := bufio.NewWriter(f) + generate(b, pkg, blobs, keys) + + if err = b.Flush(); err != nil { + fatalf("cc: %v", err) + } + + if err = f.Close(); err != nil { + fatalf("cc: %v", err) + } +} diff --git a/library/go/core/resource/cc/ya.make b/library/go/core/resource/cc/ya.make new file mode 100644 index 0000000000..4d99fcc9c0 --- /dev/null +++ b/library/go/core/resource/cc/ya.make @@ -0,0 +1,9 @@ +GO_PROGRAM() + +SRCS(main.go) + +GO_TEST_SRCS(generate_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/core/resource/resource.go b/library/go/core/resource/resource.go new file mode 100644 index 0000000000..686ea73c3b --- /dev/null +++ b/library/go/core/resource/resource.go @@ -0,0 +1,56 @@ +// Package resource provides integration with RESOURCE and RESOURCE_FILES macros. +// +// Use RESOURCE macro to "link" file into the library or executable. +// +// RESOURCE(my_file.txt some_key) +// +// And then retrieve file content in the runtime. +// +// blob := resource.Get("some_key") +// +// Warning: Excessive consumption of resource leads to obesity. +package resource + +import ( + "fmt" + "sort" +) + +var resources = map[string][]byte{} + +// InternalRegister is private API used by generated code. +func InternalRegister(key string, blob []byte) { + if _, ok := resources[key]; ok { + panic(fmt.Sprintf("resource key %q is already defined", key)) + } + + resources[key] = blob +} + +// Get returns content of the file registered by the given key. +// +// If no file was registered for the given key, nil slice is returned. +// +// User should take care, to avoid mutating returned slice. +func Get(key string) []byte { + return resources[key] +} + +// MustGet is like Get, but panics when associated resource is not defined. +func MustGet(key string) []byte { + r, ok := resources[key] + if !ok { + panic(fmt.Sprintf("resource with key %q is not defined", key)) + } + return r +} + +// Keys returns sorted keys of all registered resources inside the binary +func Keys() []string { + keys := make([]string, 0, len(resources)) + for k := range resources { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/library/go/core/resource/ya.make b/library/go/core/resource/ya.make new file mode 100644 index 0000000000..4860291e25 --- /dev/null +++ b/library/go/core/resource/ya.make @@ -0,0 +1,14 @@ +GO_LIBRARY() + +SRCS(resource.go) + +END() + +RECURSE( + cc + test + test-bin + test-fileonly + test-files + test-keyonly +) diff --git a/library/go/httputil/headers/accept.go b/library/go/httputil/headers/accept.go new file mode 100644 index 0000000000..394bed7360 --- /dev/null +++ b/library/go/httputil/headers/accept.go @@ -0,0 +1,259 @@ +package headers + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +const ( + AcceptKey = "Accept" + AcceptEncodingKey = "Accept-Encoding" +) + +type AcceptableEncodings []AcceptableEncoding + +type AcceptableEncoding struct { + Encoding ContentEncoding + Weight float32 + + pos int +} + +func (as AcceptableEncodings) IsAcceptable(encoding ContentEncoding) bool { + for _, ae := range as { + if ae.Encoding == encoding { + return ae.Weight != 0 + } + } + return false +} + +func (as AcceptableEncodings) String() string { + if len(as) == 0 { + return "" + } + + var b strings.Builder + for i, ae := range as { + b.WriteString(ae.Encoding.String()) + + if ae.Weight > 0.0 && ae.Weight < 1.0 { + b.WriteString(";q=" + strconv.FormatFloat(float64(ae.Weight), 'f', 1, 32)) + } + + if i < len(as)-1 { + b.WriteString(", ") + } + } + return b.String() +} + +type AcceptableTypes []AcceptableType + +func (as AcceptableTypes) IsAcceptable(contentType ContentType) bool { + for _, ae := range as { + if ae.Type == contentType { + return ae.Weight != 0 + } + } + return false +} + +type AcceptableType struct { + Type ContentType + Weight float32 + Extension map[string]string + + pos int +} + +func (as AcceptableTypes) String() string { + if len(as) == 0 { + return "" + } + + var b strings.Builder + for i, at := range as { + b.WriteString(at.Type.String()) + + if at.Weight > 0.0 && at.Weight < 1.0 { + b.WriteString(";q=" + strconv.FormatFloat(float64(at.Weight), 'f', 1, 32)) + } + + for k, v := range at.Extension { + b.WriteString(";" + k + "=" + v) + } + + if i < len(as)-1 { + b.WriteString(", ") + } + } + return b.String() +} + +// ParseAccept parses Accept HTTP header. +// It will sort acceptable types by weight, specificity and position. +// See: https://tools.ietf.org/html/rfc2616#section-14.1 +func ParseAccept(headerValue string) (AcceptableTypes, error) { + if headerValue == "" { + return nil, nil + } + + parsedValues, err := parseAcceptFamilyHeader(headerValue) + if err != nil { + return nil, err + } + ah := make(AcceptableTypes, 0, len(parsedValues)) + for _, parsedValue := range parsedValues { + ah = append(ah, AcceptableType{ + Type: ContentType(parsedValue.Value), + Weight: parsedValue.Weight, + Extension: parsedValue.Extension, + pos: parsedValue.pos, + }) + } + + sort.Slice(ah, func(i, j int) bool { + // sort by weight only + if ah[i].Weight != ah[j].Weight { + return ah[i].Weight > ah[j].Weight + } + + // sort by most specific if types are equal + if ah[i].Type == ah[j].Type { + return len(ah[i].Extension) > len(ah[j].Extension) + } + + // move counterpart up if one of types is ANY + if ah[i].Type == ContentTypeAny { + return false + } + if ah[j].Type == ContentTypeAny { + return true + } + + // i type has j type as prefix + if strings.HasSuffix(string(ah[j].Type), "/*") && + strings.HasPrefix(string(ah[i].Type), string(ah[j].Type)[:len(ah[j].Type)-1]) { + return true + } + + // j type has i type as prefix + if strings.HasSuffix(string(ah[i].Type), "/*") && + strings.HasPrefix(string(ah[j].Type), string(ah[i].Type)[:len(ah[i].Type)-1]) { + return false + } + + // sort by position if nothing else left + return ah[i].pos < ah[j].pos + }) + + return ah, nil +} + +// ParseAcceptEncoding parses Accept-Encoding HTTP header. +// It will sort acceptable encodings by weight and position. +// See: https://tools.ietf.org/html/rfc2616#section-14.3 +func ParseAcceptEncoding(headerValue string) (AcceptableEncodings, error) { + if headerValue == "" { + return nil, nil + } + + // e.g. gzip;q=1.0, compress, identity + parsedValues, err := parseAcceptFamilyHeader(headerValue) + if err != nil { + return nil, err + } + acceptableEncodings := make(AcceptableEncodings, 0, len(parsedValues)) + for _, parsedValue := range parsedValues { + acceptableEncodings = append(acceptableEncodings, AcceptableEncoding{ + Encoding: ContentEncoding(parsedValue.Value), + Weight: parsedValue.Weight, + pos: parsedValue.pos, + }) + } + sort.Slice(acceptableEncodings, func(i, j int) bool { + // sort by weight only + if acceptableEncodings[i].Weight != acceptableEncodings[j].Weight { + return acceptableEncodings[i].Weight > acceptableEncodings[j].Weight + } + + // move counterpart up if one of encodings is ANY + if acceptableEncodings[i].Encoding == EncodingAny { + return false + } + if acceptableEncodings[j].Encoding == EncodingAny { + return true + } + + // sort by position if nothing else left + return acceptableEncodings[i].pos < acceptableEncodings[j].pos + }) + + return acceptableEncodings, nil +} + +type acceptHeaderValue struct { + Value string + Weight float32 + Extension map[string]string + + pos int +} + +// parseAcceptFamilyHeader parses family of Accept* HTTP headers +// See: https://tools.ietf.org/html/rfc2616#section-14.1 +func parseAcceptFamilyHeader(header string) ([]acceptHeaderValue, error) { + headerValues := strings.Split(header, ",") + + parsedValues := make([]acceptHeaderValue, 0, len(headerValues)) + for i, headerValue := range headerValues { + valueParams := strings.Split(headerValue, ";") + + parsedValue := acceptHeaderValue{ + Value: strings.TrimSpace(valueParams[0]), + Weight: 1.0, + pos: i, + } + + // parse quality factor and/or accept extension + if len(valueParams) > 1 { + for _, rawParam := range valueParams[1:] { + rawParam = strings.TrimSpace(rawParam) + params := strings.SplitN(rawParam, "=", 2) + key := strings.TrimSpace(params[0]) + + // quality factor + if key == "q" { + if len(params) != 2 { + return nil, fmt.Errorf("invalid quality factor format: %q", rawParam) + } + + w, err := strconv.ParseFloat(params[1], 32) + if err != nil { + return nil, err + } + parsedValue.Weight = float32(w) + + continue + } + + // extension + if parsedValue.Extension == nil { + parsedValue.Extension = make(map[string]string) + } + + var value string + if len(params) == 2 { + value = strings.TrimSpace(params[1]) + } + parsedValue.Extension[key] = value + } + } + + parsedValues = append(parsedValues, parsedValue) + } + return parsedValues, nil +} diff --git a/library/go/httputil/headers/accept_test.go b/library/go/httputil/headers/accept_test.go new file mode 100644 index 0000000000..09d3da086f --- /dev/null +++ b/library/go/httputil/headers/accept_test.go @@ -0,0 +1,309 @@ +package headers_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/httputil/headers" +) + +// examples for tests taken from https://tools.ietf.org/html/rfc2616#section-14.3 +func TestParseAcceptEncoding(t *testing.T) { + testCases := []struct { + name string + input string + expected headers.AcceptableEncodings + expectedErr error + }{ + { + "ietf_example_1", + "compress, gzip", + headers.AcceptableEncodings{ + {Encoding: headers.ContentEncoding("compress"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("gzip"), Weight: 1.0}, + }, + nil, + }, + { + "ietf_example_2", + "", + nil, + nil, + }, + { + "ietf_example_3", + "*", + headers.AcceptableEncodings{ + {Encoding: headers.ContentEncoding("*"), Weight: 1.0}, + }, + nil, + }, + { + "ietf_example_4", + "compress;q=0.5, gzip;q=1.0", + headers.AcceptableEncodings{ + {Encoding: headers.ContentEncoding("gzip"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("compress"), Weight: 0.5}, + }, + nil, + }, + { + "ietf_example_5", + "gzip;q=1.0, identity; q=0.5, *;q=0", + headers.AcceptableEncodings{ + {Encoding: headers.ContentEncoding("gzip"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("identity"), Weight: 0.5}, + {Encoding: headers.ContentEncoding("*"), Weight: 0}, + }, + nil, + }, + { + "solomon_headers", + "zstd,lz4,gzip,deflate", + headers.AcceptableEncodings{ + {Encoding: headers.ContentEncoding("zstd"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("lz4"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("gzip"), Weight: 1.0}, + {Encoding: headers.ContentEncoding("deflate"), Weight: 1.0}, + }, + nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + acceptableEncodings, err := headers.ParseAcceptEncoding(tc.input) + + if tc.expectedErr != nil { + assert.EqualError(t, err, tc.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + require.Len(t, acceptableEncodings, len(tc.expected)) + + opt := cmpopts.IgnoreUnexported(headers.AcceptableEncoding{}) + assert.True(t, cmp.Equal(tc.expected, acceptableEncodings, opt), cmp.Diff(tc.expected, acceptableEncodings, opt)) + }) + } +} + +func TestParseAccept(t *testing.T) { + testCases := []struct { + name string + input string + expected headers.AcceptableTypes + expectedErr error + }{ + { + "empty_header", + "", + nil, + nil, + }, + { + "accept_any", + "*/*", + headers.AcceptableTypes{ + {Type: headers.ContentTypeAny, Weight: 1.0}, + }, + nil, + }, + { + "accept_single", + "application/json", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationJSON, Weight: 1.0}, + }, + nil, + }, + { + "accept_multiple", + "application/json, application/protobuf", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationJSON, Weight: 1.0}, + {Type: headers.TypeApplicationProtobuf, Weight: 1.0}, + }, + nil, + }, + { + "accept_multiple_weighted", + "application/json;q=0.8, application/protobuf", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationProtobuf, Weight: 1.0}, + {Type: headers.TypeApplicationJSON, Weight: 0.8}, + }, + nil, + }, + { + "accept_multiple_weighted_unsorted", + "text/plain;q=0.5, application/protobuf, application/json;q=0.5", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationProtobuf, Weight: 1.0}, + {Type: headers.TypeTextPlain, Weight: 0.5}, + {Type: headers.TypeApplicationJSON, Weight: 0.5}, + }, + nil, + }, + { + "unknown_type", + "custom/type, unknown/my_type;q=0.2", + headers.AcceptableTypes{ + {Type: headers.ContentType("custom/type"), Weight: 1.0}, + {Type: headers.ContentType("unknown/my_type"), Weight: 0.2}, + }, + nil, + }, + { + "yabro_19.6.0", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3", + headers.AcceptableTypes{ + {Type: headers.ContentType("text/html"), Weight: 1.0}, + {Type: headers.ContentType("application/xhtml+xml"), Weight: 1.0}, + {Type: headers.ContentType("image/webp"), Weight: 1.0}, + {Type: headers.ContentType("image/apng"), Weight: 1.0}, + {Type: headers.ContentType("application/signed-exchange"), Weight: 1.0, Extension: map[string]string{"v": "b3"}}, + {Type: headers.ContentType("application/xml"), Weight: 0.9}, + {Type: headers.ContentType("*/*"), Weight: 0.8}, + }, + nil, + }, + { + "chrome_81.0.4044", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + headers.AcceptableTypes{ + {Type: headers.ContentType("text/html"), Weight: 1.0}, + {Type: headers.ContentType("application/xhtml+xml"), Weight: 1.0}, + {Type: headers.ContentType("image/webp"), Weight: 1.0}, + {Type: headers.ContentType("image/apng"), Weight: 1.0}, + {Type: headers.ContentType("application/xml"), Weight: 0.9}, + {Type: headers.ContentType("application/signed-exchange"), Weight: 0.9, Extension: map[string]string{"v": "b3"}}, + {Type: headers.ContentType("*/*"), Weight: 0.8}, + }, + nil, + }, + { + "firefox_77.0b3", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + headers.AcceptableTypes{ + {Type: headers.ContentType("text/html"), Weight: 1.0}, + {Type: headers.ContentType("application/xhtml+xml"), Weight: 1.0}, + {Type: headers.ContentType("image/webp"), Weight: 1.0}, + {Type: headers.ContentType("application/xml"), Weight: 0.9}, + {Type: headers.ContentType("*/*"), Weight: 0.8}, + }, + nil, + }, + { + "sort_by_most_specific", + "text/*, text/html, */*, text/html;level=1", + headers.AcceptableTypes{ + {Type: headers.ContentType("text/html"), Weight: 1.0, Extension: map[string]string{"level": "1"}}, + {Type: headers.ContentType("text/html"), Weight: 1.0}, + {Type: headers.ContentType("text/*"), Weight: 1.0}, + {Type: headers.ContentType("*/*"), Weight: 1.0}, + }, + nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + at, err := headers.ParseAccept(tc.input) + + if tc.expectedErr != nil { + assert.EqualError(t, err, tc.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + require.Len(t, at, len(tc.expected)) + + opt := cmpopts.IgnoreUnexported(headers.AcceptableType{}) + assert.True(t, cmp.Equal(tc.expected, at, opt), cmp.Diff(tc.expected, at, opt)) + }) + } +} + +func TestAcceptableTypesString(t *testing.T) { + testCases := []struct { + name string + types headers.AcceptableTypes + expected string + }{ + { + "empty", + headers.AcceptableTypes{}, + "", + }, + { + "single", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationJSON}, + }, + "application/json", + }, + { + "single_weighted", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationJSON, Weight: 0.8}, + }, + "application/json;q=0.8", + }, + { + "multiple", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationJSON}, + {Type: headers.TypeApplicationProtobuf}, + }, + "application/json, application/protobuf", + }, + { + "multiple_weighted", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationProtobuf}, + {Type: headers.TypeApplicationJSON, Weight: 0.8}, + }, + "application/protobuf, application/json;q=0.8", + }, + { + "multiple_weighted_with_extension", + headers.AcceptableTypes{ + {Type: headers.TypeApplicationProtobuf}, + {Type: headers.TypeApplicationJSON, Weight: 0.8}, + {Type: headers.TypeApplicationXML, Weight: 0.5, Extension: map[string]string{"label": "1"}}, + }, + "application/protobuf, application/json;q=0.8, application/xml;q=0.5;label=1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.types.String()) + }) + } +} + +func BenchmarkParseAccept(b *testing.B) { + benchCases := []string{ + "", + "*/*", + "application/json", + "application/json, application/protobuf", + "application/json;q=0.8, application/protobuf", + "text/plain;q=0.5, application/protobuf, application/json;q=0.5", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + "text/*, text/html, */*, text/html;level=1", + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = headers.ParseAccept(benchCases[i%len(benchCases)]) + } +} diff --git a/library/go/httputil/headers/authorization.go b/library/go/httputil/headers/authorization.go new file mode 100644 index 0000000000..145e04f931 --- /dev/null +++ b/library/go/httputil/headers/authorization.go @@ -0,0 +1,31 @@ +package headers + +import "strings" + +const ( + AuthorizationKey = "Authorization" + + TokenTypeBearer TokenType = "bearer" + TokenTypeMAC TokenType = "mac" +) + +type TokenType string + +// String implements stringer interface +func (tt TokenType) String() string { + return string(tt) +} + +func AuthorizationTokenType(token string) TokenType { + if len(token) > len(TokenTypeBearer) && + strings.ToLower(token[:len(TokenTypeBearer)]) == TokenTypeBearer.String() { + return TokenTypeBearer + } + + if len(token) > len(TokenTypeMAC) && + strings.ToLower(token[:len(TokenTypeMAC)]) == TokenTypeMAC.String() { + return TokenTypeMAC + } + + return TokenType("unknown") +} diff --git a/library/go/httputil/headers/authorization_test.go b/library/go/httputil/headers/authorization_test.go new file mode 100644 index 0000000000..4e93aac1cd --- /dev/null +++ b/library/go/httputil/headers/authorization_test.go @@ -0,0 +1,30 @@ +package headers_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/httputil/headers" +) + +func TestAuthorizationTokenType(t *testing.T) { + testCases := []struct { + name string + token string + expected headers.TokenType + }{ + {"bearer", "bearer ololo.trololo", headers.TokenTypeBearer}, + {"Bearer", "Bearer ololo.trololo", headers.TokenTypeBearer}, + {"BEARER", "BEARER ololo.trololo", headers.TokenTypeBearer}, + {"mac", "mac ololo.trololo", headers.TokenTypeMAC}, + {"Mac", "Mac ololo.trololo", headers.TokenTypeMAC}, + {"MAC", "MAC ololo.trololo", headers.TokenTypeMAC}, + {"unknown", "shimba ololo.trololo", headers.TokenType("unknown")}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, headers.AuthorizationTokenType(tc.token)) + }) + } +} diff --git a/library/go/httputil/headers/content.go b/library/go/httputil/headers/content.go new file mode 100644 index 0000000000..b92e013cc3 --- /dev/null +++ b/library/go/httputil/headers/content.go @@ -0,0 +1,57 @@ +package headers + +type ContentType string + +// String implements stringer interface +func (ct ContentType) String() string { + return string(ct) +} + +type ContentEncoding string + +// String implements stringer interface +func (ce ContentEncoding) String() string { + return string(ce) +} + +const ( + ContentTypeKey = "Content-Type" + ContentLength = "Content-Length" + ContentEncodingKey = "Content-Encoding" + + ContentTypeAny ContentType = "*/*" + + TypeApplicationJSON ContentType = "application/json" + TypeApplicationXML ContentType = "application/xml" + TypeApplicationOctetStream ContentType = "application/octet-stream" + TypeApplicationProtobuf ContentType = "application/protobuf" + TypeApplicationMsgpack ContentType = "application/msgpack" + TypeApplicationXSolomonSpack ContentType = "application/x-solomon-spack" + + EncodingAny ContentEncoding = "*" + EncodingZSTD ContentEncoding = "zstd" + EncodingLZ4 ContentEncoding = "lz4" + EncodingGZIP ContentEncoding = "gzip" + EncodingDeflate ContentEncoding = "deflate" + + TypeTextPlain ContentType = "text/plain" + TypeTextHTML ContentType = "text/html" + TypeTextCSV ContentType = "text/csv" + TypeTextCmd ContentType = "text/cmd" + TypeTextCSS ContentType = "text/css" + TypeTextXML ContentType = "text/xml" + TypeTextMarkdown ContentType = "text/markdown" + + TypeImageAny ContentType = "image/*" + TypeImageJPEG ContentType = "image/jpeg" + TypeImageGIF ContentType = "image/gif" + TypeImagePNG ContentType = "image/png" + TypeImageSVG ContentType = "image/svg+xml" + TypeImageTIFF ContentType = "image/tiff" + TypeImageWebP ContentType = "image/webp" + + TypeVideoMPEG ContentType = "video/mpeg" + TypeVideoMP4 ContentType = "video/mp4" + TypeVideoOgg ContentType = "video/ogg" + TypeVideoWebM ContentType = "video/webm" +) diff --git a/library/go/httputil/headers/content_test.go b/library/go/httputil/headers/content_test.go new file mode 100644 index 0000000000..36c7b8ea8f --- /dev/null +++ b/library/go/httputil/headers/content_test.go @@ -0,0 +1,41 @@ +package headers_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/httputil/headers" +) + +func TestContentTypeConsts(t *testing.T) { + assert.Equal(t, headers.ContentTypeKey, "Content-Type") + + assert.Equal(t, headers.ContentTypeAny, headers.ContentType("*/*")) + + assert.Equal(t, headers.TypeApplicationJSON, headers.ContentType("application/json")) + assert.Equal(t, headers.TypeApplicationXML, headers.ContentType("application/xml")) + assert.Equal(t, headers.TypeApplicationOctetStream, headers.ContentType("application/octet-stream")) + assert.Equal(t, headers.TypeApplicationProtobuf, headers.ContentType("application/protobuf")) + assert.Equal(t, headers.TypeApplicationMsgpack, headers.ContentType("application/msgpack")) + + assert.Equal(t, headers.TypeTextPlain, headers.ContentType("text/plain")) + assert.Equal(t, headers.TypeTextHTML, headers.ContentType("text/html")) + assert.Equal(t, headers.TypeTextCSV, headers.ContentType("text/csv")) + assert.Equal(t, headers.TypeTextCmd, headers.ContentType("text/cmd")) + assert.Equal(t, headers.TypeTextCSS, headers.ContentType("text/css")) + assert.Equal(t, headers.TypeTextXML, headers.ContentType("text/xml")) + assert.Equal(t, headers.TypeTextMarkdown, headers.ContentType("text/markdown")) + + assert.Equal(t, headers.TypeImageAny, headers.ContentType("image/*")) + assert.Equal(t, headers.TypeImageJPEG, headers.ContentType("image/jpeg")) + assert.Equal(t, headers.TypeImageGIF, headers.ContentType("image/gif")) + assert.Equal(t, headers.TypeImagePNG, headers.ContentType("image/png")) + assert.Equal(t, headers.TypeImageSVG, headers.ContentType("image/svg+xml")) + assert.Equal(t, headers.TypeImageTIFF, headers.ContentType("image/tiff")) + assert.Equal(t, headers.TypeImageWebP, headers.ContentType("image/webp")) + + assert.Equal(t, headers.TypeVideoMPEG, headers.ContentType("video/mpeg")) + assert.Equal(t, headers.TypeVideoMP4, headers.ContentType("video/mp4")) + assert.Equal(t, headers.TypeVideoOgg, headers.ContentType("video/ogg")) + assert.Equal(t, headers.TypeVideoWebM, headers.ContentType("video/webm")) +} diff --git a/library/go/httputil/headers/cookie.go b/library/go/httputil/headers/cookie.go new file mode 100644 index 0000000000..bcc685c474 --- /dev/null +++ b/library/go/httputil/headers/cookie.go @@ -0,0 +1,5 @@ +package headers + +const ( + CookieKey = "Cookie" +) diff --git a/library/go/httputil/headers/gotest/ya.make b/library/go/httputil/headers/gotest/ya.make new file mode 100644 index 0000000000..467fc88ca4 --- /dev/null +++ b/library/go/httputil/headers/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/httputil/headers) + +END() diff --git a/library/go/httputil/headers/tvm.go b/library/go/httputil/headers/tvm.go new file mode 100644 index 0000000000..1737cc69d7 --- /dev/null +++ b/library/go/httputil/headers/tvm.go @@ -0,0 +1,8 @@ +package headers + +const ( + // XYaServiceTicket is http header that should be used for service ticket transfer. + XYaServiceTicketKey = "X-Ya-Service-Ticket" + // XYaUserTicket is http header that should be used for user ticket transfer. + XYaUserTicketKey = "X-Ya-User-Ticket" +) diff --git a/library/go/httputil/headers/user_agent.go b/library/go/httputil/headers/user_agent.go new file mode 100644 index 0000000000..366606a01d --- /dev/null +++ b/library/go/httputil/headers/user_agent.go @@ -0,0 +1,5 @@ +package headers + +const ( + UserAgentKey = "User-Agent" +) diff --git a/library/go/httputil/headers/warning.go b/library/go/httputil/headers/warning.go new file mode 100644 index 0000000000..20df80e664 --- /dev/null +++ b/library/go/httputil/headers/warning.go @@ -0,0 +1,167 @@ +package headers + +import ( + "errors" + "net/http" + "strconv" + "strings" + "time" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +const ( + WarningKey = "Warning" + + WarningResponseIsStale = 110 // RFC 7234, 5.5.1 + WarningRevalidationFailed = 111 // RFC 7234, 5.5.2 + WarningDisconnectedOperation = 112 // RFC 7234, 5.5.3 + WarningHeuristicExpiration = 113 // RFC 7234, 5.5.4 + WarningMiscellaneousWarning = 199 // RFC 7234, 5.5.5 + WarningTransformationApplied = 214 // RFC 7234, 5.5.6 + WarningMiscellaneousPersistentWarning = 299 // RFC 7234, 5.5.7 +) + +var warningStatusText = map[int]string{ + WarningResponseIsStale: "Response is Stale", + WarningRevalidationFailed: "Revalidation Failed", + WarningDisconnectedOperation: "Disconnected Operation", + WarningHeuristicExpiration: "Heuristic Expiration", + WarningMiscellaneousWarning: "Miscellaneous Warning", + WarningTransformationApplied: "Transformation Applied", + WarningMiscellaneousPersistentWarning: "Miscellaneous Persistent Warning", +} + +// WarningText returns a text for the warning header code. It returns the empty +// string if the code is unknown. +func WarningText(warn int) string { + return warningStatusText[warn] +} + +// AddWarning adds Warning to http.Header with proper formatting +// see: https://tools.ietf.org/html/rfc7234#section-5.5 +func AddWarning(h http.Header, warn int, agent, reason string, date time.Time) { + values := make([]string, 0, 4) + values = append(values, strconv.Itoa(warn)) + + if agent != "" { + values = append(values, agent) + } else { + values = append(values, "-") + } + + if reason != "" { + values = append(values, strconv.Quote(reason)) + } + + if !date.IsZero() { + values = append(values, strconv.Quote(date.Format(time.RFC1123))) + } + + h.Add(WarningKey, strings.Join(values, " ")) +} + +type WarningHeader struct { + Code int + Agent string + Reason string + Date time.Time +} + +// ParseWarnings reads and parses Warning headers from http.Header +func ParseWarnings(h http.Header) ([]WarningHeader, error) { + warnings, ok := h[WarningKey] + if !ok { + return nil, nil + } + + res := make([]WarningHeader, 0, len(warnings)) + for _, warn := range warnings { + wh, err := parseWarning(warn) + if err != nil { + return nil, xerrors.Errorf("cannot parse '%s' header: %w", warn, err) + } + res = append(res, wh) + } + + return res, nil +} + +func parseWarning(warn string) (WarningHeader, error) { + var res WarningHeader + + // parse code + { + codeSP := strings.Index(warn, " ") + + // fast path - code only warning + if codeSP == -1 { + code, err := strconv.Atoi(warn) + res.Code = code + return res, err + } + + code, err := strconv.Atoi(warn[:codeSP]) + if err != nil { + return WarningHeader{}, err + } + res.Code = code + + warn = strings.TrimSpace(warn[codeSP+1:]) + } + + // parse agent + { + agentSP := strings.Index(warn, " ") + + // fast path - no data after agent + if agentSP == -1 { + res.Agent = warn + return res, nil + } + + res.Agent = warn[:agentSP] + warn = strings.TrimSpace(warn[agentSP+1:]) + } + + // parse reason + { + if len(warn) == 0 { + return res, nil + } + + // reason must by quoted, so we search for second quote + reasonSP := strings.Index(warn[1:], `"`) + + // fast path - bad reason + if reasonSP == -1 { + return WarningHeader{}, errors.New("bad reason formatting") + } + + res.Reason = warn[1 : reasonSP+1] + warn = strings.TrimSpace(warn[reasonSP+2:]) + } + + // parse date + { + if len(warn) == 0 { + return res, nil + } + + // optional date must by quoted, so we search for second quote + dateSP := strings.Index(warn[1:], `"`) + + // fast path - bad date + if dateSP == -1 { + return WarningHeader{}, errors.New("bad date formatting") + } + + dt, err := time.Parse(time.RFC1123, warn[1:dateSP+1]) + if err != nil { + return WarningHeader{}, err + } + res.Date = dt + } + + return res, nil +} diff --git a/library/go/httputil/headers/warning_test.go b/library/go/httputil/headers/warning_test.go new file mode 100644 index 0000000000..9decb2f52f --- /dev/null +++ b/library/go/httputil/headers/warning_test.go @@ -0,0 +1,245 @@ +package headers + +import ( + "net/http" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWarningText(t *testing.T) { + testCases := []struct { + code int + expect string + }{ + {WarningResponseIsStale, "Response is Stale"}, + {WarningRevalidationFailed, "Revalidation Failed"}, + {WarningDisconnectedOperation, "Disconnected Operation"}, + {WarningHeuristicExpiration, "Heuristic Expiration"}, + {WarningMiscellaneousWarning, "Miscellaneous Warning"}, + {WarningTransformationApplied, "Transformation Applied"}, + {WarningMiscellaneousPersistentWarning, "Miscellaneous Persistent Warning"}, + {42, ""}, + {1489, ""}, + } + + for _, tc := range testCases { + t.Run(strconv.Itoa(tc.code), func(t *testing.T) { + assert.Equal(t, tc.expect, WarningText(tc.code)) + }) + } +} + +func TestAddWarning(t *testing.T) { + type args struct { + warn int + agent string + reason string + date time.Time + } + + testCases := []struct { + name string + args args + expect http.Header + }{ + { + name: "code_only", + args: args{warn: WarningResponseIsStale, agent: "", reason: "", date: time.Time{}}, + expect: http.Header{ + WarningKey: []string{ + "110 -", + }, + }, + }, + { + name: "code_agent", + args: args{warn: WarningResponseIsStale, agent: "ololo/trololo", reason: "", date: time.Time{}}, + expect: http.Header{ + WarningKey: []string{ + "110 ololo/trololo", + }, + }, + }, + { + name: "code_agent_reason", + args: args{warn: WarningResponseIsStale, agent: "ololo/trololo", reason: "shimba-boomba", date: time.Time{}}, + expect: http.Header{ + WarningKey: []string{ + `110 ololo/trololo "shimba-boomba"`, + }, + }, + }, + { + name: "code_agent_reason_date", + args: args{ + warn: WarningResponseIsStale, + agent: "ololo/trololo", + reason: "shimba-boomba", + date: time.Date(2019, time.January, 14, 10, 50, 43, 0, time.UTC), + }, + expect: http.Header{ + WarningKey: []string{ + `110 ololo/trololo "shimba-boomba" "Mon, 14 Jan 2019 10:50:43 UTC"`, + }, + }, + }, + { + name: "code_reason_date", + args: args{ + warn: WarningResponseIsStale, + agent: "", + reason: "shimba-boomba", + date: time.Date(2019, time.January, 14, 10, 50, 43, 0, time.UTC), + }, + expect: http.Header{ + WarningKey: []string{ + `110 - "shimba-boomba" "Mon, 14 Jan 2019 10:50:43 UTC"`, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + AddWarning(h, tc.args.warn, tc.args.agent, tc.args.reason, tc.args.date) + assert.Equal(t, tc.expect, h) + }) + } +} + +func TestParseWarnings(t *testing.T) { + testCases := []struct { + name string + h http.Header + expect []WarningHeader + expectErr bool + }{ + { + name: "no_warnings", + h: http.Header{}, + expect: nil, + expectErr: false, + }, + { + name: "single_code_only", + h: http.Header{ + WarningKey: []string{ + "110", + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "", + Reason: "", + Date: time.Time{}, + }, + }, + }, + { + name: "single_code_and_empty_agent", + h: http.Header{ + WarningKey: []string{ + "110 -", + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "-", + Reason: "", + Date: time.Time{}, + }, + }, + }, + { + name: "single_code_and_agent", + h: http.Header{ + WarningKey: []string{ + "110 shimba/boomba", + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "shimba/boomba", + Reason: "", + Date: time.Time{}, + }, + }, + }, + { + name: "single_code_agent_and_reason", + h: http.Header{ + WarningKey: []string{ + `110 shimba/boomba "looken tooken"`, + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "shimba/boomba", + Reason: "looken tooken", + Date: time.Time{}, + }, + }, + }, + { + name: "single_full", + h: http.Header{ + WarningKey: []string{ + `110 shimba/boomba "looken tooken" "Mon, 14 Jan 2019 10:50:43 UTC"`, + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "shimba/boomba", + Reason: "looken tooken", + Date: time.Date(2019, time.January, 14, 10, 50, 43, 0, time.UTC), + }, + }, + }, + { + name: "multiple_full", + h: http.Header{ + WarningKey: []string{ + `110 shimba/boomba "looken tooken" "Mon, 14 Jan 2019 10:50:43 UTC"`, + `112 chiken "cooken" "Mon, 15 Jan 2019 10:51:43 UTC"`, + }, + }, + expect: []WarningHeader{ + { + Code: 110, + Agent: "shimba/boomba", + Reason: "looken tooken", + Date: time.Date(2019, time.January, 14, 10, 50, 43, 0, time.UTC), + }, + { + Code: 112, + Agent: "chiken", + Reason: "cooken", + Date: time.Date(2019, time.January, 15, 10, 51, 43, 0, time.UTC), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseWarnings(tc.h) + + if tc.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tc.expect, got) + }) + } +} diff --git a/library/go/httputil/headers/ya.make b/library/go/httputil/headers/ya.make new file mode 100644 index 0000000000..d249197dc3 --- /dev/null +++ b/library/go/httputil/headers/ya.make @@ -0,0 +1,23 @@ +GO_LIBRARY() + +SRCS( + accept.go + authorization.go + content.go + cookie.go + tvm.go + user_agent.go + warning.go +) + +GO_TEST_SRCS(warning_test.go) + +GO_XTEST_SRCS( + accept_test.go + authorization_test.go + content_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/httputil/middleware/tvm/gotest/ya.make b/library/go/httputil/middleware/tvm/gotest/ya.make new file mode 100644 index 0000000000..f8ad1ffb46 --- /dev/null +++ b/library/go/httputil/middleware/tvm/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/httputil/middleware/tvm) + +END() diff --git a/library/go/httputil/middleware/tvm/middleware.go b/library/go/httputil/middleware/tvm/middleware.go new file mode 100644 index 0000000000..2e578ffca1 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware.go @@ -0,0 +1,112 @@ +package tvm + +import ( + "context" + "net/http" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/ctxlog" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/httputil/headers" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "golang.org/x/xerrors" +) + +const ( + // XYaServiceTicket is http header that should be used for service ticket transfer. + XYaServiceTicket = headers.XYaServiceTicketKey + // XYaUserTicket is http header that should be used for user ticket transfer. + XYaUserTicket = headers.XYaUserTicketKey +) + +type ( + MiddlewareOption func(*middleware) + + middleware struct { + l log.Structured + + clients []tvm.Client + + authClient func(context.Context, tvm.ClientID, tvm.ClientID) error + + onError func(w http.ResponseWriter, r *http.Request, err error) + } +) + +func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusForbidden) +} + +func getMiddleware(clients []tvm.Client, opts ...MiddlewareOption) middleware { + m := middleware{ + clients: clients, + onError: defaultErrorHandler, + } + + for _, opt := range opts { + opt(&m) + } + + if m.authClient == nil { + panic("must provide authorization policy") + } + + if m.l == nil { + m.l = &nop.Logger{} + } + + return m +} + +// CheckServiceTicketMultiClient returns http middleware that validates service tickets for all incoming requests. +// It tries to check ticket with all the given clients in the given order +// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket. +func CheckServiceTicketMultiClient(clients []tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler { + m := getMiddleware(clients, opts...) + return m.wrap +} + +// CheckServiceTicket returns http middleware that validates service tickets for all incoming requests. +// +// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket. +func CheckServiceTicket(client tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler { + m := getMiddleware([]tvm.Client{client}, opts...) + return m.wrap +} + +func (m *middleware) wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serviceTicket := r.Header.Get(XYaServiceTicket) + if serviceTicket == "" { + ctxlog.Error(r.Context(), m.l.Logger(), "missing service ticket") + m.onError(w, r, xerrors.New("missing service ticket")) + return + } + var ( + ticket *tvm.CheckedServiceTicket + err error + ) + for _, client := range m.clients { + ticket, err = client.CheckServiceTicket(r.Context(), serviceTicket) + if err == nil { + break + } + } + if err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "service ticket check failed", log.Error(err)) + m.onError(w, r, xerrors.Errorf("service ticket check failed: %w", err)) + return + } + + if err := m.authClient(r.Context(), ticket.SrcID, ticket.DstID); err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "client authorization failed", + log.String("ticket", ticket.LogInfo), + log.Error(err)) + m.onError(w, r, xerrors.Errorf("client authorization failed: %w", err)) + return + } + + r = r.WithContext(tvm.WithServiceTicket(r.Context(), ticket)) + next.ServeHTTP(w, r) + }) +} diff --git a/library/go/httputil/middleware/tvm/middleware_opts.go b/library/go/httputil/middleware/tvm/middleware_opts.go new file mode 100644 index 0000000000..4e33b4ee59 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware_opts.go @@ -0,0 +1,46 @@ +package tvm + +import ( + "context" + "net/http" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "golang.org/x/xerrors" +) + +// WithAllowedClients sets list of allowed clients. +func WithAllowedClients(allowedClients []tvm.ClientID) MiddlewareOption { + return func(m *middleware) { + m.authClient = func(_ context.Context, src tvm.ClientID, dst tvm.ClientID) error { + for _, allowed := range allowedClients { + if allowed == src { + return nil + } + } + + return xerrors.Errorf("client with tvm_id=%d is not whitelisted", dst) + } + } +} + +// WithClientAuth sets custom function for client authorization. +func WithClientAuth(authClient func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error) MiddlewareOption { + return func(m *middleware) { + m.authClient = authClient + } +} + +// WithErrorHandler sets http handler invoked for rejected requests. +func WithErrorHandler(h func(w http.ResponseWriter, r *http.Request, err error)) MiddlewareOption { + return func(m *middleware) { + m.onError = h + } +} + +// WithLogger sets logger. +func WithLogger(l log.Structured) MiddlewareOption { + return func(m *middleware) { + m.l = l + } +} diff --git a/library/go/httputil/middleware/tvm/middleware_test.go b/library/go/httputil/middleware/tvm/middleware_test.go new file mode 100644 index 0000000000..e6005a76a6 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware_test.go @@ -0,0 +1,126 @@ +package tvm + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +type fakeClient struct { + ticket *tvm.CheckedServiceTicket + err error +} + +func (f *fakeClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + panic("implement me") +} + +func (f *fakeClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + panic("implement me") +} + +func (f *fakeClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + return f.ticket, f.err +} + +func (f *fakeClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + panic("implement me") +} + +func (f *fakeClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + panic("implement me") +} + +func (f *fakeClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + panic("implement me") +} + +func TestMiddlewareOkTicket(t *testing.T) { + var f fakeClient + f.ticket = &tvm.CheckedServiceTicket{SrcID: 42} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + require.Equal(t, f.ticket, tvm.ContextServiceTicket(r.Context())) + }) + + m(handler).ServeHTTP(nil, r) + require.True(t, handlerCalled) +} + +func TestMiddlewareClientNotAllowed(t *testing.T) { + var f fakeClient + f.ticket = &tvm.CheckedServiceTicket{SrcID: 43} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareMissingTicket(t *testing.T) { + m := CheckServiceTicket(nil, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareInvalidTicket(t *testing.T) { + var f fakeClient + f.err = &tvm.Error{} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareMultipleDsts(t *testing.T) { + var f1, f2, f3 fakeClient + f1.err = &tvm.Error{} + f2.err = &tvm.Error{} + f3.ticket = &tvm.CheckedServiceTicket{SrcID: 42, DstID: 43} + + m := CheckServiceTicketMultiClient([]tvm.Client{ + &f1, + &f3, + &f2, + }, WithClientAuth(func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error { + require.Equal(t, tvm.ClientID(43), dst) + require.Equal(t, tvm.ClientID(42), src) + return nil + })) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + require.Equal(t, f3.ticket, tvm.ContextServiceTicket(r.Context())) + }) + + m(handler).ServeHTTP(nil, r) + require.True(t, handlerCalled) +} diff --git a/library/go/httputil/middleware/tvm/ya.make b/library/go/httputil/middleware/tvm/ya.make new file mode 100644 index 0000000000..7aab530b70 --- /dev/null +++ b/library/go/httputil/middleware/tvm/ya.make @@ -0,0 +1,12 @@ +GO_LIBRARY() + +SRCS( + middleware.go + middleware_opts.go +) + +GO_TEST_SRCS(middleware_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/maxprocs/cgroups.go b/library/go/maxprocs/cgroups.go new file mode 100644 index 0000000000..ab4b3240ac --- /dev/null +++ b/library/go/maxprocs/cgroups.go @@ -0,0 +1,173 @@ +package maxprocs + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + + "github.com/prometheus/procfs" + "github.com/ydb-platform/ydb/library/go/slices" +) + +const ( + unifiedHierarchy = "unified" + cpuHierarchy = "cpu" +) + +var ErrNoCgroups = errors.New("no suitable cgroups were found") + +func isCgroupsExists() bool { + mounts, err := procfs.GetMounts() + if err != nil { + return false + } + + for _, m := range mounts { + if m.FSType == "cgroup" || m.FSType == "cgroup2" { + return true + } + } + + return false +} + +func parseCgroupsMountPoints() (map[string]string, error) { + mounts, err := procfs.GetMounts() + if err != nil { + return nil, err + } + + out := make(map[string]string) + for _, mount := range mounts { + switch mount.FSType { + case "cgroup2": + out[unifiedHierarchy] = mount.MountPoint + case "cgroup": + for opt := range mount.SuperOptions { + if opt == cpuHierarchy { + out[cpuHierarchy] = mount.MountPoint + break + } + } + } + } + + return out, nil +} + +func getCFSQuota() (float64, error) { + self, err := procfs.Self() + if err != nil { + return 0, err + } + + selfCgroups, err := self.Cgroups() + if err != nil { + return 0, fmt.Errorf("parse self cgroups: %w", err) + } + + cgroups, err := parseCgroupsMountPoints() + if err != nil { + return 0, fmt.Errorf("parse cgroups: %w", err) + } + + if len(selfCgroups) == 0 || len(cgroups) == 0 { + return 0, ErrNoCgroups + } + + for _, cgroup := range selfCgroups { + var quota float64 + switch { + case cgroup.HierarchyID == 0: + // for the cgroups v2 hierarchy id is always 0 + mp, ok := cgroups[unifiedHierarchy] + if !ok { + continue + } + + quota, _ = parseV2CPUQuota(mp, cgroup.Path) + case slices.ContainsString(cgroup.Controllers, cpuHierarchy): + mp, ok := cgroups[cpuHierarchy] + if !ok { + continue + } + + quota, _ = parseV1CPUQuota(mp, cgroup.Path) + } + + if quota > 0 { + return quota, nil + } + } + + return 0, ErrNoCgroups +} + +func parseV1CPUQuota(mountPoint string, cgroupPath string) (float64, error) { + basePath := filepath.Join(mountPoint, cgroupPath) + cfsQuota, err := readFileInt(filepath.Join(basePath, "cpu.cfs_quota_us")) + if err != nil { + return -1, fmt.Errorf("parse cpu.cfs_quota_us: %w", err) + } + + // A value of -1 for cpu.cfs_quota_us indicates that the group does not have any + // bandwidth restriction in place + // https://www.kernel.org/doc/Documentation/scheduler/sched-bwc.txt + if cfsQuota == -1 { + return float64(runtime.NumCPU()), nil + } + + cfsPeriod, err := readFileInt(filepath.Join(basePath, "cpu.cfs_period_us")) + if err != nil { + return -1, fmt.Errorf("parse cpu.cfs_period_us: %w", err) + } + + return float64(cfsQuota) / float64(cfsPeriod), nil +} + +func parseV2CPUQuota(mountPoint string, cgroupPath string) (float64, error) { + /* + https://www.kernel.org/doc/Documentation/cgroup-v2.txt + + cpu.max + A read-write two value file which exists on non-root cgroups. + The default is "max 100000". + + The maximum bandwidth limit. It's in the following format:: + $MAX $PERIOD + + which indicates that the group may consume upto $MAX in each + $PERIOD duration. "max" for $MAX indicates no limit. If only + one number is written, $MAX is updated. + */ + rawCPUMax, err := os.ReadFile(filepath.Join(mountPoint, cgroupPath, "cpu.max")) + if err != nil { + return -1, fmt.Errorf("read cpu.max: %w", err) + } + + parts := strings.Fields(string(rawCPUMax)) + if len(parts) != 2 { + return -1, fmt.Errorf("invalid cpu.max format: %s", string(rawCPUMax)) + } + + // "max" for $MAX indicates no limit + if parts[0] == "max" { + return float64(runtime.NumCPU()), nil + } + + cpuMax, err := strconv.Atoi(parts[0]) + if err != nil { + return -1, fmt.Errorf("parse cpu.max[max] (%q): %w", parts[0], err) + } + + cpuPeriod, err := strconv.Atoi(parts[1]) + if err != nil { + return -1, fmt.Errorf("parse cpu.max[period] (%q): %w", parts[1], err) + } + + return float64(cpuMax) / float64(cpuPeriod), nil +} diff --git a/library/go/maxprocs/doc.go b/library/go/maxprocs/doc.go new file mode 100644 index 0000000000..2461d6022c --- /dev/null +++ b/library/go/maxprocs/doc.go @@ -0,0 +1,9 @@ +// Automatically sets GOMAXPROCS to match Yandex clouds container CPU quota. +// +// This package always adjust GOMAXPROCS to some "safe" value. +// "safe" values are: +// - 2 or more +// - no more than logical cores +// - no moore than container guarantees +// - no more than 8 +package maxprocs diff --git a/library/go/maxprocs/helpers.go b/library/go/maxprocs/helpers.go new file mode 100644 index 0000000000..f1192623b5 --- /dev/null +++ b/library/go/maxprocs/helpers.go @@ -0,0 +1,45 @@ +package maxprocs + +import ( + "bytes" + "math" + "os" + "strconv" +) + +func getEnv(envName string) (string, bool) { + val, ok := os.LookupEnv(envName) + return val, ok && val != "" +} + +func applyIntStringLimit(val string) int { + maxProc, err := strconv.Atoi(val) + if err == nil { + return Adjust(maxProc) + } + + return Adjust(SafeProc) +} + +func applyFloatStringLimit(val string) int { + maxProc, err := strconv.ParseFloat(val, 64) + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(maxProc) +} + +func applyFloatLimit(val float64) int { + maxProc := int(math.Floor(val)) + return Adjust(maxProc) +} + +func readFileInt(filename string) (int, error) { + raw, err := os.ReadFile(filename) + if err != nil { + return 0, err + } + + return strconv.Atoi(string(bytes.TrimSpace(raw))) +} diff --git a/library/go/maxprocs/maxprocs.go b/library/go/maxprocs/maxprocs.go new file mode 100644 index 0000000000..c04cca8dfb --- /dev/null +++ b/library/go/maxprocs/maxprocs.go @@ -0,0 +1,159 @@ +package maxprocs + +import ( + "context" + "os" + "runtime" + "strings" + + "github.com/ydb-platform/ydb/library/go/yandex/deploy/podagent" + "github.com/ydb-platform/ydb/library/go/yandex/yplite" +) + +const ( + SafeProc = 4 + MinProc = 2 + MaxProc = 8 + + GoMaxProcEnvName = "GOMAXPROCS" + QloudCPUEnvName = "QLOUD_CPU_GUARANTEE" + InstancectlCPUEnvName = "CPU_GUARANTEE" + DeloyBoxIDName = podagent.EnvBoxIDKey +) + +// Adjust adjust the maximum number of CPUs that can be executing. +// Takes a minimum between n and CPU counts and returns the previous setting +func Adjust(n int) int { + if n < MinProc { + n = MinProc + } + + nCPU := runtime.NumCPU() + if n < nCPU { + return runtime.GOMAXPROCS(n) + } + + return runtime.GOMAXPROCS(nCPU) +} + +// AdjustAuto automatically adjust the maximum number of CPUs that can be executing to safe value +// and returns the previous setting +func AdjustAuto() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if isCgroupsExists() { + return AdjustCgroup() + } + + if val, ok := getEnv(InstancectlCPUEnvName); ok { + return applyFloatStringLimit(strings.TrimRight(val, "c")) + } + + if val, ok := getEnv(QloudCPUEnvName); ok { + return applyFloatStringLimit(val) + } + + if boxID, ok := os.LookupEnv(DeloyBoxIDName); ok { + return adjustYPBox(boxID) + } + + if yplite.IsAPIAvailable() { + return AdjustYPLite() + } + + return Adjust(SafeProc) +} + +// AdjustQloud automatically adjust the maximum number of CPUs in case of Qloud env +// and returns the previous setting +func AdjustQloud() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if val, ok := getEnv(QloudCPUEnvName); ok { + return applyFloatStringLimit(val) + } + + return Adjust(MaxProc) +} + +// AdjustYP automatically adjust the maximum number of CPUs in case of YP/Y.Deploy/YP.Hard env +// and returns the previous setting +func AdjustYP() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if isCgroupsExists() { + return AdjustCgroup() + } + + return adjustYPBox(os.Getenv(DeloyBoxIDName)) +} + +func adjustYPBox(boxID string) int { + resources, err := podagent.NewClient().PodAttributes(context.Background()) + if err != nil { + return Adjust(SafeProc) + } + + var cpuGuarantee float64 + if boxResources, ok := resources.BoxesRequirements[boxID]; ok { + cpuGuarantee = boxResources.CPU.Guarantee / 1000 + } + + if cpuGuarantee <= 0 { + // if we don't have guarantees for current box, let's use pod guarantees + cpuGuarantee = resources.PodRequirements.CPU.Guarantee / 1000 + } + + return applyFloatLimit(cpuGuarantee) +} + +// AdjustYPLite automatically adjust the maximum number of CPUs in case of YP.Lite env +// and returns the previous setting +func AdjustYPLite() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + podAttributes, err := yplite.FetchPodAttributes() + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(float64(podAttributes.ResourceRequirements.CPU.Guarantee / 1000)) +} + +// AdjustInstancectl automatically adjust the maximum number of CPUs +// and returns the previous setting +// WARNING: supported only instancectl v1.177+ (https://wiki.yandex-team.ru/runtime-cloud/nanny/instancectl-change-log/#1.177) +func AdjustInstancectl() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if val, ok := getEnv(InstancectlCPUEnvName); ok { + return applyFloatStringLimit(strings.TrimRight(val, "c")) + } + + return Adjust(MaxProc) +} + +// AdjustCgroup automatically adjust the maximum number of CPUs based on the CFS quota +// and returns the previous setting. +func AdjustCgroup() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + quota, err := getCFSQuota() + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(quota) +} diff --git a/library/go/maxprocs/ya.make b/library/go/maxprocs/ya.make new file mode 100644 index 0000000000..eaaa397a9b --- /dev/null +++ b/library/go/maxprocs/ya.make @@ -0,0 +1,20 @@ +GO_LIBRARY() + +SRCS( + cgroups.go + doc.go + helpers.go + maxprocs.go +) + +GO_XTEST_SRCS( + example_test.go + maxprocs_test.go +) + +END() + +RECURSE( + example + gotest +) diff --git a/library/go/ptr/ptr.go b/library/go/ptr/ptr.go new file mode 100644 index 0000000000..7ebf3dbd72 --- /dev/null +++ b/library/go/ptr/ptr.go @@ -0,0 +1,75 @@ +package ptr + +import "time" + +// Int returns pointer to provided value +func Int(v int) *int { return &v } + +// Int8 returns pointer to provided value +func Int8(v int8) *int8 { return &v } + +// Int16 returns pointer to provided value +func Int16(v int16) *int16 { return &v } + +// Int32 returns pointer to provided value +func Int32(v int32) *int32 { return &v } + +// Int64 returns pointer to provided value +func Int64(v int64) *int64 { return &v } + +// Uint returns pointer to provided value +func Uint(v uint) *uint { return &v } + +// Uint8 returns pointer to provided value +func Uint8(v uint8) *uint8 { return &v } + +// Uint16 returns pointer to provided value +func Uint16(v uint16) *uint16 { return &v } + +// Uint32 returns pointer to provided value +func Uint32(v uint32) *uint32 { return &v } + +// Uint64 returns pointer to provided value +func Uint64(v uint64) *uint64 { return &v } + +// Float32 returns pointer to provided value +func Float32(v float32) *float32 { return &v } + +// Float64 returns pointer to provided value +func Float64(v float64) *float64 { return &v } + +// Bool returns pointer to provided value +func Bool(v bool) *bool { return &v } + +// String returns pointer to provided value +func String(v string) *string { return &v } + +// Byte returns pointer to provided value +func Byte(v byte) *byte { return &v } + +// Rune returns pointer to provided value +func Rune(v rune) *rune { return &v } + +// Complex64 returns pointer to provided value +func Complex64(v complex64) *complex64 { return &v } + +// Complex128 returns pointer to provided value +func Complex128(v complex128) *complex128 { return &v } + +// Time returns pointer to provided value +func Time(v time.Time) *time.Time { return &v } + +// Duration returns pointer to provided value +func Duration(v time.Duration) *time.Duration { return &v } + +// T returns pointer to provided value +func T[T any](v T) *T { return &v } + +// From returns value from pointer +func From[T any](v *T) T { + if v == nil { + return *new(T) + } + + return *v +} diff --git a/library/go/ptr/ya.make b/library/go/ptr/ya.make new file mode 100644 index 0000000000..17cf07a3c6 --- /dev/null +++ b/library/go/ptr/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(ptr.go) + +END() diff --git a/library/go/slices/chunk.go b/library/go/slices/chunk.go new file mode 100644 index 0000000000..2a69eb475d --- /dev/null +++ b/library/go/slices/chunk.go @@ -0,0 +1,21 @@ +package slices + +func Chunk[T any](slice []T, chunkSize int) [][]T { + if chunkSize < 1 { + return [][]T{slice} + } + chunksCount := len(slice) / chunkSize + if len(slice)%chunkSize > 0 { + chunksCount++ + } + chunks := make([][]T, chunksCount) + + for i := range chunks { + if len(slice) < chunkSize { + chunkSize = len(slice) + } + chunks[i] = slice[0:chunkSize] + slice = slice[chunkSize:] + } + return chunks +} diff --git a/library/go/slices/contains.go b/library/go/slices/contains.go new file mode 100644 index 0000000000..7253b185ed --- /dev/null +++ b/library/go/slices/contains.go @@ -0,0 +1,90 @@ +package slices + +import ( + "bytes" + "net" + + "github.com/gofrs/uuid" + "golang.org/x/exp/slices" +) + +// ContainsString checks if string slice contains given string. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsString = slices.Contains[[]string, string] + +// ContainsBool checks if bool slice contains given bool. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsBool = slices.Contains[[]bool, bool] + +// ContainsInt checks if int slice contains given int +var ContainsInt = slices.Contains[[]int, int] + +// ContainsInt8 checks if int8 slice contains given int8. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt8 = slices.Contains[[]int8, int8] + +// ContainsInt16 checks if int16 slice contains given int16. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt16 = slices.Contains[[]int16, int16] + +// ContainsInt32 checks if int32 slice contains given int32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt32 = slices.Contains[[]int32, int32] + +// ContainsInt64 checks if int64 slice contains given int64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt64 = slices.Contains[[]int64, int64] + +// ContainsUint checks if uint slice contains given uint. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint = slices.Contains[[]uint, uint] + +// ContainsUint8 checks if uint8 slice contains given uint8. +func ContainsUint8(haystack []uint8, needle uint8) bool { + return bytes.IndexByte(haystack, needle) != -1 +} + +// ContainsUint16 checks if uint16 slice contains given uint16. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint16 = slices.Contains[[]uint16, uint16] + +// ContainsUint32 checks if uint32 slice contains given uint32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint32 = slices.Contains[[]uint32, uint32] + +// ContainsUint64 checks if uint64 slice contains given uint64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint64 = slices.Contains[[]uint64, uint64] + +// ContainsFloat32 checks if float32 slice contains given float32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsFloat32 = slices.Contains[[]float32, float32] + +// ContainsFloat64 checks if float64 slice contains given float64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsFloat64 = slices.Contains[[]float64, float64] + +// ContainsByte checks if byte slice contains given byte +func ContainsByte(haystack []byte, needle byte) bool { + return bytes.IndexByte(haystack, needle) != -1 +} + +// ContainsIP checks if net.IP slice contains given net.IP +func ContainsIP(haystack []net.IP, needle net.IP) bool { + for _, e := range haystack { + if e.Equal(needle) { + return true + } + } + return false +} + +// ContainsUUID checks if UUID slice contains given UUID. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUUID = slices.Contains[[]uuid.UUID, uuid.UUID] + +// Contains checks if slice of T contains given T +// Deprecated: use golang.org/x/exp/slices.Contains instead. +func Contains[E comparable](haystack []E, needle E) (bool, error) { + return slices.Contains(haystack, needle), nil +} diff --git a/library/go/slices/contains_all.go b/library/go/slices/contains_all.go new file mode 100644 index 0000000000..3c3e8e1878 --- /dev/null +++ b/library/go/slices/contains_all.go @@ -0,0 +1,23 @@ +package slices + +// ContainsAll checks if slice of type E contains all elements of given slice, order independent +func ContainsAll[E comparable](haystack []E, needle []E) bool { + m := make(map[E]struct{}, len(haystack)) + for _, i := range haystack { + m[i] = struct{}{} + } + for _, v := range needle { + if _, ok := m[v]; !ok { + return false + } + } + return true +} + +// ContainsAllStrings checks if string slice contains all elements of given slice +// Deprecated: use ContainsAll instead +var ContainsAllStrings = ContainsAll[string] + +// ContainsAllBools checks if bool slice contains all elements of given slice +// Deprecated: use ContainsAll instead +var ContainsAllBools = ContainsAll[bool] diff --git a/library/go/slices/contains_any.go b/library/go/slices/contains_any.go new file mode 100644 index 0000000000..0fc6a7ace4 --- /dev/null +++ b/library/go/slices/contains_any.go @@ -0,0 +1,72 @@ +package slices + +import ( + "bytes" +) + +// ContainsAny checks if slice of type E contains any element from given slice +func ContainsAny[E comparable](haystack, needle []E) bool { + return len(Intersection(haystack, needle)) > 0 +} + +// ContainsAnyString checks if string slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyString = ContainsAny[string] + +// ContainsAnyBool checks if bool slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyBool = ContainsAny[bool] + +// ContainsAnyInt checks if int slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt = ContainsAny[int] + +// ContainsAnyInt8 checks if int8 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt8 = ContainsAny[int8] + +// ContainsAnyInt16 checks if int16 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt16 = ContainsAny[int16] + +// ContainsAnyInt32 checks if int32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt32 = ContainsAny[int32] + +// ContainsAnyInt64 checks if int64 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt64 = ContainsAny[int64] + +// ContainsAnyUint checks if uint slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint = ContainsAny[uint] + +// ContainsAnyUint8 checks if uint8 slice contains any element from given slice +func ContainsAnyUint8(haystack []uint8, needle []uint8) bool { + return bytes.Contains(haystack, needle) +} + +// ContainsAnyUint16 checks if uint16 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint16 = ContainsAny[uint16] + +// ContainsAnyUint32 checks if uint32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint32 = ContainsAny[uint32] + +// ContainsAnyUint64 checks if uint64 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint64 = ContainsAny[uint64] + +// ContainsAnyFloat32 checks if float32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyFloat32 = ContainsAny[float32] + +// ContainsAnyFloat64 checks if float64 slice any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyFloat64 = ContainsAny[float64] + +// ContainsAnyByte checks if byte slice contains any element from given slice +func ContainsAnyByte(haystack []byte, needle []byte) bool { + return bytes.Contains(haystack, needle) +} diff --git a/library/go/slices/dedup.go b/library/go/slices/dedup.go new file mode 100644 index 0000000000..365f3b2d74 --- /dev/null +++ b/library/go/slices/dedup.go @@ -0,0 +1,109 @@ +package slices + +import ( + "sort" + + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) + +// Dedup removes duplicate values from slice. +// It will alter original non-empty slice, consider copy it beforehand. +func Dedup[E constraints.Ordered](s []E) []E { + if len(s) < 2 { + return s + } + slices.Sort(s) + tmp := s[:1] + cur := s[0] + for i := 1; i < len(s); i++ { + if s[i] != cur { + tmp = append(tmp, s[i]) + cur = s[i] + } + } + return tmp +} + +// DedupBools removes duplicate values from bool slice. +// It will alter original non-empty slice, consider copy it beforehand. +func DedupBools(a []bool) []bool { + if len(a) < 2 { + return a + } + sort.Slice(a, func(i, j int) bool { return a[i] != a[j] }) + tmp := a[:1] + cur := a[0] + for i := 1; i < len(a); i++ { + if a[i] != cur { + tmp = append(tmp, a[i]) + cur = a[i] + } + } + return tmp +} + +// DedupStrings removes duplicate values from string slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupStrings = Dedup[string] + +// DedupInts removes duplicate values from ints slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInts = Dedup[int] + +// DedupInt8s removes duplicate values from int8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt8s = Dedup[int8] + +// DedupInt16s removes duplicate values from int16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt16s = Dedup[int16] + +// DedupInt32s removes duplicate values from int32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt32s = Dedup[int32] + +// DedupInt64s removes duplicate values from int64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt64s = Dedup[int64] + +// DedupUints removes duplicate values from uint slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUints = Dedup[uint] + +// DedupUint8s removes duplicate values from uint8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint8s = Dedup[uint8] + +// DedupUint16s removes duplicate values from uint16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint16s = Dedup[uint16] + +// DedupUint32s removes duplicate values from uint32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint32s = Dedup[uint32] + +// DedupUint64s removes duplicate values from uint64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint64s = Dedup[uint64] + +// DedupFloat32s removes duplicate values from float32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupFloat32s = Dedup[float32] + +// DedupFloat64s removes duplicate values from float64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupFloat64s = Dedup[float64] diff --git a/library/go/slices/equal.go b/library/go/slices/equal.go new file mode 100644 index 0000000000..6f8ae9973e --- /dev/null +++ b/library/go/slices/equal.go @@ -0,0 +1,24 @@ +package slices + +// EqualUnordered checks if slices of type E are equal, order independent. +func EqualUnordered[E comparable](a []E, b []E) bool { + if len(a) != len(b) { + return false + } + + ma := make(map[E]int) + for _, v := range a { + ma[v]++ + } + for _, v := range b { + if ma[v] == 0 { + return false + } + ma[v]-- + } + return true +} + +// EqualAnyOrderStrings checks if string slices are equal, order independent. +// Deprecated: use EqualUnordered instead. +var EqualAnyOrderStrings = EqualUnordered[string] diff --git a/library/go/slices/filter.go b/library/go/slices/filter.go new file mode 100644 index 0000000000..8b383bfcb2 --- /dev/null +++ b/library/go/slices/filter.go @@ -0,0 +1,29 @@ +package slices + +import ( + "golang.org/x/exp/slices" +) + +// Filter reduces slice values using given function. +// It operates with a copy of given slice +func Filter[S ~[]T, T any](s S, fn func(T) bool) S { + if len(s) == 0 { + return s + } + return Reduce(slices.Clone(s), fn) +} + +// Reduce is like Filter, but modifies original slice. +func Reduce[S ~[]T, T any](s S, fn func(T) bool) S { + if len(s) == 0 { + return s + } + var p int + for _, v := range s { + if fn(v) { + s[p] = v + p++ + } + } + return s[:p] +} diff --git a/library/go/slices/group_by.go b/library/go/slices/group_by.go new file mode 100644 index 0000000000..fb61a29314 --- /dev/null +++ b/library/go/slices/group_by.go @@ -0,0 +1,90 @@ +package slices + +import ( + "fmt" +) + +func createNotUniqueKeyError[T comparable](key T) error { + return fmt.Errorf("duplicated key \"%v\" found. keys are supposed to be unique", key) +} + +// GroupBy groups slice entities into map by key provided via keyGetter. +func GroupBy[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) map[K][]T { + res := map[K][]T{} + + for _, entity := range s { + key := keyGetter(entity) + res[key] = append(res[key], entity) + } + + return res +} + +// GroupByUniqueKey groups slice entities into map by key provided via keyGetter with assumption that each key is unique. +// +// Returns an error in case of key ununiqueness. +func GroupByUniqueKey[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) (map[K]T, error) { + res := map[K]T{} + + for _, entity := range s { + key := keyGetter(entity) + + _, duplicated := res[key] + if duplicated { + return res, createNotUniqueKeyError(key) + } + + res[key] = entity + } + + return res, nil +} + +// IndexedEntity stores an entity of original slice with its initial index in that slice +type IndexedEntity[T any] struct { + Value T + Index int +} + +// GroupByWithIndex groups slice entities into map by key provided via keyGetter. +// Each entity of underlying result slice contains the value itself and its index in the original slice +// (See IndexedEntity). +func GroupByWithIndex[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) map[K][]IndexedEntity[T] { + res := map[K][]IndexedEntity[T]{} + + for i, entity := range s { + key := keyGetter(entity) + res[key] = append(res[key], IndexedEntity[T]{ + Value: entity, + Index: i, + }) + } + + return res +} + +// GroupByUniqueKeyWithIndex groups slice entities into map by key provided via keyGetter with assumption that +// each key is unique. +// Each result entity contains the value itself and its index in the original slice +// (See IndexedEntity). +// +// Returns an error in case of key ununiqueness. +func GroupByUniqueKeyWithIndex[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) (map[K]IndexedEntity[T], error) { + res := map[K]IndexedEntity[T]{} + + for i, entity := range s { + key := keyGetter(entity) + + _, duplicated := res[key] + if duplicated { + return res, createNotUniqueKeyError(key) + } + + res[key] = IndexedEntity[T]{ + Value: entity, + Index: i, + } + } + + return res, nil +} diff --git a/library/go/slices/intersects.go b/library/go/slices/intersects.go new file mode 100644 index 0000000000..d40c0e8d29 --- /dev/null +++ b/library/go/slices/intersects.go @@ -0,0 +1,86 @@ +package slices + +// Intersection returns intersection for slices of various built-in types. +// +// Note that this function does not perform deduplication on result slice, +// expect duplicate entries to be present in it. +func Intersection[E comparable](a, b []E) []E { + if len(a) == 0 || len(b) == 0 { + return nil + } + + p, s := a, b + if len(b) > len(a) { + p, s = b, a + } + + m := make(map[E]struct{}) + for _, i := range s { + m[i] = struct{}{} + } + + var res []E + for _, v := range p { + if _, exists := m[v]; exists { + res = append(res, v) + } + } + + return res +} + +// IntersectStrings returns intersection of two string slices +// Deprecated: use Intersection instead. +var IntersectStrings = Intersection[string] + +// IntersectInts returns intersection of two int slices +// Deprecated: use Intersection instead. +var IntersectInts = Intersection[int] + +// IntersectInt8s returns intersection of two int8 slices +// Deprecated: use Intersection instead. +var IntersectInt8s = Intersection[int8] + +// IntersectInt16s returns intersection of two int16 slices +// Deprecated: use Intersection instead. +var IntersectInt16s = Intersection[int16] + +// IntersectInt32s returns intersection of two int32 slices +// Deprecated: use Intersection instead. +var IntersectInt32s = Intersection[int32] + +// IntersectInt64s returns intersection of two int64 slices +// Deprecated: use Intersection instead. +var IntersectInt64s = Intersection[int64] + +// IntersectUints returns intersection of two uint slices +// Deprecated: use Intersection instead. +var IntersectUints = Intersection[uint] + +// IntersectUint8s returns intersection of two uint8 slices +// Deprecated: use Intersection instead. +var IntersectUint8s = Intersection[uint8] + +// IntersectUint16s returns intersection of two uint16 slices +// Deprecated: use Intersection instead. +var IntersectUint16s = Intersection[uint16] + +// IntersectUint32s returns intersection of two uint32 slices +// Deprecated: use Intersection instead. +var IntersectUint32s = Intersection[uint32] + +// IntersectUint64s returns intersection of two uint64 slices +// Deprecated: use Intersection instead. +var IntersectUint64s = Intersection[uint64] + +// IntersectFloat32s returns intersection of two float32 slices +// Deprecated: use Intersection instead. +var IntersectFloat32s = Intersection[float32] + +// IntersectFloat64s returns intersection of two float64 slices +// Deprecated: use Intersection instead. +var IntersectFloat64s = Intersection[float64] + +// IntersectBools returns intersection of two bool slices +// Deprecated: use Intersection instead. +var IntersectBools = Intersection[bool] diff --git a/library/go/slices/join.go b/library/go/slices/join.go new file mode 100644 index 0000000000..7b72db5ed1 --- /dev/null +++ b/library/go/slices/join.go @@ -0,0 +1,14 @@ +package slices + +import ( + "fmt" + "strings" +) + +// Join joins slice of any types +func Join(s interface{}, glue string) string { + if t, ok := s.([]string); ok { + return strings.Join(t, glue) + } + return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(s)), glue), "[]") +} diff --git a/library/go/slices/map.go b/library/go/slices/map.go new file mode 100644 index 0000000000..943261f786 --- /dev/null +++ b/library/go/slices/map.go @@ -0,0 +1,27 @@ +package slices + +// Map applies given function to every value of slice +func Map[S ~[]T, T, M any](s S, fn func(T) M) []M { + if s == nil { + return []M(nil) + } + if len(s) == 0 { + return make([]M, 0) + } + res := make([]M, len(s)) + for i, v := range s { + res[i] = fn(v) + } + return res +} + +// Mutate is like Map, but it prohibits type changes and modifies original slice. +func Mutate[S ~[]T, T any](s S, fn func(T) T) S { + if len(s) == 0 { + return s + } + for i, v := range s { + s[i] = fn(v) + } + return s +} diff --git a/library/go/slices/reverse.go b/library/go/slices/reverse.go new file mode 100644 index 0000000000..a436617b67 --- /dev/null +++ b/library/go/slices/reverse.go @@ -0,0 +1,83 @@ +package slices + +// Reverse reverses given slice. +// It will alter original non-empty slice, consider copy it beforehand. +func Reverse[E any](s []E) []E { + if len(s) < 2 { + return s + } + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } + return s +} + +// ReverseStrings reverses given string slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseStrings = Reverse[string] + +// ReverseInts reverses given int slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInts = Reverse[int] + +// ReverseInt8s reverses given int8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt8s = Reverse[int8] + +// ReverseInt16s reverses given int16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt16s = Reverse[int16] + +// ReverseInt32s reverses given int32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt32s = Reverse[int32] + +// ReverseInt64s reverses given int64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt64s = Reverse[int64] + +// ReverseUints reverses given uint slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUints = Reverse[uint] + +// ReverseUint8s reverses given uint8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint8s = Reverse[uint8] + +// ReverseUint16s reverses given uint16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint16s = Reverse[uint16] + +// ReverseUint32s reverses given uint32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint32s = Reverse[uint32] + +// ReverseUint64s reverses given uint64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint64s = Reverse[uint64] + +// ReverseFloat32s reverses given float32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseFloat32s = Reverse[float32] + +// ReverseFloat64s reverses given float64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseFloat64s = Reverse[float64] + +// ReverseBools reverses given bool slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseBools = Reverse[bool] diff --git a/library/go/slices/shuffle.go b/library/go/slices/shuffle.go new file mode 100644 index 0000000000..5df9b33c3c --- /dev/null +++ b/library/go/slices/shuffle.go @@ -0,0 +1,95 @@ +package slices + +import ( + "math/rand" +) + +// Shuffle shuffles values in slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +func Shuffle[E any](a []E, src rand.Source) []E { + if len(a) < 2 { + return a + } + shuffle(src)(len(a), func(i, j int) { + a[i], a[j] = a[j], a[i] + }) + return a +} + +// ShuffleStrings shuffles values in string slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleStrings = Shuffle[string] + +// ShuffleInts shuffles values in int slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInts = Shuffle[int] + +// ShuffleInt8s shuffles values in int8 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt8s = Shuffle[int8] + +// ShuffleInt16s shuffles values in int16 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt16s = Shuffle[int16] + +// ShuffleInt32s shuffles values in int32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt32s = Shuffle[int32] + +// ShuffleInt64s shuffles values in int64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt64s = Shuffle[int64] + +// ShuffleUints shuffles values in uint slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUints = Shuffle[uint] + +// ShuffleUint8s shuffles values in uint8 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint8s = Shuffle[uint8] + +// ShuffleUint16s shuffles values in uint16 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint16s = Shuffle[uint16] + +// ShuffleUint32s shuffles values in uint32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint32s = Shuffle[uint32] + +// ShuffleUint64s shuffles values in uint64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint64s = Shuffle[uint64] + +// ShuffleFloat32s shuffles values in float32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleFloat32s = Shuffle[float32] + +// ShuffleFloat64s shuffles values in float64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleFloat64s = Shuffle[float64] + +// ShuffleBools shuffles values in bool slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleBools = Shuffle[bool] + +func shuffle(src rand.Source) func(n int, swap func(i, j int)) { + shuf := rand.Shuffle + if src != nil { + shuf = rand.New(src).Shuffle + } + return shuf +} diff --git a/library/go/slices/ya.make b/library/go/slices/ya.make new file mode 100644 index 0000000000..97a793cfd1 --- /dev/null +++ b/library/go/slices/ya.make @@ -0,0 +1,34 @@ +GO_LIBRARY() + +SRCS( + chunk.go + contains.go + contains_all.go + contains_any.go + dedup.go + equal.go + filter.go + group_by.go + intersects.go + join.go + map.go + reverse.go + shuffle.go +) + +GO_XTEST_SRCS( + chunk_test.go + dedup_test.go + equal_test.go + filter_test.go + group_by_test.go + intersects_test.go + join_test.go + map_test.go + reverse_test.go + shuffle_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/test/assertpb/assert.go b/library/go/test/assertpb/assert.go new file mode 100644 index 0000000000..f5420748d2 --- /dev/null +++ b/library/go/test/assertpb/assert.go @@ -0,0 +1,35 @@ +package assertpb + +import ( + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" +) + +type TestingT interface { + Errorf(format string, args ...interface{}) + FailNow() + Helper() +} + +func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + t.Helper() + + if cmp.Equal(expected, actual, cmp.Comparer(proto.Equal)) { + return true + } + + diff := cmp.Diff(expected, actual, cmp.Comparer(proto.Equal)) + return assert.Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s\n"+ + "diff : %s", expected, actual, diff), msgAndArgs) +} + +func Equalf(t TestingT, expected, actual interface{}, msg string, args ...interface{}) bool { + t.Helper() + + return Equal(t, expected, actual, append([]interface{}{msg}, args...)...) +} diff --git a/library/go/test/assertpb/ya.make b/library/go/test/assertpb/ya.make new file mode 100644 index 0000000000..109571f55a --- /dev/null +++ b/library/go/test/assertpb/ya.make @@ -0,0 +1,9 @@ +GO_LIBRARY() + +SRCS(assert.go) + +GO_TEST_SRCS(assert_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/x/xsync/singleinflight.go b/library/go/x/xsync/singleinflight.go new file mode 100644 index 0000000000..3beee1ea67 --- /dev/null +++ b/library/go/x/xsync/singleinflight.go @@ -0,0 +1,36 @@ +package xsync + +import ( + "sync" + "sync/atomic" +) + +// SingleInflight allows only one execution of function at time. +// For more exhaustive functionality see https://pkg.go.dev/golang.org/x/sync/singleflight. +type SingleInflight struct { + updatingOnce atomic.Value +} + +// NewSingleInflight creates new SingleInflight. +func NewSingleInflight() SingleInflight { + var v atomic.Value + v.Store(new(sync.Once)) + return SingleInflight{updatingOnce: v} +} + +// Do executes the given function, making sure that only one execution is in-flight. +// If another caller comes in, it waits for the original to complete. +func (i *SingleInflight) Do(f func()) { + i.getOnce().Do(func() { + f() + i.setOnce() + }) +} + +func (i *SingleInflight) getOnce() *sync.Once { + return i.updatingOnce.Load().(*sync.Once) +} + +func (i *SingleInflight) setOnce() { + i.updatingOnce.Store(new(sync.Once)) +} diff --git a/library/go/x/xsync/ya.make b/library/go/x/xsync/ya.make new file mode 100644 index 0000000000..cffc9f89b8 --- /dev/null +++ b/library/go/x/xsync/ya.make @@ -0,0 +1,9 @@ +GO_LIBRARY() + +SRCS(singleinflight.go) + +GO_TEST_SRCS(singleinflight_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/yandex/deploy/podagent/client.go b/library/go/yandex/deploy/podagent/client.go new file mode 100644 index 0000000000..8f87d0e682 --- /dev/null +++ b/library/go/yandex/deploy/podagent/client.go @@ -0,0 +1,66 @@ +package podagent + +import ( + "context" + "time" + + "github.com/go-resty/resty/v2" + "github.com/ydb-platform/ydb/library/go/core/xerrors" + "github.com/ydb-platform/ydb/library/go/httputil/headers" +) + +const ( + EndpointURL = "http://127.0.0.1:1/" + HTTPTimeout = 500 * time.Millisecond +) + +type Client struct { + httpc *resty.Client +} + +func NewClient(opts ...Option) *Client { + c := &Client{ + httpc: resty.New(). + SetBaseURL(EndpointURL). + SetTimeout(HTTPTimeout), + } + + for _, opt := range opts { + opt(c) + } + return c +} + +// PodAttributes returns current pod attributes. +// +// Documentation: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api#localhost:1pod_attributes +func (c *Client) PodAttributes(ctx context.Context) (rsp PodAttributesResponse, err error) { + err = c.call(ctx, "/pod_attributes", &rsp) + return +} + +// PodStatus returns current pod status. +// +// Documentation: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api#localhost:1pod_status +func (c *Client) PodStatus(ctx context.Context) (rsp PodStatusResponse, err error) { + err = c.call(ctx, "/pod_status", &rsp) + return +} + +func (c *Client) call(ctx context.Context, handler string, result interface{}) error { + rsp, err := c.httpc.R(). + SetContext(ctx). + ExpectContentType(headers.TypeApplicationJSON.String()). + SetResult(&result). + Get(handler) + + if err != nil { + return xerrors.Errorf("failed to request pod agent API: %w", err) + } + + if !rsp.IsSuccess() { + return xerrors.Errorf("unexpected status code: %d", rsp.StatusCode()) + } + + return nil +} diff --git a/library/go/yandex/deploy/podagent/doc.go b/library/go/yandex/deploy/podagent/doc.go new file mode 100644 index 0000000000..326b84040f --- /dev/null +++ b/library/go/yandex/deploy/podagent/doc.go @@ -0,0 +1,4 @@ +// Package podagent provides the client and types for making API requests to Y.Deploy PodAgent. +// +// Official documentation for PogAgent public API: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api +package podagent diff --git a/library/go/yandex/deploy/podagent/env.go b/library/go/yandex/deploy/podagent/env.go new file mode 100644 index 0000000000..4dd4ae1790 --- /dev/null +++ b/library/go/yandex/deploy/podagent/env.go @@ -0,0 +1,33 @@ +package podagent + +import "os" + +// Box/Workload environment variable names, documentation references: +// - https://deploy.yandex-team.ru/docs/concepts/pod/box#systemenv +// - https://deploy.yandex-team.ru/docs/concepts/pod/workload/workload#system_env +const ( + EnvWorkloadIDKey = "DEPLOY_WORKLOAD_ID" + EnvContainerIDKey = "DEPLOY_CONTAINER_ID" + EnvBoxIDKey = "DEPLOY_BOX_ID" + EnvPodIDKey = "DEPLOY_POD_ID" + EnvProjectIDKey = "DEPLOY_PROJECT_ID" + EnvStageIDKey = "DEPLOY_STAGE_ID" + EnvUnitIDKey = "DEPLOY_UNIT_ID" + + EnvLogsEndpointKey = "DEPLOY_LOGS_ENDPOINT" + EnvLogsNameKey = "DEPLOY_LOGS_DEFAULT_NAME" + EnvLogsSecretKey = "DEPLOY_LOGS_SECRET" + + EnvNodeClusterKey = "DEPLOY_NODE_CLUSTER" + EnvNodeDCKey = "DEPLOY_NODE_DC" + EnvNodeFQDNKey = "DEPLOY_NODE_FQDN" + + EnvPodPersistentFQDN = "DEPLOY_POD_PERSISTENT_FQDN" + EnvPodTransientFQDN = "DEPLOY_POD_TRANSIENT_FQDN" +) + +// UnderPodAgent returns true if application managed by pod-agent. +func UnderPodAgent() bool { + _, ok := os.LookupEnv(EnvPodIDKey) + return ok +} diff --git a/library/go/yandex/deploy/podagent/options.go b/library/go/yandex/deploy/podagent/options.go new file mode 100644 index 0000000000..f0ab9ba4c3 --- /dev/null +++ b/library/go/yandex/deploy/podagent/options.go @@ -0,0 +1,17 @@ +package podagent + +import "github.com/ydb-platform/ydb/library/go/core/log" + +type Option func(client *Client) + +func WithEndpoint(endpointURL string) Option { + return func(c *Client) { + c.httpc.SetBaseURL(endpointURL) + } +} + +func WithLogger(l log.Fmt) Option { + return func(c *Client) { + c.httpc.SetLogger(l) + } +} diff --git a/library/go/yandex/deploy/podagent/responses.go b/library/go/yandex/deploy/podagent/responses.go new file mode 100644 index 0000000000..e97c70dc7c --- /dev/null +++ b/library/go/yandex/deploy/podagent/responses.go @@ -0,0 +1,82 @@ +package podagent + +import ( + "encoding/json" + "net" +) + +type BoxStatus struct { + ID string `json:"id"` + Revision uint32 `json:"revision"` +} + +type WorkloadStatus struct { + ID string `json:"id"` + Revision uint32 `json:"revision"` +} + +type PodStatusResponse struct { + Boxes []BoxStatus `json:"boxes"` + Workloads []WorkloadStatus `json:"workloads"` +} + +type MemoryResource struct { + Guarantee uint64 `json:"memory_guarantee_bytes"` + Limit uint64 `json:"memory_limit_bytes"` +} + +type CPUResource struct { + Guarantee float64 `json:"cpu_guarantee_millicores"` + Limit float64 `json:"cpu_limit_millicores"` +} + +type ResourceRequirements struct { + Memory MemoryResource `json:"memory"` + CPU CPUResource `json:"cpu"` +} + +type NodeMeta struct { + DC string `json:"dc"` + Cluster string `json:"cluster"` + FQDN string `json:"fqdn"` +} + +type PodMeta struct { + PodID string `json:"pod_id"` + PodSetID string `json:"pod_set_id"` + Annotations json.RawMessage `json:"annotations"` + Labels json.RawMessage `json:"labels"` +} + +type Resources struct { + Boxes map[string]ResourceRequirements `json:"box_resource_requirements"` + Pod ResourceRequirements `json:"resource_requirements"` +} + +type InternetAddress struct { + Address net.IP `json:"ip4_address"` + ID string `json:"id"` +} + +type VirtualService struct { + IPv4Addrs []net.IP `json:"ip4_addresses"` + IPv6Addrs []net.IP `json:"ip6_addresses"` +} + +type IPAllocation struct { + InternetAddress InternetAddress `json:"internet_address"` + TransientFQDN string `json:"transient_fqdn"` + PersistentFQDN string `json:"persistent_fqdn"` + Addr net.IP `json:"address"` + VlanID string `json:"vlan_id"` + VirtualServices []VirtualService `json:"virtual_services"` + Labels map[string]string `json:"labels"` +} + +type PodAttributesResponse struct { + NodeMeta NodeMeta `json:"node_meta"` + PodMeta PodMeta `json:"metadata"` + BoxesRequirements map[string]ResourceRequirements `json:"box_resource_requirements"` + PodRequirements ResourceRequirements `json:"resource_requirements"` + IPAllocations []IPAllocation `json:"ip6_address_allocations"` +} diff --git a/library/go/yandex/deploy/podagent/ya.make b/library/go/yandex/deploy/podagent/ya.make new file mode 100644 index 0000000000..4ae3d12925 --- /dev/null +++ b/library/go/yandex/deploy/podagent/ya.make @@ -0,0 +1,15 @@ +GO_LIBRARY() + +SRCS( + client.go + doc.go + env.go + options.go + responses.go +) + +GO_XTEST_SRCS(client_test.go) + +END() + +RECURSE(gotest) diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/example_test.go b/library/go/yandex/solomon/reporters/puller/httppuller/example_test.go new file mode 100644 index 0000000000..c04c81168d --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/example_test.go @@ -0,0 +1,40 @@ +package httppuller_test + +import ( + "net/http" + "time" + + "github.com/ydb-platform/ydb/library/go/core/metrics/solomon" + "github.com/ydb-platform/ydb/library/go/yandex/solomon/reporters/puller/httppuller" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +func ExampleNewHandler() { + // create metrics registry + opts := solomon.NewRegistryOpts(). + SetSeparator('_'). + SetPrefix("myprefix") + + reg := solomon.NewRegistry(opts) + + // register new metric + cnt := reg.Counter("cyclesCount") + + // pass metric to your function and do job + go func() { + for { + cnt.Inc() + time.Sleep(1 * time.Second) + } + }() + + // start HTTP server with handler on /metrics URI + mux := http.NewServeMux() + mux.Handle("/metrics", httppuller.NewHandler(reg)) + + // Or start + var tvm tvm.Client + mux.Handle("/secure_metrics", httppuller.NewHandler(reg, httppuller.WithTVM(tvm))) + + _ = http.ListenAndServe(":80", mux) +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/gotest/ya.make b/library/go/yandex/solomon/reporters/puller/httppuller/gotest/ya.make new file mode 100644 index 0000000000..cf11e75a33 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/yandex/solomon/reporters/puller/httppuller) + +END() diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/handler.go b/library/go/yandex/solomon/reporters/puller/httppuller/handler.go new file mode 100644 index 0000000000..9521d41bdc --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/handler.go @@ -0,0 +1,120 @@ +package httppuller + +import ( + "context" + "fmt" + "io" + "net/http" + "reflect" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/core/metrics/solomon" + "github.com/ydb-platform/ydb/library/go/httputil/headers" + "github.com/ydb-platform/ydb/library/go/httputil/middleware/tvm" +) + +const nilRegistryPanicMsg = "nil registry given" + +type MetricsStreamer interface { + StreamJSON(context.Context, io.Writer) (int, error) + StreamSpack(context.Context, io.Writer, solomon.CompressionType) (int, error) +} + +type handler struct { + registry MetricsStreamer + streamFormat headers.ContentType + checkTicket func(h http.Handler) http.Handler + logger log.Logger +} + +type Option interface { + isOption() +} + +// NewHandler returns new HTTP handler to expose gathered metrics using metrics dumper +func NewHandler(r MetricsStreamer, opts ...Option) http.Handler { + if v := reflect.ValueOf(r); !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() { + panic(nilRegistryPanicMsg) + } + + h := handler{ + registry: r, + streamFormat: headers.TypeApplicationJSON, + checkTicket: func(h http.Handler) http.Handler { + return h + }, + logger: &nop.Logger{}, + } + + for _, opt := range opts { + switch o := opt.(type) { + case *tvmOption: + h.checkTicket = tvm.CheckServiceTicket(o.client, tvm.WithAllowedClients(AllFetchers)) + case *spackOption: + h.streamFormat = headers.TypeApplicationXSolomonSpack + case *loggerOption: + h.logger = o.logger + default: + panic(fmt.Sprintf("unsupported option %T", opt)) + } + } + + return h.checkTicket(h) +} + +func (h handler) okSpack(header http.Header) bool { + if h.streamFormat != headers.TypeApplicationXSolomonSpack { + return false + } + for _, header := range header[headers.AcceptKey] { + types, err := headers.ParseAccept(header) + if err != nil { + h.logger.Warn("Can't parse accept header", log.Error(err), log.String("header", header)) + continue + } + for _, acceptableType := range types { + if acceptableType.Type == headers.TypeApplicationXSolomonSpack { + return true + } + } + } + return false +} + +func (h handler) okLZ4Compression(header http.Header) bool { + for _, header := range header[headers.AcceptEncodingKey] { + encodings, err := headers.ParseAcceptEncoding(header) + if err != nil { + h.logger.Warn("Can't parse accept-encoding header", log.Error(err), log.String("header", header)) + continue + } + for _, acceptableEncoding := range encodings { + if acceptableEncoding.Encoding == headers.EncodingLZ4 { + return true + } + } + } + return false +} + +func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.okSpack(r.Header) { + compression := solomon.CompressionNone + if h.okLZ4Compression(r.Header) { + compression = solomon.CompressionLz4 + } + w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationXSolomonSpack.String()) + _, err := h.registry.StreamSpack(r.Context(), w, compression) + if err != nil { + h.logger.Error("Failed to write compressed spack", log.Error(err)) + } + return + } + + w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String()) + _, err := h.registry.StreamJSON(r.Context(), w) + if err != nil { + h.logger.Error("Failed to write json", log.Error(err)) + } +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/handler_test.go b/library/go/yandex/solomon/reporters/puller/httppuller/handler_test.go new file mode 100644 index 0000000000..686f9b60f9 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/handler_test.go @@ -0,0 +1,197 @@ +package httppuller + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/core/metrics" + "github.com/ydb-platform/ydb/library/go/core/metrics/solomon" + "github.com/ydb-platform/ydb/library/go/httputil/headers" +) + +type testMetricsData struct { + Metrics []struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Histogram struct { + Bounds []float64 `json:"bounds"` + Buckets []int64 `json:"buckets"` + Inf int64 `json:"inf"` + } `json:"hist"` + } `json:"metrics"` +} + +type testStreamer struct{} + +func (s testStreamer) StreamJSON(context.Context, io.Writer) (int, error) { return 0, nil } +func (s testStreamer) StreamSpack(context.Context, io.Writer, solomon.CompressionType) (int, error) { + return 0, nil +} + +func TestHandler_NewHandler(t *testing.T) { + assert.PanicsWithValue(t, nilRegistryPanicMsg, func() { NewHandler(nil) }) + assert.PanicsWithValue(t, nilRegistryPanicMsg, func() { var s *solomon.Registry; NewHandler(s) }) + assert.PanicsWithValue(t, nilRegistryPanicMsg, func() { var ts *testStreamer; NewHandler(ts) }) + assert.NotPanics(t, func() { NewHandler(&solomon.Registry{}) }) + assert.NotPanics(t, func() { NewHandler(&testStreamer{}) }) + assert.NotPanics(t, func() { NewHandler(testStreamer{}) }) +} + +func TestHandler_ServeHTTP(t *testing.T) { + testCases := []struct { + name string + registry *solomon.Registry + expectStatus int + expectedApplicationType headers.ContentType + expectBody []byte + }{ + { + "success_json", + func() *solomon.Registry { + r := solomon.NewRegistry(solomon.NewRegistryOpts()) + + cnt := r.Counter("mycounter") + cnt.Add(42) + + gg := r.Gauge("mygauge") + gg.Set(2.4) + + hs := r.Histogram("myhistogram", metrics.NewBuckets(1, 2, 3)) + hs.RecordValue(0.5) + hs.RecordValue(1.5) + hs.RecordValue(1.7) + hs.RecordValue(2.2) + hs.RecordValue(42) + + return r + }(), + http.StatusOK, + headers.TypeApplicationJSON, + []byte(` + { + "metrics": [ + { + "type": "COUNTER", + "labels": { + "sensor": "mycounter" + }, + "value": 42 + }, + { + "type": "DGAUGE", + "labels": { + "sensor": "mygauge" + }, + "value": 2.4 + }, + { + "type": "HIST", + "labels": { + "sensor": "myhistogram" + }, + "hist": { + "bounds": [ + 1, + 2, + 3 + ], + "buckets": [ + 1, + 2, + 1 + ], + "inf": 1 + } + } + ] + } + `), + }, + { + "success_spack", + func() *solomon.Registry { + r := solomon.NewRegistry(solomon.NewRegistryOpts()) + _ = r.Histogram("histogram", metrics.NewBuckets(0, 0.1, 0.11)) + return r + }(), + http.StatusOK, + headers.TypeApplicationXSolomonSpack, + []byte{ + 0x53, 0x50, // magic + 0x01, 0x01, // version + 0x18, 0x00, // header size + 0x0, // time precision + 0x0, // compression algorithm + 0x7, 0x0, 0x0, 0x0, // label names size + 0xa, 0x0, 0x0, 0x0, // label values size + 0x1, 0x0, 0x0, 0x0, // metric count + 0x1, 0x0, 0x0, 0x0, // point count + // label names pool + 0x73, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x0, // "sensor" + // label values pool + 0x68, 0x69, 0x73, 0x74, 0x6F, 0x67, 0x72, 0x61, 0x6D, 0x0, // "histogram" + // common time + 0x0, 0x0, 0x0, 0x0, + // common labels + 0x0, + /*types*/ 0x15, + /*flags*/ 0x0, + /*labels*/ 0x1, // ? + /*name*/ 0x0, + /*value*/ 0x0, + /*buckets count*/ 0x3, + /*upper bound 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*upper bound 1*/ 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f, + /*upper bound 2*/ 0x29, 0x5c, 0x8f, 0xc2, 0xf5, 0x28, 0xbc, 0x3f, + /*counter 0*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 1*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + /*counter 2*/ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/metrics", nil) + + var h http.Handler + if tc.expectedApplicationType == headers.TypeApplicationXSolomonSpack { + h = NewHandler(tc.registry, WithSpack()) + } else { + h = NewHandler(tc.registry) + } + + r.Header.Set(headers.AcceptKey, tc.expectedApplicationType.String()) + h.ServeHTTP(w, r) + assert.Equal(t, tc.expectStatus, w.Code) + assert.Equal(t, tc.expectedApplicationType.String(), w.Header().Get(headers.ContentTypeKey)) + + if tc.expectedApplicationType == headers.TypeApplicationXSolomonSpack { + assert.EqualValues(t, tc.expectBody, w.Body.Bytes()) + } else { + var expectedObj, givenObj testMetricsData + err := json.Unmarshal(tc.expectBody, &expectedObj) + assert.NoError(t, err) + err = json.Unmarshal(w.Body.Bytes(), &givenObj) + assert.NoError(t, err) + + sort.Slice(expectedObj.Metrics, func(i, j int) bool { + return expectedObj.Metrics[i].Type < expectedObj.Metrics[j].Type + }) + sort.Slice(givenObj.Metrics, func(i, j int) bool { + return givenObj.Metrics[i].Type < givenObj.Metrics[j].Type + }) + + assert.EqualValues(t, expectedObj, givenObj) + } + }) + } +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/logger.go b/library/go/yandex/solomon/reporters/puller/httppuller/logger.go new file mode 100644 index 0000000000..19fe4bf733 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/logger.go @@ -0,0 +1,15 @@ +package httppuller + +import "github.com/ydb-platform/ydb/library/go/core/log" + +type loggerOption struct { + logger log.Logger +} + +func (*loggerOption) isOption() {} + +func WithLogger(logger log.Logger) Option { + return &loggerOption{ + logger: logger, + } +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/spack.go b/library/go/yandex/solomon/reporters/puller/httppuller/spack.go new file mode 100644 index 0000000000..cf59abd52a --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/spack.go @@ -0,0 +1,10 @@ +package httppuller + +type spackOption struct { +} + +func (*spackOption) isOption() {} + +func WithSpack() Option { + return &spackOption{} +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go b/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go new file mode 100644 index 0000000000..e6afeec115 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go @@ -0,0 +1,27 @@ +package httppuller + +import "github.com/ydb-platform/ydb/library/go/yandex/tvm" + +const ( + FetcherPreTVMID = 2012024 + FetcherTestTVMID = 2012026 + FetcherProdTVMID = 2012028 +) + +var ( + AllFetchers = []tvm.ClientID{ + FetcherPreTVMID, + FetcherTestTVMID, + FetcherProdTVMID, + } +) + +type tvmOption struct { + client tvm.Client +} + +func (*tvmOption) isOption() {} + +func WithTVM(tvm tvm.Client) Option { + return &tvmOption{client: tvm} +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/tvm_test.go b/library/go/yandex/solomon/reporters/puller/httppuller/tvm_test.go new file mode 100644 index 0000000000..8eb4d27942 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/tvm_test.go @@ -0,0 +1,80 @@ +package httppuller_test + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ydb-platform/ydb/library/go/core/metrics/solomon" + "github.com/ydb-platform/ydb/library/go/yandex/solomon/reporters/puller/httppuller" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +type fakeTVMClient struct{} + +func (f *fakeTVMClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + return "", &tvm.Error{Code: tvm.ErrorMissingServiceTicket} +} + +func (f *fakeTVMClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + return "", &tvm.Error{Code: tvm.ErrorMissingServiceTicket} +} + +func (f *fakeTVMClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + if ticket == "qwerty" { + return &tvm.CheckedServiceTicket{SrcID: httppuller.FetcherProdTVMID}, nil + } + + return nil, &tvm.Error{Code: tvm.ErrorMissingServiceTicket} +} + +func (f *fakeTVMClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + return nil, &tvm.Error{Code: tvm.ErrorMissingServiceTicket} +} + +func (f *fakeTVMClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + return tvm.ClientStatusInfo{}, &tvm.Error{Code: tvm.ErrorMissingServiceTicket} +} + +func (f *fakeTVMClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + return nil, errors.New("not implemented") +} + +var _ tvm.Client = &fakeTVMClient{} + +func TestHandler_ServiceTicketValidation(t *testing.T) { + registry := solomon.NewRegistry(solomon.NewRegistryOpts()) + h := httppuller.NewHandler(registry, httppuller.WithTVM(&fakeTVMClient{})) + + t.Run("MissingTicket", func(t *testing.T) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/metrics", nil) + + h.ServeHTTP(w, r) + assert.Equal(t, 403, w.Code) + assert.Equal(t, "missing service ticket\n", w.Body.String()) + }) + + t.Run("InvalidTicket", func(t *testing.T) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/metrics", nil) + r.Header.Add("X-Ya-Service-Ticket", "123456") + + h.ServeHTTP(w, r) + assert.Equal(t, 403, w.Code) + assert.Truef(t, strings.HasPrefix(w.Body.String(), "service ticket check failed"), "body=%q", w.Body.String()) + }) + + t.Run("GoodTicket", func(t *testing.T) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/metrics", nil) + r.Header.Add("X-Ya-Service-Ticket", "qwerty") + + h.ServeHTTP(w, r) + assert.Equal(t, 200, w.Code) + }) +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/ya.make b/library/go/yandex/solomon/reporters/puller/httppuller/ya.make new file mode 100644 index 0000000000..283ca566af --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/ya.make @@ -0,0 +1,19 @@ +GO_LIBRARY() + +SRCS( + handler.go + logger.go + spack.go + tvm.go +) + +GO_TEST_SRCS(handler_test.go) + +GO_XTEST_SRCS( + example_test.go + tvm_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/yandex/tvm/cachedtvm/cache.go b/library/go/yandex/tvm/cachedtvm/cache.go new file mode 100644 index 0000000000..a04e2baf8a --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/cache.go @@ -0,0 +1,22 @@ +package cachedtvm + +import ( + "time" + + "github.com/karlseguin/ccache/v2" +) + +type cache struct { + *ccache.Cache + ttl time.Duration +} + +func (c *cache) Fetch(key string, fn func() (interface{}, error)) (*ccache.Item, error) { + return c.Cache.Fetch(key, c.ttl, fn) +} + +func (c *cache) Stop() { + if c.Cache != nil { + c.Cache.Stop() + } +} diff --git a/library/go/yandex/tvm/cachedtvm/client.go b/library/go/yandex/tvm/cachedtvm/client.go new file mode 100644 index 0000000000..ed7d51d8d1 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client.go @@ -0,0 +1,116 @@ +package cachedtvm + +import ( + "context" + "fmt" + "time" + + "github.com/karlseguin/ccache/v2" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +const ( + DefaultTTL = 1 * time.Minute + DefaultMaxItems = 100 + MaxServiceTicketTTL = 5 * time.Minute + MaxUserTicketTTL = 1 * time.Minute +) + +type CachedClient struct { + tvm.Client + serviceTicketCache cache + userTicketCache cache + userTicketFn func(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) +} + +func NewClient(tvmClient tvm.Client, opts ...Option) (*CachedClient, error) { + newCache := func(o cacheOptions) cache { + return cache{ + Cache: ccache.New( + ccache.Configure().MaxSize(o.maxItems), + ), + ttl: o.ttl, + } + } + + out := &CachedClient{ + Client: tvmClient, + serviceTicketCache: newCache(cacheOptions{ + ttl: DefaultTTL, + maxItems: DefaultMaxItems, + }), + userTicketFn: tvmClient.CheckUserTicket, + } + + for _, opt := range opts { + switch o := opt.(type) { + case OptionServiceTicket: + if o.ttl > MaxServiceTicketTTL { + return nil, fmt.Errorf("maximum TTL for check service ticket exceed: %s > %s", o.ttl, MaxServiceTicketTTL) + } + + out.serviceTicketCache = newCache(o.cacheOptions) + case OptionUserTicket: + if o.ttl > MaxUserTicketTTL { + return nil, fmt.Errorf("maximum TTL for check user ticket exceed: %s > %s", o.ttl, MaxUserTicketTTL) + } + + out.userTicketFn = out.cacheCheckUserTicket + out.userTicketCache = newCache(o.cacheOptions) + default: + panic(fmt.Sprintf("unexpected cache option: %T", o)) + } + } + + return out, nil +} + +func (c *CachedClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + out, err := c.serviceTicketCache.Fetch(ticket, func() (interface{}, error) { + return c.Client.CheckServiceTicket(ctx, ticket) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedServiceTicket), nil +} + +func (c *CachedClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + return c.userTicketFn(ctx, ticket, opts...) +} + +func (c *CachedClient) cacheCheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + cacheKey := func(ticket string, opts ...tvm.CheckUserTicketOption) string { + if len(opts) == 0 { + return ticket + } + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + if options.EnvOverride == nil { + return ticket + } + + return fmt.Sprintf("%d:%s", *options.EnvOverride, ticket) + } + + out, err := c.userTicketCache.Fetch(cacheKey(ticket, opts...), func() (interface{}, error) { + return c.Client.CheckUserTicket(ctx, ticket, opts...) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedUserTicket), nil +} + +func (c *CachedClient) Close() { + c.serviceTicketCache.Stop() + c.userTicketCache.Stop() +} diff --git a/library/go/yandex/tvm/cachedtvm/client_example_test.go b/library/go/yandex/tvm/cachedtvm/client_example_test.go new file mode 100644 index 0000000000..749e1d34c5 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_example_test.go @@ -0,0 +1,40 @@ +package cachedtvm_test + +import ( + "context" + "fmt" + "time" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/cachedtvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient_checkServiceTicket() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + cachedTvmClient, err := cachedtvm.NewClient( + tvmClient, + cachedtvm.WithCheckServiceTicket(1*time.Minute, 1000), + ) + if err != nil { + panic(err) + } + defer cachedTvmClient.Close() + + ticketInfo, err := cachedTvmClient.CheckServiceTicket(context.TODO(), "3:serv:....") + if err != nil { + panic(err) + } + + fmt.Println("ticket info: ", ticketInfo.LogInfo) +} diff --git a/library/go/yandex/tvm/cachedtvm/client_test.go b/library/go/yandex/tvm/cachedtvm/client_test.go new file mode 100644 index 0000000000..0d9f0a8732 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_test.go @@ -0,0 +1,194 @@ +package cachedtvm_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/cachedtvm" +) + +const ( + checkPasses = 5 +) + +type mockTvmClient struct { + tvm.Client + checkServiceTicketCalls int + checkUserTicketCalls int +} + +func (c *mockTvmClient) CheckServiceTicket(_ context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + defer func() { c.checkServiceTicketCalls++ }() + + return &tvm.CheckedServiceTicket{ + LogInfo: ticket, + IssuerUID: tvm.UID(c.checkServiceTicketCalls), + }, nil +} + +func (c *mockTvmClient) CheckUserTicket(_ context.Context, ticket string, _ ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer func() { c.checkUserTicketCalls++ }() + + return &tvm.CheckedUserTicket{ + LogInfo: ticket, + DefaultUID: tvm.UID(c.checkUserTicketCalls), + }, nil +} + +func (c *mockTvmClient) GetServiceTicketForAlias(_ context.Context, alias string) (string, error) { + return alias, nil +} + +func checkServiceTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckServiceTicket(context.Background(), "3:serv:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, true) + } +} + +func runNotEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, false) + } +} + +func checkUserTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckUserTicket(context.Background(), "3:user:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, true) + } +} + +func runNotEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, false) + } +} +func TestDefaultBehavior(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 1, nestedClient.checkServiceTicketCalls) + require.Equal(t, checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckServiceTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + require.Equal(t, 2*checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckUserTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_usr", runEqualUserTickets(client)) + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceAndUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, + cachedtvm.WithCheckServiceTicket(10*time.Second, 10), + cachedtvm.WithCheckUserTicket(10*time.Second, 10), + ) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestErrors(t *testing.T) { + cases := []cachedtvm.Option{ + cachedtvm.WithCheckServiceTicket(12*time.Hour, 1), + cachedtvm.WithCheckUserTicket(30*time.Minute, 1), + } + + nestedClient := &mockTvmClient{} + for i, tc := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + _, err := cachedtvm.NewClient(nestedClient, tc) + require.Error(t, err) + }) + } +} diff --git a/library/go/yandex/tvm/cachedtvm/gotest/ya.make b/library/go/yandex/tvm/cachedtvm/gotest/ya.make new file mode 100644 index 0000000000..342ec17e53 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(library/go/yandex/tvm/cachedtvm) + +FORK_TESTS() + +END() diff --git a/library/go/yandex/tvm/cachedtvm/opts.go b/library/go/yandex/tvm/cachedtvm/opts.go new file mode 100644 index 0000000000..0df9dfa89e --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/opts.go @@ -0,0 +1,40 @@ +package cachedtvm + +import "time" + +type ( + Option interface{ isCachedOption() } + + cacheOptions struct { + ttl time.Duration + maxItems int64 + } + + OptionServiceTicket struct { + Option + cacheOptions + } + + OptionUserTicket struct { + Option + cacheOptions + } +) + +func WithCheckServiceTicket(ttl time.Duration, maxSize int) Option { + return OptionServiceTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} + +func WithCheckUserTicket(ttl time.Duration, maxSize int) Option { + return OptionUserTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} diff --git a/library/go/yandex/tvm/cachedtvm/ya.make b/library/go/yandex/tvm/cachedtvm/ya.make new file mode 100644 index 0000000000..e9c82b90f5 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/ya.make @@ -0,0 +1,16 @@ +GO_LIBRARY() + +SRCS( + cache.go + client.go + opts.go +) + +GO_XTEST_SRCS( + client_example_test.go + client_test.go +) + +END() + +RECURSE(gotest) diff --git a/library/go/yandex/tvm/client.go b/library/go/yandex/tvm/client.go new file mode 100644 index 0000000000..f05a97e2d5 --- /dev/null +++ b/library/go/yandex/tvm/client.go @@ -0,0 +1,64 @@ +package tvm + +//go:generate ya tool mockgen -source=$GOFILE -destination=mocks/tvm.gen.go Client + +import ( + "context" + "fmt" +) + +type ClientStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/client/client_status.h +const ( + ClientOK ClientStatus = iota + ClientWarning + ClientError +) + +func (s ClientStatus) String() string { + switch s { + case ClientOK: + return "OK" + case ClientWarning: + return "Warning" + case ClientError: + return "Error" + default: + return fmt.Sprintf("Unknown%d", s) + } +} + +type ClientStatusInfo struct { + Status ClientStatus + + // This message allows to trigger alert with useful message + // It returns "OK" if Status==Ok + LastError string +} + +// Client allows to use aliases for ClientID. +// +// Alias is local label for ClientID which can be used to avoid this number in every checking case in code. +type Client interface { + GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) + GetServiceTicketForID(ctx context.Context, dstID ClientID) (string, error) + + // CheckServiceTicket returns struct with SrcID: you should check it by yourself with ACL + CheckServiceTicket(ctx context.Context, ticket string) (*CheckedServiceTicket, error) + CheckUserTicket(ctx context.Context, ticket string, opts ...CheckUserTicketOption) (*CheckedUserTicket, error) + GetRoles(ctx context.Context) (*Roles, error) + + // GetStatus returns current status of client: + // * you should trigger your monitoring if status is not Ok + // * it will be unable to operate if status is Invalid + GetStatus(ctx context.Context) (ClientStatusInfo, error) +} + +// Dynamic client allows to add dsts dynamically +type DynamicClient interface { + Client + + GetOptionalServiceTicketForID(ctx context.Context, dstID ClientID) (*string, error) + AddDsts(ctx context.Context, dsts []ClientID) error +} diff --git a/library/go/yandex/tvm/context.go b/library/go/yandex/tvm/context.go new file mode 100644 index 0000000000..3a30dbb0b6 --- /dev/null +++ b/library/go/yandex/tvm/context.go @@ -0,0 +1,33 @@ +package tvm + +import "context" + +type ( + serviceTicketContextKey struct{} + userTicketContextKey struct{} +) + +var ( + stKey serviceTicketContextKey + utKey userTicketContextKey +) + +// WithServiceTicket returns copy of the ctx with service ticket attached to it. +func WithServiceTicket(ctx context.Context, t *CheckedServiceTicket) context.Context { + return context.WithValue(ctx, &stKey, t) +} + +// WithUserTicket returns copy of the ctx with user ticket attached to it. +func WithUserTicket(ctx context.Context, t *CheckedUserTicket) context.Context { + return context.WithValue(ctx, &utKey, t) +} + +func ContextServiceTicket(ctx context.Context) (t *CheckedServiceTicket) { + t, _ = ctx.Value(&stKey).(*CheckedServiceTicket) + return +} + +func ContextUserTicket(ctx context.Context) (t *CheckedUserTicket) { + t, _ = ctx.Value(&utKey).(*CheckedUserTicket) + return +} diff --git a/library/go/yandex/tvm/errors.go b/library/go/yandex/tvm/errors.go new file mode 100644 index 0000000000..bd511d05f3 --- /dev/null +++ b/library/go/yandex/tvm/errors.go @@ -0,0 +1,107 @@ +package tvm + +import ( + "errors" + "fmt" +) + +// ErrNotSupported - error to be used within cgo disabled builds. +var ErrNotSupported = errors.New("ticket_parser2 is not available when building with -DCGO_ENABLED=0") + +var ( + ErrTicketExpired = &TicketError{Status: TicketExpired} + ErrTicketInvalidBlackboxEnv = &TicketError{Status: TicketInvalidBlackboxEnv} + ErrTicketInvalidDst = &TicketError{Status: TicketInvalidDst} + ErrTicketInvalidTicketType = &TicketError{Status: TicketInvalidTicketType} + ErrTicketMalformed = &TicketError{Status: TicketMalformed} + ErrTicketMissingKey = &TicketError{Status: TicketMissingKey} + ErrTicketSignBroken = &TicketError{Status: TicketSignBroken} + ErrTicketUnsupportedVersion = &TicketError{Status: TicketUnsupportedVersion} + ErrTicketStatusOther = &TicketError{Status: TicketStatusOther} + ErrTicketInvalidScopes = &TicketError{Status: TicketInvalidScopes} + ErrTicketInvalidSrcID = &TicketError{Status: TicketInvalidSrcID} +) + +type TicketError struct { + Status TicketStatus + Msg string +} + +func (e *TicketError) Is(err error) bool { + otherTickerErr, ok := err.(*TicketError) + if !ok { + return false + } + if e == nil && otherTickerErr == nil { + return true + } + if e == nil || otherTickerErr == nil { + return false + } + return e.Status == otherTickerErr.Status +} + +func (e *TicketError) Error() string { + if e.Msg != "" { + return fmt.Sprintf("tvm: invalid ticket: %s: %s", e.Status, e.Msg) + } + return fmt.Sprintf("tvm: invalid ticket: %s", e.Status) +} + +type ErrorCode int + +// This constants must be in sync with code in go/tvmauth/tvm.cpp:CatchError +const ( + ErrorOK ErrorCode = iota + ErrorMalformedSecret + ErrorMalformedKeys + ErrorEmptyKeys + ErrorNotAllowed + ErrorBrokenTvmClientSettings + ErrorMissingServiceTicket + ErrorPermissionDenied + ErrorOther + + // Go-only errors below + ErrorBadRequest + ErrorAuthFail +) + +func (e ErrorCode) String() string { + switch e { + case ErrorOK: + return "OK" + case ErrorMalformedSecret: + return "MalformedSecret" + case ErrorMalformedKeys: + return "MalformedKeys" + case ErrorEmptyKeys: + return "EmptyKeys" + case ErrorNotAllowed: + return "NotAllowed" + case ErrorBrokenTvmClientSettings: + return "BrokenTvmClientSettings" + case ErrorMissingServiceTicket: + return "MissingServiceTicket" + case ErrorPermissionDenied: + return "PermissionDenied" + case ErrorOther: + return "Other" + case ErrorBadRequest: + return "ErrorBadRequest" + case ErrorAuthFail: + return "AuthFail" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +type Error struct { + Code ErrorCode + Retriable bool + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Msg, e.Code) +} diff --git a/library/go/yandex/tvm/examples/tvm_example_test.go b/library/go/yandex/tvm/examples/tvm_example_test.go new file mode 100644 index 0000000000..b10c6ce95f --- /dev/null +++ b/library/go/yandex/tvm/examples/tvm_example_test.go @@ -0,0 +1,59 @@ +package tvm_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" +) + +func ExampleClient_alias() { + blackboxAlias := "blackbox" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "...", + map[string]tvm.ClientID{ + blackboxAlias: 1000501, + }), + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleClient_roles() { + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("...", nil), + FetchRolesForIdmSystemSlug: "some_idm_system", + DiskCacheDir: "...", + EnableServiceTicketChecking: true, + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.CheckServiceTicket(context.Background(), "3:serv:...") + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + + r, err := c.GetRoles(context.Background()) + if err != nil { + panic(err) + } + fmt.Println(r.GetMeta().Revision) +} diff --git a/library/go/yandex/tvm/examples/ya.make b/library/go/yandex/tvm/examples/ya.make new file mode 100644 index 0000000000..eaf54a8c35 --- /dev/null +++ b/library/go/yandex/tvm/examples/ya.make @@ -0,0 +1,5 @@ +GO_TEST() + +GO_XTEST_SRCS(tvm_example_test.go) + +END() diff --git a/library/go/yandex/tvm/gotest/ya.make b/library/go/yandex/tvm/gotest/ya.make new file mode 100644 index 0000000000..cc1ed6eeb7 --- /dev/null +++ b/library/go/yandex/tvm/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/yandex/tvm) + +END() diff --git a/library/go/yandex/tvm/mocks/tvm.gen.go b/library/go/yandex/tvm/mocks/tvm.gen.go new file mode 100644 index 0000000000..151b8f4c72 --- /dev/null +++ b/library/go/yandex/tvm/mocks/tvm.gen.go @@ -0,0 +1,130 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client.go + +// Package mock_tvm is a generated GoMock package. +package mock_tvm + +import ( + tvm "github.com/ydb-platform/ydb/library/go/yandex/tvm" + context "context" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// GetServiceTicketForAlias mocks base method. +func (m *MockClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForAlias", ctx, alias) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForAlias indicates an expected call of GetServiceTicketForAlias. +func (mr *MockClientMockRecorder) GetServiceTicketForAlias(ctx, alias interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForAlias", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForAlias), ctx, alias) +} + +// GetServiceTicketForID mocks base method. +func (m *MockClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForID", ctx, dstID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForID indicates an expected call of GetServiceTicketForID. +func (mr *MockClientMockRecorder) GetServiceTicketForID(ctx, dstID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForID", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForID), ctx, dstID) +} + +// CheckServiceTicket mocks base method. +func (m *MockClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckServiceTicket", ctx, ticket) + ret0, _ := ret[0].(*tvm.CheckedServiceTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckServiceTicket indicates an expected call of CheckServiceTicket. +func (mr *MockClientMockRecorder) CheckServiceTicket(ctx, ticket interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckServiceTicket", reflect.TypeOf((*MockClient)(nil).CheckServiceTicket), ctx, ticket) +} + +// CheckUserTicket mocks base method. +func (m *MockClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, ticket} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CheckUserTicket", varargs...) + ret0, _ := ret[0].(*tvm.CheckedUserTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckUserTicket indicates an expected call of CheckUserTicket. +func (mr *MockClientMockRecorder) CheckUserTicket(ctx, ticket interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, ticket}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserTicket", reflect.TypeOf((*MockClient)(nil).CheckUserTicket), varargs...) +} + +// GetRoles mocks base method. +func (m *MockClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoles", ctx) + ret0, _ := ret[0].(*tvm.Roles) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoles indicates an expected call of GetRoles. +func (mr *MockClientMockRecorder) GetRoles(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoles", reflect.TypeOf((*MockClient)(nil).GetRoles), ctx) +} + +// GetStatus mocks base method. +func (m *MockClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatus", ctx) + ret0, _ := ret[0].(tvm.ClientStatusInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStatus indicates an expected call of GetStatus. +func (mr *MockClientMockRecorder) GetStatus(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatus", reflect.TypeOf((*MockClient)(nil).GetStatus), ctx) +} diff --git a/library/go/yandex/tvm/mocks/ya.make b/library/go/yandex/tvm/mocks/ya.make new file mode 100644 index 0000000000..ca5a5d3d35 --- /dev/null +++ b/library/go/yandex/tvm/mocks/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(tvm.gen.go) + +END() diff --git a/library/go/yandex/tvm/roles.go b/library/go/yandex/tvm/roles.go new file mode 100644 index 0000000000..12447d8b11 --- /dev/null +++ b/library/go/yandex/tvm/roles.go @@ -0,0 +1,150 @@ +package tvm + +import ( + "encoding/json" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +func (r *Roles) GetRolesForService(t *CheckedServiceTicket) *ConsumerRoles { + return r.tvmRoles[t.SrcID] +} + +func (r *Roles) GetRolesForUser(t *CheckedUserTicket, uid *UID) (*ConsumerRoles, error) { + if t.Env != BlackboxProdYateam { + return nil, xerrors.Errorf("user ticket must be from ProdYateam, got from %s", t.Env) + } + + if uid == nil { + if t.DefaultUID == 0 { + return nil, xerrors.Errorf("default uid is 0 - it cannot have any role") + } + uid = &t.DefaultUID + } else { + found := false + for _, u := range t.UIDs { + if u == *uid { + found = true + break + } + } + if !found { + return nil, xerrors.Errorf("'uid' must be in user ticket but it is not: %d", *uid) + } + } + + return r.userRoles[*uid], nil +} + +func (r *Roles) GetRaw() []byte { + return r.raw +} + +func (r *Roles) GetMeta() Meta { + return r.meta +} + +func (r *Roles) CheckServiceRole(t *CheckedServiceTicket, roleName string, opts *CheckServiceOptions) bool { + roles := r.GetRolesForService(t) + + if !roles.HasRole(roleName) { + return false + } + + if opts != nil && opts.Entity != nil { + e := roles.GetEntitiesForRole(roleName) + if e == nil { + return false + } + + if !e.ContainsExactEntity(opts.Entity) { + return false + } + } + + return true +} + +func (r *Roles) CheckUserRole(t *CheckedUserTicket, roleName string, opts *CheckUserOptions) (bool, error) { + var uid *UID + if opts != nil && opts.UID != 0 { + uid = &opts.UID + } + + roles, err := r.GetRolesForUser(t, uid) + if err != nil { + return false, err + } + + if !roles.HasRole(roleName) { + return false, nil + } + + if opts != nil && opts.Entity != nil { + e := roles.GetEntitiesForRole(roleName) + if e == nil { + return false, nil + } + + if !e.ContainsExactEntity(opts.Entity) { + return false, nil + } + } + + return true, nil +} + +func (r *ConsumerRoles) HasRole(roleName string) bool { + if r == nil { + return false + } + + _, ok := r.roles[roleName] + return ok +} + +func (r *ConsumerRoles) GetRoles() EntitiesByRoles { + if r == nil { + return nil + } + return r.roles +} + +func (r *ConsumerRoles) GetEntitiesForRole(roleName string) *Entities { + if r == nil { + return nil + } + return r.roles[roleName] +} + +func (r *ConsumerRoles) DebugPrint() string { + tmp := make(map[string][]Entity) + + for k, v := range r.roles { + if v != nil { + tmp[k] = v.subtree.entities + } else { + tmp[k] = nil + } + } + + res, err := json.MarshalIndent(tmp, "", " ") + if err != nil { + panic(err) + } + return string(res) +} + +func (e *Entities) ContainsExactEntity(entity Entity) bool { + if e == nil { + return false + } + return e.subtree.containsExactEntity(entity) +} + +func (e *Entities) GetEntitiesWithAttrs(entityPart Entity) []Entity { + if e == nil { + return nil + } + return e.subtree.getEntitiesWithAttrs(entityPart) +} diff --git a/library/go/yandex/tvm/roles_entities_index.go b/library/go/yandex/tvm/roles_entities_index.go new file mode 100644 index 0000000000..488ce7fb09 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index.go @@ -0,0 +1,73 @@ +package tvm + +import "sort" + +type entityAttribute struct { + key string + value string +} + +// subTree provides index for fast entity lookup with attributes +// +// or some subset of entity attributes +type subTree struct { + // entities contains entities with attributes from previous branches of tree: + // * root subTree contains all entities + // * next subTree contains entities with {"key#X": "value#X"} + // * next subTree after next contains entities with {"key#X": "value#X", "key#Y": "value#Y"} + // * and so on + // "key#X", "key#Y", ... - are sorted + entities []Entity + // entityLengths provides O(1) for exact entity lookup + entityLengths map[int]interface{} + // entityIds is creation-time crutch + entityIds []int + idxByAttrs *idxByAttrs +} + +type idxByAttrs = map[entityAttribute]*subTree + +func (s *subTree) containsExactEntity(entity Entity) bool { + subtree := s.findSubTree(entity) + if subtree == nil { + return false + } + + _, ok := subtree.entityLengths[len(entity)] + return ok +} + +func (s *subTree) getEntitiesWithAttrs(entityPart Entity) []Entity { + subtree := s.findSubTree(entityPart) + if subtree == nil { + return nil + } + + return subtree.entities +} + +func (s *subTree) findSubTree(e Entity) *subTree { + keys := make([]string, 0, len(e)) + for k := range e { + keys = append(keys, k) + } + sort.Strings(keys) + + res := s + + for _, k := range keys { + if res.idxByAttrs == nil { + return nil + } + + kv := entityAttribute{key: k, value: e[k]} + ok := false + + res, ok = (*res.idxByAttrs)[kv] + if !ok { + return nil + } + } + + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder.go b/library/go/yandex/tvm/roles_entities_index_builder.go new file mode 100644 index 0000000000..b4e2769287 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder.go @@ -0,0 +1,129 @@ +package tvm + +import "sort" + +type stages struct { + keys []string + id uint64 +} + +func createStages(keys []string) stages { + return stages{ + keys: keys, + } +} + +func (s *stages) getNextStage(keys *[]string) bool { + s.id += 1 + *keys = (*keys)[:0] + + for idx := range s.keys { + need := (s.id >> idx) & 0x01 + if need == 1 { + *keys = append(*keys, s.keys[idx]) + } + } + + return len(*keys) > 0 +} + +func buildLightEntities(entities []Entity) *Entities { + if len(entities) == 0 || len(entities[0]) == 0 { + return nil + } + + return &Entities{ + subtree: subTree{ + entities: entities, + }, + } +} + +func buildEntities(entities []Entity) *Entities { + root := make(idxByAttrs) + res := &Entities{ + subtree: subTree{ + idxByAttrs: &root, + }, + } + + stage := createStages(getUniqueSortedKeys(entities)) + + keySet := make([]string, 0, len(stage.keys)) + for stage.getNextStage(&keySet) { + for entityID, entity := range entities { + currentBranch := &res.subtree + + for _, key := range keySet { + entValue, ok := entity[key] + if !ok { + continue + } + + if currentBranch.idxByAttrs == nil { + index := make(idxByAttrs) + currentBranch.idxByAttrs = &index + } + + kv := entityAttribute{key: key, value: entValue} + subtree, ok := (*currentBranch.idxByAttrs)[kv] + if !ok { + subtree = &subTree{} + (*currentBranch.idxByAttrs)[kv] = subtree + } + + currentBranch = subtree + currentBranch.entityIds = append(currentBranch.entityIds, entityID) + res.subtree.entityIds = append(res.subtree.entityIds, entityID) + } + } + } + + postProcessSubTree(&res.subtree, entities) + + return res +} + +func postProcessSubTree(sub *subTree, entities []Entity) { + tmp := make(map[int]interface{}, len(entities)) + for _, e := range sub.entityIds { + tmp[e] = nil + } + sub.entityIds = sub.entityIds[:0] + for i := range tmp { + sub.entityIds = append(sub.entityIds, i) + } + sort.Ints(sub.entityIds) + + sub.entities = make([]Entity, 0, len(sub.entityIds)) + sub.entityLengths = make(map[int]interface{}) + for _, idx := range sub.entityIds { + sub.entities = append(sub.entities, entities[idx]) + sub.entityLengths[len(entities[idx])] = nil + } + sub.entityIds = nil + + if sub.idxByAttrs != nil { + for _, rest := range *sub.idxByAttrs { + postProcessSubTree(rest, entities) + } + } +} + +func getUniqueSortedKeys(entities []Entity) []string { + tmp := map[string]interface{}{} + + for _, e := range entities { + for k := range e { + tmp[k] = nil + } + } + + res := make([]string, 0, len(tmp)) + for k := range tmp { + res = append(res, k) + } + + sort.Strings(res) + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder_test.go b/library/go/yandex/tvm/roles_entities_index_builder_test.go new file mode 100644 index 0000000000..dd795369d5 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder_test.go @@ -0,0 +1,259 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesGetNextStage(t *testing.T) { + s := createStages([]string{"key#1", "key#2", "key#3", "key#4"}) + + results := [][]string{ + {"key#1"}, + {"key#2"}, + {"key#1", "key#2"}, + {"key#3"}, + {"key#1", "key#3"}, + {"key#2", "key#3"}, + {"key#1", "key#2", "key#3"}, + {"key#4"}, + {"key#1", "key#4"}, + {"key#2", "key#4"}, + {"key#1", "key#2", "key#4"}, + {"key#3", "key#4"}, + {"key#1", "key#3", "key#4"}, + {"key#2", "key#3", "key#4"}, + {"key#1", "key#2", "key#3", "key#4"}, + } + + keySet := make([]string, 0) + for idx, exp := range results { + s.getNextStage(&keySet) + require.Equal(t, exp, keySet, idx) + } + + // require.False(t, s.getNextStage(&keySet)) +} + +func TestRolesBuildEntities(t *testing.T) { + type TestCase struct { + in []Entity + out Entities + } + cases := []TestCase{ + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: Entities{subtree: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{1: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#1", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + }, + }, + }, + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#3", value: "value#3"}: &subTree{ + entities: []Entity{ + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, *buildEntities(c.in), idx) + } +} + +func TestRolesPostProcessSubTree(t *testing.T) { + type TestCase struct { + in subTree + out subTree + } + + cases := []TestCase{ + { + in: subTree{ + entityIds: []int{1, 1, 1, 1, 1, 2, 0, 0, 0}, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + }, + }, + { + in: subTree{ + entityIds: []int{1, 0}, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entityIds: []int{2, 0, 0}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entityIds: []int{0, 0, 0}, + }, + }, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + }, + }, + }, + } + + entities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + + for idx, c := range cases { + postProcessSubTree(&c.in, entities) + require.Equal(t, c.out, c.in, idx) + } +} + +func TestRolesGetUniqueSortedKeys(t *testing.T) { + type TestCase struct { + in []Entity + out []string + } + + cases := []TestCase{ + { + in: nil, + out: []string{}, + }, + { + in: []Entity{}, + out: []string{}, + }, + { + in: []Entity{ + {}, + }, + out: []string{}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2"}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: []string{"key#1", "key#2", "key#3"}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, getUniqueSortedKeys(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_entities_index_test.go b/library/go/yandex/tvm/roles_entities_index_test.go new file mode 100644 index 0000000000..e1abaa0f0e --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_test.go @@ -0,0 +1,113 @@ +package tvm + +import ( + "math/rand" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRolesSubTreeContainsExactEntity(t *testing.T) { + origEntities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#1", "key#2": "value#2"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + entities := buildEntities(origEntities) + + for _, e := range generatedRandEntities() { + found := false + for _, o := range origEntities { + if reflect.DeepEqual(e, o) { + found = true + break + } + } + + require.Equal(t, found, entities.subtree.containsExactEntity(e), e) + } +} + +func generatedRandEntities() []Entity { + rand.Seed(time.Now().UnixNano()) + + keysStages := createStages([]string{"key#1", "key#2", "key#3", "key#4", "key#5"}) + valuesSet := []string{"value#1", "value#2", "value#3", "value#4", "value#5"} + + res := make([]Entity, 0) + + keySet := make([]string, 0, 5) + for keysStages.getNextStage(&keySet) { + entity := Entity{} + for _, key := range keySet { + entity[key] = valuesSet[rand.Intn(len(valuesSet))] + + e := Entity{} + for k, v := range entity { + e[k] = v + } + res = append(res, e) + } + } + + return res +} + +func TestRolesGetEntitiesWithAttrs(t *testing.T) { + type TestCase struct { + in Entity + out []Entity + } + + cases := []TestCase{ + { + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + }, + { + in: Entity{"key#1": "value#1"}, + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + }, + { + in: Entity{"key#1": "value#2"}, + out: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#2": "value#2"}, + out: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#3": "value#3"}, + out: []Entity{ + {"key#3": "value#3"}, + }, + }, + } + + entities := buildEntities([]Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }) + + for idx, c := range cases { + require.Equal(t, c.out, entities.subtree.getEntitiesWithAttrs(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_opts.go b/library/go/yandex/tvm/roles_opts.go new file mode 100644 index 0000000000..8e0a0e0608 --- /dev/null +++ b/library/go/yandex/tvm/roles_opts.go @@ -0,0 +1,10 @@ +package tvm + +type CheckServiceOptions struct { + Entity Entity +} + +type CheckUserOptions struct { + Entity Entity + UID UID +} diff --git a/library/go/yandex/tvm/roles_parser.go b/library/go/yandex/tvm/roles_parser.go new file mode 100644 index 0000000000..0c74698efe --- /dev/null +++ b/library/go/yandex/tvm/roles_parser.go @@ -0,0 +1,77 @@ +package tvm + +import ( + "encoding/json" + "strconv" + "time" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +type rawRoles struct { + Revision string `json:"revision"` + BornDate int64 `json:"born_date"` + Tvm rawConsumers `json:"tvm"` + User rawConsumers `json:"user"` +} + +type rawConsumers = map[string]rawConsumerRoles +type rawConsumerRoles = map[string][]Entity + +func NewRoles(buf []byte) (*Roles, error) { + return NewRolesWithOpts(buf) +} + +func NewRolesWithOpts(buf []byte, opts ...RoleParserOption) (*Roles, error) { + options := newRolesParserOptions(opts...) + + var raw rawRoles + if err := json.Unmarshal(buf, &raw); err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid json: %w", err) + } + + tvmRoles := map[ClientID]*ConsumerRoles{} + for key, value := range raw.Tvm { + id, err := strconv.ParseUint(key, 10, 32) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid tvmid '%s': %w", key, err) + } + tvmRoles[ClientID(id)] = buildConsumerRoles(value, options) + } + + userRoles := map[UID]*ConsumerRoles{} + for key, value := range raw.User { + id, err := strconv.ParseUint(key, 10, 64) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid UID '%s': %w", key, err) + } + userRoles[UID(id)] = buildConsumerRoles(value, options) + } + + return &Roles{ + tvmRoles: tvmRoles, + userRoles: userRoles, + raw: buf, + meta: Meta{ + Revision: raw.Revision, + BornTime: time.Unix(raw.BornDate, 0), + Applied: time.Now(), + }, + }, nil +} + +func buildConsumerRoles(rawConsumerRoles rawConsumerRoles, opts *rolesParserOptions) *ConsumerRoles { + roles := &ConsumerRoles{ + roles: make(EntitiesByRoles, len(rawConsumerRoles)), + } + + for r, ents := range rawConsumerRoles { + if opts.UseLightIndex { + roles.roles[r] = buildLightEntities(ents) + } else { + roles.roles[r] = buildEntities(ents) + } + } + + return roles +} diff --git a/library/go/yandex/tvm/roles_parser_opts.go b/library/go/yandex/tvm/roles_parser_opts.go new file mode 100644 index 0000000000..b9b0fb6819 --- /dev/null +++ b/library/go/yandex/tvm/roles_parser_opts.go @@ -0,0 +1,22 @@ +package tvm + +type RoleParserOption func(options *rolesParserOptions) +type rolesParserOptions struct { + UseLightIndex bool +} + +func newRolesParserOptions(opts ...RoleParserOption) *rolesParserOptions { + options := &rolesParserOptions{} + + for _, opt := range opts { + opt(options) + } + + return options +} + +func WithLightIndex() RoleParserOption { + return func(options *rolesParserOptions) { + options.UseLightIndex = true + } +} diff --git a/library/go/yandex/tvm/roles_parser_test.go b/library/go/yandex/tvm/roles_parser_test.go new file mode 100644 index 0000000000..c7fd069d4a --- /dev/null +++ b/library/go/yandex/tvm/roles_parser_test.go @@ -0,0 +1,126 @@ +package tvm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewRolesWithOpts(t *testing.T) { + type TestCase struct { + buf string + opts []RoleParserOption + roles Roles + err string + } + + cases := []TestCase{ + { + buf: `{"revision":100500}`, + opts: []RoleParserOption{}, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"born_date":1612791978.42}`, + opts: []RoleParserOption{}, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"tvm":{"asd":{}}}`, + opts: []RoleParserOption{}, + err: "failed to parse roles: invalid tvmid 'asd'", + }, + { + buf: `{"user":{"asd":{}}}`, + opts: []RoleParserOption{}, + err: "failed to parse roles: invalid UID 'asd'", + }, + { + buf: `{"tvm":{"1120000000000493":{}}}`, + opts: []RoleParserOption{}, + err: "failed to parse roles: invalid tvmid '1120000000000493'", + }, + { + buf: `{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`, + opts: []RoleParserOption{}, + roles: Roles{ + tvmRoles: map[ClientID]*ConsumerRoles{ + ClientID(2012192): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/impersonator/": {}, + "/group/system/system_on/abc/role/tree_edit/": {}, + }, + }, + }, + userRoles: map[UID]*ConsumerRoles{ + UID(1120000000000493): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/roles_manage/": {}, + }, + }, + }, + raw: []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`), + meta: Meta{ + Revision: "GYYDEMJUGBQWC", + BornTime: time.Unix(1612791978, 0), + }, + }, + }, + { + buf: `{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/","blank":""}],"/group/system/system_on/abc/role/admin/":[{}]}}}`, + opts: []RoleParserOption{WithLightIndex()}, + roles: Roles{ + tvmRoles: map[ClientID]*ConsumerRoles{ + ClientID(2012192): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/impersonator/": {}, + "/group/system/system_on/abc/role/tree_edit/": nil, + }, + }, + }, + userRoles: map[UID]*ConsumerRoles{ + UID(1120000000000493): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/roles_manage/": {}, + "/group/system/system_on/abc/role/admin/": nil, + }, + }, + }, + raw: []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/","blank":""}],"/group/system/system_on/abc/role/admin/":[{}]}}}`), + meta: Meta{ + Revision: "GYYDEMJUGBQWC", + BornTime: time.Unix(1612791978, 0), + }, + }, + }, + } + + for idx, c := range cases { + r, err := NewRolesWithOpts([]byte(c.buf), c.opts...) + if c.err == "" { + require.NoError(t, err, idx) + + r.meta.Applied = time.Time{} + for _, roles := range r.tvmRoles { + for _, v := range roles.roles { + if v != nil { + v.subtree = subTree{} + } + } + } + for _, roles := range r.userRoles { + for _, v := range roles.roles { + if v != nil { + v.subtree = subTree{} + } + } + } + + require.Equal(t, c.roles, *r, idx) + } else { + require.Error(t, err, idx) + require.Contains(t, err.Error(), c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/roles_test.go b/library/go/yandex/tvm/roles_test.go new file mode 100644 index 0000000000..719516308b --- /dev/null +++ b/library/go/yandex/tvm/roles_test.go @@ -0,0 +1,297 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesPublicServiceTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"},{"blank":""}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}],"/group/system/system_on/abc/role/admin/":[]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`)) + require.NoError(t, err) + + st := &CheckedServiceTicket{SrcID: 42} + require.Nil(t, roles.GetRolesForService(st)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/admin/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "/"}})) + + st = &CheckedServiceTicket{SrcID: 2012192} + r := roles.GetRolesForService(st) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/admin/": [], + "/group/system/system_on/abc/role/impersonator/": [ + { + "scope": "/" + }, + { + "blank": "" + } + ], + "/group/system/system_on/abc/role/tree_edit/": [ + { + "scope": "/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 3, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/impersonator/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/admin/")) + require.False(t, roles.CheckServiceRole(st, "/", nil)) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "kek"}})) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"blank": "/"}})) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"blank": ""}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/admin/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.Nil(t, r.GetEntitiesForRole("/")) + + en := r.GetEntitiesForRole("/group/system/system_on/abc/role/impersonator/") + require.NotNil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "kek"})) + require.True(t, en.ContainsExactEntity(Entity{"scope": "/"})) + require.False(t, en.ContainsExactEntity(Entity{"blank": "/"})) + require.True(t, en.ContainsExactEntity(Entity{"blank": ""})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "kek"})) + require.Equal(t, []Entity{{"scope": "/"}}, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"blank": "kek"})) + require.Equal(t, []Entity{{"blank": ""}}, en.GetEntitiesWithAttrs(Entity{"blank": ""})) + require.ElementsMatch(t, []Entity{{"scope": "/"}, {"blank": ""}}, en.GetEntitiesWithAttrs(nil)) + + en = r.GetEntitiesForRole("/group/system/system_on/abc/role/admin/") + require.NotNil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "/"})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) +} + +func TestRolesPublicServiceTicketWithNilEntities(t *testing.T) { + roles, err := NewRolesWithOpts( + []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"},{"blank":""}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}],"/group/system/system_on/abc/role/admin/":[{}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`), + WithLightIndex(), + ) + require.NoError(t, err) + + st := &CheckedServiceTicket{SrcID: 42} + require.Nil(t, roles.GetRolesForService(st)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/admin/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "/"}})) + + st = &CheckedServiceTicket{SrcID: 2012192} + r := roles.GetRolesForService(st) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/admin/": null, + "/group/system/system_on/abc/role/impersonator/": [ + { + "scope": "/" + }, + { + "blank": "" + } + ], + "/group/system/system_on/abc/role/tree_edit/": [ + { + "scope": "/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 3, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/impersonator/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/admin/")) + require.False(t, roles.CheckServiceRole(st, "/", nil)) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "kek"}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"blank": "/"}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"blank": ""}})) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/admin/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.Nil(t, r.GetEntitiesForRole("/")) + + en := r.GetEntitiesForRole("/group/system/system_on/abc/role/impersonator/") + require.NotNil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "kek"})) + require.False(t, en.ContainsExactEntity(Entity{"scope": "/"})) + require.False(t, en.ContainsExactEntity(Entity{"blank": "/"})) + require.False(t, en.ContainsExactEntity(Entity{"blank": ""})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "kek"})) + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"blank": "kek"})) + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"blank": ""})) + require.ElementsMatch(t, []Entity{{"scope": "/"}, {"blank": ""}}, en.GetEntitiesWithAttrs(nil)) + + en = r.GetEntitiesForRole("/group/system/system_on/abc/role/admin/") + require.Nil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "/"})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) +} + +func TestRolesPublicUserTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"},{"blank":""}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}],"/group/system/system_on/abc/role/roles_admin/":[]}}}`)) + require.NoError(t, err) + + ut := &CheckedUserTicket{DefaultUID: 42} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "user ticket must be from ProdYateam, got from Prod") + ut.Env = BlackboxProdYateam + + r, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, r) + ok, err := roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/admin/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 1120000000000493, UIDs: []UID{42}, Env: BlackboxProdYateam} + r, err = roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/roles_admin/": [], + "/group/system/system_on/abc/role/roles_manage/": [ + { + "scope": "/services/meta_infra/tools/jobjira/" + }, + { + "scope": "/services/meta_edu/infrastructure/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 2, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/roles_manage/")) + ok, err = roles.CheckUserRole(ut, "/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "kek"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "/services/meta_infra/tools/jobjira/"}}) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{UID: UID(42)}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 0, UIDs: []UID{42}, Env: BlackboxProdYateam} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "default uid is 0 - it cannot have any role") + uid := UID(83) + _, err = roles.GetRolesForUser(ut, &uid) + require.EqualError(t, err, "'uid' must be in user ticket but it is not: 83") +} + +func TestRolesPublicUserTicketWithNilEntities(t *testing.T) { + roles, err := NewRolesWithOpts( + []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"},{"blank":""}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}],"/group/system/system_on/abc/role/roles_admin/":[{}]}}}`), + WithLightIndex(), + ) + require.NoError(t, err) + + ut := &CheckedUserTicket{DefaultUID: 42} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "user ticket must be from ProdYateam, got from Prod") + ut.Env = BlackboxProdYateam + + r, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, r) + ok, err := roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/admin/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 1120000000000493, UIDs: []UID{42}, Env: BlackboxProdYateam} + r, err = roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/roles_admin/": null, + "/group/system/system_on/abc/role/roles_manage/": [ + { + "scope": "/services/meta_infra/tools/jobjira/" + }, + { + "scope": "/services/meta_edu/infrastructure/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 2, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/roles_manage/")) + ok, err = roles.CheckUserRole(ut, "/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "kek"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "/services/meta_infra/tools/jobjira/"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_admin/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{UID: UID(42)}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 0, UIDs: []UID{42}, Env: BlackboxProdYateam} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "default uid is 0 - it cannot have any role") + uid := UID(83) + _, err = roles.GetRolesForUser(ut, &uid) + require.EqualError(t, err, "'uid' must be in user ticket but it is not: 83") +} diff --git a/library/go/yandex/tvm/roles_types.go b/library/go/yandex/tvm/roles_types.go new file mode 100644 index 0000000000..d1bfb07b3c --- /dev/null +++ b/library/go/yandex/tvm/roles_types.go @@ -0,0 +1,30 @@ +package tvm + +import ( + "time" +) + +type Roles struct { + tvmRoles map[ClientID]*ConsumerRoles + userRoles map[UID]*ConsumerRoles + raw []byte + meta Meta +} + +type Meta struct { + Revision string + BornTime time.Time + Applied time.Time +} + +type ConsumerRoles struct { + roles EntitiesByRoles +} + +type EntitiesByRoles = map[string]*Entities + +type Entities struct { + subtree subTree +} + +type Entity = map[string]string diff --git a/library/go/yandex/tvm/service_ticket.go b/library/go/yandex/tvm/service_ticket.go new file mode 100644 index 0000000000..77eab31047 --- /dev/null +++ b/library/go/yandex/tvm/service_ticket.go @@ -0,0 +1,52 @@ +package tvm + +import ( + "fmt" +) + +// CheckedServiceTicket is service credential +type CheckedServiceTicket struct { + // SrcID is ID of request source service. You should check SrcID by yourself with your ACL. + SrcID ClientID + // DstID is ID of request destination service. It should be checked manually if DisableDstCheck is specified + DstID ClientID + // IssuerUID is UID of developer who is debuging something, so he(she) issued CheckedServiceTicket with his(her) ssh-sign: + // it is grant_type=sshkey in tvm-api + // https://wiki.yandex-team.ru/passport/tvm2/debug/#sxoditvapizakrytoeserviceticketami. + IssuerUID UID + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string +} + +func (t *CheckedServiceTicket) CheckSrcID(allowedSrcIDsMap map[uint32]struct{}) error { + if len(allowedSrcIDsMap) == 0 { + return nil + } + if _, allowed := allowedSrcIDsMap[uint32(t.SrcID)]; !allowed { + return &TicketError{ + Status: TicketInvalidSrcID, + Msg: fmt.Sprintf("service ticket srcID is not in allowed srcIDs: %v (actual: %v)", allowedSrcIDsMap, t.SrcID), + } + } + return nil +} + +func (t CheckedServiceTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +type ServiceTicketACL func(ticket *CheckedServiceTicket) error + +func AllowAllServiceTickets() ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return nil + } +} + +func CheckServiceTicketSrcID(allowedSrcIDs map[uint32]struct{}) ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return ticket.CheckSrcID(allowedSrcIDs) + } +} diff --git a/library/go/yandex/tvm/tvm.go b/library/go/yandex/tvm/tvm.go new file mode 100644 index 0000000000..2e561bd842 --- /dev/null +++ b/library/go/yandex/tvm/tvm.go @@ -0,0 +1,129 @@ +// This package defines interface which provides fast and cryptographically secure authorization tickets: https://wiki.yandex-team.ru/passport/tvm2/. +// +// Encoded ticket is a valid ASCII string: [0-9a-zA-Z_-:]+. +// +// This package defines interface. All libraries should depend on this package. +// Pure Go implementations of interface is located in library/go/yandex/tvm/tvmtool. +// CGO implementation is located in library/ticket_parser2/go/ticket_parser2. +package tvm + +import ( + "fmt" + "strings" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +// ClientID represents ID of the application. Another name - TvmID. +type ClientID uint32 + +// UID represents ID of the user in Passport. +type UID uint64 + +// PorgID represents ID of the porganization +type PorgID uint64 + +// BlackboxEnv describes environment of Passport: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#0-opredeljaemsjasokruzhenijami +type BlackboxEnv int + +type UserExtFields struct { + UID UID + CurrentPorgID PorgID +} + +// This constants must be in sync with EBlackboxEnv from library/cpp/tvmauth/checked_user_ticket.h +const ( + BlackboxProd BlackboxEnv = iota + BlackboxTest + BlackboxProdYateam + BlackboxTestYateam + BlackboxStress +) + +func (e BlackboxEnv) String() string { + switch e { + case BlackboxProd: + return "Prod" + case BlackboxTest: + return "Test" + case BlackboxProdYateam: + return "ProdYateam" + case BlackboxTestYateam: + return "TestYateam" + case BlackboxStress: + return "Stress" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +func BlackboxEnvFromString(envStr string) (BlackboxEnv, error) { + switch strings.ToLower(envStr) { + case "prod": + return BlackboxProd, nil + case "test": + return BlackboxTest, nil + case "prodyateam", "prod_yateam": + return BlackboxProdYateam, nil + case "testyateam", "test_yateam": + return BlackboxTestYateam, nil + case "stress": + return BlackboxStress, nil + default: + return BlackboxEnv(-1), xerrors.Errorf("blackbox env is unknown: '%s'", envStr) + } +} + +type TicketStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/ticket_status.h +const ( + TicketOk TicketStatus = iota + TicketExpired + TicketInvalidBlackboxEnv + TicketInvalidDst + TicketInvalidTicketType + TicketMalformed + TicketMissingKey + TicketSignBroken + TicketUnsupportedVersion + TicketNoRoles + + // Go-only statuses below + TicketStatusOther + TicketInvalidScopes + TicketInvalidSrcID +) + +func (s TicketStatus) String() string { + switch s { + case TicketOk: + return "Ok" + case TicketExpired: + return "Expired" + case TicketInvalidBlackboxEnv: + return "InvalidBlackboxEnv" + case TicketInvalidDst: + return "InvalidDst" + case TicketInvalidTicketType: + return "InvalidTicketType" + case TicketMalformed: + return "Malformed" + case TicketMissingKey: + return "MissingKey" + case TicketSignBroken: + return "SignBroken" + case TicketUnsupportedVersion: + return "UnsupportedVersion" + case TicketNoRoles: + return "NoRoles" + case TicketStatusOther: + return "Other" + case TicketInvalidScopes: + return "InvalidScopes" + case TicketInvalidSrcID: + return "InvalidSrcID" + default: + return fmt.Sprintf("Unknown%d", s) + } +} diff --git a/library/go/yandex/tvm/tvm_test.go b/library/go/yandex/tvm/tvm_test.go new file mode 100644 index 0000000000..b7cb5605d1 --- /dev/null +++ b/library/go/yandex/tvm/tvm_test.go @@ -0,0 +1,245 @@ +package tvm_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +func TestUserTicketCheckScopes(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + "not_all_scopes": { + ticketScopes: []string{"bb:sessionid", "test:test1"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestUserTicketCheckScopesAny(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test"}, + err: false, + }, + "multiple_scopes_2": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_3": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestServiceTicketAllowedSrcIDs(t *testing.T) { + cases := map[string]struct { + srcID uint32 + allowedSrcIDs []uint32 + err bool + }{ + "empty_allow_list_allows_any_srcID": {srcID: 162, allowedSrcIDs: []uint32{}, err: false}, + "known_src_id_is_allowed": {srcID: 42, allowedSrcIDs: []uint32{42, 100500}, err: false}, + "unknown_src_id_is_not_allowed": {srcID: 404, allowedSrcIDs: []uint32{42, 100500}, err: true}, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedServiceTicket{ + SrcID: tvm.ClientID(testCase.srcID), + } + allowedSrcIDsMap := make(map[uint32]struct{}, len(testCase.allowedSrcIDs)) + for _, allowedSrcID := range testCase.allowedSrcIDs { + allowedSrcIDsMap[allowedSrcID] = struct{}{} + } + err := ticket.CheckSrcID(allowedSrcIDsMap) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidSrcID, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTicketError_Is(t *testing.T) { + err1 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err2 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err3 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "other uh oh message", + } + err4 := &tvm.TicketError{ + Status: tvm.TicketExpired, + Msg: "uh oh", + } + err5 := &tvm.TicketError{ + Status: tvm.TicketMalformed, + Msg: "i am completely different", + } + var nilErr *tvm.TicketError = nil + + // ticketErrors are equal to themselves + require.True(t, err1.Is(err1)) + require.True(t, err2.Is(err2)) + require.True(t, nilErr.Is(nilErr)) + + // equal value ticketErrors are equal + require.True(t, err1.Is(err2)) + require.True(t, err2.Is(err1)) + // equal status ticketErrors are equal + require.True(t, err1.Is(err3)) + require.True(t, err1.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err2.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err3.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err4.Is(tvm.ErrTicketExpired)) + require.True(t, err5.Is(tvm.ErrTicketMalformed)) + + // different status ticketErrors are not equal + require.False(t, err1.Is(err4)) + + // completely different ticketErrors are not equal + require.False(t, err1.Is(err5)) + + // non-nil ticketErrors are not equal to nil errors + require.False(t, err1.Is(nil)) + require.False(t, err2.Is(nil)) + + // non-nil ticketErrors are not equal to nil ticketErrors + require.False(t, err1.Is(nilErr)) + require.False(t, err2.Is(nilErr)) +} + +func TestBbEnvFromString(t *testing.T) { + type Case struct { + in string + env tvm.BlackboxEnv + err string + } + cases := []Case{ + {in: "prod", env: tvm.BlackboxProd}, + {in: "Prod", env: tvm.BlackboxProd}, + {in: "ProD", env: tvm.BlackboxProd}, + {in: "PROD", env: tvm.BlackboxProd}, + {in: "test", env: tvm.BlackboxTest}, + {in: "prod_yateam", env: tvm.BlackboxProdYateam}, + {in: "ProdYateam", env: tvm.BlackboxProdYateam}, + {in: "test_yateam", env: tvm.BlackboxTestYateam}, + {in: "TestYateam", env: tvm.BlackboxTestYateam}, + {in: "stress", env: tvm.BlackboxStress}, + {in: "", err: "blackbox env is unknown: ''"}, + {in: "kek", err: "blackbox env is unknown: 'kek'"}, + } + + for idx, c := range cases { + res, err := tvm.BlackboxEnvFromString(c.in) + + if c.err == "" { + require.NoError(t, err, idx) + require.Equal(t, c.env, res, idx) + } else { + require.EqualError(t, err, c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/apitest/.arcignore b/library/go/yandex/tvm/tvmauth/apitest/.arcignore new file mode 100644 index 0000000000..c8a6e77006 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/.arcignore @@ -0,0 +1 @@ +apitest diff --git a/library/go/yandex/tvm/tvmauth/apitest/client_test.go b/library/go/yandex/tvm/tvmauth/apitest/client_test.go new file mode 100644 index 0000000000..9d946b2297 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/client_test.go @@ -0,0 +1,351 @@ +package apitest + +import ( + "context" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" + uzap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func apiSettings(t testing.TB, client tvm.ClientID) tvmauth.TvmAPISettings { + var portStr []byte + portStr, err := os.ReadFile("tvmapi.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + env := tvm.BlackboxProd + + if client == 1000501 { + return tvmauth.TvmAPISettings{ + SelfID: 1000501, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{1000502}), + + TVMHost: "localhost", + TVMPort: port, + } + } else if client == 1000502 { + return tvmauth.TvmAPISettings{ + SelfID: 1000502, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "e5kL0vM3nP-nPf-388Hi6Q", + map[string]tvm.ClientID{ + "cl1000501": 1000501, + "cl1000503": 1000503, + }), + + TVMHost: "localhost", + TVMPort: port, + } + } else if client == 1000503 { + return tvmauth.TvmAPISettings{ + SelfID: 1000503, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "S3TyTYVqjlbsflVEwxj33w", + map[string]tvm.ClientID{ + "cl1000501": 1000501, + "cl1000503": 1000503, + }), + + TVMHost: "localhost", + TVMPort: port, + DisableDstCheck: true, + } + } else { + t.Fatalf("Bad client id: %d", client) + return tvmauth.TvmAPISettings{} + } +} + +func TestErrorPassing(t *testing.T) { + _, err := tvmauth.NewAPIClient(tvmauth.TvmAPISettings{}, &nop.Logger{}) + require.Error(t, err) +} + +func TestGetServiceTicketForID(t *testing.T) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(t, 1000501), &nop.Logger{}) + require.NoError(t, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(t, err) + + ticket, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(1000501), ticket.SrcID) + + ticketStrByAlias, err := c1000501.GetServiceTicketForAlias(context.Background(), "1000502") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000501.CheckServiceTicket(context.Background(), ticketStr) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, tvm.TicketInvalidDst, err.(*tvm.TicketError).Status) + + _, err = c1000501.GetServiceTicketForID(context.Background(), 127) + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) + + ticketStr, err = c1000502.GetServiceTicketForID(context.Background(), 1000501) + require.NoError(t, err) + ticketStrByAlias, err = c1000502.GetServiceTicketForAlias(context.Background(), "cl1000501") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000502.GetServiceTicketForAlias(context.Background(), "1000501") + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) +} + +func TestLogger(t *testing.T) { + logger, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + require.NoError(t, err) + + core, logs := observer.New(zap.ZapifyLevel(log.DebugLevel)) + logger.L = logger.L.WithOptions(uzap.WrapCore(func(_ zapcore.Core) zapcore.Core { + return core + })) + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), logger) + require.NoError(t, err) + defer c1000502.Destroy() + + loggedEntries := logs.AllUntimed() + for idx := 0; len(loggedEntries) < 7 && idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + loggedEntries = logs.AllUntimed() + } + + var plainLog string + for _, le := range loggedEntries { + plainLog += le.Message + "\n" + } + + require.Contains( + t, + plainLog, + "Thread-worker started") +} + +func BenchmarkServiceTicket(b *testing.B) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(b, 1000501), &nop.Logger{}) + require.NoError(b, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(b, 1000502), &nop.Logger{}) + require.NoError(b, err) + defer c1000502.Destroy() + + b.Run("GetServiceTicketForID", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + } + }) + }) + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + + b.Run("CheckServiceTicket", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(b, err) + } + }) + }) +} + +const serviceTicketStr = "3:serv:CBAQ__________9_IggIlJEGELaIPQ:KC8zKTnoM7GQ8UkBixoAlDt7CAuNIO_6J4rzeqelj7wn7vCKBfsy1jlg2UIvBw0JKUUc6116s5aBw1-vr4BD1V0eh0z-k_CSGC4DKKlnBEEAwcpHRjOZUdW_5UJFe-l77KMObvZUPLckWUaQKybMSBYDGrAeo1TqHHmkumwSG5s" +const serviceTicketStr2 = "3:serv:CBAQ__________9_IgcIt4g9ENwB:Vw6y8J5k80qeHgZlvT1LLd9CXAQlKW92w1LVxke65AHkK9jOUy6cteUGp3-brIya--n35e3ltJfMuKF0pYRBsYin5PsP7x4KwXUY1ZNUcvCd4URuwAgaWFEASs4Nx62sQmCkToGZG6zEv95C_nuq0aGkv0v_JPSmWu7D2EyaFzA" +const userTicketStr = "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:LPpzn2ILhY1BHXA1a51mtU1emb2QSMH3UhTxsmL07iJ7m2AMc2xloXCKQOI7uK6JuLDf7aSWd9QQJpaRV0mfPzvFTnz2j78hvO3bY8KT_TshA3A-M5-t5gip8CfTVGPmEPwnuUhmKqAGkGSL-sCHyu1RIjHkGbJA250ThHHKgAY" +const userTicketStr2 = "3:user:CA0Q__________9_GjIKAgh7CgMIyAMQyAMaB2JiOnNlc3MaCGJiOnNlc3MyIBIoATINdGVzdC1sb2dpbi1pZA:Bz6R7gV283K3bFzWLcew0G8FTmMz6afl49QUtkgZSniShcahmWEQlG1ANXeHblhfq8IH3VcVPWnUT4rnYRVIXjPIQt4yoOD6rRXbqK7QdBDq9P2fCshfZJUFlYdSxMFnbD7ev3PxrtM6w-jWhMbsK6GZ551RAYjHXzUU5l0Nnqk" +const userTicketStr3 = "3:user:CA0Q__________9_Gj8KAgh7CgMIyAMKBgiVBhDbBxCVBhoIYmI6c2VzczEaCGJiOnNlc3MyINKF2MwEKAEyDXRlc3QtbG9naW4taWQ:Gcl5nYCOsgwWG146HP0dcLSbU1jaV0zr6TEXrPTL02qgwaSsOL1GO37LOPnoa0mTSqQzek3U7uwpfOVr50C65IUXDF64F9H6uIgkl4LizcnIShIkFQcMVE8gPKv_hDxBTY-N1SRBKraJ4jtIDbTropDHGgdyu72riUOsGOfAsU0" +const userTicketDefaultUID0 = "3:user:CA0Q__________9_Gh0KAggBEAAg0oXYzAQoATINdGVzdC1sb2dpbi1pZA:CHkdr6eh5CRcC7878r-SBrq59YzlJ-yvgv6fIoaik3Z4y0tYprwKQwLt-1BME6GMG7grlALscZmU8zlWJ8GvASHyGH1cQ76SpLdwzoFqPYSvNii3mkDwEH2iFk-aSczh9FGpb3_6mbQvsZYiXpxRa2BYn56s4k5yEHq5T2ytFeE" + +func TestDebugInfo(t *testing.T) { + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketS, err := c1000502.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(100500), ticketS.SrcID) + require.Equal(t, tvm.UID(0), ticketS.IssuerUID) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketS, err = c1000502.CheckServiceTicket(context.Background(), serviceTicketStr[:len(serviceTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketU, err := c1000502.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.Equal(t, []tvm.UID{123, 456}, ticketU.UIDs) + require.Equal(t, tvm.UID(456), ticketU.DefaultUID) + require.Equal(t, []string{"bb:kek", "some:scopes"}, ticketU.Scopes) + require.Equal(t, map[tvm.UID]tvm.UserExtFields{123: {UID: 123, CurrentPorgID: 0}, 456: {UID: 456, CurrentPorgID: 0}}, ticketU.UidsExtFieldsMap) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) + + _, err = c1000502.CheckUserTicket(context.Background(), userTicketStr, tvm.WithBlackboxOverride(tvm.BlackboxProdYateam)) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketInvalidBlackboxEnv) + + ticketU, err = c1000502.CheckUserTicket(context.Background(), userTicketStr[:len(userTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) + + s := apiSettings(t, 1000502) + env := tvm.BlackboxTest + s.BlackboxEnv = &env + c, err := tvmauth.NewAPIClient(s, &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + ticketU, err = c.CheckUserTicket(context.Background(), userTicketStr2) + require.NoError(t, err) + require.Equal(t, "test-login-id", ticketU.LoginID) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:sess;scope=bb:sess2;default_uid=456;uid=123;uid=456;env=Test;login_id=test-login-id;", ticketU.DbgInfo) + + ticketU, err = c.CheckUserTicket(context.Background(), userTicketStr3) + require.NoError(t, err) + require.Equal(t, map[tvm.UID]tvm.UserExtFields{123: {UID: 123, CurrentPorgID: 0}, 456: {UID: 456, CurrentPorgID: 0}, 789: {UID: 789, CurrentPorgID: 987}}, ticketU.UidsExtFieldsMap) + require.Equal(t, 789, int(ticketU.DefaultUIDExtFields.UID)) + require.Equal(t, 987, int(ticketU.DefaultUIDExtFields.CurrentPorgID)) + + ticketU, err = c.CheckUserTicket(context.Background(), userTicketDefaultUID0) + require.NoError(t, err) + require.Nil(t, ticketU.DefaultUIDExtFields) + + s = apiSettings(t, 1000503) + s.DisableDstCheck = true + c, err = tvmauth.NewAPIClient(s, &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + ticketS, err = c.CheckServiceTicket(context.Background(), serviceTicketStr2) + require.NoError(t, err) + require.Equal(t, 220, int(ticketS.DstID)) +} + +func TestUnittestClient(t *testing.T) { + _, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{}) + require.NoError(t, err) + + client, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{ + SelfID: 1000502, + }) + require.NoError(t, err) + + _, err = client.GetRoles(context.Background()) + require.ErrorContains(t, err, "Roles are not provided") + _, err = client.GetServiceTicketForID(context.Background(), tvm.ClientID(42)) + require.ErrorContains(t, err, "Destination '42' was not specified in settings") + + status, err := client.GetStatus(context.Background()) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientOK, status.Status) + + st, err := client.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientID(100500), st.SrcID) + + ut, err := client.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.UID(456), ut.DefaultUID) +} + +func TestDynamicClient(t *testing.T) { + logger, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + require.NoError(t, err) + + core, logs := observer.New(zap.ZapifyLevel(log.DebugLevel)) + logger.L = logger.L.WithOptions(uzap.WrapCore(func(_ zapcore.Core) zapcore.Core { + return core + })) + + c1000501, err := tvmauth.NewDynamicApiClient(apiSettings(t, 1000501), logger) + require.NoError(t, err) + + c1000502, err := tvmauth.NewDynamicApiClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketStr, err := c1000501.GetOptionalServiceTicketForID(context.Background(), tvm.ClientID(1000502)) + require.NoError(t, err) + require.NotNil(t, ticketStr) + + ticket, err := c1000502.CheckServiceTicket(context.Background(), *ticketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(1000501), ticket.SrcID) + + err = c1000501.AddDsts(context.Background(), []tvm.ClientID{1000503, 1000504}) + require.NoError(t, err) + + ticketStr, err = c1000501.GetOptionalServiceTicketForID(context.Background(), tvm.ClientID(1000503)) + require.NoError(t, err) + require.Nil(t, ticketStr) + + ticketStr, err = c1000501.GetOptionalServiceTicketForID(context.Background(), tvm.ClientID(1000504)) + require.NoError(t, err) + require.Nil(t, ticketStr) + + c1000501.Destroy() + + loggedEntries := logs.AllUntimed() + for idx := 0; len(loggedEntries) < 7 && idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + loggedEntries = logs.AllUntimed() + } + + var plainLog string + for _, le := range loggedEntries { + plainLog += le.Message + "\n" + } + + require.Contains( + t, + plainLog, + "Adding dst: got task #1 with 2 dsts") + +} diff --git a/library/go/yandex/tvm/tvmauth/apitest/ya.make b/library/go/yandex/tvm/tvmauth/apitest/ya.make new file mode 100644 index 0000000000..93b8746b8d --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/ya.make @@ -0,0 +1,9 @@ +GO_TEST() + +ENV(GODEBUG="cgocheck=2") + +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + +GO_TEST_SRCS(client_test.go) + +END() diff --git a/library/go/yandex/tvm/tvmauth/client.go b/library/go/yandex/tvm/tvmauth/client.go new file mode 100644 index 0000000000..78424498c2 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client.go @@ -0,0 +1,641 @@ +//go:build cgo +// +build cgo + +package tvmauth + +// #include <stdlib.h> +// +// #include "tvm.h" +import "C" +import ( + "context" + "encoding/json" + "fmt" + "runtime" + "sync" + "unsafe" + + "github.com/ydb-platform/ydb/library/go/cgosem" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +// NewIDsOptions creates options for fetching CheckedServiceTicket's with ClientID +func NewIDsOptions(secret string, dsts []tvm.ClientID) *TVMAPIOptions { + tmp := make(map[string]tvm.ClientID) + for _, dst := range dsts { + tmp[fmt.Sprintf("%d", dst)] = dst + } + + res, err := json.Marshal(tmp) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +// NewAliasesOptions creates options for fetching CheckedServiceTicket's with alias+ClientID +func NewAliasesOptions(secret string, dsts map[string]tvm.ClientID) *TVMAPIOptions { + if dsts == nil { + dsts = make(map[string]tvm.ClientID) + } + + res, err := json.Marshal(dsts) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +func (o *TvmAPISettings) pack(out *C.TVM_ApiSettings) { + out.SelfId = C.uint32_t(o.SelfID) + + if o.EnableServiceTicketChecking { + out.EnableServiceTicketChecking = 1 + } + + if o.BlackboxEnv != nil { + out.EnableUserTicketChecking = 1 + out.BlackboxEnv = C.int(*o.BlackboxEnv) + } + + if o.FetchRolesForIdmSystemSlug != "" { + o.fetchRolesForIdmSystemSlug = []byte(o.FetchRolesForIdmSystemSlug) + out.IdmSystemSlug = (*C.uchar)(&o.fetchRolesForIdmSystemSlug[0]) + out.IdmSystemSlugSize = C.int(len(o.fetchRolesForIdmSystemSlug)) + } + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } + if o.DisableDstCheck { + out.DisableDstCheck = 1 + } + if o.TVMHost != "" { + o.tvmHost = []byte(o.TVMHost) + out.TVMHost = (*C.uchar)(&o.tvmHost[0]) + out.TVMHostSize = C.int(len(o.tvmHost)) + } + out.TVMPort = C.int(o.TVMPort) + + if o.TiroleHost != "" { + o.tiroleHost = []byte(o.TiroleHost) + out.TiroleHost = (*C.uchar)(&o.tiroleHost[0]) + out.TiroleHostSize = C.int(len(o.tiroleHost)) + } + out.TirolePort = C.int(o.TirolePort) + out.TiroleTvmId = C.uint32_t(o.TiroleTvmID) + + if o.ServiceTicketOptions != nil { + if o.ServiceTicketOptions.selfSecret != "" { + o.ServiceTicketOptions.selfSecretB = []byte(o.ServiceTicketOptions.selfSecret) + out.SelfSecret = (*C.uchar)(&o.ServiceTicketOptions.selfSecretB[0]) + out.SelfSecretSize = C.int(len(o.ServiceTicketOptions.selfSecretB)) + } + + if len(o.ServiceTicketOptions.dstAliases) != 0 { + out.DstAliases = (*C.uchar)(&o.ServiceTicketOptions.dstAliases[0]) + out.DstAliasesSize = C.int(len(o.ServiceTicketOptions.dstAliases)) + } + } + + if o.DiskCacheDir != "" { + o.diskCacheDir = []byte(o.DiskCacheDir) + + out.DiskCacheDir = (*C.uchar)(&o.diskCacheDir[0]) + out.DiskCacheDirSize = C.int(len(o.diskCacheDir)) + } +} + +func (o *TvmToolSettings) pack(out *C.TVM_ToolSettings) { + if o.Alias != "" { + o.alias = []byte(o.Alias) + + out.Alias = (*C.uchar)(&o.alias[0]) + out.AliasSize = C.int(len(o.alias)) + } + + out.Port = C.int(o.Port) + + if o.Hostname != "" { + o.hostname = []byte(o.Hostname) + out.Hostname = (*C.uchar)(&o.hostname[0]) + out.HostnameSize = C.int(len(o.hostname)) + } + + if o.AuthToken != "" { + o.authToken = []byte(o.AuthToken) + out.AuthToken = (*C.uchar)(&o.authToken[0]) + out.AuthTokenSize = C.int(len(o.authToken)) + } + + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } + if o.DisableDstCheck { + out.DisableDstCheck = 1 + } +} + +func (o *TvmUnittestSettings) pack(out *C.TVM_UnittestSettings) { + out.SelfId = C.uint32_t(o.SelfID) + out.BlackboxEnv = C.int(o.BlackboxEnv) +} + +// Destroy stops client and delete it from memory. +// Do not try to use client after destroying it +func (c *Client) Destroy() { + if c.handle == nil { + return + } + + C.TVM_DestroyClient(c.handle) + c.handle = nil + + if c.logger != nil { + unregisterLogger(*c.logger) + } +} + +func (c *DynamicClient) Destroy() { + c.dynHandle = nil + c.Client.Destroy() +} + +func unpackString(s *C.TVM_String) string { + if s.Data == nil { + return "" + } + + return C.GoStringN(s.Data, s.Size) +} + +func unpackErr(err *C.TVM_Error) error { + msg := unpackString(&err.Message) + code := tvm.ErrorCode(err.Code) + + if code != 0 { + return &tvm.Error{Code: code, Retriable: err.Retriable != 0, Msg: msg} + } + + return nil +} + +func unpackScopes(scopes *C.TVM_String, scopeSize C.int) (s []string) { + if scopeSize == 0 { + return + } + + s = make([]string, int(scopeSize)) + scopesArr := (*[1 << 30]C.TVM_String)(unsafe.Pointer(scopes)) + + for i := 0; i < int(scopeSize); i++ { + s[i] = C.GoStringN(scopesArr[i].Data, scopesArr[i].Size) + } + + return +} + +func unpackUidsExtFieldsMap(uidsExtFields *C.TVM_UserExtFields, uidsExtFieldsSize C.int) map[tvm.UID]tvm.UserExtFields { + res := make(map[tvm.UID]tvm.UserExtFields, int(uidsExtFieldsSize)) + + extFieldsArr := (*[1 << 30]C.TVM_UserExtFields)(unsafe.Pointer(uidsExtFields)) + + for i := 0; i < int(uidsExtFieldsSize); i++ { + uid := C.uint64_t(extFieldsArr[i].Uid) + currentPorgId := C.uint64_t(extFieldsArr[i].CurrentPorgId) + + res[tvm.UID(uid)] = tvm.UserExtFields{ + UID: tvm.UID(uid), + CurrentPorgID: tvm.PorgID(currentPorgId), + } + } + + return res +} + +func unpackDefaultUIDExtFields(defaultUIDExtFields *C.TVM_UserExtFields) *tvm.UserExtFields { + fields := (*C.TVM_UserExtFields)(unsafe.Pointer(defaultUIDExtFields)) + + if fields == nil { + return nil + } + + uid := C.uint64_t(fields.Uid) + currentPorgId := C.uint64_t(fields.CurrentPorgId) + + res := &tvm.UserExtFields{ + UID: tvm.UID(uid), + CurrentPorgID: tvm.PorgID(currentPorgId), + } + + return res +} + +func unpackStatus(status C.int) error { + if status == 0 { + return nil + } + + return &tvm.TicketError{ + Status: tvm.TicketStatus(status), + Msg: C.GoString(C.TVM_TicketStatusToString(status)), + } +} + +func unpackServiceTicket(t *C.TVM_ServiceTicket) (*tvm.CheckedServiceTicket, error) { + ticket := &tvm.CheckedServiceTicket{} + ticket.SrcID = tvm.ClientID(t.SrcId) + ticket.DstID = tvm.ClientID(t.DstId) + ticket.IssuerUID = tvm.UID(t.IssuerUid) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + return ticket, unpackStatus(t.Status) +} + +func unpackUserTicket(t *C.TVM_UserTicket) (*tvm.CheckedUserTicket, error) { + ticket := &tvm.CheckedUserTicket{} + ticket.DefaultUID = tvm.UID(t.DefaultUid) + if t.UidsSize != 0 { + ticket.UIDs = make([]tvm.UID, int(t.UidsSize)) + uids := (*[1 << 30]C.uint64_t)(unsafe.Pointer(t.Uids)) + for i := 0; i < int(t.UidsSize); i++ { + ticket.UIDs[i] = tvm.UID(uids[i]) + } + } + + ticket.Env = tvm.BlackboxEnv(t.Env) + + ticket.Scopes = unpackScopes(t.Scopes, t.ScopesSize) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + ticket.LoginID = unpackString(&t.LoginId) + ticket.UidsExtFieldsMap = unpackUidsExtFieldsMap(t.UidsExtFields, t.UidsExtFieldsSize) + ticket.DefaultUIDExtFields = unpackDefaultUIDExtFields(t.DefaultUidExtFields) + return ticket, unpackStatus(t.Status) +} + +func unpackClientStatus(s *C.TVM_ClientStatus) (status tvm.ClientStatusInfo) { + status.Status = tvm.ClientStatus(s.Status) + status.LastError = C.GoStringN(s.LastError.Data, s.LastError.Size) + + return +} + +// NewAPIClient creates client which uses https://tvm-api.yandex.net to get state +func NewAPIClient(options TvmAPISettings, log log.Logger) (*Client, error) { + var settings C.TVM_ApiSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewApiClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +func NewDynamicApiClient(options TvmAPISettings, log log.Logger) (*DynamicClient, error) { + var settings C.TVM_ApiSettings + options.pack(&settings) + + client := &DynamicClient{ + Client: &Client{ + mutex: &sync.RWMutex{}, + }, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewDynamicApiClient(settings, C.int(loggerId), &client.Client.handle, &client.dynHandle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*DynamicClient).Destroy) + return client, nil +} + +// NewToolClient creates client uses local http-interface to get state: http://localhost/tvm/. +// Details: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +func NewToolClient(options TvmToolSettings, log log.Logger) (*Client, error) { + var settings C.TVM_ToolSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewToolClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// NewUnittestClient creates client with mocked state. +func NewUnittestClient(options TvmUnittestSettings) (*Client, error) { + var settings C.TVM_UnittestSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var tvmErr C.TVM_Error + C.TVM_NewUnittestClient(settings, &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// CheckServiceTicket always checks ticket with keys from memory +func (c *Client) CheckServiceTicket(ctx context.Context, ticketStr string) (*tvm.CheckedServiceTicket, error) { + defer cgosem.S.Acquire().Release() + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket C.TVM_ServiceTicket + var tvmErr C.TVM_Error + C.TVM_CheckServiceTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackServiceTicket(&ticket) +} + +// CheckUserTicket always checks ticket with keys from memory +func (c *Client) CheckUserTicket(ctx context.Context, ticketStr string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer cgosem.S.Acquire().Release() + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var bbEnv *C.int + var bbEnvOverrided C.int + if options.EnvOverride != nil { + bbEnvOverrided = C.int(*options.EnvOverride) + bbEnv = &bbEnvOverrided + } + + var ticket C.TVM_UserTicket + var tvmErr C.TVM_Error + C.TVM_CheckUserTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + bbEnv, + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackUserTicket(&ticket) +} + +// GetServiceTicketForAlias always returns ticket from memory +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + defer cgosem.S.Acquire().Release() + + aliasBytes := []byte(alias) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicketForAlias( + c.handle, + (*C.uchar)(&aliasBytes[0]), C.int(len(aliasBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +// GetServiceTicketForID always returns ticket from memory +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicket( + c.handle, + C.uint32_t(dstID), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +func (c *DynamicClient) GetOptionalServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (*string, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetOptionalServiceTicketFor( + c.dynHandle, + C.uint32_t(dstID), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + if ticket == nil { + return nil, nil + } + + res := C.GoString(ticket) + return &res, nil +} + +func (c *DynamicClient) AddDsts(ctx context.Context, dsts []tvm.ClientID) error { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var tvmErr C.TVM_Error + C.TVM_AddDsts( + c.dynHandle, + (*C.uint32_t)(&dsts[0]), + C.int(len(dsts)), + &tvmErr, + &pool) + runtime.KeepAlive(c) + runtime.KeepAlive(dsts) + + if err := unpackErr(&tvmErr); err != nil { + return err + } + + return nil +} + +// GetStatus returns current status of client. +// See detials: https://godoc.yandex-team.ru/pkg/github.com/ydb-platform/ydb/library/go/yandex/tvm/#Client +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var status C.TVM_ClientStatus + var tvmErr C.TVM_Error + C.TVM_GetStatus(c.handle, &status, &tvmErr, &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return tvm.ClientStatusInfo{}, err + } + + return unpackClientStatus(&status), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + currentRoles := c.getCurrentRoles() + var currentRevision []byte + var currentRevisionPtr *C.uchar + if currentRoles != nil { + currentRevision = []byte(currentRoles.GetMeta().Revision) + currentRevisionPtr = (*C.uchar)(¤tRevision[0]) + } + + var raw *C.char + var rawSize C.int + var tvmErr C.TVM_Error + C.TVM_GetRoles( + c.handle, + currentRevisionPtr, C.int(len(currentRevision)), + &raw, + &rawSize, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + if raw == nil { + return currentRoles, nil + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if currentRoles != c.roles { + return c.roles, nil + } + + roles, err := tvm.NewRoles(C.GoBytes(unsafe.Pointer(raw), rawSize)) + if err != nil { + return nil, err + } + + c.roles = roles + return c.roles, nil +} + +func (c *Client) getCurrentRoles() *tvm.Roles { + c.mutex.RLock() + defer c.mutex.RUnlock() + return c.roles +} diff --git a/library/go/yandex/tvm/tvmauth/client_example_test.go b/library/go/yandex/tvm/tvmauth/client_example_test.go new file mode 100644 index 0000000000..09e39b5af7 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client_example_test.go @@ -0,0 +1,182 @@ +package tvmauth_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" +) + +func ExampleNewAPIClient_getServiceTicketsWithAliases() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_getServiceTicketsWithID() { + blackboxID := tvm.ClientID(1000502) + datasyncID := tvm.ClientID(1000503) + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{ + blackboxID, + datasyncID, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForID(context.Background(), blackboxID) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_checkServiceTicket() { + // allowed tvm consumers for your service + acl := map[tvm.ClientID]interface{}{} + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + + serviceTicketStruct, err := c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "service ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if serviceTicketStruct != nil { + response["debug_info"] = serviceTicketStruct.DbgInfo + } + panic(response) // return 403 + } + if _, ok := acl[serviceTicketStruct.SrcID]; !ok { + response := map[string]string{ + "error": fmt.Sprintf("tvm client id is not allowed: %d", serviceTicketStruct.SrcID), + } + panic(response) // return 403 + } + + // proceed... +} + +func ExampleNewAPIClient_checkUserTicket() { + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + userTicketFromRequest := "lol" + + _, _ = c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) // See example for this method + + userTicketStruct, err := c.CheckUserTicket(context.Background(), userTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "user ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if userTicketStruct != nil { + response["debug_info"] = userTicketStruct.DbgInfo + } + panic(response) // return 403 + } + + fmt.Printf("Got user in request: %d", userTicketStruct.DefaultUID) + // proceed... +} + +func ExampleNewAPIClient_createClientWithAllSettings() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + _, _ = tvmauth.NewAPIClient(settings, &nop.Logger{}) +} + +func ExampleNewToolClient_getServiceTicketsWithAliases() { + // should be configured in tvmtool + blackboxAlias := "blackbox" + + settings := tvmauth.TvmToolSettings{ + Alias: "my_service", + Port: 18000, + AuthToken: "kek", + } + + c, err := tvmauth.NewToolClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + // please extrapolate other methods for this way of construction +} diff --git a/library/go/yandex/tvm/tvmauth/doc.go b/library/go/yandex/tvm/tvmauth/doc.go new file mode 100644 index 0000000000..ece7efd3ba --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/doc.go @@ -0,0 +1,10 @@ +// CGO implementation of tvm-interface based on ticket_parser2. +// +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package provides client via tvm-api or tvmtool. +// Also this package provides the most efficient way for checking tickets regardless of the client construction way. +// All scenerios are provided without any request after construction. +// +// You should create client with NewAPIClient() or NewToolClient(). +// Also you need to check status of client with GetStatus(). +package tvmauth diff --git a/library/go/yandex/tvm/tvmauth/gotest/ya.make b/library/go/yandex/tvm/tvmauth/gotest/ya.make new file mode 100644 index 0000000000..ef0de6851d --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/yandex/tvm/tvmauth) + +END() diff --git a/library/go/yandex/tvm/tvmauth/logger.go b/library/go/yandex/tvm/tvmauth/logger.go new file mode 100644 index 0000000000..c32d9dd895 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/logger.go @@ -0,0 +1,77 @@ +//go:build cgo +// +build cgo + +package tvmauth + +import "C" +import ( + "fmt" + "sync" + + "github.com/ydb-platform/ydb/library/go/core/log" +) + +// CGO pointer rules state: +// +// Go code may pass a Go pointer to C provided the Go memory to which it points **does not contain any Go pointers**. +// +// Logger is an interface and contains pointer to implementation. That means, we are forbidden from +// passing Logger to C code. +// +// Instead, we put logger into a global map and pass key to the C code. +// +// This might seem inefficient, but we are not concerned with performance here, since the logger is not on the hot path anyway. + +var ( + loggersLock sync.Mutex + nextSlot int + loggers = map[int]log.Logger{} +) + +func registerLogger(l log.Logger) int { + loggersLock.Lock() + defer loggersLock.Unlock() + + i := nextSlot + nextSlot++ + loggers[i] = l + return i +} + +func unregisterLogger(i int) { + loggersLock.Lock() + defer loggersLock.Unlock() + + if _, ok := loggers[i]; !ok { + panic(fmt.Sprintf("attempt to unregister unknown logger %d", i)) + } + + delete(loggers, i) +} + +func findLogger(i int) log.Logger { + loggersLock.Lock() + defer loggersLock.Unlock() + + return loggers[i] +} + +// TVM_WriteToLog is technical artifact +// +//export TVM_WriteToLog +func TVM_WriteToLog(logger int, level int, msgData *C.char, msgSize C.int) { + l := findLogger(logger) + + msg := C.GoStringN(msgData, msgSize) + + switch level { + case 3: + l.Error(msg) + case 4: + l.Warn(msg) + case 6: + l.Info(msg) + default: + l.Debug(msg) + } +} diff --git a/library/go/yandex/tvm/tvmauth/stub.go b/library/go/yandex/tvm/tvmauth/stub.go new file mode 100644 index 0000000000..eb76ea6e69 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/stub.go @@ -0,0 +1,90 @@ +//go:build !cgo +// +build !cgo + +package tvmauth + +// +// Pure 'go' stub to avoid linting CGO constrains violation errors on +// sandbox build stage of dependant projects. +// + +import ( + "context" + "errors" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +// NewIDsOptions stub for tvmauth.NewIDsOptions. +func NewIDsOptions(secret string, dsts []tvm.ClientID) *TVMAPIOptions { + return nil +} + +// NewAliasesOptions stub for tvmauth.NewAliasesOptions +func NewAliasesOptions(secret string, dsts map[string]tvm.ClientID) *TVMAPIOptions { + return nil +} + +// NewAPIClient implemtation of tvm.Client interface. +// nolint: go-lint +func NewAPIClient(options TvmAPISettings, log log.Logger) (*Client, error) { + return nil, tvm.ErrNotSupported +} + +// NewDynamicApiClient implemtation of tvm.DynamicClient interface. +// +//nolint:st1003 +func NewDynamicApiClient(options TvmAPISettings, log log.Logger) (*DynamicClient, error) { + return nil, tvm.ErrNotSupported +} + +// NewToolClient stub. +func NewToolClient(options TvmToolSettings, log log.Logger) (*Client, error) { + return nil, tvm.ErrNotSupported +} + +// NewUnittestClient stub. +func NewUnittestClient(options TvmUnittestSettings) (*Client, error) { + return nil, tvm.ErrNotSupported +} + +// CheckServiceTicket implementation of tvm.Client interface. +func (c *Client) CheckServiceTicket(ctx context.Context, ticketStr string) (*tvm.CheckedServiceTicket, error) { + return nil, tvm.ErrNotSupported +} + +// CheckUserTicket implemtation of tvm.Client interface. +func (c *Client) CheckUserTicket(ctx context.Context, ticketStr string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + return nil, tvm.ErrNotSupported +} + +// GetServiceTicketForAlias implemtation of tvm.Client interface. +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + return "", tvm.ErrNotSupported +} + +// GetServiceTicketForID implemtation of tvm.Client interface. +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + return "", tvm.ErrNotSupported +} + +// GetStatus implemtation of tvm.Client interface. +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + return tvm.ClientStatusInfo{}, tvm.ErrNotSupported +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + return nil, errors.New("not implemented") +} + +func (c *Client) GetOptionalServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (*string, error) { + return nil, tvm.ErrNotSupported +} + +func (c *Client) AddDsts(ctx context.Context, dsts []tvm.ClientID) error { + return tvm.ErrNotSupported +} + +func (c *Client) Destroy() { +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go new file mode 100644 index 0000000000..585bf40d17 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go @@ -0,0 +1,343 @@ +package tiroletest + +import ( + "context" + "os" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" +) + +func getPort(t *testing.T, filename string) int { + body, err := os.ReadFile(filename) + require.NoError(t, err) + + res, err := strconv.Atoi(string(body)) + require.NoError(t, err, "port is invalid: ", filename) + + return res +} + +func createClientWithTirole(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + env := tvm.BlackboxProdYateam + client, err := tvmauth.NewAPIClient( + tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("e5kL0vM3nP-nPf-388Hi6Q", nil), + DiskCacheDir: "./", + FetchRolesForIdmSystemSlug: "some_slug_2", + EnableServiceTicketChecking: true, + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + BlackboxEnv: &env, + TVMHost: "http://localhost", + TVMPort: getPort(t, "tvmapi.port"), + TiroleHost: "http://localhost", + TirolePort: getPort(t, "tirole.port"), + TiroleTvmID: 1000001, + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func createClientWithTvmtool(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + token, err := os.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + client, err := tvmauth.NewToolClient( + tvmauth.TvmToolSettings{ + Alias: "me", + AuthToken: string(token), + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + Port: getPort(t, "tvmtool.port"), + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func checkServiceNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000000: tvmknife unittest service -s 1000000000 -d 1000502 + stWithoutRoles := "3:serv:CBAQ__________9_IgoIgJTr3AMQtog9:Sv3SKuDQ4p-2419PKqc1vo9EC128K6Iv7LKck5SyliJZn5gTAqMDAwb9aYWHhf49HTR-Qmsjw4i_Lh-sNhge-JHWi5PTGFJm03CZHOCJG9Y0_G1pcgTfodtAsvDykMxLhiXGB4N84cGhVVqn1pFWz6SPmMeKUPulTt7qH1ifVtQ" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{ + Status: tvm.TicketNoRoles, + Msg: "Subject (src or defaultUid) does not have any roles in IDM", + }, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + st, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res := roles.GetRolesForService(st) + require.Nil(t, res) + } +} + +func checkServiceHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000001: tvmknife unittest service -s 1000000001 -d 1000502 + stWithRoles := "3:serv:CBAQ__________9_IgoIgZTr3AMQtog9:EyPympmoLBM6jyiQLcK8ummNmL5IUAdTvKM1do8ppuEgY6yHfto3s_WAKmP9Pf9EiNqPBe18HR7yKmVS7gvdFJY4gP4Ut51ejS-iBPlsbsApJOYTgodQPhkmjHVKIT0ub0pT3fWHQtapb8uimKpGcO6jCfopFQSVG04Ehj7a0jw" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckServiceTicket(ctx, stWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + require.EqualValues(t, + `{ + "/role/service/read/": [], + "/role/service/write/": [ + { + "foo": "bar", + "kek": "lol" + } + ] +}`, + clientRoles.GetRolesForService(checked).DebugPrint(), + ) + + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", nil)) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", nil)) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/foo/", nil)) + + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"kek": "lol"}, + })) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } +} + +func checkUserNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1000000000: tvmknife unittest user -d 1000000000 --env prod_yateam + utWithoutRoles := "3:user:CAwQ__________9_GhYKBgiAlOvcAxCAlOvcAyDShdjMBCgC:LloRDlCZ4vd0IUTOj6MD1mxBPgGhS6EevnnWvHgyXmxc--2CVVkAtNKNZJqCJ6GtDY4nknEnYmWvEu6-MInibD-Uk6saI1DN-2Y3C1Wdsz2SJCq2OYgaqQsrM5PagdyP9PLrftkuV_ZluS_FUYebMXPzjJb0L0ALKByMPkCVWuk" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{ + Status: tvm.TicketNoRoles, + Msg: "Subject (src or defaultUid) does not have any roles in IDM", + }, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + ut, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, res) + } +} + +func checkUserHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1120000000000001: tvmknife unittest user -d 1120000000000001 --env prod_yateam + utWithRoles := "3:user:CAwQ__________9_GhwKCQiBgJiRpdT-ARCBgJiRpdT-ASDShdjMBCgC:SQV7Z9hDpZ_F62XGkSF6yr8PoZHezRp0ZxCINf_iAbT2rlEiO6j4UfLjzwn3EnRXkAOJxuAtTDCnHlrzdh3JgSKK7gciwPstdRT5GGTixBoUU9kI_UlxEbfGBX1DfuDsw_GFQ2eCLu4Svq6jC3ynuqQ41D2RKopYL8Bx8PDZKQc" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckUserTicket(ctx, utWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + ut, err := clientRoles.GetRolesForUser(checked, nil) + require.NoError(t, err) + require.EqualValues(t, + `{ + "/role/user/read/": [ + { + "foo": "bar", + "kek": "lol" + } + ], + "/role/user/write/": [] +}`, + ut.DebugPrint(), + ) + + res, err := clientRoles.CheckUserRole(checked, "/role/user/write/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/foo/", nil) + require.NoError(t, err) + require.False(t, res) + + res, err = clientRoles.CheckUserRole(checked, "/role/user/write/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.True(t, res) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } + +} + +func TestRolesFromTiroleCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml new file mode 100644 index 0000000000..d2fcaead59 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml @@ -0,0 +1,5 @@ +slugs: + some_slug_2: + tvmid: + - 1000502 + - 1000503 diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json new file mode 100644 index 0000000000..84d85fae19 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json @@ -0,0 +1,22 @@ +{ + "revision": "some_revision_2", + "born_date": 1642160002, + "tvm": { + "1000000001": { + "/role/service/read/": [{}], + "/role/service/write/": [{ + "foo": "bar", + "kek": "lol" + }] + } + }, + "user": { + "1120000000000001": { + "/role/user/write/": [{}], + "/role/user/read/": [{ + "foo": "bar", + "kek": "lol" + }] + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg new file mode 100644 index 0000000000..dbb8fcd458 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg @@ -0,0 +1,10 @@ +{ + "BbEnvType": 2, + "clients": { + "me": { + "secret": "fake_secret", + "self_tvm_id": 1000502, + "roles_for_idm_slug": "some_slug_2" + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/ya.make b/library/go/yandex/tvm/tvmauth/tiroletest/ya.make new file mode 100644 index 0000000000..05f9b95651 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/ya.make @@ -0,0 +1,25 @@ +GO_TEST() + +ENV(GODEBUG="cgocheck=2") + +GO_TEST_SRCS(client_test.go) + +# tirole +INCLUDE(${ARCADIA_ROOT}/library/recipes/tirole/recipe.inc) +USE_RECIPE( + library/recipes/tirole/tirole + --roles-dir library/go/yandex/tvm/tvmauth/tiroletest/roles +) + +# tvmapi - to provide service ticket for tirole +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + +# tvmtool +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) +USE_RECIPE( + library/recipes/tvmtool/tvmtool + library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg + --with-roles-dir library/go/yandex/tvm/tvmauth/tiroletest/roles +) + +END() diff --git a/library/go/yandex/tvm/tvmauth/tooltest/.arcignore b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore new file mode 100644 index 0000000000..251ded04a5 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore @@ -0,0 +1 @@ +tooltest diff --git a/library/go/yandex/tvm/tvmauth/tooltest/client_test.go b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go new file mode 100644 index 0000000000..615256a6f6 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go @@ -0,0 +1,74 @@ +package tooltest + +import ( + "context" + "os" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" +) + +const serviceTicketStr = "3:serv:CBAQ__________9_IgYIKhCWkQY:DnbhBOAMpunP9TuhCvXV8Hg9MEUHSFbRETf710eHVS7plghVsdM-JlLR6XtGeiofX3yiCFMs4Nq7aFJqZwX75HFgGiQymyWWKm2pWTyF0pp8QnaTivIM-Q6xmMqfInUlYrozhkVPmIxT4fqsdrKEACq-Zh8VtuNQYrTLZgsUfWo" + +func recipeToolOptions(t *testing.T) tvmauth.TvmToolSettings { + var portStr, token []byte + portStr, err := os.ReadFile("tvmtool.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + + token, err = os.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + return tvmauth.TvmToolSettings{Alias: "me", Port: port, AuthToken: string(token)} +} + +func disableDstCheckOptions(t *testing.T) tvmauth.TvmToolSettings { + s := recipeToolOptions(t) + s.DisableDstCheck = true + return s +} + +func TestToolClient(t *testing.T) { + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + t.Run("GetServiceTicketForID", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100500) + require.NoError(t, err) + }) + + t.Run("GetInvalidTicket", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100999) + require.Error(t, err) + require.IsType(t, &tvm.Error{}, err) + require.Equal(t, tvm.ErrorBrokenTvmClientSettings, err.(*tvm.Error).Code) + }) + + t.Run("ClientStatus", func(t *testing.T) { + status, err := c.GetStatus(context.Background()) + require.NoError(t, err) + + t.Logf("Got client status: %v", status) + + require.Equal(t, tvm.ClientStatus(0), status.Status) + require.Equal(t, "OK", status.LastError) + }) +} + +func TestDisableDstCheck(t *testing.T) { + c, err := tvmauth.NewToolClient(disableDstCheckOptions(t), &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + ticketS, err := c.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.Equal(t, 100502, int(ticketS.DstID)) +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go new file mode 100644 index 0000000000..f12adc650a --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go @@ -0,0 +1,33 @@ +package tooltest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmauth" +) + +type testLogger struct { + nop.Logger + + msgs []string +} + +func (l *testLogger) Info(msg string, fields ...log.Field) { + l.msgs = append(l.msgs, msg) +} + +func TestLogger(t *testing.T) { + var l testLogger + + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &l) + require.NoError(t, err) + defer c.Destroy() + + time.Sleep(time.Second) + + require.NotEmpty(t, l.msgs) +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/ya.make b/library/go/yandex/tvm/tvmauth/tooltest/ya.make new file mode 100644 index 0000000000..65bd6ccec1 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/ya.make @@ -0,0 +1,12 @@ +GO_TEST() + +ENV(GODEBUG="cgocheck=2") + +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe_with_default_cfg.inc) + +GO_TEST_SRCS( + client_test.go + logger_test.go +) + +END() diff --git a/library/go/yandex/tvm/tvmauth/tvm.cpp b/library/go/yandex/tvm/tvmauth/tvm.cpp new file mode 100644 index 0000000000..2b17495bec --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.cpp @@ -0,0 +1,542 @@ +#include "tvm.h" + +#include "_cgo_export.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/tvmauth/type.h> +#include <library/cpp/tvmauth/client/facade.h> +#include <library/cpp/tvmauth/client/logger.h> +#include <library/cpp/tvmauth/client/mocked_updater.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/api/settings.h> +#include <library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.h> +#include <library/cpp/tvmauth/client/misc/roles/roles.h> + +#include <cstddef> +#include <optional> + +using namespace NTvmAuth; + +void TVM_DestroyMemPool(TVM_MemPool* pool) { + auto freeStr = [](char*& str) { + if (str != nullptr) { + free(str); + str = nullptr; + } + }; + + freeStr(pool->ErrorStr); + + if (pool->Scopes != nullptr) { + free(reinterpret_cast<void*>(pool->Scopes)); + pool->Scopes = nullptr; + } + + if (pool->TicketStr != nullptr) { + delete reinterpret_cast<TString*>(pool->TicketStr); + pool->TicketStr = nullptr; + } + if (pool->RawRolesStr != nullptr) { + delete reinterpret_cast<TString*>(pool->RawRolesStr); + pool->RawRolesStr = nullptr; + } + + if (pool->CheckedUserTicket != nullptr) { + delete reinterpret_cast<TCheckedUserTicket*>(pool->CheckedUserTicket); + pool->CheckedUserTicket = nullptr; + } + + if (pool->CheckedServiceTicket != nullptr) { + delete reinterpret_cast<TCheckedServiceTicket*>(pool->CheckedServiceTicket); + pool->CheckedServiceTicket = nullptr; + } + + if (pool->UidsExtFields != nullptr) { + free(reinterpret_cast<void*>(pool->UidsExtFields)); + pool->UidsExtFields = nullptr; + } + + if (pool->DefaultUidExtFields != nullptr) { + free(reinterpret_cast<void*>(pool->DefaultUidExtFields)); + pool->DefaultUidExtFields = nullptr; + } + + freeStr(pool->DbgInfo); + freeStr(pool->LogInfo); + freeStr(pool->LoginId); + freeStr(pool->LastError.Data); +} + +static void PackStr(TStringBuf in, TVM_String* out, char*& poolStr) noexcept { + out->Data = poolStr = reinterpret_cast<char*>(malloc(in.size())); + out->Size = in.size(); + memcpy(out->Data, in.data(), in.size()); +} + +static void UnpackSettings( + TVM_ApiSettings* in, + NTvmApi::TClientSettings* out) { + if (in->SelfId != 0) { + out->SelfTvmId = in->SelfId; + } + + if (in->EnableServiceTicketChecking != 0) { + out->CheckServiceTickets = true; + } + + if (in->EnableUserTicketChecking != 0) { + out->CheckUserTicketsWithBbEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); + } + + if (in->SelfSecret != nullptr) { + out->Secret = TString(reinterpret_cast<char*>(in->SelfSecret), in->SelfSecretSize); + } + + TStringBuf aliases(reinterpret_cast<char*>(in->DstAliases), in->DstAliasesSize); + if (aliases) { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(aliases, &doc), "Invalid json: from go part: " << aliases); + Y_ENSURE(doc.IsMap(), "Dsts is not map: from go part: " << aliases); + + for (const auto& pair : doc.GetMap()) { + Y_ENSURE(pair.second.IsUInteger(), "dstID must be number"); + out->FetchServiceTicketsForDstsWithAliases.emplace(pair.first, pair.second.GetUInteger()); + } + } + + if (in->IdmSystemSlug != nullptr) { + out->FetchRolesForIdmSystemSlug = TString(reinterpret_cast<char*>(in->IdmSystemSlug), in->IdmSystemSlugSize); + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; + } + + if (in->TVMHost != nullptr) { + out->TvmHost = TString(reinterpret_cast<char*>(in->TVMHost), in->TVMHostSize); + out->TvmPort = in->TVMPort; + } + if (in->TiroleHost != nullptr) { + out->TiroleHost = TString(reinterpret_cast<char*>(in->TiroleHost), in->TiroleHostSize); + out->TirolePort = in->TirolePort; + } + if (in->TiroleTvmId != 0) { + out->TiroleTvmId = in->TiroleTvmId; + } + + if (in->DiskCacheDir != nullptr) { + out->DiskCacheDir = TString(reinterpret_cast<char*>(in->DiskCacheDir), in->DiskCacheDirSize); + } + out->ShouldCheckDst = in->DisableDstCheck == 0; +} + +static void UnpackSettings( + TVM_ToolSettings* in, + NTvmTool::TClientSettings* out) { + if (in->Port != 0) { + out->SetPort(in->Port); + } + + if (in->HostnameSize != 0) { + out->SetHostname(TString(reinterpret_cast<char*>(in->Hostname), in->HostnameSize)); + } + + if (in->AuthTokenSize != 0) { + out->SetAuthToken(TString(reinterpret_cast<char*>(in->AuthToken), in->AuthTokenSize)); + } + + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; + out->ShouldCheckDst = in->DisableDstCheck == 0; +} + +static void UnpackSettings( + TVM_UnittestSettings* in, + TMockedUpdater::TSettings* out) { + out->SelfTvmId = in->SelfId; + out->UserTicketEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); +} + +template <class TTicket> +static void PackScopes( + const TScopes& scopes, + TTicket* ticket, + TVM_MemPool* pool) { + if (scopes.empty()) { + return; + } + + pool->Scopes = ticket->Scopes = reinterpret_cast<TVM_String*>(malloc(scopes.size() * sizeof(TVM_String))); + + for (size_t i = 0; i < scopes.size(); i++) { + ticket->Scopes[i].Data = const_cast<char*>(scopes[i].data()); + ticket->Scopes[i].Size = scopes[i].size(); + } + ticket->ScopesSize = scopes.size(); +} + +static void PackUidsExtFields( + const TUidsExtFieldsMap& uidsExtFields, + TVM_UserTicket* ticket, + TVM_MemPool* pool) { + pool->UidsExtFields = ticket->UidsExtFields = reinterpret_cast<TVM_UserExtFields*>(malloc(uidsExtFields.size() * sizeof(TVM_UserExtFields))); + + size_t i = 0; + for (const auto& [uid, userExtFields] : uidsExtFields) { + ticket->UidsExtFields[i].Uid = userExtFields.Uid; + ticket->UidsExtFields[i].CurrentPorgId = userExtFields.CurrentPorgId; + i++; + } + + ticket->UidsExtFieldsSize = uidsExtFields.size(); +} + +static void PackDefaultUidExtFields( + const std::optional<TUserExtFields>& defaultUidExtFields, + TVM_UserTicket* ticket, + TVM_MemPool* pool) { + if (!defaultUidExtFields) { + return; + } + + pool->DefaultUidExtFields = ticket->DefaultUidExtFields = reinterpret_cast<TVM_UserExtFields*>(malloc(sizeof(TVM_UserExtFields))); + ticket->DefaultUidExtFields->Uid = defaultUidExtFields->Uid; + ticket->DefaultUidExtFields->CurrentPorgId = defaultUidExtFields->CurrentPorgId; +} + +static void PackUserTicket( + TCheckedUserTicket in, + TVM_UserTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedUserTicket(std::move(in)); + pool->CheckedUserTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->DefaultUid = copy->GetDefaultUid(); + + const auto& uids = copy->GetUids(); + if (!uids.empty()) { + out->Uids = const_cast<TUid*>(uids.data()); + out->UidsSize = uids.size(); + } + + out->Env = static_cast<int>(copy->GetEnv()); + + PackScopes(copy->GetScopes(), out, pool); + + PackStr(copy->GetLoginId(), &out->LoginId, pool->LoginId); + + PackUidsExtFields(copy->GetUidsExtFields(), out, pool); + + PackDefaultUidExtFields(copy->GetDefaultUidExtFields(), out, pool); +} + +static void PackServiceTicket( + TCheckedServiceTicket in, + TVM_ServiceTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedServiceTicket(std::move(in)); + pool->CheckedServiceTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->SrcId = copy->GetSrc(); + out->DstId = copy->GetDst(); + + auto issuer = copy->GetIssuerUid(); + if (issuer) { + out->IssuerUid = *issuer; + } +} + +template <class F> +static void CatchError(TVM_Error* err, TVM_MemPool* pool, const F& f) { + try { + f(); + } catch (const TMalformedTvmSecretException& ex) { + err->Code = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMalformedTvmKeysException& ex) { + err->Code = 2; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TEmptyTvmKeysException& ex) { + err->Code = 3; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TNotAllowedException& ex) { + err->Code = 4; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TBrokenTvmClientSettings& ex) { + err->Code = 5; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMissingServiceTicket& ex) { + err->Code = 6; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TPermissionDenied& ex) { + err->Code = 7; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TRetriableException& ex) { + err->Code = 8; + err->Retriable = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const std::exception& ex) { + err->Code = 8; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } +} + +namespace { + class TGoLogger: public ILogger { + public: + TGoLogger(int loggerHandle) + : LoggerHandle_(loggerHandle) + { + } + + void Log(int lvl, const TString& msg) override { + TVM_WriteToLog(LoggerHandle_, lvl, const_cast<char*>(msg.data()), msg.size()); + } + + private: + int LoggerHandle_; + }; + +} + +extern "C" void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + NTvmApi::TClientSettings realSettings; + UnpackSettings(&settings, &realSettings); + + realSettings.LibVersionPrefix = "go_"; + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewDynamicApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + void** dynHandle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + NTvmApi::TClientSettings realSettings; + UnpackSettings(&settings, &realSettings); + + realSettings.LibVersionPrefix = "go_"; + TServiceContext::TCheckFlags flags; + flags.NeedDstCheck = realSettings.ShouldCheckDst; + + auto dynamicClient = NDynamicClient::TTvmClient::Create(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)).Release(); + auto client = new TTvmClient(dynamicClient, flags); + + *handle = static_cast<void*>(client); + *dynHandle = static_cast<void*>(dynamicClient); + }); +} + +extern "C" void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TString alias(reinterpret_cast<char*>(settings.Alias), settings.AliasSize); + NTvmTool::TClientSettings realSettings(alias); + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TMockedUpdater::TSettings realSettings; + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(MakeIntrusiveConst<TMockedUpdater>(realSettings)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_DestroyClient(void* handle) { + delete static_cast<TTvmClient*>(handle); +} + +extern "C" void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + + TClientStatus s = client->GetStatus(); + status->Status = static_cast<int>(s.GetCode()); + + PackStr(s.GetLastError(), &status->LastError, pool->LastError.Data); + }); +} + +extern "C" void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + + TMaybe<EBlackboxEnv> optEnv; + if (env) { + optEnv = (EBlackboxEnv)*env; + } + + auto userTicket = client->CheckUserTicket(str, optEnv); + PackUserTicket(std::move(userTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + auto serviceTicket = client->CheckServiceTicket(str); + PackServiceTicket(std::move(serviceTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(dstId)); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(TString((char*)alias, aliasSize))); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + NTvmAuth::NRoles::TRolesPtr roles = client->GetRoles(); + + if (currentRevision && + roles->GetMeta().Revision == TStringBuf(reinterpret_cast<char*>(currentRevision), currentRevisionSize)) { + return; + } + + auto rawPtr = new TString(roles->GetRaw()); + + pool->RawRolesStr = reinterpret_cast<void*>(rawPtr); + *raw = const_cast<char*>(rawPtr->c_str()); + *rawSize = rawPtr->size(); + }); +} + +extern "C" void TVM_AddDsts( + void* dynHandle, + ui32* dsts, + int size, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<NDynamicClient::TTvmClient*>(dynHandle); + NDynamicClient::TDsts destinations; + for (int i = 0; i < size; i++) { + destinations.insert(dsts[i]); + } + + client->Add(std::move(destinations)); + }); +} + +extern "C" void TVM_GetOptionalServiceTicketFor( + void* dynHandle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<NDynamicClient::TTvmClient*>(dynHandle); + std::optional<TString> optionalTicket = client->GetOptionalServiceTicketFor(dstId); + *ticket = nullptr; + + if (!optionalTicket) { + return; + } + + auto ticketPtr = new TString(*optionalTicket); + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +static const char* UNKNOWN_STATUS = "unknown status"; + +extern "C" const char* TVM_TicketStatusToString(int status) { + try { + return StatusToString(static_cast<ETicketStatus>(status)).data(); + } catch (const std::exception&) { + return UNKNOWN_STATUS; + } +} diff --git a/library/go/yandex/tvm/tvmauth/tvm.h b/library/go/yandex/tvm/tvmauth/tvm.h new file mode 100644 index 0000000000..189ba02f6e --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.h @@ -0,0 +1,235 @@ +#pragma once + +#include <util/system/types.h> + +#include <stdint.h> +#include <time.h> + +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct _TVM_String { + char* Data; + int Size; + } TVM_String; + + typedef struct { + ui64 Uid; + ui64 CurrentPorgId; + } TVM_UserExtFields; + + // MemPool owns memory allocated by C. + typedef struct { + char* ErrorStr; + void* TicketStr; + void* RawRolesStr; + TVM_String* Scopes; + void* CheckedUserTicket; + void* CheckedServiceTicket; + char* DbgInfo; + char* LogInfo; + char* LoginId; + TVM_UserExtFields* UidsExtFields; + TVM_UserExtFields* DefaultUidExtFields; + TVM_String LastError; + } TVM_MemPool; + + void TVM_DestroyMemPool(TVM_MemPool* pool); + + typedef struct { + int Code; + int Retriable; + + TVM_String Message; + } TVM_Error; + + typedef struct { + int Status; + + ui64 DefaultUid; + + ui64* Uids; + int UidsSize; + + int Env; + + TVM_String* Scopes; + int ScopesSize; + + TVM_String DbgInfo; + TVM_String LogInfo; + + TVM_String LoginId; + + TVM_UserExtFields* UidsExtFields; + int UidsExtFieldsSize; + + TVM_UserExtFields* DefaultUidExtFields; + } TVM_UserTicket; + + typedef struct { + int Status; + + ui32 SrcId; + ui32 DstId; + + ui64 IssuerUid; + + TVM_String DbgInfo; + TVM_String LogInfo; + } TVM_ServiceTicket; + + typedef struct { + ui32 SelfId; + + int EnableServiceTicketChecking; + + int EnableUserTicketChecking; + int BlackboxEnv; + + unsigned char* SelfSecret; + int SelfSecretSize; + unsigned char* DstAliases; + int DstAliasesSize; + + unsigned char* IdmSystemSlug; + int IdmSystemSlugSize; + int DisableSrcCheck; + int DisableDefaultUIDCheck; + + unsigned char* TVMHost; + int TVMHostSize; + int TVMPort; + unsigned char* TiroleHost; + int TiroleHostSize; + int TirolePort; + ui32 TiroleTvmId; + + unsigned char* DiskCacheDir; + int DiskCacheDirSize; + + int DisableDstCheck; + } TVM_ApiSettings; + + typedef struct { + unsigned char* Alias; + int AliasSize; + + int Port; + + unsigned char* Hostname; + int HostnameSize; + + unsigned char* AuthToken; + int AuthTokenSize; + + int DisableSrcCheck; + int DisableDefaultUIDCheck; + int DisableDstCheck; + } TVM_ToolSettings; + + typedef struct { + ui32 SelfId; + int BlackboxEnv; + } TVM_UnittestSettings; + + typedef struct { + int Status; + TVM_String LastError; + } TVM_ClientStatus; + + // First argument must be passed by value. "Go code may pass a Go pointer to C + // provided the Go memory to which it points does not contain any Go pointers." + void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewDynamicApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + void** dynHandle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_DestroyClient(void* handle); + + void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_AddDsts( + void* dynHandle, + ui32* dsts, + int size, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetOptionalServiceTicketFor( + void* dynHandle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + const char* TVM_TicketStatusToString(int status); + +#ifdef __cplusplus +} +#endif diff --git a/library/go/yandex/tvm/tvmauth/types.go b/library/go/yandex/tvm/tvmauth/types.go new file mode 100644 index 0000000000..5e3e7878c2 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/types.go @@ -0,0 +1,154 @@ +package tvmauth + +import ( + "sync" + "unsafe" + + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +// TvmAPISettings may be used to fetch data from tvm-api +type TvmAPISettings struct { + // SelfID is required for ServiceTicketOptions and EnableServiceTicketChecking + SelfID tvm.ClientID + + // ServiceTicketOptions provides info for fetching Service Tickets from tvm-api + // to allow you send them to your backends. + // + // WARNING: It is not way to provide authorization for incoming ServiceTickets! + // It is way only to send your ServiceTickets to your backend! + ServiceTicketOptions *TVMAPIOptions + + // EnableServiceTicketChecking enables fetching of public keys for signature checking + EnableServiceTicketChecking bool + + // BlackboxEnv with not nil value enables UserTicket checking + // and enables fetching of public keys for signature checking + BlackboxEnv *tvm.BlackboxEnv + + fetchRolesForIdmSystemSlug []byte + // Non-empty FetchRolesForIdmSystemSlug enables roles fetching from tirole + FetchRolesForIdmSystemSlug string + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + // By default client checks dst from ServiceTicket. If this check is switched off + // incorrect dst does not result in error of checked ticket status + // DANGEROUS: In this case you must check dst manually, you can get it via DstID option in ServiceTicket. + DisableDstCheck bool + + tvmHost []byte + // TVMHost should be used only in tests + TVMHost string + // TVMPort should be used only in tests + TVMPort int + + tiroleHost []byte + // TiroleHost should be used only in tests or for tirole-api-test.yandex.net + TiroleHost string + // TirolePort should be used only in tests + TirolePort int + // TiroleTvmID should be used only in tests or for tirole-api-test.yandex.net + TiroleTvmID tvm.ClientID + + // Directory for disk cache. + // Requires read/write permissions. Permissions will be checked before start. + // WARNING: The same directory can be used only: + // - for TVM clients with the same settings + // OR + // - for new client replacing previous - with another config. + // System user must be the same for processes with these clients inside. + // Implementation doesn't provide other scenarios. + DiskCacheDir string + diskCacheDir []byte +} + +// TVMAPIOptions is part of TvmAPISettings: allows to enable fetching of ServiceTickets +type TVMAPIOptions struct { + selfSecret string + selfSecretB []byte + dstAliases []byte +} + +// TvmToolSettings may be used to fetch data from tvmtool +type TvmToolSettings struct { + // Alias is required: self alias of your tvm ClientID + Alias string + alias []byte + + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + // By default client checks dst from ServiceTicket. If this check is switched off + // incorrect dst does not result in error of checked ticket status + // DANGEROUS: In this case you must check dst manually, you can get it via DstID option in ServiceTicket. + DisableDstCheck bool + + // Port will be detected with env["DEPLOY_TVM_TOOL_URL"] (provided with Yandex.Deploy), + // otherwise port == 1 (it is ok for Qloud) + Port int + // Hostname == "localhost" by default + Hostname string + hostname []byte + + // AuthToken is protection from SSRF. + // By default it is fetched from env: + // * TVMTOOL_LOCAL_AUTHTOKEN (provided with Yandex.Deploy) + // * QLOUD_TVM_TOKEN (provided with Qloud) + AuthToken string + authToken []byte +} + +type TvmUnittestSettings struct { + // SelfID is required for service ticket checking + SelfID tvm.ClientID + + // Service ticket checking is enabled by default + + // User ticket checking is enabled by default: choose required environment + BlackboxEnv tvm.BlackboxEnv + + // Other features are not supported yet +} + +// Client contains raw pointer for C++ object +type Client struct { + handle unsafe.Pointer + + logger *int + + roles *tvm.Roles + mutex *sync.RWMutex +} + +type DynamicClient struct { + *Client + dynHandle unsafe.Pointer +} + +var _ tvm.Client = (*Client)(nil) +var _ tvm.DynamicClient = (*DynamicClient)(nil) diff --git a/library/go/yandex/tvm/tvmauth/ya.make b/library/go/yandex/tvm/tvmauth/ya.make new file mode 100644 index 0000000000..1bcc512d0d --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/ya.make @@ -0,0 +1,43 @@ +GO_LIBRARY() + +IF (CGO_ENABLED) + USE_CXX() + + PEERDIR( + library/cpp/tvmauth/client + library/cpp/tvmauth/client/misc/api/dynamic_dst + ) + + + SRCS( + CGO_EXPORT + tvm.cpp + ) + + CGO_SRCS( + client.go + logger.go + ) +ELSE() + SRCS( + stub.go + ) +ENDIF() + +SRCS( + doc.go + types.go +) + +GO_XTEST_SRCS(client_example_test.go) + +END() + +IF (CGO_ENABLED) + RECURSE_FOR_TESTS( + apitest + gotest + tiroletest + tooltest + ) +ENDIF() diff --git a/library/go/yandex/tvm/tvmtool/any.go b/library/go/yandex/tvm/tvmtool/any.go new file mode 100644 index 0000000000..4a690f4169 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any.go @@ -0,0 +1,37 @@ +package tvmtool + +import ( + "os" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +const ( + LocalEndpointEnvKey = "TVMTOOL_URL" + LocalTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +var ErrUnknownTvmtoolEnvironment = xerrors.NewSentinel("unknown tvmtool environment") + +// NewAnyClient method creates a new tvmtool client with environment auto-detection. +// You must reuse it to prevent connection/goroutines leakage. +func NewAnyClient(opts ...Option) (*Client, error) { + switch { + case os.Getenv(QloudEndpointEnvKey) != "": + // it's Qloud + return NewQloudClient(opts...) + case os.Getenv(DeployEndpointEnvKey) != "": + // it's Y.Deploy + return NewDeployClient(opts...) + case os.Getenv(LocalEndpointEnvKey) != "": + passedOpts := append( + []Option{ + WithAuthToken(os.Getenv(LocalTokenEnvKey)), + }, + opts..., + ) + return NewClient(os.Getenv(LocalEndpointEnvKey), passedOpts...) + default: + return nil, ErrUnknownTvmtoolEnvironment.WithFrame() + } +} diff --git a/library/go/yandex/tvm/tvmtool/any_example_test.go b/library/go/yandex/tvm/tvmtool/any_example_test.go new file mode 100644 index 0000000000..77972b4377 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewAnyClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewAnyClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/clients_test.go b/library/go/yandex/tvm/tvmtool/clients_test.go new file mode 100644 index 0000000000..2f01b832d4 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/clients_test.go @@ -0,0 +1,153 @@ +//go:build linux || darwin +// +build linux darwin + +package tvmtool_test + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func TestNewClients(t *testing.T) { + type TestCase struct { + env map[string]string + willFail bool + expectedErr string + expectedBaseURI string + expectedAuthToken string + } + + cases := map[string]struct { + constructor func(opts ...tvmtool.Option) (*tvmtool.Client, error) + cases map[string]TestCase + }{ + "qloud": { + constructor: tvmtool.NewQloudClient, + cases: map[string]TestCase{ + "no-auth": { + willFail: true, + expectedErr: "empty auth token (looked at ENV[QLOUD_TVM_TOKEN])", + }, + "ok-default-origin": { + env: map[string]string{ + "QLOUD_TVM_TOKEN": "ok-default-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1/tvm", + expectedAuthToken: "ok-default-origin-token", + }, + "ok-custom-origin": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://localhost:9000", + "QLOUD_TVM_TOKEN": "ok-custom-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:9000/tvm", + expectedAuthToken: "ok-custom-origin-token", + }, + }, + }, + "deploy": { + constructor: tvmtool.NewDeployClient, + cases: map[string]TestCase{ + "no-url": { + willFail: true, + expectedErr: "empty tvmtool url (looked at ENV[DEPLOY_TVM_TOOL_URL])", + }, + "no-auth": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:2", + }, + willFail: true, + expectedErr: "empty auth token (looked at ENV[TVMTOOL_LOCAL_AUTHTOKEN])", + }, + "ok": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1337/tvm", + expectedAuthToken: "ok-token", + }, + }, + }, + "any": { + constructor: tvmtool.NewAnyClient, + cases: map[string]TestCase{ + "empty": { + willFail: true, + expectedErr: "unknown tvmtool environment", + }, + "ok-qloud": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://qloud:9000", + "QLOUD_TVM_TOKEN": "ok-qloud", + }, + expectedBaseURI: "http://qloud:9000/tvm", + expectedAuthToken: "ok-qloud", + }, + "ok-deploy": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://deploy:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-deploy", + }, + expectedBaseURI: "http://deploy:1337/tvm", + expectedAuthToken: "ok-deploy", + }, + "ok-local": { + env: map[string]string{ + "TVMTOOL_URL": "http://local:1338", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-local", + }, + willFail: false, + expectedBaseURI: "http://local:1338/tvm", + expectedAuthToken: "ok-local", + }, + }, + }, + } + + // NB! this checks are not thread safe, never use t.Parallel() and so on + for clientName, client := range cases { + t.Run(clientName, func(t *testing.T) { + for name, tc := range client.cases { + t.Run(name, func(t *testing.T) { + savedEnv := os.Environ() + defer func() { + os.Clearenv() + for _, env := range savedEnv { + parts := strings.SplitN(env, "=", 2) + err := os.Setenv(parts[0], parts[1]) + require.NoError(t, err) + } + }() + + os.Clearenv() + for key, val := range tc.env { + _ = os.Setenv(key, val) + } + + tvmClient, err := client.constructor() + if tc.willFail { + require.Error(t, err) + if tc.expectedErr != "" { + require.EqualError(t, err, tc.expectedErr) + } + + require.Nil(t, tvmClient) + } else { + require.NoError(t, err) + require.NotNil(t, tvmClient) + require.Equal(t, tc.expectedBaseURI, tvmClient.BaseURI()) + require.Equal(t, tc.expectedAuthToken, tvmClient.AuthToken()) + } + }) + } + }) + } +} diff --git a/library/go/yandex/tvm/tvmtool/deploy.go b/library/go/yandex/tvm/tvmtool/deploy.go new file mode 100644 index 0000000000..d7a2eac62b --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy.go @@ -0,0 +1,31 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + DeployEndpointEnvKey = "DEPLOY_TVM_TOOL_URL" + DeployTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +// NewDeployClient method creates a new tvmtool client for Deploy environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewDeployClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(DeployEndpointEnvKey) + if baseURI == "" { + return nil, fmt.Errorf("empty tvmtool url (looked at ENV[%s])", DeployEndpointEnvKey) + } + + authToken := os.Getenv(DeployTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", DeployTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/deploy_example_test.go b/library/go/yandex/tvm/tvmtool/deploy_example_test.go new file mode 100644 index 0000000000..674a59083b --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewDeployClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewDeployClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/doc.go b/library/go/yandex/tvm/tvmtool/doc.go new file mode 100644 index 0000000000..d46dca8132 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/doc.go @@ -0,0 +1,7 @@ +// Pure Go implementation of tvm-interface based on TVMTool client. +// +// https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package can provide fast getting of service tickets (from cache), other cases lead to http request to localhost. +// Also this package provides TVM client for Qloud (NewQloudClient) and Yandex.Deploy (NewDeployClient) environments. +package tvmtool diff --git a/library/go/yandex/tvm/tvmtool/errors.go b/library/go/yandex/tvm/tvmtool/errors.go new file mode 100644 index 0000000000..85ccde73e0 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/errors.go @@ -0,0 +1,61 @@ +package tvmtool + +import ( + "fmt" + + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +// Generic TVM errors, before retry any request it check .Retriable field. +type Error = tvm.Error + +const ( + // ErrorAuthFail - auth failed, probably you provides invalid auth token + ErrorAuthFail = tvm.ErrorAuthFail + // ErrorBadRequest - tvmtool rejected our request, check .Msg for details + ErrorBadRequest = tvm.ErrorBadRequest + // ErrorOther - any other TVM-related errors, check .Msg for details + ErrorOther = tvm.ErrorOther +) + +// Ticket validation error +type TicketError = tvm.TicketError + +const ( + TicketErrorInvalidScopes = tvm.TicketInvalidScopes + TicketErrorOther = tvm.TicketStatusOther +) + +type PingCode uint32 + +const ( + PingCodeDie = iota + PingCodeWarning + PingCodeError + PingCodeOther +) + +func (e PingCode) String() string { + switch e { + case PingCodeDie: + return "HttpDie" + case PingCodeWarning: + return "Warning" + case PingCodeError: + return "Error" + case PingCodeOther: + return "Other" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +// Special ping error +type PingError struct { + Code PingCode + Err error +} + +func (e *PingError) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Err.Error(), e.Code) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go new file mode 100644 index 0000000000..772d9790bb --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + srvTicket string + userTicket string +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&srvTicket, "srv", "", "service ticket to check") + flag.StringVar(&userTicket, "usr", "", "user ticket to check") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + fmt.Printf("------ Check service ticket ------\n\n") + srvCheck, err := tvmClient.CheckServiceTicket(context.Background(), srvTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", srvCheck, err) + } else { + fmt.Printf("OK\nInfo: %s\n", srvCheck) + } + + if userTicket == "" { + return + } + + fmt.Printf("\n------ Check user ticket result ------\n\n") + + usrCheck, err := tvmClient.CheckUserTicket(context.Background(), userTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", usrCheck, err) + return + } + fmt.Printf("OK\nInfo: %s\n", usrCheck) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/check_tickets/ya.make b/library/go/yandex/tvm/tvmtool/examples/check_tickets/ya.make new file mode 100644 index 0000000000..6a0765382d --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/check_tickets/ya.make @@ -0,0 +1,5 @@ +GO_PROGRAM() + +SRCS(main.go) + +END() diff --git a/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go new file mode 100644 index 0000000000..90e12f35c4 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + dst = "dst" +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&dst, "dst", dst, "Destination TVM app (must be configured in tvm-tool)") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), dst) + if err != nil { + zlog.Fatal("failed to get tvm ticket", log.String("dst", dst), log.Error(err)) + return + } + + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/ya.make b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/ya.make new file mode 100644 index 0000000000..6a0765382d --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/ya.make @@ -0,0 +1,5 @@ +GO_PROGRAM() + +SRCS(main.go) + +END() diff --git a/library/go/yandex/tvm/tvmtool/examples/ya.make b/library/go/yandex/tvm/tvmtool/examples/ya.make new file mode 100644 index 0000000000..a9518720af --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/ya.make @@ -0,0 +1,4 @@ +RECURSE( + check_tickets + get_service_ticket +) diff --git a/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json new file mode 100644 index 0000000000..7ff1aa7979 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json @@ -0,0 +1,32 @@ +{ + "BbEnvType": 3, + "clients": { + "main": { + "secret": "fake_secret", + "self_tvm_id": 42, + "dsts": { + "he": { + "dst_id": 100500 + }, + "he_clone": { + "dst_id": 100500 + }, + "slave": { + "dst_id": 43 + }, + "self": { + "dst_id": 42 + } + } + }, + "slave": { + "secret": "fake_secret", + "self_tvm_id": 43, + "dsts": { + "he": { + "dst_id": 100500 + } + } + } + } +} diff --git a/library/go/yandex/tvm/tvmtool/gotest/ya.make b/library/go/yandex/tvm/tvmtool/gotest/ya.make new file mode 100644 index 0000000000..c1de9253ff --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/gotest/ya.make @@ -0,0 +1,33 @@ +GO_TEST_FOR(library/go/yandex/tvm/tvmtool) + +SIZE(MEDIUM) + +DEFAULT( + USE_TVM_TOOL + 0 +) + +# tvmtool recipe exists only for linux & darwin + +IF (OS_LINUX) + SET( + USE_TVM_TOOL + 1 + ) +ELSEIF(OS_DARWIN) + SET( + USE_TVM_TOOL + 1 + ) +ENDIF() + +IF (USE_TVM_TOOL) + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + + USE_RECIPE( + library/recipes/tvmtool/tvmtool + library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json + ) +ENDIF() + +END() diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go new file mode 100644 index 0000000000..9ec9665682 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go @@ -0,0 +1,128 @@ +package cache + +import ( + "sync" + "time" + + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +const ( + Hit Status = iota + Miss + GonnaMissy +) + +type ( + Status int + + Cache struct { + ttl time.Duration + maxTTL time.Duration + tickets map[tvm.ClientID]entry + aliases map[string]tvm.ClientID + lock sync.RWMutex + } + + entry struct { + value *string + born time.Time + } +) + +func New(ttl, maxTTL time.Duration) *Cache { + return &Cache{ + ttl: ttl, + maxTTL: maxTTL, + tickets: make(map[tvm.ClientID]entry, 1), + aliases: make(map[string]tvm.ClientID, 1), + } +} + +func (c *Cache) Gc() { + now := time.Now() + + c.lock.Lock() + defer c.lock.Unlock() + for clientID, ticket := range c.tickets { + if ticket.born.Add(c.maxTTL).After(now) { + continue + } + + delete(c.tickets, clientID) + for alias, aClientID := range c.aliases { + if clientID == aClientID { + delete(c.aliases, alias) + } + } + } +} + +func (c *Cache) ClientIDs() []tvm.ClientID { + c.lock.RLock() + defer c.lock.RUnlock() + + clientIDs := make([]tvm.ClientID, 0, len(c.tickets)) + for clientID := range c.tickets { + clientIDs = append(clientIDs, clientID) + } + return clientIDs +} + +func (c *Cache) Aliases() []string { + c.lock.RLock() + defer c.lock.RUnlock() + + aliases := make([]string, 0, len(c.aliases)) + for alias := range c.aliases { + aliases = append(aliases, alias) + } + return aliases +} + +func (c *Cache) Load(clientID tvm.ClientID) (*string, Status) { + c.lock.RLock() + e, ok := c.tickets[clientID] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + now := time.Now() + exp := e.born.Add(c.ttl) + if exp.After(now) { + return e.value, Hit + } + + exp = e.born.Add(c.maxTTL) + if exp.After(now) { + return e.value, GonnaMissy + } + + c.lock.Lock() + delete(c.tickets, clientID) + c.lock.Unlock() + return nil, Miss +} + +func (c *Cache) LoadByAlias(alias string) (*string, Status) { + c.lock.RLock() + clientID, ok := c.aliases[alias] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + return c.Load(clientID) +} + +func (c *Cache) Store(clientID tvm.ClientID, alias string, value *string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.aliases[alias] = clientID + c.tickets[clientID] = entry{ + value: value, + born: time.Now(), + } +} diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go new file mode 100644 index 0000000000..1d493dc3a3 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go @@ -0,0 +1,124 @@ +package cache_test + +import ( + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool/internal/cache" +) + +var ( + testDst = "test_dst" + testDstAlias = "test_dst_alias" + testDstID = tvm.ClientID(1) + testValue = "test_val" +) + +func TestNewAtHour(t *testing.T) { + c := cache.New(time.Hour, 11*time.Hour) + assert.NotNil(t, c, "failed to create cache") +} + +func TestCache_Load(t *testing.T) { + + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + // checking before + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.Hit, hit, "failed to get '%d' from cache before deadline", testDstID) + assert.NotNil(t, r, "failed to get '%d' from cache before deadline", testDstID) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.Hit, hit, "failed to get '%s' from cache before deadline", testDst) + assert.NotNil(t, r, "failed to get %q from tickets before deadline", testDst) + assert.Equal(t, testValue, *r) + } + { + r, hit := c.Load(999833321) + assert.Equal(t, cache.Miss, hit, "got tickets for '999833321', but that key must be never existed") + assert.Nil(t, r, "got tickets for '999833321', but that key must be never existed") + + r, hit = c.LoadByAlias("kek") + assert.Equal(t, cache.Miss, hit, "got tickets for 'kek', but that key must be never existed") + assert.Nil(t, r, "got tickets for 'kek', but that key must be never existed") + } + + time.Sleep(3 * time.Second) + // checking after + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + } +} + +func TestCache_Keys(t *testing.T) { + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + sort.Strings(aliases) + require.Equal(t, 2, len(aliases), "not correct length of aliases") + require.EqualValues(t, []string{testDst, testDstAlias}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.EqualValues(t, []tvm.ClientID{testDstID}, ids) + }) +} + +func TestCache_ExpiredKeys(t *testing.T) { + c := cache.New(time.Second, 10*time.Second) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + time.Sleep(3 * time.Second) + c.Gc() + + var ( + newDst = "new_dst" + newDstID = tvm.ClientID(2) + ) + c.Store(newDstID, newDst, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 3, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{testDst, testDstAlias, newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 2, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{testDstID, newDstID}, ids) + }) + + time.Sleep(8 * time.Second) + c.Gc() + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 1, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{newDstID}, ids) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/gotest/ya.make b/library/go/yandex/tvm/tvmtool/internal/cache/gotest/ya.make new file mode 100644 index 0000000000..8fe6d0ae33 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/yandex/tvm/tvmtool/internal/cache) + +END() diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/ya.make b/library/go/yandex/tvm/tvmtool/internal/cache/ya.make new file mode 100644 index 0000000000..8f870e1943 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/ya.make @@ -0,0 +1,9 @@ +GO_LIBRARY() + +SRCS(cache.go) + +GO_XTEST_SRCS(cache_test.go) + +END() + +RECURSE_FOR_TESTS(gotest) diff --git a/library/go/yandex/tvm/tvmtool/internal/ya.make b/library/go/yandex/tvm/tvmtool/internal/ya.make new file mode 100644 index 0000000000..9ef654573f --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/ya.make @@ -0,0 +1 @@ +RECURSE(cache) diff --git a/library/go/yandex/tvm/tvmtool/opts.go b/library/go/yandex/tvm/tvmtool/opts.go new file mode 100644 index 0000000000..2004a56f53 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/opts.go @@ -0,0 +1,103 @@ +package tvmtool + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/xerrors" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool/internal/cache" +) + +type ( + Option func(tool *Client) error +) + +// Source TVM client (id or alias) +// +// WARNING: id/alias must be configured in tvmtool. Documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#konfig +func WithSrc(src string) Option { + return func(tool *Client) error { + tool.src = src + return nil + } +} + +// Auth token +func WithAuthToken(token string) Option { + return func(tool *Client) error { + tool.authToken = token + return nil + } +} + +// Use custom HTTP client +func WithHTTPClient(client *http.Client) Option { + return func(tool *Client) error { + tool.ownHTTPClient = false + tool.httpClient = client + return nil + } +} + +// Enable or disable service tickets cache +// +// Enabled by default +func WithCacheEnabled(enabled bool) Option { + return func(tool *Client) error { + switch { + case enabled && tool.cache == nil: + tool.cache = cache.New(cacheTTL, cacheMaxTTL) + case !enabled: + tool.cache = nil + } + return nil + } +} + +// Overrides blackbox environment defined in config. +// +// Documentation about environment overriding: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func WithOverrideEnv(bbEnv tvm.BlackboxEnv) Option { + return func(tool *Client) error { + tool.bbEnv = strings.ToLower(bbEnv.String()) + return nil + } +} + +// WithLogger sets logger for tvm client. +func WithLogger(l log.Structured) Option { + return func(tool *Client) error { + tool.l = l + return nil + } +} + +// WithRefreshFrequency sets service tickets refresh frequency. +// Frequency must be lower chan cacheTTL (10 min) +// +// Default: 8 min +func WithRefreshFrequency(freq time.Duration) Option { + return func(tool *Client) error { + if freq > cacheTTL { + return xerrors.Errorf("refresh frequency must be lower than cacheTTL (%d > %d)", freq, cacheTTL) + } + + tool.refreshFreq = int64(freq.Seconds()) + return nil + } +} + +// WithBackgroundUpdate force Client to update all service ticket at background. +// You must manually cancel given ctx to stops refreshing. +// +// Default: disabled +func WithBackgroundUpdate(ctx context.Context) Option { + return func(tool *Client) error { + tool.bgCtx, tool.bgCancel = context.WithCancel(ctx) + return nil + } +} diff --git a/library/go/yandex/tvm/tvmtool/qloud.go b/library/go/yandex/tvm/tvmtool/qloud.go new file mode 100644 index 0000000000..4dcf0648db --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud.go @@ -0,0 +1,32 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + QloudEndpointEnvKey = "QLOUD_TVM_INTERFACE_ORIGIN" + QloudTokenEnvKey = "QLOUD_TVM_TOKEN" + QloudDefaultEndpoint = "http://localhost:1" +) + +// NewQloudClient method creates a new tvmtool client for Qloud environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewQloudClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(QloudEndpointEnvKey) + if baseURI == "" { + baseURI = QloudDefaultEndpoint + } + + authToken := os.Getenv(QloudTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", QloudTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/qloud_example_test.go b/library/go/yandex/tvm/tvmtool/qloud_example_test.go new file mode 100644 index 0000000000..38ec35026d --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud_example_test.go @@ -0,0 +1,71 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewQloudClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewQloudClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithOverrideEnv(tvm.BlackboxProd), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool.go b/library/go/yandex/tvm/tvmtool/tool.go new file mode 100644 index 0000000000..2e8dee98e1 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool.go @@ -0,0 +1,586 @@ +package tvmtool + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/core/xerrors" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool/internal/cache" +) + +const ( + dialTimeout = 100 * time.Millisecond + requestTimeout = 500 * time.Millisecond + keepAlive = 60 * time.Second + cacheTTL = 10 * time.Minute + cacheMaxTTL = 11 * time.Hour +) + +var _ tvm.Client = (*Client)(nil) + +type ( + Client struct { + lastSync int64 + apiURI string + baseURI string + src string + authToken string + bbEnv string + refreshFreq int64 + bgCtx context.Context + bgCancel context.CancelFunc + inFlightRefresh uint32 + cache *cache.Cache + pingRequest *http.Request + ownHTTPClient bool + httpClient *http.Client + l log.Structured + cachedRoles *atomic.Pointer[tvm.Roles] + } + + ticketsResponse map[string]struct { + Error string `json:"error"` + Ticket string `json:"ticket"` + TvmID tvm.ClientID `json:"tvm_id"` + } + + checkSrvResponse struct { + SrcID tvm.ClientID `json:"src"` + DstID tvm.ClientID `json:"dst"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } + + checkUserResponse struct { + DefaultUID tvm.UID `json:"default_uid"` + UIDs []tvm.UID `json:"uids"` + Scopes []string `json:"scopes"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } +) + +// NewClient method creates a new tvmtool client. +// You must reuse it to prevent connection/goroutines leakage. +func NewClient(apiURI string, opts ...Option) (*Client, error) { + baseURI := strings.TrimRight(apiURI, "/") + "/tvm" + pingRequest, err := http.NewRequest("GET", baseURI+"/ping", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: keepAlive, + }).DialContext + + tool := &Client{ + apiURI: apiURI, + baseURI: baseURI, + refreshFreq: 8 * 60, + cache: cache.New(cacheTTL, cacheMaxTTL), + pingRequest: pingRequest, + l: &nop.Logger{}, + ownHTTPClient: true, + httpClient: &http.Client{ + Transport: transport, + Timeout: requestTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + cachedRoles: &atomic.Pointer[tvm.Roles]{}, + } + + for _, opt := range opts { + if err := opt(tool); err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + } + + if tool.bgCtx != nil { + go tool.serviceTicketsRefreshLoop() + } + + return tool, nil +} + +// GetServiceTicketForAlias returns TVM service ticket for alias +// +// WARNING: alias must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.LoadByAlias(alias); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + return "", xerrors.Errorf("tvmtool: alias %q was not found in response", alias) + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(entry.TvmID, alias, &ticket) + } + return ticket, nil +} + +// GetServiceTicketForID returns TVM service ticket for destination application id +// +// WARNING: id must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.Load(dstID); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + alias := strconv.FormatUint(uint64(dstID), 10) + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + // ok, let's find him + for candidateAlias, candidate := range tickets { + if candidate.TvmID == dstID { + entry = candidate + alias = candidateAlias + ok = true + break + } + } + + if !ok { + return "", xerrors.Errorf("tvmtool: dst %q was not found in response", alias) + } + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(dstID, alias, &ticket) + } + return ticket, nil +} + +// Close stops background ticket updates (if configured) and closes idle connections. +func (c *Client) Close() { + if c.bgCancel != nil { + c.bgCancel() + } + + if c.ownHTTPClient { + c.httpClient.CloseIdleConnections() + } +} + +func (c *Client) refreshServiceTickets() { + if c.bgCtx != nil { + // service tickets will be updated at background in the separated goroutine + return + } + + now := time.Now().Unix() + if now-atomic.LoadInt64(&c.lastSync) > c.refreshFreq { + atomic.StoreInt64(&c.lastSync, now) + if atomic.CompareAndSwapUint32(&c.inFlightRefresh, 0, 1) { + go c.doServiceTicketsRefresh(context.Background()) + } + } +} + +func (c *Client) serviceTicketsRefreshLoop() { + var ticker = time.NewTicker(time.Duration(c.refreshFreq) * time.Second) + defer ticker.Stop() + for { + select { + case <-c.bgCtx.Done(): + return + case <-ticker.C: + c.doServiceTicketsRefresh(c.bgCtx) + } + } +} + +func (c *Client) doServiceTicketsRefresh(ctx context.Context) { + defer atomic.CompareAndSwapUint32(&c.inFlightRefresh, 1, 0) + + c.cache.Gc() + aliases := c.cache.Aliases() + if len(aliases) == 0 { + return + } + + c.l.Debug("tvmtool: service ticket update started") + defer c.l.Debug("tvmtool: service ticket update finished") + + // fast path: batch update, must work most of time + err := c.refreshServiceTicket(ctx, aliases...) + if err == nil { + return + } + + if tvmErr, ok := err.(*Error); ok && tvmErr.Code != ErrorBadRequest { + c.l.Error( + "tvmtool: failed to refresh all service tickets at background", + log.Strings("dsts", aliases), + log.Error(err), + ) + + // if we have non "bad request" error - something really terrible happens, nothing to do with it :( + // TODO(buglloc): implement adaptive refreshFreq based on errors? + return + } + + // slow path: trying to update service tickets one by one + c.l.Error( + "tvmtool: failed to refresh all service tickets at background, switched to slow path", + log.Strings("dsts", aliases), + log.Error(err), + ) + + for _, dst := range aliases { + if err := c.refreshServiceTicket(ctx, dst); err != nil { + c.l.Error( + "tvmtool: failed to refresh service ticket at background", + log.String("dst", dst), + log.Error(err), + ) + } + } +} + +func (c *Client) refreshServiceTicket(ctx context.Context, dsts ...string) error { + tickets, err := c.getServiceTickets(ctx, strings.Join(dsts, ",")) + if err != nil { + return err + } + + for _, dst := range dsts { + entry, ok := tickets[dst] + if !ok { + c.l.Error( + "tvmtool: destination was not found in tvmtool response", + log.String("dst", dst), + ) + continue + } + + if entry.Error != "" { + c.l.Error( + "tvmtool: failed to get service ticket for destination", + log.String("dst", dst), + log.String("err", entry.Error), + ) + continue + } + + c.cache.Store(entry.TvmID, dst, &entry.Ticket) + } + return nil +} + +func (c *Client) getServiceTickets(ctx context.Context, dst string) (ticketsResponse, error) { + params := url.Values{ + "dsts": {dst}, + } + if c.src != "" { + params.Set("src", c.src) + } + + req, err := http.NewRequest("GET", c.baseURI+"/tickets?"+params.Encode(), nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + req.Header.Set("Authorization", c.authToken) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result ticketsResponse + err = readResponse(resp, &result) + return result, err +} + +// Check TVM service ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checksrv +func (c *Client) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + req, err := http.NewRequest("GET", c.baseURI+"/checksrv", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + if c.src != "" { + req.URL.RawQuery += "dst=" + url.QueryEscape(c.src) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-Service-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkSrvResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedServiceTicket{ + SrcID: result.SrcID, + DstID: result.DstID, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Check TVM user ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func (c *Client) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + for range opts { + panic("implement me") + } + + req, err := http.NewRequest("GET", c.baseURI+"/checkusr", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + if c.bbEnv != "" { + req.URL.RawQuery += "override_env=" + url.QueryEscape(c.bbEnv) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-User-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkUserResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedUserTicket{ + DefaultUID: result.DefaultUID, + UIDs: result.UIDs, + Scopes: result.Scopes, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Checks TVMTool liveness +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/ping +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + + var status tvm.ClientStatusInfo + switch resp.StatusCode { + case http.StatusOK: + // OK! + status = tvm.ClientStatusInfo{Status: tvm.ClientOK} + err = nil + case http.StatusPartialContent: + status = tvm.ClientStatusInfo{Status: tvm.ClientWarning} + err = &PingError{Code: PingCodeWarning, Err: xerrors.New(string(body))} + case http.StatusInternalServerError: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeError, Err: xerrors.New(string(body))} + default: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeOther, Err: xerrors.Errorf("tvmtool: unexpected status: %d", resp.StatusCode)} + } + return status, err +} + +// Returns TVMTool version +func (c *Client) Version(ctx context.Context) (string, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return "", xerrors.Errorf("tvmtool: failed to call tmvtool: %w", err) + } + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + return resp.Header.Get("Server"), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + var cachedRevision string + cachedRolesValue := c.cachedRoles.Load() + if cachedRolesValue != nil { + cachedRevision = cachedRolesValue.GetMeta().Revision + } + + params := url.Values{ + "self": []string{c.src}, + } + req, err := http.NewRequest("GET", c.apiURI+"/v2/roles?"+params.Encode(), nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to make request to roles: %w", err) + } + req.Header.Set("Authorization", c.authToken) + if cachedRevision != "" { + req.Header.Set("If-None-Match", "\""+cachedRevision+"\"") + } + req = req.WithContext(ctx) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode == http.StatusNotModified { + if cachedRolesValue == nil { + return nil, xerrors.Errorf("tvmtool: logic error got 304 on empty cached roles data") + } + return cachedRolesValue, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, xerrors.Errorf("tvmtool: getroles: [%d] %s", resp.StatusCode, resp.Status) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, xerrors.Errorf("tvmtool: getroles: [%d] %s: %w", resp.StatusCode, resp.Status, err) + } + + roles, err := tvm.NewRoles(b) + if err != nil { + return nil, xerrors.Errorf("tvmtool: unable to parse roles: %w", err) + } + + c.cachedRoles.Store(roles) + + return roles, nil +} + +func readResponse(resp *http.Response, dst interface{}) error { + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return xerrors.Errorf("tvmtool: failed to read response: %w", err) + } + + switch resp.StatusCode { + case http.StatusOK, http.StatusForbidden: + // ok + return json.Unmarshal(body, dst) + case http.StatusUnauthorized: + return &Error{ + Code: ErrorAuthFail, + Msg: string(body), + } + case http.StatusBadRequest: + return &Error{ + Code: ErrorBadRequest, + Msg: string(body), + } + case http.StatusInternalServerError: + return &Error{ + Code: ErrorOther, + Msg: string(body), + Retriable: true, + } + default: + return &Error{ + Code: ErrorOther, + Msg: fmt.Sprintf("tvmtool: unexpected status: %d, msg: %s", resp.StatusCode, string(body)), + } + } +} diff --git a/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go new file mode 100644 index 0000000000..43cf0b69e1 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go @@ -0,0 +1,353 @@ +package tvmtool_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" + "go.uber.org/atomic" +) + +func newMockClient(upstream string, options ...tvmtool.Option) (*tvmtool.Client, error) { + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + options = append(options, tvmtool.WithLogger(zlog), tvmtool.WithAuthToken("token")) + return tvmtool.NewClient(upstream, options...) +} + +// TestClientBackgroundUpdate_Updatable checks that TVMTool client updates tickets state +func TestClientBackgroundUpdate_Updatable(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(500*time.Millisecond)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(1*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = atomic.NewString("3:serv:original-test-ticket:signature") + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = atomic.NewString("3:serv:original-test-foo-ticket:signature") + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket.Load(), + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket.Load(), + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func(mustEquals bool) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + } + + // populate tickets cache + requestTickets(true) + + // now change tickets + newTicket := "3:serv:changed-test-ticket:signature" + testTicket.Store(newTicket) + testFooTicket.Store("3:serv:changed-test-foo-ticket:signature") + + // wait some time + time.Sleep(2 * time.Second) + + // request new tickets + requestTickets(false) + + // and wait updates some time + for idx := 0; idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + ticket, _ := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + if ticket == newTicket { + break + } + } + + // now out tvmclient MUST returns new tickets + requestTickets(true) + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +// TestClientBackgroundUpdate_NotTooOften checks that TVMTool client request tvmtool not too often +func TestClientBackgroundUpdate_NotTooOften(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(20*time.Second)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(20*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + reqCount = atomic.NewUint32(0) + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = "3:serv:original-test-ticket:signature" + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = "3:serv:original-test-foo-ticket:signature" + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket, + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket, + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func() { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + } + + // populate cache + requestTickets() + + // requests tickets some time that lower than refresh frequency + for i := 0; i < 10; i++ { + requestTickets() + time.Sleep(200 * time.Millisecond) + } + + require.Equal(t, uint32(2), reqCount.Load(), "tvmtool client calls tvmtool too many times") + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +func TestClient_RefreshFrequency(t *testing.T) { + cases := map[string]struct { + freq time.Duration + err bool + }{ + "too_high": { + freq: 20 * time.Minute, + err: true, + }, + "ok": { + freq: 2 * time.Minute, + err: false, + }, + } + + for name, cs := range cases { + t.Run(name, func(t *testing.T) { + _, err := tvmtool.NewClient("fake", tvmtool.WithRefreshFrequency(cs.freq)) + if cs.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestClient_MultipleAliases(t *testing.T) { + reqCount := atomic.NewUint32(0) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ +"test": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456}, +"test_alias": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456} +}`)) + })) + defer srv.Close() + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := newMockClient( + srv.URL, + tvmtool.WithRefreshFrequency(2*time.Second), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + require.NoError(t, err) + + requestTickets := func(t *testing.T) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "test") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), "test_alias") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), tvm.ClientID(2002456)) + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + } + + t.Run("first", requestTickets) + + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 2 - one for each aliases + require.Equal(t, uint32(2), reqCount.Load()) + }) + + // now wait GC + reqCount.Store(0) + time.Sleep(3 * time.Second) + + t.Run("after_gc", requestTickets) + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 1 + require.Equal(t, uint32(1), reqCount.Load()) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_example_test.go b/library/go/yandex/tvm/tvmtool/tool_example_test.go new file mode 100644 index 0000000000..574f66b3b8 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_example_test.go @@ -0,0 +1,81 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewClient_backgroundServiceTicketsUpdate() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_export_test.go b/library/go/yandex/tvm/tvmtool/tool_export_test.go new file mode 100644 index 0000000000..7981a2db72 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_export_test.go @@ -0,0 +1,9 @@ +package tvmtool + +func (c *Client) BaseURI() string { + return c.baseURI +} + +func (c *Client) AuthToken() string { + return c.authToken +} diff --git a/library/go/yandex/tvm/tvmtool/tool_test.go b/library/go/yandex/tvm/tvmtool/tool_test.go new file mode 100644 index 0000000000..a043741cad --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_test.go @@ -0,0 +1,295 @@ +//go:build linux || darwin +// +build linux darwin + +// tvmtool recipe exists only for linux & darwin so we skip another OSes +package tvmtool_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "regexp" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/zap" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "github.com/ydb-platform/ydb/library/go/yandex/tvm/tvmtool" +) + +const ( + tvmToolPortFile = "tvmtool.port" + tvmToolAuthTokenFile = "tvmtool.authtoken" + userTicketFor1120000000038691 = "3:user" + + ":CA4Q__________9_GjUKCQijrpqRpdT-ARCjrpqRpdT-ARoMYmI6c2Vzc2lvbmlkGgl0ZXN0OnRlc3Qg0oXY" + + "zAQoAw:A-YI2yhoD7BbGU80_dKQ6vm7XADdvgD2QUFCeTI3XZ4MS4N8iENvsNDvYwsW89-vLQPv9pYqn8jxx" + + "awkvu_ZS2aAfpU8vXtnEHvzUQfes2kMjweRJE71cyX8B0VjENdXC5QAfGyK7Y0b4elTDJzw8b28Ro7IFFbNe" + + "qgcPInXndY" + serviceTicketFor41_42 = "3:serv:CBAQ__________9_IgQIKRAq" + + ":VVXL3wkhpBHB7OXSeG0IhqM5AP2CP-gJRD31ksAb-q7pmssBJKtPNbH34BSyLpBllmM1dgOfwL8ICUOGUA3l" + + "jOrwuxZ9H8ayfdrpM7q1-BVPE0sh0L9cd8lwZIW6yHejTe59s6wk1tG5MdSfncdaJpYiF3MwNHSRklNAkb6hx" + + "vg" + serviceTicketFor41_99 = "3:serv:CBAQ__________9_IgQIKRBj" + + ":PjJKDOsEk8VyxZFZwsVnKrW1bRyA82nGd0oIxnEFEf7DBTVZmNuxEejncDrMxnjkKwimrumV9POK4ptTo0ZPY" + + "6Du9zHR5QxekZYwDzFkECVrv9YT2QI03odwZJX8_WCpmlgI8hUog_9yZ5YCYxrQpWaOwDXx4T7VVMwH_Z9YTZk" +) + +var ( + srvTicketRe = regexp.MustCompile(`^3:serv:[A-Za-z0-9_\-]+:[A-Za-z0-9_\-]+$`) +) + +func newTvmToolClient(src string, authToken ...string) (*tvmtool.Client, error) { + raw, err := os.ReadFile(tvmToolPortFile) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(string(raw)) + if err != nil { + return nil, err + } + + return newTvmToolClientAtURL(src, fmt.Sprintf("http://localhost:%d", port), authToken...) +} + +func newTvmToolClientAtURL(src string, apiURL string, authToken ...string) (*tvmtool.Client, error) { + var auth string + if len(authToken) > 0 { + auth = authToken[0] + } else { + raw, err := os.ReadFile(tvmToolAuthTokenFile) + if err != nil { + return nil, err + } + auth = string(raw) + } + + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + + return tvmtool.NewClient( + apiURL, + tvmtool.WithAuthToken(auth), + tvmtool.WithCacheEnabled(false), + tvmtool.WithSrc(src), + tvmtool.WithLogger(zlog), + ) +} + +func TestNewClient(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestClient_GetStatus(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + status, err := client.GetStatus(context.Background()) + require.NoError(t, err, "ping must work") + require.Equal(t, tvm.ClientOK, status.Status) +} + +func TestClient_BadAuth(t *testing.T) { + badClient, err := newTvmToolClient("main", "fake-auth") + require.NoError(t, err) + + _, err = badClient.GetServiceTicketForAlias(context.Background(), "lala") + require.Error(t, err) + require.IsType(t, err, &tvmtool.Error{}) + srvTickerErr := err.(*tvmtool.Error) + require.Equal(t, tvmtool.ErrorAuthFail, srvTickerErr.Code) +} + +func TestClient_GetServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + + t.Run("invalid_alias", func(t *testing.T) { + // Ticket for invalid alias must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForAlias(ctx, "not_exists") + require.Error(t, err, "ticket for invalid alias must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = not_exists (strconv) (code ErrorBadRequest)") + }) + + t.Run("invalid_dst_id", func(t *testing.T) { + // Ticket for invalid client id must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForID(ctx, 123123123) + require.Error(t, err, "ticket for invalid ID must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = 123123123 (by number) (code ErrorBadRequest)") + }) + + t.Run("by_alias", func(t *testing.T) { + // Try to get ticket by alias + t.Parallel() + heTicketByAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he") + if assert.NoError(t, err, "failed to get srv ticket to 'he'") { + assert.Regexp(t, srvTicketRe, heTicketByAlias, "invalid 'he' srv ticket") + } + + heCloneTicketAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he_clone") + if assert.NoError(t, err, "failed to get srv ticket to 'he_clone'") { + assert.Regexp(t, srvTicketRe, heCloneTicketAlias, "invalid 'he_clone' srv ticket") + } + }) + + t.Run("by_dst_id", func(t *testing.T) { + // Try to get ticket by id + t.Parallel() + heTicketByID, err := tvmClient.GetServiceTicketForID(ctx, 100500) + if assert.NoError(t, err, "failed to get srv ticket to '100500'") { + assert.Regexp(t, srvTicketRe, heTicketByID, "invalid '100500' srv ticket") + } + }) +} + +func TestClient_CheckServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + t.Run("self_to_self", func(t *testing.T) { + t.Parallel() + + // Check from self to self + selfTicket, err := tvmClient.GetServiceTicketForAlias(ctx, "self") + require.NoError(t, err, "failed to get service ticket to 'self'") + assert.Regexp(t, srvTicketRe, selfTicket, "invalid 'self' srv ticket") + + // Now we can check srv ticket + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, selfTicket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("to_another", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to self + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_42) + require.NoError(t, err, "failed to check srv ticket 41 -> 42") + + assert.Equal(t, tvm.ClientID(41), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("invalid_dst", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to invalid dst (99) + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_99) + require.Error(t, err, "srv ticket for 41 -> 99 must fails") + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "Wrong ticket dst, expected 42, got 99", ticketErr.Msg) + }) + + t.Run("broken", func(t *testing.T) { + t.Parallel() + + // Check with broken sign + _, err := tvmClient.CheckServiceTicket(ctx, "lalala") + require.Error(t, err, "srv ticket with broken sign must fails") + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "invalid ticket format", ticketErr.Msg) + }) +} + +func TestClient_MultipleClients(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + slaveClient, err := newTvmToolClient("slave") + require.NoError(t, err) + + ctx := context.Background() + + ticket, err := tvmClient.GetServiceTicketForAlias(ctx, "slave") + require.NoError(t, err, "failed to get service ticket to 'slave'") + assert.Regexp(t, srvTicketRe, ticket, "invalid 'slave' srv ticket") + + ticketInfo, err := slaveClient.CheckServiceTicket(ctx, ticket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_CheckUserTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ticketInfo, err := tvmClient.CheckUserTicket(context.Background(), userTicketFor1120000000038691) + require.NoError(t, err, "failed to check user ticket") + + assert.Equal(t, tvm.UID(1120000000038691), ticketInfo.DefaultUID) + assert.Subset(t, []tvm.UID{1120000000038691}, ticketInfo.UIDs) + assert.Subset(t, []string{"bb:sessionid", "test:test"}, ticketInfo.Scopes) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_Version(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + version, err := tvmClient.Version(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, version) +} + +func TestClient_GetRoles(t *testing.T) { + // ходить в настоящий tvmtool не получилось, + // потому что он не может стартовать в рецепте с "roles_for_idm_slug" + tvmtoolServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + require.Equal(t, "main", req.URL.Query().Get("self")) + if req.Header.Get("If-None-Match") == "\"GY2GCMTFMQ2DE\"" { + res.WriteHeader(http.StatusNotModified) + return + } + res.WriteHeader(http.StatusOK) + resp := `{"revision":"GY2GCMTFMQ2DE","born_date":1688399170,"user":{"1120000000022901":{"/role/advanced/":[{}]}}}` + _, err := res.Write([]byte(resp)) + require.NoError(t, err) + })) + defer tvmtoolServer.Close() + + tvmClient, err := newTvmToolClientAtURL("main", tvmtoolServer.URL, "12345") + require.NoError(t, err) + + // первый раз отвечаем с непустыми данными из tvmtool + // второй раз отвечаем из кеша + for i := 0; i < 2; i++ { + roles, err := tvmClient.GetRoles(context.Background()) + require.NoError(t, err) + require.NotNil(t, roles) + + userRoles, err := roles.GetRolesForUser(&tvm.CheckedUserTicket{ + DefaultUID: 1120000000022901, + Env: tvm.BlackboxProdYateam, + }, nil) + require.NoError(t, err) + require.True(t, userRoles.HasRole("/role/advanced/")) + } +} diff --git a/library/go/yandex/tvm/tvmtool/ya.make b/library/go/yandex/tvm/tvmtool/ya.make new file mode 100644 index 0000000000..f492903e50 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/ya.make @@ -0,0 +1,44 @@ +GO_LIBRARY() + +SRCS( + any.go + deploy.go + doc.go + errors.go + opts.go + qloud.go + tool.go +) + +GO_TEST_SRCS(tool_export_test.go) + +GO_XTEST_SRCS( + any_example_test.go + deploy_example_test.go + qloud_example_test.go + tool_bg_update_test.go + tool_example_test.go +) + +IF (OS_LINUX) + GO_XTEST_SRCS( + clients_test.go + tool_test.go + ) +ENDIF() + +IF (OS_DARWIN) + GO_XTEST_SRCS( + clients_test.go + tool_test.go + ) +ENDIF() + +END() + +RECURSE( + examples + internal +) + +RECURSE_FOR_TESTS(gotest) diff --git a/library/go/yandex/tvm/user_ticket.go b/library/go/yandex/tvm/user_ticket.go new file mode 100644 index 0000000000..d745c9e508 --- /dev/null +++ b/library/go/yandex/tvm/user_ticket.go @@ -0,0 +1,128 @@ +package tvm + +import ( + "fmt" +) + +// CheckedUserTicket is short-lived user credential. +// +// CheckedUserTicket contains only valid users. +// Details: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#chtoestvusertickete +type CheckedUserTicket struct { + // DefaultUID is default user - maybe 0 + DefaultUID UID + // UIDs is array of valid users - never empty + UIDs []UID + // Env is blackbox environment which created this UserTicket - provides only tvmauth now + Env BlackboxEnv + // Scopes is array of scopes inherited from credential - never empty + Scopes []string + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string + //LoginID of a user, can be empty if ticket does not contain LoginID + LoginID string + //UIDs of users in ticket with extended fields + UidsExtFieldsMap map[UID]UserExtFields + //Default user in ticket with extended fields, can be nil if there is no default uid in ticket + DefaultUIDExtFields *UserExtFields +} + +func (t CheckedUserTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +// CheckScopes verify that ALL needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopes(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; !found { + // exit on first nonexistent scope + msg := fmt.Sprintf( + "user ticket doesn't contain one of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } + } + + return nil + } +} + +// CheckScopesAny verify that ANY of needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopesAny(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain any of expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; found { + // exit on first valid scope + return nil + } + } + + msg := fmt.Sprintf( + "user ticket doesn't contain any of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } +} + +type CheckUserTicketOptions struct { + EnvOverride *BlackboxEnv +} + +type CheckUserTicketOption func(*CheckUserTicketOptions) + +func WithBlackboxOverride(env BlackboxEnv) CheckUserTicketOption { + return func(opts *CheckUserTicketOptions) { + opts.EnvOverride = &env + } +} + +type UserTicketACL func(ticket *CheckedUserTicket) error + +func AllowAllUserTickets() UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return nil + } +} + +func CheckAllUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopes(scopes...) + } +} + +func CheckAnyUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopesAny(scopes...) + } +} diff --git a/library/go/yandex/tvm/ya.make b/library/go/yandex/tvm/ya.make new file mode 100644 index 0000000000..80997806ac --- /dev/null +++ b/library/go/yandex/tvm/ya.make @@ -0,0 +1,38 @@ +GO_LIBRARY() + +SRCS( + client.go + context.go + errors.go + roles.go + roles_entities_index.go + roles_entities_index_builder.go + roles_opts.go + roles_parser.go + roles_parser_opts.go + roles_types.go + service_ticket.go + tvm.go + user_ticket.go +) + +GO_TEST_SRCS( + roles_entities_index_builder_test.go + roles_entities_index_test.go + roles_parser_test.go + roles_test.go +) + +GO_XTEST_SRCS(tvm_test.go) + +END() + +RECURSE(examples) + +RECURSE_FOR_TESTS( + cachedtvm + gotest + mocks + tvmauth + tvmtool +) diff --git a/library/go/yandex/unistat/aggr/aggr.go b/library/go/yandex/unistat/aggr/aggr.go new file mode 100644 index 0000000000..515c5c0335 --- /dev/null +++ b/library/go/yandex/unistat/aggr/aggr.go @@ -0,0 +1,64 @@ +package aggr + +import "github.com/ydb-platform/ydb/library/go/yandex/unistat" + +// Histogram returns delta histogram aggregation (dhhh). +func Histogram() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Delta, + Group: unistat.Hgram, + MetaGroup: unistat.Hgram, + Rollup: unistat.Hgram, + } +} + +// AbsoluteHistogram returns absolute histogram aggregation (ahhh). +func AbsoluteHistogram() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Absolute, + Group: unistat.Hgram, + MetaGroup: unistat.Hgram, + Rollup: unistat.Hgram, + } +} + +// Counter returns counter aggregation (dmmm) +func Counter() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Delta, + Group: unistat.Sum, + MetaGroup: unistat.Sum, + Rollup: unistat.Sum, + } +} + +// Absolute returns value aggregation (ammm) +func Absolute() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Absolute, + Group: unistat.Sum, + MetaGroup: unistat.Sum, + Rollup: unistat.Sum, + } +} + +// SummAlias corresponds to _summ suffix +type SummAlias struct{} + +func (s SummAlias) Suffix() string { + return "summ" +} + +// SummAlias corresponds to _hgram suffix +type HgramAlias struct{} + +func (s HgramAlias) Suffix() string { + return "hgram" +} + +// SummAlias corresponds to _max suffix +type MaxAlias struct{} + +func (s MaxAlias) Suffix() string { + return "max" +} diff --git a/library/go/yandex/unistat/aggr/ya.make b/library/go/yandex/unistat/aggr/ya.make new file mode 100644 index 0000000000..21d6599197 --- /dev/null +++ b/library/go/yandex/unistat/aggr/ya.make @@ -0,0 +1,5 @@ +GO_LIBRARY() + +SRCS(aggr.go) + +END() diff --git a/library/go/yandex/unistat/histogram.go b/library/go/yandex/unistat/histogram.go new file mode 100644 index 0000000000..ca7b78dadb --- /dev/null +++ b/library/go/yandex/unistat/histogram.go @@ -0,0 +1,86 @@ +package unistat + +import ( + "sync" + + "github.com/goccy/go-json" +) + +// Histogram implements Metric interface +type Histogram struct { + mu sync.RWMutex + name string + tags []Tag + priority Priority + aggr Aggregation + + intervals []float64 + weights []int64 + size int64 +} + +// NewHistogram allocates Histogram metric. +// For naming rules see https://wiki.yandex-team.ru/golovan/tagsandsignalnaming. +// Intervals in left edges of histograms buckets (maximum 50 allowed). +func NewHistogram(name string, priority Priority, aggr Aggregation, intervals []float64, tags ...Tag) *Histogram { + return &Histogram{ + name: formatTags(tags) + name, + priority: priority, + aggr: aggr, + intervals: intervals, + weights: make([]int64, len(intervals)), + } +} + +// Name from Metric interface. +func (h *Histogram) Name() string { + return h.name +} + +// Priority from Metric interface. +func (h *Histogram) Priority() Priority { + return h.priority +} + +// Aggregation from Metric interface. +func (h *Histogram) Aggregation() Aggregation { + return h.aggr +} + +// Update from Metric interface. +func (h *Histogram) Update(value float64) { + h.mu.Lock() + defer h.mu.Unlock() + + for i := len(h.intervals); i > 0; i-- { + if value >= h.intervals[i-1] { + h.weights[i-1]++ + h.size++ + break + } + } +} + +// MarshalJSON from Metric interface. +func (h *Histogram) MarshalJSON() ([]byte, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + buckets := [][2]interface{}{} + for i := range h.intervals { + b := h.intervals[i] + w := h.weights[i] + buckets = append(buckets, [2]interface{}{b, w}) + } + + jsonName := h.name + "_" + h.aggr.Suffix() + return json.Marshal([]interface{}{jsonName, buckets}) +} + +// GetSize returns histogram's values count. +func (h *Histogram) GetSize() int64 { + h.mu.Lock() + defer h.mu.Unlock() + + return h.size +} diff --git a/library/go/yandex/unistat/number.go b/library/go/yandex/unistat/number.go new file mode 100644 index 0000000000..c9b26ef073 --- /dev/null +++ b/library/go/yandex/unistat/number.go @@ -0,0 +1,86 @@ +package unistat + +import ( + "math" + "sync" + + "github.com/goccy/go-json" +) + +// Numeric implements Metric interface. +type Numeric struct { + mu sync.RWMutex + name string + tags []Tag + priority Priority + aggr Aggregation + localAggr AggregationRule + + value float64 +} + +// NewNumeric allocates Numeric value metric. +func NewNumeric(name string, priority Priority, aggr Aggregation, localAggr AggregationRule, tags ...Tag) *Numeric { + return &Numeric{ + name: formatTags(tags) + name, + priority: priority, + aggr: aggr, + localAggr: localAggr, + } +} + +// Name from Metric interface. +func (n *Numeric) Name() string { + return n.name +} + +// Aggregation from Metric interface. +func (n *Numeric) Aggregation() Aggregation { + return n.aggr +} + +// Priority from Metric interface. +func (n *Numeric) Priority() Priority { + return n.priority +} + +// Update from Metric interface. +func (n *Numeric) Update(value float64) { + n.mu.Lock() + defer n.mu.Unlock() + + switch n.localAggr { + case Max: + n.value = math.Max(n.value, value) + case Min: + n.value = math.Min(n.value, value) + case Sum: + n.value += value + case Last: + n.value = value + default: + n.value = -1 + } +} + +// MarshalJSON from Metric interface. +func (n *Numeric) MarshalJSON() ([]byte, error) { + jsonName := n.name + "_" + n.aggr.Suffix() + return json.Marshal([]interface{}{jsonName, n.GetValue()}) +} + +// GetValue returns current metric value. +func (n *Numeric) GetValue() float64 { + n.mu.RLock() + defer n.mu.RUnlock() + + return n.value +} + +// SetValue sets current metric value. +func (n *Numeric) SetValue(value float64) { + n.mu.Lock() + defer n.mu.Unlock() + + n.value = value +} diff --git a/library/go/yandex/unistat/registry.go b/library/go/yandex/unistat/registry.go new file mode 100644 index 0000000000..d846d6f23c --- /dev/null +++ b/library/go/yandex/unistat/registry.go @@ -0,0 +1,60 @@ +package unistat + +import ( + "sort" + "sync" + + "github.com/goccy/go-json" +) + +type registry struct { + mu sync.Mutex + byName map[string]Metric + + metrics []Metric + unsorted bool +} + +// NewRegistry allocate new registry container for unistat metrics. +func NewRegistry() Registry { + return ®istry{ + byName: map[string]Metric{}, + metrics: []Metric{}, + } +} + +func (r *registry) Register(m Metric) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.byName[m.Name()]; ok { + panic(ErrDuplicate) + } + + r.byName[m.Name()] = m + r.metrics = append(r.metrics, m) + r.unsorted = true +} + +func (r *registry) MarshalJSON() ([]byte, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.unsorted { + sort.Sort(byPriority(r.metrics)) + r.unsorted = false + } + return json.Marshal(r.metrics) +} + +type byPriority []Metric + +func (m byPriority) Len() int { return len(m) } +func (m byPriority) Less(i, j int) bool { + if m[i].Priority() == m[j].Priority() { + return m[i].Name() < m[j].Name() + } + + return m[i].Priority() > m[j].Priority() +} +func (m byPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] } diff --git a/library/go/yandex/unistat/tags.go b/library/go/yandex/unistat/tags.go new file mode 100644 index 0000000000..dd1872bdaf --- /dev/null +++ b/library/go/yandex/unistat/tags.go @@ -0,0 +1,30 @@ +package unistat + +import ( + "sort" + "strings" +) + +type Tag struct { + Name string + Value string +} + +func formatTags(tags []Tag) string { + if len(tags) == 0 { + return "" + } + + sort.Slice(tags, func(i, j int) bool { + return tags[i].Name < tags[j].Name + }) + + var result strings.Builder + for i := range tags { + value := tags[i].Name + "=" + tags[i].Value + ";" + + result.WriteString(value) + } + + return result.String() +} diff --git a/library/go/yandex/unistat/unistat.go b/library/go/yandex/unistat/unistat.go new file mode 100644 index 0000000000..6abc68d14e --- /dev/null +++ b/library/go/yandex/unistat/unistat.go @@ -0,0 +1,171 @@ +package unistat + +import ( + "errors" + "fmt" + "time" + + "github.com/goccy/go-json" +) + +// StructuredAggregation provides type safe API to create an Aggregation. For more +// information see: https://wiki.yandex-team.ru/golovan/aggregation-types/ +type StructuredAggregation struct { + AggregationType AggregationType + Group AggregationRule + MetaGroup AggregationRule + Rollup AggregationRule +} + +// Aggregation defines rules how to aggregate signal on each level. For more +// information see: https://wiki.yandex-team.ru/golovan/aggregation-types/ +type Aggregation interface { + Suffix() string +} + +const ( + AggregationUnknown = "<unknown>" +) + +// Suffix defines signal aggregation on each level: +// 1 - Signal type: absolute (A) or delta (D). +// 2 - Group aggregation. +// 3 - Meta-group aggregation type. +// 4 - Time aggregation for roll-up. +// +// Doc: https://doc.yandex-team.ru/Search/golovan-quickstart/concepts/signal-aggregation.html#agrr-levels +func (a StructuredAggregation) Suffix() string { + return fmt.Sprintf("%s%s%s%s", a.AggregationType, a.Group, a.MetaGroup, a.Rollup) +} + +// Priority is used to order signals in unistat report. +// https://wiki.yandex-team.ru/golovan/stat-handle/#protokol +type Priority int + +// AggregationType is Absolute or Delta. +type AggregationType int + +// Value types +const ( + Absolute AggregationType = iota // Absolute value. Use for gauges. + Delta // Delta value. Use for increasing counters. +) + +func (v AggregationType) String() string { + switch v { + case Absolute: + return "a" + case Delta: + return "d" + default: + return AggregationUnknown + } +} + +// AggregationRule defines aggregation rules: +// +// https://wiki.yandex-team.ru/golovan/aggregation-types/#algoritmyagregacii +type AggregationRule int + +// Aggregation rules +const ( + Hgram AggregationRule = iota // Hgram is histogram aggregation. + Max // Max value. + Min // Min value. + Sum // Sum with default 0. + SumNone // SumNone is sum with default None. + Last // Last value. + Average // Average value. +) + +func (r AggregationRule) String() string { + switch r { + case Hgram: + return "h" + case Max: + return "x" + case Min: + return "n" + case Sum: + return "m" + case SumNone: + return "e" + case Last: + return "t" + case Average: + return "v" + default: + return AggregationUnknown + } +} + +func (r *AggregationRule) UnmarshalText(source []byte) error { + text := string(source) + switch text { + case "h": + *r = Hgram + case "x": + *r = Max + case "n": + *r = Min + case "m": + *r = Sum + case "e": + *r = SumNone + case "t": + *r = Last + case "v": + *r = Average + default: + return fmt.Errorf("unknown aggregation rule '%s'", text) + } + return nil +} + +// ErrDuplicate is raised on duplicate metric name registration. +var ErrDuplicate = errors.New("unistat: duplicate metric") + +// Metric is interface that accepted by Registry. +type Metric interface { + Name() string + Priority() Priority + Aggregation() Aggregation + MarshalJSON() ([]byte, error) +} + +// Updater is interface that wraps basic Update() method. +type Updater interface { + Update(value float64) +} + +// Registry is interface for container that generates stat report +type Registry interface { + Register(metric Metric) + MarshalJSON() ([]byte, error) +} + +var defaultRegistry = NewRegistry() + +// Register metric in default registry. +func Register(metric Metric) { + defaultRegistry.Register(metric) +} + +// MarshalJSON marshals default registry to JSON. +func MarshalJSON() ([]byte, error) { + return json.Marshal(defaultRegistry) +} + +// MeasureMicrosecondsSince updates metric with duration that started +// at ts and ends now. +func MeasureMicrosecondsSince(m Updater, ts time.Time) { + measureMicrosecondsSince(time.Since, m, ts) +} + +// For unittest +type timeSinceFunc func(t time.Time) time.Duration + +func measureMicrosecondsSince(sinceFunc timeSinceFunc, m Updater, ts time.Time) { + dur := sinceFunc(ts) + m.Update(float64(dur / time.Microsecond)) // to microseconds +} diff --git a/library/go/yandex/unistat/ya.make b/library/go/yandex/unistat/ya.make new file mode 100644 index 0000000000..6e74f2270b --- /dev/null +++ b/library/go/yandex/unistat/ya.make @@ -0,0 +1,25 @@ +GO_LIBRARY() + +SRCS( + histogram.go + number.go + registry.go + tags.go + unistat.go +) + +GO_TEST_SRCS( + histogram_test.go + number_test.go + registry_test.go + tags_test.go + unistat_test.go +) + +END() + +RECURSE( + aggr + example_server + gotest +) diff --git a/library/go/yandex/yplite/spec.go b/library/go/yandex/yplite/spec.go new file mode 100644 index 0000000000..228f9627ef --- /dev/null +++ b/library/go/yandex/yplite/spec.go @@ -0,0 +1,46 @@ +package yplite + +type PodSpec struct { + DNS PodDNS `json:"dns"` + ResourceRequests ResourceRequest `json:"resourceRequests"` + PortoProperties []PortoProperty `json:"portoProperties"` + IP6AddressAllocations []IP6AddressAllocation `json:"ip6AddressAllocations"` +} + +type PodAttributes struct { + ResourceRequirements struct { + CPU struct { + Guarantee uint64 `json:"cpu_guarantee_millicores,string"` + Limit uint64 `json:"cpu_limit_millicores,string"` + } `json:"cpu"` + Memory struct { + Guarantee uint64 `json:"memory_guarantee_bytes,string"` + Limit uint64 `json:"memory_limit_bytes,string"` + } `json:"memory"` + } `json:"resource_requirements"` +} + +type ResourceRequest struct { + CPUGuarantee uint64 `json:"vcpuGuarantee,string"` + CPULimit uint64 `json:"vcpuLimit,string"` + MemoryGuarantee uint64 `json:"memoryGuarantee,string"` + MemoryLimit uint64 `json:"memoryLimit,string"` + AnonymousMemoryLimit uint64 `json:"anonymousMemoryLimit,string"` +} + +type IP6AddressAllocation struct { + Address string `json:"address"` + VlanID string `json:"vlanId"` + PersistentFQDN string `json:"persistentFqdn"` + TransientFQDN string `json:"transientFqdn"` +} + +type PortoProperty struct { + Name string `json:"key"` + Value string `json:"value"` +} + +type PodDNS struct { + PersistentFqdn string `json:"persistentFqdn"` + TransientFqdn string `json:"transientFqdn"` +} diff --git a/library/go/yandex/yplite/ya.make b/library/go/yandex/yplite/ya.make new file mode 100644 index 0000000000..8583357d0a --- /dev/null +++ b/library/go/yandex/yplite/ya.make @@ -0,0 +1,8 @@ +GO_LIBRARY() + +SRCS( + spec.go + yplite.go +) + +END() diff --git a/library/go/yandex/yplite/yplite.go b/library/go/yandex/yplite/yplite.go new file mode 100644 index 0000000000..a39d889391 --- /dev/null +++ b/library/go/yandex/yplite/yplite.go @@ -0,0 +1,67 @@ +package yplite + +import ( + "context" + "encoding/json" + "net" + "net/http" + "os" + "time" + + "github.com/ydb-platform/ydb/library/go/core/xerrors" +) + +const ( + PodSocketPath = "/run/iss/pod.socket" + NodeAgentTimeout = 1 * time.Second +) + +var ( + httpClient = http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.DialTimeout("unix", PodSocketPath, NodeAgentTimeout) + }, + }, + Timeout: NodeAgentTimeout, + } +) + +func IsAPIAvailable() bool { + if _, err := os.Stat(PodSocketPath); err == nil { + return true + } + return false +} + +func FetchPodSpec() (*PodSpec, error) { + res, err := httpClient.Get("http://localhost/pod_spec") + if err != nil { + return nil, xerrors.Errorf("failed to request pod spec: %w", err) + } + defer func() { _ = res.Body.Close() }() + + spec := new(PodSpec) + err = json.NewDecoder(res.Body).Decode(spec) + if err != nil { + return nil, xerrors.Errorf("failed to decode pod spec: %w", err) + } + + return spec, nil +} + +func FetchPodAttributes() (*PodAttributes, error) { + res, err := httpClient.Get("http://localhost/pod_attributes") + if err != nil { + return nil, xerrors.Errorf("failed to request pod attributes: %w", err) + } + defer func() { _ = res.Body.Close() }() + + attrs := new(PodAttributes) + err = json.NewDecoder(res.Body).Decode(attrs) + if err != nil { + return nil, xerrors.Errorf("failed to decode pod attributes: %w", err) + } + + return attrs, nil +} diff --git a/library/recipes/tirole/README.md b/library/recipes/tirole/README.md new file mode 100644 index 0000000000..1cade61912 --- /dev/null +++ b/library/recipes/tirole/README.md @@ -0,0 +1,16 @@ +Tirole recipe +-- + +Этот рецепт позволяет в тестах поднять демон, который скрывается за `tirole-api.yandex.net`. +Демон слушает на порте из файла `tirole.port` - только http. + +База ролей в tirole - это каталог в Аркадии с файлами: + * `<slug>.json` - роли, которые надо отдавать в API + * `mapping.yaml` - соответствие между slug и tvmid + +В рецепте API принимает service-тикеты с dst==1000001: их можно получать из `tvmapi`/`tvmtool`, запущеного в рецепте. + +Примеры: +1. `ut_simple` + +Вопросы можно писать в [PASSPORTDUTY](https://st.yandex-team.ru/createTicket?queue=PASSPORTDUTY&_form=77618) diff --git a/library/recipes/tirole/__main__.py b/library/recipes/tirole/__main__.py new file mode 100644 index 0000000000..e61c1517d7 --- /dev/null +++ b/library/recipes/tirole/__main__.py @@ -0,0 +1,102 @@ +import argparse +import datetime +import json +import os +import requests +import sys + +import yatest.common +from library.python.testing.recipe import declare_recipe +from library.recipes.common import start_daemon, stop_daemon +from yatest.common import network + +TIROLE_PORT_FILE = "tirole.port" +TIROLE_PID_FILE = "tirole.pid" + +CONFIG_PATH = './tirole.config.json' + + +PORT_MANAGER = network.PortManager() + + +def _gen_config(roles_dir): + http_port = PORT_MANAGER.get_tcp_port(80) + + cfg = { + "http_common": { + "listen_address": "localhost", + "port": http_port, + }, + "logger": { + "file": yatest.common.output_path("tirole-common.log"), + }, + "service": { + "common": { + "access_log": yatest.common.output_path("tirole-access.log"), + }, + "tvm": { + "self_tvm_id": 1000001, + }, + "key_map": { + "keys_file": yatest.common.source_path("library/recipes/tirole/data/sign.keys"), + "default_key": "1", + }, + "unittest": { + "roles_dir": yatest.common.source_path(roles_dir) + "/", + }, + }, + } + + with open(CONFIG_PATH, 'wt') as f: + json.dump(cfg, f, sort_keys=True, indent=4) + + return http_port + + +def start(argv): + _log('Starting Tirole recipe') + + parser = argparse.ArgumentParser() + parser.add_argument('--roles-dir', dest='roles_dir', type=str, required=True) + input_args = parser.parse_args(argv) + + http_port = _gen_config(input_args.roles_dir) + + print(http_port, file=sys.stderr) + with open(TIROLE_PORT_FILE, "w") as f: + f.write(str(http_port)) + + # launch + args = [ + yatest.common.build_path() + '/passport/infra/daemons/tirole/cmd/tirole', + '-c', + CONFIG_PATH, + ] + + def check(): + try: + r = requests.get("http://localhost:%d/ping" % http_port) + if r.status_code == 200: + return True + else: + _log("ping: %d : %s" % (r.status_code, r.text)) + except Exception as e: + _log("ping: %s" % e) + return False + + start_daemon(command=args, environment=os.environ.copy(), is_alive_check=check, pid_file_name=TIROLE_PID_FILE) + + +def stop(argv): + with open(TIROLE_PID_FILE) as f: + pid = f.read() + if not stop_daemon(pid): + _log("pid is dead: %s" % pid) + + +def _log(msg): + print("%s : tirole-recipe : %s" % (datetime.datetime.now(), msg), file=sys.stdout) + + +if __name__ == "__main__": + declare_recipe(start, stop) diff --git a/library/recipes/tirole/data/sign.keys b/library/recipes/tirole/data/sign.keys new file mode 100644 index 0000000000..bc1a0be4da --- /dev/null +++ b/library/recipes/tirole/data/sign.keys @@ -0,0 +1 @@ +{"1": "733f9cdba433040287a4235247f8f31a326fee9e0f094d2987aac16d5eb0b883"} diff --git a/library/recipes/tirole/recipe.inc b/library/recipes/tirole/recipe.inc new file mode 100644 index 0000000000..50ed13c092 --- /dev/null +++ b/library/recipes/tirole/recipe.inc @@ -0,0 +1,8 @@ +DEPENDS( + library/recipes/tirole + passport/infra/daemons/tirole/cmd +) + +DATA( + arcadia/library/recipes/tirole/data +) diff --git a/library/recipes/tirole/ya.make b/library/recipes/tirole/ya.make new file mode 100644 index 0000000000..64de3cd8ea --- /dev/null +++ b/library/recipes/tirole/ya.make @@ -0,0 +1,18 @@ +PY3_PROGRAM() + +PY_SRCS(__main__.py) + +PEERDIR( + contrib/python/requests + library/python/testing/recipe + library/python/testing/yatest_common + library/recipes/common +) + +END() + +IF (NOT OS_WINDOWS AND NOT SANITIZER_TYPE) + RECURSE_FOR_TESTS( + ut_simple + ) +ENDIF() diff --git a/library/recipes/tvmapi/README.md b/library/recipes/tvmapi/README.md new file mode 100644 index 0000000000..cab650a5f6 --- /dev/null +++ b/library/recipes/tvmapi/README.md @@ -0,0 +1,16 @@ +TVM-API recipe +-- + +Этот рецепт позволяет в тестах поднять демон, который скрывается за `tvm-api.yandex.net`. +Демон слушает на порте из файла `tvmapi.port` - только http. + +База у этого демона в read-only режиме, список доступных TVM-приложений с секретами лежит [здесь](clients/clients.json). + +Публичные ключи этого демона позволяют проверять тикеты, сгенерированные через `tvmknife unittest`. + +Примеры: +1. `ut_simple` - поднимается tvm-api + +Примеры комбинирования с tvmtool можно найти в `library/recipes/tvmtool` + +Вопросы можно писать в [PASSPORTDUTY](https://st.yandex-team.ru/createTicket?queue=PASSPORTDUTY&_form=77618) diff --git a/library/recipes/tvmapi/__main__.py b/library/recipes/tvmapi/__main__.py new file mode 100644 index 0000000000..93544eaeb6 --- /dev/null +++ b/library/recipes/tvmapi/__main__.py @@ -0,0 +1,121 @@ +import datetime +import os +import requests +import subprocess +import sys + +import yatest.common +from library.python.testing.recipe import declare_recipe +from library.recipes.common import start_daemon, stop_daemon +from yatest.common import network + +TVMAPI_PORT_FILE = "tvmapi.port" +TVMAPI_PID_FILE = "tvmapi.pid" + +TVMCERT_PORT_FILE = "tvmcert.port" + +CONFIG_PATH = './tvm-api.config.xml' + + +def test_data_path(): + return yatest.common.source_path() + '/library/recipes/tvmapi/data/' + + +PORT_MANAGER = network.PortManager() + + +def _gen_config(cfg_template): + http_port = PORT_MANAGER.get_tcp_port(80) + tvmcert_port = PORT_MANAGER.get_tcp_port(9001) + + f = open(cfg_template) + cfg = f.read() + + cfg = cfg.replace('{port}', str(http_port)) + + cfg = cfg.replace('{secret.key}', test_data_path() + 'secret.key') + cfg = cfg.replace('{test_secret.key}', test_data_path() + 'test_secret.key') + + cfg = cfg.replace('{tvmdb_credentials}', test_data_path() + 'tvmdb.credentials') + cfg = cfg.replace('{client_secret}', test_data_path() + 'client_secret.secret') + cfg = cfg.replace('{tvm_cache}', test_data_path() + "tvm_cache") + + cfg = cfg.replace('{abc.json}', test_data_path() + 'abc.json') + cfg = cfg.replace('{staff.json}', test_data_path() + 'staff.json') + + cfg = cfg.replace('{tvmcert_port}', str(tvmcert_port)) + + print(cfg, file=sys.stderr) + + f = open(CONFIG_PATH, 'wt') + f.write(cfg) + + return http_port, tvmcert_port + + +def _prepare_db(sql, db): + SQLITE_BIN = yatest.common.build_path() + '/contrib/tools/sqlite3/sqlite3' + if os.path.isfile(db): + os.remove(db) + + input_sql = open(sql) + p = subprocess.run([SQLITE_BIN, db], stdin=input_sql) + assert 0 == p.returncode + + +def start(argv): + _log('Starting TVM recipe') + + def pop_arg(def_val): + if len(argv) > 0: + return yatest.common.source_path(argv.pop(0)) + return test_data_path() + def_val + + dbfile = pop_arg('tvm.sql') + cfg_template = pop_arg('config.xml') + + _prepare_db(dbfile, './tvm.db') + + http_port, tvmcert_port = _gen_config(cfg_template) + + print(http_port, tvmcert_port, file=sys.stderr) + with open(TVMAPI_PORT_FILE, "w") as f: + f.write(str(http_port)) + + with open(TVMCERT_PORT_FILE, "w") as f: + f.write(str(tvmcert_port)) + + # launch tvm + args = [ + yatest.common.build_path() + '/passport/infra/daemons/tvmapi/daemon/tvm', + '-c', + CONFIG_PATH, + ] + + def check(): + try: + r = requests.get("http://localhost:%d/nagios" % http_port) + if r.status_code == 200: + return True + else: + _log("ping: %d : %s" % (r.status_code, r.text)) + except Exception as e: + _log("ping: %s" % e) + return False + + start_daemon(command=args, environment=os.environ.copy(), is_alive_check=check, pid_file_name=TVMAPI_PID_FILE) + + +def stop(argv): + with open(TVMAPI_PID_FILE) as f: + pid = f.read() + if not stop_daemon(pid): + _log("pid is dead: %s" % pid) + + +def _log(msg): + print("%s : tvmapi-recipe : %s" % (datetime.datetime.now(), msg), file=sys.stdout) + + +if __name__ == "__main__": + declare_recipe(start, stop) diff --git a/library/recipes/tvmapi/clients/clients.json b/library/recipes/tvmapi/clients/clients.json new file mode 100644 index 0000000000..bb7c9686f5 --- /dev/null +++ b/library/recipes/tvmapi/clients/clients.json @@ -0,0 +1,302 @@ +{ + "1000501": { + "secret": "bAicxJVa5uVY7MjDlapthw" + }, + "1000502": { + "secret": "e5kL0vM3nP-nPf-388Hi6Q" + }, + "1000503": { + "secret": "S3TyTYVqjlbsflVEwxj33w" + }, + "1000504": { + "secret": "CJua5YZXEPuVLgJDquPOTA" + }, + "1000505": { + "secret": "z5oaXOjgB5nV5gycBpzZ-A" + }, + "1000506": { + "secret": "VAMgcBS0wRB5fu-3jBoNUA" + }, + "1000507": { + "secret": "4bT0rnjSnM0CBrskSLVViA" + }, + "1000508": { + "secret": "MIMTd8qQQ3ALLXD6Irv_fA" + }, + "1000509": { + "secret": "UAsWdsDA93sNfI8LfuPE5w" + }, + "1000510": { + "secret": "LUTTSCreg1f976_B_EHKzg" + }, + "1000511": { + "secret": "Qp7JAt_KUJ0PAFYi1Z96Cg" + }, + "1000512": { + "secret": "FM1XM4Ek2QNyz-hpzA1v_g" + }, + "1000513": { + "secret": "v7uPV5HZxYMxlH9D2fHKMw" + }, + "1000514": { + "secret": "shkEKUUBGJ8t-GM4GNjKPg" + }, + "1000515": { + "secret": "z8Rj_ogbBldm8XBtqIqB4w" + }, + "1000516": { + "secret": "3B17sAZVKFP6MUWVzDPIzw" + }, + "1000517": { + "secret": "Veli9VD280mLcIv0UtPbWw" + }, + "1000518": { + "secret": "qbdUNAfMk7hX0M9xJtHEsA" + }, + "1000519": { + "secret": "Uz-ISKFvVoVYiy9q-PC_9Q" + }, + "1000520": { + "secret": "UN5tvVicOZaHYKFcII1q7g" + }, + "1000521": { + "secret": "P5hRcHEkmK5zbZcEqAvlKA" + }, + "1000522": { + "secret": "erqXUL7bRxOCJB5fEorfiw" + }, + "1000523": { + "secret": "zFCVMjkmn0d2kt47unq4Uw" + }, + "1000524": { + "secret": "J7cUdsWKVVoeqvbCjOSRhQ" + }, + "1000525": { + "secret": "vcIV1ae41FnAF4OcOJANfQ" + }, + "1000526": { + "secret": "6iecbb_OxUNcsDqD6dwWdw" + }, + "1000527": { + "secret": "3nGLXI-LqzFICq_FVyg_dQ" + }, + "1000528": { + "secret": "_98qV1ROSO4-rN6pdt_mxA" + }, + "1000529": { + "secret": "gk8O8U6il5Fet6txZV8Wkw" + }, + "1000530": { + "secret": "jhkJGPcsruRy8rrvjBqHCQ" + }, + "1000531": { + "secret": "Jx_9QZcbS6pgi8tqM4FyeQ" + }, + "1000532": { + "secret": "Pt_gGMVoe-LpjAGGUL2L_Q" + }, + "1000533": { + "secret": "XCoskT_D_q-udy5misBRKg" + }, + "1000534": { + "secret": "HRJ5deobpngW2_D6cs0mXQ" + }, + "1000535": { + "secret": "bx2LxRdx8sR_qOFKJEmGqQ" + }, + "1000536": { + "secret": "u8OGPAQAsA6TdEngaksR5g" + }, + "1000537": { + "secret": "PpnbWrIJDEuBkoSJCoDr8g" + }, + "1000538": { + "secret": "jZoAveXpNuJrgH-6RYbP2g" + }, + "1000539": { + "secret": "7Jr4Fs82YwpPZ65Nq0fo7w" + }, + "1000540": { + "secret": "64n_JH6faRdgTon7potVyg" + }, + "1000541": { + "secret": "SxwDfkJySfnLPOvQBJlIaQ" + }, + "1000542": { + "secret": "A0wk3RnDWU7e6GGCbuSqtw" + }, + "1000543": { + "secret": "f2jmQiEjijWo8xivmuLQ0A" + }, + "1000544": { + "secret": "onhYk5AkKuTEb-QDwyuAng" + }, + "1000545": { + "secret": "bQxn_3sZvgPazJFMNUuFtw" + }, + "1000546": { + "secret": "kzw-gG6HCYqd8FPNHwdFcw" + }, + "1000547": { + "secret": "XHHuPPwLsrzE4I6RkPVMAg" + }, + "1000548": { + "secret": "UsBWJoLx-nWlziCVwB3ffA" + }, + "1000549": { + "secret": "EIZZooj2P6UO53YOdWyVQw" + }, + "1000550": { + "secret": "xbopwHPpJ9R-bTDQGeRZNQ" + }, + "1000551": { + "secret": "BrYem08Mz3Tt9RrM1wfk-w" + }, + "1000552": { + "secret": "NlNnf6wL8Y8SjEg-IorTAQ" + }, + "1000553": { + "secret": "EvZ6FrZkWGaBpaLrssWBcA" + }, + "1000554": { + "secret": "liVByAp3FXOb4xGcV_U-hg" + }, + "1000555": { + "secret": "fEof4p9_LWGwUSp2-HyAdw" + }, + "1000556": { + "secret": "B4KRHvr6Z2HS8St4JNGZFg" + }, + "1000557": { + "secret": "A90VljsKm1lpDge0OBGmTw" + }, + "1000558": { + "secret": "qnuJcyD9p5TdHd6GoXDmxQ" + }, + "1000559": { + "secret": "r1HDEiHbIWwR8EwVvAmodw" + }, + "1000560": { + "secret": "jSbOzmOKPuFtPoLcSfETog" + }, + "1000561": { + "secret": "QIN_pHBHI-2JE5w0mrnh3Q" + }, + "1000562": { + "secret": "4aWu0E97B1yxoLEH7Cwh_g" + }, + "1000563": { + "secret": "rgqM-66GNODKGrCKH5IXPA" + }, + "1000564": { + "secret": "J8-Utmooq3nMQPAWBt9nDg" + }, + "1000565": { + "secret": "P2K-8EjVdTMd6PZWSvMdmA" + }, + "1000566": { + "secret": "rNhgkxjydpxaitLDZT8anQ" + }, + "1000567": { + "secret": "BQ-RK_41r8FCtjWaA08vWg" + }, + "1000568": { + "secret": "PcnbfQ5whheKwP_XhJ7tNg" + }, + "1000569": { + "secret": "HqFWUa01Wq09p5VvgREyeA" + }, + "1000570": { + "secret": "th6F_p-EqoMKCdy-Huf7Zw" + }, + "1000571": { + "secret": "nji7fu5bGBdGNG6Yc1uvbg" + }, + "1000572": { + "secret": "hYsochsur9VNiOjGaE0UYQ" + }, + "1000573": { + "secret": "ZIhUGzPK8qsTI6SI7DsRdQ" + }, + "1000574": { + "secret": "h3Fi3rWXKSzAgoG3cnuj-g" + }, + "1000575": { + "secret": "lid7tysEVHMmWp158FsbCQ" + }, + "1000576": { + "secret": "JuW7Y0IwyhXZAKdSQ0Ub9g" + }, + "1000577": { + "secret": "QvPpqp7fLjLME2GUyCdcwQ" + }, + "1000578": { + "secret": "DmLAHLp6nHmaBeSbNEPflQ" + }, + "1000579": { + "secret": "xU7G-496l-1cs0kda_tY3g" + }, + "1000580": { + "secret": "DyW0jyW4N2XIyluzICxOuA" + }, + "1000581": { + "secret": "jmAkUgWa1HIEvb0Wty1znw" + }, + "1000582": { + "secret": "9d1mWWTZj8gh72GANWr9Hg" + }, + "1000583": { + "secret": "WHC11W2RMpGZBDSNbGmjLw" + }, + "1000584": { + "secret": "E5FLiFaAo8vF-d0iGeiOQw" + }, + "1000585": { + "secret": "PlO6AfAKNvGWO3hB3Brq0A" + }, + "1000586": { + "secret": "TmpKdeV-_ZjJrdnUK1uLdw" + }, + "1000587": { + "secret": "H4uSYOv-yfAAll2wn0UjAA" + }, + "1000588": { + "secret": "XdVwnji2sutxZN_TmiEHjQ" + }, + "1000589": { + "secret": "giDih1G7y3hMlKN-1WHs9A" + }, + "1000590": { + "secret": "O4fYUDRWu0YudtWLkk3vmg" + }, + "1000591": { + "secret": "scYvtH5Wk7zPjue1fj-TsQ" + }, + "1000592": { + "secret": "rw7dQX1SZvdf_uZ954u-pg" + }, + "1000593": { + "secret": "5oRwpw-GXoGk2qAT_6LfbQ" + }, + "1000594": { + "secret": "bGVzBZdt8aIJfzjeO46WWw" + }, + "1000595": { + "secret": "stH1g0OLh5qOduhoh1CEYQ" + }, + "1000596": { + "secret": "N2ViD5vCYRdrPN1Q-iT9Mw" + }, + "1000597": { + "secret": "lJluRpLdw90FWzhJtAraHQ" + }, + "1000598": { + "secret": "mCD9wUI2JoDisHmLwIWIJw" + }, + "1000599": { + "secret": "VAH7Kxuxyldq8lT4ecWS9A" + }, + "1000600": { + "secret": "WiLXMmbqhyOqxicxAb76ow" + } +}
\ No newline at end of file diff --git a/library/recipes/tvmapi/data/abc.json b/library/recipes/tvmapi/data/abc.json new file mode 100644 index 0000000000..27725868b7 --- /dev/null +++ b/library/recipes/tvmapi/data/abc.json @@ -0,0 +1,4 @@ +[{"results":[ +{"id":44793,"person":{"id":16360,"login":"cerevra","first_name":{"ru":"Игорь","en":"Igor"},"last_name":{"ru":"Клеванец","en":"Klevanets"},"uid":"1120000000026887","name":{"ru":"Игорь Клеванец","en":"Igor Klevanets"}},"service":{"id":14,"slug":"passp","name":{"ru":"Паспорт","en":"Passport"},"parent":848},"role":{"id":631,"name":{"ru":"TVM ssh пользователь","en":"TVM ssh user"},"service":null,"scope":{"slug":"tvm_management","name":{"ru":"Управление TVM","en":"TVM management"}},"code":"tvm_ssh_user"},"created_at":"2017-12-11T20:26:03.999074Z","modified_at":"2018-05-04T10:05:42.653792Z","state":"approved"}, +{"id":44794,"person":{"id":16361,"login":"robot-passport-test","first_name":{"ru":"Robot","en":"Robot"},"last_name":{"ru":"Test","en":"Test"},"uid":"1120000000021014","name":{"ru":"Robot Test","en":"Robot Test"}},"service":{"id":2280,"slug":"passporttestservice","name":{"ru":"Паспорт","en":"Passport"},"parent":848},"role":{"id":631,"name":{"ru":"TVM ssh пользователь","en":"TVM ssh user"},"service":null,"scope":{"slug":"tvm_management","name":{"ru":"Управление TVM","en":"TVM management"}},"code":"tvm_ssh_user"},"created_at":"2017-12-11T20:26:03.999074Z","modified_at":"2018-05-04T10:05:42.653792Z","state":"approved"} +]}] diff --git a/library/recipes/tvmapi/data/client_secret.secret b/library/recipes/tvmapi/data/client_secret.secret new file mode 100644 index 0000000000..d07b60a003 --- /dev/null +++ b/library/recipes/tvmapi/data/client_secret.secret @@ -0,0 +1 @@ +unused_value diff --git a/library/recipes/tvmapi/data/config.xml b/library/recipes/tvmapi/data/config.xml new file mode 100644 index 0000000000..d938087b05 --- /dev/null +++ b/library/recipes/tvmapi/data/config.xml @@ -0,0 +1,107 @@ +<?xml version="1.0"?> +<config xmlns:xi="http://www.w3.org/2003/XInclude"> + <http_daemon> + <listen_address>localhost</listen_address> + <port>{port}</port> + <max_connections>4096</max_connections> + <max_queue_size>4096</max_queue_size> + </http_daemon> + <components> + <component name="tvm"> + <force_down_file>./tvm.down</force_down_file> + <checksecret_client_id>39</checksecret_client_id> + <verifyssh_client_id>27</verifyssh_client_id> + <keys_acceptable_age>34560000000000</keys_acceptable_age> + <logger_common> + <level>DEBUG</level> + <print-level>yes</print-level> + <time-format>_DEFAULT_</time-format> + <file>./testing_out_stuff/tvm-error.log</file> + </logger_common> + <logger_access> + <level>INFO</level> + <file>./testing_out_stuff/tvm-access.log</file> + <time-format>%Y-%m-%d %T</time-format> + </logger_access> + <log_dbpool>./testing_out_stuff/tvm-dbpool.log</log_dbpool> + <log_notify>./testing_out_stuff/tvm-notify.log</log_notify> + <tvm_db> + <poolsize>1</poolsize> + <get_timeout>500</get_timeout> + <connect_timeout>500</connect_timeout> + <query_timeout>5000</query_timeout> + <fail_threshold>1500</fail_threshold> + <db_driver>sqlite</db_driver> + <db_host>.</db_host> + <db_port>3306</db_port> + <db_name>./tvm.db</db_name> + <db_credentials>{tvmdb_credentials}</db_credentials> + </tvm_db> + + <tvm_client> + <disk_cache>{tvm_cache}</disk_cache> + <tvm>172</tvm> + <tvm_host>localhost</tvm_host> + <tvm_port>{port}</tvm_port> + <client_secret_file>{client_secret}</client_secret_file> + </tvm_client> + + <staff> + <cache_path>{staff.json}</cache_path> + <tvm>2001974</tvm> + <db_host>.</db_host> + <db_port>80</db_port> + <connect_timeout>1000</connect_timeout> + <query_timeout>1000</query_timeout> + </staff> + <staff_processor> + <enabled>1</enabled> + <refresh_period>900</refresh_period> + <retries_per_request>3</retries_per_request> + <limit_per_request>1000</limit_per_request> + </staff_processor> + <abc> + <enabled>1</enabled> + <tvm>2012190</tvm> + <cache_path>{abc.json}</cache_path> + <refresh_period>900</refresh_period> + <retries_per_request>3</retries_per_request> + <limit_per_request>1000</limit_per_request> + <tvm_manager_role_id>631</tvm_manager_role_id> + <db_host>.</db_host> + <db_port>443</db_port> + <connect_timeout>1000</connect_timeout> + <query_timeout>1000</query_timeout> + </abc> + <cache> + <switching_threads>1</switching_threads> + <bucket_count>2048</bucket_count> + <bucket_size>128</bucket_size> + <ttl>75</ttl> + </cache> + <timestamp_allowed_diff>60</timestamp_allowed_diff> + <db_fetcher> + <refresh_period>600</refresh_period> + <key_file>{secret.key}</key_file> + <retries_per_request>3</retries_per_request> + <disk_cache>./db.cache</disk_cache> + <prefered_private_key_idx>14</prefered_private_key_idx> + <min_key_count>1</min_key_count> + </db_fetcher> + <passport_ids> + <bb_prod>162</bb_prod> + <bb_prod_yateam>164</bb_prod_yateam> + <bb_test>166</bb_test> + <bb_test_yateam>168</bb_test_yateam> + <bb_stress>170</bb_stress> + <bb_mimino>188</bb_mimino> + <tvm>172</tvm> + </passport_ids> + <pregeneration> + <period>60</period> + <key_ttl>60</key_ttl> + <raw_list_file>./raw_list</raw_list_file> + </pregeneration> + </component> + </components> +</config> diff --git a/library/recipes/tvmapi/data/secret.key b/library/recipes/tvmapi/data/secret.key new file mode 100644 index 0000000000..d959e077c9 --- /dev/null +++ b/library/recipes/tvmapi/data/secret.key @@ -0,0 +1 @@ +BtniSXCXhroPtOgXA61i5ZxSeX/solWR diff --git a/library/recipes/tvmapi/data/staff.json b/library/recipes/tvmapi/data/staff.json new file mode 100644 index 0000000000..ce16637583 --- /dev/null +++ b/library/recipes/tvmapi/data/staff.json @@ -0,0 +1,4 @@ +[{"links": {}, "page": 1, "limit": 1000, "result": [ +{"keys": [{"key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDiaMj8sW6KLfOGMbnxZ9RX5LI6+yXWNsHd+DuJKAdGPT526HyfUKoFYV4a/SWTXqq0sPOGZvgphFUZ0VhteZ2dOKNPPrDYumtB/DfBMbT0Q32vCAfKc6ggPkGOdtNzZRZg92SfAzMLvlDguBfqgN+z/Jraa7QpqzpaYd2aoG7GWAlT+ViK3VrbeL5R9Jzts5qP92baq+gZ1MBtmjCKXON/tG9NfJXPEImUduHE4e0uaLF0ZWQXPr6iLR4WC1OR+QyYFnhVmmFAiG1Z5T8o6WGb210gE7oaDhUeZAD3CZseT6vyZSyvBpeREI89kBOV44KlO1ExhIBFblWk07Jlvl9j cerevra@yandex-team.ru", "fingerprint": "ef:bd:51:d4:7f:6e:be:40:f8:8f:9c:36:03:ae:3a:58"}], "login": "cerevra", "uid": "1120000000026887", "id":1}, +{"keys": [{"key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDnEfnmPTd+PgE1PoF9X5iXeInqM3ygruOGwAaWk/n9uTbK1GhrslbzUEQOsbq4k6urmeYu8iTzfE/NR+pScubrEV/tvOsUp9ysTzQQVFsolcoUWf0PV4bMTmhRqYm782yE7o0dZ6b3aFDkRQgRchcq2uSIepCZ3sIVKKfhSiUBia7faPwf5szLga/xXs0sTDzR1ZcbSxQdn/h0ZORiu0nreLZp2hIak9U+bCfgzXmoeWM3QieWgkK+m6IJslek0gqo6wXmwssg26Pv1zJVrul/qL36KaGnBqC/2kppOUbfHQHO4cqR1QWP2OYlwx2EQQlinAAmSkyiz6Wu9TybaHJT tester@yandex-team.ru", "fingerprint": "8b:ca:b6:37:49:0f:53:2a:4d:3a:7a:39:30:12:26:05"}], "login": "robot-passport-test", "uid": "1120000000021014", "id":2} +]}] diff --git a/library/recipes/tvmapi/data/tvm.sql b/library/recipes/tvmapi/data/tvm.sql new file mode 100644 index 0000000000..d5b88c8e98 --- /dev/null +++ b/library/recipes/tvmapi/data/tvm.sql @@ -0,0 +1,128 @@ +CREATE TABLE tvm_client_attributes (id int(11), type smallint(6), value varbinary(65523), PRIMARY KEY (id, type)); +CREATE TABLE tvm_secret_key_attributes (id int(11), type smallint(6), value varbinary(65523), PRIMARY KEY (id, type)); + +INSERT INTO tvm_client_attributes(id, type, value) VALUES + (162, 6, '|11|'), + (164, 6, '|12|'), + (166, 6, '|13|'), + (168, 6, '|14|'), + (170, 6, '|15|'), + (172, 6, '|16|'), + (188, 6, '|522|'), + (1000001, 1, 'this is tvmid for tirole'), + (1000501, 5, '2:laWH0udZPd81EIZQ:w7aKyqEeqyc71Nqmg5xljYuy0azHu4qVzTHNnYNm:jWYDv6qx1cFusMinouxi7A'), + (1000502, 5, '2:E0isTVuM9IPMaqeh:zjzNLf_JlhzbvjjaXcMx3bsrWvxIudWi8Nl2Omvv:EiuxbzkUvJmxTee7XvpJGQ'), + (1000503, 5, '2:daHjG5dIqa0CEVbK:f26HIFO_Z0x0zM90yyUw-QijvnejXnWsSszuxlDx:X6H4hgc4dF9iWdz28FDYgA'), + (1000504, 5, '2:KG4538iTP3DfUGoT:ISaPcTE1a0bgLquoJ60LwD7W76VfKmAYbR0SYo75:ctwxUO1EiQSd_91ya7JGSQ'), + (1000505, 5, '2:geNDu3vFC5MOrKBR:Yrt5xx1ienXP4hQuCWQOV1GSq2ywUStcr-9q1H2b:Q-ynKWCuTW8EqbW0XbOKDA'), + (1000506, 5, '2:JMGQIOfpUEbUmWmV:Wqkd7iAchbw95vncNvRZMXhePB0tNcPxuLlEMIW-:ptC39VRBbDCkB6X-PKUPKQ'), + (1000507, 5, '2:V0DlGBiqRBheMQn2:svLTsP7B2l6HhHkjxTrknhbKQyJ2orBK4kcw7i_H:I3MPBdluZqWNU1deFnia-g'), + (1000508, 5, '2:ITYgZWQgkf9YPmlm:820OXSHj9Qi9QxGrDxmohZ0aLEY0-1fAJwG9cPA9:20AFBj9j0FG_eymUJ2ZZ_A'), + (1000509, 5, '2:7AYdZ7tu77hRcdGv:NNZTGGmzo5WLrUfSvH5G0NAEmgUBpS9s1p-7ZPjS:AZ1g3rFPyJ9r-wT65urhIQ'), + (1000510, 5, '2:XTgBklls03cWC7gO:sonF9bE4IPj1xH3Z74gpL4Gb8U5K2Z2w3U1obSs7:Y4h-REnInnkVmZt1Y9DzLg'), + (1000511, 5, '2:baMBAxbtMuE5elvF:WsR_jX6R-PKiJ-bpmrCM39Z0ioDbDv10Vz6TISRA:HmM02sHq_wR0NkbL0dDRKg'), + (1000512, 5, '2:u8lLL0JzC_GpRWt9:IvKNqfV8hplbVDJGQ_SbxIOTSJ5_5lWYMoKKlxJO:WDexVfKF4TKV2MhAd00-oA'), + (1000513, 5, '2:VhvNVO84E01lJWqB:eC-V7vhICPtCwT8FuR_o0Zq0SnLnTtXxpHXxDoMs:wPyeemLE3p_WvzK2Nkq7jg'), + (1000514, 5, '2:b4cVAihF0mVJEapw:CXw0cVSrOI3h_YrybYlhPWK6Y6MNjrHYwyUZrT2F:EJQ1qgNFdiPXgkBM5BfGaA'), + (1000515, 5, '2:G9YiYqYM6p_gRaFl:g9XEz17LnPJlA55RxLEQu3wKRgLhCDvq-ZCPlW3X:AoUQzlIiIIcR_PmpMjM8PA'), + (1000516, 5, '2:uHXi4fULgyFAVvfS:lF5VTes1wNstvZDkiWjCtBVI-ubZ_zvVR-lS6xyH:CfF2bZwW8o_1lB9RNxDrYw'), + (1000517, 5, '2:JELmdJ9aIJmIA9_j:qMl6TNhrHdFDUJOLn-4AxmBGy5xqrPApcwnaihIG:qYsYAXx6fPgyBxVkYykEig'), + (1000518, 5, '2:iOO4Gj3-E9ypSjdl:0nRDxDccoFFt_kTdRgo96q1lMYQr0LIeptFemu3a:AA14v5ixobCuCkQ-bPcZiA'), + (1000519, 5, '2:sdhZV8riyK8S_BQR:0RkLd9QSo9t7bFxkYqRzxRJJw-vY5J-x-P32i_Rq:YLIajFH70ZwDu4vDr9whCg'), + (1000520, 5, '2:djFYLoQDt0-L9T3l:p2STOUIZ2hgAnT05Llt5CZQSQkG-ocpjA83_Nnin:d99Y9k9IiXlrOMxmU8TyIA'), + (1000521, 5, '2:kek-1jXU2QGuB1jk:r-oOna--NrJ7wNTWZmV66mcG7BRef9BrCMhrqeV6:2-cWynuDw-jKlgG8VMTCJA'), + (1000522, 5, '2:u73pk7AhnBIYp6Vf:DPdQjFXyfiXlOqgYUSEBeBSvBq--B9wZBBEaGboP:y4akpriHob-1YZQW7CyT2g'), + (1000523, 5, '2:EJk-Z7ai-aghm4wn:i4-k8f3h92ypTK2hV_3n3rlRDVTyJWbsO0ZvlY5E:84cZN1pZs9jyh6MVjaZYUA'), + (1000524, 5, '2:gpjIdv65_MIkJSR_:S8-cz5-rKyPu_YpuWaCm6Cvd5216UJ462loKmtnh:7aXCAXudf0FoV9gsyrwqGQ'), + (1000525, 5, '2:RRyZY74THKpMF18T:rqmSy8yaFcshxd2m06JN6L70YzpkCtr7KWcQz-Ed:jYgJct7msl9mstsyejdBjg'), + (1000526, 5, '2:Wkdmx_Dp3zvUkWQO:Ai8XXqJbqCfRFw4vWV9yAvG1Xrv3LC_NZPYae6yG:uyXu5oK7E89FQrOJndurOw'), + (1000527, 5, '2:hW7izuInoBGuv8ax:SbfN8UaxvYT-GKaV9V2piFapYhRyygGUcwAfQHyk:ElccQQQFDTw8y1sh2O9MFw'), + (1000528, 5, '2:2j_ZvR4EwNucCxUk:9ThJ83pSzWx0OTlCvWUEVjx1eHG9O00oKhxuQxxj:NvUc2J1Ea0gs1aoqynIG2A'), + (1000529, 5, '2:cNn0u4R5kLuqPT84:M8Sw0Q44RVoFd0bn62mdCGexbfdPGLlTqvNkELog:KAuct7XjaWm-2LrVTYR0iA'), + (1000530, 5, '2:aMUqoKgJbrJoC0uu:kYQMpDqU12AwHKWFLEpkFjCy0tD-IQozt2ZxmYdt:IdCu84j8FDUNsyTnXSE7zw'), + (1000531, 5, '2:62jDZlhC79P-xPCZ:I4Oi8AX-aLE40LLb9x3oxw2ACfTlxkFQMSWXdS-N:X6s2glYGk4c1ONr-ebGWbQ'), + (1000532, 5, '2:vYAPMz7UDTVdU_ce:gKqb5Pa8TKI7U9nOgq_zd-oB02kBwd6kYB6eaOzN:Azk72M93oAfQ33hT-z9m3Q'), + (1000533, 5, '2:_aPMXhhguCLIuJU9:0r4YCPbqK_TJOiYBhZE9iC2SVxXkV9WQviN19-wc:6Sj5ejmrG4PrTITbF92qEQ'), + (1000534, 5, '2:2XpO3eoRJbcbmM4b:PoK87HPw2l907Ga4CUiQkUwyQT1zKahEHg04aQFE:ngco61-31xFxQFUYtRpfbw'), + (1000535, 5, '2:qqgnVolNSc50UBWp:hfK57xl7fWSHR9Ljw_ez01wPkdMSyCEKbmG8o5gW:SL7mvMEWy2oxxjxcJQnE6A'), + (1000536, 5, '2:e6r_D7kNcWiKcs2z:gMnmNNN0mj7uiFvrhufGBvHJYlc5fdusRZwd3jJm:V4TlrLvLt6bkrDuEEMXq8g'), + (1000537, 5, '2:ujR8rBs5jR_5MNn4:nUm6auDNJF6cRSN56FQAq4MrIMdq-MOvB5Xa34E-:emZode7gqJCsRwTcCI2trg'), + (1000538, 5, '2:1fU_nMnSiQh8IHpv:_7aZi-RpgHcV3-EXRqNKOH21e45npMqrRRQ_bc_K:DlwTjK1cRsnL2y-hT7XGHQ'), + (1000539, 5, '2:tJBRNwTVw21ib61g:F1oGEaVUIRO0NPi-QooujvKbQnBwF4WScgniONvJ:AR8Oa9DOZwYO4YBO27nUgw'), + (1000540, 5, '2:z9T3xQ-0SoZLR6iC:W1SHm4hWXh78_lU0l858JBF4J1GwNQtqOPrp4zAb:nT5eei25qpqz8Uy9ji1sLQ'), + (1000541, 5, '2:KKwDmhKe_NOqIy51:uhn2BF3urlKoSQ4Vi7aZJtCK_sub3AHc3SiSwzSM:W17foTbZ-J_dpZ4gijWleA'), + (1000542, 5, '2:ZSZCp8ECZJzYCAvR:sEpivxBxVoqtmlPB8k7GEaTwR71tT7bGHhbUWJMS:FclGKW7O5N08EnTQJBTUPQ'), + (1000543, 5, '2:czg_J9NmNcjCzGm0:FVHiWcgV7wrKN4vT_YykdFqfJWdWxrm1zed012Un:RAbFadB8j1RXRuD9gePClA'), + (1000544, 5, '2:2_2mZp5u7rZ_ndv2:7eza1NMFZcZxGEsl3OlF3xQOgEZ8tCII47tQb85O:oN_A6LeeCJTT45TmHK68yg'), + (1000545, 5, '2:q9fN9r_PS2mKgAkI:jzfLTENaOzTnIufwlE5jEZSP0uQusM_oXv61F3NA:49TSXY4x_i6wzTFQ_OADAQ'), + (1000546, 5, '2:TVuyfVinJ9H0tipz:U0Do2X6E_is9EJX8Q-e_n11igjdByHBJkjXWnH8_:XqajQM1iSyOWv99sdmcQcQ'), + (1000547, 5, '2:-SNDYmwCXLi82ETv:KS3Thgj67oqXbEVNmNGvfuc_qEFDkZoIabECnJh6:TAD_JyxigwHbXF2vcVPTvA'), + (1000548, 5, '2:JlnO7NKifeCuFNwY:pa-3J_7JG55h3IaEdnWlUwBvzB_nmfvLgnxjyeom:hnaYWVUYRdtSiNjhroLcaw'), + (1000549, 5, '2:ObQsSt8GLaTOEqBr:1TFSYFZYFw5fWvi_E8Xf4dw-xn0lv2L07_dF-3oO:nVAb8zj39MxCaVkQvP7vtg'), + (1000550, 5, '2:pwoySbETze67TDCO:-smk_WjFzjOuaCbbE1unT_jYmTtNoKeixKhM5j-6:SG1q1iIbppKm3iPvG1R3dw'), + (1000551, 5, '2:S25sbkoXx6oyx0aM:LpWcsr6zxsuip3stQiIFZIqOhGb1Klh-qInDVZOa:UE38JsRultaz9IR6TKjZ9w'), + (1000552, 5, '2:EQm1w2_RmBMZ81_D:zQslMNxzSUQXdMrng-JegFKGbCFkYGC1zlV9KxPM:dhrkxrAnyZvGhZz3IXjTLg'), + (1000553, 5, '2:SbOPVWNxVwX9OS-q:GcNxNLfdg3dcJB33Vuf3ejWuBUKfcYaLuaLEVs9J:D2HDbb8K7RWTq9nbkYWbRQ'), + (1000554, 5, '2:FyvKgueosQ26VB4i:Iz1pcJ7SXVQMDV7PnWFK4mskVcNVRXaPPdBh1MRF:JlLtXKnpA5VRLYwaZ78-Qg'), + (1000555, 5, '2:c2H9sb4ovM9Quaxa:DHvbnrht9_oELwJm3QUY9IvDAFaKht4RFSWMjacR:rCi_wXQ0e9juSs3o6YmnwQ'), + (1000556, 5, '2:1yywS_osfg4qlbKz:tnFkQGG_sYoBkKvXlrVfJvOhwaiIdo6oeS6Usg8b:qkWPlaqHPIF0qlRLfTaQSQ'), + (1000557, 5, '2:hU6iZqMcnGyM_R2U:rOoDpoAC1tqH80QC6gH14KygtPzcKrvhTvZZTB-N:rhdmtlLLceV-RLmmWUI5Ug'), + (1000558, 5, '2:mvqm5cTya_a9XJOH:Q8NhuTg5WVERmEZjxdx5aU00zwmTcaRVQPExPjCN:qumXyKXAlB7t4oLXGOA15Q'), + (1000559, 5, '2:gbfiET4Druz9PlYO:bryPvuhWijqK42EkTgPwWfQVjdf2JRVKqV-c8AEM:x_xQQ1tIjIKFOJA_6kiEVA'), + (1000560, 5, '2:Mz18kmSxOm1WSjUr:Q4aYvn8gU5tq_MQRyq94b3IOgCcXf2VdPoAzCI-E:mehRmrYXBXTCgEqy07sZZg'), + (1000561, 5, '2:PMoy2sv2GaNyzIxx:ITDIBcOpCVaaFLoTdGKAjQZVqPfOXdaEs8o44zKm:1cPm7JE6zDetfzlLABP5nw'), + (1000562, 5, '2:ODrjBVFyLvMHuqZX:yCynrW02YMxuBhwGTS3QDWTJM5GATgN4sIBPdHTL:uvVw_YAmxO74GQ97upHatQ'), + (1000563, 5, '2:58GZ5WTNKqNc8fv0:hroOmRWBRaAZaT8t1ojXU-ptgTUhlJq2jK3uIXzx:5WUvHTPm5OG269txkY2cNg'), + (1000564, 5, '2:TI5CY0euHMznIqUm:N5VXOCbQhP0qO2lO23P6k-FMtFJM53rfDlTJb2uu:bDra4A5S8himRapb-gVDgA'), + (1000565, 5, '2:mWPnAtxVbQg_k6xw:jny4EaKrmxg7EfQZ3BfP27zWte7KrTAQ4Sof_Ihi:Xc2J_S2XBAvM7v1f-WSBsw'), + (1000566, 5, '2:0i9GGU1zBGjXqLJL:VHt_4aHHdGXuI-RHSwz3rAn6gxh34976UxNSCs4S:b_7K51YlIQPAqNo2njsdnQ'), + (1000567, 5, '2:cnwe5H5wjI6GOnvI:-zKRGIhQJMs-9-F6YCzwYZJDc3fejCVIbpLJyApz:ZS_ts4ddp9xx1OAgjZQTtg'), + (1000568, 5, '2:upsEzkjfV81oGumE:qu-NbQLPbx6RW3DFdPiTp-T7T1IGpbIIWl_Wybb0:Lytchszwmc8rZ3xPBQaVPQ'), + (1000569, 5, '2:iHLTtiqcvoVrma5U:SDKu0-lTsaz78N9p94z2rhvjsy9e4WpE8kZA30E-:aF88ehqAXBhHIXQoZT_Wbg'), + (1000570, 5, '2:pnzon4T9d0NZYf2W:RK6clDtCy5E3dyb2LDkuhuHJ3SLhxfRdoKjojen9:y9hJtRuH_b4nQ64THBoKUA'), + (1000571, 5, '2:M20myHYFXTb362rz:ehtbs1cDLNeIPPi-X1OjK6Tb9BiTFBTodrI1eU3E:C9POT-fnacN4GZY6WvIi3w'), + (1000572, 5, '2:oOW_CzLdbAwuqx_S:5-8pSGU5GO-efgN3WNoR9tX4l8ot6omHoz-88Us5:vT8hNKA4d51mTjLJSejlgA'), + (1000573, 5, '2:Tfhxe-GNpb4UjPrP:rzM1rS24FVd4JLWunx099UHwugnxaYrGkkO147J9:68fAzhQ5UR_2SpPC22rjdQ'), + (1000574, 5, '2:oCjAmD0sftLTWX5b:c03BruTzS8FDWLHjNsoCLPUwRIMUcn63DKCzc8TW:NsCn2DKGTqN8gunzI64Lqg'), + (1000575, 5, '2:FgbbYF-jfGZq7EpU:Nz4g-17KbNgUsGC7VdHpXMZ5L2olmQgLcfCChU-l:fdG8EOaaLXQ6awia4zTcZA'), + (1000576, 5, '2:iVIDEqDPDscFE34P:yIPREH7hlGf1kCI1g60FUYsnoIRce0KJ3JyYyjXl:vsdsZkbXQ7vVk1JuPd97XA'), + (1000577, 5, '2:8-IflTPf_XJYqu9A:Do9-zfjw_-eWET0hdCtQwzIR1w6stQ5eKeWf5-1c:whV2rBkFnOhONO2tuIrZ1A'), + (1000578, 5, '2:OJySLr4ms8pqM747:sSJkOkeFljJGvVGXSrxempZG90_ousZ9BcBnARpl:1TjRxj-yZqB6-T7HGIALxQ'), + (1000579, 5, '2:11dtUtwAPsnMrmzJ:6UmNOHvfmRmqCY5aUp3zI489p6w8pP2Cmq_Nu7vr:oO8CgStzxRWHJ64sZJ-_rg'), + (1000580, 5, '2:f2VPZy3oJhGBgCnI:daPkPzLkMVAgPC6x2JdCy6SktocQsdwtiJuvCR7K:B4e_dSIUGH5pMrF_KUATzQ'), + (1000581, 5, '2:cLDvN4Mh2__Zikkp:RCCnCl5aCrNbXw7zrw3yGtgxck7XCBcQ5SyJDH5O:h3WaSR6G6iR7Smlgf4fgMw'), + (1000582, 5, '2:Pzh0wHFdvmMR-FnU:eZzpZj9pLE8aqNJUicm9wA48-MRQlnN13xXcJ5ig:RoTHBdEuGrBo_joxFc3kyw'), + (1000583, 5, '2:D9R3I0HXYnq2Naej:vtfRDxIw8p9ExxMT5Iup2Kf3N2wLGVExEPrcOmBs:FnnZjLCxqGMLkC3opiH3_Q'), + (1000584, 5, '2:OCWZ9drBaeqj9ooQ:8sr1AOIwMcGmZhYu9LREVXpH_PCLZrqAz4BQAfM_:jniYjGkw3H7JlWttPI7hQw'), + (1000585, 5, '2:JIUgGHAlB4BTnAHv:_3XNmLqNImDkymlUMHJtZn37ZzBMj9FC_owMvGO6:rhkWpR3weqBcNC9CvxYOxA'), + (1000586, 5, '2:HE_cwYaqcEq_X5uM:GOIW12INz_wWtaFlFnnDb0CciCq-puF6pTJSnsSa:Y6xr_CHCStUq4D6jOy5kLg'), + (1000587, 5, '2:py2gtY93QUHidY5j:ojHqgBCkV0yfJpz4OSxTILGCXZj8DbAykl6n0vvk:Ha17OwU0kNdpNZQp7EqmTg'), + (1000588, 5, '2:Yz05_d4-drOJkjYR:CzV7qIeil2C25tF1aj0iF3EXM1SC7Fk6C7HNOPHP:sSoAtX0ofIdeJWJrXrsV4g'), + (1000589, 5, '2:uj_u8zprvxX9zD_Y:-_uS0YHs5IwnSyVV3Co6M4kTNle9SGc4XSKbaim6:uvNGS7Ovh-isWV487CasIg'), + (1000590, 5, '2:qOufgoMFDHoHztU2:U8X0QspTt8p6UHmOXW26z8D4sB0vyb59WLcjY_YI:dj_TuYW1_vPA555y1whMVw'), + (1000591, 5, '2:aLcHZKefZkT-5mA6:BCEKm5rNhfYg0rLSiCFIg8jjw39mo6Gzw0cIpVEf:TrUGLgLQ6woZL2wFssHNNA'), + (1000592, 5, '2:MBDmDtmjIBY2vW3Z:1-XdLeya2OdQbqPedPURrJELx7qK2mFxcI4SO4VA:cZ3Q14_qtiVWddVB8lpDKw'), + (1000593, 5, '2:W_scRZo5XFuRJSf1:VyNDW49oxE9h5sVeem9J89ZzlzF6W8Cs1h2SII4v:ThCrNzZAils18Y68PzQaEw'), + (1000594, 5, '2:3DE0dsrQpti5LTFo:0ix_f4Aj_CkMgwG8zKvA9EszLaFkFlZ54TbakS8l:kLEIyGMwk0LCsdP2IW7o9Q'), + (1000595, 5, '2:1gd9hqbgI0E4W28I:KyiFnT0b9JJq7ZP3lBdVK6pRtRjZ0YdHDrBFIs4J:J6oA_zs5J6HrXxZffzVYbw'), + (1000596, 5, '2:Pb0BZ1n_7xhVSZYK:6ojkjL2EgBsu90qljNWdyGulJu47dBvNwigS6WYw:V4QBWtxvGg96pH0wPu9C9w'), + (1000597, 5, '2:bLPpTCi2b2R5dUxt:Ozhd-7AVqCIJzCL6iNv5UmhpMx6WeqBoAvba3Ot5:3UWX39H8MWJpWT_7YYqmNg'), + (1000598, 5, '2:xQyAWsoDqjR_POH6:bZwxht5YpZjf0qZk6J1rCBJQXjOkkJSkiKMt4xYE:m-Hzqvs_pckIV9PdhKqc0g'), + (1000599, 5, '2:tRtGweoTGw4wyl_R:0ygW8d8S7j97etG85Z2jjpp2hwWwaEdkA9ZLQGmn:3NreHDFl6chaq6yhQv0f2Q'), + (1000600, 5, '2:nbCgg2CdIIQBEmQE:CB-XvIY_HxZwqQ7bWLlp3QFde6xj1O90tCTfbvhm:GVBpQealINSR2Hi9c1Ou-A'); + +INSERT INTO tvm_secret_key_attributes(id, type, value) VALUES + (11, 1, '2:7YbhlHl0nIsaZagY:_MaLvZ0ZsYKdZ6CqiwI5xbRnZUIvivBR6Ir6n4D6vR9sNDdpgj5DwSjCYYcZUcnEGwoG1COQCVIpvXKAc-w9v2bsP-J57S711kiNqT4qbyJ7WnPB10wTB8yEbAZKBv7MVh7TKr-zJcxvfhLYr95-TTtfYcwn0y0N2oWwz4URdByphjn3-Zy-HXPK1bkaBZzelIjeOkN4AKsrHToBZM2jjnFqnCAFCYL9x2CZxPErHWayeJBXtMnbPblbRA:5Cd81ytYslGSlQVTn4gDvQ'), + (11, 2, '2:r6uCZDsDpvQG18QZ:lJICo0pfD_04aNSzBSWf6FZC0yvL1TA7R9mZ5RHzCbj3CQLNRri2gTniWyv-w5ClN-957GIU8zeYoqT9iZnlhciju2jdKxKuPAMn0-C12Vo-tpRcmywyNRa2E38Dt7NVv7pXTOMB_xcesiH8__HjVDceDcQ-qkWTVo0szSEEVwlH6f-l7hlNOA-bE2Y0dqE2UNB6KpmXAwdS6vFi0BlcPpJT_zA-iMPKHdVjkMVWgEWq6__9S2Br3KY2emu0Et5dOH8rbpVfDYWywnriUJ6UfW1kFMSVrPTcD_vJgxDw9uG-cL41FX9PPSGs0coH04wLsw0UyvGflrUA0PpN-SKSXx-U5jnKa3LsPX9q7FOSQxytD2xaBDWzhHb3QCHIXTTmDohyswJLvMV_BYsajHm1MnajcAI98yp_I2n6Llj0O02fBE2q14HbpJTYqWRIuSO2wLq6eSOrQjsy4YKv0qkr37hGYShY27vMrX2Quy0GwITA0d-vsiE-GPvhh445V2mmfWdG2gHPz3VTzqSbwlk9w2dUKOmmZVt4Il76F4fsVw6BcY1lG-TeOXzXm-Vh_R1dw6siis-4bysqXp0o3eoHo-2ntBsNNKmTcwFcMb1kpVpV6XgqFL0I1rZM1OoUY-w6OHCioyAcUd5Ou7GA2LhivdafFEwOqdQzAcyw4dNaafG1-sWvLbWl4PpXvgMw8pPJ-2VC1V9hxk2gQRTiNe0ShTZVhsqOgbQgqYMl26I4c5uwYqXwz2WisJ6KPRlAIY9LhIFlTBSwZc4RYH_NVpDlpYm5YsKjO_cfNSrwv11MB7wIAzQat0ASDtYmBvKpXS-_xfV_mzfU1MR3qyJqEBjniKI9MJRn_cQMH7f0rFVLmjBnIHE0hT2edpYS9XYcnV8MT6yB7DUQi8ihL3iFGsqbIl5Q-YBsc2NwQKBxb_ZK0LwNCj0NMDeojgSy_TEE90t0ew6BFZiQcy2g8s9clgs__ERTfAutzcvQ3JSAVJ63ZSpbkJOMhc-8P_geZOmevKqb4mALGJs4IHAHFYqKcYJfaoa--xTJxtIdBtl3Kh8wxJq9j6dwODvk:yIUUn4YWuGO7ZX55vPQYRg'), + (12, 1, '2:Df8zob2vAPhTXMTw:TxGEMFi1zsQgu7QJpehsjEJQduuclrS5DpjLuI7sZDBaH6W9P6-cqejXrYud61WH39xI9rzOnAJP1xL5VPQiOCzOeB54Hmkd7pSSeQv-7i40MPplqC8YQC5Ff_ywUlWhgXUEzmk-137Fhm_VujhCa58LAAywBWr_er_haa10qMXdfto_IGmWMl01QsQ1KhE6CHVpIGEyEaMLNS7_x3dNk8nbZj_pwpyAcYTfQxxlnoEzp0umS6PKjsT0GJ8:zFGTFDMOl36l5hPH50rrWQ'), + (12, 2, '2:0n9FshEW6Pg5wDDl:6XBHnGdCm-1XIHh8c9WbodhmnAQq2gc9dNmBUxsqWTag0Zoi6c02zOkPUTi-WZQYnOeIcELkZyHeNqwIqa5rBSET_dBAaDpnN68L21aMGaykLU935hANgCNt7jcLueAkMkKoBOQg4ix7U6H_qgOR0UIvRourF4BHa72Rl5-6x__FnUhyUthP64pywjEgEpMR_DMTMXUc_QqFAtu99__TdDDJU_SnsCfi2fin2LYULxtES2MNgJ-meBp1i7T1mlMTmqchRl8kVnkpKwWXOBcZ9JerSzUGpLtA_Sa_XsBYNIsEhiafwBHO_7a2G25A4bJhG63xUVJawgXJK3xBAPFNz7mSlG1MiVD5gv5YsloT-2732JPsBrQ9GSS0_KtEUdMSzHP-gLDd2KA9hpPlSzDd1x5qT4xf7YoN3Vv5SaB49kdHwMlDQTcue6BXnfC-7mlVYR9HwECX232PTQAjGdOw3i0NUpI2qVCmPJt275s80gl85jo2C77B06ngnMHDmy4CUYxr3Z5AxN1veCkUF6ViI4RcuihhaUg6DIZmcTDMEF5nCPmm1i2_-qRcORrp3h_aJaG4dWPw28VEojws-GLYrOjZ67YwSerOuMhg8vKIpP2yTT7m3ioGZFlMubQyvB6KLRdf6u-_ZshZXj5xvMNraiyerXVI9RUI9sg2hWrcnhO7fHZrWbSCNpg2RGeXsqDcu3mEfEa4ETK6mJ_EF4WJPctkiHkERhjrfF8B8bOIdpnPqoT5dQ7Z0B0TFWI6a_iBVFq3eXIYAr-Vpei7lWATqjjmMzVmIlVmovQvWQTXRcOSuK9IdRdlZIh_kJiwzCWGSx5tglgX0GpOUrED1aWvYETeoCo6D6SqeLVFjPND8PKOf5Tic0khJXFXSqb1lh63JNG4ch6K6B0eEyG10MJ8MFNS7LwYsmDcaSe9pvfZtxmsPFZa-0W7cfA7C7vsWKwmNY3pTx1Mf-NLQxs2CpcvYljwjj3XKua1okHSxVp3jXVObrNVRyuw-iyJfXsRw7fMcIJQIWy-oLLlbOFw_Nn2X5Vm6f2TXjST1OpDxmbvSUYh2IBsOBZ0Hw:8PAxRdI_H5FehVHQJt_2rw'), + (13, 1, '2:WiHYcsoSi-6CVuuE:jY83J3929Hi3i2CgRUtS1fs_Tf2AnGcW7TnG05ht9RvjbdBgPTYtShsQgXUnWWOV2fh6-IlVpOwnE3v9FqGT8rEhy6V4QOqa300YS13hknms19Lt6UKEHsJKrYemK5EGqkjFTijnyC-AMqOtSeDwmGW_2hoNYOmQgaJD93DFiPEFU2KYt0TTfr5avXjkH4TCzpGUGRfPz2WB422Msh3dc0_hI1COkbGaAUGLuNpVcAohLOvu_fAJw7tgXNI:CDdhvDyzAz8cTgjsfw-dIQ'), + (13, 2, '2:h715IWhxexGa6LJl:iAz3GelexXYNegx9UuiCWMdxOA5RouRD7cZY5xOuTLyab9G5mhwzNjDlhNx1mxcmp7GukeqViZusBzZ8yOrcgA2NjPwgWp8s41sb0z0PXZYJGqqTteKV7o_ALy7qRmeSb07x3g7v2NX4ezciWJfPuJtC6Z5AibJVt6_YgFBvGrbJwqQLnDslKNUSDexaMpqWOglwCNHTS9akXS_Wgs6n_bzm-3qB8peY8pMipJMsd_XofNTRBp8MH2Q_0PsPgfpRdtY_heSSjpIsVmCbT_kNP7zAXotKlXA70oAn4MAd72Wkt3pL-UjMlNY972H2jw7TNTn0207wwZ-37ZVHlRCx2kCqLkMIL1s_U_vEyGlnR_QQQM9dNA6ONe0mLW5lhEeZOTamWjvYwTKRma1NbyBk5EpCJMMbbovJdItHsJgic94u0r9KGBgREyp_MbEZcXbkd28U7eF2m1wY7TrpRsEc4ZLTRUgbqW5F9rqee0Zp_v7MNX8Q4JpexYJ3ok8U5nWBqqJQaYKj43A0OtuasQDkjmrgAEVEVU6IlY1dqBR1DRU7ZDeZk2yWErtE7ePGept1savAB7aUdZSw3PkTRHxKJqGZQJtXzF8Hrt4WUB9S2QOE9Hm4OWvjqHfM9IUF2lhDF82M0oaUcpVMTLW3ytjsFMKOkZgzxpp4GbQXsexL_WOIkURPlBHww1WYHJudazPqnszxIrsRtI1coRc6HiVJ925hlBKY8g3bwzTLM2cfXMkZZwmjajLKBx-926LiCJ2M-JOfEMnFfpFpOeLYbTJ_dddxxd5beT-eNRhvOVxiHvinw3n06CxyUqh4oHyY2dOQlNHTh1G7QRVBcUXHgRaYxPc-Jum3OD3Yxpt9rso_6saSmbT15LeF3EhyXZVkSPkAcWK4MKLUH77qxchYr6hKIg6SeZ2CeXByRYvd1k-XbwOrqFWJNdaf1eBSa6mgHgy5NKCMbQUlunKcO1pWAD_watKQbfW6bJ-R80_4bqwSCfFMBcJETNxf_Lef8Xg22HvBRdcYPxNa9TQjoOy0VU0mYOeq_NQrw2pXZpm7teZZrIqSkETKn_oc4A:gSyNyreLOTB2YSbM1ZOTfw'), + (14, 1, '2:IrGSkY3MxWglnFsN:wbgibmmeJPMAaeJl3__OQP-i4WUqyWgUFDtZOZgk6z7TApmNbTDoNq59y-HAWWiSWckVOhZI25MwjQu8nTJyEC8r04l8vUCEvwP85VWUQWJaw_OvpWDHdytm1Uy4jTwdxdXPxIzSEWOxQgK18ZKl-gJhoH6kFOpjvxIEJYhaCSFxWoqqYUNz7oFvuLjUMFSecCtUbce7GXPzzE9sZzxH4Crn5GRKipd63izJUmnYxaLYm0zg9A1h9jPMYA:fUK2Vw79Vs_0knIyF5S3HQ'), + (14, 2, '2:jNYZ-e5P4JTLlbQP:QeyMk0DB2ThiCV7en5hhwZZhOxWw0R-z_sWJlKNh7esKm5ONuvkx_NjGFq428qBuc1KoQVdwh4tvT2Be64DMIwL2Pk4BXY8xc06XETqYnMWmc4T2vyBLYNbqsnE4wwoqM4swFl9hUVxLdynJn30H4c9n021Hyzcjb2OsCsGBSs8jUz-OysdFdS59CnI3su26foSSXw2s5VTDNKfjX7cZdm1s6EYMwjl6Dah9q8LLwmWlxKORhE4VXOwljyTVHB-HqFIsxVD1MitrOBXXiwWETkWJFpBcXs5g0bLUChYOJx-W66EwKveWd4q3ntpBrVKRURKC-GuHaeXHMventnplUiJHa2vilQJqc8radcxmwLPm-8IsOk_3BGCrM7NbDisqzPq1EWEEgJAx-kr8B3I1o0BA-XiA6CkvwORzjpb4nZR9R13R8q6_VFl1-_Xd8WP2rn3bJDQ0LnBHeBjPr9zxTPsel6OxH2hohv-u25fhT1lO3wF_Y7un6ha0a9RgJLAG7Hi1IKHmDv5igFUv_8NfYgmIF29WD2iYS9M46vojkI1d3uxJVoK27to4CQfWepk29iYsWpp6mUzy8mMbhPt7NFp3HuHN7xBeUlB6Gxq1-UNQEbC61bgN7c-lbDzSZJNFv-NT0phTmnm24kkAJEiP2beuvVV5akDkeABXtfqf4hrHJOYazjBaTgqu-EnxvuRf6AA-EvQjst7y6olcqS0pkrjhldwkXN0m8W9aj_aruerBGIEz6_oRbf_SzAtjDJgsDIvMObm9GdxN40oTzZzspB7YUkiH1irFmb38GionpHUwCoA1bLa6_wvW8OYwBFJ393CAmio9pr-WrSI-L9ztvEBjXfiG1WKy0jqKE_3QvDw05DmN57WXJawHrmGTI576FYaxC3QO679KJ56A8se2Psfo89yvDqoHHn_R9SK-Dtg0g2m6ron66Yd3qJ8seiO5higQaNu9ZCLBqV-5uBqT63MELQztv-chN8-snXPa1C4nJHz03Un4_cSZtasEBifApPEu3aM-o0ske6AM4LcSzcpp2w597NNDPVxJZYS7DMP64tpPkjU1:1vzoNnD-6colsWPAJFi-fg'), + (15, 1, '2:c_ETCkusuxVnpSEo:Bz75HmMt-ZISKVcjGh1E3TR1bBbXXKcLmdOvv9z-89MzVRK4kTahouH-s9DoJNvJD6qv-mvElWHvKN6-WU987c21oUjxLlGI6WFMvp0tmUUJAUWeEpjhnb4xqV7qKIcoWRkD8-ws3BW0iMGI8ksC6-AeFX3zzQ9bPQms6KvNKlRibuE8eSJaBaF3Cb0gXh-51Up_LglRp8PDuf2uP8GRIM5Xui6QULM4vv28FEpIIyn_mRC8gWG1I1Fc5A:H3udl-uIufUzaDJG576ZVw'), + (15, 2, '2:tnjjNmVDQIh2cfFe:y0fSCdoOVulR2KMHBV90O1y52r7hTSFZ39_f6JS6i1baQveQZ7rXd-DLuNLPHh18iz7Cx0GJjev8YVGe4FxU7c3-AEZC999NDT-vqRj8735VI6G1FG4O0qiiVA7EpyURxsFvw9xq2vojCnirw9yY4pZd4HNQnb9hCymKqPUdE6_GzYKC0ESK9ZYtbnLfdH0P3r35IRZm8gx7E99mGR3xWbxLUAJV20OZ6I7ZAsPutAx8VYeXiiKXpTbHrxJsfnLgPiPKQlwBZqrhDw6jj9UdVB0jwOvUUSJRSDBZIY_POvgMNebqHDSYXvmpTqMc9wckRuJTAMHRxawM36gjupVDGUhDtfKEh7-ERNhS1sqZwHfvlptswQj4WrB09E3i67phxUH9GO142m6hZsVpJFf0F-BGer2WljTMN_NqaWEvTjp1SQxSDl288SD4KQkUUovc9z078knIDNmqnmvcSt9lTZ7vaAcwrPJ1a72gKaPLe-olRga6lnGsTXgpohgQjYXpi_G08i4AibzJrlrhIodYczuQ4O8c52aGHUoQFV254PPi3tHwd2f7L3j-PMv6eC4t-7972Rxrmt9BGrbfyvZ-4qH4KWslzCKzcLCqYcdVSANIn3fMZ2RZCMFiE3WWyWSH1zEbGkE-DdDuFEGpybFQwswSRXSk95HLe-oNAQYl_33iQTxYLplpyrAwmiFd-kLRh9FMvNDiU9XvOYPrnAMmeOxzcCdcFi8NSjXFGgWe458urNOiB7fNLNJPMUnaRwYTlWrQBhMKLB_-xsAn9qEg1GkmxApR5s3pVB0giMwBFePgXrCK1fU7gXghrz31UR504y97anvV4Q0REJo0wNXrLDSZwkb8Tz21XKNEOPAo13wWCM9d2n4EtfNRhTNLOgGtMJGAYr1CH2xI7XSlG3Ji2Vp6krheBzy4mzEKkT219yQ1_ELpX8ejXIbWIiyGx4UarOEhFFl7PYgv0UKhMTax0Oqa8H5RB874bPq_2fCMLWeCShlopOdqbt95Prek3glUyjTN50rx1-vGNBXGRIeUbjbfgneDq6rV-G9oKjwe5j-2rLnkAgra:kbxMH81ZBYt-GTJQMDAUSw'), + (16, 1, '2:a46VR6JExNtreCmX:bF9bxIEj63fSI_PNOH36UBoTw587X5OdYEl4K33JhGBs4DH84VKTARcZ6LiBzwPwN2Tf__PWy9fAt4V4V5eAKKOjS_erpv8FCB5LC8dJuDaHOZOuZCTDhRom9gLbGxotDwB39LOaH8eBOEf6AkwtxzRJrXyZBQFtb8KlvnlnqetfgBLaf6111MyYO6P4SiA0LXrE9MjS1WY_R76K6kGF9_8j9qbCke-EF0ougcqWXsuaz8mEjpHG3O8M2SM:F4SLvUdC-07ISrovDWChDg'), + (16, 2, '2:v063VzSOGooDsKyA:ZQEhJIZS6WcabCE0YjE1Q_R6pO9xBeXIlzlHNAlipv4pzZ2b9HKi8fFahNoHJ8709dtxvi0WWnrhxdkCMKe764fPx8vuHtMuTUHderctwynuawb8gZ_9OUQWD9nD8mvJv7sY61g6KrrzD6RZDkSXtcgJp3G8duN_Hmk2zhZHcWjbeDo1SmfC1lPHIB9QiBM_OCr3JL3dhOQK_3cGeDTBbHmEDsTTUlco5cPIE7UlsdwrJ46S1d1nvCT3B3nondfn40p9MdZ0-1FGNYb-d-QGrlvZMZpdIfgBgneUGHc3N9s1Yqp6ZLi-TgDWHi73t88n-BbSTYJnLUgYrMd_UZmvudRZ-RYye94bdShoO6yF4v2hfUvxQvsk0l9fayXiP0cAx8EcZ85Ndo442X9bBjwRr5FEcDDWc0MehJ4pW0d_5fDlcVHeQcrdMDHGssPpbpwWEv3gveoaf2S8AYXf3bu4A_RnaUN922T0Cg79pyRfhnsxKVHdezc96B3si7b-wrFqSNxJxYXj4y25licxLIVDnva8OVCLB4HGKnx5XEeSB-ocjtYmqeCcTG49NuEey-3jRgUkzDwNVYfFfIdHCn5I3Rg77KwpYijnNo88F0kUzFui3DGHxshT72E0ZBt6ieolMaghept398KBSlJ6UnWp6P83rrznmZjVhyXMaiCBrnwnS271AP618ypLjh8rKM0h6zjj1Sop-PVtdkV7Z1oFPyLGGJLOaOFO1OXooVIXgyLMT_FLnzulQ5Las6SvfL_44fT0zPPigyNTM3PHmca5ZjOBo9RM_CYcDXdZ-IADkkzfBiGgzZiRRWCdFP43oZjP86wQBOhoIqYqmVXrwU--9JAGElBvYnyFcmmXaVRdYldx8fF0KXFPZVx-O0ydonJJ0PxR5B3XzUv3B_TU15Xvnc-uEatO5JFynhF8Gombw9_jMhmm3Nog7JDDO0kHPybUKYbLMB02nXBSFddYrF8SjJUknYsxII79YNFepnl1jUKi_eFj4EwHfBq0sLKj1b1tBKUjrMSBBZw-7t6DzegfCzSCqsXJsyyE-hb4vjfZBgqUrswJRE5JmQAv:sCeBbbtxCVsU3uZj6Ud6dQ'), + (522, 1, '2:asXyM0G-XlIBTfmx:miPgnyrt3q5l0VzPPrrMITi_pIDksbmVtrATJx7io06EAN4xHb2nNusQahdg7X5mdBE-BJPOViBAsT4IJ6ADAzbT_ZNYstxV6ll6E_4eimtTNFe4UTM4dXsvpVDB3tBoyvnkK73pwqbgKVHK-3b8LDpfqlAr6Ptrdz5glZggykTCE7sCNzOF9jWEnI7vp2e7r7DUoxHZzh4bsbzPZzdd_7nZ24YEhcPF9O8oioV4RtmARkdkQHmj4olM2To:W5Qmtz39HUKTM_KUHpoBBw'), + (522, 2, '2:RHjTGvuuioRH4g_e:GJLw8xu-qk0Tu_SFhbAAsZ149BSwKhimGBgIzYQbWtA_4xAXMwkD5-aGsPBcKnze4PMPYgRPKpEYO2NHfKF2EVfTHRWvnzDnLTbkjamB6oLt-f0_EvaF6brKVME9vYqqYx3CHFGRC5Ytj1z14Cc9yxbXFuWh59GmlqvddyzlLtcbDjlV7PAIzPz_shjiZxl5FpR0ZNkranhHP45tINkyZsZDA-V-ix_I6UZGuUlDcBl4HMcOusxgouqQNCw4Z5W2fNdjvHpa8MGdkCDf4pDfVaxw0Len7dVfARehBrVehPDvp4b5JPuCu7Zo4QgLIE_GX0fw4rGhFrCHxy_YbtiHIHz-3mjThxwzHL6F3JLbwbzA_bQXX7Ryk_MBsRFzSEHXidJZpnPGN-SYykWlqxl8rXhEYOtt1rX6NDYLuj94tPN33c7ND7fZ1ZsThiHyGOzqQEJSEPLO_1UfoBvYp3QuZjWReLF8T4ZutouAnOTE2Ff3alucKItlcATNGU4L8P65-nd9-lij9Sitqi8z2IH4tgEUDguzEE4liyPqjePwiav02IR0bCziDaBZAbzu83LHC9Wn0VKDVULHesNRDLnSlCxXkQrykDY3yV_AngW-uISof_lnJFPCQSlelNoswqY2SdqMVABvmQ_5cYMUuJyIaiM1Q_9AhE9ZoDkiSsjMT8sf_MaZ3MDF9_kMN12wowCTID0cWNC3rFw-T94G4Pv71s-ufdKVJDL_y6NFG2kzeAz6xy-UeprXPnCNFXKHzF69qcgD9cazICZllTYRisYa-oqedlHvABtbNtnbSOuK07RBJLio2gQATkwkGgIribzp-JtnsgqmFE0StPctf0rSLl4OGxP6393SvdopsuFSJyuWkS3we_Pbm2Jh2ckhETEnfVSovmDh5rsBkbXi39XAtGfdHW09ZAt9p9wkzRhTp6MCyJLRpo85t14ElLTwJjKTo_5oPRiWU_ZYxdbsgO6FiQFsCmmHeIX3i_lECqHYrJXdlZnlzjNj6exdiO7e87ljt6UNlDgWd4fRGsOK2YuwK4jh_85sBuQV5rCg2VDruxF6-XYPxdTzkMXq9A:Q1adWaSGK1LOLFsq0bKGIg'); diff --git a/library/recipes/tvmapi/data/tvmdb.credentials b/library/recipes/tvmapi/data/tvmdb.credentials new file mode 100644 index 0000000000..e3a2f0e121 --- /dev/null +++ b/library/recipes/tvmapi/data/tvmdb.credentials @@ -0,0 +1,4 @@ +{ + "db_user": "unused_value", + "db_pass": "unused_value" +} diff --git a/library/recipes/tvmapi/recipe.inc b/library/recipes/tvmapi/recipe.inc new file mode 100644 index 0000000000..b59cc2da2b --- /dev/null +++ b/library/recipes/tvmapi/recipe.inc @@ -0,0 +1,12 @@ +DEPENDS( + contrib/tools/sqlite3 + library/recipes/tvmapi + passport/infra/daemons/tvmapi/daemon +) + +DATA( + arcadia/library/recipes/tvmapi/clients + arcadia/library/recipes/tvmapi/data +) + +USE_RECIPE(library/recipes/tvmapi/tvmapi) diff --git a/library/recipes/tvmapi/ya.make b/library/recipes/tvmapi/ya.make new file mode 100644 index 0000000000..3f0f93ad3b --- /dev/null +++ b/library/recipes/tvmapi/ya.make @@ -0,0 +1,16 @@ +PY3_PROGRAM() + +PY_SRCS(__main__.py) + +PEERDIR( + contrib/python/requests + library/python/testing/recipe + library/python/testing/yatest_common + library/recipes/common +) + +END() + +RECURSE_FOR_TESTS( + ut_simple +) diff --git a/library/recipes/tvmtool/README.md b/library/recipes/tvmtool/README.md new file mode 100644 index 0000000000..222fdc9752 --- /dev/null +++ b/library/recipes/tvmtool/README.md @@ -0,0 +1,81 @@ +tvmtool recipe +--- + +Этот рецепт позволяет в тестах поднять [tvmtool](https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/), который в проде разворачивается на localhost. + +Демон слушает на порте из файла `tvmtool.port`. Для запросов к нему следует использовать AUTHTOKEN из `tvmtool.authtoken`. См. [пример](https://a.yandex-team.ru/arc/trunk/arcadia/library/recipes/tvmtool/ut/test.py). + +Варианты подключения: + 1) `recipe_with_default_cfg.inc` - для запуска демона с дефолтным [конфигом](https://a.yandex-team.ru/arc/trunk/arcadia/library/recipes/tvmtool/tvmtool.default.conf). Например: + ``` + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe_with_default_cfg.inc) + ``` + [Пример](https://a.yandex-team.ru/arc_vcs/library/recipes/tvmtool/examples/ut_simple) +2) `recipe.inc` - для запуска демона со своим конфигом. Например + ``` + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + + USE_RECIPE( + library/recipes/tvmtool/tvmtool + foo/tvmtool.conf + ) + ``` + [Пример](https://a.yandex-team.ru/arc_vcs/library/recipes/tvmtool/examples/ut_with_custom_config) +3) `recipe.inc` + `--with-roles-dir` - запуск со своим конфигом и с поддержкой ролей + ``` + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + + USE_RECIPE( + library/recipes/tvmtool/tvmtool + foo/tvmtool.conf + --with-roles-dir foo/roles + ) + ``` + В каталоге `foo/` ожидается наличие файлов с именами вида `{slug}.json` - для всех slug из tvmtool.conf. + [Пример](https://a.yandex-team.ru/arc_vcs/library/recipes/tvmtool/examples/ut_with_roles) +4) `recipe.inc` + `--with-tvmapi` - для запуска демона, который будет ходить в tvm-api (тоже рецепт). Например: + ``` + # start tvm-api + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + + # start tvmtool + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + USE_RECIPE( + library/recipes/tvmtool/tvmtool + foo/tvmtool.conf + --with-tvmapi + ) + ``` + [Пример](https://a.yandex-team.ru/arc_vcs/library/recipes/tvmtool/examples/ut_with_tvmapi) +5) `recipe.inc` + `--with-tvmapi` + `--with-tirole` - для запуска демона, который будет ходить в tvm-api и tirole (тоже рецепты). Например: + ``` + # start tvm-api + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + + # start tirole + INCLUDE(${ARCADIA_ROOT}/library/recipes/tirole/recipe.inc) + USE_RECIPE( + library/recipes/tirole/tirole + --roles-dir library/recipes/tirole/ut_simple/roles_dir + ) + + # start tvmtool + INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + USE_RECIPE( + library/recipes/tvmtool/tvmtool + foo/tvmtool.conf + --with-tvmapi + --with-tirole + ) + ``` + [Пример](https://a.yandex-team.ru/arc_vcs/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole) + +Варианты 1, 2 и 3 запустят tvmtool с флагом `--unittest`. Это значит, что: + * в конфиге можно указывать какие угодно tvm_id + * в конфиге секрет может быть пустым или равен строке "fake_secret" + +Вариант 4 и 5 запустит tvmtool, который будет ходить в tvm-api. Это значит, он сможет работать только с теми приложениями и их секретами, которые есть в [базе](https://a.yandex-team.ru/arc/trunk/arcadia/library/recipes/tvmapi/clients/clients.json) tvm-api. В этом варианте можно получать ServiceTicket в tvm-api и проверять в tvmtool. + +Любой из этих вариантов позволяет проверять ServiceTicket'ы/UserTicket'ы, сгенерированные через `tvmknife unittest`. + +Вопросы можно писать в [PASSPORTDUTY](https://st.yandex-team.ru/createTicket?queue=PASSPORTDUTY&_form=77618) diff --git a/library/recipes/tvmtool/__main__.py b/library/recipes/tvmtool/__main__.py new file mode 100644 index 0000000000..9fb70253b2 --- /dev/null +++ b/library/recipes/tvmtool/__main__.py @@ -0,0 +1,107 @@ +import argparse +import datetime +import binascii +import os +import requests +import sys + +from library.python.testing.recipe import declare_recipe +from library.recipes.common import start_daemon, stop_daemon +import yatest.common +import yatest.common.network + +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + +TIROLE_PORT_FILE = "tirole.port" +TVMAPI_PORT_FILE = "tvmapi.port" +TVMTOOL_PID_FILE = "tvmtool.pid" + + +def start(argv): + parser = argparse.ArgumentParser() + parser.add_argument('cfgfile', type=str) + parser.add_argument('--with-roles-dir', dest='with_roles_dir', type=str) + parser.add_argument('--with-tirole', dest='with_tirole', action='store_true') + parser.add_argument('--with-tvmapi', dest='with_tvmapi', action='store_true') + input_args = parser.parse_args(argv) + + _log("cfgfile: %s" % input_args.cfgfile) + _log("with-roles-dir: %s" % input_args.with_roles_dir) + _log("with-tirole: %s" % input_args.with_tirole) + _log("with-tvmapi: %s" % input_args.with_tvmapi) + + pm = yatest.common.network.PortManager() + port = pm.get_tcp_port(80) + + with open(TVMTOOL_PORT_FILE, "w") as f: + f.write(str(port)) + _log("port: %d" % port) + + authtoken = binascii.hexlify(os.urandom(16)) + with open(TVMTOOL_AUTHTOKEN_FILE, "wb") as f: + f.write(authtoken) + _log("authtoken: %s" % authtoken) + + args = [ + yatest.common.build_path('passport/infra/daemons/tvmtool/cmd/tvmtool'), + '--port', + str(port), + '-c', + yatest.common.source_path(input_args.cfgfile), + '-v', + '--cache-dir', + './', + ] + env = { + 'QLOUD_TVM_TOKEN': authtoken, + } + + if input_args.with_tvmapi: + with open(TVMAPI_PORT_FILE) as f: + env['__TEST_TVM_API_URL'] = "http://localhost:%s" % f.read() + else: + args.append('--unittest') + + if input_args.with_tirole: + with open(TIROLE_PORT_FILE) as f: + env['__TEST_TIROLE_URL'] = "http://localhost:%s" % f.read() + + if input_args.with_roles_dir: + assert not input_args.with_tirole, "--with-roles-dir and --with-tirole conflicts with each other" + args += [ + '--unittest-roles-dir', + yatest.common.source_path(input_args.with_roles_dir), + ] + + def check(): + try: + r = requests.get("http://localhost:%d/tvm/ping" % port) + if r.status_code == 200: + _log("ping: 200!") + return True + else: + _log("ping: %d : %s" % (r.status_code, r.text)) + except Exception as e: + _log("ping: %s" % e) + return False + + start_daemon(command=args, environment=env, is_alive_check=check, pid_file_name=TVMTOOL_PID_FILE) + + +def stop(argv): + with open(TVMTOOL_PID_FILE) as f: + pid = f.read() + if not stop_daemon(pid): + _log("pid is dead: %s" % pid) + + +def _log(msg): + print("%s : tvmtool-recipe : %s" % (datetime.datetime.now(), msg), file=sys.stdout) + + +if __name__ == "__main__": + try: + declare_recipe(start, stop) + except Exception as e: + _log("exception: %s" % e) diff --git a/library/recipes/tvmtool/a.yaml b/library/recipes/tvmtool/a.yaml new file mode 100644 index 0000000000..1a46b3547f --- /dev/null +++ b/library/recipes/tvmtool/a.yaml @@ -0,0 +1,23 @@ +service: passport_infra +title: tvmtool recipe + +arcanum: + review: + auto_assign: true + + groups: + - name: backend-developers + roles: developer + + rules: + - reviewers: + name: backend-developers + ship: 2 + assign: 2 + +ci: + release-title-source: flow + autocheck: + fast-targets: + - library/recipes/tvmtool + strong: true diff --git a/library/recipes/tvmtool/examples/ut_simple/test.py b/library/recipes/tvmtool/examples/ut_simple/test.py new file mode 100644 index 0000000000..32b44ad564 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_simple/test.py @@ -0,0 +1,26 @@ +import os +import os.path +import requests + +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + + +def _get_tvmtool_params(): + port = int(open(TVMTOOL_PORT_FILE).read()) + authtoken = open(TVMTOOL_AUTHTOKEN_FILE).read() + return port, authtoken + + +def test_tvmtool(): + assert os.path.isfile(TVMTOOL_PORT_FILE) + assert os.path.isfile(TVMTOOL_AUTHTOKEN_FILE) + + port, authtoken = _get_tvmtool_params() + + r = requests.get("http://localhost:%d/tvm/ping" % port) + assert r.text == 'OK' + assert r.status_code == 200 + + r = requests.get("http://localhost:%d/tvm/keys" % port, headers={'Authorization': authtoken}) + assert r.status_code == 200 diff --git a/library/recipes/tvmtool/examples/ut_simple/ya.make b/library/recipes/tvmtool/examples/ut_simple/ya.make new file mode 100644 index 0000000000..c53024fdd5 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_simple/ya.make @@ -0,0 +1,15 @@ +PY3TEST() + +OWNER(g:passport_infra) + +TEST_SRCS( + test.py +) + +PEERDIR( + contrib/python/requests +) + +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe_with_default_cfg.inc) + +END() diff --git a/library/recipes/tvmtool/examples/ut_with_custom_config/custom.cfg b/library/recipes/tvmtool/examples/ut_with_custom_config/custom.cfg new file mode 100644 index 0000000000..e412046c74 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_custom_config/custom.cfg @@ -0,0 +1,8 @@ +{ + "BbEnvType": 1, + "clients": { + "me": { + "self_tvm_id": 42 + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_custom_config/test.py b/library/recipes/tvmtool/examples/ut_with_custom_config/test.py new file mode 100644 index 0000000000..9cbf48dd85 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_custom_config/test.py @@ -0,0 +1,23 @@ +import os +import os.path +import requests + +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + + +def _get_tvmtool_params(): + port = int(open(TVMTOOL_PORT_FILE).read()) + authtoken = open(TVMTOOL_AUTHTOKEN_FILE).read() + return port, authtoken + + +def test_tvmtool(): + assert os.path.isfile(TVMTOOL_PORT_FILE) + assert os.path.isfile(TVMTOOL_AUTHTOKEN_FILE) + + port, authtoken = _get_tvmtool_params() + + r = requests.get("http://localhost:%d/tvm/ping" % port) + assert r.text == 'OK' + assert r.status_code == 200 diff --git a/library/recipes/tvmtool/examples/ut_with_custom_config/ya.make b/library/recipes/tvmtool/examples/ut_with_custom_config/ya.make new file mode 100644 index 0000000000..04d39144b4 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_custom_config/ya.make @@ -0,0 +1,19 @@ +PY3TEST() + +OWNER(g:passport_infra) + +TEST_SRCS( + test.py +) + +PEERDIR( + contrib/python/requests +) + +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + +USE_RECIPE( + library/recipes/tvmtool/tvmtool library/recipes/tvmtool/examples/ut_with_custom_config/custom.cfg +) + +END() diff --git a/library/recipes/tvmtool/examples/ut_with_roles/custom.cfg b/library/recipes/tvmtool/examples/ut_with_roles/custom.cfg new file mode 100644 index 0000000000..8c90b8f62e --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_roles/custom.cfg @@ -0,0 +1,10 @@ +{ + "BbEnvType": 1, + "clients": { + "me": { + "secret": "fake_secret", + "self_tvm_id": 42, + "roles_for_idm_slug": "some_slug" + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_roles/roles/some_slug.json b/library/recipes/tvmtool/examples/ut_with_roles/roles/some_slug.json new file mode 100644 index 0000000000..15d6ff7e13 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_roles/roles/some_slug.json @@ -0,0 +1,19 @@ +{ + "revision": "foobar", + "born_date": 1642160000, + "tvm": { + "101": { + "/role/service/auth_type/with_user/access_type/write/handlers/routes/": [{}], + "/role/service/auth_type/without_user/access_type/read/handlers/blockedphones/": [{}] + }, + "104": { + "/role/service/auth_type/without_user/access_type/read/handlers/routes/": [{}] + } + }, + "user": { + "1120000000000001": { + "/role/user/access_type/read/handlers/all/": [{}], + "/role/user/access_type/write/handlers/all/": [{}] + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_roles/test.py b/library/recipes/tvmtool/examples/ut_with_roles/test.py new file mode 100644 index 0000000000..4a976bc228 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_roles/test.py @@ -0,0 +1,26 @@ +import os +import os.path +import requests + +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + + +def _get_tvmtool_params(): + port = int(open(TVMTOOL_PORT_FILE).read()) + authtoken = open(TVMTOOL_AUTHTOKEN_FILE).read() + return port, authtoken + + +def test_tvmtool(): + assert os.path.isfile(TVMTOOL_PORT_FILE) + assert os.path.isfile(TVMTOOL_AUTHTOKEN_FILE) + + port, authtoken = _get_tvmtool_params() + + r = requests.get("http://localhost:%d/tvm/ping" % port) + assert r.text == 'OK' + assert r.status_code == 200 + + r = requests.get("http://localhost:%d/v2/roles?self=me" % port, headers={'Authorization': authtoken}) + assert r.status_code == 200, r.text diff --git a/library/recipes/tvmtool/examples/ut_with_roles/ya.make b/library/recipes/tvmtool/examples/ut_with_roles/ya.make new file mode 100644 index 0000000000..6d658b7318 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_roles/ya.make @@ -0,0 +1,21 @@ +PY3TEST() + +OWNER(g:passport_infra) + +TEST_SRCS( + test.py +) + +PEERDIR( + contrib/python/requests +) + +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + +USE_RECIPE( + library/recipes/tvmtool/tvmtool + library/recipes/tvmtool/examples/ut_with_roles/custom.cfg + --with-roles-dir library/recipes/tvmtool/examples/ut_with_roles/roles +) + +END() diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi/test.py b/library/recipes/tvmtool/examples/ut_with_tvmapi/test.py new file mode 100644 index 0000000000..f1c8a3fc35 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi/test.py @@ -0,0 +1,49 @@ +import os +import os.path + +import ticket_parser2 as tp2 + + +TVMAPI_PORT_FILE = "tvmapi.port" +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + + +def _get_tvmapi_port(): + with open(TVMAPI_PORT_FILE) as f: + return int(f.read()) + + +def _get_tvmtool_params(): + port = int(open(TVMTOOL_PORT_FILE).read()) + authtoken = open(TVMTOOL_AUTHTOKEN_FILE).read() + return port, authtoken + + +def test_tvmapi(): + assert os.path.isfile(TVMAPI_PORT_FILE) + assert os.path.isfile(TVMTOOL_PORT_FILE) + assert os.path.isfile(TVMTOOL_AUTHTOKEN_FILE) + + port = _get_tvmapi_port() + + cs = tp2.TvmApiClientSettings( + self_client_id=1000501, + self_secret='bAicxJVa5uVY7MjDlapthw', + dsts={'my backend': 1000502}, + enable_service_ticket_checking=True, + ) + cs.__set_localhost(port) + + ca = tp2.TvmClient(cs) + assert ca.status == tp2.TvmClientStatus.Ok + + port, authtoken = _get_tvmtool_params() + ct = tp2.TvmClient(tp2.TvmToolClientSettings("me", auth_token=authtoken, port=port)) + assert ct.status == tp2.TvmClientStatus.Ok + + st = ca.check_service_ticket(ct.get_service_ticket_for(client_id=1000501)) + assert st.src == 1000503 + + ct.stop() + ca.stop() diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi/tvmtool.conf b/library/recipes/tvmtool/examples/ut_with_tvmapi/tvmtool.conf new file mode 100644 index 0000000000..18332548ac --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi/tvmtool.conf @@ -0,0 +1,17 @@ +{ + "BbEnvType": 1, + "clients": { + "me": { + "secret": "S3TyTYVqjlbsflVEwxj33w", + "self_tvm_id": 1000503, + "dsts": { + "he": { + "dst_id": 1000504 + }, + "she": { + "dst_id": 1000501 + } + } + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi/ya.make b/library/recipes/tvmtool/examples/ut_with_tvmapi/ya.make new file mode 100644 index 0000000000..10c710b141 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi/ya.make @@ -0,0 +1,22 @@ +PY3TEST() + +OWNER(g:passport_infra) + +TEST_SRCS(test.py) + +PEERDIR( + library/python/deprecated/ticket_parser2 +) + +# common usage +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + +# tvmtool for connoisseurs +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) +USE_RECIPE( + library/recipes/tvmtool/tvmtool + library/recipes/tvmtool/examples/ut_with_tvmapi/tvmtool.conf + --with-tvmapi +) + +END() diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/mapping.yaml b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/mapping.yaml new file mode 100644 index 0000000000..d2fcaead59 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/mapping.yaml @@ -0,0 +1,5 @@ +slugs: + some_slug_2: + tvmid: + - 1000502 + - 1000503 diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/some_slug_2.json b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/some_slug_2.json new file mode 100644 index 0000000000..27e38c5bc1 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir/some_slug_2.json @@ -0,0 +1,14 @@ +{ + "revision": "some_revision_2", + "born_date": 1642160002, + "tvm": { + "1000501": { + "/role/service/auth_type/without_user/access_type/read/handlers/routes/": [{}] + } + }, + "user": { + "1120000000000001": { + "/role/user/access_type/write/handlers/all/": [{}] + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/test.py b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/test.py new file mode 100644 index 0000000000..7700a45160 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/test.py @@ -0,0 +1,50 @@ +import os +import os.path + +import tvmauth + + +TVMAPI_PORT_FILE = "tvmapi.port" +TVMTOOL_PORT_FILE = "tvmtool.port" +TVMTOOL_AUTHTOKEN_FILE = "tvmtool.authtoken" + + +def _get_tvmapi_port(): + with open(TVMAPI_PORT_FILE) as f: + return int(f.read()) + + +def _get_tvmtool_params(): + tvmtool_port = int(open(TVMTOOL_PORT_FILE).read()) + authtoken = open(TVMTOOL_AUTHTOKEN_FILE).read() + return tvmtool_port, authtoken + + +def test_tvmapi(): + assert os.path.isfile(TVMAPI_PORT_FILE) + assert os.path.isfile(TVMTOOL_PORT_FILE) + assert os.path.isfile(TVMTOOL_AUTHTOKEN_FILE) + + ca = tvmauth.TvmClient( + tvmauth.TvmApiClientSettings( + self_tvm_id=1000501, + self_secret='bAicxJVa5uVY7MjDlapthw', + disk_cache_dir="./", + dsts={'my backend': 1000502}, + localhost_port=_get_tvmapi_port(), + ) + ) + assert ca.status == tvmauth.TvmClientStatus.Ok + + tvmtool_port, authtoken = _get_tvmtool_params() + ct = tvmauth.TvmClient(tvmauth.TvmToolClientSettings("me", auth_token=authtoken, port=tvmtool_port)) + assert ct.status == tvmauth.TvmClientStatus.Ok + + st = ct.check_service_ticket(ca.get_service_ticket_for('my backend')) + assert st.src == 1000501 + + expected_role = '/role/service/auth_type/without_user/access_type/read/handlers/routes/' + assert expected_role in ct.get_roles().get_service_roles(st) + + ct.stop() + ca.stop() diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/tvmtool.conf b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/tvmtool.conf new file mode 100644 index 0000000000..bb2827199b --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/tvmtool.conf @@ -0,0 +1,10 @@ +{ + "BbEnvType": 1, + "clients": { + "me": { + "secret": "e5kL0vM3nP-nPf-388Hi6Q", + "roles_for_idm_slug": "some_slug_2", + "self_tvm_id": 1000502 + } + } +} diff --git a/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/ya.make b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/ya.make new file mode 100644 index 0000000000..09e9694b63 --- /dev/null +++ b/library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/ya.make @@ -0,0 +1,30 @@ +PY3TEST() + +OWNER(g:passport_infra) + +TEST_SRCS(test.py) + +PEERDIR( + library/python/tvmauth +) + +# start tvm-api +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmapi/recipe.inc) + +# start tirole +INCLUDE(${ARCADIA_ROOT}/library/recipes/tirole/recipe.inc) +USE_RECIPE( + library/recipes/tirole/tirole + --roles-dir library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/roles_dir +) + +# tvmtool for connoisseurs +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) +USE_RECIPE( + library/recipes/tvmtool/tvmtool + library/recipes/tvmtool/examples/ut_with_tvmapi_and_tirole/tvmtool.conf + --with-tvmapi + --with-tirole +) + +END() diff --git a/library/recipes/tvmtool/examples/ya.make b/library/recipes/tvmtool/examples/ya.make new file mode 100644 index 0000000000..624c084ab5 --- /dev/null +++ b/library/recipes/tvmtool/examples/ya.make @@ -0,0 +1,14 @@ +OWNER(g:passport_infra) + +RECURSE_FOR_TESTS( + ut_simple + ut_with_custom_config + ut_with_roles + ut_with_tvmapi +) + +IF (NOT SANITIZER_TYPE) + RECURSE( + ut_with_tvmapi_and_tirole + ) +ENDIF() diff --git a/library/recipes/tvmtool/recipe.inc b/library/recipes/tvmtool/recipe.inc new file mode 100644 index 0000000000..2e648e109d --- /dev/null +++ b/library/recipes/tvmtool/recipe.inc @@ -0,0 +1,8 @@ +DEPENDS( + library/recipes/tvmtool + passport/infra/daemons/tvmtool/cmd +) + +DATA( + arcadia/library/recipes/tvmtool +) diff --git a/library/recipes/tvmtool/recipe_with_default_cfg.inc b/library/recipes/tvmtool/recipe_with_default_cfg.inc new file mode 100644 index 0000000000..dd979322b6 --- /dev/null +++ b/library/recipes/tvmtool/recipe_with_default_cfg.inc @@ -0,0 +1,5 @@ +INCLUDE(${ARCADIA_ROOT}/library/recipes/tvmtool/recipe.inc) + +USE_RECIPE( + library/recipes/tvmtool/tvmtool library/recipes/tvmtool/tvmtool.default.conf +) diff --git a/library/recipes/tvmtool/tvmtool.default.conf b/library/recipes/tvmtool/tvmtool.default.conf new file mode 100644 index 0000000000..1de49f9d1c --- /dev/null +++ b/library/recipes/tvmtool/tvmtool.default.conf @@ -0,0 +1,17 @@ +{ + "BbEnvType": 1, + "clients": { + "me": { + "secret": "fake_secret", + "self_tvm_id": 42, + "dsts": { + "he": { + "dst_id": 100500 + }, + "she": { + "dst_id": 100501 + } + } + } + } +} diff --git a/library/recipes/tvmtool/ya.make b/library/recipes/tvmtool/ya.make new file mode 100644 index 0000000000..ea0e14e5ca --- /dev/null +++ b/library/recipes/tvmtool/ya.make @@ -0,0 +1,16 @@ +PY3_PROGRAM() + +PY_SRCS(__main__.py) + +PEERDIR( + contrib/python/requests + library/python/testing/recipe + library/python/testing/yatest_common + library/recipes/common +) + +END() + +RECURSE_FOR_TESTS( + examples +) |