diff options
author | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
---|---|---|
committer | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
commit | 22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch) | |
tree | bffa27765faf54126ad44bcafa89fadecb7a73d7 /library/go | |
parent | 332b99e2173f0425444abb759eebcb2fafaa9209 (diff) | |
download | ydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz |
validate canons without yatest_common
Diffstat (limited to 'library/go')
163 files changed, 16529 insertions, 0 deletions
diff --git a/library/go/blockcodecs/all/all.go b/library/go/blockcodecs/all/all.go new file mode 100644 index 0000000000..75c6ef285b --- /dev/null +++ b/library/go/blockcodecs/all/all.go @@ -0,0 +1,8 @@ +package all + +import ( + _ "a.yandex-team.ru/library/go/blockcodecs/blockbrotli" + _ "a.yandex-team.ru/library/go/blockcodecs/blocklz4" + _ "a.yandex-team.ru/library/go/blockcodecs/blocksnappy" + _ "a.yandex-team.ru/library/go/blockcodecs/blockzstd" +) diff --git a/library/go/blockcodecs/blockbrotli/brotli.go b/library/go/blockcodecs/blockbrotli/brotli.go new file mode 100644 index 0000000000..0597a24e08 --- /dev/null +++ b/library/go/blockcodecs/blockbrotli/brotli.go @@ -0,0 +1,95 @@ +package blockbrotli + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/andybalholm/brotli" + + "a.yandex-team.ru/library/go/blockcodecs" +) + +type brotliCodec int + +func (b brotliCodec) ID() blockcodecs.CodecID { + switch b { + case 1: + return 48947 + case 10: + return 43475 + case 11: + return 7241 + case 2: + return 63895 + case 3: + return 11408 + case 4: + return 47136 + case 5: + return 45284 + case 6: + return 63219 + case 7: + return 59675 + case 8: + return 40233 + case 9: + return 10380 + default: + panic("unsupported level") + } +} + +func (b brotliCodec) Name() string { + return fmt.Sprintf("brotli_%d", b) +} + +func (b brotliCodec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (b brotliCodec) Encode(dst, src []byte) ([]byte, error) { + if cap(dst) < 8 { + dst = make([]byte, 8) + } + + dst = dst[:8] + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + wb := bytes.NewBuffer(dst) + w := brotli.NewWriterLevel(wb, int(b)) + + if _, err := w.Write(src); err != nil { + return nil, err + } + + if err := w.Close(); err != nil { + return nil, err + } + + return wb.Bytes(), nil +} + +func (b brotliCodec) Decode(dst, src []byte) ([]byte, error) { + if len(src) < 8 { + return nil, fmt.Errorf("short block: %d < 8", len(src)) + } + + rb := bytes.NewBuffer(src[8:]) + r := brotli.NewReader(rb) + + _, err := io.ReadFull(r, dst) + if err != nil { + return nil, err + } + + return dst, nil +} + +func init() { + for i := 1; i <= 11; i++ { + blockcodecs.Register(brotliCodec(i)) + } +} diff --git a/library/go/blockcodecs/blocklz4/lz4.go b/library/go/blockcodecs/blocklz4/lz4.go new file mode 100644 index 0000000000..bb549b9beb --- /dev/null +++ b/library/go/blockcodecs/blocklz4/lz4.go @@ -0,0 +1,82 @@ +package blocklz4 + +import ( + "encoding/binary" + + "github.com/pierrec/lz4" + + "a.yandex-team.ru/library/go/blockcodecs" +) + +type lz4Codec struct{} + +func (l lz4Codec) ID() blockcodecs.CodecID { + return 6051 +} + +func (l lz4Codec) Name() string { + return "lz4-fast14-safe" +} + +func (l lz4Codec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (l lz4Codec) Encode(dst, src []byte) ([]byte, error) { + dst = dst[:cap(dst)] + + n := lz4.CompressBlockBound(len(src)) + 8 + if len(dst) < n { + dst = append(dst, make([]byte, n-len(dst))...) + } + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + m, err := lz4.CompressBlock(src, dst[8:], nil) + if err != nil { + return nil, err + } + + return dst[:8+m], nil +} + +func (l lz4Codec) Decode(dst, src []byte) ([]byte, error) { + n, err := lz4.UncompressBlock(src[8:], dst) + if err != nil { + return nil, err + } + return dst[:n], nil +} + +type lz4HCCodec struct { + lz4Codec +} + +func (l lz4HCCodec) ID() blockcodecs.CodecID { + return 62852 +} + +func (l lz4HCCodec) Name() string { + return "lz4-hc-safe" +} + +func (l lz4HCCodec) Encode(dst, src []byte) ([]byte, error) { + dst = dst[:cap(dst)] + + n := lz4.CompressBlockBound(len(src)) + 8 + if len(dst) < n { + dst = append(dst, make([]byte, n-len(dst))...) + } + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + m, err := lz4.CompressBlockHC(src, dst[8:], 0) + if err != nil { + return nil, err + } + + return dst[:8+m], nil +} + +func init() { + blockcodecs.Register(lz4Codec{}) + blockcodecs.Register(lz4HCCodec{}) +} diff --git a/library/go/blockcodecs/blocksnappy/snappy.go b/library/go/blockcodecs/blocksnappy/snappy.go new file mode 100644 index 0000000000..0f6d22cde9 --- /dev/null +++ b/library/go/blockcodecs/blocksnappy/snappy.go @@ -0,0 +1,33 @@ +package blocksnappy + +import ( + "github.com/golang/snappy" + + "a.yandex-team.ru/library/go/blockcodecs" +) + +type snappyCodec struct{} + +func (s snappyCodec) ID() blockcodecs.CodecID { + return 50986 +} + +func (s snappyCodec) Name() string { + return "snappy" +} + +func (s snappyCodec) DecodedLen(in []byte) (int, error) { + return snappy.DecodedLen(in) +} + +func (s snappyCodec) Encode(dst, src []byte) ([]byte, error) { + return snappy.Encode(dst, src), nil +} + +func (s snappyCodec) Decode(dst, src []byte) ([]byte, error) { + return snappy.Decode(dst, src) +} + +func init() { + blockcodecs.Register(snappyCodec{}) +} diff --git a/library/go/blockcodecs/blockzstd/zstd.go b/library/go/blockcodecs/blockzstd/zstd.go new file mode 100644 index 0000000000..7822672b3d --- /dev/null +++ b/library/go/blockcodecs/blockzstd/zstd.go @@ -0,0 +1,73 @@ +package blockzstd + +import ( + "encoding/binary" + "fmt" + + "github.com/klauspost/compress/zstd" + + "a.yandex-team.ru/library/go/blockcodecs" +) + +type zstdCodec int + +func (z zstdCodec) ID() blockcodecs.CodecID { + switch z { + case 1: + return 55019 + case 3: + return 23308 + case 7: + return 33533 + default: + panic("unsupported level") + } +} + +func (z zstdCodec) Name() string { + return fmt.Sprintf("zstd08_%d", z) +} + +func (z zstdCodec) DecodedLen(in []byte) (int, error) { + return blockcodecs.DecodedLen(in) +} + +func (z zstdCodec) Encode(dst, src []byte) ([]byte, error) { + if cap(dst) < 8 { + dst = make([]byte, 8) + } + + dst = dst[:8] + binary.LittleEndian.PutUint64(dst, uint64(len(src))) + + w, err := zstd.NewWriter(nil, + zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(int(z))), + zstd.WithEncoderConcurrency(1)) + if err != nil { + return nil, err + } + + defer w.Close() + return w.EncodeAll(src, dst), nil +} + +func (z zstdCodec) Decode(dst, src []byte) ([]byte, error) { + if len(src) < 8 { + return nil, fmt.Errorf("short block: %d < 8", len(src)) + } + + r, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1)) + if err != nil { + return nil, err + } + + defer r.Close() + return r.DecodeAll(src[8:], dst[:0]) +} + +func init() { + for _, i := range []int{1, 3, 7} { + blockcodecs.Register(zstdCodec(i)) + blockcodecs.RegisterAlias(fmt.Sprintf("zstd_%d", i), zstdCodec(i)) + } +} diff --git a/library/go/blockcodecs/codecs.go b/library/go/blockcodecs/codecs.go new file mode 100644 index 0000000000..b45bda6d61 --- /dev/null +++ b/library/go/blockcodecs/codecs.go @@ -0,0 +1,89 @@ +package blockcodecs + +import ( + "encoding/binary" + "fmt" + "sync" + + "go.uber.org/atomic" +) + +type CodecID uint16 + +type Codec interface { + ID() CodecID + Name() string + + DecodedLen(in []byte) (int, error) + Encode(dst, src []byte) ([]byte, error) + Decode(dst, src []byte) ([]byte, error) +} + +var ( + codecsByID sync.Map + codecsByName sync.Map +) + +// Register new codec. +// +// NOTE: update FindCodecByName description, after adding new codecs. +func Register(c Codec) { + if _, duplicate := codecsByID.LoadOrStore(c.ID(), c); duplicate { + panic(fmt.Sprintf("codec with id %d is already registered", c.ID())) + } + + RegisterAlias(c.Name(), c) +} + +func RegisterAlias(name string, c Codec) { + if _, duplicate := codecsByName.LoadOrStore(name, c); duplicate { + panic(fmt.Sprintf("codec with name %s is already registered", c.Name())) + } +} + +func ListCodecs() []Codec { + var c []Codec + codecsByID.Range(func(key, value interface{}) bool { + c = append(c, value.(Codec)) + return true + }) + return c +} + +func FindCodec(id CodecID) Codec { + c, ok := codecsByID.Load(id) + if ok { + return c.(Codec) + } else { + return nil + } +} + +// FindCodecByName returns codec by name. +// +// Possible names: +// +// null +// snappy +// zstd08_{level} - level is integer 1, 3 or 7. +// zstd_{level} - level is integer 1, 3 or 7. +func FindCodecByName(name string) Codec { + c, ok := codecsByName.Load(name) + if ok { + return c.(Codec) + } else { + return nil + } +} + +var ( + maxDecompressedBlockSize = atomic.NewInt32(16 << 20) // 16 MB +) + +func DecodedLen(in []byte) (int, error) { + if len(in) < 8 { + return 0, fmt.Errorf("short block: %d < 8", len(in)) + } + + return int(binary.LittleEndian.Uint64(in[:8])), nil +} diff --git a/library/go/blockcodecs/decoder.go b/library/go/blockcodecs/decoder.go new file mode 100644 index 0000000000..bb38dcf844 --- /dev/null +++ b/library/go/blockcodecs/decoder.go @@ -0,0 +1,155 @@ +package blockcodecs + +import ( + "encoding/binary" + "fmt" + "io" +) + +type Decoder struct { + // optional + codec Codec + + r io.Reader + header [10]byte + eof bool + checkEOF bool + + pos int + buffer []byte + + scratch []byte +} + +func (d *Decoder) getCodec(id CodecID) (Codec, error) { + if d.codec != nil { + if id != d.codec.ID() { + return nil, fmt.Errorf("blockcodecs: received block codec differs from provided: %d != %d", id, d.codec.ID()) + } + + return d.codec, nil + } + + if codec := FindCodec(id); codec != nil { + return codec, nil + } + + return nil, fmt.Errorf("blockcodecs: received block with unsupported codec %d", id) +} + +// SetCheckUnderlyingEOF changes EOF handling. +// +// Blockcodecs format contains end of stream separator. By default Decoder will stop right after +// that separator, without trying to read following bytes from underlying reader. +// +// That allows reading sequence of blockcodecs streams from one underlying stream of bytes, +// but messes up HTTP keep-alive, when using blockcodecs together with net/http connection pool. +// +// Setting CheckUnderlyingEOF to true, changes that. After encoutering end of stream block, +// Decoder will perform one more Read from underlying reader and check for io.EOF. +func (d *Decoder) SetCheckUnderlyingEOF(checkEOF bool) { + d.checkEOF = checkEOF +} + +func (d *Decoder) Read(p []byte) (int, error) { + if d.eof { + return 0, io.EOF + } + + if d.pos == len(d.buffer) { + if _, err := io.ReadFull(d.r, d.header[:]); err != nil { + return 0, fmt.Errorf("blockcodecs: invalid header: %w", err) + } + + codecID := CodecID(binary.LittleEndian.Uint16(d.header[:2])) + size := int(binary.LittleEndian.Uint64(d.header[2:])) + + codec, err := d.getCodec(codecID) + if err != nil { + return 0, err + } + + if limit := int(maxDecompressedBlockSize.Load()); size > limit { + return 0, fmt.Errorf("blockcodecs: block size exceeds limit: %d > %d", size, limit) + } + + if len(d.scratch) < size { + d.scratch = append(d.scratch, make([]byte, size-len(d.scratch))...) + } + d.scratch = d.scratch[:size] + + if _, err := io.ReadFull(d.r, d.scratch[:]); err != nil { + return 0, fmt.Errorf("blockcodecs: truncated block: %w", err) + } + + decodedSize, err := codec.DecodedLen(d.scratch[:]) + if err != nil { + return 0, fmt.Errorf("blockcodecs: corrupted block: %w", err) + } + + if decodedSize == 0 { + if d.checkEOF { + var scratch [1]byte + n, err := d.r.Read(scratch[:]) + if n != 0 { + return 0, fmt.Errorf("blockcodecs: data after EOF block") + } + if err != nil && err != io.EOF { + return 0, fmt.Errorf("blockcodecs: error after EOF block: %v", err) + } + } + + d.eof = true + return 0, io.EOF + } + + if limit := int(maxDecompressedBlockSize.Load()); decodedSize > limit { + return 0, fmt.Errorf("blockcodecs: decoded block size exceeds limit: %d > %d", decodedSize, limit) + } + + decodeInto := func(buf []byte) error { + out, err := codec.Decode(buf, d.scratch) + if err != nil { + return fmt.Errorf("blockcodecs: corrupted block: %w", err) + } else if len(out) != decodedSize { + return fmt.Errorf("blockcodecs: incorrect block size: %d != %d", len(out), decodedSize) + } + + return nil + } + + if len(p) >= decodedSize { + if err := decodeInto(p[:decodedSize]); err != nil { + return 0, err + } + + return decodedSize, nil + } + + if len(d.buffer) < decodedSize { + d.buffer = append(d.buffer, make([]byte, decodedSize-len(d.buffer))...) + } + d.buffer = d.buffer[:decodedSize] + d.pos = decodedSize + + if err := decodeInto(d.buffer); err != nil { + return 0, err + } + + d.pos = 0 + } + + n := copy(p, d.buffer[d.pos:]) + d.pos += n + return n, nil +} + +// NewDecoder creates decoder that supports input in any of registered codecs. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// NewDecoderCodec creates decode that tries to decode input using provided codec. +func NewDecoderCodec(r io.Reader, codec Codec) *Decoder { + return &Decoder{r: r, codec: codec} +} diff --git a/library/go/blockcodecs/encoder.go b/library/go/blockcodecs/encoder.go new file mode 100644 index 0000000000..b7bb154f79 --- /dev/null +++ b/library/go/blockcodecs/encoder.go @@ -0,0 +1,139 @@ +package blockcodecs + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type encoder struct { + w io.Writer + codec Codec + + closed bool + header [10]byte + + buf []byte + pos int + + scratch []byte +} + +const ( + // defaultBufferSize is 32KB, same as size of buffer used in io.Copy. + defaultBufferSize = 32 << 10 +) + +var ( + _ io.WriteCloser = (*encoder)(nil) +) + +func (e *encoder) Write(p []byte) (int, error) { + if e.closed { + return 0, errors.New("blockcodecs: encoder is closed") + } + + n := len(p) + + // Complete current block + if e.pos != 0 { + m := copy(e.buf[e.pos:], p) + p = p[m:] + e.pos += m + + if e.pos == len(e.buf) { + e.pos = 0 + + if err := e.doFlush(e.buf); err != nil { + return 0, err + } + } + } + + // Copy huge input directly to output + for len(p) >= len(e.buf) { + if e.pos != 0 { + panic("broken invariant") + } + + var chunk []byte + if len(p) > len(e.buf) { + chunk = p[:len(e.buf)] + p = p[len(e.buf):] + } else { + chunk = p + p = nil + } + + if err := e.doFlush(chunk); err != nil { + return 0, err + } + } + + // Store suffix in buffer + m := copy(e.buf, p) + e.pos += m + if m != len(p) { + panic("broken invariant") + } + + return n, nil +} + +func (e *encoder) Close() error { + if e.closed { + return nil + } + + if err := e.Flush(); err != nil { + return err + } + + e.closed = true + + return e.doFlush(nil) +} + +func (e *encoder) doFlush(block []byte) error { + var err error + e.scratch, err = e.codec.Encode(e.scratch, block) + if err != nil { + return fmt.Errorf("blockcodecs: block compression error: %w", err) + } + + binary.LittleEndian.PutUint16(e.header[:2], uint16(e.codec.ID())) + binary.LittleEndian.PutUint64(e.header[2:], uint64(len(e.scratch))) + + if _, err := e.w.Write(e.header[:]); err != nil { + return err + } + + if _, err := e.w.Write(e.scratch); err != nil { + return err + } + + return nil +} + +func (e *encoder) Flush() error { + if e.closed { + return errors.New("blockcodecs: flushing closed encoder") + } + + if e.pos == 0 { + return nil + } + + err := e.doFlush(e.buf[:e.pos]) + e.pos = 0 + return err +} + +func NewEncoder(w io.Writer, codec Codec) io.WriteCloser { + return NewEncoderBuffer(w, codec, defaultBufferSize) +} + +func NewEncoderBuffer(w io.Writer, codec Codec, bufferSize int) io.WriteCloser { + return &encoder{w: w, codec: codec, buf: make([]byte, bufferSize)} +} diff --git a/library/go/blockcodecs/nop_codec.go b/library/go/blockcodecs/nop_codec.go new file mode 100644 index 0000000000..c15e65a29e --- /dev/null +++ b/library/go/blockcodecs/nop_codec.go @@ -0,0 +1,27 @@ +package blockcodecs + +type nopCodec struct{} + +func (n nopCodec) ID() CodecID { + return 54476 +} + +func (n nopCodec) Name() string { + return "null" +} + +func (n nopCodec) DecodedLen(in []byte) (int, error) { + return len(in), nil +} + +func (n nopCodec) Encode(dst, src []byte) ([]byte, error) { + return append(dst[:0], src...), nil +} + +func (n nopCodec) Decode(dst, src []byte) ([]byte, error) { + return append(dst[:0], src...), nil +} + +func init() { + Register(nopCodec{}) +} diff --git a/library/go/certifi/cas.go b/library/go/certifi/cas.go new file mode 100644 index 0000000000..a195e28a54 --- /dev/null +++ b/library/go/certifi/cas.go @@ -0,0 +1,35 @@ +package certifi + +import ( + "crypto/x509" + "sync" + + "a.yandex-team.ru/library/go/certifi/internal/certs" +) + +var ( + internalOnce sync.Once + commonOnce sync.Once + internalCAs []*x509.Certificate + commonCAs []*x509.Certificate +) + +// InternalCAs returns list of Yandex Internal certificates +func InternalCAs() []*x509.Certificate { + internalOnce.Do(initInternalCAs) + return internalCAs +} + +// CommonCAs returns list of common certificates +func CommonCAs() []*x509.Certificate { + commonOnce.Do(initCommonCAs) + return commonCAs +} + +func initInternalCAs() { + internalCAs = certsFromPEM(certs.InternalCAs()) +} + +func initCommonCAs() { + commonCAs = certsFromPEM(certs.CommonCAs()) +} diff --git a/library/go/certifi/certifi.go b/library/go/certifi/certifi.go new file mode 100644 index 0000000000..e969263883 --- /dev/null +++ b/library/go/certifi/certifi.go @@ -0,0 +1,80 @@ +package certifi + +import ( + "crypto/x509" + "os" +) + +var underYaMake = true + +// NewCertPool returns a copy of the system or bundled cert pool. +// +// Default behavior can be modified with env variable, e.g. use system pool: +// +// CERTIFI_USE_SYSTEM_CA=yes ./my-cool-program +func NewCertPool() (caCertPool *x509.CertPool, err error) { + if forceSystem() { + return NewCertPoolSystem() + } + + return NewCertPoolBundled() +} + +// NewCertPoolSystem returns a copy of the system cert pool + common CAs + internal CAs +// +// WARNING: system cert pool is not available on Windows +func NewCertPoolSystem() (caCertPool *x509.CertPool, err error) { + caCertPool, err = x509.SystemCertPool() + + if err != nil || caCertPool == nil { + caCertPool = x509.NewCertPool() + } + + for _, cert := range CommonCAs() { + caCertPool.AddCert(cert) + } + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +// NewCertPoolBundled returns a new cert pool with common CAs + internal CAs +func NewCertPoolBundled() (caCertPool *x509.CertPool, err error) { + caCertPool = x509.NewCertPool() + + for _, cert := range CommonCAs() { + caCertPool.AddCert(cert) + } + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +// NewCertPoolInternal returns a new cert pool with internal CAs +func NewCertPoolInternal() (caCertPool *x509.CertPool, err error) { + caCertPool = x509.NewCertPool() + + for _, cert := range InternalCAs() { + caCertPool.AddCert(cert) + } + + return caCertPool, nil +} + +func forceSystem() bool { + if os.Getenv("CERTIFI_USE_SYSTEM_CA") == "yes" { + return true + } + + if !underYaMake && len(InternalCAs()) == 0 { + return true + } + + return false +} diff --git a/library/go/certifi/doc.go b/library/go/certifi/doc.go new file mode 100644 index 0000000000..76e963b2cc --- /dev/null +++ b/library/go/certifi/doc.go @@ -0,0 +1,4 @@ +// Certifi is a collection of public and internal Root Certificates for validating the trustworthiness of SSL certificates while verifying the identity of TLS hosts. +// +// Certifi use Arcadia Root Certificates for that: https://a.yandex-team.ru/arc/trunk/arcadia/certs +package certifi diff --git a/library/go/certifi/internal/certs/certs.go b/library/go/certifi/internal/certs/certs.go new file mode 100644 index 0000000000..b2c6443d21 --- /dev/null +++ b/library/go/certifi/internal/certs/certs.go @@ -0,0 +1,13 @@ +package certs + +import ( + "a.yandex-team.ru/library/go/core/resource" +) + +func InternalCAs() []byte { + return resource.Get("/certifi/internal.pem") +} + +func CommonCAs() []byte { + return resource.Get("/certifi/common.pem") +} diff --git a/library/go/certifi/utils.go b/library/go/certifi/utils.go new file mode 100644 index 0000000000..76d90e3f1f --- /dev/null +++ b/library/go/certifi/utils.go @@ -0,0 +1,29 @@ +package certifi + +import ( + "crypto/x509" + "encoding/pem" +) + +func certsFromPEM(pemCerts []byte) []*x509.Certificate { + var result []*x509.Certificate + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + continue + } + + result = append(result, cert) + } + + return result +} diff --git a/library/go/cgosem/sem.go b/library/go/cgosem/sem.go new file mode 100644 index 0000000000..357e0529db --- /dev/null +++ b/library/go/cgosem/sem.go @@ -0,0 +1,67 @@ +// Package cgosem implements fast and imprecise semaphore used to globally limit concurrency of _fast_ cgo calls. +// +// In the future, when go runtime scheduler gets smarter and stop suffering from uncontrolled growth the number of +// system threads, this package should be removed. +// +// See "Cgoroutines != Goroutines" section of https://www.cockroachlabs.com/blog/the-cost-and-complexity-of-cgo/ +// for explanation of the thread leak problem. +// +// To use this semaphore, put the following line at the beginning of the function doing Cgo calls. +// +// defer cgosem.S.Acquire().Release() +// +// This will globally limit number of concurrent Cgo calls to GOMAXPROCS, limiting number of additional threads created by the +// go runtime to the same number. +// +// Overhead of this semaphore is about 1us, which should be negligible compared to the work you are trying to do in the C function. +// +// To see code in action, run: +// +// ya make -r library/go/cgosem/gotest +// env GODEBUG=schedtrace=1000,scheddetail=1 library/go/cgosem/gotest/gotest --test.run TestLeak +// env GODEBUG=schedtrace=1000,scheddetail=1 library/go/cgosem/gotest/gotest --test.run TestLeakFix +// +// And look for the number of created M's. +package cgosem + +import "runtime" + +type Sem chan struct{} + +// new creates new semaphore with max concurrency of n. +func newSem(n int) (s Sem) { + s = make(chan struct{}, n) + for i := 0; i < n; i++ { + s <- struct{}{} + } + return +} + +func (s Sem) Acquire() Sem { + if s == nil { + return nil + } + + <-s + return s +} + +func (s Sem) Release() { + if s == nil { + return + } + + s <- struct{}{} +} + +// S is global semaphore with good enough settings for most cgo libraries. +var S Sem + +// Disable global cgo semaphore. Must be called from init() function. +func Disable() { + S = nil +} + +func init() { + S = newSem(runtime.GOMAXPROCS(0)) +} diff --git a/library/go/core/log/ctxlog/ctxlog.go b/library/go/core/log/ctxlog/ctxlog.go new file mode 100644 index 0000000000..185c4cf5e7 --- /dev/null +++ b/library/go/core/log/ctxlog/ctxlog.go @@ -0,0 +1,124 @@ +package ctxlog + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" +) + +type ctxKey struct{} + +// ContextFields returns log.Fields bound with ctx. +// If no fields are bound, it returns nil. +func ContextFields(ctx context.Context) []log.Field { + fs, _ := ctx.Value(ctxKey{}).([]log.Field) + return fs +} + +// WithFields returns a new context that is bound with given fields and based +// on parent ctx. +func WithFields(ctx context.Context, fields ...log.Field) context.Context { + if len(fields) == 0 { + return ctx + } + + return context.WithValue(ctx, ctxKey{}, mergeFields(ContextFields(ctx), fields)) +} + +// Trace logs at Trace log level using fields both from arguments and ones that +// are bound to ctx. +func Trace(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Trace(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Debug logs at Debug log level using fields both from arguments and ones that +// are bound to ctx. +func Debug(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Debug(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Info logs at Info log level using fields both from arguments and ones that +// are bound to ctx. +func Info(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Info(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Warn logs at Warn log level using fields both from arguments and ones that +// are bound to ctx. +func Warn(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Warn(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Error logs at Error log level using fields both from arguments and ones that +// are bound to ctx. +func Error(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Error(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Fatal logs at Fatal log level using fields both from arguments and ones that +// are bound to ctx. +func Fatal(ctx context.Context, l log.Logger, msg string, fields ...log.Field) { + log.AddCallerSkip(l, 1).Fatal(msg, mergeFields(ContextFields(ctx), fields)...) +} + +// Tracef logs at Trace log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Tracef(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Trace(msg, ContextFields(ctx)...) +} + +// Debugf logs at Debug log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Debugf(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Debug(msg, ContextFields(ctx)...) +} + +// Infof logs at Info log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Infof(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Info(msg, ContextFields(ctx)...) +} + +// Warnf logs at Warn log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Warnf(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Warn(msg, ContextFields(ctx)...) +} + +// Errorf logs at Error log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Errorf(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Error(msg, ContextFields(ctx)...) +} + +// Fatalf logs at Fatal log level using fields that are bound to ctx. +// The message is formatted using provided arguments. +func Fatalf(ctx context.Context, l log.Logger, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.AddCallerSkip(l, 1).Fatal(msg, ContextFields(ctx)...) +} + +func mergeFields(a, b []log.Field) []log.Field { + if a == nil { + return b + } + if b == nil { + return a + } + + // NOTE: just append() here is unsafe. If a caller passed slice of fields + // followed by ... with capacity greater than length, then simultaneous + // logging will lead to a data race condition. + // + // See https://golang.org/ref/spec#Passing_arguments_to_..._parameters + c := make([]log.Field, len(a)+len(b)) + n := copy(c, a) + copy(c[n:], b) + return c +} diff --git a/library/go/core/log/fields.go b/library/go/core/log/fields.go new file mode 100644 index 0000000000..afd41c197e --- /dev/null +++ b/library/go/core/log/fields.go @@ -0,0 +1,446 @@ +package log + +import ( + "fmt" + "time" +) + +const ( + // DefaultErrorFieldName is the default field name used for errors + DefaultErrorFieldName = "error" +) + +// FieldType is a type of data Field can represent +type FieldType int + +const ( + // FieldTypeNil is for a pure nil + FieldTypeNil FieldType = iota + // FieldTypeString is for a string + FieldTypeString + // FieldTypeBinary is for a binary array + FieldTypeBinary + // FieldTypeBoolean is for boolean + FieldTypeBoolean + // FieldTypeSigned is for signed integers + FieldTypeSigned + // FieldTypeUnsigned is for unsigned integers + FieldTypeUnsigned + // FieldTypeFloat is for float + FieldTypeFloat + // FieldTypeTime is for time.Time + FieldTypeTime + // FieldTypeDuration is for time.Duration + FieldTypeDuration + // FieldTypeError is for an error + FieldTypeError + // FieldTypeArray is for an array of any type + FieldTypeArray + // FieldTypeAny is for any type + FieldTypeAny + // FieldTypeReflect is for unknown types + FieldTypeReflect + // FieldTypeByteString is for a bytes that can be represented as UTF-8 string + FieldTypeByteString +) + +// Field stores one structured logging field +type Field struct { + key string + ftype FieldType + string string + signed int64 + unsigned uint64 + float float64 + iface interface{} +} + +// Key returns field key +func (f Field) Key() string { + return f.key +} + +// Type returns field type +func (f Field) Type() FieldType { + return f.ftype +} + +// String returns field string +func (f Field) String() string { + return f.string +} + +// Binary constructs field of []byte +func (f Field) Binary() []byte { + if f.iface == nil { + return nil + } + return f.iface.([]byte) +} + +// Bool returns field bool +func (f Field) Bool() bool { + return f.Signed() != 0 +} + +// Signed returns field int64 +func (f Field) Signed() int64 { + return f.signed +} + +// Unsigned returns field uint64 +func (f Field) Unsigned() uint64 { + return f.unsigned +} + +// Float returns field float64 +func (f Field) Float() float64 { + return f.float +} + +// Time returns field time.Time +func (f Field) Time() time.Time { + return time.Unix(0, f.signed) +} + +// Duration returns field time.Duration +func (f Field) Duration() time.Duration { + return time.Nanosecond * time.Duration(f.signed) +} + +// Error constructs field of error type +func (f Field) Error() error { + if f.iface == nil { + return nil + } + return f.iface.(error) +} + +// Interface returns field interface +func (f Field) Interface() interface{} { + return f.iface +} + +// Any returns contained data as interface{} +// nolint: gocyclo +func (f Field) Any() interface{} { + switch f.Type() { + case FieldTypeNil: + return nil + case FieldTypeString: + return f.String() + case FieldTypeBinary: + return f.Interface() + case FieldTypeBoolean: + return f.Bool() + case FieldTypeSigned: + return f.Signed() + case FieldTypeUnsigned: + return f.Unsigned() + case FieldTypeFloat: + return f.Float() + case FieldTypeTime: + return f.Time() + case FieldTypeDuration: + return f.Duration() + case FieldTypeError: + return f.Error() + case FieldTypeArray: + return f.Interface() + case FieldTypeAny: + return f.Interface() + case FieldTypeReflect: + return f.Interface() + case FieldTypeByteString: + return f.Interface() + default: + // For when new field type is not added to this func + panic(fmt.Sprintf("unknown field type: %d", f.Type())) + } +} + +// Nil constructs field of nil type +func Nil(key string) Field { + return Field{key: key, ftype: FieldTypeNil} +} + +// String constructs field of string type +func String(key, value string) Field { + return Field{key: key, ftype: FieldTypeString, string: value} +} + +// Sprintf constructs field of string type with formatting +func Sprintf(key, format string, args ...interface{}) Field { + return Field{key: key, ftype: FieldTypeString, string: fmt.Sprintf(format, args...)} +} + +// Strings constructs Field from []string +func Strings(key string, value []string) Field { + return Array(key, value) +} + +// Binary constructs field of []byte type +func Binary(key string, value []byte) Field { + return Field{key: key, ftype: FieldTypeBinary, iface: value} +} + +// Bool constructs field of bool type +func Bool(key string, value bool) Field { + field := Field{key: key, ftype: FieldTypeBoolean} + if value { + field.signed = 1 + } else { + field.signed = 0 + } + + return field +} + +// Bools constructs Field from []bool +func Bools(key string, value []bool) Field { + return Array(key, value) +} + +// Int constructs Field from int +func Int(key string, value int) Field { + return Int64(key, int64(value)) +} + +// Ints constructs Field from []int +func Ints(key string, value []int) Field { + return Array(key, value) +} + +// Int8 constructs Field from int8 +func Int8(key string, value int8) Field { + return Int64(key, int64(value)) +} + +// Int8s constructs Field from []int8 +func Int8s(key string, value []int8) Field { + return Array(key, value) +} + +// Int16 constructs Field from int16 +func Int16(key string, value int16) Field { + return Int64(key, int64(value)) +} + +// Int16s constructs Field from []int16 +func Int16s(key string, value []int16) Field { + return Array(key, value) +} + +// Int32 constructs Field from int32 +func Int32(key string, value int32) Field { + return Int64(key, int64(value)) +} + +// Int32s constructs Field from []int32 +func Int32s(key string, value []int32) Field { + return Array(key, value) +} + +// Int64 constructs Field from int64 +func Int64(key string, value int64) Field { + return Field{key: key, ftype: FieldTypeSigned, signed: value} +} + +// Int64s constructs Field from []int64 +func Int64s(key string, value []int64) Field { + return Array(key, value) +} + +// UInt constructs Field from uint +func UInt(key string, value uint) Field { + return UInt64(key, uint64(value)) +} + +// UInts constructs Field from []uint +func UInts(key string, value []uint) Field { + return Array(key, value) +} + +// UInt8 constructs Field from uint8 +func UInt8(key string, value uint8) Field { + return UInt64(key, uint64(value)) +} + +// UInt8s constructs Field from []uint8 +func UInt8s(key string, value []uint8) Field { + return Array(key, value) +} + +// UInt16 constructs Field from uint16 +func UInt16(key string, value uint16) Field { + return UInt64(key, uint64(value)) +} + +// UInt16s constructs Field from []uint16 +func UInt16s(key string, value []uint16) Field { + return Array(key, value) +} + +// UInt32 constructs Field from uint32 +func UInt32(key string, value uint32) Field { + return UInt64(key, uint64(value)) +} + +// UInt32s constructs Field from []uint32 +func UInt32s(key string, value []uint32) Field { + return Array(key, value) +} + +// UInt64 constructs Field from uint64 +func UInt64(key string, value uint64) Field { + return Field{key: key, ftype: FieldTypeUnsigned, unsigned: value} +} + +// UInt64s constructs Field from []uint64 +func UInt64s(key string, value []uint64) Field { + return Array(key, value) +} + +// Float32 constructs Field from float32 +func Float32(key string, value float32) Field { + return Float64(key, float64(value)) +} + +// Float32s constructs Field from []float32 +func Float32s(key string, value []float32) Field { + return Array(key, value) +} + +// Float64 constructs Field from float64 +func Float64(key string, value float64) Field { + return Field{key: key, ftype: FieldTypeFloat, float: value} +} + +// Float64s constructs Field from []float64 +func Float64s(key string, value []float64) Field { + return Array(key, value) +} + +// Time constructs field of time.Time type +func Time(key string, value time.Time) Field { + return Field{key: key, ftype: FieldTypeTime, signed: value.UnixNano()} +} + +// Times constructs Field from []time.Time +func Times(key string, value []time.Time) Field { + return Array(key, value) +} + +// Duration constructs field of time.Duration type +func Duration(key string, value time.Duration) Field { + return Field{key: key, ftype: FieldTypeDuration, signed: value.Nanoseconds()} +} + +// Durations constructs Field from []time.Duration +func Durations(key string, value []time.Duration) Field { + return Array(key, value) +} + +// NamedError constructs field of error type +func NamedError(key string, value error) Field { + return Field{key: key, ftype: FieldTypeError, iface: value} +} + +// Error constructs field of error type with default field name +func Error(value error) Field { + return NamedError(DefaultErrorFieldName, value) +} + +// Errors constructs Field from []error +func Errors(key string, value []error) Field { + return Array(key, value) +} + +// Array constructs field of array type +func Array(key string, value interface{}) Field { + return Field{key: key, ftype: FieldTypeArray, iface: value} +} + +// Reflect constructs field of unknown type +func Reflect(key string, value interface{}) Field { + return Field{key: key, ftype: FieldTypeReflect, iface: value} +} + +// ByteString constructs field of bytes that could represent UTF-8 string +func ByteString(key string, value []byte) Field { + return Field{key: key, ftype: FieldTypeByteString, iface: value} +} + +// Any tries to deduce interface{} underlying type and constructs Field from it. +// Use of this function is ok only for the sole purpose of not repeating its entire code +// or parts of it in user's code (when you need to log interface{} types with unknown content). +// Otherwise please use specialized functions. +// nolint: gocyclo +func Any(key string, value interface{}) Field { + switch val := value.(type) { + case bool: + return Bool(key, val) + case float64: + return Float64(key, val) + case float32: + return Float32(key, val) + case int: + return Int(key, val) + case []int: + return Ints(key, val) + case int64: + return Int64(key, val) + case []int64: + return Int64s(key, val) + case int32: + return Int32(key, val) + case []int32: + return Int32s(key, val) + case int16: + return Int16(key, val) + case []int16: + return Int16s(key, val) + case int8: + return Int8(key, val) + case []int8: + return Int8s(key, val) + case string: + return String(key, val) + case []string: + return Strings(key, val) + case uint: + return UInt(key, val) + case []uint: + return UInts(key, val) + case uint64: + return UInt64(key, val) + case []uint64: + return UInt64s(key, val) + case uint32: + return UInt32(key, val) + case []uint32: + return UInt32s(key, val) + case uint16: + return UInt16(key, val) + case []uint16: + return UInt16s(key, val) + case uint8: + return UInt8(key, val) + case []byte: + return Binary(key, val) + case time.Time: + return Time(key, val) + case []time.Time: + return Times(key, val) + case time.Duration: + return Duration(key, val) + case []time.Duration: + return Durations(key, val) + case error: + return NamedError(key, val) + case []error: + return Errors(key, val) + default: + return Field{key: key, ftype: FieldTypeAny, iface: value} + } +} diff --git a/library/go/core/log/levels.go b/library/go/core/log/levels.go new file mode 100644 index 0000000000..54810410b9 --- /dev/null +++ b/library/go/core/log/levels.go @@ -0,0 +1,108 @@ +package log + +import ( + "fmt" + "strings" +) + +// Level of logging +type Level int + +// MarshalText marshals level to text +func (l Level) MarshalText() ([]byte, error) { + if l >= maxLevel || l < 0 { + return nil, fmt.Errorf("failed to marshal log level: level value (%d) is not in the allowed range (0-%d)", l, maxLevel-1) + } + return []byte(l.String()), nil +} + +// UnmarshalText unmarshals level from text +func (l *Level) UnmarshalText(text []byte) error { + level, err := ParseLevel(string(text)) + if err != nil { + return err + } + + *l = level + return nil +} + +// Standard log levels +const ( + TraceLevel Level = iota + DebugLevel + InfoLevel + WarnLevel + ErrorLevel + FatalLevel + maxLevel +) + +func Levels() (l []Level) { + for i := 0; i < int(maxLevel); i++ { + l = append(l, Level(i)) + } + return +} + +// String values for standard log levels +const ( + TraceString = "trace" + DebugString = "debug" + InfoString = "info" + WarnString = "warn" + ErrorString = "error" + FatalString = "fatal" +) + +// String implements Stringer interface for Level +func (l Level) String() string { + switch l { + case TraceLevel: + return TraceString + case DebugLevel: + return DebugString + case InfoLevel: + return InfoString + case WarnLevel: + return WarnString + case ErrorLevel: + return ErrorString + case FatalLevel: + return FatalString + default: + // For when new log level is not added to this func (most likely never). + panic(fmt.Sprintf("unknown log level: %d", l)) + } +} + +// Set implements flag.Value interface +func (l *Level) Set(v string) error { + lvl, err := ParseLevel(v) + if err != nil { + return err + } + + *l = lvl + return nil +} + +// ParseLevel parses log level from string. Returns ErrUnknownLevel for unknown log level. +func ParseLevel(l string) (Level, error) { + switch strings.ToLower(l) { + case TraceString: + return TraceLevel, nil + case DebugString: + return DebugLevel, nil + case InfoString: + return InfoLevel, nil + case WarnString: + return WarnLevel, nil + case ErrorString: + return ErrorLevel, nil + case FatalString: + return FatalLevel, nil + default: + return FatalLevel, fmt.Errorf("unknown log level: %s", l) + } +} diff --git a/library/go/core/log/log.go b/library/go/core/log/log.go new file mode 100644 index 0000000000..3e1f76e870 --- /dev/null +++ b/library/go/core/log/log.go @@ -0,0 +1,134 @@ +package log + +import "errors" + +// Logger is the universal logger that can do everything. +type Logger interface { + loggerStructured + loggerFmt + toStructured + toFmt + withName +} + +type withName interface { + WithName(name string) Logger +} + +type toLogger interface { + // Logger returns general logger + Logger() Logger +} + +// Structured provides interface for logging using fields. +type Structured interface { + loggerStructured + toFmt + toLogger +} + +type loggerStructured interface { + // Trace logs at Trace log level using fields + Trace(msg string, fields ...Field) + // Debug logs at Debug log level using fields + Debug(msg string, fields ...Field) + // Info logs at Info log level using fields + Info(msg string, fields ...Field) + // Warn logs at Warn log level using fields + Warn(msg string, fields ...Field) + // Error logs at Error log level using fields + Error(msg string, fields ...Field) + // Fatal logs at Fatal log level using fields + Fatal(msg string, fields ...Field) +} + +type toFmt interface { + // Fmt returns fmt logger + Fmt() Fmt +} + +// Fmt provides interface for logging using fmt formatter. +type Fmt interface { + loggerFmt + toStructured + toLogger +} + +type loggerFmt interface { + // Tracef logs at Trace log level using fmt formatter + Tracef(format string, args ...interface{}) + // Debugf logs at Debug log level using fmt formatter + Debugf(format string, args ...interface{}) + // Infof logs at Info log level using fmt formatter + Infof(format string, args ...interface{}) + // Warnf logs at Warn log level using fmt formatter + Warnf(format string, args ...interface{}) + // Errorf logs at Error log level using fmt formatter + Errorf(format string, args ...interface{}) + // Fatalf logs at Fatal log level using fmt formatter + Fatalf(format string, args ...interface{}) +} + +type toStructured interface { + // Structured returns structured logger + Structured() Structured +} + +// LoggerWith is an interface for 'With' function +// LoggerWith provides interface for logger modifications. +type LoggerWith interface { + // With implements 'With' + With(fields ...Field) Logger +} + +// With for loggers that implement LoggerWith interface, returns logger that +// always adds provided key/value to every log entry. Otherwise returns same logger. +func With(l Logger, fields ...Field) Logger { + e, ok := l.(LoggerWith) + if !ok { + return l + } + + return e.With(fields...) +} + +// LoggerAddCallerSkip is an interface for 'AddCallerSkip' function +type LoggerAddCallerSkip interface { + // AddCallerSkip implements 'AddCallerSkip' + AddCallerSkip(skip int) Logger +} + +// AddCallerSkip for loggers that implement LoggerAddCallerSkip interface, returns logger that +// adds caller skip to each log entry. Otherwise returns same logger. +func AddCallerSkip(l Logger, skip int) Logger { + e, ok := l.(LoggerAddCallerSkip) + if !ok { + return l + } + + return e.AddCallerSkip(skip) +} + +// WriteAt is a helper method that checks logger and writes message at given level +func WriteAt(l Structured, lvl Level, msg string, fields ...Field) error { + if l == nil { + return errors.New("nil logger given") + } + + switch lvl { + case DebugLevel: + l.Debug(msg, fields...) + case TraceLevel: + l.Trace(msg, fields...) + case InfoLevel: + l.Info(msg, fields...) + case WarnLevel: + l.Warn(msg, fields...) + case ErrorLevel: + l.Error(msg, fields...) + case FatalLevel: + l.Fatal(msg, fields...) + } + + return nil +} diff --git a/library/go/core/log/nop/nop.go b/library/go/core/log/nop/nop.go new file mode 100644 index 0000000000..1db66ca6fa --- /dev/null +++ b/library/go/core/log/nop/nop.go @@ -0,0 +1,73 @@ +package nop + +import ( + "os" + + "a.yandex-team.ru/library/go/core/log" +) + +// Logger that does nothing +type Logger struct{} + +var _ log.Logger = &Logger{} +var _ log.Structured = &Logger{} +var _ log.Fmt = &Logger{} + +// Logger returns general logger +func (l *Logger) Logger() log.Logger { + return l +} + +// Fmt returns fmt logger +func (l *Logger) Fmt() log.Fmt { + return l +} + +// Structured returns structured logger +func (l *Logger) Structured() log.Structured { + return l +} + +// Trace implements Trace method of log.Logger interface +func (l *Logger) Trace(msg string, fields ...log.Field) {} + +// Tracef implements Tracef method of log.Logger interface +func (l *Logger) Tracef(format string, args ...interface{}) {} + +// Debug implements Debug method of log.Logger interface +func (l *Logger) Debug(msg string, fields ...log.Field) {} + +// Debugf implements Debugf method of log.Logger interface +func (l *Logger) Debugf(format string, args ...interface{}) {} + +// Info implements Info method of log.Logger interface +func (l *Logger) Info(msg string, fields ...log.Field) {} + +// Infof implements Infof method of log.Logger interface +func (l *Logger) Infof(format string, args ...interface{}) {} + +// Warn implements Warn method of log.Logger interface +func (l *Logger) Warn(msg string, fields ...log.Field) {} + +// Warnf implements Warnf method of log.Logger interface +func (l *Logger) Warnf(format string, args ...interface{}) {} + +// Error implements Error method of log.Logger interface +func (l *Logger) Error(msg string, fields ...log.Field) {} + +// Errorf implements Errorf method of log.Logger interface +func (l *Logger) Errorf(format string, args ...interface{}) {} + +// Fatal implements Fatal method of log.Logger interface +func (l *Logger) Fatal(msg string, fields ...log.Field) { + os.Exit(1) +} + +// Fatalf implements Fatalf method of log.Logger interface +func (l *Logger) Fatalf(format string, args ...interface{}) { + os.Exit(1) +} + +func (l *Logger) WithName(name string) log.Logger { + return l +} diff --git a/library/go/core/log/zap/deploy.go b/library/go/core/log/zap/deploy.go new file mode 100644 index 0000000000..652d1c11f8 --- /dev/null +++ b/library/go/core/log/zap/deploy.go @@ -0,0 +1,98 @@ +package zap + +import ( + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "a.yandex-team.ru/library/go/core/log" +) + +// NewDeployEncoderConfig returns an opinionated EncoderConfig for +// deploy environment. +func NewDeployEncoderConfig() zapcore.EncoderConfig { + return zapcore.EncoderConfig{ + MessageKey: "msg", + LevelKey: "levelStr", + StacktraceKey: "stackTrace", + TimeKey: "@timestamp", + CallerKey: "", + NameKey: "loggerName", + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } +} + +// NewDeployConfig returns default configuration (with no sampling). +// Not recommended for production use. +func NewDeployConfig() zap.Config { + return zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Encoding: "json", + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: NewDeployEncoderConfig(), + } +} + +// NewDeployLogger constructs fully-fledged Deploy compatible logger +// based on predefined config. See https://deploy.yandex-team.ru/docs/concepts/pod/sidecars/logs/logs#format +// for more information +func NewDeployLogger(level log.Level, opts ...zap.Option) (*Logger, error) { + cfg := NewDeployConfig() + cfg.Level = zap.NewAtomicLevelAt(ZapifyLevel(level)) + + zl, err := cfg.Build(opts...) + if err != nil { + return nil, err + } + + return &Logger{ + L: addDeployContext(zl).(*zap.Logger), + }, nil +} + +// NewProductionDeployConfig returns configuration, suitable for production use. +// +// It uses a JSON encoder, writes to standard error, and enables sampling. +// Stacktraces are automatically included on logs of ErrorLevel and above. +func NewProductionDeployConfig() zap.Config { + return zap.Config{ + Sampling: &zap.SamplingConfig{ + Initial: 100, + Thereafter: 100, + }, + Development: false, + Level: zap.NewAtomicLevelAt(zap.InfoLevel), + Encoding: "json", + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: NewDeployEncoderConfig(), + } +} + +// Same as NewDeployLogger, but with sampling +func NewProductionDeployLogger(level log.Level, opts ...zap.Option) (*Logger, error) { + cfg := NewProductionDeployConfig() + cfg.Level = zap.NewAtomicLevelAt(ZapifyLevel(level)) + + zl, err := cfg.Build(opts...) + if err != nil { + return nil, err + } + + return &Logger{ + L: addDeployContext(zl).(*zap.Logger), + }, nil +} + +func addDeployContext(i interface{}) interface{} { + switch c := i.(type) { + case *zap.Logger: + return c.With(zap.Namespace("@fields")) + case zapcore.Core: + return c.With([]zapcore.Field{zap.Namespace("@fields")}) + } + return i +} diff --git a/library/go/core/log/zap/encoders/cli.go b/library/go/core/log/zap/encoders/cli.go new file mode 100644 index 0000000000..f19d8527df --- /dev/null +++ b/library/go/core/log/zap/encoders/cli.go @@ -0,0 +1,78 @@ +package encoders + +import ( + "sync" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" +) + +const ( + // EncoderNameCli is the encoder name to use for zap config + EncoderNameCli = "cli" +) + +var cliPool = sync.Pool{New: func() interface{} { + return &cliEncoder{} +}} + +func getCliEncoder() *cliEncoder { + return cliPool.Get().(*cliEncoder) +} + +type cliEncoder struct { + *kvEncoder +} + +// NewCliEncoder constructs cli encoder +func NewCliEncoder(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) { + return newCliEncoder(cfg), nil +} + +func newCliEncoder(cfg zapcore.EncoderConfig) *cliEncoder { + return &cliEncoder{ + kvEncoder: newKVEncoder(cfg), + } +} + +func (enc *cliEncoder) Clone() zapcore.Encoder { + clone := enc.clone() + _, _ = clone.buf.Write(enc.buf.Bytes()) + return clone +} + +func (enc *cliEncoder) clone() *cliEncoder { + clone := getCliEncoder() + clone.kvEncoder = getKVEncoder() + clone.cfg = enc.cfg + clone.openNamespaces = enc.openNamespaces + clone.pool = enc.pool + clone.buf = enc.pool.Get() + return clone +} + +func (enc *cliEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + + // Direct write because we do not want to quote message in cli mode + final.buf.AppendString(ent.Message) + + // Add any structured context. + for _, f := range fields { + f.AddTo(final) + } + + // If there's no stacktrace key, honor that; this allows users to force + // single-line output. + if ent.Stack != "" && final.cfg.StacktraceKey != "" { + final.buf.AppendByte('\n') + final.AppendString(ent.Stack) + } + + if final.cfg.LineEnding != "" { + final.AppendString(final.cfg.LineEnding) + } else { + final.AppendString(zapcore.DefaultLineEnding) + } + return final.buf, nil +} diff --git a/library/go/core/log/zap/encoders/kv.go b/library/go/core/log/zap/encoders/kv.go new file mode 100644 index 0000000000..8fd6c607c6 --- /dev/null +++ b/library/go/core/log/zap/encoders/kv.go @@ -0,0 +1,386 @@ +package encoders + +import ( + "encoding/base64" + "encoding/json" + "math" + "strings" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" +) + +const ( + // EncoderNameKV is the encoder name to use for zap config + EncoderNameKV = "kv" +) + +const ( + // We use ' for quote symbol instead of " so that it doesn't interfere with %q of fmt package + stringQuoteSymbol = '\'' + kvArraySeparator = ',' +) + +var kvPool = sync.Pool{New: func() interface{} { + return &kvEncoder{} +}} + +func getKVEncoder() *kvEncoder { + return kvPool.Get().(*kvEncoder) +} + +type kvEncoder struct { + cfg zapcore.EncoderConfig + pool buffer.Pool + buf *buffer.Buffer + openNamespaces int + + // for encoding generic values by reflection + reflectBuf *buffer.Buffer + reflectEnc *json.Encoder +} + +// NewKVEncoder constructs kv encoder +func NewKVEncoder(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) { + return newKVEncoder(cfg), nil +} + +func newKVEncoder(cfg zapcore.EncoderConfig) *kvEncoder { + pool := buffer.NewPool() + return &kvEncoder{ + cfg: cfg, + pool: pool, + buf: pool.Get(), + } +} + +func (enc *kvEncoder) addElementSeparator() { + if enc.buf.Len() == 0 { + return + } + + enc.buf.AppendByte(' ') +} + +func (enc *kvEncoder) addKey(key string) { + enc.addElementSeparator() + enc.buf.AppendString(key) + enc.buf.AppendByte('=') +} + +func (enc *kvEncoder) appendFloat(val float64, bitSize int) { + enc.appendArrayItemSeparator() + switch { + case math.IsNaN(val): + enc.buf.AppendString(`"NaN"`) + case math.IsInf(val, 1): + enc.buf.AppendString(`"+Inf"`) + case math.IsInf(val, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(val, bitSize) + } +} + +func (enc *kvEncoder) AddArray(key string, arr zapcore.ArrayMarshaler) error { + enc.addKey(key) + return enc.AppendArray(arr) +} + +func (enc *kvEncoder) AddObject(key string, obj zapcore.ObjectMarshaler) error { + enc.addKey(key) + return enc.AppendObject(obj) +} + +func (enc *kvEncoder) AddBinary(key string, val []byte) { + enc.AddString(key, base64.StdEncoding.EncodeToString(val)) +} + +func (enc *kvEncoder) AddByteString(key string, val []byte) { + enc.addKey(key) + enc.AppendByteString(val) +} + +func (enc *kvEncoder) AddBool(key string, val bool) { + enc.addKey(key) + enc.AppendBool(val) +} + +func (enc *kvEncoder) AddComplex128(key string, val complex128) { + enc.addKey(key) + enc.AppendComplex128(val) +} + +func (enc *kvEncoder) AddDuration(key string, val time.Duration) { + enc.addKey(key) + enc.AppendDuration(val) +} + +func (enc *kvEncoder) AddFloat64(key string, val float64) { + enc.addKey(key) + enc.AppendFloat64(val) +} + +func (enc *kvEncoder) AddInt64(key string, val int64) { + enc.addKey(key) + enc.AppendInt64(val) +} + +func (enc *kvEncoder) resetReflectBuf() { + if enc.reflectBuf == nil { + enc.reflectBuf = enc.pool.Get() + enc.reflectEnc = json.NewEncoder(enc.reflectBuf) + } else { + enc.reflectBuf.Reset() + } +} + +func (enc *kvEncoder) AddReflected(key string, obj interface{}) error { + enc.resetReflectBuf() + err := enc.reflectEnc.Encode(obj) + if err != nil { + return err + } + enc.reflectBuf.TrimNewline() + enc.addKey(key) + _, err = enc.buf.Write(enc.reflectBuf.Bytes()) + return err +} + +func (enc *kvEncoder) OpenNamespace(key string) { + enc.addKey(key) + enc.buf.AppendByte('{') + enc.openNamespaces++ +} + +func (enc *kvEncoder) AddString(key, val string) { + enc.addKey(key) + enc.AppendString(val) +} + +func (enc *kvEncoder) AddTime(key string, val time.Time) { + enc.addKey(key) + enc.AppendTime(val) +} + +func (enc *kvEncoder) AddUint64(key string, val uint64) { + enc.addKey(key) + enc.AppendUint64(val) +} + +func (enc *kvEncoder) appendArrayItemSeparator() { + last := enc.buf.Len() - 1 + if last < 0 { + return + } + + switch enc.buf.Bytes()[last] { + case '[', '{', '=': + return + default: + enc.buf.AppendByte(kvArraySeparator) + } +} + +func (enc *kvEncoder) AppendArray(arr zapcore.ArrayMarshaler) error { + enc.appendArrayItemSeparator() + enc.buf.AppendByte('[') + err := arr.MarshalLogArray(enc) + enc.buf.AppendByte(']') + return err +} + +func (enc *kvEncoder) AppendObject(obj zapcore.ObjectMarshaler) error { + enc.appendArrayItemSeparator() + enc.buf.AppendByte('{') + err := obj.MarshalLogObject(enc) + enc.buf.AppendByte('}') + return err +} + +func (enc *kvEncoder) AppendBool(val bool) { + enc.appendArrayItemSeparator() + enc.buf.AppendBool(val) +} + +func (enc *kvEncoder) AppendByteString(val []byte) { + enc.appendArrayItemSeparator() + _, _ = enc.buf.Write(val) +} + +func (enc *kvEncoder) AppendComplex128(val complex128) { + enc.appendArrayItemSeparator() + r, i := real(val), imag(val) + + enc.buf.AppendByte('"') + // Because we're always in a quoted string, we can use strconv without + // special-casing NaN and +/-Inf. + enc.buf.AppendFloat(r, 64) + enc.buf.AppendByte('+') + enc.buf.AppendFloat(i, 64) + enc.buf.AppendByte('i') + enc.buf.AppendByte('"') +} + +func (enc *kvEncoder) AppendDuration(val time.Duration) { + cur := enc.buf.Len() + enc.cfg.EncodeDuration(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeDuration is a no-op. Fall back to nanoseconds to keep + // JSON valid. + enc.AppendInt64(int64(val)) + } +} + +func (enc *kvEncoder) AppendInt64(val int64) { + enc.appendArrayItemSeparator() + enc.buf.AppendInt(val) +} + +func (enc *kvEncoder) AppendReflected(val interface{}) error { + enc.appendArrayItemSeparator() + enc.resetReflectBuf() + err := enc.reflectEnc.Encode(val) + if err != nil { + return err + } + enc.reflectBuf.TrimNewline() + enc.addElementSeparator() + _, err = enc.buf.Write(enc.reflectBuf.Bytes()) + return err +} + +func (enc *kvEncoder) AppendString(val string) { + enc.appendArrayItemSeparator() + var quotes bool + if strings.ContainsAny(val, " =[]{}") { + quotes = true + } + + if quotes { + enc.buf.AppendByte(stringQuoteSymbol) + } + enc.buf.AppendString(val) + if quotes { + enc.buf.AppendByte(stringQuoteSymbol) + } +} + +func (enc *kvEncoder) AppendTime(val time.Time) { + cur := enc.buf.Len() + enc.cfg.EncodeTime(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeTime is a no-op. Fall back to nanos since epoch to keep + // output JSON valid. + enc.AppendInt64(val.UnixNano()) + } +} + +func (enc *kvEncoder) AppendUint64(val uint64) { + enc.appendArrayItemSeparator() + enc.buf.AppendUint(val) +} + +func (enc *kvEncoder) AddComplex64(k string, v complex64) { enc.AddComplex128(k, complex128(v)) } +func (enc *kvEncoder) AddFloat32(k string, v float32) { enc.AddFloat64(k, float64(v)) } +func (enc *kvEncoder) AddInt(k string, v int) { enc.AddInt64(k, int64(v)) } +func (enc *kvEncoder) AddInt32(k string, v int32) { enc.AddInt64(k, int64(v)) } +func (enc *kvEncoder) AddInt16(k string, v int16) { enc.AddInt64(k, int64(v)) } +func (enc *kvEncoder) AddInt8(k string, v int8) { enc.AddInt64(k, int64(v)) } +func (enc *kvEncoder) AddUint(k string, v uint) { enc.AddUint64(k, uint64(v)) } +func (enc *kvEncoder) AddUint32(k string, v uint32) { enc.AddUint64(k, uint64(v)) } +func (enc *kvEncoder) AddUint16(k string, v uint16) { enc.AddUint64(k, uint64(v)) } +func (enc *kvEncoder) AddUint8(k string, v uint8) { enc.AddUint64(k, uint64(v)) } +func (enc *kvEncoder) AddUintptr(k string, v uintptr) { enc.AddUint64(k, uint64(v)) } +func (enc *kvEncoder) AppendComplex64(v complex64) { enc.AppendComplex128(complex128(v)) } +func (enc *kvEncoder) AppendFloat64(v float64) { enc.appendFloat(v, 64) } +func (enc *kvEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), 32) } +func (enc *kvEncoder) AppendInt(v int) { enc.AppendInt64(int64(v)) } +func (enc *kvEncoder) AppendInt32(v int32) { enc.AppendInt64(int64(v)) } +func (enc *kvEncoder) AppendInt16(v int16) { enc.AppendInt64(int64(v)) } +func (enc *kvEncoder) AppendInt8(v int8) { enc.AppendInt64(int64(v)) } +func (enc *kvEncoder) AppendUint(v uint) { enc.AppendUint64(uint64(v)) } +func (enc *kvEncoder) AppendUint32(v uint32) { enc.AppendUint64(uint64(v)) } +func (enc *kvEncoder) AppendUint16(v uint16) { enc.AppendUint64(uint64(v)) } +func (enc *kvEncoder) AppendUint8(v uint8) { enc.AppendUint64(uint64(v)) } +func (enc *kvEncoder) AppendUintptr(v uintptr) { enc.AppendUint64(uint64(v)) } + +func (enc *kvEncoder) Clone() zapcore.Encoder { + clone := enc.clone() + _, _ = clone.buf.Write(enc.buf.Bytes()) + return clone +} + +func (enc *kvEncoder) clone() *kvEncoder { + clone := getKVEncoder() + clone.cfg = enc.cfg + clone.openNamespaces = enc.openNamespaces + clone.pool = enc.pool + clone.buf = enc.pool.Get() + return clone +} + +// nolint: gocyclo +func (enc *kvEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + if final.cfg.TimeKey != "" && final.cfg.EncodeTime != nil { + final.addElementSeparator() + final.buf.AppendString(final.cfg.TimeKey + "=") + final.cfg.EncodeTime(ent.Time, final) + } + if final.cfg.LevelKey != "" && final.cfg.EncodeLevel != nil { + final.addElementSeparator() + final.buf.AppendString(final.cfg.LevelKey + "=") + final.cfg.EncodeLevel(ent.Level, final) + } + if ent.LoggerName != "" && final.cfg.NameKey != "" { + nameEncoder := final.cfg.EncodeName + + if nameEncoder == nil { + // Fall back to FullNameEncoder for backward compatibility. + nameEncoder = zapcore.FullNameEncoder + } + + final.addElementSeparator() + final.buf.AppendString(final.cfg.NameKey + "=") + nameEncoder(ent.LoggerName, final) + } + if ent.Caller.Defined && final.cfg.CallerKey != "" && final.cfg.EncodeCaller != nil { + final.addElementSeparator() + final.buf.AppendString(final.cfg.CallerKey + "=") + final.cfg.EncodeCaller(ent.Caller, final) + } + + if enc.buf.Len() > 0 { + final.addElementSeparator() + _, _ = final.buf.Write(enc.buf.Bytes()) + } + + // Add the message itself. + if final.cfg.MessageKey != "" { + final.addElementSeparator() + final.buf.AppendString(final.cfg.MessageKey + "=") + final.AppendString(ent.Message) + } + + // Add any structured context. + for _, f := range fields { + f.AddTo(final) + } + + // If there's no stacktrace key, honor that; this allows users to force + // single-line output. + if ent.Stack != "" && final.cfg.StacktraceKey != "" { + final.buf.AppendByte('\n') + final.buf.AppendString(ent.Stack) + } + + if final.cfg.LineEnding != "" { + final.buf.AppendString(final.cfg.LineEnding) + } else { + final.buf.AppendString(zapcore.DefaultLineEnding) + } + return final.buf, nil +} diff --git a/library/go/core/log/zap/encoders/tskv.go b/library/go/core/log/zap/encoders/tskv.go new file mode 100644 index 0000000000..98950f1f2b --- /dev/null +++ b/library/go/core/log/zap/encoders/tskv.go @@ -0,0 +1,443 @@ +package encoders + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math" + "strings" + "sync" + "time" + + "go.uber.org/zap/buffer" + "go.uber.org/zap/zapcore" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const ( + // EncoderNameKV is the encoder name to use for zap config + EncoderNameTSKV = "tskv" +) + +const ( + tskvLineEnding = '\n' + tskvElementSeparator = '\t' + tskvKVSeparator = '=' + tskvMark = "tskv" + tskvArrayStart = '[' + tskvArrayEnd = ']' + tskvArraySeparator = ',' +) + +var tskvKeyEscapeRules = []string{ + `\`, `\\`, + "\t", "\\t", + "\n", "\\n", + "\r", `\r`, + "\x00", `\0`, + "=", `\=`, +} + +var tskvValueEscapeRules = []string{ + `\`, `\\`, + "\t", "\\t", + "\n", `\n`, + "\r", `\r`, + "\x00", `\0`, +} + +type tskvEscaper struct { + keyReplacer *strings.Replacer + valueReplacer *strings.Replacer +} + +func newTSKVEscaper() tskvEscaper { + return tskvEscaper{ + keyReplacer: strings.NewReplacer(tskvKeyEscapeRules...), + valueReplacer: strings.NewReplacer(tskvValueEscapeRules...), + } +} + +func (esc *tskvEscaper) escapeKey(key string) string { + return esc.keyReplacer.Replace(key) +} + +func (esc *tskvEscaper) escapeValue(val string) string { + return esc.valueReplacer.Replace(val) +} + +func hexEncode(val []byte) []byte { + dst := make([]byte, hex.EncodedLen(len(val))) + hex.Encode(dst, val) + return dst +} + +var tskvPool = sync.Pool{New: func() interface{} { + return &tskvEncoder{} +}} + +func getTSKVEncoder() *tskvEncoder { + return tskvPool.Get().(*tskvEncoder) +} + +type tskvEncoder struct { + cfg zapcore.EncoderConfig + pool buffer.Pool + buf *buffer.Buffer + + // for encoding generic values by reflection + reflectBuf *buffer.Buffer + reflectEnc *json.Encoder + + tskvEscaper tskvEscaper +} + +// NewKVEncoder constructs tskv encoder +func NewTSKVEncoder(cfg zapcore.EncoderConfig) (zapcore.Encoder, error) { + return newTSKVEncoder(cfg), nil +} + +func newTSKVEncoder(cfg zapcore.EncoderConfig) *tskvEncoder { + pool := buffer.NewPool() + return &tskvEncoder{ + cfg: cfg, + pool: pool, + buf: pool.Get(), + tskvEscaper: newTSKVEscaper(), + } +} + +func (enc *tskvEncoder) appendElementSeparator() { + if enc.buf.Len() == 0 { + return + } + + enc.buf.AppendByte(tskvElementSeparator) +} + +func (enc *tskvEncoder) appendArrayItemSeparator() { + last := enc.buf.Len() - 1 + if last < 0 { + return + } + + switch enc.buf.Bytes()[last] { + case tskvArrayStart, tskvKVSeparator: + return + default: + enc.buf.AppendByte(tskvArraySeparator) + } +} + +func (enc *tskvEncoder) safeAppendKey(key string) { + enc.appendElementSeparator() + enc.buf.AppendString(enc.tskvEscaper.escapeKey(key)) + enc.buf.AppendByte(tskvKVSeparator) +} + +func (enc *tskvEncoder) safeAppendString(val string) { + enc.buf.AppendString(enc.tskvEscaper.escapeValue(val)) +} + +func (enc *tskvEncoder) appendFloat(val float64, bitSize int) { + enc.appendArrayItemSeparator() + switch { + case math.IsNaN(val): + enc.buf.AppendString(`"NaN"`) + case math.IsInf(val, 1): + enc.buf.AppendString(`"+Inf"`) + case math.IsInf(val, -1): + enc.buf.AppendString(`"-Inf"`) + default: + enc.buf.AppendFloat(val, bitSize) + } +} + +func (enc *tskvEncoder) AddArray(key string, arr zapcore.ArrayMarshaler) error { + enc.safeAppendKey(key) + return enc.AppendArray(arr) +} + +func (enc *tskvEncoder) AddObject(key string, obj zapcore.ObjectMarshaler) error { + enc.safeAppendKey(key) + return enc.AppendObject(obj) +} + +func (enc *tskvEncoder) AddBinary(key string, val []byte) { + enc.AddByteString(key, val) +} + +func (enc *tskvEncoder) AddByteString(key string, val []byte) { + enc.safeAppendKey(key) + enc.AppendByteString(val) +} + +func (enc *tskvEncoder) AddBool(key string, val bool) { + enc.safeAppendKey(key) + enc.AppendBool(val) +} + +func (enc *tskvEncoder) AddComplex128(key string, val complex128) { + enc.safeAppendKey(key) + enc.AppendComplex128(val) +} + +func (enc *tskvEncoder) AddDuration(key string, val time.Duration) { + enc.safeAppendKey(key) + enc.AppendDuration(val) +} + +func (enc *tskvEncoder) AddFloat64(key string, val float64) { + enc.safeAppendKey(key) + enc.AppendFloat64(val) +} + +func (enc *tskvEncoder) AddInt64(key string, val int64) { + enc.safeAppendKey(key) + enc.AppendInt64(val) +} + +func (enc *tskvEncoder) resetReflectBuf() { + if enc.reflectBuf == nil { + enc.reflectBuf = enc.pool.Get() + enc.reflectEnc = json.NewEncoder(enc.reflectBuf) + } else { + enc.reflectBuf.Reset() + } +} + +func (enc *tskvEncoder) AddReflected(key string, obj interface{}) error { + enc.resetReflectBuf() + err := enc.reflectEnc.Encode(obj) + if err != nil { + return err + } + enc.reflectBuf.TrimNewline() + enc.safeAppendKey(key) + enc.safeAppendString(enc.reflectBuf.String()) + return err +} + +// OpenNamespace is not supported due to tskv format design +// See AppendObject() for more details +func (enc *tskvEncoder) OpenNamespace(key string) { + panic("TSKV encoder does not support namespaces") +} + +func (enc *tskvEncoder) AddString(key, val string) { + enc.safeAppendKey(key) + enc.safeAppendString(val) +} + +func (enc *tskvEncoder) AddTime(key string, val time.Time) { + enc.safeAppendKey(key) + enc.AppendTime(val) +} + +func (enc *tskvEncoder) AddUint64(key string, val uint64) { + enc.safeAppendKey(key) + enc.AppendUint64(val) +} + +func (enc *tskvEncoder) AppendArray(arr zapcore.ArrayMarshaler) error { + enc.appendArrayItemSeparator() + enc.buf.AppendByte(tskvArrayStart) + err := arr.MarshalLogArray(enc) + enc.buf.AppendByte(tskvArrayEnd) + return err +} + +// TSKV format does not support hierarchy data so we can't log Objects here +// The only thing we can do is to implicitly use fmt.Stringer interface +// +// ObjectMarshaler interface requires MarshalLogObject method +// from within MarshalLogObject you only have access to ObjectEncoder methods (AddString, AddBool ...) +// so if you call AddString then object log will be split by \t sign +// but \t is key-value separator and tskv doesn't have another separators +// e.g +// json encoded: objLogFieldName={"innerObjKey1":{"innerObjKey2":"value"}} +// tskv encoded: objLogFieldName={ \tinnerObjKey1={ \tinnerObjKey2=value}} +func (enc *tskvEncoder) AppendObject(obj zapcore.ObjectMarshaler) error { + var err error + + enc.appendArrayItemSeparator() + enc.buf.AppendByte('{') + stringerObj, ok := obj.(fmt.Stringer) + if !ok { + err = xerrors.Errorf("fmt.Stringer implementation required due to marshall into tskv format") + } else { + enc.safeAppendString(stringerObj.String()) + } + enc.buf.AppendByte('}') + + return err +} + +func (enc *tskvEncoder) AppendBool(val bool) { + enc.appendArrayItemSeparator() + enc.buf.AppendBool(val) +} + +func (enc *tskvEncoder) AppendByteString(val []byte) { + enc.appendArrayItemSeparator() + _, _ = enc.buf.Write(hexEncode(val)) +} + +func (enc *tskvEncoder) AppendComplex128(val complex128) { // TODO + enc.appendArrayItemSeparator() + + r, i := real(val), imag(val) + enc.buf.AppendByte('"') + // Because we're always in a quoted string, we can use strconv without + // special-casing NaN and +/-Inf. + enc.buf.AppendFloat(r, 64) + enc.buf.AppendByte('+') + enc.buf.AppendFloat(i, 64) + enc.buf.AppendByte('i') + enc.buf.AppendByte('"') +} + +func (enc *tskvEncoder) AppendDuration(val time.Duration) { + cur := enc.buf.Len() + enc.cfg.EncodeDuration(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeDuration is a no-op. Fall back to nanoseconds + enc.AppendInt64(int64(val)) + } +} + +func (enc *tskvEncoder) AppendInt64(val int64) { + enc.appendArrayItemSeparator() + enc.buf.AppendInt(val) +} + +func (enc *tskvEncoder) AppendReflected(val interface{}) error { + enc.appendArrayItemSeparator() + + enc.resetReflectBuf() + err := enc.reflectEnc.Encode(val) + if err != nil { + return err + } + enc.reflectBuf.TrimNewline() + enc.safeAppendString(enc.reflectBuf.String()) + return nil +} + +func (enc *tskvEncoder) AppendString(val string) { + enc.appendArrayItemSeparator() + enc.safeAppendString(val) +} + +func (enc *tskvEncoder) AppendTime(val time.Time) { + cur := enc.buf.Len() + enc.cfg.EncodeTime(val, enc) + if cur == enc.buf.Len() { + // User-supplied EncodeTime is a no-op. Fall back to nanos since epoch to keep output tskv valid. + enc.AppendInt64(val.Unix()) + } +} + +func (enc *tskvEncoder) AppendUint64(val uint64) { + enc.appendArrayItemSeparator() + enc.buf.AppendUint(val) +} + +func (enc *tskvEncoder) AddComplex64(k string, v complex64) { enc.AddComplex128(k, complex128(v)) } +func (enc *tskvEncoder) AddFloat32(k string, v float32) { enc.AddFloat64(k, float64(v)) } +func (enc *tskvEncoder) AddInt(k string, v int) { enc.AddInt64(k, int64(v)) } +func (enc *tskvEncoder) AddInt32(k string, v int32) { enc.AddInt64(k, int64(v)) } +func (enc *tskvEncoder) AddInt16(k string, v int16) { enc.AddInt64(k, int64(v)) } +func (enc *tskvEncoder) AddInt8(k string, v int8) { enc.AddInt64(k, int64(v)) } +func (enc *tskvEncoder) AddUint(k string, v uint) { enc.AddUint64(k, uint64(v)) } +func (enc *tskvEncoder) AddUint32(k string, v uint32) { enc.AddUint64(k, uint64(v)) } +func (enc *tskvEncoder) AddUint16(k string, v uint16) { enc.AddUint64(k, uint64(v)) } +func (enc *tskvEncoder) AddUint8(k string, v uint8) { enc.AddUint64(k, uint64(v)) } +func (enc *tskvEncoder) AddUintptr(k string, v uintptr) { enc.AddUint64(k, uint64(v)) } +func (enc *tskvEncoder) AppendComplex64(v complex64) { enc.AppendComplex128(complex128(v)) } +func (enc *tskvEncoder) AppendFloat64(v float64) { enc.appendFloat(v, 64) } +func (enc *tskvEncoder) AppendFloat32(v float32) { enc.appendFloat(float64(v), 32) } +func (enc *tskvEncoder) AppendInt(v int) { enc.AppendInt64(int64(v)) } +func (enc *tskvEncoder) AppendInt32(v int32) { enc.AppendInt64(int64(v)) } +func (enc *tskvEncoder) AppendInt16(v int16) { enc.AppendInt64(int64(v)) } +func (enc *tskvEncoder) AppendInt8(v int8) { enc.AppendInt64(int64(v)) } +func (enc *tskvEncoder) AppendUint(v uint) { enc.AppendUint64(uint64(v)) } +func (enc *tskvEncoder) AppendUint32(v uint32) { enc.AppendUint64(uint64(v)) } +func (enc *tskvEncoder) AppendUint16(v uint16) { enc.AppendUint64(uint64(v)) } +func (enc *tskvEncoder) AppendUint8(v uint8) { enc.AppendUint64(uint64(v)) } +func (enc *tskvEncoder) AppendUintptr(v uintptr) { enc.AppendUint64(uint64(v)) } + +func (enc *tskvEncoder) Clone() zapcore.Encoder { + clone := enc.clone() + _, _ = clone.buf.Write(enc.buf.Bytes()) + return clone +} + +func (enc *tskvEncoder) clone() *tskvEncoder { + clone := getTSKVEncoder() + clone.cfg = enc.cfg + clone.pool = enc.pool + clone.buf = enc.pool.Get() + clone.tskvEscaper = enc.tskvEscaper + return clone +} + +// nolint: gocyclo +func (enc *tskvEncoder) EncodeEntry(ent zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + final := enc.clone() + final.AppendString(tskvMark) + + if final.cfg.TimeKey != "" && final.cfg.EncodeTime != nil { + final.safeAppendKey(final.cfg.TimeKey) + final.cfg.EncodeTime(ent.Time, final) + } + if final.cfg.LevelKey != "" && final.cfg.EncodeLevel != nil { + final.safeAppendKey(final.cfg.LevelKey) + final.cfg.EncodeLevel(ent.Level, final) + } + if ent.LoggerName != "" && final.cfg.NameKey != "" { + nameEncoder := final.cfg.EncodeName + + if nameEncoder == nil { + // Fall back to FullNameEncoder for backward compatibility. + nameEncoder = zapcore.FullNameEncoder + } + + final.safeAppendKey(final.cfg.NameKey) + nameEncoder(ent.LoggerName, final) + } + if ent.Caller.Defined && final.cfg.CallerKey != "" && final.cfg.EncodeCaller != nil { + final.safeAppendKey(final.cfg.CallerKey) + final.cfg.EncodeCaller(ent.Caller, final) + } + + if enc.buf.Len() > 0 { + final.appendElementSeparator() + _, _ = final.buf.Write(enc.buf.Bytes()) + } + + // Add the message itself. + if final.cfg.MessageKey != "" { + final.safeAppendKey(final.cfg.MessageKey) + final.safeAppendString(ent.Message) + } + + // Add any structured context. + for _, f := range fields { + f.AddTo(final) + } + + if ent.Stack != "" && final.cfg.StacktraceKey != "" { + final.safeAppendKey(final.cfg.StacktraceKey) + final.safeAppendString(ent.Stack) + } + + if final.cfg.LineEnding != "" { + final.buf.AppendString(final.cfg.LineEnding) + } else { + final.buf.AppendByte(tskvLineEnding) + } + + return final.buf, nil +} diff --git a/library/go/core/log/zap/logrotate/sink.go b/library/go/core/log/zap/logrotate/sink.go new file mode 100644 index 0000000000..e813316201 --- /dev/null +++ b/library/go/core/log/zap/logrotate/sink.go @@ -0,0 +1,122 @@ +//go:build darwin || freebsd || linux +// +build darwin freebsd linux + +package logrotate + +import ( + "fmt" + "net/url" + "os" + "os/signal" + "sync/atomic" + "unsafe" + + "go.uber.org/zap" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const defaultSchemeName = "logrotate" + +// Register logrotate sink in zap sink registry. +// This sink internally is like file sink, but listens to provided logrotate signal +// and reopens file when that signal is delivered +// This can be called only once. Any future calls will result in an error +func RegisterLogrotateSink(sig ...os.Signal) error { + return RegisterNamedLogrotateSink(defaultSchemeName, sig...) +} + +// Same as RegisterLogrotateSink, but use provided schemeName instead of default `logrotate` +// Can be useful in special cases for registering different types of sinks for different signal +func RegisterNamedLogrotateSink(schemeName string, sig ...os.Signal) error { + factory := func(url *url.URL) (sink zap.Sink, e error) { + return NewLogrotateSink(url, sig...) + } + return zap.RegisterSink(schemeName, factory) +} + +// sink itself, use RegisterLogrotateSink to register it in zap machinery +type sink struct { + path string + notifier chan os.Signal + file unsafe.Pointer +} + +// Factory for logrotate sink, which accepts os.Signals to listen to for reloading +// Generally if you don't build your own core it is used by zap machinery. +// See RegisterLogrotateSink. +func NewLogrotateSink(u *url.URL, sig ...os.Signal) (zap.Sink, error) { + notifier := make(chan os.Signal, 1) + signal.Notify(notifier, sig...) + + if u.User != nil { + return nil, fmt.Errorf("user and password not allowed with logrotate file URLs: got %v", u) + } + if u.Fragment != "" { + return nil, fmt.Errorf("fragments not allowed with logrotate file URLs: got %v", u) + } + // Error messages are better if we check hostname and port separately. + if u.Port() != "" { + return nil, fmt.Errorf("ports not allowed with logrotate file URLs: got %v", u) + } + if hn := u.Hostname(); hn != "" && hn != "localhost" { + return nil, fmt.Errorf("logrotate file URLs must leave host empty or use localhost: got %v", u) + } + + sink := &sink{ + path: u.Path, + notifier: notifier, + } + if err := sink.reopen(); err != nil { + return nil, err + } + go sink.listenToSignal() + return sink, nil +} + +// wait for signal delivery or chanel close +func (m *sink) listenToSignal() { + for { + _, ok := <-m.notifier + if !ok { + return + } + if err := m.reopen(); err != nil { + // Last chance to signalize about an error + _, _ = fmt.Fprintf(os.Stderr, "%s", err) + } + } +} + +func (m *sink) reopen() error { + file, err := os.OpenFile(m.path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return xerrors.Errorf("failed to open log file on %s: %w", m.path, err) + } + old := (*os.File)(m.file) + atomic.StorePointer(&m.file, unsafe.Pointer(file)) + if old != nil { + if err := old.Close(); err != nil { + return xerrors.Errorf("failed to close old file: %w", err) + } + } + return nil +} + +func (m *sink) getFile() *os.File { + return (*os.File)(atomic.LoadPointer(&m.file)) +} + +func (m *sink) Close() error { + signal.Stop(m.notifier) + close(m.notifier) + return m.getFile().Close() +} + +func (m *sink) Write(p []byte) (n int, err error) { + return m.getFile().Write(p) +} + +func (m *sink) Sync() error { + return m.getFile().Sync() +} diff --git a/library/go/core/log/zap/qloud.go b/library/go/core/log/zap/qloud.go new file mode 100644 index 0000000000..0672a09ff3 --- /dev/null +++ b/library/go/core/log/zap/qloud.go @@ -0,0 +1,50 @@ +package zap + +import ( + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "a.yandex-team.ru/library/go/core/log" +) + +// NewQloudLogger constructs fully-fledged Qloud compatible logger +// based on predefined config. See https://wiki.yandex-team.ru/qloud/doc/logs +// for more information +func NewQloudLogger(level log.Level, opts ...zap.Option) (*Logger, error) { + cfg := zap.Config{ + Level: zap.NewAtomicLevelAt(ZapifyLevel(level)), + Encoding: "json", + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: zapcore.EncoderConfig{ + MessageKey: "msg", + LevelKey: "level", + StacktraceKey: "stackTrace", + TimeKey: "", + CallerKey: "", + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } + + zl, err := cfg.Build(opts...) + if err != nil { + return nil, err + } + + return &Logger{ + L: addQloudContext(zl).(*zap.Logger), + }, nil +} + +func addQloudContext(i interface{}) interface{} { + switch c := i.(type) { + case *zap.Logger: + return c.With(zap.Namespace("@fields")) + case zapcore.Core: + return c.With([]zapcore.Field{zap.Namespace("@fields")}) + } + return i +} diff --git a/library/go/core/log/zap/zap.go b/library/go/core/log/zap/zap.go new file mode 100644 index 0000000000..2870ece4ab --- /dev/null +++ b/library/go/core/log/zap/zap.go @@ -0,0 +1,253 @@ +package zap + +import ( + "fmt" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap/encoders" +) + +const ( + // callerSkip is number of stack frames to skip when logging caller + callerSkip = 1 +) + +func init() { + if err := zap.RegisterEncoder(encoders.EncoderNameKV, encoders.NewKVEncoder); err != nil { + panic(err) + } + if err := zap.RegisterEncoder(encoders.EncoderNameCli, encoders.NewCliEncoder); err != nil { + panic(err) + } + if err := zap.RegisterEncoder(encoders.EncoderNameTSKV, encoders.NewTSKVEncoder); err != nil { + panic(err) + } +} + +// Logger implements log.Logger interface +type Logger struct { + L *zap.Logger +} + +var _ log.Logger = &Logger{} +var _ log.Structured = &Logger{} +var _ log.Fmt = &Logger{} +var _ log.LoggerWith = &Logger{} +var _ log.LoggerAddCallerSkip = &Logger{} + +// New constructs zap-based logger from provided config +func New(cfg zap.Config) (*Logger, error) { + zl, err := cfg.Build(zap.AddCallerSkip(callerSkip)) + if err != nil { + return nil, err + } + + return &Logger{ + L: zl, + }, nil +} + +// NewWithCore constructs zap-based logger from provided core +func NewWithCore(core zapcore.Core, options ...zap.Option) *Logger { + options = append(options, zap.AddCallerSkip(callerSkip)) + return &Logger{L: zap.New(core, options...)} +} + +// Must constructs zap-based logger from provided config and panics on error +func Must(cfg zap.Config) *Logger { + l, err := New(cfg) + if err != nil { + panic(fmt.Sprintf("failed to construct zap logger: %v", err)) + } + return l +} + +// JSONConfig returns zap config for structured logging (zap's json encoder) +func JSONConfig(level log.Level) zap.Config { + return StandardConfig("json", level) +} + +// ConsoleConfig returns zap config for logging to console (zap's console encoder) +func ConsoleConfig(level log.Level) zap.Config { + return StandardConfig("console", level) +} + +// CLIConfig returns zap config for cli logging (custom cli encoder) +func CLIConfig(level log.Level) zap.Config { + return StandardConfig("cli", level) +} + +// KVConfig returns zap config for logging to kv (custom kv encoder) +func KVConfig(level log.Level) zap.Config { + return StandardConfig("kv", level) +} + +// TSKVConfig returns zap config for logging to tskv (custom tskv encoder) +func TSKVConfig(level log.Level) zap.Config { + return zap.Config{ + Level: zap.NewAtomicLevelAt(ZapifyLevel(level)), + Encoding: "tskv", + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: zapcore.EncoderConfig{ + MessageKey: "message", + LevelKey: "levelname", + TimeKey: "unixtime", + CallerKey: "caller", + NameKey: "name", + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.EpochTimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } +} + +// StandardConfig returns default zap config with specified encoding and level +func StandardConfig(encoding string, level log.Level) zap.Config { + return zap.Config{ + Level: zap.NewAtomicLevelAt(ZapifyLevel(level)), + Encoding: encoding, + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: zapcore.EncoderConfig{ + MessageKey: "msg", + LevelKey: "level", + TimeKey: "ts", + CallerKey: "caller", + NameKey: "name", + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + } +} + +// Logger returns general logger +func (l *Logger) Logger() log.Logger { + return l +} + +// Fmt returns fmt logger +func (l *Logger) Fmt() log.Fmt { + return l +} + +// Structured returns structured logger +func (l *Logger) Structured() log.Structured { + return l +} + +// With returns logger that always adds provided key/value to every log entry +func (l *Logger) With(fields ...log.Field) log.Logger { + return &Logger{ + L: l.L.With(zapifyFields(fields...)...), + } +} + +func (l *Logger) AddCallerSkip(skip int) log.Logger { + return &Logger{ + L: l.L.WithOptions(zap.AddCallerSkip(skip)), + } +} + +// Trace logs at Trace log level using fields +func (l *Logger) Trace(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.DebugLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Tracef logs at Trace log level using fmt formatter +func (l *Logger) Tracef(msg string, args ...interface{}) { + if ce := l.L.Check(zap.DebugLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// Debug logs at Debug log level using fields +func (l *Logger) Debug(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.DebugLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Debugf logs at Debug log level using fmt formatter +func (l *Logger) Debugf(msg string, args ...interface{}) { + if ce := l.L.Check(zap.DebugLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// Info logs at Info log level using fields +func (l *Logger) Info(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.InfoLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Infof logs at Info log level using fmt formatter +func (l *Logger) Infof(msg string, args ...interface{}) { + if ce := l.L.Check(zap.InfoLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// Warn logs at Warn log level using fields +func (l *Logger) Warn(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.WarnLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Warnf logs at Warn log level using fmt formatter +func (l *Logger) Warnf(msg string, args ...interface{}) { + if ce := l.L.Check(zap.WarnLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// Error logs at Error log level using fields +func (l *Logger) Error(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.ErrorLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Errorf logs at Error log level using fmt formatter +func (l *Logger) Errorf(msg string, args ...interface{}) { + if ce := l.L.Check(zap.ErrorLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// Fatal logs at Fatal log level using fields +func (l *Logger) Fatal(msg string, fields ...log.Field) { + if ce := l.L.Check(zap.FatalLevel, msg); ce != nil { + ce.Write(zapifyFields(fields...)...) + } +} + +// Fatalf logs at Fatal log level using fmt formatter +func (l *Logger) Fatalf(msg string, args ...interface{}) { + if ce := l.L.Check(zap.FatalLevel, ""); ce != nil { + ce.Message = fmt.Sprintf(msg, args...) + ce.Write() + } +} + +// WithName adds name to logger +func (l *Logger) WithName(name string) log.Logger { + return &Logger{ + L: l.L.Named(name), + } +} diff --git a/library/go/core/log/zap/zapify.go b/library/go/core/log/zap/zapify.go new file mode 100644 index 0000000000..43ff6da697 --- /dev/null +++ b/library/go/core/log/zap/zapify.go @@ -0,0 +1,96 @@ +package zap + +import ( + "fmt" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "a.yandex-team.ru/library/go/core/log" +) + +// ZapifyLevel turns interface log level to zap log level +func ZapifyLevel(level log.Level) zapcore.Level { + switch level { + case log.TraceLevel: + return zapcore.DebugLevel + case log.DebugLevel: + return zapcore.DebugLevel + case log.InfoLevel: + return zapcore.InfoLevel + case log.WarnLevel: + return zapcore.WarnLevel + case log.ErrorLevel: + return zapcore.ErrorLevel + case log.FatalLevel: + return zapcore.FatalLevel + default: + // For when new log level is not added to this func (most likely never). + panic(fmt.Sprintf("unknown log level: %d", level)) + } +} + +// UnzapifyLevel turns zap log level to interface log level. +func UnzapifyLevel(level zapcore.Level) log.Level { + switch level { + case zapcore.DebugLevel: + return log.DebugLevel + case zapcore.InfoLevel: + return log.InfoLevel + case zapcore.WarnLevel: + return log.WarnLevel + case zapcore.ErrorLevel: + return log.ErrorLevel + case zapcore.FatalLevel, zapcore.DPanicLevel, zapcore.PanicLevel: + return log.FatalLevel + default: + // For when new log level is not added to this func (most likely never). + panic(fmt.Sprintf("unknown log level: %d", level)) + } +} + +// nolint: gocyclo +func zapifyField(field log.Field) zap.Field { + switch field.Type() { + case log.FieldTypeNil: + return zap.Reflect(field.Key(), nil) + case log.FieldTypeString: + return zap.String(field.Key(), field.String()) + case log.FieldTypeBinary: + return zap.Binary(field.Key(), field.Binary()) + case log.FieldTypeBoolean: + return zap.Bool(field.Key(), field.Bool()) + case log.FieldTypeSigned: + return zap.Int64(field.Key(), field.Signed()) + case log.FieldTypeUnsigned: + return zap.Uint64(field.Key(), field.Unsigned()) + case log.FieldTypeFloat: + return zap.Float64(field.Key(), field.Float()) + case log.FieldTypeTime: + return zap.Time(field.Key(), field.Time()) + case log.FieldTypeDuration: + return zap.Duration(field.Key(), field.Duration()) + case log.FieldTypeError: + return zap.NamedError(field.Key(), field.Error()) + case log.FieldTypeArray: + return zap.Any(field.Key(), field.Interface()) + case log.FieldTypeAny: + return zap.Any(field.Key(), field.Interface()) + case log.FieldTypeReflect: + return zap.Reflect(field.Key(), field.Interface()) + case log.FieldTypeByteString: + return zap.ByteString(field.Key(), field.Binary()) + default: + // For when new field type is not added to this func + panic(fmt.Sprintf("unknown field type: %d", field.Type())) + } +} + +func zapifyFields(fields ...log.Field) []zapcore.Field { + zapFields := make([]zapcore.Field, 0, len(fields)) + for _, field := range fields { + zapFields = append(zapFields, zapifyField(field)) + } + + return zapFields +} diff --git a/library/go/core/metrics/buckets.go b/library/go/core/metrics/buckets.go new file mode 100644 index 0000000000..063c0c4418 --- /dev/null +++ b/library/go/core/metrics/buckets.go @@ -0,0 +1,147 @@ +package metrics + +import ( + "sort" + "time" +) + +var ( + _ DurationBuckets = (*durationBuckets)(nil) + _ Buckets = (*buckets)(nil) +) + +const ( + errBucketsCountNeedsGreaterThanZero = "n needs to be > 0" + errBucketsStartNeedsGreaterThanZero = "start needs to be > 0" + errBucketsFactorNeedsGreaterThanOne = "factor needs to be > 1" +) + +type durationBuckets struct { + buckets []time.Duration +} + +// NewDurationBuckets returns new DurationBuckets implementation. +func NewDurationBuckets(bk ...time.Duration) DurationBuckets { + sort.Slice(bk, func(i, j int) bool { + return bk[i] < bk[j] + }) + return durationBuckets{buckets: bk} +} + +func (d durationBuckets) Size() int { + return len(d.buckets) +} + +func (d durationBuckets) MapDuration(dv time.Duration) (idx int) { + for _, bound := range d.buckets { + if dv < bound { + break + } + idx++ + } + return +} + +func (d durationBuckets) UpperBound(idx int) time.Duration { + if idx > d.Size()-1 { + panic("idx is out of bounds") + } + return d.buckets[idx] +} + +type buckets struct { + buckets []float64 +} + +// NewBuckets returns new Buckets implementation. +func NewBuckets(bk ...float64) Buckets { + sort.Slice(bk, func(i, j int) bool { + return bk[i] < bk[j] + }) + return buckets{buckets: bk} +} + +func (d buckets) Size() int { + return len(d.buckets) +} + +func (d buckets) MapValue(v float64) (idx int) { + for _, bound := range d.buckets { + if v < bound { + break + } + idx++ + } + return +} + +func (d buckets) UpperBound(idx int) float64 { + if idx > d.Size()-1 { + panic("idx is out of bounds") + } + return d.buckets[idx] +} + +// MakeLinearBuckets creates a set of linear value buckets. +func MakeLinearBuckets(start, width float64, n int) Buckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + bounds := make([]float64, n) + for i := range bounds { + bounds[i] = start + (float64(i) * width) + } + return NewBuckets(bounds...) +} + +// MakeLinearDurationBuckets creates a set of linear duration buckets. +func MakeLinearDurationBuckets(start, width time.Duration, n int) DurationBuckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + buckets := make([]time.Duration, n) + for i := range buckets { + buckets[i] = start + (time.Duration(i) * width) + } + return NewDurationBuckets(buckets...) +} + +// MakeExponentialBuckets creates a set of exponential value buckets. +func MakeExponentialBuckets(start, factor float64, n int) Buckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + if start <= 0 { + panic(errBucketsStartNeedsGreaterThanZero) + } + if factor <= 1 { + panic(errBucketsFactorNeedsGreaterThanOne) + } + buckets := make([]float64, n) + curr := start + for i := range buckets { + buckets[i] = curr + curr *= factor + } + return NewBuckets(buckets...) +} + +// MakeExponentialDurationBuckets creates a set of exponential duration buckets. +func MakeExponentialDurationBuckets(start time.Duration, factor float64, n int) DurationBuckets { + if n <= 0 { + panic(errBucketsCountNeedsGreaterThanZero) + } + if start <= 0 { + panic(errBucketsStartNeedsGreaterThanZero) + } + if factor <= 1 { + panic(errBucketsFactorNeedsGreaterThanOne) + } + buckets := make([]time.Duration, n) + curr := start + for i := range buckets { + buckets[i] = curr + curr = time.Duration(float64(curr) * factor) + } + return NewDurationBuckets(buckets...) +} diff --git a/library/go/core/metrics/collect/collect.go b/library/go/core/metrics/collect/collect.go new file mode 100644 index 0000000000..3abbbcdfa9 --- /dev/null +++ b/library/go/core/metrics/collect/collect.go @@ -0,0 +1,9 @@ +package collect + +import ( + "context" + + "a.yandex-team.ru/library/go/core/metrics" +) + +type Func func(ctx context.Context, r metrics.Registry, c metrics.CollectPolicy) diff --git a/library/go/core/metrics/collect/system.go b/library/go/core/metrics/collect/system.go new file mode 100644 index 0000000000..8ce89ebc05 --- /dev/null +++ b/library/go/core/metrics/collect/system.go @@ -0,0 +1,229 @@ +package collect + +import ( + "context" + "os" + "runtime" + "runtime/debug" + "time" + + "github.com/prometheus/procfs" + + "a.yandex-team.ru/library/go/core/buildinfo" + "a.yandex-team.ru/library/go/core/metrics" +) + +var _ Func = GoMetrics + +func GoMetrics(_ context.Context, r metrics.Registry, c metrics.CollectPolicy) { + if r == nil { + return + } + r = r.WithPrefix("go") + + var stats debug.GCStats + stats.PauseQuantiles = make([]time.Duration, 5) // Minimum, 25%, 50%, 75%, and maximum pause times. + var numGoroutine, numThread int + var ms runtime.MemStats + + c.AddCollect(func(context.Context) { + debug.ReadGCStats(&stats) + runtime.ReadMemStats(&ms) + + numThread, _ = runtime.ThreadCreateProfile(nil) + numGoroutine = runtime.NumGoroutine() + }) + + gcRegistry := r.WithPrefix("gc") + gcRegistry.FuncCounter("num", c.RegisteredCounter(func() int64 { + return stats.NumGC + })) + gcRegistry.FuncCounter(r.ComposeName("pause", "total", "ns"), c.RegisteredCounter(func() int64 { + return stats.PauseTotal.Nanoseconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "min"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[0].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "25"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[1].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "50"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[2].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "75"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[3].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("pause", "quantile", "max"), c.RegisteredGauge(func() float64 { + return stats.PauseQuantiles[4].Seconds() + })) + gcRegistry.FuncGauge(r.ComposeName("last", "ts"), c.RegisteredGauge(func() float64 { + return float64(ms.LastGC) + })) + gcRegistry.FuncCounter(r.ComposeName("forced", "num"), c.RegisteredCounter(func() int64 { + return int64(ms.NumForcedGC) + })) + + r.FuncGauge(r.ComposeName("goroutine", "num"), c.RegisteredGauge(func() float64 { + return float64(numGoroutine) + })) + r.FuncGauge(r.ComposeName("thread", "num"), c.RegisteredGauge(func() float64 { + return float64(numThread) + })) + + memRegistry := r.WithPrefix("mem") + memRegistry.FuncCounter(r.ComposeName("alloc", "total"), c.RegisteredCounter(func() int64 { + return int64(ms.TotalAlloc) + })) + memRegistry.FuncGauge("sys", c.RegisteredGauge(func() float64 { + return float64(ms.Sys) + })) + memRegistry.FuncCounter("lookups", c.RegisteredCounter(func() int64 { + return int64(ms.Lookups) + })) + memRegistry.FuncCounter("mallocs", c.RegisteredCounter(func() int64 { + return int64(ms.Mallocs) + })) + memRegistry.FuncCounter("frees", c.RegisteredCounter(func() int64 { + return int64(ms.Frees) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "alloc"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapAlloc) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapSys) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "idle"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapIdle) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapInuse) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "released"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapReleased) + })) + memRegistry.FuncGauge(r.ComposeName("heap", "objects"), c.RegisteredGauge(func() float64 { + return float64(ms.HeapObjects) + })) + + memRegistry.FuncGauge(r.ComposeName("stack", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.StackInuse) + })) + memRegistry.FuncGauge(r.ComposeName("stack", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.StackSys) + })) + + memRegistry.FuncGauge(r.ComposeName("span", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.MSpanInuse) + })) + memRegistry.FuncGauge(r.ComposeName("span", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.MSpanSys) + })) + + memRegistry.FuncGauge(r.ComposeName("cache", "inuse"), c.RegisteredGauge(func() float64 { + return float64(ms.MCacheInuse) + })) + memRegistry.FuncGauge(r.ComposeName("cache", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.MCacheSys) + })) + + memRegistry.FuncGauge(r.ComposeName("buck", "hash", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.BuckHashSys) + })) + memRegistry.FuncGauge(r.ComposeName("gc", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.GCSys) + })) + memRegistry.FuncGauge(r.ComposeName("other", "sys"), c.RegisteredGauge(func() float64 { + return float64(ms.OtherSys) + })) + memRegistry.FuncGauge(r.ComposeName("gc", "next"), c.RegisteredGauge(func() float64 { + return float64(ms.NextGC) + })) + + memRegistry.FuncGauge(r.ComposeName("gc", "cpu", "fraction"), c.RegisteredGauge(func() float64 { + return ms.GCCPUFraction + })) +} + +var _ Func = ProcessMetrics + +func ProcessMetrics(_ context.Context, r metrics.Registry, c metrics.CollectPolicy) { + if r == nil { + return + } + buildVersion := buildinfo.Info.ArcadiaSourceRevision + r.WithTags(map[string]string{"revision": buildVersion}).Gauge("build").Set(1.0) + + pid := os.Getpid() + proc, err := procfs.NewProc(pid) + if err != nil { + return + } + + procRegistry := r.WithPrefix("proc") + + var ioStat procfs.ProcIO + var procStat procfs.ProcStat + var fd int + var cpuWait uint64 + + const clocksPerSec = 100 + + c.AddCollect(func(ctx context.Context) { + if gatheredFD, err := proc.FileDescriptorsLen(); err == nil { + fd = gatheredFD + } + + if gatheredIOStat, err := proc.IO(); err == nil { + ioStat.SyscW = gatheredIOStat.SyscW + ioStat.WriteBytes = gatheredIOStat.WriteBytes + ioStat.SyscR = gatheredIOStat.SyscR + ioStat.ReadBytes = gatheredIOStat.ReadBytes + } + + if gatheredStat, err := proc.Stat(); err == nil { + procStat.UTime = gatheredStat.UTime + procStat.STime = gatheredStat.STime + procStat.RSS = gatheredStat.RSS + } + + if gatheredSched, err := proc.Schedstat(); err == nil { + cpuWait = gatheredSched.WaitingNanoseconds + } + }) + + procRegistry.FuncGauge("fd", c.RegisteredGauge(func() float64 { + return float64(fd) + })) + + ioRegistry := procRegistry.WithPrefix("io") + ioRegistry.FuncCounter(r.ComposeName("read", "count"), c.RegisteredCounter(func() int64 { + return int64(ioStat.SyscR) + })) + ioRegistry.FuncCounter(r.ComposeName("read", "bytes"), c.RegisteredCounter(func() int64 { + return int64(ioStat.ReadBytes) + })) + ioRegistry.FuncCounter(r.ComposeName("write", "count"), c.RegisteredCounter(func() int64 { + return int64(ioStat.SyscW) + })) + ioRegistry.FuncCounter(r.ComposeName("write", "bytes"), c.RegisteredCounter(func() int64 { + return int64(ioStat.WriteBytes) + })) + + cpuRegistry := procRegistry.WithPrefix("cpu") + cpuRegistry.FuncCounter(r.ComposeName("total", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.UTime+procStat.STime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("user", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.UTime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("system", "ns"), c.RegisteredCounter(func() int64 { + return int64(procStat.STime) * (1_000_000_000 / clocksPerSec) + })) + cpuRegistry.FuncCounter(r.ComposeName("wait", "ns"), c.RegisteredCounter(func() int64 { + return int64(cpuWait) + })) + + procRegistry.FuncGauge(r.ComposeName("mem", "rss"), c.RegisteredGauge(func() float64 { + return float64(procStat.RSS) + })) +} diff --git a/library/go/core/metrics/internal/pkg/metricsutil/buckets.go b/library/go/core/metrics/internal/pkg/metricsutil/buckets.go new file mode 100644 index 0000000000..5db605cd4d --- /dev/null +++ b/library/go/core/metrics/internal/pkg/metricsutil/buckets.go @@ -0,0 +1,27 @@ +package metricsutil + +import ( + "sort" + + "a.yandex-team.ru/library/go/core/metrics" +) + +// BucketsBounds unwraps Buckets bounds to slice of float64. +func BucketsBounds(b metrics.Buckets) []float64 { + bkts := make([]float64, b.Size()) + for i := range bkts { + bkts[i] = b.UpperBound(i) + } + sort.Float64s(bkts) + return bkts +} + +// DurationBucketsBounds unwraps DurationBuckets bounds to slice of float64. +func DurationBucketsBounds(b metrics.DurationBuckets) []float64 { + bkts := make([]float64, b.Size()) + for i := range bkts { + bkts[i] = b.UpperBound(i).Seconds() + } + sort.Float64s(bkts) + return bkts +} diff --git a/library/go/core/metrics/internal/pkg/registryutil/registryutil.go b/library/go/core/metrics/internal/pkg/registryutil/registryutil.go new file mode 100644 index 0000000000..ebce50d8cb --- /dev/null +++ b/library/go/core/metrics/internal/pkg/registryutil/registryutil.go @@ -0,0 +1,104 @@ +package registryutil + +import ( + "errors" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/OneOfOne/xxhash" +) + +// BuildRegistryKey creates registry name based on given prefix and tags +func BuildRegistryKey(prefix string, tags map[string]string) string { + var builder strings.Builder + + builder.WriteString(strconv.Quote(prefix)) + builder.WriteRune('{') + builder.WriteString(StringifyTags(tags)) + builder.WriteByte('}') + + return builder.String() +} + +// BuildFQName returns name parts joined by given separator. +// Mainly used to append prefix to registry +func BuildFQName(separator string, parts ...string) (name string) { + var b strings.Builder + for _, p := range parts { + if p == "" { + continue + } + if b.Len() > 0 { + b.WriteString(separator) + } + b.WriteString(strings.Trim(p, separator)) + } + return b.String() +} + +// MergeTags merges 2 sets of tags with the tags from tagsRight overriding values from tagsLeft +func MergeTags(leftTags map[string]string, rightTags map[string]string) map[string]string { + if leftTags == nil && rightTags == nil { + return nil + } + + if len(leftTags) == 0 { + return rightTags + } + + if len(rightTags) == 0 { + return leftTags + } + + newTags := make(map[string]string) + for key, value := range leftTags { + newTags[key] = value + } + for key, value := range rightTags { + newTags[key] = value + } + return newTags +} + +// StringifyTags returns string representation of given tags map. +// It is guaranteed that equal sets of tags will produce equal strings. +func StringifyTags(tags map[string]string) string { + keys := make([]string, 0, len(tags)) + for key := range tags { + keys = append(keys, key) + } + sort.Strings(keys) + + var builder strings.Builder + for i, key := range keys { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(key + "=" + tags[key]) + } + + return builder.String() +} + +// VectorHash computes hash of metrics vector element +func VectorHash(tags map[string]string, labels []string) (uint64, error) { + if len(tags) != len(labels) { + return 0, errors.New("inconsistent tags and labels sets") + } + + h := xxhash.New64() + + for _, label := range labels { + v, ok := tags[label] + if !ok { + return 0, fmt.Errorf("label '%s' not found in tags", label) + } + _, _ = h.WriteString(label) + _, _ = h.WriteString(v) + _, _ = h.WriteString(",") + } + + return h.Sum64(), nil +} diff --git a/library/go/core/metrics/metrics.go b/library/go/core/metrics/metrics.go new file mode 100644 index 0000000000..0eb436046b --- /dev/null +++ b/library/go/core/metrics/metrics.go @@ -0,0 +1,140 @@ +// Package metrics provides interface collecting performance metrics. +package metrics + +import ( + "context" + "time" +) + +// Gauge tracks single float64 value. +type Gauge interface { + Set(value float64) + Add(value float64) +} + +// FuncGauge is Gauge with value provided by callback function. +type FuncGauge interface { + Function() func() float64 +} + +// Counter tracks monotonically increasing value. +type Counter interface { + // Inc increments counter by 1. + Inc() + + // Add adds delta to the counter. Delta must be >=0. + Add(delta int64) +} + +// FuncCounter is Counter with value provided by callback function. +type FuncCounter interface { + Function() func() int64 +} + +// Histogram tracks distribution of value. +type Histogram interface { + RecordValue(value float64) +} + +// Timer measures durations. +type Timer interface { + RecordDuration(value time.Duration) +} + +// DurationBuckets defines buckets of the duration histogram. +type DurationBuckets interface { + // Size returns number of buckets. + Size() int + + // MapDuration returns index of the bucket. + // + // index is integer in range [0, Size()). + MapDuration(d time.Duration) int + + // UpperBound of the last bucket is always +Inf. + // + // bucketIndex is integer in range [0, Size()-1). + UpperBound(bucketIndex int) time.Duration +} + +// Buckets defines intervals of the regular histogram. +type Buckets interface { + // Size returns number of buckets. + Size() int + + // MapValue returns index of the bucket. + // + // Index is integer in range [0, Size()). + MapValue(v float64) int + + // UpperBound of the last bucket is always +Inf. + // + // bucketIndex is integer in range [0, Size()-1). + UpperBound(bucketIndex int) float64 +} + +// GaugeVec stores multiple dynamically created gauges. +type GaugeVec interface { + With(map[string]string) Gauge + + // Reset deletes all metrics in vector. + Reset() +} + +// CounterVec stores multiple dynamically created counters. +type CounterVec interface { + With(map[string]string) Counter + + // Reset deletes all metrics in vector. + Reset() +} + +// TimerVec stores multiple dynamically created timers. +type TimerVec interface { + With(map[string]string) Timer + + // Reset deletes all metrics in vector. + Reset() +} + +// HistogramVec stores multiple dynamically created histograms. +type HistogramVec interface { + With(map[string]string) Histogram + + // Reset deletes all metrics in vector. + Reset() +} + +// Registry creates profiling metrics. +type Registry interface { + // WithTags creates new sub-scope, where each metric has tags attached to it. + WithTags(tags map[string]string) Registry + // WithPrefix creates new sub-scope, where each metric has prefix added to it name. + WithPrefix(prefix string) Registry + + ComposeName(parts ...string) string + + Counter(name string) Counter + CounterVec(name string, labels []string) CounterVec + FuncCounter(name string, function func() int64) FuncCounter + + Gauge(name string) Gauge + GaugeVec(name string, labels []string) GaugeVec + FuncGauge(name string, function func() float64) FuncGauge + + Timer(name string) Timer + TimerVec(name string, labels []string) TimerVec + + Histogram(name string, buckets Buckets) Histogram + HistogramVec(name string, buckets Buckets, labels []string) HistogramVec + + DurationHistogram(name string, buckets DurationBuckets) Timer + DurationHistogramVec(name string, buckets DurationBuckets, labels []string) TimerVec +} + +// CollectPolicy defines how registered gauge metrics are updated via collect func. +type CollectPolicy interface { + RegisteredCounter(counterFunc func() int64) func() int64 + RegisteredGauge(gaugeFunc func() float64) func() float64 + AddCollect(collect func(ctx context.Context)) +} diff --git a/library/go/core/metrics/solomon/converter.go b/library/go/core/metrics/solomon/converter.go new file mode 100644 index 0000000000..4458d1a932 --- /dev/null +++ b/library/go/core/metrics/solomon/converter.go @@ -0,0 +1,73 @@ +package solomon + +import ( + "fmt" + + dto "github.com/prometheus/client_model/go" + "go.uber.org/atomic" +) + +// PrometheusMetrics converts Prometheus metrics to Solomon metrics. +func PrometheusMetrics(metrics []*dto.MetricFamily) (*Metrics, error) { + s := &Metrics{ + metrics: make([]Metric, 0, len(metrics)), + } + + if len(metrics) == 0 { + return s, nil + } + + for _, mf := range metrics { + if len(mf.Metric) == 0 { + continue + } + + metric := mf.Metric[0] + + tags := make(map[string]string, len(metric.Label)) + for _, label := range metric.Label { + tags[label.GetName()] = label.GetValue() + } + + switch *mf.Type { + case dto.MetricType_COUNTER: + s.metrics = append(s.metrics, &Counter{ + name: mf.GetName(), + tags: tags, + value: *atomic.NewInt64(int64(metric.Counter.GetValue())), + }) + case dto.MetricType_GAUGE: + s.metrics = append(s.metrics, &Gauge{ + name: mf.GetName(), + tags: tags, + value: *atomic.NewFloat64(metric.Gauge.GetValue()), + }) + case dto.MetricType_HISTOGRAM: + bounds := make([]float64, 0, len(metric.Histogram.Bucket)) + values := make([]int64, 0, len(metric.Histogram.Bucket)) + + var prevValuesSum int64 + + for _, bucket := range metric.Histogram.Bucket { + // prometheus uses cumulative buckets where solomon uses instant + bucketValue := int64(bucket.GetCumulativeCount()) + bucketValue -= prevValuesSum + prevValuesSum += bucketValue + + bounds = append(bounds, bucket.GetUpperBound()) + values = append(values, bucketValue) + } + + s.metrics = append(s.metrics, &Histogram{ + name: mf.GetName(), + tags: tags, + bucketBounds: bounds, + bucketValues: values, + }) + default: + return nil, fmt.Errorf("unsupported type: %s", mf.Type.String()) + } + } + + return s, nil +} diff --git a/library/go/core/metrics/solomon/counter.go b/library/go/core/metrics/solomon/counter.go new file mode 100644 index 0000000000..64ea1b47ca --- /dev/null +++ b/library/go/core/metrics/solomon/counter.go @@ -0,0 +1,98 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/metrics" +) + +var ( + _ metrics.Counter = (*Counter)(nil) + _ Metric = (*Counter)(nil) +) + +// Counter tracks monotonically increasing value. +type Counter struct { + name string + metricType metricType + tags map[string]string + value atomic.Int64 + timestamp *time.Time + + useNameTag bool +} + +// Inc increments counter by 1. +func (c *Counter) Inc() { + c.Add(1) +} + +// Add adds delta to the counter. Delta must be >=0. +func (c *Counter) Add(delta int64) { + c.value.Add(delta) +} + +func (c *Counter) Name() string { + return c.name +} + +func (c *Counter) getType() metricType { + return c.metricType +} + +func (c *Counter) getLabels() map[string]string { + return c.tags +} + +func (c *Counter) getValue() interface{} { + return c.value.Load() +} + +func (c *Counter) getTimestamp() *time.Time { + return c.timestamp +} + +func (c *Counter) getNameTag() string { + if c.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (c *Counter) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: c.metricType.String(), + Value: c.value.Load(), + Labels: func() map[string]string { + labels := make(map[string]string, len(c.tags)+1) + labels[c.getNameTag()] = c.Name() + for k, v := range c.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(c.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (c *Counter) Snapshot() Metric { + return &Counter{ + name: c.name, + metricType: c.metricType, + tags: c.tags, + value: *atomic.NewInt64(c.value.Load()), + + useNameTag: c.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/func_counter.go b/library/go/core/metrics/solomon/func_counter.go new file mode 100644 index 0000000000..db862869e4 --- /dev/null +++ b/library/go/core/metrics/solomon/func_counter.go @@ -0,0 +1,86 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" +) + +var _ Metric = (*FuncCounter)(nil) + +// FuncCounter tracks int64 value returned by function. +type FuncCounter struct { + name string + metricType metricType + tags map[string]string + function func() int64 + timestamp *time.Time + useNameTag bool +} + +func (c *FuncCounter) Name() string { + return c.name +} + +func (c *FuncCounter) Function() func() int64 { + return c.function +} + +func (c *FuncCounter) getType() metricType { + return c.metricType +} + +func (c *FuncCounter) getLabels() map[string]string { + return c.tags +} + +func (c *FuncCounter) getValue() interface{} { + return c.function() +} + +func (c *FuncCounter) getTimestamp() *time.Time { + return c.timestamp +} + +func (c *FuncCounter) getNameTag() string { + if c.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (c *FuncCounter) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value int64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: c.metricType.String(), + Value: c.function(), + Labels: func() map[string]string { + labels := make(map[string]string, len(c.tags)+1) + labels[c.getNameTag()] = c.Name() + for k, v := range c.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(c.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (c *FuncCounter) Snapshot() Metric { + return &Counter{ + name: c.name, + metricType: c.metricType, + tags: c.tags, + value: *atomic.NewInt64(c.function()), + + useNameTag: c.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/func_gauge.go b/library/go/core/metrics/solomon/func_gauge.go new file mode 100644 index 0000000000..ce824c6fa8 --- /dev/null +++ b/library/go/core/metrics/solomon/func_gauge.go @@ -0,0 +1,87 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" +) + +var _ Metric = (*FuncGauge)(nil) + +// FuncGauge tracks float64 value returned by function. +type FuncGauge struct { + name string + metricType metricType + tags map[string]string + function func() float64 + timestamp *time.Time + + useNameTag bool +} + +func (g *FuncGauge) Name() string { + return g.name +} + +func (g *FuncGauge) Function() func() float64 { + return g.function +} + +func (g *FuncGauge) getType() metricType { + return g.metricType +} + +func (g *FuncGauge) getLabels() map[string]string { + return g.tags +} + +func (g *FuncGauge) getValue() interface{} { + return g.function() +} + +func (g *FuncGauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *FuncGauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *FuncGauge) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: g.metricType.String(), + Value: g.function(), + Labels: func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (g *FuncGauge) Snapshot() Metric { + return &Gauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewFloat64(g.function()), + + useNameTag: g.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/gauge.go b/library/go/core/metrics/solomon/gauge.go new file mode 100644 index 0000000000..4d7e17195d --- /dev/null +++ b/library/go/core/metrics/solomon/gauge.go @@ -0,0 +1,116 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/metrics" +) + +var ( + _ metrics.Gauge = (*Gauge)(nil) + _ Metric = (*Gauge)(nil) +) + +// Gauge tracks single float64 value. +type Gauge struct { + name string + metricType metricType + tags map[string]string + value atomic.Float64 + timestamp *time.Time + + useNameTag bool +} + +func NewGauge(name string, value float64, opts ...metricOpts) Gauge { + mOpts := MetricsOpts{} + for _, op := range opts { + op(&mOpts) + } + return Gauge{ + name: name, + metricType: typeGauge, + tags: mOpts.tags, + value: *atomic.NewFloat64(value), + useNameTag: mOpts.useNameTag, + timestamp: mOpts.timestamp, + } +} + +func (g *Gauge) Set(value float64) { + g.value.Store(value) +} + +func (g *Gauge) Add(value float64) { + g.value.Add(value) +} + +func (g *Gauge) Name() string { + return g.name +} + +func (g *Gauge) getType() metricType { + return g.metricType +} + +func (g *Gauge) getLabels() map[string]string { + return g.tags +} + +func (g *Gauge) getValue() interface{} { + return g.value.Load() +} + +func (g *Gauge) getTimestamp() *time.Time { + return g.timestamp +} + +func (g *Gauge) getNameTag() string { + if g.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (g *Gauge) MarshalJSON() ([]byte, error) { + metricType := g.metricType.String() + value := g.value.Load() + labels := func() map[string]string { + labels := make(map[string]string, len(g.tags)+1) + labels[g.getNameTag()] = g.Name() + for k, v := range g.tags { + labels[k] = v + } + return labels + }() + + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: metricType, + Value: value, + Labels: labels, + Timestamp: tsAsRef(g.timestamp), + }) +} + +// Snapshot returns independent copy of metric. +func (g *Gauge) Snapshot() Metric { + return &Gauge{ + name: g.name, + metricType: g.metricType, + tags: g.tags, + value: *atomic.NewFloat64(g.value.Load()), + + useNameTag: g.useNameTag, + timestamp: g.timestamp, + } +} diff --git a/library/go/core/metrics/solomon/histogram.go b/library/go/core/metrics/solomon/histogram.go new file mode 100644 index 0000000000..574aeaccf6 --- /dev/null +++ b/library/go/core/metrics/solomon/histogram.go @@ -0,0 +1,177 @@ +package solomon + +import ( + "encoding/binary" + "encoding/json" + "io" + "sort" + "sync" + "time" + + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/metrics" + "a.yandex-team.ru/library/go/core/xerrors" +) + +var ( + _ metrics.Histogram = (*Histogram)(nil) + _ metrics.Timer = (*Histogram)(nil) + _ Metric = (*Histogram)(nil) +) + +type Histogram struct { + name string + metricType metricType + tags map[string]string + bucketBounds []float64 + bucketValues []int64 + infValue atomic.Int64 + mutex sync.Mutex + timestamp *time.Time + useNameTag bool +} + +type histogram struct { + Bounds []float64 `json:"bounds"` + Buckets []int64 `json:"buckets"` + Inf int64 `json:"inf,omitempty"` +} + +func (h *histogram) writeHistogram(w io.Writer) error { + err := writeULEB128(w, uint32(len(h.Buckets))) + if err != nil { + return xerrors.Errorf("writeULEB128 size histogram buckets failed: %w", err) + } + + for _, upperBound := range h.Bounds { + err = binary.Write(w, binary.LittleEndian, float64(upperBound)) + if err != nil { + return xerrors.Errorf("binary.Write upper bound failed: %w", err) + } + } + + for _, bucketValue := range h.Buckets { + err = binary.Write(w, binary.LittleEndian, uint64(bucketValue)) + if err != nil { + return xerrors.Errorf("binary.Write histogram buckets failed: %w", err) + } + } + return nil +} + +func (h *Histogram) RecordValue(value float64) { + boundIndex := sort.SearchFloat64s(h.bucketBounds, value) + + if boundIndex < len(h.bucketValues) { + h.mutex.Lock() + h.bucketValues[boundIndex] += 1 + h.mutex.Unlock() + } else { + h.infValue.Inc() + } +} + +func (h *Histogram) RecordDuration(value time.Duration) { + h.RecordValue(value.Seconds()) +} + +func (h *Histogram) Reset() { + h.mutex.Lock() + defer h.mutex.Unlock() + + h.bucketValues = make([]int64, len(h.bucketValues)) + h.infValue.Store(0) +} + +func (h *Histogram) Name() string { + return h.name +} + +func (h *Histogram) getType() metricType { + return h.metricType +} + +func (h *Histogram) getLabels() map[string]string { + return h.tags +} + +func (h *Histogram) getValue() interface{} { + return histogram{ + Bounds: h.bucketBounds, + Buckets: h.bucketValues, + } +} + +func (h *Histogram) getTimestamp() *time.Time { + return h.timestamp +} + +func (h *Histogram) getNameTag() string { + if h.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (h *Histogram) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Histogram histogram `json:"hist"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: h.metricType.String(), + Histogram: histogram{ + Bounds: h.bucketBounds, + Buckets: h.bucketValues, + Inf: h.infValue.Load(), + }, + Labels: func() map[string]string { + labels := make(map[string]string, len(h.tags)+1) + labels[h.getNameTag()] = h.Name() + for k, v := range h.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(h.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (h *Histogram) Snapshot() Metric { + bucketBounds := make([]float64, len(h.bucketBounds)) + bucketValues := make([]int64, len(h.bucketValues)) + + copy(bucketBounds, h.bucketBounds) + copy(bucketValues, h.bucketValues) + + return &Histogram{ + name: h.name, + metricType: h.metricType, + tags: h.tags, + bucketBounds: bucketBounds, + bucketValues: bucketValues, + infValue: *atomic.NewInt64(h.infValue.Load()), + useNameTag: h.useNameTag, + } +} + +// InitBucketValues cleans internal bucketValues and saves new values in order. +// Length of internal bucketValues stays unchanged. +// If length of slice in argument bucketValues more than length of internal one, +// the first extra element of bucketValues is stored in infValue. +func (h *Histogram) InitBucketValues(bucketValues []int64) { + h.mutex.Lock() + defer h.mutex.Unlock() + + h.bucketValues = make([]int64, len(h.bucketValues)) + h.infValue.Store(0) + copy(h.bucketValues, bucketValues) + if len(bucketValues) > len(h.bucketValues) { + h.infValue.Store(bucketValues[len(h.bucketValues)]) + } +} diff --git a/library/go/core/metrics/solomon/metrics.go b/library/go/core/metrics/solomon/metrics.go new file mode 100644 index 0000000000..7f4bf4b5ec --- /dev/null +++ b/library/go/core/metrics/solomon/metrics.go @@ -0,0 +1,178 @@ +package solomon + +import ( + "bytes" + "context" + "encoding" + "encoding/json" + "fmt" + "time" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +// Gather collects all metrics data via snapshots. +func (r Registry) Gather() (*Metrics, error) { + metrics := make([]Metric, 0) + + var err error + r.metrics.Range(func(_, v interface{}) bool { + if s, ok := v.(Metric); ok { + metrics = append(metrics, s.Snapshot()) + return true + } + err = fmt.Errorf("unexpected value type: %T", v) + return false + }) + + if err != nil { + return nil, err + } + + return &Metrics{metrics: metrics}, nil +} + +func NewMetrics(metrics []Metric) Metrics { + return Metrics{metrics: metrics} +} + +func NewMetricsWithTimestamp(metrics []Metric, ts time.Time) Metrics { + return Metrics{metrics: metrics, timestamp: &ts} +} + +type valueType uint8 + +const ( + valueTypeNone valueType = iota + valueTypeOneWithoutTS valueType = 0x01 + valueTypeOneWithTS valueType = 0x02 + valueTypeManyWithTS valueType = 0x03 +) + +type metricType uint8 + +const ( + typeUnspecified metricType = iota + typeGauge metricType = 0x01 + typeCounter metricType = 0x02 + typeRated metricType = 0x03 + typeHistogram metricType = 0x05 + typeRatedHistogram metricType = 0x06 +) + +func (k metricType) String() string { + switch k { + case typeCounter: + return "COUNTER" + case typeGauge: + return "DGAUGE" + case typeHistogram: + return "HIST" + case typeRated: + return "RATE" + case typeRatedHistogram: + return "HIST_RATE" + default: + panic("unknown metric type") + } +} + +// Metric is an any abstract solomon Metric. +type Metric interface { + json.Marshaler + + Name() string + getType() metricType + getLabels() map[string]string + getValue() interface{} + getNameTag() string + getTimestamp() *time.Time + + Snapshot() Metric +} + +// Rated marks given Solomon metric or vector as rated. +// Example: +// +// cnt := r.Counter("mycounter") +// Rated(cnt) +// +// cntvec := r.CounterVec("mycounter", []string{"mytag"}) +// Rated(cntvec) +// +// For additional info: https://docs.yandex-team.ru/solomon/data-collection/dataformat/json +func Rated(s interface{}) { + switch st := s.(type) { + case *Counter: + st.metricType = typeRated + case *FuncCounter: + st.metricType = typeRated + case *Histogram: + st.metricType = typeRatedHistogram + + case *CounterVec: + st.vec.rated = true + case *HistogramVec: + st.vec.rated = true + case *DurationHistogramVec: + st.vec.rated = true + } + // any other metrics types are unrateable +} + +var ( + _ json.Marshaler = (*Metrics)(nil) + _ encoding.BinaryMarshaler = (*Metrics)(nil) +) + +type Metrics struct { + metrics []Metric + timestamp *time.Time +} + +// MarshalJSON implements json.Marshaler. +func (s Metrics) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Metrics []Metric `json:"metrics"` + Timestamp *int64 `json:"ts,omitempty"` + }{s.metrics, tsAsRef(s.timestamp)}) +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (s Metrics) MarshalBinary() ([]byte, error) { + var buf bytes.Buffer + se := NewSpackEncoder(context.Background(), CompressionNone, &s) + n, err := se.Encode(&buf) + if err != nil { + return nil, xerrors.Errorf("encode only %d bytes: %w", n, err) + } + return buf.Bytes(), nil +} + +// SplitToChunks splits Metrics into a slice of chunks, each at most maxChunkSize long. +// The length of returned slice is always at least one. +// Zero maxChunkSize denotes unlimited chunk length. +func (s Metrics) SplitToChunks(maxChunkSize int) []Metrics { + if maxChunkSize == 0 || len(s.metrics) == 0 { + return []Metrics{s} + } + chunks := make([]Metrics, 0, len(s.metrics)/maxChunkSize+1) + + for leftBound := 0; leftBound < len(s.metrics); leftBound += maxChunkSize { + rightBound := leftBound + maxChunkSize + if rightBound > len(s.metrics) { + rightBound = len(s.metrics) + } + chunk := s.metrics[leftBound:rightBound] + chunks = append(chunks, Metrics{metrics: chunk}) + } + return chunks +} + +func tsAsRef(t *time.Time) *int64 { + if t == nil { + return nil + } + ts := t.Unix() + return &ts +} diff --git a/library/go/core/metrics/solomon/metrics_opts.go b/library/go/core/metrics/solomon/metrics_opts.go new file mode 100644 index 0000000000..d9ade67966 --- /dev/null +++ b/library/go/core/metrics/solomon/metrics_opts.go @@ -0,0 +1,29 @@ +package solomon + +import "time" + +type MetricsOpts struct { + useNameTag bool + tags map[string]string + timestamp *time.Time +} + +type metricOpts func(*MetricsOpts) + +func WithTags(tags map[string]string) func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.tags = tags + } +} + +func WithUseNameTag() func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.useNameTag = true + } +} + +func WithTimestamp(t time.Time) func(*MetricsOpts) { + return func(m *MetricsOpts) { + m.timestamp = &t + } +} diff --git a/library/go/core/metrics/solomon/registry.go b/library/go/core/metrics/solomon/registry.go new file mode 100644 index 0000000000..cdd7489843 --- /dev/null +++ b/library/go/core/metrics/solomon/registry.go @@ -0,0 +1,221 @@ +package solomon + +import ( + "reflect" + "strconv" + "sync" + + "a.yandex-team.ru/library/go/core/metrics" + "a.yandex-team.ru/library/go/core/metrics/internal/pkg/metricsutil" + "a.yandex-team.ru/library/go/core/metrics/internal/pkg/registryutil" +) + +var _ metrics.Registry = (*Registry)(nil) + +type Registry struct { + separator string + prefix string + tags map[string]string + rated bool + useNameTag bool + + subregistries map[string]*Registry + m *sync.Mutex + + metrics *sync.Map +} + +func NewRegistry(opts *RegistryOpts) *Registry { + r := &Registry{ + separator: ".", + useNameTag: false, + + subregistries: make(map[string]*Registry), + m: new(sync.Mutex), + + metrics: new(sync.Map), + } + + if opts != nil { + r.separator = string(opts.Separator) + r.prefix = opts.Prefix + r.tags = opts.Tags + r.rated = opts.Rated + r.useNameTag = opts.UseNameTag + for _, collector := range opts.Collectors { + collector(r) + } + } + + return r +} + +// Rated returns copy of registry with rated set to desired value. +func (r Registry) Rated(rated bool) metrics.Registry { + return &Registry{ + separator: r.separator, + prefix: r.prefix, + tags: r.tags, + rated: rated, + useNameTag: r.useNameTag, + + subregistries: r.subregistries, + m: r.m, + + metrics: r.metrics, + } +} + +// WithTags creates new sub-scope, where each metric has tags attached to it. +func (r Registry) WithTags(tags map[string]string) metrics.Registry { + return r.newSubregistry(r.prefix, registryutil.MergeTags(r.tags, tags)) +} + +// WithPrefix creates new sub-scope, where each metric has prefix added to it name. +func (r Registry) WithPrefix(prefix string) metrics.Registry { + return r.newSubregistry(registryutil.BuildFQName(r.separator, r.prefix, prefix), r.tags) +} + +// ComposeName builds FQ name with appropriate separator. +func (r Registry) ComposeName(parts ...string) string { + return registryutil.BuildFQName(r.separator, parts...) +} + +func (r Registry) Counter(name string) metrics.Counter { + s := &Counter{ + name: r.newMetricName(name), + metricType: typeCounter, + tags: r.tags, + + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Counter) +} + +func (r Registry) FuncCounter(name string, function func() int64) metrics.FuncCounter { + s := &FuncCounter{ + name: r.newMetricName(name), + metricType: typeCounter, + tags: r.tags, + function: function, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.FuncCounter) +} + +func (r Registry) Gauge(name string) metrics.Gauge { + s := &Gauge{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Gauge) +} + +func (r Registry) FuncGauge(name string, function func() float64) metrics.FuncGauge { + s := &FuncGauge{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + function: function, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.FuncGauge) +} + +func (r Registry) Timer(name string) metrics.Timer { + s := &Timer{ + name: r.newMetricName(name), + metricType: typeGauge, + tags: r.tags, + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Timer) +} + +func (r Registry) Histogram(name string, buckets metrics.Buckets) metrics.Histogram { + s := &Histogram{ + name: r.newMetricName(name), + metricType: typeHistogram, + tags: r.tags, + bucketBounds: metricsutil.BucketsBounds(buckets), + bucketValues: make([]int64, buckets.Size()), + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Histogram) +} + +func (r Registry) DurationHistogram(name string, buckets metrics.DurationBuckets) metrics.Timer { + s := &Histogram{ + name: r.newMetricName(name), + metricType: typeHistogram, + tags: r.tags, + bucketBounds: metricsutil.DurationBucketsBounds(buckets), + bucketValues: make([]int64, buckets.Size()), + useNameTag: r.useNameTag, + } + + return r.registerMetric(s).(metrics.Timer) +} + +func (r *Registry) newSubregistry(prefix string, tags map[string]string) *Registry { + // differ simple and rated registries + keyTags := registryutil.MergeTags(tags, map[string]string{"rated": strconv.FormatBool(r.rated)}) + registryKey := registryutil.BuildRegistryKey(prefix, keyTags) + + r.m.Lock() + defer r.m.Unlock() + + if existing, ok := r.subregistries[registryKey]; ok { + return existing + } + + subregistry := &Registry{ + separator: r.separator, + prefix: prefix, + tags: tags, + rated: r.rated, + useNameTag: r.useNameTag, + + subregistries: r.subregistries, + m: r.m, + + metrics: r.metrics, + } + + r.subregistries[registryKey] = subregistry + return subregistry +} + +func (r *Registry) newMetricName(name string) string { + return registryutil.BuildFQName(r.separator, r.prefix, name) +} + +func (r *Registry) registerMetric(s Metric) Metric { + if r.rated { + Rated(s) + } + + // differ simple and rated registries + keyTags := registryutil.MergeTags(r.tags, map[string]string{"rated": strconv.FormatBool(r.rated)}) + key := registryutil.BuildRegistryKey(s.Name(), keyTags) + + oldMetric, loaded := r.metrics.LoadOrStore(key, s) + if !loaded { + return s + } + + if reflect.TypeOf(oldMetric) == reflect.TypeOf(s) { + return oldMetric.(Metric) + } else { + r.metrics.Store(key, s) + return s + } +} diff --git a/library/go/core/metrics/solomon/registry_opts.go b/library/go/core/metrics/solomon/registry_opts.go new file mode 100644 index 0000000000..d2d19718ee --- /dev/null +++ b/library/go/core/metrics/solomon/registry_opts.go @@ -0,0 +1,87 @@ +package solomon + +import ( + "context" + + "a.yandex-team.ru/library/go/core/metrics" + "a.yandex-team.ru/library/go/core/metrics/collect" + "a.yandex-team.ru/library/go/core/metrics/internal/pkg/registryutil" +) + +type RegistryOpts struct { + Separator rune + Prefix string + Tags map[string]string + Rated bool + UseNameTag bool + Collectors []func(metrics.Registry) +} + +// NewRegistryOpts returns new initialized instance of RegistryOpts +func NewRegistryOpts() *RegistryOpts { + return &RegistryOpts{ + Separator: '.', + Tags: make(map[string]string), + UseNameTag: false, + } +} + +// SetUseNameTag overrides current UseNameTag opt +func (o *RegistryOpts) SetUseNameTag(useNameTag bool) *RegistryOpts { + o.UseNameTag = useNameTag + return o +} + +// SetTags overrides existing tags +func (o *RegistryOpts) SetTags(tags map[string]string) *RegistryOpts { + o.Tags = tags + return o +} + +// AddTags merges given tags with existing +func (o *RegistryOpts) AddTags(tags map[string]string) *RegistryOpts { + for k, v := range tags { + o.Tags[k] = v + } + return o +} + +// SetPrefix overrides existing prefix +func (o *RegistryOpts) SetPrefix(prefix string) *RegistryOpts { + o.Prefix = prefix + return o +} + +// AppendPrefix adds given prefix as postfix to existing using separator +func (o *RegistryOpts) AppendPrefix(prefix string) *RegistryOpts { + o.Prefix = registryutil.BuildFQName(string(o.Separator), o.Prefix, prefix) + return o +} + +// SetSeparator overrides existing separator +func (o *RegistryOpts) SetSeparator(separator rune) *RegistryOpts { + o.Separator = separator + return o +} + +// SetRated overrides existing rated flag +func (o *RegistryOpts) SetRated(rated bool) *RegistryOpts { + o.Rated = rated + return o +} + +// AddCollectors adds collectors that handle their metrics automatically (e.g. system metrics). +func (o *RegistryOpts) AddCollectors( + ctx context.Context, c metrics.CollectPolicy, collectors ...collect.Func, +) *RegistryOpts { + if len(collectors) == 0 { + return o + } + + o.Collectors = append(o.Collectors, func(r metrics.Registry) { + for _, collector := range collectors { + collector(ctx, r, c) + } + }) + return o +} diff --git a/library/go/core/metrics/solomon/spack.go b/library/go/core/metrics/solomon/spack.go new file mode 100644 index 0000000000..9dc0434716 --- /dev/null +++ b/library/go/core/metrics/solomon/spack.go @@ -0,0 +1,340 @@ +package solomon + +import ( + "bytes" + "context" + "encoding/binary" + "io" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +type errWriter struct { + w io.Writer + err error +} + +func (ew *errWriter) binaryWrite(data interface{}) { + if ew.err != nil { + return + } + switch t := data.(type) { + case uint8: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint8)) + case uint16: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint16)) + case uint32: + ew.err = binary.Write(ew.w, binary.LittleEndian, data.(uint32)) + default: + ew.err = xerrors.Errorf("binaryWrite not supported type %v", t) + } +} + +func writeULEB128(w io.Writer, value uint32) error { + remaining := value >> 7 + for remaining != 0 { + err := binary.Write(w, binary.LittleEndian, uint8(value&0x7f|0x80)) + if err != nil { + return xerrors.Errorf("binary.Write failed: %w", err) + } + value = remaining + remaining >>= 7 + } + err := binary.Write(w, binary.LittleEndian, uint8(value&0x7f)) + if err != nil { + return xerrors.Errorf("binary.Write failed: %w", err) + } + return err +} + +type spackMetric struct { + flags uint8 + + labelsCount uint32 + labels bytes.Buffer + + metric Metric +} + +func (s *spackMetric) writeLabel(se *spackEncoder, namesIdx map[string]uint32, valuesIdx map[string]uint32, name string, value string) error { + s.labelsCount++ + + _, ok := namesIdx[name] + if !ok { + namesIdx[name] = se.nameCounter + se.nameCounter++ + _, err := se.labelNamePool.WriteString(name) + if err != nil { + return err + } + err = se.labelNamePool.WriteByte(0) + if err != nil { + return err + } + } + + _, ok = valuesIdx[value] + if !ok { + valuesIdx[value] = se.valueCounter + se.valueCounter++ + _, err := se.labelValuePool.WriteString(value) + if err != nil { + return err + } + err = se.labelValuePool.WriteByte(0) + if err != nil { + return err + } + } + + err := writeULEB128(&s.labels, uint32(namesIdx[name])) + if err != nil { + return err + } + err = writeULEB128(&s.labels, uint32(valuesIdx[value])) + if err != nil { + return err + } + + return nil +} + +func (s *spackMetric) writeMetric(w io.Writer) error { + metricValueType := valueTypeOneWithoutTS + if s.metric.getTimestamp() != nil { + metricValueType = valueTypeOneWithTS + } + // library/cpp/monlib/encode/spack/spack_v1_encoder.cpp?rev=r9098142#L190 + types := uint8(s.metric.getType()<<2) | uint8(metricValueType) + err := binary.Write(w, binary.LittleEndian, types) + if err != nil { + return xerrors.Errorf("binary.Write types failed: %w", err) + } + + err = binary.Write(w, binary.LittleEndian, uint8(s.flags)) + if err != nil { + return xerrors.Errorf("binary.Write flags failed: %w", err) + } + + err = writeULEB128(w, uint32(s.labelsCount)) + if err != nil { + return xerrors.Errorf("writeULEB128 labels count failed: %w", err) + } + + _, err = w.Write(s.labels.Bytes()) // s.writeLabels(w) + if err != nil { + return xerrors.Errorf("write labels failed: %w", err) + } + if s.metric.getTimestamp() != nil { + err = binary.Write(w, binary.LittleEndian, uint32(s.metric.getTimestamp().Unix())) + if err != nil { + return xerrors.Errorf("write timestamp failed: %w", err) + } + } + + switch s.metric.getType() { + case typeGauge: + err = binary.Write(w, binary.LittleEndian, s.metric.getValue().(float64)) + if err != nil { + return xerrors.Errorf("binary.Write gauge value failed: %w", err) + } + case typeCounter, typeRated: + err = binary.Write(w, binary.LittleEndian, uint64(s.metric.getValue().(int64))) + if err != nil { + return xerrors.Errorf("binary.Write counter value failed: %w", err) + } + case typeHistogram, typeRatedHistogram: + h := s.metric.getValue().(histogram) + err = h.writeHistogram(w) + if err != nil { + return xerrors.Errorf("writeHistogram failed: %w", err) + } + default: + return xerrors.Errorf("unknown metric type: %v", s.metric.getType()) + } + return nil +} + +type spackEncoder struct { + context context.Context + compression uint8 + + nameCounter uint32 + valueCounter uint32 + + labelNamePool bytes.Buffer + labelValuePool bytes.Buffer + + metrics Metrics +} + +func NewSpackEncoder(ctx context.Context, compression CompressionType, metrics *Metrics) *spackEncoder { + if metrics == nil { + metrics = &Metrics{} + } + return &spackEncoder{ + context: ctx, + compression: uint8(compression), + metrics: *metrics, + } +} + +func (se *spackEncoder) writeLabels() ([]spackMetric, error) { + namesIdx := make(map[string]uint32) + valuesIdx := make(map[string]uint32) + spackMetrics := make([]spackMetric, len(se.metrics.metrics)) + + for idx, metric := range se.metrics.metrics { + m := spackMetric{metric: metric} + + err := m.writeLabel(se, namesIdx, valuesIdx, metric.getNameTag(), metric.Name()) + if err != nil { + return nil, err + } + + for name, value := range metric.getLabels() { + if err := m.writeLabel(se, namesIdx, valuesIdx, name, value); err != nil { + return nil, err + } + + } + spackMetrics[idx] = m + } + + return spackMetrics, nil +} + +func (se *spackEncoder) Encode(w io.Writer) (written int, err error) { + spackMetrics, err := se.writeLabels() + if err != nil { + return written, xerrors.Errorf("writeLabels failed: %w", err) + } + + err = se.writeHeader(w) + if err != nil { + return written, xerrors.Errorf("writeHeader failed: %w", err) + } + written += HeaderSize + compression := CompressionType(se.compression) + + cw := newCompressedWriter(w, compression) + + err = se.writeLabelNamesPool(cw) + if err != nil { + return written, xerrors.Errorf("writeLabelNamesPool failed: %w", err) + } + + err = se.writeLabelValuesPool(cw) + if err != nil { + return written, xerrors.Errorf("writeLabelValuesPool failed: %w", err) + } + + err = se.writeCommonTime(cw) + if err != nil { + return written, xerrors.Errorf("writeCommonTime failed: %w", err) + } + + err = se.writeCommonLabels(cw) + if err != nil { + return written, xerrors.Errorf("writeCommonLabels failed: %w", err) + } + + err = se.writeMetricsData(cw, spackMetrics) + if err != nil { + return written, xerrors.Errorf("writeMetricsData failed: %w", err) + } + + err = cw.Close() + if err != nil { + return written, xerrors.Errorf("close failed: %w", err) + } + + switch compression { + case CompressionNone: + written += cw.(*noCompressionWriteCloser).written + case CompressionLz4: + written += cw.(*lz4CompressionWriteCloser).written + } + + return written, nil +} + +func (se *spackEncoder) writeHeader(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + ew := &errWriter{w: w} + ew.binaryWrite(uint16(0x5053)) // Magic + ew.binaryWrite(uint16(0x0101)) // Version + ew.binaryWrite(uint16(24)) // HeaderSize + ew.binaryWrite(uint8(0)) // TimePrecision(SECONDS) + ew.binaryWrite(uint8(se.compression)) // CompressionAlg + ew.binaryWrite(uint32(se.labelNamePool.Len())) // LabelNamesSize + ew.binaryWrite(uint32(se.labelValuePool.Len())) // LabelValuesSize + ew.binaryWrite(uint32(len(se.metrics.metrics))) // MetricsCount + ew.binaryWrite(uint32(len(se.metrics.metrics))) // PointsCount + if ew.err != nil { + return xerrors.Errorf("binaryWrite failed: %w", ew.err) + } + return nil +} + +func (se *spackEncoder) writeLabelNamesPool(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + _, err := w.Write(se.labelNamePool.Bytes()) + if err != nil { + return xerrors.Errorf("write labelNamePool failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeLabelValuesPool(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + _, err := w.Write(se.labelValuePool.Bytes()) + if err != nil { + return xerrors.Errorf("write labelValuePool failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeCommonTime(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + if se.metrics.timestamp == nil { + return binary.Write(w, binary.LittleEndian, uint32(0)) + } + return binary.Write(w, binary.LittleEndian, uint32(se.metrics.timestamp.Unix())) +} + +func (se *spackEncoder) writeCommonLabels(w io.Writer) error { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + _, err := w.Write([]byte{0}) + if err != nil { + return xerrors.Errorf("write commonLabels failed: %w", err) + } + return nil +} + +func (se *spackEncoder) writeMetricsData(w io.Writer, metrics []spackMetric) error { + for _, s := range metrics { + if se.context.Err() != nil { + return xerrors.Errorf("streamSpack context error: %w", se.context.Err()) + } + + err := s.writeMetric(w) + if err != nil { + return xerrors.Errorf("write metric failed: %w", err) + } + } + return nil +} diff --git a/library/go/core/metrics/solomon/spack_compression.go b/library/go/core/metrics/solomon/spack_compression.go new file mode 100644 index 0000000000..004fe0150d --- /dev/null +++ b/library/go/core/metrics/solomon/spack_compression.go @@ -0,0 +1,162 @@ +package solomon + +import ( + "encoding/binary" + "io" + + "github.com/OneOfOne/xxhash" + "github.com/pierrec/lz4" +) + +type CompressionType uint8 + +const ( + CompressionNone CompressionType = 0x0 + CompressionZlib CompressionType = 0x1 + CompressionZstd CompressionType = 0x2 + CompressionLz4 CompressionType = 0x3 +) + +const ( + compressionFrameLength = 512 * 1024 + hashTableSize = 64 * 1024 +) + +type noCompressionWriteCloser struct { + underlying io.Writer + written int +} + +func (w *noCompressionWriteCloser) Write(p []byte) (int, error) { + n, err := w.underlying.Write(p) + w.written += n + return n, err +} + +func (w *noCompressionWriteCloser) Close() error { + return nil +} + +type lz4CompressionWriteCloser struct { + underlying io.Writer + buffer []byte + table []int + written int +} + +func (w *lz4CompressionWriteCloser) flushFrame() (written int, err error) { + src := w.buffer + dst := make([]byte, lz4.CompressBlockBound(len(src))) + + sz, err := lz4.CompressBlock(src, dst, w.table) + if err != nil { + return written, err + } + + if sz == 0 { + dst = src + } else { + dst = dst[:sz] + } + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(len(dst))) + if err != nil { + return written, err + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(len(src))) + if err != nil { + return written, err + } + w.written += 4 + + n, err := w.underlying.Write(dst) + if err != nil { + return written, err + } + w.written += n + + checksum := xxhash.Checksum32S(dst, 0x1337c0de) + err = binary.Write(w.underlying, binary.LittleEndian, checksum) + if err != nil { + return written, err + } + w.written += 4 + + w.buffer = w.buffer[:0] + + return written, nil +} + +func (w *lz4CompressionWriteCloser) Write(p []byte) (written int, err error) { + q := p[:] + for len(q) > 0 { + space := compressionFrameLength - len(w.buffer) + if space == 0 { + n, err := w.flushFrame() + if err != nil { + return written, err + } + w.written += n + space = compressionFrameLength + } + length := len(q) + if length > space { + length = space + } + w.buffer = append(w.buffer, q[:length]...) + q = q[length:] + } + return written, nil +} + +func (w *lz4CompressionWriteCloser) Close() error { + var err error + if len(w.buffer) > 0 { + n, err := w.flushFrame() + if err != nil { + return err + } + w.written += n + } + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + err = binary.Write(w.underlying, binary.LittleEndian, uint32(0)) + if err != nil { + return nil + } + w.written += 4 + + return nil +} + +func newCompressedWriter(w io.Writer, compression CompressionType) io.WriteCloser { + switch compression { + case CompressionNone: + return &noCompressionWriteCloser{w, 0} + case CompressionZlib: + panic("zlib compression not supported") + case CompressionZstd: + panic("zstd compression not supported") + case CompressionLz4: + return &lz4CompressionWriteCloser{ + w, + make([]byte, 0, compressionFrameLength), + make([]int, hashTableSize), + 0, + } + default: + panic("unsupported compression algorithm") + } +} diff --git a/library/go/core/metrics/solomon/stream.go b/library/go/core/metrics/solomon/stream.go new file mode 100644 index 0000000000..26dc768c98 --- /dev/null +++ b/library/go/core/metrics/solomon/stream.go @@ -0,0 +1,89 @@ +package solomon + +import ( + "context" + "encoding/json" + "io" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const HeaderSize = 24 + +type StreamFormat string + +func (r *Registry) StreamJSON(ctx context.Context, w io.Writer) (written int, err error) { + cw := newCompressedWriter(w, CompressionNone) + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + _, err = cw.Write([]byte("{\"metrics\":[")) + if err != nil { + return written, xerrors.Errorf("write metrics failed: %w", err) + } + + first := true + r.metrics.Range(func(_, s interface{}) bool { + if ctx.Err() != nil { + err = xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + return false + } + + // write trailing comma + if !first { + _, err = cw.Write([]byte(",")) + if err != nil { + err = xerrors.Errorf("write metrics failed: %w", err) + return false + } + } + + var b []byte + + b, err = json.Marshal(s) + if err != nil { + err = xerrors.Errorf("marshal metric failed: %w", err) + return false + } + + // write metric json + _, err = cw.Write(b) + if err != nil { + err = xerrors.Errorf("write metric failed: %w", err) + return false + } + + first = false + return true + }) + if err != nil { + return written, err + } + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + _, err = cw.Write([]byte("]}")) + if err != nil { + return written, xerrors.Errorf("write metrics failed: %w", err) + } + + if ctx.Err() != nil { + return written, xerrors.Errorf("streamJSON context error: %w", ctx.Err()) + } + err = cw.Close() + if err != nil { + return written, xerrors.Errorf("close failed: %w", err) + } + + return cw.(*noCompressionWriteCloser).written, nil +} + +func (r *Registry) StreamSpack(ctx context.Context, w io.Writer, compression CompressionType) (int, error) { + metrics, err := r.Gather() + if err != nil { + return 0, err + } + return NewSpackEncoder(ctx, compression, metrics).Encode(w) +} diff --git a/library/go/core/metrics/solomon/timer.go b/library/go/core/metrics/solomon/timer.go new file mode 100644 index 0000000000..b26acc490b --- /dev/null +++ b/library/go/core/metrics/solomon/timer.go @@ -0,0 +1,92 @@ +package solomon + +import ( + "encoding/json" + "time" + + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/metrics" +) + +var ( + _ metrics.Timer = (*Timer)(nil) + _ Metric = (*Timer)(nil) +) + +// Timer measures gauge duration. +type Timer struct { + name string + metricType metricType + tags map[string]string + value atomic.Duration + timestamp *time.Time + + useNameTag bool +} + +func (t *Timer) RecordDuration(value time.Duration) { + t.value.Store(value) +} + +func (t *Timer) Name() string { + return t.name +} + +func (t *Timer) getType() metricType { + return t.metricType +} + +func (t *Timer) getLabels() map[string]string { + return t.tags +} + +func (t *Timer) getValue() interface{} { + return t.value.Load().Seconds() +} + +func (t *Timer) getTimestamp() *time.Time { + return t.timestamp +} + +func (t *Timer) getNameTag() string { + if t.useNameTag { + return "name" + } else { + return "sensor" + } +} + +// MarshalJSON implements json.Marshaler. +func (t *Timer) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Labels map[string]string `json:"labels"` + Value float64 `json:"value"` + Timestamp *int64 `json:"ts,omitempty"` + }{ + Type: t.metricType.String(), + Value: t.value.Load().Seconds(), + Labels: func() map[string]string { + labels := make(map[string]string, len(t.tags)+1) + labels[t.getNameTag()] = t.Name() + for k, v := range t.tags { + labels[k] = v + } + return labels + }(), + Timestamp: tsAsRef(t.timestamp), + }) +} + +// Snapshot returns independent copy on metric. +func (t *Timer) Snapshot() Metric { + return &Timer{ + name: t.name, + metricType: t.metricType, + tags: t.tags, + value: *atomic.NewDuration(t.value.Load()), + + useNameTag: t.useNameTag, + } +} diff --git a/library/go/core/metrics/solomon/vec.go b/library/go/core/metrics/solomon/vec.go new file mode 100644 index 0000000000..a4d3ab1a83 --- /dev/null +++ b/library/go/core/metrics/solomon/vec.go @@ -0,0 +1,226 @@ +package solomon + +import ( + "sync" + + "a.yandex-team.ru/library/go/core/metrics" + "a.yandex-team.ru/library/go/core/metrics/internal/pkg/registryutil" +) + +// metricsVector is a base implementation of vector of metrics of any supported type. +type metricsVector struct { + labels []string + mtx sync.RWMutex // Protects metrics. + metrics map[uint64]Metric + rated bool + newMetric func(map[string]string) Metric +} + +func (v *metricsVector) with(tags map[string]string) Metric { + hv, err := registryutil.VectorHash(tags, v.labels) + if err != nil { + panic(err) + } + + v.mtx.RLock() + metric, ok := v.metrics[hv] + v.mtx.RUnlock() + if ok { + return metric + } + + v.mtx.Lock() + defer v.mtx.Unlock() + + metric, ok = v.metrics[hv] + if !ok { + metric = v.newMetric(tags) + v.metrics[hv] = metric + } + + return metric +} + +// reset deletes all metrics in this vector. +func (v *metricsVector) reset() { + v.mtx.Lock() + defer v.mtx.Unlock() + + for h := range v.metrics { + delete(v.metrics, h) + } +} + +var _ metrics.CounterVec = (*CounterVec)(nil) + +// CounterVec stores counters and +// implements metrics.CounterVec interface. +type CounterVec struct { + vec *metricsVector +} + +// CounterVec creates a new counters vector with given metric name and +// partitioned by the given label names. +func (r *Registry) CounterVec(name string, labels []string) metrics.CounterVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + Counter(name).(*Counter) + }, + } + return &CounterVec{vec: vec} +} + +// With creates new or returns existing counter with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *CounterVec) With(tags map[string]string) metrics.Counter { + return v.vec.with(tags).(*Counter) +} + +// Reset deletes all metrics in this vector. +func (v *CounterVec) Reset() { + v.vec.reset() +} + +var _ metrics.GaugeVec = (*GaugeVec)(nil) + +// GaugeVec stores gauges and +// implements metrics.GaugeVec interface. +type GaugeVec struct { + vec *metricsVector +} + +// GaugeVec creates a new gauges vector with given metric name and +// partitioned by the given label names. +func (r *Registry) GaugeVec(name string, labels []string) metrics.GaugeVec { + return &GaugeVec{ + vec: &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return r.WithTags(tags).Gauge(name).(*Gauge) + }, + }, + } +} + +// With creates new or returns existing gauge with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *GaugeVec) With(tags map[string]string) metrics.Gauge { + return v.vec.with(tags).(*Gauge) +} + +// Reset deletes all metrics in this vector. +func (v *GaugeVec) Reset() { + v.vec.reset() +} + +var _ metrics.TimerVec = (*TimerVec)(nil) + +// TimerVec stores timers and +// implements metrics.TimerVec interface. +type TimerVec struct { + vec *metricsVector +} + +// TimerVec creates a new timers vector with given metric name and +// partitioned by the given label names. +func (r *Registry) TimerVec(name string, labels []string) metrics.TimerVec { + return &TimerVec{ + vec: &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + newMetric: func(tags map[string]string) Metric { + return r.WithTags(tags).Timer(name).(*Timer) + }, + }, + } +} + +// With creates new or returns existing timer with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *TimerVec) With(tags map[string]string) metrics.Timer { + return v.vec.with(tags).(*Timer) +} + +// Reset deletes all metrics in this vector. +func (v *TimerVec) Reset() { + v.vec.reset() +} + +var _ metrics.HistogramVec = (*HistogramVec)(nil) + +// HistogramVec stores histograms and +// implements metrics.HistogramVec interface. +type HistogramVec struct { + vec *metricsVector +} + +// HistogramVec creates a new histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) HistogramVec(name string, buckets metrics.Buckets, labels []string) metrics.HistogramVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + Histogram(name, buckets).(*Histogram) + }, + } + return &HistogramVec{vec: vec} +} + +// With creates new or returns existing histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *HistogramVec) With(tags map[string]string) metrics.Histogram { + return v.vec.with(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *HistogramVec) Reset() { + v.vec.reset() +} + +var _ metrics.TimerVec = (*DurationHistogramVec)(nil) + +// DurationHistogramVec stores duration histograms and +// implements metrics.TimerVec interface. +type DurationHistogramVec struct { + vec *metricsVector +} + +// DurationHistogramVec creates a new duration histograms vector with given metric name and buckets and +// partitioned by the given label names. +func (r *Registry) DurationHistogramVec(name string, buckets metrics.DurationBuckets, labels []string) metrics.TimerVec { + var vec *metricsVector + vec = &metricsVector{ + labels: append([]string(nil), labels...), + metrics: make(map[uint64]Metric), + rated: r.rated, + newMetric: func(tags map[string]string) Metric { + return r.Rated(vec.rated). + WithTags(tags). + DurationHistogram(name, buckets).(*Histogram) + }, + } + return &DurationHistogramVec{vec: vec} +} + +// With creates new or returns existing duration histogram with given tags from vector. +// It will panic if tags keys set is not equal to vector labels. +func (v *DurationHistogramVec) With(tags map[string]string) metrics.Timer { + return v.vec.with(tags).(*Histogram) +} + +// Reset deletes all metrics in this vector. +func (v *DurationHistogramVec) Reset() { + v.vec.reset() +} diff --git a/library/go/core/resource/cc/main.go b/library/go/core/resource/cc/main.go new file mode 100644 index 0000000000..35687c65a6 --- /dev/null +++ b/library/go/core/resource/cc/main.go @@ -0,0 +1,92 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "strings" +) + +func fatalf(msg string, args ...interface{}) { + _, _ = fmt.Fprintf(os.Stderr, msg+"\n", args...) + os.Exit(1) +} + +func generate(w io.Writer, pkg string, blobs [][]byte, keys []string) { + _, _ = fmt.Fprint(w, "// Code generated by a.yandex-team.ru/library/go/core/resource/cc DO NOT EDIT.\n") + _, _ = fmt.Fprintf(w, "package %s\n\n", pkg) + _, _ = fmt.Fprint(w, "import \"a.yandex-team.ru/library/go/core/resource\"\n") + + for i := 0; i < len(blobs); i++ { + blob := blobs[i] + + _, _ = fmt.Fprint(w, "\nfunc init() {\n") + + _, _ = fmt.Fprint(w, "\tblob := []byte(") + _, _ = fmt.Fprintf(w, "%+q", blob) + _, _ = fmt.Fprint(w, ")\n") + _, _ = fmt.Fprintf(w, "\tresource.InternalRegister(%q, blob)\n", keys[i]) + _, _ = fmt.Fprint(w, "}\n") + } +} + +func main() { + var pkg, output string + + flag.StringVar(&pkg, "package", "", "package name") + flag.StringVar(&output, "o", "", "output filename") + flag.Parse() + + if flag.NArg()%2 != 0 { + fatalf("cc: must provide even number of arguments") + } + + var keys []string + var blobs [][]byte + for i := 0; 2*i < flag.NArg(); i++ { + file := flag.Arg(2 * i) + key := flag.Arg(2*i + 1) + + if !strings.HasPrefix(key, "notafile") { + fatalf("cc: key argument must start with \"notafile\" string") + } + key = key[8:] + + if file == "-" { + parts := strings.SplitN(key, "=", 2) + if len(parts) != 2 { + fatalf("cc: invalid key syntax: %q", key) + } + + keys = append(keys, parts[0]) + blobs = append(blobs, []byte(parts[1])) + } else { + blob, err := ioutil.ReadFile(file) + if err != nil { + fatalf("cc: %v", err) + } + + keys = append(keys, key) + blobs = append(blobs, blob) + } + } + + f, err := os.Create(output) + if err != nil { + fatalf("cc: %v", err) + } + + b := bufio.NewWriter(f) + generate(b, pkg, blobs, keys) + + if err = b.Flush(); err != nil { + fatalf("cc: %v", err) + } + + if err = f.Close(); err != nil { + fatalf("cc: %v", err) + } +} diff --git a/library/go/core/resource/resource.go b/library/go/core/resource/resource.go new file mode 100644 index 0000000000..686ea73c3b --- /dev/null +++ b/library/go/core/resource/resource.go @@ -0,0 +1,56 @@ +// Package resource provides integration with RESOURCE and RESOURCE_FILES macros. +// +// Use RESOURCE macro to "link" file into the library or executable. +// +// RESOURCE(my_file.txt some_key) +// +// And then retrieve file content in the runtime. +// +// blob := resource.Get("some_key") +// +// Warning: Excessive consumption of resource leads to obesity. +package resource + +import ( + "fmt" + "sort" +) + +var resources = map[string][]byte{} + +// InternalRegister is private API used by generated code. +func InternalRegister(key string, blob []byte) { + if _, ok := resources[key]; ok { + panic(fmt.Sprintf("resource key %q is already defined", key)) + } + + resources[key] = blob +} + +// Get returns content of the file registered by the given key. +// +// If no file was registered for the given key, nil slice is returned. +// +// User should take care, to avoid mutating returned slice. +func Get(key string) []byte { + return resources[key] +} + +// MustGet is like Get, but panics when associated resource is not defined. +func MustGet(key string) []byte { + r, ok := resources[key] + if !ok { + panic(fmt.Sprintf("resource with key %q is not defined", key)) + } + return r +} + +// Keys returns sorted keys of all registered resources inside the binary +func Keys() []string { + keys := make([]string, 0, len(resources)) + for k := range resources { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/library/go/core/xerrors/doc.go b/library/go/core/xerrors/doc.go new file mode 100644 index 0000000000..de06dd15d2 --- /dev/null +++ b/library/go/core/xerrors/doc.go @@ -0,0 +1,2 @@ +// package xerrors is a drop in replacement for errors and golang.org/x/xerrors packages and functionally for github.com/pkg/errors. +package xerrors diff --git a/library/go/core/xerrors/errorf.go b/library/go/core/xerrors/errorf.go new file mode 100644 index 0000000000..de9e5248bb --- /dev/null +++ b/library/go/core/xerrors/errorf.go @@ -0,0 +1,88 @@ +package xerrors + +import ( + "fmt" + "io" + "strings" + + "a.yandex-team.ru/library/go/x/xruntime" +) + +type wrappedErrorf struct { + err error + stacktrace *xruntime.StackTrace +} + +var _ ErrorStackTrace = &wrappedErrorf{} + +func Errorf(format string, a ...interface{}) error { + err := fmt.Errorf(format, a...) + return &wrappedErrorf{ + err: err, + stacktrace: newStackTrace(1, err), + } +} + +func SkipErrorf(skip int, format string, a ...interface{}) error { + err := fmt.Errorf(format, a...) + return &wrappedErrorf{ + err: err, + stacktrace: newStackTrace(skip+1, err), + } +} + +func (e *wrappedErrorf) Format(s fmt.State, v rune) { + switch v { + case 'v': + if s.Flag('+') { + msg := e.err.Error() + inner := Unwrap(e.err) + // If Errorf wrapped another error then it will be our message' suffix. If so, cut it since otherwise we will + // print it again as part of formatting that error. + if inner != nil { + if strings.HasSuffix(msg, inner.Error()) { + msg = msg[:len(msg)-len(inner.Error())] + // Cut last space if needed but only if there is stacktrace present (very likely) + if e.stacktrace != nil && strings.HasSuffix(msg, ": ") { + msg = msg[:len(msg)-1] + } + } + } + + _, _ = io.WriteString(s, msg) + if e.stacktrace != nil { + // New line is useful only when printing frames, otherwise it is better to print next error in the chain + // right after we print this one + _, _ = io.WriteString(s, "\n") + writeStackTrace(s, e.stacktrace) + } + + // Print next error down the chain if there is one + if inner != nil { + _, _ = fmt.Fprintf(s, "%+v", inner) + } + + return + } + fallthrough + case 's': + _, _ = io.WriteString(s, e.err.Error()) + case 'q': + _, _ = fmt.Fprintf(s, "%q", e.err.Error()) + } +} + +func (e *wrappedErrorf) Error() string { + // Wrapped error has correct formatting + return e.err.Error() +} + +func (e *wrappedErrorf) Unwrap() error { + // Skip wrapped error and return whatever it is wrapping + // TODO: test for correct unwrap + return Unwrap(e.err) +} + +func (e *wrappedErrorf) StackTrace() *xruntime.StackTrace { + return e.stacktrace +} diff --git a/library/go/core/xerrors/forward.go b/library/go/core/xerrors/forward.go new file mode 100644 index 0000000000..aaa900133c --- /dev/null +++ b/library/go/core/xerrors/forward.go @@ -0,0 +1,56 @@ +package xerrors + +import "errors" + +// Unwrap returns the result of calling the Unwrap method on err, if err's +// type contains an Unwrap method returning error. +// Otherwise, Unwrap returns nil. +func Unwrap(err error) error { + return errors.Unwrap(err) +} + +// Is reports whether any error in err's chain matches target. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error is considered to match a target if it is equal to that target or if +// it implements a method Is(error) bool such that Is(target) returns true. +// +// An error type might provide an Is method so it can be treated as equivalent +// to an existing error. For example, if MyError defines +// +// func (m MyError) Is(target error) bool { return target == os.ErrExist } +// +// then Is(MyError{}, os.ErrExist) returns true. See syscall.Errno.Is for +// an example in the standard library. +func Is(err, target error) bool { + return errors.Is(err, target) +} + +// As finds the first error in err's chain that matches target, and if so, sets +// target to that error value and returns true. Otherwise, it returns false. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error matches target if the error's concrete value is assignable to the value +// pointed to by target, or if the error has a method As(interface{}) bool such that +// As(target) returns true. In the latter case, the As method is responsible for +// setting target. +// +// An error type might provide an As method so it can be treated as if it were a +// different error type. +// +// As panics if target is not a non-nil pointer to either a type that implements +// error, or to any interface type. +func As(err error, target interface{}) bool { + return errors.As(err, target) +} + +// Wrapper provides context around another error. +type Wrapper interface { + // Unwrap returns the next error in the error chain. + // If there is no next error, Unwrap returns nil. + Unwrap() error +} diff --git a/library/go/core/xerrors/internal/modes/stack_frames_count.go b/library/go/core/xerrors/internal/modes/stack_frames_count.go new file mode 100644 index 0000000000..c117becf6a --- /dev/null +++ b/library/go/core/xerrors/internal/modes/stack_frames_count.go @@ -0,0 +1,22 @@ +package modes + +import "sync/atomic" + +type StackFramesCount = int32 + +const ( + StackFramesCount16 StackFramesCount = 16 + StackFramesCount32 StackFramesCount = 32 + StackFramesCount64 StackFramesCount = 64 + StackFramesCount128 StackFramesCount = 128 +) + +var StackFramesCountMax = StackFramesCount32 + +func SetStackFramesCountMax(count StackFramesCount) { + atomic.StoreInt32(&StackFramesCountMax, count) +} + +func GetStackFramesCountMax() StackFramesCount { + return atomic.LoadInt32(&StackFramesCountMax) +} diff --git a/library/go/core/xerrors/internal/modes/stack_trace_mode.go b/library/go/core/xerrors/internal/modes/stack_trace_mode.go new file mode 100644 index 0000000000..04f78ffd3d --- /dev/null +++ b/library/go/core/xerrors/internal/modes/stack_trace_mode.go @@ -0,0 +1,48 @@ +package modes + +import "sync/atomic" + +type StackTraceMode int32 + +const ( + StackTraceModeFrames StackTraceMode = iota + StackTraceModeStacks + StackTraceModeStackThenFrames + StackTraceModeStackThenNothing + StackTraceModeNothing +) + +func (m StackTraceMode) String() string { + return []string{"Frames", "Stacks", "StackThenFrames", "StackThenNothing", "Nothing"}[m] +} + +const defaultStackTraceMode = StackTraceModeFrames + +var ( + // Default mode + stackTraceMode = defaultStackTraceMode + // Known modes (used in tests) + knownStackTraceModes = []StackTraceMode{ + StackTraceModeFrames, + StackTraceModeStacks, + StackTraceModeStackThenFrames, + StackTraceModeStackThenNothing, + StackTraceModeNothing, + } +) + +func SetStackTraceMode(v StackTraceMode) { + atomic.StoreInt32((*int32)(&stackTraceMode), int32(v)) +} + +func GetStackTraceMode() StackTraceMode { + return StackTraceMode(atomic.LoadInt32((*int32)(&stackTraceMode))) +} + +func DefaultStackTraceMode() { + SetStackTraceMode(defaultStackTraceMode) +} + +func KnownStackTraceModes() []StackTraceMode { + return knownStackTraceModes +} diff --git a/library/go/core/xerrors/mode.go b/library/go/core/xerrors/mode.go new file mode 100644 index 0000000000..e6051625b2 --- /dev/null +++ b/library/go/core/xerrors/mode.go @@ -0,0 +1,93 @@ +package xerrors + +import ( + "fmt" + + "a.yandex-team.ru/library/go/core/xerrors/internal/modes" + "a.yandex-team.ru/library/go/x/xruntime" +) + +func DefaultStackTraceMode() { + modes.DefaultStackTraceMode() +} + +func EnableFrames() { + modes.SetStackTraceMode(modes.StackTraceModeFrames) +} + +func EnableStacks() { + modes.SetStackTraceMode(modes.StackTraceModeStacks) +} + +func EnableStackThenFrames() { + modes.SetStackTraceMode(modes.StackTraceModeStackThenFrames) +} + +func EnableStackThenNothing() { + modes.SetStackTraceMode(modes.StackTraceModeStackThenNothing) +} + +func DisableStackTraces() { + modes.SetStackTraceMode(modes.StackTraceModeNothing) +} + +// newStackTrace returns stacktrace based on current mode and frames count +func newStackTrace(skip int, err error) *xruntime.StackTrace { + skip++ + m := modes.GetStackTraceMode() + switch m { + case modes.StackTraceModeFrames: + return xruntime.NewFrame(skip) + case modes.StackTraceModeStackThenFrames: + if err != nil && StackTraceOfEffect(err) != nil { + return xruntime.NewFrame(skip) + } + + return _newStackTrace(skip) + case modes.StackTraceModeStackThenNothing: + if err != nil && StackTraceOfEffect(err) != nil { + return nil + } + + return _newStackTrace(skip) + case modes.StackTraceModeStacks: + return _newStackTrace(skip) + case modes.StackTraceModeNothing: + return nil + } + + panic(fmt.Sprintf("unknown stack trace mode %d", m)) +} + +func MaxStackFrames16() { + modes.SetStackFramesCountMax(modes.StackFramesCount16) +} + +func MaxStackFrames32() { + modes.SetStackFramesCountMax(modes.StackFramesCount32) +} + +func MaxStackFrames64() { + modes.SetStackFramesCountMax(modes.StackFramesCount64) +} + +func MaxStackFrames128() { + modes.SetStackFramesCountMax(modes.StackFramesCount128) +} + +func _newStackTrace(skip int) *xruntime.StackTrace { + skip++ + count := modes.GetStackFramesCountMax() + switch count { + case 16: + return xruntime.NewStackTrace16(skip) + case 32: + return xruntime.NewStackTrace32(skip) + case 64: + return xruntime.NewStackTrace64(skip) + case 128: + return xruntime.NewStackTrace128(skip) + } + + panic(fmt.Sprintf("unknown stack frames count %d", count)) +} diff --git a/library/go/core/xerrors/new.go b/library/go/core/xerrors/new.go new file mode 100644 index 0000000000..e4ad213410 --- /dev/null +++ b/library/go/core/xerrors/new.go @@ -0,0 +1,48 @@ +package xerrors + +import ( + "fmt" + "io" + + "a.yandex-team.ru/library/go/x/xruntime" +) + +type newError struct { + msg string + stacktrace *xruntime.StackTrace +} + +var _ ErrorStackTrace = &newError{} + +func New(text string) error { + return &newError{ + msg: text, + stacktrace: newStackTrace(1, nil), + } +} + +func (e *newError) Error() string { + return e.msg +} + +func (e *newError) Format(s fmt.State, v rune) { + switch v { + case 'v': + if s.Flag('+') && e.stacktrace != nil { + _, _ = io.WriteString(s, e.msg) + _, _ = io.WriteString(s, "\n") + writeStackTrace(s, e.stacktrace) + return + } + + fallthrough + case 's': + _, _ = io.WriteString(s, e.msg) + case 'q': + _, _ = fmt.Fprintf(s, "%q", e.msg) + } +} + +func (e *newError) StackTrace() *xruntime.StackTrace { + return e.stacktrace +} diff --git a/library/go/core/xerrors/sentinel.go b/library/go/core/xerrors/sentinel.go new file mode 100644 index 0000000000..9727b84824 --- /dev/null +++ b/library/go/core/xerrors/sentinel.go @@ -0,0 +1,150 @@ +package xerrors + +import ( + "errors" + "fmt" + "io" + "strings" + + "a.yandex-team.ru/library/go/x/xreflect" + "a.yandex-team.ru/library/go/x/xruntime" +) + +// NewSentinel acts as New but does not add stack frame +func NewSentinel(text string) *Sentinel { + return &Sentinel{error: errors.New(text)} +} + +// Sentinel error +type Sentinel struct { + error +} + +// WithFrame adds stack frame to sentinel error (DEPRECATED) +func (s *Sentinel) WithFrame() error { + return &sentinelWithStackTrace{ + err: s, + stacktrace: newStackTrace(1, nil), + } +} + +func (s *Sentinel) WithStackTrace() error { + return &sentinelWithStackTrace{ + err: s, + stacktrace: newStackTrace(1, nil), + } +} + +// Wrap error with this sentinel error. Adds stack frame. +func (s *Sentinel) Wrap(err error) error { + if err == nil { + panic("tried to wrap a nil error") + } + + return &sentinelWrapper{ + err: s, + wrapped: err, + stacktrace: newStackTrace(1, err), + } +} + +type sentinelWithStackTrace struct { + err error + stacktrace *xruntime.StackTrace +} + +func (e *sentinelWithStackTrace) Error() string { + return e.err.Error() +} + +func (e *sentinelWithStackTrace) Format(s fmt.State, v rune) { + switch v { + case 'v': + if s.Flag('+') && e.stacktrace != nil { + msg := e.err.Error() + _, _ = io.WriteString(s, msg) + writeMsgAndStackTraceSeparator(s, msg) + writeStackTrace(s, e.stacktrace) + return + } + fallthrough + case 's': + _, _ = io.WriteString(s, e.err.Error()) + case 'q': + _, _ = fmt.Fprintf(s, "%q", e.err.Error()) + } +} + +func writeMsgAndStackTraceSeparator(w io.Writer, msg string) { + separator := "\n" + if !strings.HasSuffix(msg, ":") { + separator = ":\n" + } + + _, _ = io.WriteString(w, separator) +} + +// Is checks if e holds the specified error. Checks only immediate error. +func (e *sentinelWithStackTrace) Is(target error) bool { + return e.err == target +} + +// As checks if ew holds the specified error type. Checks only immediate error. +// It does NOT perform target checks as it relies on errors.As to do it +func (e *sentinelWithStackTrace) As(target interface{}) bool { + return xreflect.Assign(e.err, target) +} + +type sentinelWrapper struct { + err error + wrapped error + stacktrace *xruntime.StackTrace +} + +func (e *sentinelWrapper) Error() string { + return fmt.Sprintf("%s", e) +} + +func (e *sentinelWrapper) Format(s fmt.State, v rune) { + switch v { + case 'v': + if s.Flag('+') { + if e.stacktrace != nil { + msg := e.err.Error() + _, _ = io.WriteString(s, msg) + writeMsgAndStackTraceSeparator(s, msg) + writeStackTrace(s, e.stacktrace) + _, _ = fmt.Fprintf(s, "%+v", e.wrapped) + } else { + _, _ = io.WriteString(s, e.err.Error()) + _, _ = io.WriteString(s, ": ") + _, _ = fmt.Fprintf(s, "%+v", e.wrapped) + } + + return + } + fallthrough + case 's': + _, _ = io.WriteString(s, e.err.Error()) + _, _ = io.WriteString(s, ": ") + _, _ = io.WriteString(s, e.wrapped.Error()) + case 'q': + _, _ = fmt.Fprintf(s, "%q", fmt.Sprintf("%s: %s", e.err.Error(), e.wrapped.Error())) + } +} + +// Unwrap implements Wrapper interface +func (e *sentinelWrapper) Unwrap() error { + return e.wrapped +} + +// Is checks if ew holds the specified error. Checks only immediate error. +func (e *sentinelWrapper) Is(target error) bool { + return e.err == target +} + +// As checks if error holds the specified error type. Checks only immediate error. +// It does NOT perform target checks as it relies on errors.As to do it +func (e *sentinelWrapper) As(target interface{}) bool { + return xreflect.Assign(e.err, target) +} diff --git a/library/go/core/xerrors/stacktrace.go b/library/go/core/xerrors/stacktrace.go new file mode 100644 index 0000000000..fab7dc28a8 --- /dev/null +++ b/library/go/core/xerrors/stacktrace.go @@ -0,0 +1,80 @@ +package xerrors + +import ( + "errors" + "fmt" + "io" + + "a.yandex-team.ru/library/go/x/xruntime" +) + +func writeStackTrace(w io.Writer, stacktrace *xruntime.StackTrace) { + for _, frame := range stacktrace.Frames() { + if frame.Function != "" { + _, _ = fmt.Fprintf(w, " %s\n ", frame.Function) + } + + if frame.File != "" { + _, _ = fmt.Fprintf(w, " %s:%d\n", frame.File, frame.Line) + } + } +} + +type ErrorStackTrace interface { + StackTrace() *xruntime.StackTrace +} + +// StackTraceOfEffect returns last stacktrace that was added to error chain (furthest from the root error). +// Guarantees that returned value has valid StackTrace object (but not that there are any frames). +func StackTraceOfEffect(err error) ErrorStackTrace { + var st ErrorStackTrace + for { + if !As(err, &st) { + return nil + } + + if st.StackTrace() != nil { + return st + } + + err = st.(error) + err = errors.Unwrap(err) + } +} + +// StackTraceOfCause returns first stacktrace that was added to error chain (closest to the root error). +// Guarantees that returned value has valid StackTrace object (but not that there are any frames). +func StackTraceOfCause(err error) ErrorStackTrace { + var res ErrorStackTrace + var st ErrorStackTrace + for { + if !As(err, &st) { + return res + } + + if st.StackTrace() != nil { + res = st + } + + err = st.(error) + err = errors.Unwrap(err) + } +} + +// NextStackTracer returns next error with stack trace. +// Guarantees that returned value has valid StackTrace object (but not that there are any frames). +func NextStackTrace(st ErrorStackTrace) ErrorStackTrace { + var res ErrorStackTrace + for { + err := st.(error) + err = errors.Unwrap(err) + + if !As(err, &res) { + return nil + } + + if res.StackTrace() != nil { + return res + } + } +} diff --git a/library/go/httputil/headers/accept.go b/library/go/httputil/headers/accept.go new file mode 100644 index 0000000000..394bed7360 --- /dev/null +++ b/library/go/httputil/headers/accept.go @@ -0,0 +1,259 @@ +package headers + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +const ( + AcceptKey = "Accept" + AcceptEncodingKey = "Accept-Encoding" +) + +type AcceptableEncodings []AcceptableEncoding + +type AcceptableEncoding struct { + Encoding ContentEncoding + Weight float32 + + pos int +} + +func (as AcceptableEncodings) IsAcceptable(encoding ContentEncoding) bool { + for _, ae := range as { + if ae.Encoding == encoding { + return ae.Weight != 0 + } + } + return false +} + +func (as AcceptableEncodings) String() string { + if len(as) == 0 { + return "" + } + + var b strings.Builder + for i, ae := range as { + b.WriteString(ae.Encoding.String()) + + if ae.Weight > 0.0 && ae.Weight < 1.0 { + b.WriteString(";q=" + strconv.FormatFloat(float64(ae.Weight), 'f', 1, 32)) + } + + if i < len(as)-1 { + b.WriteString(", ") + } + } + return b.String() +} + +type AcceptableTypes []AcceptableType + +func (as AcceptableTypes) IsAcceptable(contentType ContentType) bool { + for _, ae := range as { + if ae.Type == contentType { + return ae.Weight != 0 + } + } + return false +} + +type AcceptableType struct { + Type ContentType + Weight float32 + Extension map[string]string + + pos int +} + +func (as AcceptableTypes) String() string { + if len(as) == 0 { + return "" + } + + var b strings.Builder + for i, at := range as { + b.WriteString(at.Type.String()) + + if at.Weight > 0.0 && at.Weight < 1.0 { + b.WriteString(";q=" + strconv.FormatFloat(float64(at.Weight), 'f', 1, 32)) + } + + for k, v := range at.Extension { + b.WriteString(";" + k + "=" + v) + } + + if i < len(as)-1 { + b.WriteString(", ") + } + } + return b.String() +} + +// ParseAccept parses Accept HTTP header. +// It will sort acceptable types by weight, specificity and position. +// See: https://tools.ietf.org/html/rfc2616#section-14.1 +func ParseAccept(headerValue string) (AcceptableTypes, error) { + if headerValue == "" { + return nil, nil + } + + parsedValues, err := parseAcceptFamilyHeader(headerValue) + if err != nil { + return nil, err + } + ah := make(AcceptableTypes, 0, len(parsedValues)) + for _, parsedValue := range parsedValues { + ah = append(ah, AcceptableType{ + Type: ContentType(parsedValue.Value), + Weight: parsedValue.Weight, + Extension: parsedValue.Extension, + pos: parsedValue.pos, + }) + } + + sort.Slice(ah, func(i, j int) bool { + // sort by weight only + if ah[i].Weight != ah[j].Weight { + return ah[i].Weight > ah[j].Weight + } + + // sort by most specific if types are equal + if ah[i].Type == ah[j].Type { + return len(ah[i].Extension) > len(ah[j].Extension) + } + + // move counterpart up if one of types is ANY + if ah[i].Type == ContentTypeAny { + return false + } + if ah[j].Type == ContentTypeAny { + return true + } + + // i type has j type as prefix + if strings.HasSuffix(string(ah[j].Type), "/*") && + strings.HasPrefix(string(ah[i].Type), string(ah[j].Type)[:len(ah[j].Type)-1]) { + return true + } + + // j type has i type as prefix + if strings.HasSuffix(string(ah[i].Type), "/*") && + strings.HasPrefix(string(ah[j].Type), string(ah[i].Type)[:len(ah[i].Type)-1]) { + return false + } + + // sort by position if nothing else left + return ah[i].pos < ah[j].pos + }) + + return ah, nil +} + +// ParseAcceptEncoding parses Accept-Encoding HTTP header. +// It will sort acceptable encodings by weight and position. +// See: https://tools.ietf.org/html/rfc2616#section-14.3 +func ParseAcceptEncoding(headerValue string) (AcceptableEncodings, error) { + if headerValue == "" { + return nil, nil + } + + // e.g. gzip;q=1.0, compress, identity + parsedValues, err := parseAcceptFamilyHeader(headerValue) + if err != nil { + return nil, err + } + acceptableEncodings := make(AcceptableEncodings, 0, len(parsedValues)) + for _, parsedValue := range parsedValues { + acceptableEncodings = append(acceptableEncodings, AcceptableEncoding{ + Encoding: ContentEncoding(parsedValue.Value), + Weight: parsedValue.Weight, + pos: parsedValue.pos, + }) + } + sort.Slice(acceptableEncodings, func(i, j int) bool { + // sort by weight only + if acceptableEncodings[i].Weight != acceptableEncodings[j].Weight { + return acceptableEncodings[i].Weight > acceptableEncodings[j].Weight + } + + // move counterpart up if one of encodings is ANY + if acceptableEncodings[i].Encoding == EncodingAny { + return false + } + if acceptableEncodings[j].Encoding == EncodingAny { + return true + } + + // sort by position if nothing else left + return acceptableEncodings[i].pos < acceptableEncodings[j].pos + }) + + return acceptableEncodings, nil +} + +type acceptHeaderValue struct { + Value string + Weight float32 + Extension map[string]string + + pos int +} + +// parseAcceptFamilyHeader parses family of Accept* HTTP headers +// See: https://tools.ietf.org/html/rfc2616#section-14.1 +func parseAcceptFamilyHeader(header string) ([]acceptHeaderValue, error) { + headerValues := strings.Split(header, ",") + + parsedValues := make([]acceptHeaderValue, 0, len(headerValues)) + for i, headerValue := range headerValues { + valueParams := strings.Split(headerValue, ";") + + parsedValue := acceptHeaderValue{ + Value: strings.TrimSpace(valueParams[0]), + Weight: 1.0, + pos: i, + } + + // parse quality factor and/or accept extension + if len(valueParams) > 1 { + for _, rawParam := range valueParams[1:] { + rawParam = strings.TrimSpace(rawParam) + params := strings.SplitN(rawParam, "=", 2) + key := strings.TrimSpace(params[0]) + + // quality factor + if key == "q" { + if len(params) != 2 { + return nil, fmt.Errorf("invalid quality factor format: %q", rawParam) + } + + w, err := strconv.ParseFloat(params[1], 32) + if err != nil { + return nil, err + } + parsedValue.Weight = float32(w) + + continue + } + + // extension + if parsedValue.Extension == nil { + parsedValue.Extension = make(map[string]string) + } + + var value string + if len(params) == 2 { + value = strings.TrimSpace(params[1]) + } + parsedValue.Extension[key] = value + } + } + + parsedValues = append(parsedValues, parsedValue) + } + return parsedValues, nil +} diff --git a/library/go/httputil/headers/authorization.go b/library/go/httputil/headers/authorization.go new file mode 100644 index 0000000000..145e04f931 --- /dev/null +++ b/library/go/httputil/headers/authorization.go @@ -0,0 +1,31 @@ +package headers + +import "strings" + +const ( + AuthorizationKey = "Authorization" + + TokenTypeBearer TokenType = "bearer" + TokenTypeMAC TokenType = "mac" +) + +type TokenType string + +// String implements stringer interface +func (tt TokenType) String() string { + return string(tt) +} + +func AuthorizationTokenType(token string) TokenType { + if len(token) > len(TokenTypeBearer) && + strings.ToLower(token[:len(TokenTypeBearer)]) == TokenTypeBearer.String() { + return TokenTypeBearer + } + + if len(token) > len(TokenTypeMAC) && + strings.ToLower(token[:len(TokenTypeMAC)]) == TokenTypeMAC.String() { + return TokenTypeMAC + } + + return TokenType("unknown") +} diff --git a/library/go/httputil/headers/content.go b/library/go/httputil/headers/content.go new file mode 100644 index 0000000000..b92e013cc3 --- /dev/null +++ b/library/go/httputil/headers/content.go @@ -0,0 +1,57 @@ +package headers + +type ContentType string + +// String implements stringer interface +func (ct ContentType) String() string { + return string(ct) +} + +type ContentEncoding string + +// String implements stringer interface +func (ce ContentEncoding) String() string { + return string(ce) +} + +const ( + ContentTypeKey = "Content-Type" + ContentLength = "Content-Length" + ContentEncodingKey = "Content-Encoding" + + ContentTypeAny ContentType = "*/*" + + TypeApplicationJSON ContentType = "application/json" + TypeApplicationXML ContentType = "application/xml" + TypeApplicationOctetStream ContentType = "application/octet-stream" + TypeApplicationProtobuf ContentType = "application/protobuf" + TypeApplicationMsgpack ContentType = "application/msgpack" + TypeApplicationXSolomonSpack ContentType = "application/x-solomon-spack" + + EncodingAny ContentEncoding = "*" + EncodingZSTD ContentEncoding = "zstd" + EncodingLZ4 ContentEncoding = "lz4" + EncodingGZIP ContentEncoding = "gzip" + EncodingDeflate ContentEncoding = "deflate" + + TypeTextPlain ContentType = "text/plain" + TypeTextHTML ContentType = "text/html" + TypeTextCSV ContentType = "text/csv" + TypeTextCmd ContentType = "text/cmd" + TypeTextCSS ContentType = "text/css" + TypeTextXML ContentType = "text/xml" + TypeTextMarkdown ContentType = "text/markdown" + + TypeImageAny ContentType = "image/*" + TypeImageJPEG ContentType = "image/jpeg" + TypeImageGIF ContentType = "image/gif" + TypeImagePNG ContentType = "image/png" + TypeImageSVG ContentType = "image/svg+xml" + TypeImageTIFF ContentType = "image/tiff" + TypeImageWebP ContentType = "image/webp" + + TypeVideoMPEG ContentType = "video/mpeg" + TypeVideoMP4 ContentType = "video/mp4" + TypeVideoOgg ContentType = "video/ogg" + TypeVideoWebM ContentType = "video/webm" +) diff --git a/library/go/httputil/headers/cookie.go b/library/go/httputil/headers/cookie.go new file mode 100644 index 0000000000..bcc685c474 --- /dev/null +++ b/library/go/httputil/headers/cookie.go @@ -0,0 +1,5 @@ +package headers + +const ( + CookieKey = "Cookie" +) diff --git a/library/go/httputil/headers/user_agent.go b/library/go/httputil/headers/user_agent.go new file mode 100644 index 0000000000..366606a01d --- /dev/null +++ b/library/go/httputil/headers/user_agent.go @@ -0,0 +1,5 @@ +package headers + +const ( + UserAgentKey = "User-Agent" +) diff --git a/library/go/httputil/headers/warning.go b/library/go/httputil/headers/warning.go new file mode 100644 index 0000000000..2013fdf08d --- /dev/null +++ b/library/go/httputil/headers/warning.go @@ -0,0 +1,167 @@ +package headers + +import ( + "errors" + "net/http" + "strconv" + "strings" + "time" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const ( + WarningKey = "Warning" + + WarningResponseIsStale = 110 // RFC 7234, 5.5.1 + WarningRevalidationFailed = 111 // RFC 7234, 5.5.2 + WarningDisconnectedOperation = 112 // RFC 7234, 5.5.3 + WarningHeuristicExpiration = 113 // RFC 7234, 5.5.4 + WarningMiscellaneousWarning = 199 // RFC 7234, 5.5.5 + WarningTransformationApplied = 214 // RFC 7234, 5.5.6 + WarningMiscellaneousPersistentWarning = 299 // RFC 7234, 5.5.7 +) + +var warningStatusText = map[int]string{ + WarningResponseIsStale: "Response is Stale", + WarningRevalidationFailed: "Revalidation Failed", + WarningDisconnectedOperation: "Disconnected Operation", + WarningHeuristicExpiration: "Heuristic Expiration", + WarningMiscellaneousWarning: "Miscellaneous Warning", + WarningTransformationApplied: "Transformation Applied", + WarningMiscellaneousPersistentWarning: "Miscellaneous Persistent Warning", +} + +// WarningText returns a text for the warning header code. It returns the empty +// string if the code is unknown. +func WarningText(warn int) string { + return warningStatusText[warn] +} + +// AddWarning adds Warning to http.Header with proper formatting +// see: https://tools.ietf.org/html/rfc7234#section-5.5 +func AddWarning(h http.Header, warn int, agent, reason string, date time.Time) { + values := make([]string, 0, 4) + values = append(values, strconv.Itoa(warn)) + + if agent != "" { + values = append(values, agent) + } else { + values = append(values, "-") + } + + if reason != "" { + values = append(values, strconv.Quote(reason)) + } + + if !date.IsZero() { + values = append(values, strconv.Quote(date.Format(time.RFC1123))) + } + + h.Add(WarningKey, strings.Join(values, " ")) +} + +type WarningHeader struct { + Code int + Agent string + Reason string + Date time.Time +} + +// ParseWarnings reads and parses Warning headers from http.Header +func ParseWarnings(h http.Header) ([]WarningHeader, error) { + warnings, ok := h[WarningKey] + if !ok { + return nil, nil + } + + res := make([]WarningHeader, 0, len(warnings)) + for _, warn := range warnings { + wh, err := parseWarning(warn) + if err != nil { + return nil, xerrors.Errorf("cannot parse '%s' header: %w", warn, err) + } + res = append(res, wh) + } + + return res, nil +} + +func parseWarning(warn string) (WarningHeader, error) { + var res WarningHeader + + // parse code + { + codeSP := strings.Index(warn, " ") + + // fast path - code only warning + if codeSP == -1 { + code, err := strconv.Atoi(warn) + res.Code = code + return res, err + } + + code, err := strconv.Atoi(warn[:codeSP]) + if err != nil { + return WarningHeader{}, err + } + res.Code = code + + warn = strings.TrimSpace(warn[codeSP+1:]) + } + + // parse agent + { + agentSP := strings.Index(warn, " ") + + // fast path - no data after agent + if agentSP == -1 { + res.Agent = warn + return res, nil + } + + res.Agent = warn[:agentSP] + warn = strings.TrimSpace(warn[agentSP+1:]) + } + + // parse reason + { + if len(warn) == 0 { + return res, nil + } + + // reason must by quoted, so we search for second quote + reasonSP := strings.Index(warn[1:], `"`) + + // fast path - bad reason + if reasonSP == -1 { + return WarningHeader{}, errors.New("bad reason formatting") + } + + res.Reason = warn[1 : reasonSP+1] + warn = strings.TrimSpace(warn[reasonSP+2:]) + } + + // parse date + { + if len(warn) == 0 { + return res, nil + } + + // optional date must by quoted, so we search for second quote + dateSP := strings.Index(warn[1:], `"`) + + // fast path - bad date + if dateSP == -1 { + return WarningHeader{}, errors.New("bad date formatting") + } + + dt, err := time.Parse(time.RFC1123, warn[1:dateSP+1]) + if err != nil { + return WarningHeader{}, err + } + res.Date = dt + } + + return res, nil +} diff --git a/library/go/httputil/middleware/tvm/middleware.go b/library/go/httputil/middleware/tvm/middleware.go new file mode 100644 index 0000000000..cbcf1aae1d --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware.go @@ -0,0 +1,91 @@ +package tvm + +import ( + "context" + "net/http" + + "golang.org/x/xerrors" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/ctxlog" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" +) + +const ( + // XYaServiceTicket is http header that should be used for service ticket transfer. + XYaServiceTicket = "X-Ya-Service-Ticket" + // XYaUserTicket is http header that should be used for user ticket transfer. + XYaUserTicket = "X-Ya-User-Ticket" +) + +type ( + MiddlewareOption func(*middleware) + + middleware struct { + l log.Structured + + tvm tvm.Client + + authClient func(context.Context, tvm.ClientID) error + + onError func(w http.ResponseWriter, r *http.Request, err error) + } +) + +func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusForbidden) +} + +// CheckServiceTicket returns http middleware that validates service tickets for all incoming requests. +// +// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket. +func CheckServiceTicket(tvm tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler { + m := middleware{ + tvm: tvm, + onError: defaultErrorHandler, + } + + for _, opt := range opts { + opt(&m) + } + + if m.authClient == nil { + panic("must provide authorization policy") + } + + if m.l == nil { + m.l = &nop.Logger{} + } + + return m.wrap +} + +func (m *middleware) wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serviceTicket := r.Header.Get(XYaServiceTicket) + if serviceTicket == "" { + ctxlog.Error(r.Context(), m.l.Logger(), "missing service ticket") + m.onError(w, r, xerrors.New("missing service ticket")) + return + } + + ticket, err := m.tvm.CheckServiceTicket(r.Context(), serviceTicket) + if err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "service ticket check failed", log.Error(err)) + m.onError(w, r, xerrors.Errorf("service ticket check failed: %w", err)) + return + } + + if err := m.authClient(r.Context(), ticket.SrcID); err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "client authorization failed", + log.String("ticket", ticket.LogInfo), + log.Error(err)) + m.onError(w, r, xerrors.Errorf("client authorization failed: %w", err)) + return + } + + r = r.WithContext(tvm.WithServiceTicket(r.Context(), ticket)) + next.ServeHTTP(w, r) + }) +} diff --git a/library/go/httputil/middleware/tvm/middleware_opts.go b/library/go/httputil/middleware/tvm/middleware_opts.go new file mode 100644 index 0000000000..86df5d4796 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware_opts.go @@ -0,0 +1,47 @@ +package tvm + +import ( + "context" + "net/http" + + "golang.org/x/xerrors" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// WithAllowedClients sets list of allowed clients. +func WithAllowedClients(allowedClients []tvm.ClientID) MiddlewareOption { + return func(m *middleware) { + m.authClient = func(_ context.Context, cid tvm.ClientID) error { + for _, allowed := range allowedClients { + if allowed == cid { + return nil + } + } + + return xerrors.Errorf("client with tvm_id=%d is not whitelisted", cid) + } + } +} + +// WithClientAuth sets custom function for client authorization. +func WithClientAuth(authClient func(context.Context, tvm.ClientID) error) MiddlewareOption { + return func(m *middleware) { + m.authClient = authClient + } +} + +// WithErrorHandler sets http handler invoked for rejected requests. +func WithErrorHandler(h func(w http.ResponseWriter, r *http.Request, err error)) MiddlewareOption { + return func(m *middleware) { + m.onError = h + } +} + +// WithLogger sets logger. +func WithLogger(l log.Structured) MiddlewareOption { + return func(m *middleware) { + m.l = l + } +} diff --git a/library/go/maxprocs/cgroups.go b/library/go/maxprocs/cgroups.go new file mode 100644 index 0000000000..0482788d37 --- /dev/null +++ b/library/go/maxprocs/cgroups.go @@ -0,0 +1,174 @@ +package maxprocs + +import ( + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "runtime" + "strconv" + "strings" + + "github.com/prometheus/procfs" + + "a.yandex-team.ru/library/go/slices" +) + +const ( + unifiedHierarchy = "unified" + cpuHierarchy = "cpu" +) + +var ErrNoCgroups = errors.New("no suitable cgroups were found") + +func isCgroupsExists() bool { + mounts, err := procfs.GetMounts() + if err != nil { + return false + } + + for _, m := range mounts { + if m.FSType == "cgroup" || m.FSType == "cgroup2" { + return true + } + } + + return false +} + +func parseCgroupsMountPoints() (map[string]string, error) { + mounts, err := procfs.GetMounts() + if err != nil { + return nil, err + } + + out := make(map[string]string) + for _, mount := range mounts { + switch mount.FSType { + case "cgroup2": + out[unifiedHierarchy] = mount.MountPoint + case "cgroup": + for opt := range mount.SuperOptions { + if opt == cpuHierarchy { + out[cpuHierarchy] = mount.MountPoint + break + } + } + } + } + + return out, nil +} + +func getCFSQuota() (float64, error) { + self, err := procfs.Self() + if err != nil { + return 0, err + } + + selfCgroups, err := self.Cgroups() + if err != nil { + return 0, fmt.Errorf("parse self cgroups: %w", err) + } + + cgroups, err := parseCgroupsMountPoints() + if err != nil { + return 0, fmt.Errorf("parse cgroups: %w", err) + } + + if len(selfCgroups) == 0 || len(cgroups) == 0 { + return 0, ErrNoCgroups + } + + for _, cgroup := range selfCgroups { + var quota float64 + switch { + case cgroup.HierarchyID == 0: + // for the cgroups v2 hierarchy id is always 0 + mp, ok := cgroups[unifiedHierarchy] + if !ok { + continue + } + + quota, _ = parseV2CPUQuota(mp, cgroup.Path) + case slices.ContainsString(cgroup.Controllers, cpuHierarchy): + mp, ok := cgroups[cpuHierarchy] + if !ok { + continue + } + + quota, _ = parseV1CPUQuota(mp, cgroup.Path) + } + + if quota > 0 { + return quota, nil + } + } + + return 0, ErrNoCgroups +} + +func parseV1CPUQuota(mountPoint string, cgroupPath string) (float64, error) { + basePath := filepath.Join(mountPoint, cgroupPath) + cfsQuota, err := readFileInt(filepath.Join(basePath, "cpu.cfs_quota_us")) + if err != nil { + return -1, fmt.Errorf("parse cpu.cfs_quota_us: %w", err) + } + + // A value of -1 for cpu.cfs_quota_us indicates that the group does not have any + // bandwidth restriction in place + // https://www.kernel.org/doc/Documentation/scheduler/sched-bwc.txt + if cfsQuota == -1 { + return float64(runtime.NumCPU()), nil + } + + cfsPeriod, err := readFileInt(filepath.Join(basePath, "cpu.cfs_period_us")) + if err != nil { + return -1, fmt.Errorf("parse cpu.cfs_period_us: %w", err) + } + + return float64(cfsQuota) / float64(cfsPeriod), nil +} + +func parseV2CPUQuota(mountPoint string, cgroupPath string) (float64, error) { + /* + https://www.kernel.org/doc/Documentation/cgroup-v2.txt + + cpu.max + A read-write two value file which exists on non-root cgroups. + The default is "max 100000". + + The maximum bandwidth limit. It's in the following format:: + $MAX $PERIOD + + which indicates that the group may consume upto $MAX in each + $PERIOD duration. "max" for $MAX indicates no limit. If only + one number is written, $MAX is updated. + */ + rawCPUMax, err := ioutil.ReadFile(filepath.Join(mountPoint, cgroupPath, "cpu.max")) + if err != nil { + return -1, fmt.Errorf("read cpu.max: %w", err) + } + + parts := strings.Fields(string(rawCPUMax)) + if len(parts) != 2 { + return -1, fmt.Errorf("invalid cpu.max format: %s", string(rawCPUMax)) + } + + // "max" for $MAX indicates no limit + if parts[0] == "max" { + return float64(runtime.NumCPU()), nil + } + + cpuMax, err := strconv.Atoi(parts[0]) + if err != nil { + return -1, fmt.Errorf("parse cpu.max[max] (%q): %w", parts[0], err) + } + + cpuPeriod, err := strconv.Atoi(parts[1]) + if err != nil { + return -1, fmt.Errorf("parse cpu.max[period] (%q): %w", parts[1], err) + } + + return float64(cpuMax) / float64(cpuPeriod), nil +} diff --git a/library/go/maxprocs/doc.go b/library/go/maxprocs/doc.go new file mode 100644 index 0000000000..2461d6022c --- /dev/null +++ b/library/go/maxprocs/doc.go @@ -0,0 +1,9 @@ +// Automatically sets GOMAXPROCS to match Yandex clouds container CPU quota. +// +// This package always adjust GOMAXPROCS to some "safe" value. +// "safe" values are: +// - 2 or more +// - no more than logical cores +// - no moore than container guarantees +// - no more than 8 +package maxprocs diff --git a/library/go/maxprocs/helpers.go b/library/go/maxprocs/helpers.go new file mode 100644 index 0000000000..70263e6eb3 --- /dev/null +++ b/library/go/maxprocs/helpers.go @@ -0,0 +1,46 @@ +package maxprocs + +import ( + "bytes" + "io/ioutil" + "math" + "os" + "strconv" +) + +func getEnv(envName string) (string, bool) { + val, ok := os.LookupEnv(envName) + return val, ok && val != "" +} + +func applyIntStringLimit(val string) int { + maxProc, err := strconv.Atoi(val) + if err == nil { + return Adjust(maxProc) + } + + return Adjust(SafeProc) +} + +func applyFloatStringLimit(val string) int { + maxProc, err := strconv.ParseFloat(val, 64) + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(maxProc) +} + +func applyFloatLimit(val float64) int { + maxProc := int(math.Floor(val)) + return Adjust(maxProc) +} + +func readFileInt(filename string) (int, error) { + raw, err := ioutil.ReadFile(filename) + if err != nil { + return 0, err + } + + return strconv.Atoi(string(bytes.TrimSpace(raw))) +} diff --git a/library/go/maxprocs/maxprocs.go b/library/go/maxprocs/maxprocs.go new file mode 100644 index 0000000000..b5996ec6bc --- /dev/null +++ b/library/go/maxprocs/maxprocs.go @@ -0,0 +1,159 @@ +package maxprocs + +import ( + "context" + "os" + "runtime" + "strings" + + "a.yandex-team.ru/library/go/yandex/deploy/podagent" + "a.yandex-team.ru/library/go/yandex/yplite" +) + +const ( + SafeProc = 4 + MinProc = 2 + MaxProc = 8 + + GoMaxProcEnvName = "GOMAXPROCS" + QloudCPUEnvName = "QLOUD_CPU_GUARANTEE" + InstancectlCPUEnvName = "CPU_GUARANTEE" + DeloyBoxIDName = podagent.EnvBoxIDKey +) + +// Adjust adjust the maximum number of CPUs that can be executing. +// Takes a minimum between n and CPU counts and returns the previous setting +func Adjust(n int) int { + if n < MinProc { + n = MinProc + } + + nCPU := runtime.NumCPU() + if n < nCPU { + return runtime.GOMAXPROCS(n) + } + + return runtime.GOMAXPROCS(nCPU) +} + +// AdjustAuto automatically adjust the maximum number of CPUs that can be executing to safe value +// and returns the previous setting +func AdjustAuto() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if isCgroupsExists() { + return AdjustCgroup() + } + + if val, ok := getEnv(InstancectlCPUEnvName); ok { + return applyFloatStringLimit(strings.TrimRight(val, "c")) + } + + if val, ok := getEnv(QloudCPUEnvName); ok { + return applyFloatStringLimit(val) + } + + if boxID, ok := os.LookupEnv(DeloyBoxIDName); ok { + return adjustYPBox(boxID) + } + + if yplite.IsAPIAvailable() { + return AdjustYPLite() + } + + return Adjust(SafeProc) +} + +// AdjustQloud automatically adjust the maximum number of CPUs in case of Qloud env +// and returns the previous setting +func AdjustQloud() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if val, ok := getEnv(QloudCPUEnvName); ok { + return applyFloatStringLimit(val) + } + + return Adjust(MaxProc) +} + +// AdjustYP automatically adjust the maximum number of CPUs in case of YP/Y.Deploy/YP.Hard env +// and returns the previous setting +func AdjustYP() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if isCgroupsExists() { + return AdjustCgroup() + } + + return adjustYPBox(os.Getenv(DeloyBoxIDName)) +} + +func adjustYPBox(boxID string) int { + resources, err := podagent.NewClient().PodAttributes(context.Background()) + if err != nil { + return Adjust(SafeProc) + } + + var cpuGuarantee float64 + if boxResources, ok := resources.BoxesRequirements[boxID]; ok { + cpuGuarantee = boxResources.CPU.Guarantee / 1000 + } + + if cpuGuarantee <= 0 { + // if we don't have guarantees for current box, let's use pod guarantees + cpuGuarantee = resources.PodRequirements.CPU.Guarantee / 1000 + } + + return applyFloatLimit(cpuGuarantee) +} + +// AdjustYPLite automatically adjust the maximum number of CPUs in case of YP.Lite env +// and returns the previous setting +func AdjustYPLite() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + podAttributes, err := yplite.FetchPodAttributes() + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(float64(podAttributes.ResourceRequirements.CPU.Guarantee / 1000)) +} + +// AdjustInstancectl automatically adjust the maximum number of CPUs +// and returns the previous setting +// WARNING: supported only instancectl v1.177+ (https://wiki.yandex-team.ru/runtime-cloud/nanny/instancectl-change-log/#1.177) +func AdjustInstancectl() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + if val, ok := getEnv(InstancectlCPUEnvName); ok { + return applyFloatStringLimit(strings.TrimRight(val, "c")) + } + + return Adjust(MaxProc) +} + +// AdjustCgroup automatically adjust the maximum number of CPUs based on the CFS quota +// and returns the previous setting. +func AdjustCgroup() int { + if val, ok := getEnv(GoMaxProcEnvName); ok { + return applyIntStringLimit(val) + } + + quota, err := getCFSQuota() + if err != nil { + return Adjust(SafeProc) + } + + return applyFloatLimit(quota) +} diff --git a/library/go/ptr/ptr.go b/library/go/ptr/ptr.go new file mode 100644 index 0000000000..ae4ad1015f --- /dev/null +++ b/library/go/ptr/ptr.go @@ -0,0 +1,66 @@ +package ptr + +import "time" + +// Int returns pointer to provided value +func Int(v int) *int { return &v } + +// Int8 returns pointer to provided value +func Int8(v int8) *int8 { return &v } + +// Int16 returns pointer to provided value +func Int16(v int16) *int16 { return &v } + +// Int32 returns pointer to provided value +func Int32(v int32) *int32 { return &v } + +// Int64 returns pointer to provided value +func Int64(v int64) *int64 { return &v } + +// Uint returns pointer to provided value +func Uint(v uint) *uint { return &v } + +// Uint8 returns pointer to provided value +func Uint8(v uint8) *uint8 { return &v } + +// Uint16 returns pointer to provided value +func Uint16(v uint16) *uint16 { return &v } + +// Uint32 returns pointer to provided value +func Uint32(v uint32) *uint32 { return &v } + +// Uint64 returns pointer to provided value +func Uint64(v uint64) *uint64 { return &v } + +// Float32 returns pointer to provided value +func Float32(v float32) *float32 { return &v } + +// Float64 returns pointer to provided value +func Float64(v float64) *float64 { return &v } + +// Bool returns pointer to provided value +func Bool(v bool) *bool { return &v } + +// String returns pointer to provided value +func String(v string) *string { return &v } + +// Byte returns pointer to provided value +func Byte(v byte) *byte { return &v } + +// Rune returns pointer to provided value +func Rune(v rune) *rune { return &v } + +// Complex64 returns pointer to provided value +func Complex64(v complex64) *complex64 { return &v } + +// Complex128 returns pointer to provided value +func Complex128(v complex128) *complex128 { return &v } + +// Time returns pointer to provided value +func Time(v time.Time) *time.Time { return &v } + +// Duration returns pointer to provided value +func Duration(v time.Duration) *time.Duration { return &v } + +// T returns pointer to provided value +func T[T any](v T) *T { return &v } diff --git a/library/go/slices/contains.go b/library/go/slices/contains.go new file mode 100644 index 0000000000..0b68109f5c --- /dev/null +++ b/library/go/slices/contains.go @@ -0,0 +1,90 @@ +package slices + +import ( + "bytes" + "net" + + "github.com/gofrs/uuid" + "golang.org/x/exp/slices" +) + +// ContainsString checks if string slice contains given string. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsString = slices.Contains[string] + +// ContainsBool checks if bool slice contains given bool. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsBool = slices.Contains[bool] + +// ContainsInt checks if int slice contains given int +var ContainsInt = slices.Contains[int] + +// ContainsInt8 checks if int8 slice contains given int8. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt8 = slices.Contains[int8] + +// ContainsInt16 checks if int16 slice contains given int16. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt16 = slices.Contains[int16] + +// ContainsInt32 checks if int32 slice contains given int32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt32 = slices.Contains[int32] + +// ContainsInt64 checks if int64 slice contains given int64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsInt64 = slices.Contains[int64] + +// ContainsUint checks if uint slice contains given uint. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint = slices.Contains[uint] + +// ContainsUint8 checks if uint8 slice contains given uint8. +func ContainsUint8(haystack []uint8, needle uint8) bool { + return bytes.IndexByte(haystack, needle) != -1 +} + +// ContainsUint16 checks if uint16 slice contains given uint16. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint16 = slices.Contains[uint16] + +// ContainsUint32 checks if uint32 slice contains given uint32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint32 = slices.Contains[uint32] + +// ContainsUint64 checks if uint64 slice contains given uint64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUint64 = slices.Contains[uint64] + +// ContainsFloat32 checks if float32 slice contains given float32. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsFloat32 = slices.Contains[float32] + +// ContainsFloat64 checks if float64 slice contains given float64. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsFloat64 = slices.Contains[float64] + +// ContainsByte checks if byte slice contains given byte +func ContainsByte(haystack []byte, needle byte) bool { + return bytes.IndexByte(haystack, needle) != -1 +} + +// ContainsIP checks if net.IP slice contains given net.IP +func ContainsIP(haystack []net.IP, needle net.IP) bool { + for _, e := range haystack { + if e.Equal(needle) { + return true + } + } + return false +} + +// ContainsUUID checks if UUID slice contains given UUID. +// Deprecated: use golang.org/x/exp/slices.Contains instead +var ContainsUUID = slices.Contains[uuid.UUID] + +// Contains checks if slice of T contains given T +// Deprecated: use golang.org/x/exp/slices.Contains instead. +func Contains[E comparable](haystack []E, needle E) (bool, error) { + return slices.Contains(haystack, needle), nil +} diff --git a/library/go/slices/contains_all.go b/library/go/slices/contains_all.go new file mode 100644 index 0000000000..3c3e8e1878 --- /dev/null +++ b/library/go/slices/contains_all.go @@ -0,0 +1,23 @@ +package slices + +// ContainsAll checks if slice of type E contains all elements of given slice, order independent +func ContainsAll[E comparable](haystack []E, needle []E) bool { + m := make(map[E]struct{}, len(haystack)) + for _, i := range haystack { + m[i] = struct{}{} + } + for _, v := range needle { + if _, ok := m[v]; !ok { + return false + } + } + return true +} + +// ContainsAllStrings checks if string slice contains all elements of given slice +// Deprecated: use ContainsAll instead +var ContainsAllStrings = ContainsAll[string] + +// ContainsAllBools checks if bool slice contains all elements of given slice +// Deprecated: use ContainsAll instead +var ContainsAllBools = ContainsAll[bool] diff --git a/library/go/slices/contains_any.go b/library/go/slices/contains_any.go new file mode 100644 index 0000000000..0fc6a7ace4 --- /dev/null +++ b/library/go/slices/contains_any.go @@ -0,0 +1,72 @@ +package slices + +import ( + "bytes" +) + +// ContainsAny checks if slice of type E contains any element from given slice +func ContainsAny[E comparable](haystack, needle []E) bool { + return len(Intersection(haystack, needle)) > 0 +} + +// ContainsAnyString checks if string slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyString = ContainsAny[string] + +// ContainsAnyBool checks if bool slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyBool = ContainsAny[bool] + +// ContainsAnyInt checks if int slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt = ContainsAny[int] + +// ContainsAnyInt8 checks if int8 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt8 = ContainsAny[int8] + +// ContainsAnyInt16 checks if int16 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt16 = ContainsAny[int16] + +// ContainsAnyInt32 checks if int32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt32 = ContainsAny[int32] + +// ContainsAnyInt64 checks if int64 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyInt64 = ContainsAny[int64] + +// ContainsAnyUint checks if uint slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint = ContainsAny[uint] + +// ContainsAnyUint8 checks if uint8 slice contains any element from given slice +func ContainsAnyUint8(haystack []uint8, needle []uint8) bool { + return bytes.Contains(haystack, needle) +} + +// ContainsAnyUint16 checks if uint16 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint16 = ContainsAny[uint16] + +// ContainsAnyUint32 checks if uint32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint32 = ContainsAny[uint32] + +// ContainsAnyUint64 checks if uint64 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyUint64 = ContainsAny[uint64] + +// ContainsAnyFloat32 checks if float32 slice contains any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyFloat32 = ContainsAny[float32] + +// ContainsAnyFloat64 checks if float64 slice any element from given slice +// Deprecated: use ContainsAny instead. +var ContainsAnyFloat64 = ContainsAny[float64] + +// ContainsAnyByte checks if byte slice contains any element from given slice +func ContainsAnyByte(haystack []byte, needle []byte) bool { + return bytes.Contains(haystack, needle) +} diff --git a/library/go/slices/dedup.go b/library/go/slices/dedup.go new file mode 100644 index 0000000000..365f3b2d74 --- /dev/null +++ b/library/go/slices/dedup.go @@ -0,0 +1,109 @@ +package slices + +import ( + "sort" + + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) + +// Dedup removes duplicate values from slice. +// It will alter original non-empty slice, consider copy it beforehand. +func Dedup[E constraints.Ordered](s []E) []E { + if len(s) < 2 { + return s + } + slices.Sort(s) + tmp := s[:1] + cur := s[0] + for i := 1; i < len(s); i++ { + if s[i] != cur { + tmp = append(tmp, s[i]) + cur = s[i] + } + } + return tmp +} + +// DedupBools removes duplicate values from bool slice. +// It will alter original non-empty slice, consider copy it beforehand. +func DedupBools(a []bool) []bool { + if len(a) < 2 { + return a + } + sort.Slice(a, func(i, j int) bool { return a[i] != a[j] }) + tmp := a[:1] + cur := a[0] + for i := 1; i < len(a); i++ { + if a[i] != cur { + tmp = append(tmp, a[i]) + cur = a[i] + } + } + return tmp +} + +// DedupStrings removes duplicate values from string slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupStrings = Dedup[string] + +// DedupInts removes duplicate values from ints slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInts = Dedup[int] + +// DedupInt8s removes duplicate values from int8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt8s = Dedup[int8] + +// DedupInt16s removes duplicate values from int16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt16s = Dedup[int16] + +// DedupInt32s removes duplicate values from int32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt32s = Dedup[int32] + +// DedupInt64s removes duplicate values from int64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupInt64s = Dedup[int64] + +// DedupUints removes duplicate values from uint slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUints = Dedup[uint] + +// DedupUint8s removes duplicate values from uint8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint8s = Dedup[uint8] + +// DedupUint16s removes duplicate values from uint16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint16s = Dedup[uint16] + +// DedupUint32s removes duplicate values from uint32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint32s = Dedup[uint32] + +// DedupUint64s removes duplicate values from uint64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupUint64s = Dedup[uint64] + +// DedupFloat32s removes duplicate values from float32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupFloat32s = Dedup[float32] + +// DedupFloat64s removes duplicate values from float64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Dedup instead. +var DedupFloat64s = Dedup[float64] diff --git a/library/go/slices/equal.go b/library/go/slices/equal.go new file mode 100644 index 0000000000..3c698c7865 --- /dev/null +++ b/library/go/slices/equal.go @@ -0,0 +1,22 @@ +package slices + +// EqualUnordered checks if slices of type E are equal, order independent. +func EqualUnordered[E comparable](a []E, b []E) bool { + if len(a) != len(b) { + return false + } + + ma := make(map[E]struct{}) + for _, v := range a { + ma[v] = struct{}{} + } + l := len(ma) + for _, v := range b { + ma[v] = struct{}{} + } + return len(ma) == l +} + +// EqualAnyOrderStrings checks if string slices are equal, order independent. +// Deprecated: use EqualUnordered instead. +var EqualAnyOrderStrings = EqualUnordered[string] diff --git a/library/go/slices/filter.go b/library/go/slices/filter.go new file mode 100644 index 0000000000..8b383bfcb2 --- /dev/null +++ b/library/go/slices/filter.go @@ -0,0 +1,29 @@ +package slices + +import ( + "golang.org/x/exp/slices" +) + +// Filter reduces slice values using given function. +// It operates with a copy of given slice +func Filter[S ~[]T, T any](s S, fn func(T) bool) S { + if len(s) == 0 { + return s + } + return Reduce(slices.Clone(s), fn) +} + +// Reduce is like Filter, but modifies original slice. +func Reduce[S ~[]T, T any](s S, fn func(T) bool) S { + if len(s) == 0 { + return s + } + var p int + for _, v := range s { + if fn(v) { + s[p] = v + p++ + } + } + return s[:p] +} diff --git a/library/go/slices/group_by.go b/library/go/slices/group_by.go new file mode 100644 index 0000000000..fb61a29314 --- /dev/null +++ b/library/go/slices/group_by.go @@ -0,0 +1,90 @@ +package slices + +import ( + "fmt" +) + +func createNotUniqueKeyError[T comparable](key T) error { + return fmt.Errorf("duplicated key \"%v\" found. keys are supposed to be unique", key) +} + +// GroupBy groups slice entities into map by key provided via keyGetter. +func GroupBy[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) map[K][]T { + res := map[K][]T{} + + for _, entity := range s { + key := keyGetter(entity) + res[key] = append(res[key], entity) + } + + return res +} + +// GroupByUniqueKey groups slice entities into map by key provided via keyGetter with assumption that each key is unique. +// +// Returns an error in case of key ununiqueness. +func GroupByUniqueKey[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) (map[K]T, error) { + res := map[K]T{} + + for _, entity := range s { + key := keyGetter(entity) + + _, duplicated := res[key] + if duplicated { + return res, createNotUniqueKeyError(key) + } + + res[key] = entity + } + + return res, nil +} + +// IndexedEntity stores an entity of original slice with its initial index in that slice +type IndexedEntity[T any] struct { + Value T + Index int +} + +// GroupByWithIndex groups slice entities into map by key provided via keyGetter. +// Each entity of underlying result slice contains the value itself and its index in the original slice +// (See IndexedEntity). +func GroupByWithIndex[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) map[K][]IndexedEntity[T] { + res := map[K][]IndexedEntity[T]{} + + for i, entity := range s { + key := keyGetter(entity) + res[key] = append(res[key], IndexedEntity[T]{ + Value: entity, + Index: i, + }) + } + + return res +} + +// GroupByUniqueKeyWithIndex groups slice entities into map by key provided via keyGetter with assumption that +// each key is unique. +// Each result entity contains the value itself and its index in the original slice +// (See IndexedEntity). +// +// Returns an error in case of key ununiqueness. +func GroupByUniqueKeyWithIndex[S ~[]T, T any, K comparable](s S, keyGetter func(T) K) (map[K]IndexedEntity[T], error) { + res := map[K]IndexedEntity[T]{} + + for i, entity := range s { + key := keyGetter(entity) + + _, duplicated := res[key] + if duplicated { + return res, createNotUniqueKeyError(key) + } + + res[key] = IndexedEntity[T]{ + Value: entity, + Index: i, + } + } + + return res, nil +} diff --git a/library/go/slices/intersects.go b/library/go/slices/intersects.go new file mode 100644 index 0000000000..b7785952df --- /dev/null +++ b/library/go/slices/intersects.go @@ -0,0 +1,83 @@ +package slices + +// Intersection returns intersection for slices of various built-in types +func Intersection[E comparable](a, b []E) []E { + if len(a) == 0 || len(b) == 0 { + return nil + } + + p, s := a, b + if len(b) > len(a) { + p, s = b, a + } + + m := make(map[E]struct{}) + for _, i := range p { + m[i] = struct{}{} + } + + var res []E + for _, v := range s { + if _, exists := m[v]; exists { + res = append(res, v) + } + } + + return res +} + +// IntersectStrings returns intersection of two string slices +// Deprecated: use Intersection instead. +var IntersectStrings = Intersection[string] + +// IntersectInts returns intersection of two int slices +// Deprecated: use Intersection instead. +var IntersectInts = Intersection[int] + +// IntersectInt8s returns intersection of two int8 slices +// Deprecated: use Intersection instead. +var IntersectInt8s = Intersection[int8] + +// IntersectInt16s returns intersection of two int16 slices +// Deprecated: use Intersection instead. +var IntersectInt16s = Intersection[int16] + +// IntersectInt32s returns intersection of two int32 slices +// Deprecated: use Intersection instead. +var IntersectInt32s = Intersection[int32] + +// IntersectInt64s returns intersection of two int64 slices +// Deprecated: use Intersection instead. +var IntersectInt64s = Intersection[int64] + +// IntersectUints returns intersection of two uint slices +// Deprecated: use Intersection instead. +var IntersectUints = Intersection[uint] + +// IntersectUint8s returns intersection of two uint8 slices +// Deprecated: use Intersection instead. +var IntersectUint8s = Intersection[uint8] + +// IntersectUint16s returns intersection of two uint16 slices +// Deprecated: use Intersection instead. +var IntersectUint16s = Intersection[uint16] + +// IntersectUint32s returns intersection of two uint32 slices +// Deprecated: use Intersection instead. +var IntersectUint32s = Intersection[uint32] + +// IntersectUint64s returns intersection of two uint64 slices +// Deprecated: use Intersection instead. +var IntersectUint64s = Intersection[uint64] + +// IntersectFloat32s returns intersection of two float32 slices +// Deprecated: use Intersection instead. +var IntersectFloat32s = Intersection[float32] + +// IntersectFloat64s returns intersection of two float64 slices +// Deprecated: use Intersection instead. +var IntersectFloat64s = Intersection[float64] + +// IntersectBools returns intersection of two bool slices +// Deprecated: use Intersection instead. +var IntersectBools = Intersection[bool] diff --git a/library/go/slices/join.go b/library/go/slices/join.go new file mode 100644 index 0000000000..7b72db5ed1 --- /dev/null +++ b/library/go/slices/join.go @@ -0,0 +1,14 @@ +package slices + +import ( + "fmt" + "strings" +) + +// Join joins slice of any types +func Join(s interface{}, glue string) string { + if t, ok := s.([]string); ok { + return strings.Join(t, glue) + } + return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(s)), glue), "[]") +} diff --git a/library/go/slices/map.go b/library/go/slices/map.go new file mode 100644 index 0000000000..943261f786 --- /dev/null +++ b/library/go/slices/map.go @@ -0,0 +1,27 @@ +package slices + +// Map applies given function to every value of slice +func Map[S ~[]T, T, M any](s S, fn func(T) M) []M { + if s == nil { + return []M(nil) + } + if len(s) == 0 { + return make([]M, 0) + } + res := make([]M, len(s)) + for i, v := range s { + res[i] = fn(v) + } + return res +} + +// Mutate is like Map, but it prohibits type changes and modifies original slice. +func Mutate[S ~[]T, T any](s S, fn func(T) T) S { + if len(s) == 0 { + return s + } + for i, v := range s { + s[i] = fn(v) + } + return s +} diff --git a/library/go/slices/reverse.go b/library/go/slices/reverse.go new file mode 100644 index 0000000000..a436617b67 --- /dev/null +++ b/library/go/slices/reverse.go @@ -0,0 +1,83 @@ +package slices + +// Reverse reverses given slice. +// It will alter original non-empty slice, consider copy it beforehand. +func Reverse[E any](s []E) []E { + if len(s) < 2 { + return s + } + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } + return s +} + +// ReverseStrings reverses given string slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseStrings = Reverse[string] + +// ReverseInts reverses given int slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInts = Reverse[int] + +// ReverseInt8s reverses given int8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt8s = Reverse[int8] + +// ReverseInt16s reverses given int16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt16s = Reverse[int16] + +// ReverseInt32s reverses given int32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt32s = Reverse[int32] + +// ReverseInt64s reverses given int64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseInt64s = Reverse[int64] + +// ReverseUints reverses given uint slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUints = Reverse[uint] + +// ReverseUint8s reverses given uint8 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint8s = Reverse[uint8] + +// ReverseUint16s reverses given uint16 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint16s = Reverse[uint16] + +// ReverseUint32s reverses given uint32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint32s = Reverse[uint32] + +// ReverseUint64s reverses given uint64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseUint64s = Reverse[uint64] + +// ReverseFloat32s reverses given float32 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseFloat32s = Reverse[float32] + +// ReverseFloat64s reverses given float64 slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseFloat64s = Reverse[float64] + +// ReverseBools reverses given bool slice. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Reverse instead. +var ReverseBools = Reverse[bool] diff --git a/library/go/slices/shuffle.go b/library/go/slices/shuffle.go new file mode 100644 index 0000000000..5df9b33c3c --- /dev/null +++ b/library/go/slices/shuffle.go @@ -0,0 +1,95 @@ +package slices + +import ( + "math/rand" +) + +// Shuffle shuffles values in slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +func Shuffle[E any](a []E, src rand.Source) []E { + if len(a) < 2 { + return a + } + shuffle(src)(len(a), func(i, j int) { + a[i], a[j] = a[j], a[i] + }) + return a +} + +// ShuffleStrings shuffles values in string slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleStrings = Shuffle[string] + +// ShuffleInts shuffles values in int slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInts = Shuffle[int] + +// ShuffleInt8s shuffles values in int8 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt8s = Shuffle[int8] + +// ShuffleInt16s shuffles values in int16 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt16s = Shuffle[int16] + +// ShuffleInt32s shuffles values in int32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt32s = Shuffle[int32] + +// ShuffleInt64s shuffles values in int64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleInt64s = Shuffle[int64] + +// ShuffleUints shuffles values in uint slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUints = Shuffle[uint] + +// ShuffleUint8s shuffles values in uint8 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint8s = Shuffle[uint8] + +// ShuffleUint16s shuffles values in uint16 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint16s = Shuffle[uint16] + +// ShuffleUint32s shuffles values in uint32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint32s = Shuffle[uint32] + +// ShuffleUint64s shuffles values in uint64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleUint64s = Shuffle[uint64] + +// ShuffleFloat32s shuffles values in float32 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleFloat32s = Shuffle[float32] + +// ShuffleFloat64s shuffles values in float64 slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleFloat64s = Shuffle[float64] + +// ShuffleBools shuffles values in bool slice using given or pseudo-random source. +// It will alter original non-empty slice, consider copy it beforehand. +// Deprecated: use Shuffle instead. +var ShuffleBools = Shuffle[bool] + +func shuffle(src rand.Source) func(n int, swap func(i, j int)) { + shuf := rand.Shuffle + if src != nil { + shuf = rand.New(src).Shuffle + } + return shuf +} diff --git a/library/go/x/xreflect/assign.go b/library/go/x/xreflect/assign.go new file mode 100644 index 0000000000..624612575c --- /dev/null +++ b/library/go/x/xreflect/assign.go @@ -0,0 +1,17 @@ +package xreflect + +import "reflect" + +// Assign source's value to target's value it points to. Source must be value, target must be pointer to existing value. +// Source must be assignable to target's value it points to. +func Assign(source interface{}, target interface{}) bool { + val := reflect.ValueOf(target) + typ := val.Type() + targetType := typ.Elem() + if reflect.TypeOf(source).AssignableTo(targetType) { + val.Elem().Set(reflect.ValueOf(source)) + return true + } + + return false +} diff --git a/library/go/x/xruntime/stacktrace.go b/library/go/x/xruntime/stacktrace.go new file mode 100644 index 0000000000..5c5e661188 --- /dev/null +++ b/library/go/x/xruntime/stacktrace.go @@ -0,0 +1,69 @@ +package xruntime + +import ( + "runtime" +) + +type StackTrace struct { + frames []uintptr + full bool +} + +func NewStackTrace16(skip int) *StackTrace { + var pcs [16]uintptr + return newStackTrace(skip+2, pcs[:]) +} + +func NewStackTrace32(skip int) *StackTrace { + var pcs [32]uintptr + return newStackTrace(skip+2, pcs[:]) +} + +func NewStackTrace64(skip int) *StackTrace { + var pcs [64]uintptr + return newStackTrace(skip+2, pcs[:]) +} + +func NewStackTrace128(skip int) *StackTrace { + var pcs [128]uintptr + return newStackTrace(skip+2, pcs[:]) +} + +func newStackTrace(skip int, pcs []uintptr) *StackTrace { + n := runtime.Callers(skip+1, pcs) + return &StackTrace{frames: pcs[:n], full: true} +} + +func NewFrame(skip int) *StackTrace { + var pcs [3]uintptr + n := runtime.Callers(skip+1, pcs[:]) + return &StackTrace{frames: pcs[:n]} +} + +func (st *StackTrace) Frames() []runtime.Frame { + frames := runtime.CallersFrames(st.frames[:]) + if !st.full { + if _, ok := frames.Next(); !ok { + return nil + } + + fr, ok := frames.Next() + if !ok { + return nil + } + + return []runtime.Frame{fr} + } + + var res []runtime.Frame + for { + frame, more := frames.Next() + if !more { + break + } + + res = append(res, frame) + } + + return res +} diff --git a/library/go/yandex/deploy/podagent/client.go b/library/go/yandex/deploy/podagent/client.go new file mode 100644 index 0000000000..09dc10ada6 --- /dev/null +++ b/library/go/yandex/deploy/podagent/client.go @@ -0,0 +1,67 @@ +package podagent + +import ( + "context" + "time" + + "github.com/go-resty/resty/v2" + + "a.yandex-team.ru/library/go/core/xerrors" + "a.yandex-team.ru/library/go/httputil/headers" +) + +const ( + EndpointURL = "http://127.0.0.1:1/" + HTTPTimeout = 500 * time.Millisecond +) + +type Client struct { + httpc *resty.Client +} + +func NewClient(opts ...Option) *Client { + c := &Client{ + httpc: resty.New(). + SetBaseURL(EndpointURL). + SetTimeout(HTTPTimeout), + } + + for _, opt := range opts { + opt(c) + } + return c +} + +// PodAttributes returns current pod attributes. +// +// Documentation: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api#localhost:1pod_attributes +func (c *Client) PodAttributes(ctx context.Context) (rsp PodAttributesResponse, err error) { + err = c.call(ctx, "/pod_attributes", &rsp) + return +} + +// PodStatus returns current pod status. +// +// Documentation: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api#localhost:1pod_status +func (c *Client) PodStatus(ctx context.Context) (rsp PodStatusResponse, err error) { + err = c.call(ctx, "/pod_status", &rsp) + return +} + +func (c *Client) call(ctx context.Context, handler string, result interface{}) error { + rsp, err := c.httpc.R(). + SetContext(ctx). + ExpectContentType(headers.TypeApplicationJSON.String()). + SetResult(&result). + Get(handler) + + if err != nil { + return xerrors.Errorf("failed to request pod agent API: %w", err) + } + + if !rsp.IsSuccess() { + return xerrors.Errorf("unexpected status code: %d", rsp.StatusCode()) + } + + return nil +} diff --git a/library/go/yandex/deploy/podagent/doc.go b/library/go/yandex/deploy/podagent/doc.go new file mode 100644 index 0000000000..326b84040f --- /dev/null +++ b/library/go/yandex/deploy/podagent/doc.go @@ -0,0 +1,4 @@ +// Package podagent provides the client and types for making API requests to Y.Deploy PodAgent. +// +// Official documentation for PogAgent public API: https://deploy.yandex-team.ru/docs/reference/api/pod-agent-public-api +package podagent diff --git a/library/go/yandex/deploy/podagent/env.go b/library/go/yandex/deploy/podagent/env.go new file mode 100644 index 0000000000..4dd4ae1790 --- /dev/null +++ b/library/go/yandex/deploy/podagent/env.go @@ -0,0 +1,33 @@ +package podagent + +import "os" + +// Box/Workload environment variable names, documentation references: +// - https://deploy.yandex-team.ru/docs/concepts/pod/box#systemenv +// - https://deploy.yandex-team.ru/docs/concepts/pod/workload/workload#system_env +const ( + EnvWorkloadIDKey = "DEPLOY_WORKLOAD_ID" + EnvContainerIDKey = "DEPLOY_CONTAINER_ID" + EnvBoxIDKey = "DEPLOY_BOX_ID" + EnvPodIDKey = "DEPLOY_POD_ID" + EnvProjectIDKey = "DEPLOY_PROJECT_ID" + EnvStageIDKey = "DEPLOY_STAGE_ID" + EnvUnitIDKey = "DEPLOY_UNIT_ID" + + EnvLogsEndpointKey = "DEPLOY_LOGS_ENDPOINT" + EnvLogsNameKey = "DEPLOY_LOGS_DEFAULT_NAME" + EnvLogsSecretKey = "DEPLOY_LOGS_SECRET" + + EnvNodeClusterKey = "DEPLOY_NODE_CLUSTER" + EnvNodeDCKey = "DEPLOY_NODE_DC" + EnvNodeFQDNKey = "DEPLOY_NODE_FQDN" + + EnvPodPersistentFQDN = "DEPLOY_POD_PERSISTENT_FQDN" + EnvPodTransientFQDN = "DEPLOY_POD_TRANSIENT_FQDN" +) + +// UnderPodAgent returns true if application managed by pod-agent. +func UnderPodAgent() bool { + _, ok := os.LookupEnv(EnvPodIDKey) + return ok +} diff --git a/library/go/yandex/deploy/podagent/options.go b/library/go/yandex/deploy/podagent/options.go new file mode 100644 index 0000000000..32c3ac71aa --- /dev/null +++ b/library/go/yandex/deploy/podagent/options.go @@ -0,0 +1,17 @@ +package podagent + +import "a.yandex-team.ru/library/go/core/log" + +type Option func(client *Client) + +func WithEndpoint(endpointURL string) Option { + return func(c *Client) { + c.httpc.SetBaseURL(endpointURL) + } +} + +func WithLogger(l log.Fmt) Option { + return func(c *Client) { + c.httpc.SetLogger(l) + } +} diff --git a/library/go/yandex/deploy/podagent/responses.go b/library/go/yandex/deploy/podagent/responses.go new file mode 100644 index 0000000000..e97c70dc7c --- /dev/null +++ b/library/go/yandex/deploy/podagent/responses.go @@ -0,0 +1,82 @@ +package podagent + +import ( + "encoding/json" + "net" +) + +type BoxStatus struct { + ID string `json:"id"` + Revision uint32 `json:"revision"` +} + +type WorkloadStatus struct { + ID string `json:"id"` + Revision uint32 `json:"revision"` +} + +type PodStatusResponse struct { + Boxes []BoxStatus `json:"boxes"` + Workloads []WorkloadStatus `json:"workloads"` +} + +type MemoryResource struct { + Guarantee uint64 `json:"memory_guarantee_bytes"` + Limit uint64 `json:"memory_limit_bytes"` +} + +type CPUResource struct { + Guarantee float64 `json:"cpu_guarantee_millicores"` + Limit float64 `json:"cpu_limit_millicores"` +} + +type ResourceRequirements struct { + Memory MemoryResource `json:"memory"` + CPU CPUResource `json:"cpu"` +} + +type NodeMeta struct { + DC string `json:"dc"` + Cluster string `json:"cluster"` + FQDN string `json:"fqdn"` +} + +type PodMeta struct { + PodID string `json:"pod_id"` + PodSetID string `json:"pod_set_id"` + Annotations json.RawMessage `json:"annotations"` + Labels json.RawMessage `json:"labels"` +} + +type Resources struct { + Boxes map[string]ResourceRequirements `json:"box_resource_requirements"` + Pod ResourceRequirements `json:"resource_requirements"` +} + +type InternetAddress struct { + Address net.IP `json:"ip4_address"` + ID string `json:"id"` +} + +type VirtualService struct { + IPv4Addrs []net.IP `json:"ip4_addresses"` + IPv6Addrs []net.IP `json:"ip6_addresses"` +} + +type IPAllocation struct { + InternetAddress InternetAddress `json:"internet_address"` + TransientFQDN string `json:"transient_fqdn"` + PersistentFQDN string `json:"persistent_fqdn"` + Addr net.IP `json:"address"` + VlanID string `json:"vlan_id"` + VirtualServices []VirtualService `json:"virtual_services"` + Labels map[string]string `json:"labels"` +} + +type PodAttributesResponse struct { + NodeMeta NodeMeta `json:"node_meta"` + PodMeta PodMeta `json:"metadata"` + BoxesRequirements map[string]ResourceRequirements `json:"box_resource_requirements"` + PodRequirements ResourceRequirements `json:"resource_requirements"` + IPAllocations []IPAllocation `json:"ip6_address_allocations"` +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/handler.go b/library/go/yandex/solomon/reporters/puller/httppuller/handler.go new file mode 100644 index 0000000000..c0763c9ff0 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/handler.go @@ -0,0 +1,120 @@ +package httppuller + +import ( + "context" + "fmt" + "io" + "net/http" + "reflect" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/core/metrics/solomon" + "a.yandex-team.ru/library/go/httputil/headers" + "a.yandex-team.ru/library/go/httputil/middleware/tvm" +) + +const nilRegistryPanicMsg = "nil registry given" + +type MetricsStreamer interface { + StreamJSON(context.Context, io.Writer) (int, error) + StreamSpack(context.Context, io.Writer, solomon.CompressionType) (int, error) +} + +type handler struct { + registry MetricsStreamer + streamFormat headers.ContentType + checkTicket func(h http.Handler) http.Handler + logger log.Logger +} + +type Option interface { + isOption() +} + +// NewHandler returns new HTTP handler to expose gathered metrics using metrics dumper +func NewHandler(r MetricsStreamer, opts ...Option) http.Handler { + if v := reflect.ValueOf(r); !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() { + panic(nilRegistryPanicMsg) + } + + h := handler{ + registry: r, + streamFormat: headers.TypeApplicationJSON, + checkTicket: func(h http.Handler) http.Handler { + return h + }, + logger: &nop.Logger{}, + } + + for _, opt := range opts { + switch o := opt.(type) { + case *tvmOption: + h.checkTicket = tvm.CheckServiceTicket(o.client, tvm.WithAllowedClients(AllFetchers)) + case *spackOption: + h.streamFormat = headers.TypeApplicationXSolomonSpack + case *loggerOption: + h.logger = o.logger + default: + panic(fmt.Sprintf("unsupported option %T", opt)) + } + } + + return h.checkTicket(h) +} + +func (h handler) okSpack(header http.Header) bool { + if h.streamFormat != headers.TypeApplicationXSolomonSpack { + return false + } + for _, header := range header[headers.AcceptKey] { + types, err := headers.ParseAccept(header) + if err != nil { + h.logger.Warn("Can't parse accept header", log.Error(err), log.String("header", header)) + continue + } + for _, acceptableType := range types { + if acceptableType.Type == headers.TypeApplicationXSolomonSpack { + return true + } + } + } + return false +} + +func (h handler) okLZ4Compression(header http.Header) bool { + for _, header := range header[headers.AcceptEncodingKey] { + encodings, err := headers.ParseAcceptEncoding(header) + if err != nil { + h.logger.Warn("Can't parse accept-encoding header", log.Error(err), log.String("header", header)) + continue + } + for _, acceptableEncoding := range encodings { + if acceptableEncoding.Encoding == headers.EncodingLZ4 { + return true + } + } + } + return false +} + +func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.okSpack(r.Header) { + compression := solomon.CompressionNone + if h.okLZ4Compression(r.Header) { + compression = solomon.CompressionLz4 + } + w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationXSolomonSpack.String()) + _, err := h.registry.StreamSpack(r.Context(), w, compression) + if err != nil { + h.logger.Error("Failed to write compressed spack", log.Error(err)) + } + return + } + + w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String()) + _, err := h.registry.StreamJSON(r.Context(), w) + if err != nil { + h.logger.Error("Failed to write json", log.Error(err)) + } +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/logger.go b/library/go/yandex/solomon/reporters/puller/httppuller/logger.go new file mode 100644 index 0000000000..c8cf242aae --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/logger.go @@ -0,0 +1,15 @@ +package httppuller + +import "a.yandex-team.ru/library/go/core/log" + +type loggerOption struct { + logger log.Logger +} + +func (*loggerOption) isOption() {} + +func WithLogger(logger log.Logger) Option { + return &loggerOption{ + logger: logger, + } +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/spack.go b/library/go/yandex/solomon/reporters/puller/httppuller/spack.go new file mode 100644 index 0000000000..cf59abd52a --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/spack.go @@ -0,0 +1,10 @@ +package httppuller + +type spackOption struct { +} + +func (*spackOption) isOption() {} + +func WithSpack() Option { + return &spackOption{} +} diff --git a/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go b/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go new file mode 100644 index 0000000000..2b842bcc20 --- /dev/null +++ b/library/go/yandex/solomon/reporters/puller/httppuller/tvm.go @@ -0,0 +1,27 @@ +package httppuller + +import "a.yandex-team.ru/library/go/yandex/tvm" + +const ( + FetcherPreTVMID = 2012024 + FetcherTestTVMID = 2012026 + FetcherProdTVMID = 2012028 +) + +var ( + AllFetchers = []tvm.ClientID{ + FetcherPreTVMID, + FetcherTestTVMID, + FetcherProdTVMID, + } +) + +type tvmOption struct { + client tvm.Client +} + +func (*tvmOption) isOption() {} + +func WithTVM(tvm tvm.Client) Option { + return &tvmOption{client: tvm} +} diff --git a/library/go/yandex/tvm/cachedtvm/cache.go b/library/go/yandex/tvm/cachedtvm/cache.go new file mode 100644 index 0000000000..a04e2baf8a --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/cache.go @@ -0,0 +1,22 @@ +package cachedtvm + +import ( + "time" + + "github.com/karlseguin/ccache/v2" +) + +type cache struct { + *ccache.Cache + ttl time.Duration +} + +func (c *cache) Fetch(key string, fn func() (interface{}, error)) (*ccache.Item, error) { + return c.Cache.Fetch(key, c.ttl, fn) +} + +func (c *cache) Stop() { + if c.Cache != nil { + c.Cache.Stop() + } +} diff --git a/library/go/yandex/tvm/cachedtvm/client.go b/library/go/yandex/tvm/cachedtvm/client.go new file mode 100644 index 0000000000..503c973e8c --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client.go @@ -0,0 +1,117 @@ +package cachedtvm + +import ( + "context" + "fmt" + "time" + + "github.com/karlseguin/ccache/v2" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +const ( + DefaultTTL = 1 * time.Minute + DefaultMaxItems = 100 + MaxServiceTicketTTL = 5 * time.Minute + MaxUserTicketTTL = 1 * time.Minute +) + +type CachedClient struct { + tvm.Client + serviceTicketCache cache + userTicketCache cache + userTicketFn func(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) +} + +func NewClient(tvmClient tvm.Client, opts ...Option) (*CachedClient, error) { + newCache := func(o cacheOptions) cache { + return cache{ + Cache: ccache.New( + ccache.Configure().MaxSize(o.maxItems), + ), + ttl: o.ttl, + } + } + + out := &CachedClient{ + Client: tvmClient, + serviceTicketCache: newCache(cacheOptions{ + ttl: DefaultTTL, + maxItems: DefaultMaxItems, + }), + userTicketFn: tvmClient.CheckUserTicket, + } + + for _, opt := range opts { + switch o := opt.(type) { + case OptionServiceTicket: + if o.ttl > MaxServiceTicketTTL { + return nil, fmt.Errorf("maximum TTL for check service ticket exceed: %s > %s", o.ttl, MaxServiceTicketTTL) + } + + out.serviceTicketCache = newCache(o.cacheOptions) + case OptionUserTicket: + if o.ttl > MaxUserTicketTTL { + return nil, fmt.Errorf("maximum TTL for check user ticket exceed: %s > %s", o.ttl, MaxUserTicketTTL) + } + + out.userTicketFn = out.cacheCheckUserTicket + out.userTicketCache = newCache(o.cacheOptions) + default: + panic(fmt.Sprintf("unexpected cache option: %T", o)) + } + } + + return out, nil +} + +func (c *CachedClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + out, err := c.serviceTicketCache.Fetch(ticket, func() (interface{}, error) { + return c.Client.CheckServiceTicket(ctx, ticket) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedServiceTicket), nil +} + +func (c *CachedClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + return c.userTicketFn(ctx, ticket, opts...) +} + +func (c *CachedClient) cacheCheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + cacheKey := func(ticket string, opts ...tvm.CheckUserTicketOption) string { + if len(opts) == 0 { + return ticket + } + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + if options.EnvOverride == nil { + return ticket + } + + return fmt.Sprintf("%d:%s", *options.EnvOverride, ticket) + } + + out, err := c.userTicketCache.Fetch(cacheKey(ticket, opts...), func() (interface{}, error) { + return c.Client.CheckUserTicket(ctx, ticket, opts...) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedUserTicket), nil +} + +func (c *CachedClient) Close() { + c.serviceTicketCache.Stop() + c.userTicketCache.Stop() +} diff --git a/library/go/yandex/tvm/cachedtvm/client_example_test.go b/library/go/yandex/tvm/cachedtvm/client_example_test.go new file mode 100644 index 0000000000..a95b1674a3 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_example_test.go @@ -0,0 +1,40 @@ +package cachedtvm_test + +import ( + "context" + "fmt" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient_checkServiceTicket() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + cachedTvmClient, err := cachedtvm.NewClient( + tvmClient, + cachedtvm.WithCheckServiceTicket(1*time.Minute, 1000), + ) + if err != nil { + panic(err) + } + defer cachedTvmClient.Close() + + ticketInfo, err := cachedTvmClient.CheckServiceTicket(context.TODO(), "3:serv:....") + if err != nil { + panic(err) + } + + fmt.Println("ticket info: ", ticketInfo.LogInfo) +} diff --git a/library/go/yandex/tvm/cachedtvm/client_test.go b/library/go/yandex/tvm/cachedtvm/client_test.go new file mode 100644 index 0000000000..a3c3081e30 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_test.go @@ -0,0 +1,195 @@ +package cachedtvm_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm" +) + +const ( + checkPasses = 5 +) + +type mockTvmClient struct { + tvm.Client + checkServiceTicketCalls int + checkUserTicketCalls int +} + +func (c *mockTvmClient) CheckServiceTicket(_ context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + defer func() { c.checkServiceTicketCalls++ }() + + return &tvm.CheckedServiceTicket{ + LogInfo: ticket, + IssuerUID: tvm.UID(c.checkServiceTicketCalls), + }, nil +} + +func (c *mockTvmClient) CheckUserTicket(_ context.Context, ticket string, _ ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer func() { c.checkUserTicketCalls++ }() + + return &tvm.CheckedUserTicket{ + LogInfo: ticket, + DefaultUID: tvm.UID(c.checkUserTicketCalls), + }, nil +} + +func (c *mockTvmClient) GetServiceTicketForAlias(_ context.Context, alias string) (string, error) { + return alias, nil +} + +func checkServiceTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckServiceTicket(context.Background(), "3:serv:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, true) + } +} + +func runNotEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, false) + } +} + +func checkUserTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckUserTicket(context.Background(), "3:user:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, true) + } +} + +func runNotEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, false) + } +} +func TestDefaultBehavior(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 1, nestedClient.checkServiceTicketCalls) + require.Equal(t, checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckServiceTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + require.Equal(t, 2*checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckUserTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_usr", runEqualUserTickets(client)) + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceAndUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, + cachedtvm.WithCheckServiceTicket(10*time.Second, 10), + cachedtvm.WithCheckUserTicket(10*time.Second, 10), + ) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestErrors(t *testing.T) { + cases := []cachedtvm.Option{ + cachedtvm.WithCheckServiceTicket(12*time.Hour, 1), + cachedtvm.WithCheckUserTicket(30*time.Minute, 1), + } + + nestedClient := &mockTvmClient{} + for i, tc := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + _, err := cachedtvm.NewClient(nestedClient, tc) + require.Error(t, err) + }) + } +} diff --git a/library/go/yandex/tvm/cachedtvm/opts.go b/library/go/yandex/tvm/cachedtvm/opts.go new file mode 100644 index 0000000000..0df9dfa89e --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/opts.go @@ -0,0 +1,40 @@ +package cachedtvm + +import "time" + +type ( + Option interface{ isCachedOption() } + + cacheOptions struct { + ttl time.Duration + maxItems int64 + } + + OptionServiceTicket struct { + Option + cacheOptions + } + + OptionUserTicket struct { + Option + cacheOptions + } +) + +func WithCheckServiceTicket(ttl time.Duration, maxSize int) Option { + return OptionServiceTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} + +func WithCheckUserTicket(ttl time.Duration, maxSize int) Option { + return OptionUserTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} diff --git a/library/go/yandex/tvm/client.go b/library/go/yandex/tvm/client.go new file mode 100644 index 0000000000..2a969fb1c6 --- /dev/null +++ b/library/go/yandex/tvm/client.go @@ -0,0 +1,56 @@ +package tvm + +//go:generate ya tool mockgen -source=$GOFILE -destination=mocks/tvm.gen.go Client + +import ( + "context" + "fmt" +) + +type ClientStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/client/client_status.h +const ( + ClientOK ClientStatus = iota + ClientWarning + ClientError +) + +func (s ClientStatus) String() string { + switch s { + case ClientOK: + return "OK" + case ClientWarning: + return "Warning" + case ClientError: + return "Error" + default: + return fmt.Sprintf("Unknown%d", s) + } +} + +type ClientStatusInfo struct { + Status ClientStatus + + // This message allows to trigger alert with useful message + // It returns "OK" if Status==Ok + LastError string +} + +// Client allows to use aliases for ClientID. +// +// Alias is local label for ClientID which can be used to avoid this number in every checking case in code. +type Client interface { + GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) + GetServiceTicketForID(ctx context.Context, dstID ClientID) (string, error) + + // CheckServiceTicket returns struct with SrcID: you should check it by yourself with ACL + CheckServiceTicket(ctx context.Context, ticket string) (*CheckedServiceTicket, error) + CheckUserTicket(ctx context.Context, ticket string, opts ...CheckUserTicketOption) (*CheckedUserTicket, error) + GetRoles(ctx context.Context) (*Roles, error) + + // GetStatus returns current status of client: + // * you should trigger your monitoring if status is not Ok + // * it will be unable to operate if status is Invalid + GetStatus(ctx context.Context) (ClientStatusInfo, error) +} diff --git a/library/go/yandex/tvm/context.go b/library/go/yandex/tvm/context.go new file mode 100644 index 0000000000..3a30dbb0b6 --- /dev/null +++ b/library/go/yandex/tvm/context.go @@ -0,0 +1,33 @@ +package tvm + +import "context" + +type ( + serviceTicketContextKey struct{} + userTicketContextKey struct{} +) + +var ( + stKey serviceTicketContextKey + utKey userTicketContextKey +) + +// WithServiceTicket returns copy of the ctx with service ticket attached to it. +func WithServiceTicket(ctx context.Context, t *CheckedServiceTicket) context.Context { + return context.WithValue(ctx, &stKey, t) +} + +// WithUserTicket returns copy of the ctx with user ticket attached to it. +func WithUserTicket(ctx context.Context, t *CheckedUserTicket) context.Context { + return context.WithValue(ctx, &utKey, t) +} + +func ContextServiceTicket(ctx context.Context) (t *CheckedServiceTicket) { + t, _ = ctx.Value(&stKey).(*CheckedServiceTicket) + return +} + +func ContextUserTicket(ctx context.Context) (t *CheckedUserTicket) { + t, _ = ctx.Value(&utKey).(*CheckedUserTicket) + return +} diff --git a/library/go/yandex/tvm/errors.go b/library/go/yandex/tvm/errors.go new file mode 100644 index 0000000000..bd511d05f3 --- /dev/null +++ b/library/go/yandex/tvm/errors.go @@ -0,0 +1,107 @@ +package tvm + +import ( + "errors" + "fmt" +) + +// ErrNotSupported - error to be used within cgo disabled builds. +var ErrNotSupported = errors.New("ticket_parser2 is not available when building with -DCGO_ENABLED=0") + +var ( + ErrTicketExpired = &TicketError{Status: TicketExpired} + ErrTicketInvalidBlackboxEnv = &TicketError{Status: TicketInvalidBlackboxEnv} + ErrTicketInvalidDst = &TicketError{Status: TicketInvalidDst} + ErrTicketInvalidTicketType = &TicketError{Status: TicketInvalidTicketType} + ErrTicketMalformed = &TicketError{Status: TicketMalformed} + ErrTicketMissingKey = &TicketError{Status: TicketMissingKey} + ErrTicketSignBroken = &TicketError{Status: TicketSignBroken} + ErrTicketUnsupportedVersion = &TicketError{Status: TicketUnsupportedVersion} + ErrTicketStatusOther = &TicketError{Status: TicketStatusOther} + ErrTicketInvalidScopes = &TicketError{Status: TicketInvalidScopes} + ErrTicketInvalidSrcID = &TicketError{Status: TicketInvalidSrcID} +) + +type TicketError struct { + Status TicketStatus + Msg string +} + +func (e *TicketError) Is(err error) bool { + otherTickerErr, ok := err.(*TicketError) + if !ok { + return false + } + if e == nil && otherTickerErr == nil { + return true + } + if e == nil || otherTickerErr == nil { + return false + } + return e.Status == otherTickerErr.Status +} + +func (e *TicketError) Error() string { + if e.Msg != "" { + return fmt.Sprintf("tvm: invalid ticket: %s: %s", e.Status, e.Msg) + } + return fmt.Sprintf("tvm: invalid ticket: %s", e.Status) +} + +type ErrorCode int + +// This constants must be in sync with code in go/tvmauth/tvm.cpp:CatchError +const ( + ErrorOK ErrorCode = iota + ErrorMalformedSecret + ErrorMalformedKeys + ErrorEmptyKeys + ErrorNotAllowed + ErrorBrokenTvmClientSettings + ErrorMissingServiceTicket + ErrorPermissionDenied + ErrorOther + + // Go-only errors below + ErrorBadRequest + ErrorAuthFail +) + +func (e ErrorCode) String() string { + switch e { + case ErrorOK: + return "OK" + case ErrorMalformedSecret: + return "MalformedSecret" + case ErrorMalformedKeys: + return "MalformedKeys" + case ErrorEmptyKeys: + return "EmptyKeys" + case ErrorNotAllowed: + return "NotAllowed" + case ErrorBrokenTvmClientSettings: + return "BrokenTvmClientSettings" + case ErrorMissingServiceTicket: + return "MissingServiceTicket" + case ErrorPermissionDenied: + return "PermissionDenied" + case ErrorOther: + return "Other" + case ErrorBadRequest: + return "ErrorBadRequest" + case ErrorAuthFail: + return "AuthFail" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +type Error struct { + Code ErrorCode + Retriable bool + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Msg, e.Code) +} diff --git a/library/go/yandex/tvm/examples/tvm_example_test.go b/library/go/yandex/tvm/examples/tvm_example_test.go new file mode 100644 index 0000000000..2d47502584 --- /dev/null +++ b/library/go/yandex/tvm/examples/tvm_example_test.go @@ -0,0 +1,59 @@ +package tvm_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func ExampleClient_alias() { + blackboxAlias := "blackbox" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "...", + map[string]tvm.ClientID{ + blackboxAlias: 1000501, + }), + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleClient_roles() { + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("...", nil), + FetchRolesForIdmSystemSlug: "some_idm_system", + DiskCacheDir: "...", + EnableServiceTicketChecking: true, + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.CheckServiceTicket(context.Background(), "3:serv:...") + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + + r, err := c.GetRoles(context.Background()) + if err != nil { + panic(err) + } + fmt.Println(r.GetMeta().Revision) +} diff --git a/library/go/yandex/tvm/mocks/tvm.gen.go b/library/go/yandex/tvm/mocks/tvm.gen.go new file mode 100644 index 0000000000..9f56f65fec --- /dev/null +++ b/library/go/yandex/tvm/mocks/tvm.gen.go @@ -0,0 +1,130 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client.go + +// Package mock_tvm is a generated GoMock package. +package mock_tvm + +import ( + tvm "a.yandex-team.ru/library/go/yandex/tvm" + context "context" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// GetServiceTicketForAlias mocks base method. +func (m *MockClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForAlias", ctx, alias) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForAlias indicates an expected call of GetServiceTicketForAlias. +func (mr *MockClientMockRecorder) GetServiceTicketForAlias(ctx, alias interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForAlias", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForAlias), ctx, alias) +} + +// GetServiceTicketForID mocks base method. +func (m *MockClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForID", ctx, dstID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForID indicates an expected call of GetServiceTicketForID. +func (mr *MockClientMockRecorder) GetServiceTicketForID(ctx, dstID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForID", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForID), ctx, dstID) +} + +// CheckServiceTicket mocks base method. +func (m *MockClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckServiceTicket", ctx, ticket) + ret0, _ := ret[0].(*tvm.CheckedServiceTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckServiceTicket indicates an expected call of CheckServiceTicket. +func (mr *MockClientMockRecorder) CheckServiceTicket(ctx, ticket interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckServiceTicket", reflect.TypeOf((*MockClient)(nil).CheckServiceTicket), ctx, ticket) +} + +// CheckUserTicket mocks base method. +func (m *MockClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, ticket} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CheckUserTicket", varargs...) + ret0, _ := ret[0].(*tvm.CheckedUserTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckUserTicket indicates an expected call of CheckUserTicket. +func (mr *MockClientMockRecorder) CheckUserTicket(ctx, ticket interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, ticket}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserTicket", reflect.TypeOf((*MockClient)(nil).CheckUserTicket), varargs...) +} + +// GetRoles mocks base method. +func (m *MockClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoles", ctx) + ret0, _ := ret[0].(*tvm.Roles) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoles indicates an expected call of GetRoles. +func (mr *MockClientMockRecorder) GetRoles(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoles", reflect.TypeOf((*MockClient)(nil).GetRoles), ctx) +} + +// GetStatus mocks base method. +func (m *MockClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatus", ctx) + ret0, _ := ret[0].(tvm.ClientStatusInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStatus indicates an expected call of GetStatus. +func (mr *MockClientMockRecorder) GetStatus(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatus", reflect.TypeOf((*MockClient)(nil).GetStatus), ctx) +} diff --git a/library/go/yandex/tvm/roles.go b/library/go/yandex/tvm/roles.go new file mode 100644 index 0000000000..03c2a97af6 --- /dev/null +++ b/library/go/yandex/tvm/roles.go @@ -0,0 +1,130 @@ +package tvm + +import ( + "encoding/json" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +func (r *Roles) GetRolesForService(t *CheckedServiceTicket) *ConsumerRoles { + return r.tvmRoles[t.SrcID] +} + +func (r *Roles) GetRolesForUser(t *CheckedUserTicket, uid *UID) (*ConsumerRoles, error) { + if t.Env != BlackboxProdYateam { + return nil, xerrors.Errorf("user ticket must be from ProdYateam, got from %s", t.Env) + } + + if uid == nil { + if t.DefaultUID == 0 { + return nil, xerrors.Errorf("default uid is 0 - it cannot have any role") + } + uid = &t.DefaultUID + } else { + found := false + for _, u := range t.UIDs { + if u == *uid { + found = true + break + } + } + if !found { + return nil, xerrors.Errorf("'uid' must be in user ticket but it is not: %d", *uid) + } + } + + return r.userRoles[*uid], nil +} + +func (r *Roles) GetRaw() []byte { + return r.raw +} + +func (r *Roles) GetMeta() Meta { + return r.meta +} + +func (r *Roles) CheckServiceRole(t *CheckedServiceTicket, roleName string, opts *CheckServiceOptions) bool { + e := r.GetRolesForService(t).GetEntitiesForRole(roleName) + if e == nil { + return false + } + + if opts != nil { + if opts.Entity != nil && !e.ContainsExactEntity(opts.Entity) { + return false + } + } + + return true +} + +func (r *Roles) CheckUserRole(t *CheckedUserTicket, roleName string, opts *CheckUserOptions) (bool, error) { + var uid *UID + if opts != nil && opts.UID != 0 { + uid = &opts.UID + } + + roles, err := r.GetRolesForUser(t, uid) + if err != nil { + return false, err + } + e := roles.GetEntitiesForRole(roleName) + if e == nil { + return false, nil + } + + if opts != nil { + if opts.Entity != nil && !e.ContainsExactEntity(opts.Entity) { + return false, nil + } + } + + return true, nil +} + +func (r *ConsumerRoles) HasRole(roleName string) bool { + return r.GetEntitiesForRole(roleName) != nil +} + +func (r *ConsumerRoles) GetRoles() EntitiesByRoles { + if r == nil { + return nil + } + return r.roles +} + +func (r *ConsumerRoles) GetEntitiesForRole(roleName string) *Entities { + if r == nil { + return nil + } + return r.roles[roleName] +} + +func (r *ConsumerRoles) DebugPrint() string { + tmp := make(map[string][]Entity) + + for k, v := range r.roles { + tmp[k] = v.subtree.entities + } + + res, err := json.MarshalIndent(tmp, "", " ") + if err != nil { + panic(err) + } + return string(res) +} + +func (e *Entities) ContainsExactEntity(entity Entity) bool { + if e == nil { + return false + } + return e.subtree.containsExactEntity(entity) +} + +func (e *Entities) GetEntitiesWithAttrs(entityPart Entity) []Entity { + if e == nil { + return nil + } + return e.subtree.getEntitiesWithAttrs(entityPart) +} diff --git a/library/go/yandex/tvm/roles_entities_index.go b/library/go/yandex/tvm/roles_entities_index.go new file mode 100644 index 0000000000..488ce7fb09 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index.go @@ -0,0 +1,73 @@ +package tvm + +import "sort" + +type entityAttribute struct { + key string + value string +} + +// subTree provides index for fast entity lookup with attributes +// +// or some subset of entity attributes +type subTree struct { + // entities contains entities with attributes from previous branches of tree: + // * root subTree contains all entities + // * next subTree contains entities with {"key#X": "value#X"} + // * next subTree after next contains entities with {"key#X": "value#X", "key#Y": "value#Y"} + // * and so on + // "key#X", "key#Y", ... - are sorted + entities []Entity + // entityLengths provides O(1) for exact entity lookup + entityLengths map[int]interface{} + // entityIds is creation-time crutch + entityIds []int + idxByAttrs *idxByAttrs +} + +type idxByAttrs = map[entityAttribute]*subTree + +func (s *subTree) containsExactEntity(entity Entity) bool { + subtree := s.findSubTree(entity) + if subtree == nil { + return false + } + + _, ok := subtree.entityLengths[len(entity)] + return ok +} + +func (s *subTree) getEntitiesWithAttrs(entityPart Entity) []Entity { + subtree := s.findSubTree(entityPart) + if subtree == nil { + return nil + } + + return subtree.entities +} + +func (s *subTree) findSubTree(e Entity) *subTree { + keys := make([]string, 0, len(e)) + for k := range e { + keys = append(keys, k) + } + sort.Strings(keys) + + res := s + + for _, k := range keys { + if res.idxByAttrs == nil { + return nil + } + + kv := entityAttribute{key: k, value: e[k]} + ok := false + + res, ok = (*res.idxByAttrs)[kv] + if !ok { + return nil + } + } + + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder.go b/library/go/yandex/tvm/roles_entities_index_builder.go new file mode 100644 index 0000000000..20bde16a00 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder.go @@ -0,0 +1,117 @@ +package tvm + +import "sort" + +type stages struct { + keys []string + id uint64 +} + +func createStages(keys []string) stages { + return stages{ + keys: keys, + } +} + +func (s *stages) getNextStage(keys *[]string) bool { + s.id += 1 + *keys = (*keys)[:0] + + for idx := range s.keys { + need := (s.id >> idx) & 0x01 + if need == 1 { + *keys = append(*keys, s.keys[idx]) + } + } + + return len(*keys) > 0 +} + +func buildEntities(entities []Entity) *Entities { + root := make(idxByAttrs) + res := &Entities{ + subtree: subTree{ + idxByAttrs: &root, + }, + } + + stage := createStages(getUniqueSortedKeys(entities)) + + keySet := make([]string, 0, len(stage.keys)) + for stage.getNextStage(&keySet) { + for entityID, entity := range entities { + currentBranch := &res.subtree + + for _, key := range keySet { + entValue, ok := entity[key] + if !ok { + continue + } + + if currentBranch.idxByAttrs == nil { + index := make(idxByAttrs) + currentBranch.idxByAttrs = &index + } + + kv := entityAttribute{key: key, value: entValue} + subtree, ok := (*currentBranch.idxByAttrs)[kv] + if !ok { + subtree = &subTree{} + (*currentBranch.idxByAttrs)[kv] = subtree + } + + currentBranch = subtree + currentBranch.entityIds = append(currentBranch.entityIds, entityID) + res.subtree.entityIds = append(res.subtree.entityIds, entityID) + } + } + } + + postProcessSubTree(&res.subtree, entities) + + return res +} + +func postProcessSubTree(sub *subTree, entities []Entity) { + tmp := make(map[int]interface{}, len(entities)) + for _, e := range sub.entityIds { + tmp[e] = nil + } + sub.entityIds = sub.entityIds[:0] + for i := range tmp { + sub.entityIds = append(sub.entityIds, i) + } + sort.Ints(sub.entityIds) + + sub.entities = make([]Entity, 0, len(sub.entityIds)) + sub.entityLengths = make(map[int]interface{}) + for _, idx := range sub.entityIds { + sub.entities = append(sub.entities, entities[idx]) + sub.entityLengths[len(entities[idx])] = nil + } + sub.entityIds = nil + + if sub.idxByAttrs != nil { + for _, rest := range *sub.idxByAttrs { + postProcessSubTree(rest, entities) + } + } +} + +func getUniqueSortedKeys(entities []Entity) []string { + tmp := map[string]interface{}{} + + for _, e := range entities { + for k := range e { + tmp[k] = nil + } + } + + res := make([]string, 0, len(tmp)) + for k := range tmp { + res = append(res, k) + } + + sort.Strings(res) + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder_test.go b/library/go/yandex/tvm/roles_entities_index_builder_test.go new file mode 100644 index 0000000000..dd795369d5 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder_test.go @@ -0,0 +1,259 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesGetNextStage(t *testing.T) { + s := createStages([]string{"key#1", "key#2", "key#3", "key#4"}) + + results := [][]string{ + {"key#1"}, + {"key#2"}, + {"key#1", "key#2"}, + {"key#3"}, + {"key#1", "key#3"}, + {"key#2", "key#3"}, + {"key#1", "key#2", "key#3"}, + {"key#4"}, + {"key#1", "key#4"}, + {"key#2", "key#4"}, + {"key#1", "key#2", "key#4"}, + {"key#3", "key#4"}, + {"key#1", "key#3", "key#4"}, + {"key#2", "key#3", "key#4"}, + {"key#1", "key#2", "key#3", "key#4"}, + } + + keySet := make([]string, 0) + for idx, exp := range results { + s.getNextStage(&keySet) + require.Equal(t, exp, keySet, idx) + } + + // require.False(t, s.getNextStage(&keySet)) +} + +func TestRolesBuildEntities(t *testing.T) { + type TestCase struct { + in []Entity + out Entities + } + cases := []TestCase{ + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: Entities{subtree: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{1: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#1", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + }, + }, + }, + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#3", value: "value#3"}: &subTree{ + entities: []Entity{ + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, *buildEntities(c.in), idx) + } +} + +func TestRolesPostProcessSubTree(t *testing.T) { + type TestCase struct { + in subTree + out subTree + } + + cases := []TestCase{ + { + in: subTree{ + entityIds: []int{1, 1, 1, 1, 1, 2, 0, 0, 0}, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + }, + }, + { + in: subTree{ + entityIds: []int{1, 0}, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entityIds: []int{2, 0, 0}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entityIds: []int{0, 0, 0}, + }, + }, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + }, + }, + }, + } + + entities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + + for idx, c := range cases { + postProcessSubTree(&c.in, entities) + require.Equal(t, c.out, c.in, idx) + } +} + +func TestRolesGetUniqueSortedKeys(t *testing.T) { + type TestCase struct { + in []Entity + out []string + } + + cases := []TestCase{ + { + in: nil, + out: []string{}, + }, + { + in: []Entity{}, + out: []string{}, + }, + { + in: []Entity{ + {}, + }, + out: []string{}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2"}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: []string{"key#1", "key#2", "key#3"}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, getUniqueSortedKeys(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_entities_index_test.go b/library/go/yandex/tvm/roles_entities_index_test.go new file mode 100644 index 0000000000..e1abaa0f0e --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_test.go @@ -0,0 +1,113 @@ +package tvm + +import ( + "math/rand" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRolesSubTreeContainsExactEntity(t *testing.T) { + origEntities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#1", "key#2": "value#2"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + entities := buildEntities(origEntities) + + for _, e := range generatedRandEntities() { + found := false + for _, o := range origEntities { + if reflect.DeepEqual(e, o) { + found = true + break + } + } + + require.Equal(t, found, entities.subtree.containsExactEntity(e), e) + } +} + +func generatedRandEntities() []Entity { + rand.Seed(time.Now().UnixNano()) + + keysStages := createStages([]string{"key#1", "key#2", "key#3", "key#4", "key#5"}) + valuesSet := []string{"value#1", "value#2", "value#3", "value#4", "value#5"} + + res := make([]Entity, 0) + + keySet := make([]string, 0, 5) + for keysStages.getNextStage(&keySet) { + entity := Entity{} + for _, key := range keySet { + entity[key] = valuesSet[rand.Intn(len(valuesSet))] + + e := Entity{} + for k, v := range entity { + e[k] = v + } + res = append(res, e) + } + } + + return res +} + +func TestRolesGetEntitiesWithAttrs(t *testing.T) { + type TestCase struct { + in Entity + out []Entity + } + + cases := []TestCase{ + { + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + }, + { + in: Entity{"key#1": "value#1"}, + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + }, + { + in: Entity{"key#1": "value#2"}, + out: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#2": "value#2"}, + out: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#3": "value#3"}, + out: []Entity{ + {"key#3": "value#3"}, + }, + }, + } + + entities := buildEntities([]Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }) + + for idx, c := range cases { + require.Equal(t, c.out, entities.subtree.getEntitiesWithAttrs(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_opts.go b/library/go/yandex/tvm/roles_opts.go new file mode 100644 index 0000000000..8e0a0e0608 --- /dev/null +++ b/library/go/yandex/tvm/roles_opts.go @@ -0,0 +1,10 @@ +package tvm + +type CheckServiceOptions struct { + Entity Entity +} + +type CheckUserOptions struct { + Entity Entity + UID UID +} diff --git a/library/go/yandex/tvm/roles_parser.go b/library/go/yandex/tvm/roles_parser.go new file mode 100644 index 0000000000..f46c6b99b0 --- /dev/null +++ b/library/go/yandex/tvm/roles_parser.go @@ -0,0 +1,67 @@ +package tvm + +import ( + "encoding/json" + "strconv" + "time" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +type rawRoles struct { + Revision string `json:"revision"` + BornDate int64 `json:"born_date"` + Tvm rawConsumers `json:"tvm"` + User rawConsumers `json:"user"` +} + +type rawConsumers = map[string]rawConsumerRoles +type rawConsumerRoles = map[string][]Entity + +func NewRoles(buf []byte) (*Roles, error) { + var raw rawRoles + if err := json.Unmarshal(buf, &raw); err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid json: %w", err) + } + + tvmRoles := map[ClientID]*ConsumerRoles{} + for key, value := range raw.Tvm { + id, err := strconv.ParseUint(key, 10, 32) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid tvmid '%s': %w", key, err) + } + tvmRoles[ClientID(id)] = buildConsumerRoles(value) + } + + userRoles := map[UID]*ConsumerRoles{} + for key, value := range raw.User { + id, err := strconv.ParseUint(key, 10, 64) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid UID '%s': %w", key, err) + } + userRoles[UID(id)] = buildConsumerRoles(value) + } + + return &Roles{ + tvmRoles: tvmRoles, + userRoles: userRoles, + raw: buf, + meta: Meta{ + Revision: raw.Revision, + BornTime: time.Unix(raw.BornDate, 0), + Applied: time.Now(), + }, + }, nil +} + +func buildConsumerRoles(rawConsumerRoles rawConsumerRoles) *ConsumerRoles { + roles := &ConsumerRoles{ + roles: make(EntitiesByRoles, len(rawConsumerRoles)), + } + + for r, ents := range rawConsumerRoles { + roles.roles[r] = buildEntities(ents) + } + + return roles +} diff --git a/library/go/yandex/tvm/roles_parser_test.go b/library/go/yandex/tvm/roles_parser_test.go new file mode 100644 index 0000000000..2b27100ff0 --- /dev/null +++ b/library/go/yandex/tvm/roles_parser_test.go @@ -0,0 +1,88 @@ +package tvm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRolesUserTicketCheckScopes(t *testing.T) { + type TestCase struct { + buf string + roles Roles + err string + } + + cases := []TestCase{ + { + buf: `{"revision":100500}`, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"born_date":1612791978.42}`, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"tvm":{"asd":{}}}`, + err: "failed to parse roles: invalid tvmid 'asd'", + }, + { + buf: `{"user":{"asd":{}}}`, + err: "failed to parse roles: invalid UID 'asd'", + }, + { + buf: `{"tvm":{"1120000000000493":{}}}`, + err: "failed to parse roles: invalid tvmid '1120000000000493'", + }, + { + buf: `{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`, + roles: Roles{ + tvmRoles: map[ClientID]*ConsumerRoles{ + ClientID(2012192): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/impersonator/": {}, + "/group/system/system_on/abc/role/tree_edit/": {}, + }, + }, + }, + userRoles: map[UID]*ConsumerRoles{ + UID(1120000000000493): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/roles_manage/": {}, + }, + }, + }, + raw: []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`), + meta: Meta{ + Revision: "GYYDEMJUGBQWC", + BornTime: time.Unix(1612791978, 0), + }, + }, + }, + } + + for idx, c := range cases { + r, err := NewRoles([]byte(c.buf)) + if c.err == "" { + require.NoError(t, err, idx) + + r.meta.Applied = time.Time{} + for _, roles := range r.tvmRoles { + for _, v := range roles.roles { + v.subtree = subTree{} + } + } + for _, roles := range r.userRoles { + for _, v := range roles.roles { + v.subtree = subTree{} + } + } + + require.Equal(t, c.roles, *r, idx) + } else { + require.Error(t, err, idx) + require.Contains(t, err.Error(), c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/roles_test.go b/library/go/yandex/tvm/roles_test.go new file mode 100644 index 0000000000..d0c913984f --- /dev/null +++ b/library/go/yandex/tvm/roles_test.go @@ -0,0 +1,116 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesPublicServiceTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`)) + require.NoError(t, err) + + st := &CheckedServiceTicket{SrcID: 42} + require.Nil(t, roles.GetRolesForService(st)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "/"}})) + + st = &CheckedServiceTicket{SrcID: 2012192} + r := roles.GetRolesForService(st) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/impersonator/": [ + { + "scope": "/" + } + ], + "/group/system/system_on/abc/role/tree_edit/": [ + { + "scope": "/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 2, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/impersonator/")) + require.False(t, roles.CheckServiceRole(st, "/", nil)) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "kek"}})) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.Nil(t, r.GetEntitiesForRole("/")) + + en := r.GetEntitiesForRole("/group/system/system_on/abc/role/impersonator/") + require.NotNil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "kek"})) + require.True(t, en.ContainsExactEntity(Entity{"scope": "/"})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "kek"})) + require.Equal(t, []Entity{{"scope": "/"}}, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) +} + +func TestRolesPublicUserTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`)) + require.NoError(t, err) + + ut := &CheckedUserTicket{DefaultUID: 42} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "user ticket must be from ProdYateam, got from Prod") + ut.Env = BlackboxProdYateam + + r, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, r) + ok, err := roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 1120000000000493, UIDs: []UID{42}, Env: BlackboxProdYateam} + r, err = roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/roles_manage/": [ + { + "scope": "/services/meta_infra/tools/jobjira/" + }, + { + "scope": "/services/meta_edu/infrastructure/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 1, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/roles_manage/")) + ok, err = roles.CheckUserRole(ut, "/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "kek"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "/services/meta_infra/tools/jobjira/"}}) + require.NoError(t, err) + require.True(t, ok) + + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{UID: UID(42)}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 0, UIDs: []UID{42}, Env: BlackboxProdYateam} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "default uid is 0 - it cannot have any role") + uid := UID(83) + _, err = roles.GetRolesForUser(ut, &uid) + require.EqualError(t, err, "'uid' must be in user ticket but it is not: 83") +} diff --git a/library/go/yandex/tvm/roles_types.go b/library/go/yandex/tvm/roles_types.go new file mode 100644 index 0000000000..d1bfb07b3c --- /dev/null +++ b/library/go/yandex/tvm/roles_types.go @@ -0,0 +1,30 @@ +package tvm + +import ( + "time" +) + +type Roles struct { + tvmRoles map[ClientID]*ConsumerRoles + userRoles map[UID]*ConsumerRoles + raw []byte + meta Meta +} + +type Meta struct { + Revision string + BornTime time.Time + Applied time.Time +} + +type ConsumerRoles struct { + roles EntitiesByRoles +} + +type EntitiesByRoles = map[string]*Entities + +type Entities struct { + subtree subTree +} + +type Entity = map[string]string diff --git a/library/go/yandex/tvm/service_ticket.go b/library/go/yandex/tvm/service_ticket.go new file mode 100644 index 0000000000..2341ba2b17 --- /dev/null +++ b/library/go/yandex/tvm/service_ticket.go @@ -0,0 +1,50 @@ +package tvm + +import ( + "fmt" +) + +// CheckedServiceTicket is service credential +type CheckedServiceTicket struct { + // SrcID is ID of request source service. You should check SrcID by yourself with your ACL. + SrcID ClientID + // IssuerUID is UID of developer who is debuging something, so he(she) issued CheckedServiceTicket with his(her) ssh-sign: + // it is grant_type=sshkey in tvm-api + // https://wiki.yandex-team.ru/passport/tvm2/debug/#sxoditvapizakrytoeserviceticketami. + IssuerUID UID + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string +} + +func (t *CheckedServiceTicket) CheckSrcID(allowedSrcIDsMap map[uint32]struct{}) error { + if len(allowedSrcIDsMap) == 0 { + return nil + } + if _, allowed := allowedSrcIDsMap[uint32(t.SrcID)]; !allowed { + return &TicketError{ + Status: TicketInvalidSrcID, + Msg: fmt.Sprintf("service ticket srcID is not in allowed srcIDs: %v (actual: %v)", allowedSrcIDsMap, t.SrcID), + } + } + return nil +} + +func (t CheckedServiceTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +type ServiceTicketACL func(ticket *CheckedServiceTicket) error + +func AllowAllServiceTickets() ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return nil + } +} + +func CheckServiceTicketSrcID(allowedSrcIDs map[uint32]struct{}) ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return ticket.CheckSrcID(allowedSrcIDs) + } +} diff --git a/library/go/yandex/tvm/tvm.go b/library/go/yandex/tvm/tvm.go new file mode 100644 index 0000000000..663589efd5 --- /dev/null +++ b/library/go/yandex/tvm/tvm.go @@ -0,0 +1,121 @@ +// This package defines interface which provides fast and cryptographically secure authorization tickets: https://wiki.yandex-team.ru/passport/tvm2/. +// +// Encoded ticket is a valid ASCII string: [0-9a-zA-Z_-:]+. +// +// This package defines interface. All libraries should depend on this package. +// Pure Go implementations of interface is located in library/go/yandex/tvm/tvmtool. +// CGO implementation is located in library/ticket_parser2/go/ticket_parser2. +package tvm + +import ( + "fmt" + "strings" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +// ClientID represents ID of the application. Another name - TvmID. +type ClientID uint32 + +// UID represents ID of the user in Passport. +type UID uint64 + +// BlackboxEnv describes environment of Passport: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#0-opredeljaemsjasokruzhenijami +type BlackboxEnv int + +// This constants must be in sync with EBlackboxEnv from library/cpp/tvmauth/checked_user_ticket.h +const ( + BlackboxProd BlackboxEnv = iota + BlackboxTest + BlackboxProdYateam + BlackboxTestYateam + BlackboxStress +) + +func (e BlackboxEnv) String() string { + switch e { + case BlackboxProd: + return "Prod" + case BlackboxTest: + return "Test" + case BlackboxProdYateam: + return "ProdYateam" + case BlackboxTestYateam: + return "TestYateam" + case BlackboxStress: + return "Stress" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +func BlackboxEnvFromString(envStr string) (BlackboxEnv, error) { + switch strings.ToLower(envStr) { + case "prod": + return BlackboxProd, nil + case "test": + return BlackboxTest, nil + case "prodyateam", "prod_yateam": + return BlackboxProdYateam, nil + case "testyateam", "test_yateam": + return BlackboxTestYateam, nil + case "stress": + return BlackboxStress, nil + default: + return BlackboxEnv(-1), xerrors.Errorf("blackbox env is unknown: '%s'", envStr) + } +} + +type TicketStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/ticket_status.h +const ( + TicketOk TicketStatus = iota + TicketExpired + TicketInvalidBlackboxEnv + TicketInvalidDst + TicketInvalidTicketType + TicketMalformed + TicketMissingKey + TicketSignBroken + TicketUnsupportedVersion + TicketNoRoles + + // Go-only statuses below + TicketStatusOther + TicketInvalidScopes + TicketInvalidSrcID +) + +func (s TicketStatus) String() string { + switch s { + case TicketOk: + return "Ok" + case TicketExpired: + return "Expired" + case TicketInvalidBlackboxEnv: + return "InvalidBlackboxEnv" + case TicketInvalidDst: + return "InvalidDst" + case TicketInvalidTicketType: + return "InvalidTicketType" + case TicketMalformed: + return "Malformed" + case TicketMissingKey: + return "MissingKey" + case TicketSignBroken: + return "SignBroken" + case TicketUnsupportedVersion: + return "UnsupportedVersion" + case TicketNoRoles: + return "NoRoles" + case TicketStatusOther: + return "Other" + case TicketInvalidScopes: + return "InvalidScopes" + case TicketInvalidSrcID: + return "InvalidSrcID" + default: + return fmt.Sprintf("Unknown%d", s) + } +} diff --git a/library/go/yandex/tvm/tvm_test.go b/library/go/yandex/tvm/tvm_test.go new file mode 100644 index 0000000000..3d8f9f0532 --- /dev/null +++ b/library/go/yandex/tvm/tvm_test.go @@ -0,0 +1,246 @@ +package tvm_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +func TestUserTicketCheckScopes(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + "not_all_scopes": { + ticketScopes: []string{"bb:sessionid", "test:test1"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestUserTicketCheckScopesAny(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test"}, + err: false, + }, + "multiple_scopes_2": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_3": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestServiceTicketAllowedSrcIDs(t *testing.T) { + cases := map[string]struct { + srcID uint32 + allowedSrcIDs []uint32 + err bool + }{ + "empty_allow_list_allows_any_srcID": {srcID: 162, allowedSrcIDs: []uint32{}, err: false}, + "known_src_id_is_allowed": {srcID: 42, allowedSrcIDs: []uint32{42, 100500}, err: false}, + "unknown_src_id_is_not_allowed": {srcID: 404, allowedSrcIDs: []uint32{42, 100500}, err: true}, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedServiceTicket{ + SrcID: tvm.ClientID(testCase.srcID), + } + allowedSrcIDsMap := make(map[uint32]struct{}, len(testCase.allowedSrcIDs)) + for _, allowedSrcID := range testCase.allowedSrcIDs { + allowedSrcIDsMap[allowedSrcID] = struct{}{} + } + err := ticket.CheckSrcID(allowedSrcIDsMap) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidSrcID, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTicketError_Is(t *testing.T) { + err1 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err2 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err3 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "other uh oh message", + } + err4 := &tvm.TicketError{ + Status: tvm.TicketExpired, + Msg: "uh oh", + } + err5 := &tvm.TicketError{ + Status: tvm.TicketMalformed, + Msg: "i am completely different", + } + var nilErr *tvm.TicketError = nil + + // ticketErrors are equal to themselves + require.True(t, err1.Is(err1)) + require.True(t, err2.Is(err2)) + require.True(t, nilErr.Is(nilErr)) + + // equal value ticketErrors are equal + require.True(t, err1.Is(err2)) + require.True(t, err2.Is(err1)) + // equal status ticketErrors are equal + require.True(t, err1.Is(err3)) + require.True(t, err1.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err2.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err3.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err4.Is(tvm.ErrTicketExpired)) + require.True(t, err5.Is(tvm.ErrTicketMalformed)) + + // different status ticketErrors are not equal + require.False(t, err1.Is(err4)) + + // completely different ticketErrors are not equal + require.False(t, err1.Is(err5)) + + // non-nil ticketErrors are not equal to nil errors + require.False(t, err1.Is(nil)) + require.False(t, err2.Is(nil)) + + // non-nil ticketErrors are not equal to nil ticketErrors + require.False(t, err1.Is(nilErr)) + require.False(t, err2.Is(nilErr)) +} + +func TestBbEnvFromString(t *testing.T) { + type Case struct { + in string + env tvm.BlackboxEnv + err string + } + cases := []Case{ + {in: "prod", env: tvm.BlackboxProd}, + {in: "Prod", env: tvm.BlackboxProd}, + {in: "ProD", env: tvm.BlackboxProd}, + {in: "PROD", env: tvm.BlackboxProd}, + {in: "test", env: tvm.BlackboxTest}, + {in: "prod_yateam", env: tvm.BlackboxProdYateam}, + {in: "ProdYateam", env: tvm.BlackboxProdYateam}, + {in: "test_yateam", env: tvm.BlackboxTestYateam}, + {in: "TestYateam", env: tvm.BlackboxTestYateam}, + {in: "stress", env: tvm.BlackboxStress}, + {in: "", err: "blackbox env is unknown: ''"}, + {in: "kek", err: "blackbox env is unknown: 'kek'"}, + } + + for idx, c := range cases { + res, err := tvm.BlackboxEnvFromString(c.in) + + if c.err == "" { + require.NoError(t, err, idx) + require.Equal(t, c.env, res, idx) + } else { + require.EqualError(t, err, c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/apitest/.arcignore b/library/go/yandex/tvm/tvmauth/apitest/.arcignore new file mode 100644 index 0000000000..c8a6e77006 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/.arcignore @@ -0,0 +1 @@ +apitest diff --git a/library/go/yandex/tvm/tvmauth/apitest/client_test.go b/library/go/yandex/tvm/tvmauth/apitest/client_test.go new file mode 100644 index 0000000000..8868abe473 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/client_test.go @@ -0,0 +1,243 @@ +package apitest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + uzap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func apiSettings(t testing.TB, client tvm.ClientID) tvmauth.TvmAPISettings { + var portStr []byte + portStr, err := ioutil.ReadFile("tvmapi.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + env := tvm.BlackboxProd + + if client == 1000501 { + return tvmauth.TvmAPISettings{ + SelfID: 1000501, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{1000502}), + + TVMHost: "localhost", + TVMPort: port, + } + } else if client == 1000502 { + return tvmauth.TvmAPISettings{ + SelfID: 1000502, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "e5kL0vM3nP-nPf-388Hi6Q", + map[string]tvm.ClientID{ + "cl1000501": 1000501, + "cl1000503": 1000503, + }), + + TVMHost: "localhost", + TVMPort: port, + } + } else { + t.Fatalf("Bad client id: %d", client) + return tvmauth.TvmAPISettings{} + } +} + +func TestErrorPassing(t *testing.T) { + _, err := tvmauth.NewAPIClient(tvmauth.TvmAPISettings{}, &nop.Logger{}) + require.Error(t, err) +} + +func TestGetServiceTicketForID(t *testing.T) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(t, 1000501), &nop.Logger{}) + require.NoError(t, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(t, err) + + ticket, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(1000501), ticket.SrcID) + + ticketStrByAlias, err := c1000501.GetServiceTicketForAlias(context.Background(), "1000502") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000501.CheckServiceTicket(context.Background(), ticketStr) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, tvm.TicketInvalidDst, err.(*tvm.TicketError).Status) + + _, err = c1000501.GetServiceTicketForID(context.Background(), 127) + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) + + ticketStr, err = c1000502.GetServiceTicketForID(context.Background(), 1000501) + require.NoError(t, err) + ticketStrByAlias, err = c1000502.GetServiceTicketForAlias(context.Background(), "cl1000501") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000502.GetServiceTicketForAlias(context.Background(), "1000501") + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) +} + +func TestLogger(t *testing.T) { + logger, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + require.NoError(t, err) + + core, logs := observer.New(zap.ZapifyLevel(log.DebugLevel)) + logger.L = logger.L.WithOptions(uzap.WrapCore(func(_ zapcore.Core) zapcore.Core { + return core + })) + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), logger) + require.NoError(t, err) + defer c1000502.Destroy() + + loggedEntries := logs.AllUntimed() + for idx := 0; len(loggedEntries) < 7 && idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + loggedEntries = logs.AllUntimed() + } + + var plainLog string + for _, le := range loggedEntries { + plainLog += le.Message + "\n" + } + + require.Contains( + t, + plainLog, + "Thread-worker started") +} + +func BenchmarkServiceTicket(b *testing.B) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(b, 1000501), &nop.Logger{}) + require.NoError(b, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(b, 1000502), &nop.Logger{}) + require.NoError(b, err) + defer c1000502.Destroy() + + b.Run("GetServiceTicketForID", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + } + }) + }) + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + + b.Run("CheckServiceTicket", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(b, err) + } + }) + }) +} + +const serviceTicketStr = "3:serv:CBAQ__________9_IggIlJEGELaIPQ:KC8zKTnoM7GQ8UkBixoAlDt7CAuNIO_6J4rzeqelj7wn7vCKBfsy1jlg2UIvBw0JKUUc6116s5aBw1-vr4BD1V0eh0z-k_CSGC4DKKlnBEEAwcpHRjOZUdW_5UJFe-l77KMObvZUPLckWUaQKybMSBYDGrAeo1TqHHmkumwSG5s" +const userTicketStr = "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:LPpzn2ILhY1BHXA1a51mtU1emb2QSMH3UhTxsmL07iJ7m2AMc2xloXCKQOI7uK6JuLDf7aSWd9QQJpaRV0mfPzvFTnz2j78hvO3bY8KT_TshA3A-M5-t5gip8CfTVGPmEPwnuUhmKqAGkGSL-sCHyu1RIjHkGbJA250ThHHKgAY" + +func TestDebugInfo(t *testing.T) { + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketS, err := c1000502.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(100500), ticketS.SrcID) + require.Equal(t, tvm.UID(0), ticketS.IssuerUID) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketS, err = c1000502.CheckServiceTicket(context.Background(), serviceTicketStr[:len(serviceTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketU, err := c1000502.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.Equal(t, []tvm.UID{123, 456}, ticketU.UIDs) + require.Equal(t, tvm.UID(456), ticketU.DefaultUID) + require.Equal(t, []string{"bb:kek", "some:scopes"}, ticketU.Scopes) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) + + _, err = c1000502.CheckUserTicket(context.Background(), userTicketStr, tvm.WithBlackboxOverride(tvm.BlackboxProdYateam)) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketInvalidBlackboxEnv) + + ticketU, err = c1000502.CheckUserTicket(context.Background(), userTicketStr[:len(userTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) +} + +func TestUnittestClient(t *testing.T) { + _, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{}) + require.NoError(t, err) + + client, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{ + SelfID: 1000502, + }) + require.NoError(t, err) + + _, err = client.GetRoles(context.Background()) + require.ErrorContains(t, err, "Roles are not provided") + _, err = client.GetServiceTicketForID(context.Background(), tvm.ClientID(42)) + require.ErrorContains(t, err, "Destination '42' was not specified in settings") + + status, err := client.GetStatus(context.Background()) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientOK, status.Status) + + st, err := client.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientID(100500), st.SrcID) + + ut, err := client.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.UID(456), ut.DefaultUID) +} diff --git a/library/go/yandex/tvm/tvmauth/client.go b/library/go/yandex/tvm/tvmauth/client.go new file mode 100644 index 0000000000..0282b2939f --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client.go @@ -0,0 +1,509 @@ +//go:build cgo +// +build cgo + +package tvmauth + +// #include <stdlib.h> +// +// #include "tvm.h" +import "C" +import ( + "context" + "encoding/json" + "fmt" + "runtime" + "sync" + "unsafe" + + "a.yandex-team.ru/library/go/cgosem" + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// NewIDsOptions creates options for fetching CheckedServiceTicket's with ClientID +func NewIDsOptions(secret string, dsts []tvm.ClientID) *TVMAPIOptions { + tmp := make(map[string]tvm.ClientID) + for _, dst := range dsts { + tmp[fmt.Sprintf("%d", dst)] = dst + } + + res, err := json.Marshal(tmp) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +// NewAliasesOptions creates options for fetching CheckedServiceTicket's with alias+ClientID +func NewAliasesOptions(secret string, dsts map[string]tvm.ClientID) *TVMAPIOptions { + if dsts == nil { + dsts = make(map[string]tvm.ClientID) + } + + res, err := json.Marshal(dsts) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +func (o *TvmAPISettings) pack(out *C.TVM_ApiSettings) { + out.SelfId = C.uint32_t(o.SelfID) + + if o.EnableServiceTicketChecking { + out.EnableServiceTicketChecking = 1 + } + + if o.BlackboxEnv != nil { + out.EnableUserTicketChecking = 1 + out.BlackboxEnv = C.int(*o.BlackboxEnv) + } + + if o.FetchRolesForIdmSystemSlug != "" { + o.fetchRolesForIdmSystemSlug = []byte(o.FetchRolesForIdmSystemSlug) + out.IdmSystemSlug = (*C.uchar)(&o.fetchRolesForIdmSystemSlug[0]) + out.IdmSystemSlugSize = C.int(len(o.fetchRolesForIdmSystemSlug)) + } + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } + + if o.TVMHost != "" { + o.tvmHost = []byte(o.TVMHost) + out.TVMHost = (*C.uchar)(&o.tvmHost[0]) + out.TVMHostSize = C.int(len(o.tvmHost)) + } + out.TVMPort = C.int(o.TVMPort) + + if o.TiroleHost != "" { + o.tiroleHost = []byte(o.TiroleHost) + out.TiroleHost = (*C.uchar)(&o.tiroleHost[0]) + out.TiroleHostSize = C.int(len(o.tiroleHost)) + } + out.TirolePort = C.int(o.TirolePort) + out.TiroleTvmId = C.uint32_t(o.TiroleTvmID) + + if o.ServiceTicketOptions != nil { + if (o.ServiceTicketOptions.selfSecret != "") { + o.ServiceTicketOptions.selfSecretB = []byte(o.ServiceTicketOptions.selfSecret) + out.SelfSecret = (*C.uchar)(&o.ServiceTicketOptions.selfSecretB[0]) + out.SelfSecretSize = C.int(len(o.ServiceTicketOptions.selfSecretB)) + } + + if (len(o.ServiceTicketOptions.dstAliases) != 0) { + out.DstAliases = (*C.uchar)(&o.ServiceTicketOptions.dstAliases[0]) + out.DstAliasesSize = C.int(len(o.ServiceTicketOptions.dstAliases)) + } + } + + if o.DiskCacheDir != "" { + o.diskCacheDir = []byte(o.DiskCacheDir) + + out.DiskCacheDir = (*C.uchar)(&o.diskCacheDir[0]) + out.DiskCacheDirSize = C.int(len(o.diskCacheDir)) + } +} + +func (o *TvmToolSettings) pack(out *C.TVM_ToolSettings) { + if o.Alias != "" { + o.alias = []byte(o.Alias) + + out.Alias = (*C.uchar)(&o.alias[0]) + out.AliasSize = C.int(len(o.alias)) + } + + out.Port = C.int(o.Port) + + if o.Hostname != "" { + o.hostname = []byte(o.Hostname) + out.Hostname = (*C.uchar)(&o.hostname[0]) + out.HostnameSize = C.int(len(o.hostname)) + } + + if o.AuthToken != "" { + o.authToken = []byte(o.AuthToken) + out.AuthToken = (*C.uchar)(&o.authToken[0]) + out.AuthTokenSize = C.int(len(o.authToken)) + } + + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } +} + +func (o *TvmUnittestSettings) pack(out *C.TVM_UnittestSettings) { + out.SelfId = C.uint32_t(o.SelfID) + out.BlackboxEnv = C.int(o.BlackboxEnv) +} + +// Destroy stops client and delete it from memory. +// Do not try to use client after destroying it +func (c *Client) Destroy() { + if c.handle == nil { + return + } + + C.TVM_DestroyClient(c.handle) + c.handle = nil + + if c.logger != nil { + unregisterLogger(*c.logger) + } +} + +func unpackString(s *C.TVM_String) string { + if s.Data == nil { + return "" + } + + return C.GoStringN(s.Data, s.Size) +} + +func unpackErr(err *C.TVM_Error) error { + msg := unpackString(&err.Message) + code := tvm.ErrorCode(err.Code) + + if code != 0 { + return &tvm.Error{Code: code, Retriable: err.Retriable != 0, Msg: msg} + } + + return nil +} + +func unpackScopes(scopes *C.TVM_String, scopeSize C.int) (s []string) { + if scopeSize == 0 { + return + } + + s = make([]string, int(scopeSize)) + scopesArr := (*[1 << 30]C.TVM_String)(unsafe.Pointer(scopes)) + + for i := 0; i < int(scopeSize); i++ { + s[i] = C.GoStringN(scopesArr[i].Data, scopesArr[i].Size) + } + + return +} + +func unpackStatus(status C.int) error { + if status == 0 { + return nil + } + + return &tvm.TicketError{Status: tvm.TicketStatus(status)} +} + +func unpackServiceTicket(t *C.TVM_ServiceTicket) (*tvm.CheckedServiceTicket, error) { + ticket := &tvm.CheckedServiceTicket{} + ticket.SrcID = tvm.ClientID(t.SrcId) + ticket.IssuerUID = tvm.UID(t.IssuerUid) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + return ticket, unpackStatus(t.Status) +} + +func unpackUserTicket(t *C.TVM_UserTicket) (*tvm.CheckedUserTicket, error) { + ticket := &tvm.CheckedUserTicket{} + ticket.DefaultUID = tvm.UID(t.DefaultUid) + if t.UidsSize != 0 { + ticket.UIDs = make([]tvm.UID, int(t.UidsSize)) + uids := (*[1 << 30]C.uint64_t)(unsafe.Pointer(t.Uids)) + for i := 0; i < int(t.UidsSize); i++ { + ticket.UIDs[i] = tvm.UID(uids[i]) + } + } + + ticket.Env = tvm.BlackboxEnv(t.Env) + + ticket.Scopes = unpackScopes(t.Scopes, t.ScopesSize) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + return ticket, unpackStatus(t.Status) +} + +func unpackClientStatus(s *C.TVM_ClientStatus) (status tvm.ClientStatusInfo) { + status.Status = tvm.ClientStatus(s.Status) + status.LastError = C.GoStringN(s.LastError.Data, s.LastError.Size) + + return +} + +// NewAPIClient creates client which uses https://tvm-api.yandex.net to get state +func NewAPIClient(options TvmAPISettings, log log.Logger) (*Client, error) { + var settings C.TVM_ApiSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewApiClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// NewToolClient creates client uses local http-interface to get state: http://localhost/tvm/. +// Details: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +func NewToolClient(options TvmToolSettings, log log.Logger) (*Client, error) { + var settings C.TVM_ToolSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewToolClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// NewUnittestClient creates client with mocked state. +func NewUnittestClient(options TvmUnittestSettings) (*Client, error) { + var settings C.TVM_UnittestSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var tvmErr C.TVM_Error + C.TVM_NewUnittestClient(settings, &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// CheckServiceTicket always checks ticket with keys from memory +func (c *Client) CheckServiceTicket(ctx context.Context, ticketStr string) (*tvm.CheckedServiceTicket, error) { + defer cgosem.S.Acquire().Release() + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket C.TVM_ServiceTicket + var tvmErr C.TVM_Error + C.TVM_CheckServiceTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackServiceTicket(&ticket) +} + +// CheckUserTicket always checks ticket with keys from memory +func (c *Client) CheckUserTicket(ctx context.Context, ticketStr string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer cgosem.S.Acquire().Release() + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var bbEnv *C.int + var bbEnvOverrided C.int + if options.EnvOverride != nil { + bbEnvOverrided = C.int(*options.EnvOverride) + bbEnv = &bbEnvOverrided + } + + var ticket C.TVM_UserTicket + var tvmErr C.TVM_Error + C.TVM_CheckUserTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + bbEnv, + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackUserTicket(&ticket) +} + +// GetServiceTicketForAlias always returns ticket from memory +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + defer cgosem.S.Acquire().Release() + + aliasBytes := []byte(alias) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicketForAlias( + c.handle, + (*C.uchar)(&aliasBytes[0]), C.int(len(aliasBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +// GetServiceTicketForID always returns ticket from memory +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicket( + c.handle, + C.uint32_t(dstID), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +// GetStatus returns current status of client. +// See detials: https://godoc.yandex-team.ru/pkg/a.yandex-team.ru/library/go/yandex/tvm/#Client +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var status C.TVM_ClientStatus + var tvmErr C.TVM_Error + C.TVM_GetStatus(c.handle, &status, &tvmErr, &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return tvm.ClientStatusInfo{}, err + } + + return unpackClientStatus(&status), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + currentRoles := c.getCurrentRoles() + var currentRevision []byte + var currentRevisionPtr *C.uchar + if currentRoles != nil { + currentRevision = []byte(currentRoles.GetMeta().Revision) + currentRevisionPtr = (*C.uchar)(¤tRevision[0]) + } + + var raw *C.char + var rawSize C.int + var tvmErr C.TVM_Error + C.TVM_GetRoles( + c.handle, + currentRevisionPtr, C.int(len(currentRevision)), + &raw, + &rawSize, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + if raw == nil { + return currentRoles, nil + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if currentRoles != c.roles { + return c.roles, nil + } + + roles, err := tvm.NewRoles(C.GoBytes(unsafe.Pointer(raw), rawSize)) + if err != nil { + return nil, err + } + + c.roles = roles + return c.roles, nil +} + +func (c *Client) getCurrentRoles() *tvm.Roles { + c.mutex.RLock() + defer c.mutex.RUnlock() + return c.roles +} diff --git a/library/go/yandex/tvm/tvmauth/client_example_test.go b/library/go/yandex/tvm/tvmauth/client_example_test.go new file mode 100644 index 0000000000..babf8d51b1 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client_example_test.go @@ -0,0 +1,182 @@ +package tvmauth_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func ExampleNewAPIClient_getServiceTicketsWithAliases() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_getServiceTicketsWithID() { + blackboxID := tvm.ClientID(1000502) + datasyncID := tvm.ClientID(1000503) + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{ + blackboxID, + datasyncID, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForID(context.Background(), blackboxID) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_checkServiceTicket() { + // allowed tvm consumers for your service + acl := map[tvm.ClientID]interface{}{} + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + + serviceTicketStruct, err := c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "service ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if serviceTicketStruct != nil { + response["debug_info"] = serviceTicketStruct.DbgInfo + } + panic(response) // return 403 + } + if _, ok := acl[serviceTicketStruct.SrcID]; !ok { + response := map[string]string{ + "error": fmt.Sprintf("tvm client id is not allowed: %d", serviceTicketStruct.SrcID), + } + panic(response) // return 403 + } + + // proceed... +} + +func ExampleNewAPIClient_checkUserTicket() { + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + userTicketFromRequest := "lol" + + _, _ = c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) // See example for this method + + userTicketStruct, err := c.CheckUserTicket(context.Background(), userTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "user ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if userTicketStruct != nil { + response["debug_info"] = userTicketStruct.DbgInfo + } + panic(response) // return 403 + } + + fmt.Printf("Got user in request: %d", userTicketStruct.DefaultUID) + // proceed... +} + +func ExampleNewAPIClient_createClientWithAllSettings() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + _, _ = tvmauth.NewAPIClient(settings, &nop.Logger{}) +} + +func ExampleNewToolClient_getServiceTicketsWithAliases() { + // should be configured in tvmtool + blackboxAlias := "blackbox" + + settings := tvmauth.TvmToolSettings{ + Alias: "my_service", + Port: 18000, + AuthToken: "kek", + } + + c, err := tvmauth.NewToolClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + // please extrapolate other methods for this way of construction +} diff --git a/library/go/yandex/tvm/tvmauth/doc.go b/library/go/yandex/tvm/tvmauth/doc.go new file mode 100644 index 0000000000..ece7efd3ba --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/doc.go @@ -0,0 +1,10 @@ +// CGO implementation of tvm-interface based on ticket_parser2. +// +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package provides client via tvm-api or tvmtool. +// Also this package provides the most efficient way for checking tickets regardless of the client construction way. +// All scenerios are provided without any request after construction. +// +// You should create client with NewAPIClient() or NewToolClient(). +// Also you need to check status of client with GetStatus(). +package tvmauth diff --git a/library/go/yandex/tvm/tvmauth/logger.go b/library/go/yandex/tvm/tvmauth/logger.go new file mode 100644 index 0000000000..3731b16b65 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/logger.go @@ -0,0 +1,77 @@ +//go:build cgo +// +build cgo + +package tvmauth + +import "C" +import ( + "fmt" + "sync" + + "a.yandex-team.ru/library/go/core/log" +) + +// CGO pointer rules state: +// +// Go code may pass a Go pointer to C provided the Go memory to which it points **does not contain any Go pointers**. +// +// Logger is an interface and contains pointer to implementation. That means, we are forbidden from +// passing Logger to C code. +// +// Instead, we put logger into a global map and pass key to the C code. +// +// This might seem inefficient, but we are not concerned with performance here, since the logger is not on the hot path anyway. + +var ( + loggersLock sync.Mutex + nextSlot int + loggers = map[int]log.Logger{} +) + +func registerLogger(l log.Logger) int { + loggersLock.Lock() + defer loggersLock.Unlock() + + i := nextSlot + nextSlot++ + loggers[i] = l + return i +} + +func unregisterLogger(i int) { + loggersLock.Lock() + defer loggersLock.Unlock() + + if _, ok := loggers[i]; !ok { + panic(fmt.Sprintf("attempt to unregister unknown logger %d", i)) + } + + delete(loggers, i) +} + +func findLogger(i int) log.Logger { + loggersLock.Lock() + defer loggersLock.Unlock() + + return loggers[i] +} + +//export TVM_WriteToLog +// +// TVM_WriteToLog is technical artifact +func TVM_WriteToLog(logger int, level int, msgData *C.char, msgSize C.int) { + l := findLogger(logger) + + msg := C.GoStringN(msgData, msgSize) + + switch level { + case 3: + l.Error(msg) + case 4: + l.Warn(msg) + case 6: + l.Info(msg) + default: + l.Debug(msg) + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go new file mode 100644 index 0000000000..37e467e286 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go @@ -0,0 +1,338 @@ +package tiroletest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func getPort(t *testing.T, filename string) int { + body, err := ioutil.ReadFile(filename) + require.NoError(t, err) + + res, err := strconv.Atoi(string(body)) + require.NoError(t, err, "port is invalid: ", filename) + + return res +} + +func createClientWithTirole(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + env := tvm.BlackboxProdYateam + client, err := tvmauth.NewAPIClient( + tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("e5kL0vM3nP-nPf-388Hi6Q", nil), + DiskCacheDir: "./", + FetchRolesForIdmSystemSlug: "some_slug_2", + EnableServiceTicketChecking: true, + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + BlackboxEnv: &env, + TVMHost: "http://localhost", + TVMPort: getPort(t, "tvmapi.port"), + TiroleHost: "http://localhost", + TirolePort: getPort(t, "tirole.port"), + TiroleTvmID: 1000001, + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func createClientWithTvmtool(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + token, err := ioutil.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + client, err := tvmauth.NewToolClient( + tvmauth.TvmToolSettings{ + Alias: "me", + AuthToken: string(token), + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + Port: getPort(t, "tvmtool.port"), + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func checkServiceNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000000: tvmknife unittest service -s 1000000000 -d 1000502 + stWithoutRoles := "3:serv:CBAQ__________9_IgoIgJTr3AMQtog9:Sv3SKuDQ4p-2419PKqc1vo9EC128K6Iv7LKck5SyliJZn5gTAqMDAwb9aYWHhf49HTR-Qmsjw4i_Lh-sNhge-JHWi5PTGFJm03CZHOCJG9Y0_G1pcgTfodtAsvDykMxLhiXGB4N84cGhVVqn1pFWz6SPmMeKUPulTt7qH1ifVtQ" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{Status: tvm.TicketNoRoles}, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + st, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res := roles.GetRolesForService(st) + require.Nil(t, res) + } +} + +func checkServiceHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000001: tvmknife unittest service -s 1000000001 -d 1000502 + stWithRoles := "3:serv:CBAQ__________9_IgoIgZTr3AMQtog9:EyPympmoLBM6jyiQLcK8ummNmL5IUAdTvKM1do8ppuEgY6yHfto3s_WAKmP9Pf9EiNqPBe18HR7yKmVS7gvdFJY4gP4Ut51ejS-iBPlsbsApJOYTgodQPhkmjHVKIT0ub0pT3fWHQtapb8uimKpGcO6jCfopFQSVG04Ehj7a0jw" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckServiceTicket(ctx, stWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + require.EqualValues(t, + `{ + "/role/service/read/": [], + "/role/service/write/": [ + { + "foo": "bar", + "kek": "lol" + } + ] +}`, + clientRoles.GetRolesForService(checked).DebugPrint(), + ) + + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", nil)) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", nil)) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/foo/", nil)) + + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"kek": "lol"}, + })) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } +} + +func checkUserNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1000000000: tvmknife unittest user -d 1000000000 --env prod_yateam + utWithoutRoles := "3:user:CAwQ__________9_GhYKBgiAlOvcAxCAlOvcAyDShdjMBCgC:LloRDlCZ4vd0IUTOj6MD1mxBPgGhS6EevnnWvHgyXmxc--2CVVkAtNKNZJqCJ6GtDY4nknEnYmWvEu6-MInibD-Uk6saI1DN-2Y3C1Wdsz2SJCq2OYgaqQsrM5PagdyP9PLrftkuV_ZluS_FUYebMXPzjJb0L0ALKByMPkCVWuk" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{Status: tvm.TicketNoRoles}, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + ut, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, res) + } +} + +func checkUserHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1120000000000001: tvmknife unittest user -d 1120000000000001 --env prod_yateam + utWithRoles := "3:user:CAwQ__________9_GhwKCQiBgJiRpdT-ARCBgJiRpdT-ASDShdjMBCgC:SQV7Z9hDpZ_F62XGkSF6yr8PoZHezRp0ZxCINf_iAbT2rlEiO6j4UfLjzwn3EnRXkAOJxuAtTDCnHlrzdh3JgSKK7gciwPstdRT5GGTixBoUU9kI_UlxEbfGBX1DfuDsw_GFQ2eCLu4Svq6jC3ynuqQ41D2RKopYL8Bx8PDZKQc" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckUserTicket(ctx, utWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + ut, err := clientRoles.GetRolesForUser(checked, nil) + require.NoError(t, err) + require.EqualValues(t, + `{ + "/role/user/read/": [ + { + "foo": "bar", + "kek": "lol" + } + ], + "/role/user/write/": [] +}`, + ut.DebugPrint(), + ) + + res, err := clientRoles.CheckUserRole(checked, "/role/user/write/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/foo/", nil) + require.NoError(t, err) + require.False(t, res) + + res, err = clientRoles.CheckUserRole(checked, "/role/user/write/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.True(t, res) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } + +} + +func TestRolesFromTiroleCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml new file mode 100644 index 0000000000..d2fcaead59 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml @@ -0,0 +1,5 @@ +slugs: + some_slug_2: + tvmid: + - 1000502 + - 1000503 diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json new file mode 100644 index 0000000000..84d85fae19 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json @@ -0,0 +1,22 @@ +{ + "revision": "some_revision_2", + "born_date": 1642160002, + "tvm": { + "1000000001": { + "/role/service/read/": [{}], + "/role/service/write/": [{ + "foo": "bar", + "kek": "lol" + }] + } + }, + "user": { + "1120000000000001": { + "/role/user/write/": [{}], + "/role/user/read/": [{ + "foo": "bar", + "kek": "lol" + }] + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg new file mode 100644 index 0000000000..dbb8fcd458 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg @@ -0,0 +1,10 @@ +{ + "BbEnvType": 2, + "clients": { + "me": { + "secret": "fake_secret", + "self_tvm_id": 1000502, + "roles_for_idm_slug": "some_slug_2" + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/.arcignore b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore new file mode 100644 index 0000000000..251ded04a5 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore @@ -0,0 +1 @@ +tooltest diff --git a/library/go/yandex/tvm/tvmauth/tooltest/client_test.go b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go new file mode 100644 index 0000000000..a8d68e55ee --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go @@ -0,0 +1,57 @@ +package tooltest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func recipeToolOptions(t *testing.T) tvmauth.TvmToolSettings { + var portStr, token []byte + portStr, err := ioutil.ReadFile("tvmtool.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + + token, err = ioutil.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + return tvmauth.TvmToolSettings{Alias: "me", Port: port, AuthToken: string(token)} +} + +func TestToolClient(t *testing.T) { + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + t.Run("GetServiceTicketForID", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100500) + require.NoError(t, err) + }) + + t.Run("GetInvalidTicket", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100999) + require.Error(t, err) + require.IsType(t, &tvm.Error{}, err) + require.Equal(t, tvm.ErrorBrokenTvmClientSettings, err.(*tvm.Error).Code) + }) + + t.Run("ClientStatus", func(t *testing.T) { + status, err := c.GetStatus(context.Background()) + require.NoError(t, err) + + t.Logf("Got client status: %v", status) + + require.Equal(t, tvm.ClientStatus(0), status.Status) + require.Equal(t, "OK", status.LastError) + }) +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go new file mode 100644 index 0000000000..99e6a5835e --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go @@ -0,0 +1,34 @@ +package tooltest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +type testLogger struct { + nop.Logger + + msgs []string +} + +func (l *testLogger) Info(msg string, fields ...log.Field) { + l.msgs = append(l.msgs, msg) +} + +func TestLogger(t *testing.T) { + var l testLogger + + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &l) + require.NoError(t, err) + defer c.Destroy() + + time.Sleep(time.Second) + + require.NotEmpty(t, l.msgs) +} diff --git a/library/go/yandex/tvm/tvmauth/tvm.cpp b/library/go/yandex/tvm/tvmauth/tvm.cpp new file mode 100644 index 0000000000..b3d2070df0 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.cpp @@ -0,0 +1,417 @@ +#include "tvm.h" + +#include "_cgo_export.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/tvmauth/client/facade.h> +#include <library/cpp/tvmauth/client/logger.h> +#include <library/cpp/tvmauth/client/mocked_updater.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/api/settings.h> +#include <library/cpp/tvmauth/client/misc/roles/roles.h> + +using namespace NTvmAuth; + +void TVM_DestroyMemPool(TVM_MemPool* pool) { + auto freeStr = [](char*& str) { + if (str != nullptr) { + free(str); + str = nullptr; + } + }; + + freeStr(pool->ErrorStr); + + if (pool->Scopes != nullptr) { + free(reinterpret_cast<void*>(pool->Scopes)); + pool->Scopes = nullptr; + } + + if (pool->TicketStr != nullptr) { + delete reinterpret_cast<TString*>(pool->TicketStr); + pool->TicketStr = nullptr; + } + if (pool->RawRolesStr != nullptr) { + delete reinterpret_cast<TString*>(pool->RawRolesStr); + pool->RawRolesStr = nullptr; + } + + if (pool->CheckedUserTicket != nullptr) { + delete reinterpret_cast<TCheckedUserTicket*>(pool->CheckedUserTicket); + pool->CheckedUserTicket = nullptr; + } + + if (pool->CheckedServiceTicket != nullptr) { + delete reinterpret_cast<TCheckedServiceTicket*>(pool->CheckedServiceTicket); + pool->CheckedServiceTicket = nullptr; + } + + freeStr(pool->DbgInfo); + freeStr(pool->LogInfo); + freeStr(pool->LastError.Data); +} + +static void PackStr(TStringBuf in, TVM_String* out, char*& poolStr) noexcept { + out->Data = poolStr = reinterpret_cast<char*>(malloc(in.size())); + out->Size = in.size(); + memcpy(out->Data, in.data(), in.size()); +} + +static void UnpackSettings( + TVM_ApiSettings* in, + NTvmApi::TClientSettings* out) { + if (in->SelfId != 0) { + out->SelfTvmId = in->SelfId; + } + + if (in->EnableServiceTicketChecking != 0) { + out->CheckServiceTickets = true; + } + + if (in->EnableUserTicketChecking != 0) { + out->CheckUserTicketsWithBbEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); + } + + if (in->SelfSecret != nullptr) { + out->Secret = TString(reinterpret_cast<char*>(in->SelfSecret), in->SelfSecretSize); + } + + TStringBuf aliases(reinterpret_cast<char*>(in->DstAliases), in->DstAliasesSize); + if (aliases) { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(aliases, &doc), "Invalid json: from go part: " << aliases); + Y_ENSURE(doc.IsMap(), "Dsts is not map: from go part: " << aliases); + + for (const auto& pair : doc.GetMap()) { + Y_ENSURE(pair.second.IsUInteger(), "dstID must be number"); + out->FetchServiceTicketsForDstsWithAliases.emplace(pair.first, pair.second.GetUInteger()); + } + } + + if (in->IdmSystemSlug != nullptr) { + out->FetchRolesForIdmSystemSlug = TString(reinterpret_cast<char*>(in->IdmSystemSlug), in->IdmSystemSlugSize); + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; + } + + if (in->TVMHost != nullptr) { + out->TvmHost = TString(reinterpret_cast<char*>(in->TVMHost), in->TVMHostSize); + out->TvmPort = in->TVMPort; + } + if (in->TiroleHost != nullptr) { + out->TiroleHost = TString(reinterpret_cast<char*>(in->TiroleHost), in->TiroleHostSize); + out->TirolePort = in->TirolePort; + } + if (in->TiroleTvmId != 0) { + out->TiroleTvmId = in->TiroleTvmId; + } + + if (in->DiskCacheDir != nullptr) { + out->DiskCacheDir = TString(reinterpret_cast<char*>(in->DiskCacheDir), in->DiskCacheDirSize); + } +} + +static void UnpackSettings( + TVM_ToolSettings* in, + NTvmTool::TClientSettings* out) { + if (in->Port != 0) { + out->SetPort(in->Port); + } + + if (in->HostnameSize != 0) { + out->SetHostname(TString(reinterpret_cast<char*>(in->Hostname), in->HostnameSize)); + } + + if (in->AuthTokenSize != 0) { + out->SetAuthToken(TString(reinterpret_cast<char*>(in->AuthToken), in->AuthTokenSize)); + } + + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; +} + +static void UnpackSettings( + TVM_UnittestSettings* in, + TMockedUpdater::TSettings* out) { + out->SelfTvmId = in->SelfId; + out->UserTicketEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); +} + +template <class TTicket> +static void PackScopes( + const TScopes& scopes, + TTicket* ticket, + TVM_MemPool* pool) { + if (scopes.empty()) { + return; + } + + pool->Scopes = ticket->Scopes = reinterpret_cast<TVM_String*>(malloc(scopes.size() * sizeof(TVM_String))); + + for (size_t i = 0; i < scopes.size(); i++) { + ticket->Scopes[i].Data = const_cast<char*>(scopes[i].data()); + ticket->Scopes[i].Size = scopes[i].size(); + } + ticket->ScopesSize = scopes.size(); +} + +static void PackUserTicket( + TCheckedUserTicket in, + TVM_UserTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedUserTicket(std::move(in)); + pool->CheckedUserTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->DefaultUid = copy->GetDefaultUid(); + + const auto& uids = copy->GetUids(); + if (!uids.empty()) { + out->Uids = const_cast<TUid*>(uids.data()); + out->UidsSize = uids.size(); + } + + out->Env = static_cast<int>(copy->GetEnv()); + + PackScopes(copy->GetScopes(), out, pool); +} + +static void PackServiceTicket( + TCheckedServiceTicket in, + TVM_ServiceTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedServiceTicket(std::move(in)); + pool->CheckedServiceTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->SrcId = copy->GetSrc(); + + auto issuer = copy->GetIssuerUid(); + if (issuer) { + out->IssuerUid = *issuer; + } +} + +template <class F> +static void CatchError(TVM_Error* err, TVM_MemPool* pool, const F& f) { + try { + f(); + } catch (const TMalformedTvmSecretException& ex) { + err->Code = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMalformedTvmKeysException& ex) { + err->Code = 2; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TEmptyTvmKeysException& ex) { + err->Code = 3; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TNotAllowedException& ex) { + err->Code = 4; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TBrokenTvmClientSettings& ex) { + err->Code = 5; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMissingServiceTicket& ex) { + err->Code = 6; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TPermissionDenied& ex) { + err->Code = 7; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TRetriableException& ex) { + err->Code = 8; + err->Retriable = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const std::exception& ex) { + err->Code = 8; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } +} + +namespace { + class TGoLogger: public ILogger { + public: + TGoLogger(int loggerHandle) + : LoggerHandle_(loggerHandle) + { + } + + void Log(int lvl, const TString& msg) override { + TVM_WriteToLog(LoggerHandle_, lvl, const_cast<char*>(msg.data()), msg.size()); + } + + private: + int LoggerHandle_; + }; + +} + +extern "C" void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + NTvmApi::TClientSettings realSettings; + UnpackSettings(&settings, &realSettings); + + realSettings.LibVersionPrefix = "go_"; + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TString alias(reinterpret_cast<char*>(settings.Alias), settings.AliasSize); + NTvmTool::TClientSettings realSettings(alias); + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TMockedUpdater::TSettings realSettings; + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(MakeIntrusiveConst<TMockedUpdater>(realSettings)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_DestroyClient(void* handle) { + delete static_cast<TTvmClient*>(handle); +} + +extern "C" void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + + TClientStatus s = client->GetStatus(); + status->Status = static_cast<int>(s.GetCode()); + + PackStr(s.GetLastError(), &status->LastError, pool->LastError.Data); + }); +} + +extern "C" void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + + TMaybe<EBlackboxEnv> optEnv; + if (env) { + optEnv = (EBlackboxEnv)*env; + } + + auto userTicket = client->CheckUserTicket(str, optEnv); + PackUserTicket(std::move(userTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + auto serviceTicket = client->CheckServiceTicket(str); + PackServiceTicket(std::move(serviceTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(dstId)); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(TString((char*)alias, aliasSize))); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + NTvmAuth::NRoles::TRolesPtr roles = client->GetRoles(); + + if (currentRevision && + roles->GetMeta().Revision == TStringBuf(reinterpret_cast<char*>(currentRevision), currentRevisionSize)) { + return; + } + + auto rawPtr = new TString(roles->GetRaw()); + + pool->RawRolesStr = reinterpret_cast<void*>(rawPtr); + *raw = const_cast<char*>(rawPtr->c_str()); + *rawSize = rawPtr->size(); + }); +} diff --git a/library/go/yandex/tvm/tvmauth/tvm.h b/library/go/yandex/tvm/tvmauth/tvm.h new file mode 100644 index 0000000000..f7c7a5b2bc --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.h @@ -0,0 +1,192 @@ +#pragma once + +#include <util/system/types.h> + +#include <stdint.h> +#include <time.h> + +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct _TVM_String { + char* Data; + int Size; + } TVM_String; + + // MemPool owns memory allocated by C. + typedef struct { + char* ErrorStr; + void* TicketStr; + void* RawRolesStr; + TVM_String* Scopes; + void* CheckedUserTicket; + void* CheckedServiceTicket; + char* DbgInfo; + char* LogInfo; + TVM_String LastError; + } TVM_MemPool; + + void TVM_DestroyMemPool(TVM_MemPool* pool); + + typedef struct { + int Code; + int Retriable; + + TVM_String Message; + } TVM_Error; + + typedef struct { + int Status; + + ui64 DefaultUid; + + ui64* Uids; + int UidsSize; + + int Env; + + TVM_String* Scopes; + int ScopesSize; + + TVM_String DbgInfo; + TVM_String LogInfo; + } TVM_UserTicket; + + typedef struct { + int Status; + + ui32 SrcId; + + ui64 IssuerUid; + + TVM_String DbgInfo; + TVM_String LogInfo; + } TVM_ServiceTicket; + + typedef struct { + ui32 SelfId; + + int EnableServiceTicketChecking; + + int EnableUserTicketChecking; + int BlackboxEnv; + + unsigned char* SelfSecret; + int SelfSecretSize; + unsigned char* DstAliases; + int DstAliasesSize; + + unsigned char* IdmSystemSlug; + int IdmSystemSlugSize; + int DisableSrcCheck; + int DisableDefaultUIDCheck; + + unsigned char* TVMHost; + int TVMHostSize; + int TVMPort; + unsigned char* TiroleHost; + int TiroleHostSize; + int TirolePort; + ui32 TiroleTvmId; + + unsigned char* DiskCacheDir; + int DiskCacheDirSize; + } TVM_ApiSettings; + + typedef struct { + unsigned char* Alias; + int AliasSize; + + int Port; + + unsigned char* Hostname; + int HostnameSize; + + unsigned char* AuthToken; + int AuthTokenSize; + + int DisableSrcCheck; + int DisableDefaultUIDCheck; + } TVM_ToolSettings; + + typedef struct { + ui32 SelfId; + int BlackboxEnv; + } TVM_UnittestSettings; + + typedef struct { + int Status; + TVM_String LastError; + } TVM_ClientStatus; + + // First argument must be passed by value. "Go code may pass a Go pointer to C + // provided the Go memory to which it points does not contain any Go pointers." + void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_DestroyClient(void* handle); + + void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool); + +#ifdef __cplusplus +} +#endif diff --git a/library/go/yandex/tvm/tvmauth/types.go b/library/go/yandex/tvm/tvmauth/types.go new file mode 100644 index 0000000000..e9df007ad1 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/types.go @@ -0,0 +1,139 @@ +package tvmauth + +import ( + "sync" + "unsafe" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// TvmAPISettings may be used to fetch data from tvm-api +type TvmAPISettings struct { + // SelfID is required for ServiceTicketOptions and EnableServiceTicketChecking + SelfID tvm.ClientID + + // ServiceTicketOptions provides info for fetching Service Tickets from tvm-api + // to allow you send them to your backends. + // + // WARNING: It is not way to provide authorization for incoming ServiceTickets! + // It is way only to send your ServiceTickets to your backend! + ServiceTicketOptions *TVMAPIOptions + + // EnableServiceTicketChecking enables fetching of public keys for signature checking + EnableServiceTicketChecking bool + + // BlackboxEnv with not nil value enables UserTicket checking + // and enables fetching of public keys for signature checking + BlackboxEnv *tvm.BlackboxEnv + + fetchRolesForIdmSystemSlug []byte + // Non-empty FetchRolesForIdmSystemSlug enables roles fetching from tirole + FetchRolesForIdmSystemSlug string + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + + tvmHost []byte + // TVMHost should be used only in tests + TVMHost string + // TVMPort should be used only in tests + TVMPort int + + tiroleHost []byte + // TiroleHost should be used only in tests or for tirole-api-test.yandex.net + TiroleHost string + // TirolePort should be used only in tests + TirolePort int + // TiroleTvmID should be used only in tests or for tirole-api-test.yandex.net + TiroleTvmID tvm.ClientID + + // Directory for disk cache. + // Requires read/write permissions. Permissions will be checked before start. + // WARNING: The same directory can be used only: + // - for TVM clients with the same settings + // OR + // - for new client replacing previous - with another config. + // System user must be the same for processes with these clients inside. + // Implementation doesn't provide other scenarios. + DiskCacheDir string + diskCacheDir []byte +} + +// TVMAPIOptions is part of TvmAPISettings: allows to enable fetching of ServiceTickets +type TVMAPIOptions struct { + selfSecret string + selfSecretB []byte + dstAliases []byte +} + +// TvmToolSettings may be used to fetch data from tvmtool +type TvmToolSettings struct { + // Alias is required: self alias of your tvm ClientID + Alias string + alias []byte + + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + + // Port will be detected with env["DEPLOY_TVM_TOOL_URL"] (provided with Yandex.Deploy), + // otherwise port == 1 (it is ok for Qloud) + Port int + // Hostname == "localhost" by default + Hostname string + hostname []byte + + // AuthToken is protection from SSRF. + // By default it is fetched from env: + // * TVMTOOL_LOCAL_AUTHTOKEN (provided with Yandex.Deploy) + // * QLOUD_TVM_TOKEN (provided with Qloud) + AuthToken string + authToken []byte +} + +type TvmUnittestSettings struct { + // SelfID is required for service ticket checking + SelfID tvm.ClientID + + // Service ticket checking is enabled by default + + // User ticket checking is enabled by default: choose required environment + BlackboxEnv tvm.BlackboxEnv + + // Other features are not supported yet +} + +// Client contains raw pointer for C++ object +type Client struct { + handle unsafe.Pointer + logger *int + + roles *tvm.Roles + mutex *sync.RWMutex +} + +var _ tvm.Client = (*Client)(nil) diff --git a/library/go/yandex/tvm/tvmtool/any.go b/library/go/yandex/tvm/tvmtool/any.go new file mode 100644 index 0000000000..5c394af771 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any.go @@ -0,0 +1,37 @@ +package tvmtool + +import ( + "os" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const ( + LocalEndpointEnvKey = "TVMTOOL_URL" + LocalTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +var ErrUnknownTvmtoolEnvironment = xerrors.NewSentinel("unknown tvmtool environment") + +// NewAnyClient method creates a new tvmtool client with environment auto-detection. +// You must reuse it to prevent connection/goroutines leakage. +func NewAnyClient(opts ...Option) (*Client, error) { + switch { + case os.Getenv(QloudEndpointEnvKey) != "": + // it's Qloud + return NewQloudClient(opts...) + case os.Getenv(DeployEndpointEnvKey) != "": + // it's Y.Deploy + return NewDeployClient(opts...) + case os.Getenv(LocalEndpointEnvKey) != "": + passedOpts := append( + []Option{ + WithAuthToken(os.Getenv(LocalTokenEnvKey)), + }, + opts..., + ) + return NewClient(os.Getenv(LocalEndpointEnvKey), passedOpts...) + default: + return nil, ErrUnknownTvmtoolEnvironment.WithFrame() + } +} diff --git a/library/go/yandex/tvm/tvmtool/any_example_test.go b/library/go/yandex/tvm/tvmtool/any_example_test.go new file mode 100644 index 0000000000..d5959426bc --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewAnyClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewAnyClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/clients_test.go b/library/go/yandex/tvm/tvmtool/clients_test.go new file mode 100644 index 0000000000..5bf34b93fd --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/clients_test.go @@ -0,0 +1,154 @@ +//go:build linux || darwin +// +build linux darwin + +package tvmtool_test + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func TestNewClients(t *testing.T) { + type TestCase struct { + env map[string]string + willFail bool + expectedErr string + expectedBaseURI string + expectedAuthToken string + } + + cases := map[string]struct { + constructor func(opts ...tvmtool.Option) (*tvmtool.Client, error) + cases map[string]TestCase + }{ + "qloud": { + constructor: tvmtool.NewQloudClient, + cases: map[string]TestCase{ + "no-auth": { + willFail: true, + expectedErr: "empty auth token (looked at ENV[QLOUD_TVM_TOKEN])", + }, + "ok-default-origin": { + env: map[string]string{ + "QLOUD_TVM_TOKEN": "ok-default-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1/tvm", + expectedAuthToken: "ok-default-origin-token", + }, + "ok-custom-origin": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://localhost:9000", + "QLOUD_TVM_TOKEN": "ok-custom-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:9000/tvm", + expectedAuthToken: "ok-custom-origin-token", + }, + }, + }, + "deploy": { + constructor: tvmtool.NewDeployClient, + cases: map[string]TestCase{ + "no-url": { + willFail: true, + expectedErr: "empty tvmtool url (looked at ENV[DEPLOY_TVM_TOOL_URL])", + }, + "no-auth": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:2", + }, + willFail: true, + expectedErr: "empty auth token (looked at ENV[TVMTOOL_LOCAL_AUTHTOKEN])", + }, + "ok": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1337/tvm", + expectedAuthToken: "ok-token", + }, + }, + }, + "any": { + constructor: tvmtool.NewAnyClient, + cases: map[string]TestCase{ + "empty": { + willFail: true, + expectedErr: "unknown tvmtool environment", + }, + "ok-qloud": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://qloud:9000", + "QLOUD_TVM_TOKEN": "ok-qloud", + }, + expectedBaseURI: "http://qloud:9000/tvm", + expectedAuthToken: "ok-qloud", + }, + "ok-deploy": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://deploy:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-deploy", + }, + expectedBaseURI: "http://deploy:1337/tvm", + expectedAuthToken: "ok-deploy", + }, + "ok-local": { + env: map[string]string{ + "TVMTOOL_URL": "http://local:1338", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-local", + }, + willFail: false, + expectedBaseURI: "http://local:1338/tvm", + expectedAuthToken: "ok-local", + }, + }, + }, + } + + // NB! this checks are not thread safe, never use t.Parallel() and so on + for clientName, client := range cases { + t.Run(clientName, func(t *testing.T) { + for name, tc := range client.cases { + t.Run(name, func(t *testing.T) { + savedEnv := os.Environ() + defer func() { + os.Clearenv() + for _, env := range savedEnv { + parts := strings.SplitN(env, "=", 2) + err := os.Setenv(parts[0], parts[1]) + require.NoError(t, err) + } + }() + + os.Clearenv() + for key, val := range tc.env { + _ = os.Setenv(key, val) + } + + tvmClient, err := client.constructor() + if tc.willFail { + require.Error(t, err) + if tc.expectedErr != "" { + require.EqualError(t, err, tc.expectedErr) + } + + require.Nil(t, tvmClient) + } else { + require.NoError(t, err) + require.NotNil(t, tvmClient) + require.Equal(t, tc.expectedBaseURI, tvmClient.BaseURI()) + require.Equal(t, tc.expectedAuthToken, tvmClient.AuthToken()) + } + }) + } + }) + } +} diff --git a/library/go/yandex/tvm/tvmtool/deploy.go b/library/go/yandex/tvm/tvmtool/deploy.go new file mode 100644 index 0000000000..d7a2eac62b --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy.go @@ -0,0 +1,31 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + DeployEndpointEnvKey = "DEPLOY_TVM_TOOL_URL" + DeployTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +// NewDeployClient method creates a new tvmtool client for Deploy environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewDeployClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(DeployEndpointEnvKey) + if baseURI == "" { + return nil, fmt.Errorf("empty tvmtool url (looked at ENV[%s])", DeployEndpointEnvKey) + } + + authToken := os.Getenv(DeployTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", DeployTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/deploy_example_test.go b/library/go/yandex/tvm/tvmtool/deploy_example_test.go new file mode 100644 index 0000000000..d352336d58 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewDeployClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewDeployClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/doc.go b/library/go/yandex/tvm/tvmtool/doc.go new file mode 100644 index 0000000000..d46dca8132 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/doc.go @@ -0,0 +1,7 @@ +// Pure Go implementation of tvm-interface based on TVMTool client. +// +// https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package can provide fast getting of service tickets (from cache), other cases lead to http request to localhost. +// Also this package provides TVM client for Qloud (NewQloudClient) and Yandex.Deploy (NewDeployClient) environments. +package tvmtool diff --git a/library/go/yandex/tvm/tvmtool/errors.go b/library/go/yandex/tvm/tvmtool/errors.go new file mode 100644 index 0000000000..f0b08a9878 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/errors.go @@ -0,0 +1,61 @@ +package tvmtool + +import ( + "fmt" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// Generic TVM errors, before retry any request it check .Retriable field. +type Error = tvm.Error + +const ( + // ErrorAuthFail - auth failed, probably you provides invalid auth token + ErrorAuthFail = tvm.ErrorAuthFail + // ErrorBadRequest - tvmtool rejected our request, check .Msg for details + ErrorBadRequest = tvm.ErrorBadRequest + // ErrorOther - any other TVM-related errors, check .Msg for details + ErrorOther = tvm.ErrorOther +) + +// Ticket validation error +type TicketError = tvm.TicketError + +const ( + TicketErrorInvalidScopes = tvm.TicketInvalidScopes + TicketErrorOther = tvm.TicketStatusOther +) + +type PingCode uint32 + +const ( + PingCodeDie = iota + PingCodeWarning + PingCodeError + PingCodeOther +) + +func (e PingCode) String() string { + switch e { + case PingCodeDie: + return "HttpDie" + case PingCodeWarning: + return "Warning" + case PingCodeError: + return "Error" + case PingCodeOther: + return "Other" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +// Special ping error +type PingError struct { + Code PingCode + Err error +} + +func (e *PingError) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Err.Error(), e.Code) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go new file mode 100644 index 0000000000..95fcc0bd51 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + srvTicket string + userTicket string +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&srvTicket, "srv", "", "service ticket to check") + flag.StringVar(&userTicket, "usr", "", "user ticket to check") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + fmt.Printf("------ Check service ticket ------\n\n") + srvCheck, err := tvmClient.CheckServiceTicket(context.Background(), srvTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", srvCheck, err) + } else { + fmt.Printf("OK\nInfo: %s\n", srvCheck) + } + + if userTicket == "" { + return + } + + fmt.Printf("\n------ Check user ticket result ------\n\n") + + usrCheck, err := tvmClient.CheckUserTicket(context.Background(), userTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", usrCheck, err) + return + } + fmt.Printf("OK\nInfo: %s\n", usrCheck) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go new file mode 100644 index 0000000000..2abfca8bfb --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + dst = "dst" +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&dst, "dst", dst, "Destination TVM app (must be configured in tvm-tool)") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), dst) + if err != nil { + zlog.Fatal("failed to get tvm ticket", log.String("dst", dst), log.Error(err)) + return + } + + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json new file mode 100644 index 0000000000..db768f5d53 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json @@ -0,0 +1,32 @@ +{ + "BbEnvType": 3, + "clients": { + "main": { + "secret": "fake_secret", + "self_tvm_id": 42, + "dsts": { + "he": { + "dst_id": 100500 + }, + "he_clone": { + "dst_id": 100500 + }, + "slave": { + "dst_id": 43 + }, + "self": { + "dst_id": 42 + } + } + }, + "slave": { + "secret": "fake_secret", + "self_tvm_id": 43, + "dsts": { + "he": { + "dst_id": 100500 + } + } + } + } +}
\ No newline at end of file diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go new file mode 100644 index 0000000000..b625ca774f --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go @@ -0,0 +1,128 @@ +package cache + +import ( + "sync" + "time" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +const ( + Hit Status = iota + Miss + GonnaMissy +) + +type ( + Status int + + Cache struct { + ttl time.Duration + maxTTL time.Duration + tickets map[tvm.ClientID]entry + aliases map[string]tvm.ClientID + lock sync.RWMutex + } + + entry struct { + value *string + born time.Time + } +) + +func New(ttl, maxTTL time.Duration) *Cache { + return &Cache{ + ttl: ttl, + maxTTL: maxTTL, + tickets: make(map[tvm.ClientID]entry, 1), + aliases: make(map[string]tvm.ClientID, 1), + } +} + +func (c *Cache) Gc() { + now := time.Now() + + c.lock.Lock() + defer c.lock.Unlock() + for clientID, ticket := range c.tickets { + if ticket.born.Add(c.maxTTL).After(now) { + continue + } + + delete(c.tickets, clientID) + for alias, aClientID := range c.aliases { + if clientID == aClientID { + delete(c.aliases, alias) + } + } + } +} + +func (c *Cache) ClientIDs() []tvm.ClientID { + c.lock.RLock() + defer c.lock.RUnlock() + + clientIDs := make([]tvm.ClientID, 0, len(c.tickets)) + for clientID := range c.tickets { + clientIDs = append(clientIDs, clientID) + } + return clientIDs +} + +func (c *Cache) Aliases() []string { + c.lock.RLock() + defer c.lock.RUnlock() + + aliases := make([]string, 0, len(c.aliases)) + for alias := range c.aliases { + aliases = append(aliases, alias) + } + return aliases +} + +func (c *Cache) Load(clientID tvm.ClientID) (*string, Status) { + c.lock.RLock() + e, ok := c.tickets[clientID] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + now := time.Now() + exp := e.born.Add(c.ttl) + if exp.After(now) { + return e.value, Hit + } + + exp = e.born.Add(c.maxTTL) + if exp.After(now) { + return e.value, GonnaMissy + } + + c.lock.Lock() + delete(c.tickets, clientID) + c.lock.Unlock() + return nil, Miss +} + +func (c *Cache) LoadByAlias(alias string) (*string, Status) { + c.lock.RLock() + clientID, ok := c.aliases[alias] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + return c.Load(clientID) +} + +func (c *Cache) Store(clientID tvm.ClientID, alias string, value *string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.aliases[alias] = clientID + c.tickets[clientID] = entry{ + value: value, + born: time.Now(), + } +} diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go new file mode 100644 index 0000000000..d9a1780108 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go @@ -0,0 +1,125 @@ +package cache_test + +import ( + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +var ( + testDst = "test_dst" + testDstAlias = "test_dst_alias" + testDstID = tvm.ClientID(1) + testValue = "test_val" +) + +func TestNewAtHour(t *testing.T) { + c := cache.New(time.Hour, 11*time.Hour) + assert.NotNil(t, c, "failed to create cache") +} + +func TestCache_Load(t *testing.T) { + + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + // checking before + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.Hit, hit, "failed to get '%d' from cache before deadline", testDstID) + assert.NotNil(t, r, "failed to get '%d' from cache before deadline", testDstID) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.Hit, hit, "failed to get '%s' from cache before deadline", testDst) + assert.NotNil(t, r, "failed to get %q from tickets before deadline", testDst) + assert.Equal(t, testValue, *r) + } + { + r, hit := c.Load(999833321) + assert.Equal(t, cache.Miss, hit, "got tickets for '999833321', but that key must be never existed") + assert.Nil(t, r, "got tickets for '999833321', but that key must be never existed") + + r, hit = c.LoadByAlias("kek") + assert.Equal(t, cache.Miss, hit, "got tickets for 'kek', but that key must be never existed") + assert.Nil(t, r, "got tickets for 'kek', but that key must be never existed") + } + + time.Sleep(3 * time.Second) + // checking after + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + } +} + +func TestCache_Keys(t *testing.T) { + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + sort.Strings(aliases) + require.Equal(t, 2, len(aliases), "not correct length of aliases") + require.EqualValues(t, []string{testDst, testDstAlias}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.EqualValues(t, []tvm.ClientID{testDstID}, ids) + }) +} + +func TestCache_ExpiredKeys(t *testing.T) { + c := cache.New(time.Second, 10*time.Second) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + time.Sleep(3 * time.Second) + c.Gc() + + var ( + newDst = "new_dst" + newDstID = tvm.ClientID(2) + ) + c.Store(newDstID, newDst, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 3, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{testDst, testDstAlias, newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 2, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{testDstID, newDstID}, ids) + }) + + time.Sleep(8 * time.Second) + c.Gc() + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 1, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{newDstID}, ids) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/opts.go b/library/go/yandex/tvm/tvmtool/opts.go new file mode 100644 index 0000000000..91d29139d8 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/opts.go @@ -0,0 +1,103 @@ +package tvmtool + +import ( + "context" + "net/http" + "strings" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/xerrors" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +type ( + Option func(tool *Client) error +) + +// Source TVM client (id or alias) +// +// WARNING: id/alias must be configured in tvmtool. Documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#konfig +func WithSrc(src string) Option { + return func(tool *Client) error { + tool.src = src + return nil + } +} + +// Auth token +func WithAuthToken(token string) Option { + return func(tool *Client) error { + tool.authToken = token + return nil + } +} + +// Use custom HTTP client +func WithHTTPClient(client *http.Client) Option { + return func(tool *Client) error { + tool.ownHTTPClient = false + tool.httpClient = client + return nil + } +} + +// Enable or disable service tickets cache +// +// Enabled by default +func WithCacheEnabled(enabled bool) Option { + return func(tool *Client) error { + switch { + case enabled && tool.cache == nil: + tool.cache = cache.New(cacheTTL, cacheMaxTTL) + case !enabled: + tool.cache = nil + } + return nil + } +} + +// Overrides blackbox environment defined in config. +// +// Documentation about environment overriding: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func WithOverrideEnv(bbEnv tvm.BlackboxEnv) Option { + return func(tool *Client) error { + tool.bbEnv = strings.ToLower(bbEnv.String()) + return nil + } +} + +// WithLogger sets logger for tvm client. +func WithLogger(l log.Structured) Option { + return func(tool *Client) error { + tool.l = l + return nil + } +} + +// WithRefreshFrequency sets service tickets refresh frequency. +// Frequency must be lower chan cacheTTL (10 min) +// +// Default: 8 min +func WithRefreshFrequency(freq time.Duration) Option { + return func(tool *Client) error { + if freq > cacheTTL { + return xerrors.Errorf("refresh frequency must be lower than cacheTTL (%d > %d)", freq, cacheTTL) + } + + tool.refreshFreq = int64(freq.Seconds()) + return nil + } +} + +// WithBackgroundUpdate force Client to update all service ticket at background. +// You must manually cancel given ctx to stops refreshing. +// +// Default: disabled +func WithBackgroundUpdate(ctx context.Context) Option { + return func(tool *Client) error { + tool.bgCtx, tool.bgCancel = context.WithCancel(ctx) + return nil + } +} diff --git a/library/go/yandex/tvm/tvmtool/qloud.go b/library/go/yandex/tvm/tvmtool/qloud.go new file mode 100644 index 0000000000..4dcf0648db --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud.go @@ -0,0 +1,32 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + QloudEndpointEnvKey = "QLOUD_TVM_INTERFACE_ORIGIN" + QloudTokenEnvKey = "QLOUD_TVM_TOKEN" + QloudDefaultEndpoint = "http://localhost:1" +) + +// NewQloudClient method creates a new tvmtool client for Qloud environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewQloudClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(QloudEndpointEnvKey) + if baseURI == "" { + baseURI = QloudDefaultEndpoint + } + + authToken := os.Getenv(QloudTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", QloudTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/qloud_example_test.go b/library/go/yandex/tvm/tvmtool/qloud_example_test.go new file mode 100644 index 0000000000..a6bfcbede6 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud_example_test.go @@ -0,0 +1,71 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewQloudClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewQloudClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithOverrideEnv(tvm.BlackboxProd), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool.go b/library/go/yandex/tvm/tvmtool/tool.go new file mode 100644 index 0000000000..0273902b6f --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool.go @@ -0,0 +1,530 @@ +package tvmtool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/core/xerrors" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +const ( + dialTimeout = 100 * time.Millisecond + requestTimeout = 500 * time.Millisecond + keepAlive = 60 * time.Second + cacheTTL = 10 * time.Minute + cacheMaxTTL = 11 * time.Hour +) + +var _ tvm.Client = (*Client)(nil) + +type ( + Client struct { + lastSync int64 + baseURI string + src string + authToken string + bbEnv string + refreshFreq int64 + bgCtx context.Context + bgCancel context.CancelFunc + inFlightRefresh uint32 + cache *cache.Cache + pingRequest *http.Request + ownHTTPClient bool + httpClient *http.Client + l log.Structured + } + + ticketsResponse map[string]struct { + Error string `json:"error"` + Ticket string `json:"ticket"` + TvmID tvm.ClientID `json:"tvm_id"` + } + + checkSrvResponse struct { + SrcID tvm.ClientID `json:"src"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } + + checkUserResponse struct { + DefaultUID tvm.UID `json:"default_uid"` + UIDs []tvm.UID `json:"uids"` + Scopes []string `json:"scopes"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } +) + +// NewClient method creates a new tvmtool client. +// You must reuse it to prevent connection/goroutines leakage. +func NewClient(apiURI string, opts ...Option) (*Client, error) { + baseURI := strings.TrimRight(apiURI, "/") + "/tvm" + pingRequest, err := http.NewRequest("GET", baseURI+"/ping", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: keepAlive, + }).DialContext + + tool := &Client{ + baseURI: baseURI, + refreshFreq: 8 * 60, + cache: cache.New(cacheTTL, cacheMaxTTL), + pingRequest: pingRequest, + l: &nop.Logger{}, + ownHTTPClient: true, + httpClient: &http.Client{ + Transport: transport, + Timeout: requestTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + } + + for _, opt := range opts { + if err := opt(tool); err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + } + + if tool.bgCtx != nil { + go tool.serviceTicketsRefreshLoop() + } + + return tool, nil +} + +// GetServiceTicketForAlias returns TVM service ticket for alias +// +// WARNING: alias must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.LoadByAlias(alias); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + return "", xerrors.Errorf("tvmtool: alias %q was not found in response", alias) + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(entry.TvmID, alias, &ticket) + } + return ticket, nil +} + +// GetServiceTicketForID returns TVM service ticket for destination application id +// +// WARNING: id must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.Load(dstID); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + alias := strconv.FormatUint(uint64(dstID), 10) + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + // ok, let's find him + for candidateAlias, candidate := range tickets { + if candidate.TvmID == dstID { + entry = candidate + alias = candidateAlias + ok = true + break + } + } + + if !ok { + return "", xerrors.Errorf("tvmtool: dst %q was not found in response", alias) + } + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(dstID, alias, &ticket) + } + return ticket, nil +} + +// Close stops background ticket updates (if configured) and closes idle connections. +func (c *Client) Close() { + if c.bgCancel != nil { + c.bgCancel() + } + + if c.ownHTTPClient { + c.httpClient.CloseIdleConnections() + } +} + +func (c *Client) refreshServiceTickets() { + if c.bgCtx != nil { + // service tickets will be updated at background in the separated goroutine + return + } + + now := time.Now().Unix() + if now-atomic.LoadInt64(&c.lastSync) > c.refreshFreq { + atomic.StoreInt64(&c.lastSync, now) + if atomic.CompareAndSwapUint32(&c.inFlightRefresh, 0, 1) { + go c.doServiceTicketsRefresh(context.Background()) + } + } +} + +func (c *Client) serviceTicketsRefreshLoop() { + var ticker = time.NewTicker(time.Duration(c.refreshFreq) * time.Second) + defer ticker.Stop() + for { + select { + case <-c.bgCtx.Done(): + return + case <-ticker.C: + c.doServiceTicketsRefresh(c.bgCtx) + } + } +} + +func (c *Client) doServiceTicketsRefresh(ctx context.Context) { + defer atomic.CompareAndSwapUint32(&c.inFlightRefresh, 1, 0) + + c.cache.Gc() + aliases := c.cache.Aliases() + if len(aliases) == 0 { + return + } + + c.l.Debug("tvmtool: service ticket update started") + defer c.l.Debug("tvmtool: service ticket update finished") + + // fast path: batch update, must work most of time + err := c.refreshServiceTicket(ctx, aliases...) + if err == nil { + return + } + + if tvmErr, ok := err.(*Error); ok && tvmErr.Code != ErrorBadRequest { + c.l.Error( + "tvmtool: failed to refresh all service tickets at background", + log.Strings("dsts", aliases), + log.Error(err), + ) + + // if we have non "bad request" error - something really terrible happens, nothing to do with it :( + // TODO(buglloc): implement adaptive refreshFreq based on errors? + return + } + + // slow path: trying to update service tickets one by one + c.l.Error( + "tvmtool: failed to refresh all service tickets at background, switched to slow path", + log.Strings("dsts", aliases), + log.Error(err), + ) + + for _, dst := range aliases { + if err := c.refreshServiceTicket(ctx, dst); err != nil { + c.l.Error( + "tvmtool: failed to refresh service ticket at background", + log.String("dst", dst), + log.Error(err), + ) + } + } +} + +func (c *Client) refreshServiceTicket(ctx context.Context, dsts ...string) error { + tickets, err := c.getServiceTickets(ctx, strings.Join(dsts, ",")) + if err != nil { + return err + } + + for _, dst := range dsts { + entry, ok := tickets[dst] + if !ok { + c.l.Error( + "tvmtool: destination was not found in tvmtool response", + log.String("dst", dst), + ) + continue + } + + if entry.Error != "" { + c.l.Error( + "tvmtool: failed to get service ticket for destination", + log.String("dst", dst), + log.String("err", entry.Error), + ) + continue + } + + c.cache.Store(entry.TvmID, dst, &entry.Ticket) + } + return nil +} + +func (c *Client) getServiceTickets(ctx context.Context, dst string) (ticketsResponse, error) { + params := url.Values{ + "dsts": {dst}, + } + if c.src != "" { + params.Set("src", c.src) + } + + req, err := http.NewRequest("GET", c.baseURI+"/tickets?"+params.Encode(), nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + req.Header.Set("Authorization", c.authToken) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result ticketsResponse + err = readResponse(resp, &result) + return result, err +} + +// Check TVM service ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checksrv +func (c *Client) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + req, err := http.NewRequest("GET", c.baseURI+"/checksrv", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + if c.src != "" { + req.URL.RawQuery += "dst=" + url.QueryEscape(c.src) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-Service-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkSrvResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedServiceTicket{ + SrcID: result.SrcID, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Check TVM user ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func (c *Client) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + for range opts { + panic("implement me") + } + + req, err := http.NewRequest("GET", c.baseURI+"/checkusr", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + if c.bbEnv != "" { + req.URL.RawQuery += "override_env=" + url.QueryEscape(c.bbEnv) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-User-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkUserResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedUserTicket{ + DefaultUID: result.DefaultUID, + UIDs: result.UIDs, + Scopes: result.Scopes, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Checks TVMTool liveness +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/ping +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + defer func() { _ = resp.Body.Close() }() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + + var status tvm.ClientStatusInfo + switch resp.StatusCode { + case http.StatusOK: + // OK! + status = tvm.ClientStatusInfo{Status: tvm.ClientOK} + err = nil + case http.StatusPartialContent: + status = tvm.ClientStatusInfo{Status: tvm.ClientWarning} + err = &PingError{Code: PingCodeWarning, Err: xerrors.New(string(body))} + case http.StatusInternalServerError: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeError, Err: xerrors.New(string(body))} + default: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeOther, Err: xerrors.Errorf("tvmtool: unexpected status: %d", resp.StatusCode)} + } + return status, err +} + +// Returns TVMTool version +func (c *Client) Version(ctx context.Context) (string, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return "", xerrors.Errorf("tvmtool: failed to call tmvtool: %w", err) + } + _, _ = ioutil.ReadAll(resp.Body) + _ = resp.Body.Close() + + return resp.Header.Get("Server"), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + return nil, errors.New("not implemented") +} + +func readResponse(resp *http.Response, dst interface{}) error { + body, err := ioutil.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return xerrors.Errorf("tvmtool: failed to read response: %w", err) + } + + switch resp.StatusCode { + case http.StatusOK, http.StatusForbidden: + // ok + return json.Unmarshal(body, dst) + case http.StatusUnauthorized: + return &Error{ + Code: ErrorAuthFail, + Msg: string(body), + } + case http.StatusBadRequest: + return &Error{ + Code: ErrorBadRequest, + Msg: string(body), + } + case http.StatusInternalServerError: + return &Error{ + Code: ErrorOther, + Msg: string(body), + Retriable: true, + } + default: + return &Error{ + Code: ErrorOther, + Msg: fmt.Sprintf("tvmtool: unexpected status: %d, msg: %s", resp.StatusCode, string(body)), + } + } +} diff --git a/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go new file mode 100644 index 0000000000..e1b9f114c0 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go @@ -0,0 +1,354 @@ +package tvmtool_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func newMockClient(upstream string, options ...tvmtool.Option) (*tvmtool.Client, error) { + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + options = append(options, tvmtool.WithLogger(zlog), tvmtool.WithAuthToken("token")) + return tvmtool.NewClient(upstream, options...) +} + +// TestClientBackgroundUpdate_Updatable checks that TVMTool client updates tickets state +func TestClientBackgroundUpdate_Updatable(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(500*time.Millisecond)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(1*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = atomic.NewString("3:serv:original-test-ticket:signature") + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = atomic.NewString("3:serv:original-test-foo-ticket:signature") + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket.Load(), + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket.Load(), + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func(mustEquals bool) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + } + + // populate tickets cache + requestTickets(true) + + // now change tickets + newTicket := "3:serv:changed-test-ticket:signature" + testTicket.Store(newTicket) + testFooTicket.Store("3:serv:changed-test-foo-ticket:signature") + + // wait some time + time.Sleep(2 * time.Second) + + // request new tickets + requestTickets(false) + + // and wait updates some time + for idx := 0; idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + ticket, _ := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + if ticket == newTicket { + break + } + } + + // now out tvmclient MUST returns new tickets + requestTickets(true) + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +// TestClientBackgroundUpdate_NotTooOften checks that TVMTool client request tvmtool not too often +func TestClientBackgroundUpdate_NotTooOften(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(20*time.Second)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(20*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + reqCount = atomic.NewUint32(0) + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = "3:serv:original-test-ticket:signature" + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = "3:serv:original-test-foo-ticket:signature" + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket, + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket, + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func() { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + } + + // populate cache + requestTickets() + + // requests tickets some time that lower than refresh frequency + for i := 0; i < 10; i++ { + requestTickets() + time.Sleep(200 * time.Millisecond) + } + + require.Equal(t, uint32(2), reqCount.Load(), "tvmtool client calls tvmtool too many times") + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +func TestClient_RefreshFrequency(t *testing.T) { + cases := map[string]struct { + freq time.Duration + err bool + }{ + "too_high": { + freq: 20 * time.Minute, + err: true, + }, + "ok": { + freq: 2 * time.Minute, + err: false, + }, + } + + for name, cs := range cases { + t.Run(name, func(t *testing.T) { + _, err := tvmtool.NewClient("fake", tvmtool.WithRefreshFrequency(cs.freq)) + if cs.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestClient_MultipleAliases(t *testing.T) { + reqCount := atomic.NewUint32(0) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ +"test": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456}, +"test_alias": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456} +}`)) + })) + defer srv.Close() + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := newMockClient( + srv.URL, + tvmtool.WithRefreshFrequency(2*time.Second), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + require.NoError(t, err) + + requestTickets := func(t *testing.T) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "test") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), "test_alias") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), tvm.ClientID(2002456)) + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + } + + t.Run("first", requestTickets) + + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 2 - one for each aliases + require.Equal(t, uint32(2), reqCount.Load()) + }) + + // now wait GC + reqCount.Store(0) + time.Sleep(3 * time.Second) + + t.Run("after_gc", requestTickets) + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 1 + require.Equal(t, uint32(1), reqCount.Load()) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_example_test.go b/library/go/yandex/tvm/tvmtool/tool_example_test.go new file mode 100644 index 0000000000..f3e482de91 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_example_test.go @@ -0,0 +1,81 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewClient_backgroundServiceTicketsUpdate() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_export_test.go b/library/go/yandex/tvm/tvmtool/tool_export_test.go new file mode 100644 index 0000000000..7981a2db72 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_export_test.go @@ -0,0 +1,9 @@ +package tvmtool + +func (c *Client) BaseURI() string { + return c.baseURI +} + +func (c *Client) AuthToken() string { + return c.authToken +} diff --git a/library/go/yandex/tvm/tvmtool/tool_test.go b/library/go/yandex/tvm/tvmtool/tool_test.go new file mode 100644 index 0000000000..4329e1d101 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_test.go @@ -0,0 +1,255 @@ +//go:build linux || darwin +// +build linux darwin + +// tvmtool recipe exists only for linux & darwin so we skip another OSes +package tvmtool_test + +import ( + "context" + "fmt" + "io/ioutil" + "regexp" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +const ( + tvmToolPortFile = "tvmtool.port" + tvmToolAuthTokenFile = "tvmtool.authtoken" + userTicketFor1120000000038691 = "3:user" + + ":CA4Q__________9_GjUKCQijrpqRpdT-ARCjrpqRpdT-ARoMYmI6c2Vzc2lvbmlkGgl0ZXN0OnRlc3Qg0oXY" + + "zAQoAw:A-YI2yhoD7BbGU80_dKQ6vm7XADdvgD2QUFCeTI3XZ4MS4N8iENvsNDvYwsW89-vLQPv9pYqn8jxx" + + "awkvu_ZS2aAfpU8vXtnEHvzUQfes2kMjweRJE71cyX8B0VjENdXC5QAfGyK7Y0b4elTDJzw8b28Ro7IFFbNe" + + "qgcPInXndY" + serviceTicketFor41_42 = "3:serv:CBAQ__________9_IgQIKRAq" + + ":VVXL3wkhpBHB7OXSeG0IhqM5AP2CP-gJRD31ksAb-q7pmssBJKtPNbH34BSyLpBllmM1dgOfwL8ICUOGUA3l" + + "jOrwuxZ9H8ayfdrpM7q1-BVPE0sh0L9cd8lwZIW6yHejTe59s6wk1tG5MdSfncdaJpYiF3MwNHSRklNAkb6hx" + + "vg" + serviceTicketFor41_99 = "3:serv:CBAQ__________9_IgQIKRBj" + + ":PjJKDOsEk8VyxZFZwsVnKrW1bRyA82nGd0oIxnEFEf7DBTVZmNuxEejncDrMxnjkKwimrumV9POK4ptTo0ZPY" + + "6Du9zHR5QxekZYwDzFkECVrv9YT2QI03odwZJX8_WCpmlgI8hUog_9yZ5YCYxrQpWaOwDXx4T7VVMwH_Z9YTZk" +) + +var ( + srvTicketRe = regexp.MustCompile(`^3:serv:[A-Za-z0-9_\-]+:[A-Za-z0-9_\-]+$`) +) + +func newTvmToolClient(src string, authToken ...string) (*tvmtool.Client, error) { + raw, err := ioutil.ReadFile(tvmToolPortFile) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(string(raw)) + if err != nil { + return nil, err + } + + var auth string + if len(authToken) > 0 { + auth = authToken[0] + } else { + raw, err = ioutil.ReadFile(tvmToolAuthTokenFile) + if err != nil { + return nil, err + } + auth = string(raw) + } + + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + + return tvmtool.NewClient( + fmt.Sprintf("http://localhost:%d", port), + tvmtool.WithAuthToken(auth), + tvmtool.WithCacheEnabled(false), + tvmtool.WithSrc(src), + tvmtool.WithLogger(zlog), + ) +} + +func TestNewClient(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestClient_GetStatus(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + status, err := client.GetStatus(context.Background()) + require.NoError(t, err, "ping must work") + require.Equal(t, tvm.ClientOK, status.Status) +} + +func TestClient_BadAuth(t *testing.T) { + badClient, err := newTvmToolClient("main", "fake-auth") + require.NoError(t, err) + + _, err = badClient.GetServiceTicketForAlias(context.Background(), "lala") + require.Error(t, err) + require.IsType(t, err, &tvmtool.Error{}) + srvTickerErr := err.(*tvmtool.Error) + require.Equal(t, tvmtool.ErrorAuthFail, srvTickerErr.Code) +} + +func TestClient_GetServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + + t.Run("invalid_alias", func(t *testing.T) { + // Ticket for invalid alias must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForAlias(ctx, "not_exists") + require.Error(t, err, "ticket for invalid alias must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = not_exists (strconv) (code ErrorBadRequest)") + }) + + t.Run("invalid_dst_id", func(t *testing.T) { + // Ticket for invalid client id must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForID(ctx, 123123123) + require.Error(t, err, "ticket for invalid ID must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = 123123123 (by number) (code ErrorBadRequest)") + }) + + t.Run("by_alias", func(t *testing.T) { + // Try to get ticket by alias + t.Parallel() + heTicketByAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he") + if assert.NoError(t, err, "failed to get srv ticket to 'he'") { + assert.Regexp(t, srvTicketRe, heTicketByAlias, "invalid 'he' srv ticket") + } + + heCloneTicketAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he_clone") + if assert.NoError(t, err, "failed to get srv ticket to 'he_clone'") { + assert.Regexp(t, srvTicketRe, heCloneTicketAlias, "invalid 'he_clone' srv ticket") + } + }) + + t.Run("by_dst_id", func(t *testing.T) { + // Try to get ticket by id + t.Parallel() + heTicketByID, err := tvmClient.GetServiceTicketForID(ctx, 100500) + if assert.NoError(t, err, "failed to get srv ticket to '100500'") { + assert.Regexp(t, srvTicketRe, heTicketByID, "invalid '100500' srv ticket") + } + }) +} + +func TestClient_CheckServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + t.Run("self_to_self", func(t *testing.T) { + t.Parallel() + + // Check from self to self + selfTicket, err := tvmClient.GetServiceTicketForAlias(ctx, "self") + require.NoError(t, err, "failed to get service ticket to 'self'") + assert.Regexp(t, srvTicketRe, selfTicket, "invalid 'self' srv ticket") + + // Now we can check srv ticket + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, selfTicket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("to_another", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to self + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_42) + require.NoError(t, err, "failed to check srv ticket 41 -> 42") + + assert.Equal(t, tvm.ClientID(41), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("invalid_dst", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to invalid dst (99) + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_99) + require.Error(t, err, "srv ticket for 41 -> 99 must fails") + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "Wrong ticket dst, expected 42, got 99", ticketErr.Msg) + }) + + t.Run("broken", func(t *testing.T) { + t.Parallel() + + // Check with broken sign + _, err := tvmClient.CheckServiceTicket(ctx, "lalala") + require.Error(t, err, "srv ticket with broken sign must fails") + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "invalid ticket format", ticketErr.Msg) + }) +} + +func TestClient_MultipleClients(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + slaveClient, err := newTvmToolClient("slave") + require.NoError(t, err) + + ctx := context.Background() + + ticket, err := tvmClient.GetServiceTicketForAlias(ctx, "slave") + require.NoError(t, err, "failed to get service ticket to 'slave'") + assert.Regexp(t, srvTicketRe, ticket, "invalid 'slave' srv ticket") + + ticketInfo, err := slaveClient.CheckServiceTicket(ctx, ticket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_CheckUserTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ticketInfo, err := tvmClient.CheckUserTicket(context.Background(), userTicketFor1120000000038691) + require.NoError(t, err, "failed to check user ticket") + + assert.Equal(t, tvm.UID(1120000000038691), ticketInfo.DefaultUID) + assert.Subset(t, []tvm.UID{1120000000038691}, ticketInfo.UIDs) + assert.Subset(t, []string{"bb:sessionid", "test:test"}, ticketInfo.Scopes) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_Version(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + version, err := tvmClient.Version(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, version) +} diff --git a/library/go/yandex/tvm/user_ticket.go b/library/go/yandex/tvm/user_ticket.go new file mode 100644 index 0000000000..e68e5e5032 --- /dev/null +++ b/library/go/yandex/tvm/user_ticket.go @@ -0,0 +1,122 @@ +package tvm + +import ( + "fmt" +) + +// CheckedUserTicket is short-lived user credential. +// +// CheckedUserTicket contains only valid users. +// Details: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#chtoestvusertickete +type CheckedUserTicket struct { + // DefaultUID is default user - maybe 0 + DefaultUID UID + // UIDs is array of valid users - never empty + UIDs []UID + // Env is blackbox environment which created this UserTicket - provides only tvmauth now + Env BlackboxEnv + // Scopes is array of scopes inherited from credential - never empty + Scopes []string + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string +} + +func (t CheckedUserTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +// CheckScopes verify that ALL needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopes(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; !found { + // exit on first nonexistent scope + msg := fmt.Sprintf( + "user ticket doesn't contain one of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } + } + + return nil + } +} + +// CheckScopesAny verify that ANY of needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopesAny(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain any of expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; found { + // exit on first valid scope + return nil + } + } + + msg := fmt.Sprintf( + "user ticket doesn't contain any of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } +} + +type CheckUserTicketOptions struct { + EnvOverride *BlackboxEnv +} + +type CheckUserTicketOption func(*CheckUserTicketOptions) + +func WithBlackboxOverride(env BlackboxEnv) CheckUserTicketOption { + return func(opts *CheckUserTicketOptions) { + opts.EnvOverride = &env + } +} + +type UserTicketACL func(ticket *CheckedUserTicket) error + +func AllowAllUserTickets() UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return nil + } +} + +func CheckAllUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopes(scopes...) + } +} + +func CheckAnyUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopesAny(scopes...) + } +} diff --git a/library/go/yandex/unistat/aggr/aggr.go b/library/go/yandex/unistat/aggr/aggr.go new file mode 100644 index 0000000000..acd933ec55 --- /dev/null +++ b/library/go/yandex/unistat/aggr/aggr.go @@ -0,0 +1,64 @@ +package aggr + +import "a.yandex-team.ru/library/go/yandex/unistat" + +// Histogram returns delta histogram aggregation (dhhh). +func Histogram() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Delta, + Group: unistat.Hgram, + MetaGroup: unistat.Hgram, + Rollup: unistat.Hgram, + } +} + +// AbsoluteHistogram returns absolute histogram aggregation (ahhh). +func AbsoluteHistogram() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Absolute, + Group: unistat.Hgram, + MetaGroup: unistat.Hgram, + Rollup: unistat.Hgram, + } +} + +// Counter returns counter aggregation (dmmm) +func Counter() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Delta, + Group: unistat.Sum, + MetaGroup: unistat.Sum, + Rollup: unistat.Sum, + } +} + +// Absolute returns value aggregation (ammm) +func Absolute() unistat.Aggregation { + return unistat.StructuredAggregation{ + AggregationType: unistat.Absolute, + Group: unistat.Sum, + MetaGroup: unistat.Sum, + Rollup: unistat.Sum, + } +} + +// SummAlias corresponds to _summ suffix +type SummAlias struct{} + +func (s SummAlias) Suffix() string { + return "summ" +} + +// SummAlias corresponds to _hgram suffix +type HgramAlias struct{} + +func (s HgramAlias) Suffix() string { + return "hgram" +} + +// SummAlias corresponds to _max suffix +type MaxAlias struct{} + +func (s MaxAlias) Suffix() string { + return "max" +} diff --git a/library/go/yandex/unistat/histogram.go b/library/go/yandex/unistat/histogram.go new file mode 100644 index 0000000000..7abb9a8a27 --- /dev/null +++ b/library/go/yandex/unistat/histogram.go @@ -0,0 +1,84 @@ +package unistat + +import ( + "encoding/json" + "sync" +) + +// Histogram implements Metric interface +type Histogram struct { + mu sync.RWMutex + name string + priority Priority + aggr Aggregation + + intervals []float64 + weights []int64 + size int64 +} + +// NewHistogram allocates Histogram metric. +// For naming rules see https://wiki.yandex-team.ru/golovan/tagsandsignalnaming. +// Intervals in left edges of histograms buckets (maximum 50 allowed). +func NewHistogram(name string, priority Priority, aggr Aggregation, intervals []float64) *Histogram { + return &Histogram{ + name: name, + priority: priority, + aggr: aggr, + intervals: intervals, + weights: make([]int64, len(intervals)), + } +} + +// Name from Metric interface. +func (h *Histogram) Name() string { + return h.name +} + +// Priority from Metric interface. +func (h *Histogram) Priority() Priority { + return h.priority +} + +// Aggregation from Metric interface. +func (h *Histogram) Aggregation() Aggregation { + return h.aggr +} + +// Update from Metric interface. +func (h *Histogram) Update(value float64) { + h.mu.Lock() + defer h.mu.Unlock() + + for i := len(h.intervals); i > 0; i-- { + if value >= h.intervals[i-1] { + h.weights[i-1]++ + h.size++ + break + } + } +} + +// MarshalJSON from Metric interface. +func (h *Histogram) MarshalJSON() ([]byte, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + buckets := [][2]interface{}{} + for i := range h.intervals { + b := h.intervals[i] + w := h.weights[i] + buckets = append(buckets, [2]interface{}{b, w}) + } + + jsonName := h.name + "_" + h.aggr.Suffix() + return json.Marshal([]interface{}{jsonName, buckets}) +} + +// GetSize returns histogram's values count. +func (h *Histogram) GetSize() int64 { + h.mu.Lock() + defer h.mu.Unlock() + + return h.size +} diff --git a/library/go/yandex/unistat/number.go b/library/go/yandex/unistat/number.go new file mode 100644 index 0000000000..e38ff64475 --- /dev/null +++ b/library/go/yandex/unistat/number.go @@ -0,0 +1,84 @@ +package unistat + +import ( + "encoding/json" + "math" + "sync" +) + +// Numeric implements Metric interface. +type Numeric struct { + mu sync.RWMutex + name string + priority Priority + aggr Aggregation + localAggr AggregationRule + + value float64 +} + +// NewNumeric allocates Numeric value metric. +func NewNumeric(name string, priority Priority, aggr Aggregation, localAggr AggregationRule) *Numeric { + return &Numeric{ + name: name, + priority: priority, + aggr: aggr, + localAggr: localAggr, + } +} + +// Name from Metric interface. +func (n *Numeric) Name() string { + return n.name +} + +// Aggregation from Metric interface. +func (n *Numeric) Aggregation() Aggregation { + return n.aggr +} + +// Priority from Metric interface. +func (n *Numeric) Priority() Priority { + return n.priority +} + +// Update from Metric interface. +func (n *Numeric) Update(value float64) { + n.mu.Lock() + defer n.mu.Unlock() + + switch n.localAggr { + case Max: + n.value = math.Max(n.value, value) + case Min: + n.value = math.Min(n.value, value) + case Sum: + n.value += value + case Last: + n.value = value + default: + n.value = -1 + } +} + +// MarshalJSON from Metric interface. +func (n *Numeric) MarshalJSON() ([]byte, error) { + jsonName := n.name + "_" + n.aggr.Suffix() + return json.Marshal([]interface{}{jsonName, n.GetValue()}) +} + +// GetValue returns current metric value. +func (n *Numeric) GetValue() float64 { + n.mu.RLock() + defer n.mu.RUnlock() + + return n.value +} + +// SetValue sets current metric value. +func (n *Numeric) SetValue(value float64) { + n.mu.Lock() + defer n.mu.Unlock() + + n.value = value +} diff --git a/library/go/yandex/unistat/registry.go b/library/go/yandex/unistat/registry.go new file mode 100644 index 0000000000..0873ab7c66 --- /dev/null +++ b/library/go/yandex/unistat/registry.go @@ -0,0 +1,59 @@ +package unistat + +import ( + "encoding/json" + "sort" + "sync" +) + +type registry struct { + mu sync.Mutex + byName map[string]Metric + + metrics []Metric + unsorted bool +} + +// NewRegistry allocate new registry container for unistat metrics. +func NewRegistry() Registry { + return ®istry{ + byName: map[string]Metric{}, + metrics: []Metric{}, + } +} + +func (r *registry) Register(m Metric) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.byName[m.Name()]; ok { + panic(ErrDuplicate) + } + + r.byName[m.Name()] = m + r.metrics = append(r.metrics, m) + r.unsorted = true +} + +func (r *registry) MarshalJSON() ([]byte, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.unsorted { + sort.Sort(byPriority(r.metrics)) + r.unsorted = false + } + return json.Marshal(r.metrics) +} + +type byPriority []Metric + +func (m byPriority) Len() int { return len(m) } +func (m byPriority) Less(i, j int) bool { + if m[i].Priority() == m[j].Priority() { + return m[i].Name() < m[j].Name() + } + + return m[i].Priority() > m[j].Priority() +} +func (m byPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] } diff --git a/library/go/yandex/unistat/unistat.go b/library/go/yandex/unistat/unistat.go new file mode 100644 index 0000000000..e3664dac38 --- /dev/null +++ b/library/go/yandex/unistat/unistat.go @@ -0,0 +1,170 @@ +package unistat + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +// StructuredAggregation provides type safe API to create an Aggregation. For more +// information see: https://wiki.yandex-team.ru/golovan/aggregation-types/ +type StructuredAggregation struct { + AggregationType AggregationType + Group AggregationRule + MetaGroup AggregationRule + Rollup AggregationRule +} + +// Aggregation defines rules how to aggregate signal on each level. For more +// information see: https://wiki.yandex-team.ru/golovan/aggregation-types/ +type Aggregation interface { + Suffix() string +} + +const ( + AggregationUnknown = "<unknown>" +) + +// Suffix defines signal aggregation on each level: +// 1 - Signal type: absolute (A) or delta (D). +// 2 - Group aggregation. +// 3 - Meta-group aggregation type. +// 4 - Time aggregation for roll-up. +// +// Doc: https://doc.yandex-team.ru/Search/golovan-quickstart/concepts/signal-aggregation.html#agrr-levels +func (a StructuredAggregation) Suffix() string { + return fmt.Sprintf("%s%s%s%s", a.AggregationType, a.Group, a.MetaGroup, a.Rollup) +} + +// Priority is used to order signals in unistat report. +// https://wiki.yandex-team.ru/golovan/stat-handle/#protokol +type Priority int + +// AggregationType is Absolute or Delta. +type AggregationType int + +// Value types +const ( + Absolute AggregationType = iota // Absolute value. Use for gauges. + Delta // Delta value. Use for increasing counters. +) + +func (v AggregationType) String() string { + switch v { + case Absolute: + return "a" + case Delta: + return "d" + default: + return AggregationUnknown + } +} + +// AggregationRule defines aggregation rules: +// +// https://wiki.yandex-team.ru/golovan/aggregation-types/#algoritmyagregacii +type AggregationRule int + +// Aggregation rules +const ( + Hgram AggregationRule = iota // Hgram is histogram aggregation. + Max // Max value. + Min // Min value. + Sum // Sum with default 0. + SumNone // SumNone is sum with default None. + Last // Last value. + Average // Average value. +) + +func (r AggregationRule) String() string { + switch r { + case Hgram: + return "h" + case Max: + return "x" + case Min: + return "n" + case Sum: + return "m" + case SumNone: + return "e" + case Last: + return "t" + case Average: + return "v" + default: + return AggregationUnknown + } +} + +func (r *AggregationRule) UnmarshalText(source []byte) error { + text := string(source) + switch text { + case "h": + *r = Hgram + case "x": + *r = Max + case "n": + *r = Min + case "m": + *r = Sum + case "e": + *r = SumNone + case "t": + *r = Last + case "v": + *r = Average + default: + return fmt.Errorf("unknown aggregation rule '%s'", text) + } + return nil +} + +// ErrDuplicate is raised on duplicate metric name registration. +var ErrDuplicate = errors.New("unistat: duplicate metric") + +// Metric is interface that accepted by Registry. +type Metric interface { + Name() string + Priority() Priority + Aggregation() Aggregation + MarshalJSON() ([]byte, error) +} + +// Updater is interface that wraps basic Update() method. +type Updater interface { + Update(value float64) +} + +// Registry is interface for container that generates stat report +type Registry interface { + Register(metric Metric) + MarshalJSON() ([]byte, error) +} + +var defaultRegistry = NewRegistry() + +// Register metric in default registry. +func Register(metric Metric) { + defaultRegistry.Register(metric) +} + +// MarshalJSON marshals default registry to JSON. +func MarshalJSON() ([]byte, error) { + return json.Marshal(defaultRegistry) +} + +// MeasureMicrosecondsSince updates metric with duration that started +// at ts and ends now. +func MeasureMicrosecondsSince(m Updater, ts time.Time) { + measureMicrosecondsSince(time.Since, m, ts) +} + +// For unittest +type timeSinceFunc func(t time.Time) time.Duration + +func measureMicrosecondsSince(sinceFunc timeSinceFunc, m Updater, ts time.Time) { + dur := sinceFunc(ts) + m.Update(float64(dur / time.Microsecond)) // to microseconds +} diff --git a/library/go/yandex/yplite/spec.go b/library/go/yandex/yplite/spec.go new file mode 100644 index 0000000000..228f9627ef --- /dev/null +++ b/library/go/yandex/yplite/spec.go @@ -0,0 +1,46 @@ +package yplite + +type PodSpec struct { + DNS PodDNS `json:"dns"` + ResourceRequests ResourceRequest `json:"resourceRequests"` + PortoProperties []PortoProperty `json:"portoProperties"` + IP6AddressAllocations []IP6AddressAllocation `json:"ip6AddressAllocations"` +} + +type PodAttributes struct { + ResourceRequirements struct { + CPU struct { + Guarantee uint64 `json:"cpu_guarantee_millicores,string"` + Limit uint64 `json:"cpu_limit_millicores,string"` + } `json:"cpu"` + Memory struct { + Guarantee uint64 `json:"memory_guarantee_bytes,string"` + Limit uint64 `json:"memory_limit_bytes,string"` + } `json:"memory"` + } `json:"resource_requirements"` +} + +type ResourceRequest struct { + CPUGuarantee uint64 `json:"vcpuGuarantee,string"` + CPULimit uint64 `json:"vcpuLimit,string"` + MemoryGuarantee uint64 `json:"memoryGuarantee,string"` + MemoryLimit uint64 `json:"memoryLimit,string"` + AnonymousMemoryLimit uint64 `json:"anonymousMemoryLimit,string"` +} + +type IP6AddressAllocation struct { + Address string `json:"address"` + VlanID string `json:"vlanId"` + PersistentFQDN string `json:"persistentFqdn"` + TransientFQDN string `json:"transientFqdn"` +} + +type PortoProperty struct { + Name string `json:"key"` + Value string `json:"value"` +} + +type PodDNS struct { + PersistentFqdn string `json:"persistentFqdn"` + TransientFqdn string `json:"transientFqdn"` +} diff --git a/library/go/yandex/yplite/yplite.go b/library/go/yandex/yplite/yplite.go new file mode 100644 index 0000000000..32062b3118 --- /dev/null +++ b/library/go/yandex/yplite/yplite.go @@ -0,0 +1,67 @@ +package yplite + +import ( + "context" + "encoding/json" + "net" + "net/http" + "os" + "time" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const ( + PodSocketPath = "/run/iss/pod.socket" + NodeAgentTimeout = 1 * time.Second +) + +var ( + httpClient = http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.DialTimeout("unix", PodSocketPath, NodeAgentTimeout) + }, + }, + Timeout: NodeAgentTimeout, + } +) + +func IsAPIAvailable() bool { + if _, err := os.Stat(PodSocketPath); err == nil { + return true + } + return false +} + +func FetchPodSpec() (*PodSpec, error) { + res, err := httpClient.Get("http://localhost/pod_spec") + if err != nil { + return nil, xerrors.Errorf("failed to request pod spec: %w", err) + } + defer func() { _ = res.Body.Close() }() + + spec := new(PodSpec) + err = json.NewDecoder(res.Body).Decode(spec) + if err != nil { + return nil, xerrors.Errorf("failed to decode pod spec: %w", err) + } + + return spec, nil +} + +func FetchPodAttributes() (*PodAttributes, error) { + res, err := httpClient.Get("http://localhost/pod_attributes") + if err != nil { + return nil, xerrors.Errorf("failed to request pod attributes: %w", err) + } + defer func() { _ = res.Body.Close() }() + + attrs := new(PodAttributes) + err = json.NewDecoder(res.Body).Decode(attrs) + if err != nil { + return nil, xerrors.Errorf("failed to decode pod attributes: %w", err) + } + + return attrs, nil +} |