diff options
author | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
---|---|---|
committer | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
commit | 22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch) | |
tree | bffa27765faf54126ad44bcafa89fadecb7a73d7 /library/go/yandex/tvm | |
parent | 332b99e2173f0425444abb759eebcb2fafaa9209 (diff) | |
download | ydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz |
validate canons without yatest_common
Diffstat (limited to 'library/go/yandex/tvm')
60 files changed, 6849 insertions, 0 deletions
diff --git a/library/go/yandex/tvm/cachedtvm/cache.go b/library/go/yandex/tvm/cachedtvm/cache.go new file mode 100644 index 0000000000..a04e2baf8a --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/cache.go @@ -0,0 +1,22 @@ +package cachedtvm + +import ( + "time" + + "github.com/karlseguin/ccache/v2" +) + +type cache struct { + *ccache.Cache + ttl time.Duration +} + +func (c *cache) Fetch(key string, fn func() (interface{}, error)) (*ccache.Item, error) { + return c.Cache.Fetch(key, c.ttl, fn) +} + +func (c *cache) Stop() { + if c.Cache != nil { + c.Cache.Stop() + } +} diff --git a/library/go/yandex/tvm/cachedtvm/client.go b/library/go/yandex/tvm/cachedtvm/client.go new file mode 100644 index 0000000000..503c973e8c --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client.go @@ -0,0 +1,117 @@ +package cachedtvm + +import ( + "context" + "fmt" + "time" + + "github.com/karlseguin/ccache/v2" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +const ( + DefaultTTL = 1 * time.Minute + DefaultMaxItems = 100 + MaxServiceTicketTTL = 5 * time.Minute + MaxUserTicketTTL = 1 * time.Minute +) + +type CachedClient struct { + tvm.Client + serviceTicketCache cache + userTicketCache cache + userTicketFn func(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) +} + +func NewClient(tvmClient tvm.Client, opts ...Option) (*CachedClient, error) { + newCache := func(o cacheOptions) cache { + return cache{ + Cache: ccache.New( + ccache.Configure().MaxSize(o.maxItems), + ), + ttl: o.ttl, + } + } + + out := &CachedClient{ + Client: tvmClient, + serviceTicketCache: newCache(cacheOptions{ + ttl: DefaultTTL, + maxItems: DefaultMaxItems, + }), + userTicketFn: tvmClient.CheckUserTicket, + } + + for _, opt := range opts { + switch o := opt.(type) { + case OptionServiceTicket: + if o.ttl > MaxServiceTicketTTL { + return nil, fmt.Errorf("maximum TTL for check service ticket exceed: %s > %s", o.ttl, MaxServiceTicketTTL) + } + + out.serviceTicketCache = newCache(o.cacheOptions) + case OptionUserTicket: + if o.ttl > MaxUserTicketTTL { + return nil, fmt.Errorf("maximum TTL for check user ticket exceed: %s > %s", o.ttl, MaxUserTicketTTL) + } + + out.userTicketFn = out.cacheCheckUserTicket + out.userTicketCache = newCache(o.cacheOptions) + default: + panic(fmt.Sprintf("unexpected cache option: %T", o)) + } + } + + return out, nil +} + +func (c *CachedClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + out, err := c.serviceTicketCache.Fetch(ticket, func() (interface{}, error) { + return c.Client.CheckServiceTicket(ctx, ticket) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedServiceTicket), nil +} + +func (c *CachedClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + return c.userTicketFn(ctx, ticket, opts...) +} + +func (c *CachedClient) cacheCheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + cacheKey := func(ticket string, opts ...tvm.CheckUserTicketOption) string { + if len(opts) == 0 { + return ticket + } + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + if options.EnvOverride == nil { + return ticket + } + + return fmt.Sprintf("%d:%s", *options.EnvOverride, ticket) + } + + out, err := c.userTicketCache.Fetch(cacheKey(ticket, opts...), func() (interface{}, error) { + return c.Client.CheckUserTicket(ctx, ticket, opts...) + }) + + if err != nil { + return nil, err + } + + return out.Value().(*tvm.CheckedUserTicket), nil +} + +func (c *CachedClient) Close() { + c.serviceTicketCache.Stop() + c.userTicketCache.Stop() +} diff --git a/library/go/yandex/tvm/cachedtvm/client_example_test.go b/library/go/yandex/tvm/cachedtvm/client_example_test.go new file mode 100644 index 0000000000..a95b1674a3 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_example_test.go @@ -0,0 +1,40 @@ +package cachedtvm_test + +import ( + "context" + "fmt" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient_checkServiceTicket() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + cachedTvmClient, err := cachedtvm.NewClient( + tvmClient, + cachedtvm.WithCheckServiceTicket(1*time.Minute, 1000), + ) + if err != nil { + panic(err) + } + defer cachedTvmClient.Close() + + ticketInfo, err := cachedTvmClient.CheckServiceTicket(context.TODO(), "3:serv:....") + if err != nil { + panic(err) + } + + fmt.Println("ticket info: ", ticketInfo.LogInfo) +} diff --git a/library/go/yandex/tvm/cachedtvm/client_test.go b/library/go/yandex/tvm/cachedtvm/client_test.go new file mode 100644 index 0000000000..a3c3081e30 --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/client_test.go @@ -0,0 +1,195 @@ +package cachedtvm_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm" +) + +const ( + checkPasses = 5 +) + +type mockTvmClient struct { + tvm.Client + checkServiceTicketCalls int + checkUserTicketCalls int +} + +func (c *mockTvmClient) CheckServiceTicket(_ context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + defer func() { c.checkServiceTicketCalls++ }() + + return &tvm.CheckedServiceTicket{ + LogInfo: ticket, + IssuerUID: tvm.UID(c.checkServiceTicketCalls), + }, nil +} + +func (c *mockTvmClient) CheckUserTicket(_ context.Context, ticket string, _ ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer func() { c.checkUserTicketCalls++ }() + + return &tvm.CheckedUserTicket{ + LogInfo: ticket, + DefaultUID: tvm.UID(c.checkUserTicketCalls), + }, nil +} + +func (c *mockTvmClient) GetServiceTicketForAlias(_ context.Context, alias string) (string, error) { + return alias, nil +} + +func checkServiceTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckServiceTicket(context.Background(), "3:serv:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, true) + } +} + +func runNotEqualServiceTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkServiceTickets(t, client, false) + } +} + +func checkUserTickets(t *testing.T, client tvm.Client, equal bool) { + var prev *tvm.CheckedServiceTicket + for i := 0; i < checkPasses; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cur, err := client.CheckUserTicket(context.Background(), "3:user:tst") + require.NoError(t, err) + + if prev == nil { + return + } + + if equal { + require.Equal(t, *prev, *cur) + } else { + require.NotEqual(t, *prev, *cur) + } + }) + } +} + +func runEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, true) + } +} + +func runNotEqualUserTickets(client tvm.Client) func(t *testing.T) { + return func(t *testing.T) { + checkUserTickets(t, client, false) + } +} +func TestDefaultBehavior(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 1, nestedClient.checkServiceTicketCalls) + require.Equal(t, checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckServiceTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runNotEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runNotEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + require.Equal(t, 2*checkPasses, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckUserTicket(10*time.Second, 10)) + require.NoError(t, err) + + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_usr", runEqualUserTickets(client)) + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestCheckServiceAndUserTicket(t *testing.T) { + nestedClient := &mockTvmClient{} + client, err := cachedtvm.NewClient(nestedClient, + cachedtvm.WithCheckServiceTicket(10*time.Second, 10), + cachedtvm.WithCheckUserTicket(10*time.Second, 10), + ) + require.NoError(t, err) + + t.Run("first_pass_srv", runEqualServiceTickets(client)) + t.Run("first_pass_usr", runEqualUserTickets(client)) + time.Sleep(20 * time.Second) + t.Run("second_pass_srv", runEqualServiceTickets(client)) + t.Run("second_pass_usr", runEqualUserTickets(client)) + + require.Equal(t, 2, nestedClient.checkUserTicketCalls) + require.Equal(t, 2, nestedClient.checkServiceTicketCalls) + + ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst") + require.NoError(t, err) + require.Equal(t, "tst", ticket) +} + +func TestErrors(t *testing.T) { + cases := []cachedtvm.Option{ + cachedtvm.WithCheckServiceTicket(12*time.Hour, 1), + cachedtvm.WithCheckUserTicket(30*time.Minute, 1), + } + + nestedClient := &mockTvmClient{} + for i, tc := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + _, err := cachedtvm.NewClient(nestedClient, tc) + require.Error(t, err) + }) + } +} diff --git a/library/go/yandex/tvm/cachedtvm/opts.go b/library/go/yandex/tvm/cachedtvm/opts.go new file mode 100644 index 0000000000..0df9dfa89e --- /dev/null +++ b/library/go/yandex/tvm/cachedtvm/opts.go @@ -0,0 +1,40 @@ +package cachedtvm + +import "time" + +type ( + Option interface{ isCachedOption() } + + cacheOptions struct { + ttl time.Duration + maxItems int64 + } + + OptionServiceTicket struct { + Option + cacheOptions + } + + OptionUserTicket struct { + Option + cacheOptions + } +) + +func WithCheckServiceTicket(ttl time.Duration, maxSize int) Option { + return OptionServiceTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} + +func WithCheckUserTicket(ttl time.Duration, maxSize int) Option { + return OptionUserTicket{ + cacheOptions: cacheOptions{ + ttl: ttl, + maxItems: int64(maxSize), + }, + } +} diff --git a/library/go/yandex/tvm/client.go b/library/go/yandex/tvm/client.go new file mode 100644 index 0000000000..2a969fb1c6 --- /dev/null +++ b/library/go/yandex/tvm/client.go @@ -0,0 +1,56 @@ +package tvm + +//go:generate ya tool mockgen -source=$GOFILE -destination=mocks/tvm.gen.go Client + +import ( + "context" + "fmt" +) + +type ClientStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/client/client_status.h +const ( + ClientOK ClientStatus = iota + ClientWarning + ClientError +) + +func (s ClientStatus) String() string { + switch s { + case ClientOK: + return "OK" + case ClientWarning: + return "Warning" + case ClientError: + return "Error" + default: + return fmt.Sprintf("Unknown%d", s) + } +} + +type ClientStatusInfo struct { + Status ClientStatus + + // This message allows to trigger alert with useful message + // It returns "OK" if Status==Ok + LastError string +} + +// Client allows to use aliases for ClientID. +// +// Alias is local label for ClientID which can be used to avoid this number in every checking case in code. +type Client interface { + GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) + GetServiceTicketForID(ctx context.Context, dstID ClientID) (string, error) + + // CheckServiceTicket returns struct with SrcID: you should check it by yourself with ACL + CheckServiceTicket(ctx context.Context, ticket string) (*CheckedServiceTicket, error) + CheckUserTicket(ctx context.Context, ticket string, opts ...CheckUserTicketOption) (*CheckedUserTicket, error) + GetRoles(ctx context.Context) (*Roles, error) + + // GetStatus returns current status of client: + // * you should trigger your monitoring if status is not Ok + // * it will be unable to operate if status is Invalid + GetStatus(ctx context.Context) (ClientStatusInfo, error) +} diff --git a/library/go/yandex/tvm/context.go b/library/go/yandex/tvm/context.go new file mode 100644 index 0000000000..3a30dbb0b6 --- /dev/null +++ b/library/go/yandex/tvm/context.go @@ -0,0 +1,33 @@ +package tvm + +import "context" + +type ( + serviceTicketContextKey struct{} + userTicketContextKey struct{} +) + +var ( + stKey serviceTicketContextKey + utKey userTicketContextKey +) + +// WithServiceTicket returns copy of the ctx with service ticket attached to it. +func WithServiceTicket(ctx context.Context, t *CheckedServiceTicket) context.Context { + return context.WithValue(ctx, &stKey, t) +} + +// WithUserTicket returns copy of the ctx with user ticket attached to it. +func WithUserTicket(ctx context.Context, t *CheckedUserTicket) context.Context { + return context.WithValue(ctx, &utKey, t) +} + +func ContextServiceTicket(ctx context.Context) (t *CheckedServiceTicket) { + t, _ = ctx.Value(&stKey).(*CheckedServiceTicket) + return +} + +func ContextUserTicket(ctx context.Context) (t *CheckedUserTicket) { + t, _ = ctx.Value(&utKey).(*CheckedUserTicket) + return +} diff --git a/library/go/yandex/tvm/errors.go b/library/go/yandex/tvm/errors.go new file mode 100644 index 0000000000..bd511d05f3 --- /dev/null +++ b/library/go/yandex/tvm/errors.go @@ -0,0 +1,107 @@ +package tvm + +import ( + "errors" + "fmt" +) + +// ErrNotSupported - error to be used within cgo disabled builds. +var ErrNotSupported = errors.New("ticket_parser2 is not available when building with -DCGO_ENABLED=0") + +var ( + ErrTicketExpired = &TicketError{Status: TicketExpired} + ErrTicketInvalidBlackboxEnv = &TicketError{Status: TicketInvalidBlackboxEnv} + ErrTicketInvalidDst = &TicketError{Status: TicketInvalidDst} + ErrTicketInvalidTicketType = &TicketError{Status: TicketInvalidTicketType} + ErrTicketMalformed = &TicketError{Status: TicketMalformed} + ErrTicketMissingKey = &TicketError{Status: TicketMissingKey} + ErrTicketSignBroken = &TicketError{Status: TicketSignBroken} + ErrTicketUnsupportedVersion = &TicketError{Status: TicketUnsupportedVersion} + ErrTicketStatusOther = &TicketError{Status: TicketStatusOther} + ErrTicketInvalidScopes = &TicketError{Status: TicketInvalidScopes} + ErrTicketInvalidSrcID = &TicketError{Status: TicketInvalidSrcID} +) + +type TicketError struct { + Status TicketStatus + Msg string +} + +func (e *TicketError) Is(err error) bool { + otherTickerErr, ok := err.(*TicketError) + if !ok { + return false + } + if e == nil && otherTickerErr == nil { + return true + } + if e == nil || otherTickerErr == nil { + return false + } + return e.Status == otherTickerErr.Status +} + +func (e *TicketError) Error() string { + if e.Msg != "" { + return fmt.Sprintf("tvm: invalid ticket: %s: %s", e.Status, e.Msg) + } + return fmt.Sprintf("tvm: invalid ticket: %s", e.Status) +} + +type ErrorCode int + +// This constants must be in sync with code in go/tvmauth/tvm.cpp:CatchError +const ( + ErrorOK ErrorCode = iota + ErrorMalformedSecret + ErrorMalformedKeys + ErrorEmptyKeys + ErrorNotAllowed + ErrorBrokenTvmClientSettings + ErrorMissingServiceTicket + ErrorPermissionDenied + ErrorOther + + // Go-only errors below + ErrorBadRequest + ErrorAuthFail +) + +func (e ErrorCode) String() string { + switch e { + case ErrorOK: + return "OK" + case ErrorMalformedSecret: + return "MalformedSecret" + case ErrorMalformedKeys: + return "MalformedKeys" + case ErrorEmptyKeys: + return "EmptyKeys" + case ErrorNotAllowed: + return "NotAllowed" + case ErrorBrokenTvmClientSettings: + return "BrokenTvmClientSettings" + case ErrorMissingServiceTicket: + return "MissingServiceTicket" + case ErrorPermissionDenied: + return "PermissionDenied" + case ErrorOther: + return "Other" + case ErrorBadRequest: + return "ErrorBadRequest" + case ErrorAuthFail: + return "AuthFail" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +type Error struct { + Code ErrorCode + Retriable bool + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Msg, e.Code) +} diff --git a/library/go/yandex/tvm/examples/tvm_example_test.go b/library/go/yandex/tvm/examples/tvm_example_test.go new file mode 100644 index 0000000000..2d47502584 --- /dev/null +++ b/library/go/yandex/tvm/examples/tvm_example_test.go @@ -0,0 +1,59 @@ +package tvm_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func ExampleClient_alias() { + blackboxAlias := "blackbox" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "...", + map[string]tvm.ClientID{ + blackboxAlias: 1000501, + }), + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleClient_roles() { + settings := tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("...", nil), + FetchRolesForIdmSystemSlug: "some_idm_system", + DiskCacheDir: "...", + EnableServiceTicketChecking: true, + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.CheckServiceTicket(context.Background(), "3:serv:...") + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + + r, err := c.GetRoles(context.Background()) + if err != nil { + panic(err) + } + fmt.Println(r.GetMeta().Revision) +} diff --git a/library/go/yandex/tvm/mocks/tvm.gen.go b/library/go/yandex/tvm/mocks/tvm.gen.go new file mode 100644 index 0000000000..9f56f65fec --- /dev/null +++ b/library/go/yandex/tvm/mocks/tvm.gen.go @@ -0,0 +1,130 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client.go + +// Package mock_tvm is a generated GoMock package. +package mock_tvm + +import ( + tvm "a.yandex-team.ru/library/go/yandex/tvm" + context "context" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// GetServiceTicketForAlias mocks base method. +func (m *MockClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForAlias", ctx, alias) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForAlias indicates an expected call of GetServiceTicketForAlias. +func (mr *MockClientMockRecorder) GetServiceTicketForAlias(ctx, alias interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForAlias", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForAlias), ctx, alias) +} + +// GetServiceTicketForID mocks base method. +func (m *MockClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceTicketForID", ctx, dstID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceTicketForID indicates an expected call of GetServiceTicketForID. +func (mr *MockClientMockRecorder) GetServiceTicketForID(ctx, dstID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTicketForID", reflect.TypeOf((*MockClient)(nil).GetServiceTicketForID), ctx, dstID) +} + +// CheckServiceTicket mocks base method. +func (m *MockClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckServiceTicket", ctx, ticket) + ret0, _ := ret[0].(*tvm.CheckedServiceTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckServiceTicket indicates an expected call of CheckServiceTicket. +func (mr *MockClientMockRecorder) CheckServiceTicket(ctx, ticket interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckServiceTicket", reflect.TypeOf((*MockClient)(nil).CheckServiceTicket), ctx, ticket) +} + +// CheckUserTicket mocks base method. +func (m *MockClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, ticket} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CheckUserTicket", varargs...) + ret0, _ := ret[0].(*tvm.CheckedUserTicket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckUserTicket indicates an expected call of CheckUserTicket. +func (mr *MockClientMockRecorder) CheckUserTicket(ctx, ticket interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, ticket}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckUserTicket", reflect.TypeOf((*MockClient)(nil).CheckUserTicket), varargs...) +} + +// GetRoles mocks base method. +func (m *MockClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoles", ctx) + ret0, _ := ret[0].(*tvm.Roles) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoles indicates an expected call of GetRoles. +func (mr *MockClientMockRecorder) GetRoles(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoles", reflect.TypeOf((*MockClient)(nil).GetRoles), ctx) +} + +// GetStatus mocks base method. +func (m *MockClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatus", ctx) + ret0, _ := ret[0].(tvm.ClientStatusInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStatus indicates an expected call of GetStatus. +func (mr *MockClientMockRecorder) GetStatus(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatus", reflect.TypeOf((*MockClient)(nil).GetStatus), ctx) +} diff --git a/library/go/yandex/tvm/roles.go b/library/go/yandex/tvm/roles.go new file mode 100644 index 0000000000..03c2a97af6 --- /dev/null +++ b/library/go/yandex/tvm/roles.go @@ -0,0 +1,130 @@ +package tvm + +import ( + "encoding/json" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +func (r *Roles) GetRolesForService(t *CheckedServiceTicket) *ConsumerRoles { + return r.tvmRoles[t.SrcID] +} + +func (r *Roles) GetRolesForUser(t *CheckedUserTicket, uid *UID) (*ConsumerRoles, error) { + if t.Env != BlackboxProdYateam { + return nil, xerrors.Errorf("user ticket must be from ProdYateam, got from %s", t.Env) + } + + if uid == nil { + if t.DefaultUID == 0 { + return nil, xerrors.Errorf("default uid is 0 - it cannot have any role") + } + uid = &t.DefaultUID + } else { + found := false + for _, u := range t.UIDs { + if u == *uid { + found = true + break + } + } + if !found { + return nil, xerrors.Errorf("'uid' must be in user ticket but it is not: %d", *uid) + } + } + + return r.userRoles[*uid], nil +} + +func (r *Roles) GetRaw() []byte { + return r.raw +} + +func (r *Roles) GetMeta() Meta { + return r.meta +} + +func (r *Roles) CheckServiceRole(t *CheckedServiceTicket, roleName string, opts *CheckServiceOptions) bool { + e := r.GetRolesForService(t).GetEntitiesForRole(roleName) + if e == nil { + return false + } + + if opts != nil { + if opts.Entity != nil && !e.ContainsExactEntity(opts.Entity) { + return false + } + } + + return true +} + +func (r *Roles) CheckUserRole(t *CheckedUserTicket, roleName string, opts *CheckUserOptions) (bool, error) { + var uid *UID + if opts != nil && opts.UID != 0 { + uid = &opts.UID + } + + roles, err := r.GetRolesForUser(t, uid) + if err != nil { + return false, err + } + e := roles.GetEntitiesForRole(roleName) + if e == nil { + return false, nil + } + + if opts != nil { + if opts.Entity != nil && !e.ContainsExactEntity(opts.Entity) { + return false, nil + } + } + + return true, nil +} + +func (r *ConsumerRoles) HasRole(roleName string) bool { + return r.GetEntitiesForRole(roleName) != nil +} + +func (r *ConsumerRoles) GetRoles() EntitiesByRoles { + if r == nil { + return nil + } + return r.roles +} + +func (r *ConsumerRoles) GetEntitiesForRole(roleName string) *Entities { + if r == nil { + return nil + } + return r.roles[roleName] +} + +func (r *ConsumerRoles) DebugPrint() string { + tmp := make(map[string][]Entity) + + for k, v := range r.roles { + tmp[k] = v.subtree.entities + } + + res, err := json.MarshalIndent(tmp, "", " ") + if err != nil { + panic(err) + } + return string(res) +} + +func (e *Entities) ContainsExactEntity(entity Entity) bool { + if e == nil { + return false + } + return e.subtree.containsExactEntity(entity) +} + +func (e *Entities) GetEntitiesWithAttrs(entityPart Entity) []Entity { + if e == nil { + return nil + } + return e.subtree.getEntitiesWithAttrs(entityPart) +} diff --git a/library/go/yandex/tvm/roles_entities_index.go b/library/go/yandex/tvm/roles_entities_index.go new file mode 100644 index 0000000000..488ce7fb09 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index.go @@ -0,0 +1,73 @@ +package tvm + +import "sort" + +type entityAttribute struct { + key string + value string +} + +// subTree provides index for fast entity lookup with attributes +// +// or some subset of entity attributes +type subTree struct { + // entities contains entities with attributes from previous branches of tree: + // * root subTree contains all entities + // * next subTree contains entities with {"key#X": "value#X"} + // * next subTree after next contains entities with {"key#X": "value#X", "key#Y": "value#Y"} + // * and so on + // "key#X", "key#Y", ... - are sorted + entities []Entity + // entityLengths provides O(1) for exact entity lookup + entityLengths map[int]interface{} + // entityIds is creation-time crutch + entityIds []int + idxByAttrs *idxByAttrs +} + +type idxByAttrs = map[entityAttribute]*subTree + +func (s *subTree) containsExactEntity(entity Entity) bool { + subtree := s.findSubTree(entity) + if subtree == nil { + return false + } + + _, ok := subtree.entityLengths[len(entity)] + return ok +} + +func (s *subTree) getEntitiesWithAttrs(entityPart Entity) []Entity { + subtree := s.findSubTree(entityPart) + if subtree == nil { + return nil + } + + return subtree.entities +} + +func (s *subTree) findSubTree(e Entity) *subTree { + keys := make([]string, 0, len(e)) + for k := range e { + keys = append(keys, k) + } + sort.Strings(keys) + + res := s + + for _, k := range keys { + if res.idxByAttrs == nil { + return nil + } + + kv := entityAttribute{key: k, value: e[k]} + ok := false + + res, ok = (*res.idxByAttrs)[kv] + if !ok { + return nil + } + } + + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder.go b/library/go/yandex/tvm/roles_entities_index_builder.go new file mode 100644 index 0000000000..20bde16a00 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder.go @@ -0,0 +1,117 @@ +package tvm + +import "sort" + +type stages struct { + keys []string + id uint64 +} + +func createStages(keys []string) stages { + return stages{ + keys: keys, + } +} + +func (s *stages) getNextStage(keys *[]string) bool { + s.id += 1 + *keys = (*keys)[:0] + + for idx := range s.keys { + need := (s.id >> idx) & 0x01 + if need == 1 { + *keys = append(*keys, s.keys[idx]) + } + } + + return len(*keys) > 0 +} + +func buildEntities(entities []Entity) *Entities { + root := make(idxByAttrs) + res := &Entities{ + subtree: subTree{ + idxByAttrs: &root, + }, + } + + stage := createStages(getUniqueSortedKeys(entities)) + + keySet := make([]string, 0, len(stage.keys)) + for stage.getNextStage(&keySet) { + for entityID, entity := range entities { + currentBranch := &res.subtree + + for _, key := range keySet { + entValue, ok := entity[key] + if !ok { + continue + } + + if currentBranch.idxByAttrs == nil { + index := make(idxByAttrs) + currentBranch.idxByAttrs = &index + } + + kv := entityAttribute{key: key, value: entValue} + subtree, ok := (*currentBranch.idxByAttrs)[kv] + if !ok { + subtree = &subTree{} + (*currentBranch.idxByAttrs)[kv] = subtree + } + + currentBranch = subtree + currentBranch.entityIds = append(currentBranch.entityIds, entityID) + res.subtree.entityIds = append(res.subtree.entityIds, entityID) + } + } + } + + postProcessSubTree(&res.subtree, entities) + + return res +} + +func postProcessSubTree(sub *subTree, entities []Entity) { + tmp := make(map[int]interface{}, len(entities)) + for _, e := range sub.entityIds { + tmp[e] = nil + } + sub.entityIds = sub.entityIds[:0] + for i := range tmp { + sub.entityIds = append(sub.entityIds, i) + } + sort.Ints(sub.entityIds) + + sub.entities = make([]Entity, 0, len(sub.entityIds)) + sub.entityLengths = make(map[int]interface{}) + for _, idx := range sub.entityIds { + sub.entities = append(sub.entities, entities[idx]) + sub.entityLengths[len(entities[idx])] = nil + } + sub.entityIds = nil + + if sub.idxByAttrs != nil { + for _, rest := range *sub.idxByAttrs { + postProcessSubTree(rest, entities) + } + } +} + +func getUniqueSortedKeys(entities []Entity) []string { + tmp := map[string]interface{}{} + + for _, e := range entities { + for k := range e { + tmp[k] = nil + } + } + + res := make([]string, 0, len(tmp)) + for k := range tmp { + res = append(res, k) + } + + sort.Strings(res) + return res +} diff --git a/library/go/yandex/tvm/roles_entities_index_builder_test.go b/library/go/yandex/tvm/roles_entities_index_builder_test.go new file mode 100644 index 0000000000..dd795369d5 --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_builder_test.go @@ -0,0 +1,259 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesGetNextStage(t *testing.T) { + s := createStages([]string{"key#1", "key#2", "key#3", "key#4"}) + + results := [][]string{ + {"key#1"}, + {"key#2"}, + {"key#1", "key#2"}, + {"key#3"}, + {"key#1", "key#3"}, + {"key#2", "key#3"}, + {"key#1", "key#2", "key#3"}, + {"key#4"}, + {"key#1", "key#4"}, + {"key#2", "key#4"}, + {"key#1", "key#2", "key#4"}, + {"key#3", "key#4"}, + {"key#1", "key#3", "key#4"}, + {"key#2", "key#3", "key#4"}, + {"key#1", "key#2", "key#3", "key#4"}, + } + + keySet := make([]string, 0) + for idx, exp := range results { + s.getNextStage(&keySet) + require.Equal(t, exp, keySet, idx) + } + + // require.False(t, s.getNextStage(&keySet)) +} + +func TestRolesBuildEntities(t *testing.T) { + type TestCase struct { + in []Entity + out Entities + } + cases := []TestCase{ + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: Entities{subtree: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{1: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#1", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil}, + }, + }, + }, + entityAttribute{key: "key#2", value: "value#2"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{2: nil, 3: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }, + entityAttribute{key: "key#3", value: "value#3"}: &subTree{ + entities: []Entity{ + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + entityLengths: map[int]interface{}{3: nil}, + }, + }, + }}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, *buildEntities(c.in), idx) + } +} + +func TestRolesPostProcessSubTree(t *testing.T) { + type TestCase struct { + in subTree + out subTree + } + + cases := []TestCase{ + { + in: subTree{ + entityIds: []int{1, 1, 1, 1, 1, 2, 0, 0, 0}, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + }, + }, + { + in: subTree{ + entityIds: []int{1, 0}, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entityIds: []int{2, 0, 0}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entityIds: []int{0, 0, 0}, + }, + }, + }, + out: subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + entityLengths: map[int]interface{}{1: nil, 2: nil}, + idxByAttrs: &idxByAttrs{ + entityAttribute{key: "key#1", value: "value#1"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + {"key#3": "value#3"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + entityAttribute{key: "key#4", value: "value#4"}: &subTree{ + entities: []Entity{ + {"key#1": "value#1"}, + }, + entityLengths: map[int]interface{}{1: nil}, + }, + }, + }, + }, + } + + entities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + + for idx, c := range cases { + postProcessSubTree(&c.in, entities) + require.Equal(t, c.out, c.in, idx) + } +} + +func TestRolesGetUniqueSortedKeys(t *testing.T) { + type TestCase struct { + in []Entity + out []string + } + + cases := []TestCase{ + { + in: nil, + out: []string{}, + }, + { + in: []Entity{}, + out: []string{}, + }, + { + in: []Entity{ + {}, + }, + out: []string{}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2"}, + }, + out: []string{"key#1"}, + }, + { + in: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + out: []string{"key#1", "key#2", "key#3"}, + }, + } + + for idx, c := range cases { + require.Equal(t, c.out, getUniqueSortedKeys(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_entities_index_test.go b/library/go/yandex/tvm/roles_entities_index_test.go new file mode 100644 index 0000000000..e1abaa0f0e --- /dev/null +++ b/library/go/yandex/tvm/roles_entities_index_test.go @@ -0,0 +1,113 @@ +package tvm + +import ( + "math/rand" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRolesSubTreeContainsExactEntity(t *testing.T) { + origEntities := []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#1", "key#2": "value#2"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + } + entities := buildEntities(origEntities) + + for _, e := range generatedRandEntities() { + found := false + for _, o := range origEntities { + if reflect.DeepEqual(e, o) { + found = true + break + } + } + + require.Equal(t, found, entities.subtree.containsExactEntity(e), e) + } +} + +func generatedRandEntities() []Entity { + rand.Seed(time.Now().UnixNano()) + + keysStages := createStages([]string{"key#1", "key#2", "key#3", "key#4", "key#5"}) + valuesSet := []string{"value#1", "value#2", "value#3", "value#4", "value#5"} + + res := make([]Entity, 0) + + keySet := make([]string, 0, 5) + for keysStages.getNextStage(&keySet) { + entity := Entity{} + for _, key := range keySet { + entity[key] = valuesSet[rand.Intn(len(valuesSet))] + + e := Entity{} + for k, v := range entity { + e[k] = v + } + res = append(res, e) + } + } + + return res +} + +func TestRolesGetEntitiesWithAttrs(t *testing.T) { + type TestCase struct { + in Entity + out []Entity + } + + cases := []TestCase{ + { + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }, + }, + { + in: Entity{"key#1": "value#1"}, + out: []Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + }, + }, + { + in: Entity{"key#1": "value#2"}, + out: []Entity{ + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#2": "value#2"}, + out: []Entity{ + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + }, + }, + { + in: Entity{"key#3": "value#3"}, + out: []Entity{ + {"key#3": "value#3"}, + }, + }, + } + + entities := buildEntities([]Entity{ + {"key#1": "value#1"}, + {"key#1": "value#1", "key#2": "value#2", "key#4": "value#4"}, + {"key#1": "value#2", "key#2": "value#2"}, + {"key#3": "value#3"}, + }) + + for idx, c := range cases { + require.Equal(t, c.out, entities.subtree.getEntitiesWithAttrs(c.in), idx) + } +} diff --git a/library/go/yandex/tvm/roles_opts.go b/library/go/yandex/tvm/roles_opts.go new file mode 100644 index 0000000000..8e0a0e0608 --- /dev/null +++ b/library/go/yandex/tvm/roles_opts.go @@ -0,0 +1,10 @@ +package tvm + +type CheckServiceOptions struct { + Entity Entity +} + +type CheckUserOptions struct { + Entity Entity + UID UID +} diff --git a/library/go/yandex/tvm/roles_parser.go b/library/go/yandex/tvm/roles_parser.go new file mode 100644 index 0000000000..f46c6b99b0 --- /dev/null +++ b/library/go/yandex/tvm/roles_parser.go @@ -0,0 +1,67 @@ +package tvm + +import ( + "encoding/json" + "strconv" + "time" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +type rawRoles struct { + Revision string `json:"revision"` + BornDate int64 `json:"born_date"` + Tvm rawConsumers `json:"tvm"` + User rawConsumers `json:"user"` +} + +type rawConsumers = map[string]rawConsumerRoles +type rawConsumerRoles = map[string][]Entity + +func NewRoles(buf []byte) (*Roles, error) { + var raw rawRoles + if err := json.Unmarshal(buf, &raw); err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid json: %w", err) + } + + tvmRoles := map[ClientID]*ConsumerRoles{} + for key, value := range raw.Tvm { + id, err := strconv.ParseUint(key, 10, 32) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid tvmid '%s': %w", key, err) + } + tvmRoles[ClientID(id)] = buildConsumerRoles(value) + } + + userRoles := map[UID]*ConsumerRoles{} + for key, value := range raw.User { + id, err := strconv.ParseUint(key, 10, 64) + if err != nil { + return nil, xerrors.Errorf("failed to parse roles: invalid UID '%s': %w", key, err) + } + userRoles[UID(id)] = buildConsumerRoles(value) + } + + return &Roles{ + tvmRoles: tvmRoles, + userRoles: userRoles, + raw: buf, + meta: Meta{ + Revision: raw.Revision, + BornTime: time.Unix(raw.BornDate, 0), + Applied: time.Now(), + }, + }, nil +} + +func buildConsumerRoles(rawConsumerRoles rawConsumerRoles) *ConsumerRoles { + roles := &ConsumerRoles{ + roles: make(EntitiesByRoles, len(rawConsumerRoles)), + } + + for r, ents := range rawConsumerRoles { + roles.roles[r] = buildEntities(ents) + } + + return roles +} diff --git a/library/go/yandex/tvm/roles_parser_test.go b/library/go/yandex/tvm/roles_parser_test.go new file mode 100644 index 0000000000..2b27100ff0 --- /dev/null +++ b/library/go/yandex/tvm/roles_parser_test.go @@ -0,0 +1,88 @@ +package tvm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRolesUserTicketCheckScopes(t *testing.T) { + type TestCase struct { + buf string + roles Roles + err string + } + + cases := []TestCase{ + { + buf: `{"revision":100500}`, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"born_date":1612791978.42}`, + err: "failed to parse roles: invalid json", + }, + { + buf: `{"tvm":{"asd":{}}}`, + err: "failed to parse roles: invalid tvmid 'asd'", + }, + { + buf: `{"user":{"asd":{}}}`, + err: "failed to parse roles: invalid UID 'asd'", + }, + { + buf: `{"tvm":{"1120000000000493":{}}}`, + err: "failed to parse roles: invalid tvmid '1120000000000493'", + }, + { + buf: `{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`, + roles: Roles{ + tvmRoles: map[ClientID]*ConsumerRoles{ + ClientID(2012192): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/impersonator/": {}, + "/group/system/system_on/abc/role/tree_edit/": {}, + }, + }, + }, + userRoles: map[UID]*ConsumerRoles{ + UID(1120000000000493): { + roles: EntitiesByRoles{ + "/group/system/system_on/abc/role/roles_manage/": {}, + }, + }, + }, + raw: []byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`), + meta: Meta{ + Revision: "GYYDEMJUGBQWC", + BornTime: time.Unix(1612791978, 0), + }, + }, + }, + } + + for idx, c := range cases { + r, err := NewRoles([]byte(c.buf)) + if c.err == "" { + require.NoError(t, err, idx) + + r.meta.Applied = time.Time{} + for _, roles := range r.tvmRoles { + for _, v := range roles.roles { + v.subtree = subTree{} + } + } + for _, roles := range r.userRoles { + for _, v := range roles.roles { + v.subtree = subTree{} + } + } + + require.Equal(t, c.roles, *r, idx) + } else { + require.Error(t, err, idx) + require.Contains(t, err.Error(), c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/roles_test.go b/library/go/yandex/tvm/roles_test.go new file mode 100644 index 0000000000..d0c913984f --- /dev/null +++ b/library/go/yandex/tvm/roles_test.go @@ -0,0 +1,116 @@ +package tvm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRolesPublicServiceTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`)) + require.NoError(t, err) + + st := &CheckedServiceTicket{SrcID: 42} + require.Nil(t, roles.GetRolesForService(st)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "/"}})) + + st = &CheckedServiceTicket{SrcID: 2012192} + r := roles.GetRolesForService(st) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/impersonator/": [ + { + "scope": "/" + } + ], + "/group/system/system_on/abc/role/tree_edit/": [ + { + "scope": "/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 2, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/impersonator/")) + require.False(t, roles.CheckServiceRole(st, "/", nil)) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", nil)) + require.False(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity: Entity{"scope": "kek"}})) + require.True(t, roles.CheckServiceRole(st, "/group/system/system_on/abc/role/impersonator/", &CheckServiceOptions{Entity{"scope": "/"}})) + require.Nil(t, r.GetEntitiesForRole("/")) + + en := r.GetEntitiesForRole("/group/system/system_on/abc/role/impersonator/") + require.NotNil(t, en) + require.False(t, en.ContainsExactEntity(Entity{"scope": "kek"})) + require.True(t, en.ContainsExactEntity(Entity{"scope": "/"})) + + require.Nil(t, en.GetEntitiesWithAttrs(Entity{"scope": "kek"})) + require.Equal(t, []Entity{{"scope": "/"}}, en.GetEntitiesWithAttrs(Entity{"scope": "/"})) +} + +func TestRolesPublicUserTicket(t *testing.T) { + roles, err := NewRoles([]byte(`{"revision":"GYYDEMJUGBQWC","born_date":1612791978,"tvm":{"2012192":{"/group/system/system_on/abc/role/impersonator/":[{"scope":"/"}],"/group/system/system_on/abc/role/tree_edit/":[{"scope":"/"}]}},"user":{"1120000000000493":{"/group/system/system_on/abc/role/roles_manage/":[{"scope":"/services/meta_infra/tools/jobjira/"},{"scope":"/services/meta_edu/infrastructure/"}]}}}`)) + require.NoError(t, err) + + ut := &CheckedUserTicket{DefaultUID: 42} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "user ticket must be from ProdYateam, got from Prod") + ut.Env = BlackboxProdYateam + + r, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, r) + ok, err := roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/impersonator/", &CheckUserOptions{Entity: Entity{"scope": "/"}}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 1120000000000493, UIDs: []UID{42}, Env: BlackboxProdYateam} + r, err = roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.NotNil(t, r) + require.EqualValues(t, + `{ + "/group/system/system_on/abc/role/roles_manage/": [ + { + "scope": "/services/meta_infra/tools/jobjira/" + }, + { + "scope": "/services/meta_edu/infrastructure/" + } + ] +}`, + r.DebugPrint(), + ) + require.Equal(t, 1, len(r.GetRoles())) + require.False(t, r.HasRole("/")) + require.True(t, r.HasRole("/group/system/system_on/abc/role/roles_manage/")) + ok, err = roles.CheckUserRole(ut, "/", nil) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", nil) + require.NoError(t, err) + require.True(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "kek"}}) + require.NoError(t, err) + require.False(t, ok) + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{Entity: Entity{"scope": "/services/meta_infra/tools/jobjira/"}}) + require.NoError(t, err) + require.True(t, ok) + + ok, err = roles.CheckUserRole(ut, "/group/system/system_on/abc/role/roles_manage/", &CheckUserOptions{UID: UID(42)}) + require.NoError(t, err) + require.False(t, ok) + + ut = &CheckedUserTicket{DefaultUID: 0, UIDs: []UID{42}, Env: BlackboxProdYateam} + _, err = roles.GetRolesForUser(ut, nil) + require.EqualError(t, err, "default uid is 0 - it cannot have any role") + uid := UID(83) + _, err = roles.GetRolesForUser(ut, &uid) + require.EqualError(t, err, "'uid' must be in user ticket but it is not: 83") +} diff --git a/library/go/yandex/tvm/roles_types.go b/library/go/yandex/tvm/roles_types.go new file mode 100644 index 0000000000..d1bfb07b3c --- /dev/null +++ b/library/go/yandex/tvm/roles_types.go @@ -0,0 +1,30 @@ +package tvm + +import ( + "time" +) + +type Roles struct { + tvmRoles map[ClientID]*ConsumerRoles + userRoles map[UID]*ConsumerRoles + raw []byte + meta Meta +} + +type Meta struct { + Revision string + BornTime time.Time + Applied time.Time +} + +type ConsumerRoles struct { + roles EntitiesByRoles +} + +type EntitiesByRoles = map[string]*Entities + +type Entities struct { + subtree subTree +} + +type Entity = map[string]string diff --git a/library/go/yandex/tvm/service_ticket.go b/library/go/yandex/tvm/service_ticket.go new file mode 100644 index 0000000000..2341ba2b17 --- /dev/null +++ b/library/go/yandex/tvm/service_ticket.go @@ -0,0 +1,50 @@ +package tvm + +import ( + "fmt" +) + +// CheckedServiceTicket is service credential +type CheckedServiceTicket struct { + // SrcID is ID of request source service. You should check SrcID by yourself with your ACL. + SrcID ClientID + // IssuerUID is UID of developer who is debuging something, so he(she) issued CheckedServiceTicket with his(her) ssh-sign: + // it is grant_type=sshkey in tvm-api + // https://wiki.yandex-team.ru/passport/tvm2/debug/#sxoditvapizakrytoeserviceticketami. + IssuerUID UID + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string +} + +func (t *CheckedServiceTicket) CheckSrcID(allowedSrcIDsMap map[uint32]struct{}) error { + if len(allowedSrcIDsMap) == 0 { + return nil + } + if _, allowed := allowedSrcIDsMap[uint32(t.SrcID)]; !allowed { + return &TicketError{ + Status: TicketInvalidSrcID, + Msg: fmt.Sprintf("service ticket srcID is not in allowed srcIDs: %v (actual: %v)", allowedSrcIDsMap, t.SrcID), + } + } + return nil +} + +func (t CheckedServiceTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +type ServiceTicketACL func(ticket *CheckedServiceTicket) error + +func AllowAllServiceTickets() ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return nil + } +} + +func CheckServiceTicketSrcID(allowedSrcIDs map[uint32]struct{}) ServiceTicketACL { + return func(ticket *CheckedServiceTicket) error { + return ticket.CheckSrcID(allowedSrcIDs) + } +} diff --git a/library/go/yandex/tvm/tvm.go b/library/go/yandex/tvm/tvm.go new file mode 100644 index 0000000000..663589efd5 --- /dev/null +++ b/library/go/yandex/tvm/tvm.go @@ -0,0 +1,121 @@ +// This package defines interface which provides fast and cryptographically secure authorization tickets: https://wiki.yandex-team.ru/passport/tvm2/. +// +// Encoded ticket is a valid ASCII string: [0-9a-zA-Z_-:]+. +// +// This package defines interface. All libraries should depend on this package. +// Pure Go implementations of interface is located in library/go/yandex/tvm/tvmtool. +// CGO implementation is located in library/ticket_parser2/go/ticket_parser2. +package tvm + +import ( + "fmt" + "strings" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +// ClientID represents ID of the application. Another name - TvmID. +type ClientID uint32 + +// UID represents ID of the user in Passport. +type UID uint64 + +// BlackboxEnv describes environment of Passport: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#0-opredeljaemsjasokruzhenijami +type BlackboxEnv int + +// This constants must be in sync with EBlackboxEnv from library/cpp/tvmauth/checked_user_ticket.h +const ( + BlackboxProd BlackboxEnv = iota + BlackboxTest + BlackboxProdYateam + BlackboxTestYateam + BlackboxStress +) + +func (e BlackboxEnv) String() string { + switch e { + case BlackboxProd: + return "Prod" + case BlackboxTest: + return "Test" + case BlackboxProdYateam: + return "ProdYateam" + case BlackboxTestYateam: + return "TestYateam" + case BlackboxStress: + return "Stress" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +func BlackboxEnvFromString(envStr string) (BlackboxEnv, error) { + switch strings.ToLower(envStr) { + case "prod": + return BlackboxProd, nil + case "test": + return BlackboxTest, nil + case "prodyateam", "prod_yateam": + return BlackboxProdYateam, nil + case "testyateam", "test_yateam": + return BlackboxTestYateam, nil + case "stress": + return BlackboxStress, nil + default: + return BlackboxEnv(-1), xerrors.Errorf("blackbox env is unknown: '%s'", envStr) + } +} + +type TicketStatus int + +// This constants must be in sync with EStatus from library/cpp/tvmauth/ticket_status.h +const ( + TicketOk TicketStatus = iota + TicketExpired + TicketInvalidBlackboxEnv + TicketInvalidDst + TicketInvalidTicketType + TicketMalformed + TicketMissingKey + TicketSignBroken + TicketUnsupportedVersion + TicketNoRoles + + // Go-only statuses below + TicketStatusOther + TicketInvalidScopes + TicketInvalidSrcID +) + +func (s TicketStatus) String() string { + switch s { + case TicketOk: + return "Ok" + case TicketExpired: + return "Expired" + case TicketInvalidBlackboxEnv: + return "InvalidBlackboxEnv" + case TicketInvalidDst: + return "InvalidDst" + case TicketInvalidTicketType: + return "InvalidTicketType" + case TicketMalformed: + return "Malformed" + case TicketMissingKey: + return "MissingKey" + case TicketSignBroken: + return "SignBroken" + case TicketUnsupportedVersion: + return "UnsupportedVersion" + case TicketNoRoles: + return "NoRoles" + case TicketStatusOther: + return "Other" + case TicketInvalidScopes: + return "InvalidScopes" + case TicketInvalidSrcID: + return "InvalidSrcID" + default: + return fmt.Sprintf("Unknown%d", s) + } +} diff --git a/library/go/yandex/tvm/tvm_test.go b/library/go/yandex/tvm/tvm_test.go new file mode 100644 index 0000000000..3d8f9f0532 --- /dev/null +++ b/library/go/yandex/tvm/tvm_test.go @@ -0,0 +1,246 @@ +package tvm_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +func TestUserTicketCheckScopes(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + "not_all_scopes": { + ticketScopes: []string{"bb:sessionid", "test:test1"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestUserTicketCheckScopesAny(t *testing.T) { + cases := map[string]struct { + ticketScopes []string + requiredScopes []string + err bool + }{ + "wo_required_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: nil, + err: false, + }, + "multiple_scopes_0": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid"}, + err: false, + }, + "multiple_scopes_1": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test"}, + err: false, + }, + "multiple_scopes_2": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"bb:sessionid", "test:test"}, + err: false, + }, + "multiple_scopes_3": { + ticketScopes: []string{"bb:sessionid", "test:test"}, + requiredScopes: []string{"test:test", "bb:sessionid"}, + err: false, + }, + "wo_scopes": { + ticketScopes: nil, + requiredScopes: []string{"bb:sessionid"}, + err: true, + }, + "invalid_scopes": { + ticketScopes: []string{"bb:sessionid"}, + requiredScopes: []string{"test:test"}, + err: true, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedUserTicket{ + Scopes: testCase.ticketScopes, + } + err := ticket.CheckScopes(testCase.requiredScopes...) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidScopes, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestServiceTicketAllowedSrcIDs(t *testing.T) { + cases := map[string]struct { + srcID uint32 + allowedSrcIDs []uint32 + err bool + }{ + "empty_allow_list_allows_any_srcID": {srcID: 162, allowedSrcIDs: []uint32{}, err: false}, + "known_src_id_is_allowed": {srcID: 42, allowedSrcIDs: []uint32{42, 100500}, err: false}, + "unknown_src_id_is_not_allowed": {srcID: 404, allowedSrcIDs: []uint32{42, 100500}, err: true}, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + ticket := tvm.CheckedServiceTicket{ + SrcID: tvm.ClientID(testCase.srcID), + } + allowedSrcIDsMap := make(map[uint32]struct{}, len(testCase.allowedSrcIDs)) + for _, allowedSrcID := range testCase.allowedSrcIDs { + allowedSrcIDsMap[allowedSrcID] = struct{}{} + } + err := ticket.CheckSrcID(allowedSrcIDsMap) + if testCase.err { + require.Error(t, err) + require.IsType(t, &tvm.TicketError{}, err) + ticketErr := err.(*tvm.TicketError) + require.Equal(t, tvm.TicketInvalidSrcID, ticketErr.Status) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTicketError_Is(t *testing.T) { + err1 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err2 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "uh oh", + } + err3 := &tvm.TicketError{ + Status: tvm.TicketInvalidSrcID, + Msg: "other uh oh message", + } + err4 := &tvm.TicketError{ + Status: tvm.TicketExpired, + Msg: "uh oh", + } + err5 := &tvm.TicketError{ + Status: tvm.TicketMalformed, + Msg: "i am completely different", + } + var nilErr *tvm.TicketError = nil + + // ticketErrors are equal to themselves + require.True(t, err1.Is(err1)) + require.True(t, err2.Is(err2)) + require.True(t, nilErr.Is(nilErr)) + + // equal value ticketErrors are equal + require.True(t, err1.Is(err2)) + require.True(t, err2.Is(err1)) + // equal status ticketErrors are equal + require.True(t, err1.Is(err3)) + require.True(t, err1.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err2.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err3.Is(tvm.ErrTicketInvalidSrcID)) + require.True(t, err4.Is(tvm.ErrTicketExpired)) + require.True(t, err5.Is(tvm.ErrTicketMalformed)) + + // different status ticketErrors are not equal + require.False(t, err1.Is(err4)) + + // completely different ticketErrors are not equal + require.False(t, err1.Is(err5)) + + // non-nil ticketErrors are not equal to nil errors + require.False(t, err1.Is(nil)) + require.False(t, err2.Is(nil)) + + // non-nil ticketErrors are not equal to nil ticketErrors + require.False(t, err1.Is(nilErr)) + require.False(t, err2.Is(nilErr)) +} + +func TestBbEnvFromString(t *testing.T) { + type Case struct { + in string + env tvm.BlackboxEnv + err string + } + cases := []Case{ + {in: "prod", env: tvm.BlackboxProd}, + {in: "Prod", env: tvm.BlackboxProd}, + {in: "ProD", env: tvm.BlackboxProd}, + {in: "PROD", env: tvm.BlackboxProd}, + {in: "test", env: tvm.BlackboxTest}, + {in: "prod_yateam", env: tvm.BlackboxProdYateam}, + {in: "ProdYateam", env: tvm.BlackboxProdYateam}, + {in: "test_yateam", env: tvm.BlackboxTestYateam}, + {in: "TestYateam", env: tvm.BlackboxTestYateam}, + {in: "stress", env: tvm.BlackboxStress}, + {in: "", err: "blackbox env is unknown: ''"}, + {in: "kek", err: "blackbox env is unknown: 'kek'"}, + } + + for idx, c := range cases { + res, err := tvm.BlackboxEnvFromString(c.in) + + if c.err == "" { + require.NoError(t, err, idx) + require.Equal(t, c.env, res, idx) + } else { + require.EqualError(t, err, c.err, idx) + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/apitest/.arcignore b/library/go/yandex/tvm/tvmauth/apitest/.arcignore new file mode 100644 index 0000000000..c8a6e77006 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/.arcignore @@ -0,0 +1 @@ +apitest diff --git a/library/go/yandex/tvm/tvmauth/apitest/client_test.go b/library/go/yandex/tvm/tvmauth/apitest/client_test.go new file mode 100644 index 0000000000..8868abe473 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/apitest/client_test.go @@ -0,0 +1,243 @@ +package apitest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + uzap "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func apiSettings(t testing.TB, client tvm.ClientID) tvmauth.TvmAPISettings { + var portStr []byte + portStr, err := ioutil.ReadFile("tvmapi.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + env := tvm.BlackboxProd + + if client == 1000501 { + return tvmauth.TvmAPISettings{ + SelfID: 1000501, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{1000502}), + + TVMHost: "localhost", + TVMPort: port, + } + } else if client == 1000502 { + return tvmauth.TvmAPISettings{ + SelfID: 1000502, + + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "e5kL0vM3nP-nPf-388Hi6Q", + map[string]tvm.ClientID{ + "cl1000501": 1000501, + "cl1000503": 1000503, + }), + + TVMHost: "localhost", + TVMPort: port, + } + } else { + t.Fatalf("Bad client id: %d", client) + return tvmauth.TvmAPISettings{} + } +} + +func TestErrorPassing(t *testing.T) { + _, err := tvmauth.NewAPIClient(tvmauth.TvmAPISettings{}, &nop.Logger{}) + require.Error(t, err) +} + +func TestGetServiceTicketForID(t *testing.T) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(t, 1000501), &nop.Logger{}) + require.NoError(t, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(t, err) + + ticket, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(1000501), ticket.SrcID) + + ticketStrByAlias, err := c1000501.GetServiceTicketForAlias(context.Background(), "1000502") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000501.CheckServiceTicket(context.Background(), ticketStr) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, tvm.TicketInvalidDst, err.(*tvm.TicketError).Status) + + _, err = c1000501.GetServiceTicketForID(context.Background(), 127) + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) + + ticketStr, err = c1000502.GetServiceTicketForID(context.Background(), 1000501) + require.NoError(t, err) + ticketStrByAlias, err = c1000502.GetServiceTicketForAlias(context.Background(), "cl1000501") + require.NoError(t, err) + require.Equal(t, ticketStr, ticketStrByAlias) + + _, err = c1000502.GetServiceTicketForAlias(context.Background(), "1000501") + require.Error(t, err) + require.IsType(t, err, &tvm.Error{}) +} + +func TestLogger(t *testing.T) { + logger, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + require.NoError(t, err) + + core, logs := observer.New(zap.ZapifyLevel(log.DebugLevel)) + logger.L = logger.L.WithOptions(uzap.WrapCore(func(_ zapcore.Core) zapcore.Core { + return core + })) + + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), logger) + require.NoError(t, err) + defer c1000502.Destroy() + + loggedEntries := logs.AllUntimed() + for idx := 0; len(loggedEntries) < 7 && idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + loggedEntries = logs.AllUntimed() + } + + var plainLog string + for _, le := range loggedEntries { + plainLog += le.Message + "\n" + } + + require.Contains( + t, + plainLog, + "Thread-worker started") +} + +func BenchmarkServiceTicket(b *testing.B) { + c1000501, err := tvmauth.NewAPIClient(apiSettings(b, 1000501), &nop.Logger{}) + require.NoError(b, err) + defer c1000501.Destroy() + + c1000502, err := tvmauth.NewAPIClient(apiSettings(b, 1000502), &nop.Logger{}) + require.NoError(b, err) + defer c1000502.Destroy() + + b.Run("GetServiceTicketForID", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + } + }) + }) + + ticketStr, err := c1000501.GetServiceTicketForID(context.Background(), 1000502) + require.NoError(b, err) + + b.Run("CheckServiceTicket", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c1000502.CheckServiceTicket(context.Background(), ticketStr) + require.NoError(b, err) + } + }) + }) +} + +const serviceTicketStr = "3:serv:CBAQ__________9_IggIlJEGELaIPQ:KC8zKTnoM7GQ8UkBixoAlDt7CAuNIO_6J4rzeqelj7wn7vCKBfsy1jlg2UIvBw0JKUUc6116s5aBw1-vr4BD1V0eh0z-k_CSGC4DKKlnBEEAwcpHRjOZUdW_5UJFe-l77KMObvZUPLckWUaQKybMSBYDGrAeo1TqHHmkumwSG5s" +const userTicketStr = "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:LPpzn2ILhY1BHXA1a51mtU1emb2QSMH3UhTxsmL07iJ7m2AMc2xloXCKQOI7uK6JuLDf7aSWd9QQJpaRV0mfPzvFTnz2j78hvO3bY8KT_TshA3A-M5-t5gip8CfTVGPmEPwnuUhmKqAGkGSL-sCHyu1RIjHkGbJA250ThHHKgAY" + +func TestDebugInfo(t *testing.T) { + c1000502, err := tvmauth.NewAPIClient(apiSettings(t, 1000502), &nop.Logger{}) + require.NoError(t, err) + defer c1000502.Destroy() + + ticketS, err := c1000502.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.Equal(t, tvm.ClientID(100500), ticketS.SrcID) + require.Equal(t, tvm.UID(0), ticketS.IssuerUID) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketS, err = c1000502.CheckServiceTicket(context.Background(), serviceTicketStr[:len(serviceTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=serv;expiration_time=9223372036854775807;src=100500;dst=1000502;", ticketS.DbgInfo) + require.Equal(t, "3:serv:CBAQ__________9_IggIlJEGELaIPQ:", ticketS.LogInfo) + + ticketU, err := c1000502.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.Equal(t, []tvm.UID{123, 456}, ticketU.UIDs) + require.Equal(t, tvm.UID(456), ticketU.DefaultUID) + require.Equal(t, []string{"bb:kek", "some:scopes"}, ticketU.Scopes) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) + + _, err = c1000502.CheckUserTicket(context.Background(), userTicketStr, tvm.WithBlackboxOverride(tvm.BlackboxProdYateam)) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketInvalidBlackboxEnv) + + ticketU, err = c1000502.CheckUserTicket(context.Background(), userTicketStr[:len(userTicketStr)-1]) + require.Error(t, err) + require.IsType(t, err, &tvm.TicketError{}) + require.Equal(t, err.(*tvm.TicketError).Status, tvm.TicketSignBroken) + require.Equal(t, "ticket_type=user;expiration_time=9223372036854775807;scope=bb:kek;scope=some:scopes;default_uid=456;uid=123;uid=456;env=Prod;", ticketU.DbgInfo) + require.Equal(t, "3:user:CAsQ__________9_GikKAgh7CgMIyAMQyAMaBmJiOmtlaxoLc29tZTpzY29wZXMg0oXYzAQoAA:", ticketU.LogInfo) +} + +func TestUnittestClient(t *testing.T) { + _, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{}) + require.NoError(t, err) + + client, err := tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{ + SelfID: 1000502, + }) + require.NoError(t, err) + + _, err = client.GetRoles(context.Background()) + require.ErrorContains(t, err, "Roles are not provided") + _, err = client.GetServiceTicketForID(context.Background(), tvm.ClientID(42)) + require.ErrorContains(t, err, "Destination '42' was not specified in settings") + + status, err := client.GetStatus(context.Background()) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientOK, status.Status) + + st, err := client.CheckServiceTicket(context.Background(), serviceTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.ClientID(100500), st.SrcID) + + ut, err := client.CheckUserTicket(context.Background(), userTicketStr) + require.NoError(t, err) + require.EqualValues(t, tvm.UID(456), ut.DefaultUID) +} diff --git a/library/go/yandex/tvm/tvmauth/client.go b/library/go/yandex/tvm/tvmauth/client.go new file mode 100644 index 0000000000..0282b2939f --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client.go @@ -0,0 +1,509 @@ +//go:build cgo +// +build cgo + +package tvmauth + +// #include <stdlib.h> +// +// #include "tvm.h" +import "C" +import ( + "context" + "encoding/json" + "fmt" + "runtime" + "sync" + "unsafe" + + "a.yandex-team.ru/library/go/cgosem" + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// NewIDsOptions creates options for fetching CheckedServiceTicket's with ClientID +func NewIDsOptions(secret string, dsts []tvm.ClientID) *TVMAPIOptions { + tmp := make(map[string]tvm.ClientID) + for _, dst := range dsts { + tmp[fmt.Sprintf("%d", dst)] = dst + } + + res, err := json.Marshal(tmp) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +// NewAliasesOptions creates options for fetching CheckedServiceTicket's with alias+ClientID +func NewAliasesOptions(secret string, dsts map[string]tvm.ClientID) *TVMAPIOptions { + if dsts == nil { + dsts = make(map[string]tvm.ClientID) + } + + res, err := json.Marshal(dsts) + if err != nil { + panic(err) + } + + return &TVMAPIOptions{ + selfSecret: secret, + dstAliases: res, + } +} + +func (o *TvmAPISettings) pack(out *C.TVM_ApiSettings) { + out.SelfId = C.uint32_t(o.SelfID) + + if o.EnableServiceTicketChecking { + out.EnableServiceTicketChecking = 1 + } + + if o.BlackboxEnv != nil { + out.EnableUserTicketChecking = 1 + out.BlackboxEnv = C.int(*o.BlackboxEnv) + } + + if o.FetchRolesForIdmSystemSlug != "" { + o.fetchRolesForIdmSystemSlug = []byte(o.FetchRolesForIdmSystemSlug) + out.IdmSystemSlug = (*C.uchar)(&o.fetchRolesForIdmSystemSlug[0]) + out.IdmSystemSlugSize = C.int(len(o.fetchRolesForIdmSystemSlug)) + } + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } + + if o.TVMHost != "" { + o.tvmHost = []byte(o.TVMHost) + out.TVMHost = (*C.uchar)(&o.tvmHost[0]) + out.TVMHostSize = C.int(len(o.tvmHost)) + } + out.TVMPort = C.int(o.TVMPort) + + if o.TiroleHost != "" { + o.tiroleHost = []byte(o.TiroleHost) + out.TiroleHost = (*C.uchar)(&o.tiroleHost[0]) + out.TiroleHostSize = C.int(len(o.tiroleHost)) + } + out.TirolePort = C.int(o.TirolePort) + out.TiroleTvmId = C.uint32_t(o.TiroleTvmID) + + if o.ServiceTicketOptions != nil { + if (o.ServiceTicketOptions.selfSecret != "") { + o.ServiceTicketOptions.selfSecretB = []byte(o.ServiceTicketOptions.selfSecret) + out.SelfSecret = (*C.uchar)(&o.ServiceTicketOptions.selfSecretB[0]) + out.SelfSecretSize = C.int(len(o.ServiceTicketOptions.selfSecretB)) + } + + if (len(o.ServiceTicketOptions.dstAliases) != 0) { + out.DstAliases = (*C.uchar)(&o.ServiceTicketOptions.dstAliases[0]) + out.DstAliasesSize = C.int(len(o.ServiceTicketOptions.dstAliases)) + } + } + + if o.DiskCacheDir != "" { + o.diskCacheDir = []byte(o.DiskCacheDir) + + out.DiskCacheDir = (*C.uchar)(&o.diskCacheDir[0]) + out.DiskCacheDirSize = C.int(len(o.diskCacheDir)) + } +} + +func (o *TvmToolSettings) pack(out *C.TVM_ToolSettings) { + if o.Alias != "" { + o.alias = []byte(o.Alias) + + out.Alias = (*C.uchar)(&o.alias[0]) + out.AliasSize = C.int(len(o.alias)) + } + + out.Port = C.int(o.Port) + + if o.Hostname != "" { + o.hostname = []byte(o.Hostname) + out.Hostname = (*C.uchar)(&o.hostname[0]) + out.HostnameSize = C.int(len(o.hostname)) + } + + if o.AuthToken != "" { + o.authToken = []byte(o.AuthToken) + out.AuthToken = (*C.uchar)(&o.authToken[0]) + out.AuthTokenSize = C.int(len(o.authToken)) + } + + if o.DisableSrcCheck { + out.DisableSrcCheck = 1 + } + if o.DisableDefaultUIDCheck { + out.DisableDefaultUIDCheck = 1 + } +} + +func (o *TvmUnittestSettings) pack(out *C.TVM_UnittestSettings) { + out.SelfId = C.uint32_t(o.SelfID) + out.BlackboxEnv = C.int(o.BlackboxEnv) +} + +// Destroy stops client and delete it from memory. +// Do not try to use client after destroying it +func (c *Client) Destroy() { + if c.handle == nil { + return + } + + C.TVM_DestroyClient(c.handle) + c.handle = nil + + if c.logger != nil { + unregisterLogger(*c.logger) + } +} + +func unpackString(s *C.TVM_String) string { + if s.Data == nil { + return "" + } + + return C.GoStringN(s.Data, s.Size) +} + +func unpackErr(err *C.TVM_Error) error { + msg := unpackString(&err.Message) + code := tvm.ErrorCode(err.Code) + + if code != 0 { + return &tvm.Error{Code: code, Retriable: err.Retriable != 0, Msg: msg} + } + + return nil +} + +func unpackScopes(scopes *C.TVM_String, scopeSize C.int) (s []string) { + if scopeSize == 0 { + return + } + + s = make([]string, int(scopeSize)) + scopesArr := (*[1 << 30]C.TVM_String)(unsafe.Pointer(scopes)) + + for i := 0; i < int(scopeSize); i++ { + s[i] = C.GoStringN(scopesArr[i].Data, scopesArr[i].Size) + } + + return +} + +func unpackStatus(status C.int) error { + if status == 0 { + return nil + } + + return &tvm.TicketError{Status: tvm.TicketStatus(status)} +} + +func unpackServiceTicket(t *C.TVM_ServiceTicket) (*tvm.CheckedServiceTicket, error) { + ticket := &tvm.CheckedServiceTicket{} + ticket.SrcID = tvm.ClientID(t.SrcId) + ticket.IssuerUID = tvm.UID(t.IssuerUid) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + return ticket, unpackStatus(t.Status) +} + +func unpackUserTicket(t *C.TVM_UserTicket) (*tvm.CheckedUserTicket, error) { + ticket := &tvm.CheckedUserTicket{} + ticket.DefaultUID = tvm.UID(t.DefaultUid) + if t.UidsSize != 0 { + ticket.UIDs = make([]tvm.UID, int(t.UidsSize)) + uids := (*[1 << 30]C.uint64_t)(unsafe.Pointer(t.Uids)) + for i := 0; i < int(t.UidsSize); i++ { + ticket.UIDs[i] = tvm.UID(uids[i]) + } + } + + ticket.Env = tvm.BlackboxEnv(t.Env) + + ticket.Scopes = unpackScopes(t.Scopes, t.ScopesSize) + ticket.DbgInfo = unpackString(&t.DbgInfo) + ticket.LogInfo = unpackString(&t.LogInfo) + return ticket, unpackStatus(t.Status) +} + +func unpackClientStatus(s *C.TVM_ClientStatus) (status tvm.ClientStatusInfo) { + status.Status = tvm.ClientStatus(s.Status) + status.LastError = C.GoStringN(s.LastError.Data, s.LastError.Size) + + return +} + +// NewAPIClient creates client which uses https://tvm-api.yandex.net to get state +func NewAPIClient(options TvmAPISettings, log log.Logger) (*Client, error) { + var settings C.TVM_ApiSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewApiClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// NewToolClient creates client uses local http-interface to get state: http://localhost/tvm/. +// Details: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +func NewToolClient(options TvmToolSettings, log log.Logger) (*Client, error) { + var settings C.TVM_ToolSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + loggerId := registerLogger(log) + client.logger = &loggerId + + var tvmErr C.TVM_Error + C.TVM_NewToolClient(settings, C.int(loggerId), &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + unregisterLogger(loggerId) + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// NewUnittestClient creates client with mocked state. +func NewUnittestClient(options TvmUnittestSettings) (*Client, error) { + var settings C.TVM_UnittestSettings + options.pack(&settings) + + client := &Client{ + mutex: &sync.RWMutex{}, + } + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var tvmErr C.TVM_Error + C.TVM_NewUnittestClient(settings, &client.handle, &tvmErr, &pool) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + runtime.SetFinalizer(client, (*Client).Destroy) + return client, nil +} + +// CheckServiceTicket always checks ticket with keys from memory +func (c *Client) CheckServiceTicket(ctx context.Context, ticketStr string) (*tvm.CheckedServiceTicket, error) { + defer cgosem.S.Acquire().Release() + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket C.TVM_ServiceTicket + var tvmErr C.TVM_Error + C.TVM_CheckServiceTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackServiceTicket(&ticket) +} + +// CheckUserTicket always checks ticket with keys from memory +func (c *Client) CheckUserTicket(ctx context.Context, ticketStr string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + defer cgosem.S.Acquire().Release() + + var options tvm.CheckUserTicketOptions + for _, opt := range opts { + opt(&options) + } + + ticketBytes := []byte(ticketStr) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var bbEnv *C.int + var bbEnvOverrided C.int + if options.EnvOverride != nil { + bbEnvOverrided = C.int(*options.EnvOverride) + bbEnv = &bbEnvOverrided + } + + var ticket C.TVM_UserTicket + var tvmErr C.TVM_Error + C.TVM_CheckUserTicket( + c.handle, + (*C.uchar)(&ticketBytes[0]), C.int(len(ticketBytes)), + bbEnv, + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + + return unpackUserTicket(&ticket) +} + +// GetServiceTicketForAlias always returns ticket from memory +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + defer cgosem.S.Acquire().Release() + + aliasBytes := []byte(alias) + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicketForAlias( + c.handle, + (*C.uchar)(&aliasBytes[0]), C.int(len(aliasBytes)), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +// GetServiceTicketForID always returns ticket from memory +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var ticket *C.char + var tvmErr C.TVM_Error + C.TVM_GetServiceTicket( + c.handle, + C.uint32_t(dstID), + &ticket, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return "", err + } + + return C.GoString(ticket), nil +} + +// GetStatus returns current status of client. +// See detials: https://godoc.yandex-team.ru/pkg/a.yandex-team.ru/library/go/yandex/tvm/#Client +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + var status C.TVM_ClientStatus + var tvmErr C.TVM_Error + C.TVM_GetStatus(c.handle, &status, &tvmErr, &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return tvm.ClientStatusInfo{}, err + } + + return unpackClientStatus(&status), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + defer cgosem.S.Acquire().Release() + + var pool C.TVM_MemPool + defer C.TVM_DestroyMemPool(&pool) + + currentRoles := c.getCurrentRoles() + var currentRevision []byte + var currentRevisionPtr *C.uchar + if currentRoles != nil { + currentRevision = []byte(currentRoles.GetMeta().Revision) + currentRevisionPtr = (*C.uchar)(¤tRevision[0]) + } + + var raw *C.char + var rawSize C.int + var tvmErr C.TVM_Error + C.TVM_GetRoles( + c.handle, + currentRevisionPtr, C.int(len(currentRevision)), + &raw, + &rawSize, + &tvmErr, + &pool) + runtime.KeepAlive(c) + + if err := unpackErr(&tvmErr); err != nil { + return nil, err + } + if raw == nil { + return currentRoles, nil + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if currentRoles != c.roles { + return c.roles, nil + } + + roles, err := tvm.NewRoles(C.GoBytes(unsafe.Pointer(raw), rawSize)) + if err != nil { + return nil, err + } + + c.roles = roles + return c.roles, nil +} + +func (c *Client) getCurrentRoles() *tvm.Roles { + c.mutex.RLock() + defer c.mutex.RUnlock() + return c.roles +} diff --git a/library/go/yandex/tvm/tvmauth/client_example_test.go b/library/go/yandex/tvm/tvmauth/client_example_test.go new file mode 100644 index 0000000000..babf8d51b1 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/client_example_test.go @@ -0,0 +1,182 @@ +package tvmauth_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func ExampleNewAPIClient_getServiceTicketsWithAliases() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_getServiceTicketsWithID() { + blackboxID := tvm.ClientID(1000502) + datasyncID := tvm.ClientID(1000503) + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewIDsOptions( + "bAicxJVa5uVY7MjDlapthw", + []tvm.ClientID{ + blackboxID, + datasyncID, + }), + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForID(context.Background(), blackboxID) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) +} + +func ExampleNewAPIClient_checkServiceTicket() { + // allowed tvm consumers for your service + acl := map[tvm.ClientID]interface{}{} + + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + + serviceTicketStruct, err := c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "service ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if serviceTicketStruct != nil { + response["debug_info"] = serviceTicketStruct.DbgInfo + } + panic(response) // return 403 + } + if _, ok := acl[serviceTicketStruct.SrcID]; !ok { + response := map[string]string{ + "error": fmt.Sprintf("tvm client id is not allowed: %d", serviceTicketStruct.SrcID), + } + panic(response) // return 403 + } + + // proceed... +} + +func ExampleNewAPIClient_checkUserTicket() { + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + serviceTicketFromRequest := "kek" + userTicketFromRequest := "lol" + + _, _ = c.CheckServiceTicket(context.Background(), serviceTicketFromRequest) // See example for this method + + userTicketStruct, err := c.CheckUserTicket(context.Background(), userTicketFromRequest) + if err != nil { + response := map[string]string{ + "error": "user ticket is invalid", + "desc": err.Error(), + "status": err.(*tvm.TicketError).Status.String(), + } + if userTicketStruct != nil { + response["debug_info"] = userTicketStruct.DbgInfo + } + panic(response) // return 403 + } + + fmt.Printf("Got user in request: %d", userTicketStruct.DefaultUID) + // proceed... +} + +func ExampleNewAPIClient_createClientWithAllSettings() { + blackboxAlias := "blackbox" + datasyncAlias := "datasync" + + env := tvm.BlackboxTest + settings := tvmauth.TvmAPISettings{ + SelfID: 1000501, + ServiceTicketOptions: tvmauth.NewAliasesOptions( + "bAicxJVa5uVY7MjDlapthw", + map[string]tvm.ClientID{ + blackboxAlias: 1000502, + datasyncAlias: 1000503, + }), + EnableServiceTicketChecking: true, + BlackboxEnv: &env, + DiskCacheDir: "/var/tmp/cache/tvm/", + } + + _, _ = tvmauth.NewAPIClient(settings, &nop.Logger{}) +} + +func ExampleNewToolClient_getServiceTicketsWithAliases() { + // should be configured in tvmtool + blackboxAlias := "blackbox" + + settings := tvmauth.TvmToolSettings{ + Alias: "my_service", + Port: 18000, + AuthToken: "kek", + } + + c, err := tvmauth.NewToolClient(settings, &nop.Logger{}) + if err != nil { + panic(err) + } + + // ... + + serviceTicket, _ := c.GetServiceTicketForAlias(context.Background(), blackboxAlias) + fmt.Printf("Service ticket for visiting backend: %s", serviceTicket) + // please extrapolate other methods for this way of construction +} diff --git a/library/go/yandex/tvm/tvmauth/doc.go b/library/go/yandex/tvm/tvmauth/doc.go new file mode 100644 index 0000000000..ece7efd3ba --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/doc.go @@ -0,0 +1,10 @@ +// CGO implementation of tvm-interface based on ticket_parser2. +// +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package provides client via tvm-api or tvmtool. +// Also this package provides the most efficient way for checking tickets regardless of the client construction way. +// All scenerios are provided without any request after construction. +// +// You should create client with NewAPIClient() or NewToolClient(). +// Also you need to check status of client with GetStatus(). +package tvmauth diff --git a/library/go/yandex/tvm/tvmauth/logger.go b/library/go/yandex/tvm/tvmauth/logger.go new file mode 100644 index 0000000000..3731b16b65 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/logger.go @@ -0,0 +1,77 @@ +//go:build cgo +// +build cgo + +package tvmauth + +import "C" +import ( + "fmt" + "sync" + + "a.yandex-team.ru/library/go/core/log" +) + +// CGO pointer rules state: +// +// Go code may pass a Go pointer to C provided the Go memory to which it points **does not contain any Go pointers**. +// +// Logger is an interface and contains pointer to implementation. That means, we are forbidden from +// passing Logger to C code. +// +// Instead, we put logger into a global map and pass key to the C code. +// +// This might seem inefficient, but we are not concerned with performance here, since the logger is not on the hot path anyway. + +var ( + loggersLock sync.Mutex + nextSlot int + loggers = map[int]log.Logger{} +) + +func registerLogger(l log.Logger) int { + loggersLock.Lock() + defer loggersLock.Unlock() + + i := nextSlot + nextSlot++ + loggers[i] = l + return i +} + +func unregisterLogger(i int) { + loggersLock.Lock() + defer loggersLock.Unlock() + + if _, ok := loggers[i]; !ok { + panic(fmt.Sprintf("attempt to unregister unknown logger %d", i)) + } + + delete(loggers, i) +} + +func findLogger(i int) log.Logger { + loggersLock.Lock() + defer loggersLock.Unlock() + + return loggers[i] +} + +//export TVM_WriteToLog +// +// TVM_WriteToLog is technical artifact +func TVM_WriteToLog(logger int, level int, msgData *C.char, msgSize C.int) { + l := findLogger(logger) + + msg := C.GoStringN(msgData, msgSize) + + switch level { + case 3: + l.Error(msg) + case 4: + l.Warn(msg) + case 6: + l.Info(msg) + default: + l.Debug(msg) + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go new file mode 100644 index 0000000000..37e467e286 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/client_test.go @@ -0,0 +1,338 @@ +package tiroletest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func getPort(t *testing.T, filename string) int { + body, err := ioutil.ReadFile(filename) + require.NoError(t, err) + + res, err := strconv.Atoi(string(body)) + require.NoError(t, err, "port is invalid: ", filename) + + return res +} + +func createClientWithTirole(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + env := tvm.BlackboxProdYateam + client, err := tvmauth.NewAPIClient( + tvmauth.TvmAPISettings{ + SelfID: 1000502, + ServiceTicketOptions: tvmauth.NewIDsOptions("e5kL0vM3nP-nPf-388Hi6Q", nil), + DiskCacheDir: "./", + FetchRolesForIdmSystemSlug: "some_slug_2", + EnableServiceTicketChecking: true, + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + BlackboxEnv: &env, + TVMHost: "http://localhost", + TVMPort: getPort(t, "tvmapi.port"), + TiroleHost: "http://localhost", + TirolePort: getPort(t, "tirole.port"), + TiroleTvmID: 1000001, + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func createClientWithTvmtool(t *testing.T, disableSrcCheck bool, disableDefaultUIDCheck bool) *tvmauth.Client { + token, err := ioutil.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + client, err := tvmauth.NewToolClient( + tvmauth.TvmToolSettings{ + Alias: "me", + AuthToken: string(token), + DisableSrcCheck: disableSrcCheck, + DisableDefaultUIDCheck: disableDefaultUIDCheck, + Port: getPort(t, "tvmtool.port"), + }, + &nop.Logger{}, + ) + require.NoError(t, err) + + return client +} + +func checkServiceNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000000: tvmknife unittest service -s 1000000000 -d 1000502 + stWithoutRoles := "3:serv:CBAQ__________9_IgoIgJTr3AMQtog9:Sv3SKuDQ4p-2419PKqc1vo9EC128K6Iv7LKck5SyliJZn5gTAqMDAwb9aYWHhf49HTR-Qmsjw4i_Lh-sNhge-JHWi5PTGFJm03CZHOCJG9Y0_G1pcgTfodtAsvDykMxLhiXGB4N84cGhVVqn1pFWz6SPmMeKUPulTt7qH1ifVtQ" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{Status: tvm.TicketNoRoles}, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + st, err := cl.CheckServiceTicket(ctx, stWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res := roles.GetRolesForService(st) + require.Nil(t, res) + } +} + +func checkServiceHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // src=1000000001: tvmknife unittest service -s 1000000001 -d 1000502 + stWithRoles := "3:serv:CBAQ__________9_IgoIgZTr3AMQtog9:EyPympmoLBM6jyiQLcK8ummNmL5IUAdTvKM1do8ppuEgY6yHfto3s_WAKmP9Pf9EiNqPBe18HR7yKmVS7gvdFJY4gP4Ut51ejS-iBPlsbsApJOYTgodQPhkmjHVKIT0ub0pT3fWHQtapb8uimKpGcO6jCfopFQSVG04Ehj7a0jw" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckServiceTicket(ctx, stWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + require.EqualValues(t, + `{ + "/role/service/read/": [], + "/role/service/write/": [ + { + "foo": "bar", + "kek": "lol" + } + ] +}`, + clientRoles.GetRolesForService(checked).DebugPrint(), + ) + + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", nil)) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", nil)) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/foo/", nil)) + + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/read/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + require.False(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"kek": "lol"}, + })) + require.True(t, clientRoles.CheckServiceRole(checked, "/role/service/write/", &tvm.CheckServiceOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + })) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } +} + +func checkUserNoRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1000000000: tvmknife unittest user -d 1000000000 --env prod_yateam + utWithoutRoles := "3:user:CAwQ__________9_GhYKBgiAlOvcAxCAlOvcAyDShdjMBCgC:LloRDlCZ4vd0IUTOj6MD1mxBPgGhS6EevnnWvHgyXmxc--2CVVkAtNKNZJqCJ6GtDY4nknEnYmWvEu6-MInibD-Uk6saI1DN-2Y3C1Wdsz2SJCq2OYgaqQsrM5PagdyP9PLrftkuV_ZluS_FUYebMXPzjJb0L0ALKByMPkCVWuk" + + ctx := context.Background() + + for _, cl := range clientsWithAutoCheck { + _, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.EqualValues(t, + &tvm.TicketError{Status: tvm.TicketNoRoles}, + err, + ) + } + + for _, cl := range clientsWithoutAutoCheck { + ut, err := cl.CheckUserTicket(ctx, utWithoutRoles) + require.NoError(t, err) + + roles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + res, err := roles.GetRolesForUser(ut, nil) + require.NoError(t, err) + require.Nil(t, res) + } +} + +func checkUserHasRoles(t *testing.T, clientsWithAutoCheck, clientsWithoutAutoCheck []tvm.Client) { + // default_uid=1120000000000001: tvmknife unittest user -d 1120000000000001 --env prod_yateam + utWithRoles := "3:user:CAwQ__________9_GhwKCQiBgJiRpdT-ARCBgJiRpdT-ASDShdjMBCgC:SQV7Z9hDpZ_F62XGkSF6yr8PoZHezRp0ZxCINf_iAbT2rlEiO6j4UfLjzwn3EnRXkAOJxuAtTDCnHlrzdh3JgSKK7gciwPstdRT5GGTixBoUU9kI_UlxEbfGBX1DfuDsw_GFQ2eCLu4Svq6jC3ynuqQ41D2RKopYL8Bx8PDZKQc" + + ctx := context.Background() + + check := func(cl tvm.Client) { + checked, err := cl.CheckUserTicket(ctx, utWithRoles) + require.NoError(t, err) + + clientRoles, err := cl.GetRoles(ctx) + require.NoError(t, err) + + ut, err := clientRoles.GetRolesForUser(checked, nil) + require.NoError(t, err) + require.EqualValues(t, + `{ + "/role/user/read/": [ + { + "foo": "bar", + "kek": "lol" + } + ], + "/role/user/write/": [] +}`, + ut.DebugPrint(), + ) + + res, err := clientRoles.CheckUserRole(checked, "/role/user/write/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", nil) + require.NoError(t, err) + require.True(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/foo/", nil) + require.NoError(t, err) + require.False(t, res) + + res, err = clientRoles.CheckUserRole(checked, "/role/user/write/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"kek": "lol"}, + }) + require.NoError(t, err) + require.False(t, res) + res, err = clientRoles.CheckUserRole(checked, "/role/user/read/", &tvm.CheckUserOptions{ + Entity: tvm.Entity{"foo": "bar", "kek": "lol"}, + }) + require.NoError(t, err) + require.True(t, res) + } + + for _, cl := range clientsWithAutoCheck { + check(cl) + } + for _, cl := range clientsWithoutAutoCheck { + check(cl) + } + +} + +func TestRolesFromTiroleCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, false, true) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTiroleCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTirole(t, true, false) + clientWithoutAutoCheck := createClientWithTirole(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckSrc_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, false, true) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkServiceHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_noRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserNoRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} + +func TestRolesFromTvmtoolCheckDefaultUid_HasRoles(t *testing.T) { + clientWithAutoCheck := createClientWithTvmtool(t, true, false) + clientWithoutAutoCheck := createClientWithTvmtool(t, true, true) + + checkUserHasRoles(t, + []tvm.Client{clientWithAutoCheck}, + []tvm.Client{clientWithoutAutoCheck}, + ) + + clientWithAutoCheck.Destroy() + clientWithoutAutoCheck.Destroy() +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml new file mode 100644 index 0000000000..d2fcaead59 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/mapping.yaml @@ -0,0 +1,5 @@ +slugs: + some_slug_2: + tvmid: + - 1000502 + - 1000503 diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json new file mode 100644 index 0000000000..84d85fae19 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/roles/some_slug_2.json @@ -0,0 +1,22 @@ +{ + "revision": "some_revision_2", + "born_date": 1642160002, + "tvm": { + "1000000001": { + "/role/service/read/": [{}], + "/role/service/write/": [{ + "foo": "bar", + "kek": "lol" + }] + } + }, + "user": { + "1120000000000001": { + "/role/user/write/": [{}], + "/role/user/read/": [{ + "foo": "bar", + "kek": "lol" + }] + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg new file mode 100644 index 0000000000..dbb8fcd458 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tiroletest/tvmtool.cfg @@ -0,0 +1,10 @@ +{ + "BbEnvType": 2, + "clients": { + "me": { + "secret": "fake_secret", + "self_tvm_id": 1000502, + "roles_for_idm_slug": "some_slug_2" + } + } +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/.arcignore b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore new file mode 100644 index 0000000000..251ded04a5 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/.arcignore @@ -0,0 +1 @@ +tooltest diff --git a/library/go/yandex/tvm/tvmauth/tooltest/client_test.go b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go new file mode 100644 index 0000000000..a8d68e55ee --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/client_test.go @@ -0,0 +1,57 @@ +package tooltest + +import ( + "context" + "io/ioutil" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +func recipeToolOptions(t *testing.T) tvmauth.TvmToolSettings { + var portStr, token []byte + portStr, err := ioutil.ReadFile("tvmtool.port") + require.NoError(t, err) + + var port int + port, err = strconv.Atoi(string(portStr)) + require.NoError(t, err) + + token, err = ioutil.ReadFile("tvmtool.authtoken") + require.NoError(t, err) + + return tvmauth.TvmToolSettings{Alias: "me", Port: port, AuthToken: string(token)} +} + +func TestToolClient(t *testing.T) { + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &nop.Logger{}) + require.NoError(t, err) + defer c.Destroy() + + t.Run("GetServiceTicketForID", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100500) + require.NoError(t, err) + }) + + t.Run("GetInvalidTicket", func(t *testing.T) { + _, err := c.GetServiceTicketForID(context.Background(), 100999) + require.Error(t, err) + require.IsType(t, &tvm.Error{}, err) + require.Equal(t, tvm.ErrorBrokenTvmClientSettings, err.(*tvm.Error).Code) + }) + + t.Run("ClientStatus", func(t *testing.T) { + status, err := c.GetStatus(context.Background()) + require.NoError(t, err) + + t.Logf("Got client status: %v", status) + + require.Equal(t, tvm.ClientStatus(0), status.Status) + require.Equal(t, "OK", status.LastError) + }) +} diff --git a/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go new file mode 100644 index 0000000000..99e6a5835e --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tooltest/logger_test.go @@ -0,0 +1,34 @@ +package tooltest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/yandex/tvm/tvmauth" +) + +type testLogger struct { + nop.Logger + + msgs []string +} + +func (l *testLogger) Info(msg string, fields ...log.Field) { + l.msgs = append(l.msgs, msg) +} + +func TestLogger(t *testing.T) { + var l testLogger + + c, err := tvmauth.NewToolClient(recipeToolOptions(t), &l) + require.NoError(t, err) + defer c.Destroy() + + time.Sleep(time.Second) + + require.NotEmpty(t, l.msgs) +} diff --git a/library/go/yandex/tvm/tvmauth/tvm.cpp b/library/go/yandex/tvm/tvmauth/tvm.cpp new file mode 100644 index 0000000000..b3d2070df0 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.cpp @@ -0,0 +1,417 @@ +#include "tvm.h" + +#include "_cgo_export.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/tvmauth/client/facade.h> +#include <library/cpp/tvmauth/client/logger.h> +#include <library/cpp/tvmauth/client/mocked_updater.h> +#include <library/cpp/tvmauth/client/misc/utils.h> +#include <library/cpp/tvmauth/client/misc/api/settings.h> +#include <library/cpp/tvmauth/client/misc/roles/roles.h> + +using namespace NTvmAuth; + +void TVM_DestroyMemPool(TVM_MemPool* pool) { + auto freeStr = [](char*& str) { + if (str != nullptr) { + free(str); + str = nullptr; + } + }; + + freeStr(pool->ErrorStr); + + if (pool->Scopes != nullptr) { + free(reinterpret_cast<void*>(pool->Scopes)); + pool->Scopes = nullptr; + } + + if (pool->TicketStr != nullptr) { + delete reinterpret_cast<TString*>(pool->TicketStr); + pool->TicketStr = nullptr; + } + if (pool->RawRolesStr != nullptr) { + delete reinterpret_cast<TString*>(pool->RawRolesStr); + pool->RawRolesStr = nullptr; + } + + if (pool->CheckedUserTicket != nullptr) { + delete reinterpret_cast<TCheckedUserTicket*>(pool->CheckedUserTicket); + pool->CheckedUserTicket = nullptr; + } + + if (pool->CheckedServiceTicket != nullptr) { + delete reinterpret_cast<TCheckedServiceTicket*>(pool->CheckedServiceTicket); + pool->CheckedServiceTicket = nullptr; + } + + freeStr(pool->DbgInfo); + freeStr(pool->LogInfo); + freeStr(pool->LastError.Data); +} + +static void PackStr(TStringBuf in, TVM_String* out, char*& poolStr) noexcept { + out->Data = poolStr = reinterpret_cast<char*>(malloc(in.size())); + out->Size = in.size(); + memcpy(out->Data, in.data(), in.size()); +} + +static void UnpackSettings( + TVM_ApiSettings* in, + NTvmApi::TClientSettings* out) { + if (in->SelfId != 0) { + out->SelfTvmId = in->SelfId; + } + + if (in->EnableServiceTicketChecking != 0) { + out->CheckServiceTickets = true; + } + + if (in->EnableUserTicketChecking != 0) { + out->CheckUserTicketsWithBbEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); + } + + if (in->SelfSecret != nullptr) { + out->Secret = TString(reinterpret_cast<char*>(in->SelfSecret), in->SelfSecretSize); + } + + TStringBuf aliases(reinterpret_cast<char*>(in->DstAliases), in->DstAliasesSize); + if (aliases) { + NJson::TJsonValue doc; + Y_ENSURE(NJson::ReadJsonTree(aliases, &doc), "Invalid json: from go part: " << aliases); + Y_ENSURE(doc.IsMap(), "Dsts is not map: from go part: " << aliases); + + for (const auto& pair : doc.GetMap()) { + Y_ENSURE(pair.second.IsUInteger(), "dstID must be number"); + out->FetchServiceTicketsForDstsWithAliases.emplace(pair.first, pair.second.GetUInteger()); + } + } + + if (in->IdmSystemSlug != nullptr) { + out->FetchRolesForIdmSystemSlug = TString(reinterpret_cast<char*>(in->IdmSystemSlug), in->IdmSystemSlugSize); + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; + } + + if (in->TVMHost != nullptr) { + out->TvmHost = TString(reinterpret_cast<char*>(in->TVMHost), in->TVMHostSize); + out->TvmPort = in->TVMPort; + } + if (in->TiroleHost != nullptr) { + out->TiroleHost = TString(reinterpret_cast<char*>(in->TiroleHost), in->TiroleHostSize); + out->TirolePort = in->TirolePort; + } + if (in->TiroleTvmId != 0) { + out->TiroleTvmId = in->TiroleTvmId; + } + + if (in->DiskCacheDir != nullptr) { + out->DiskCacheDir = TString(reinterpret_cast<char*>(in->DiskCacheDir), in->DiskCacheDirSize); + } +} + +static void UnpackSettings( + TVM_ToolSettings* in, + NTvmTool::TClientSettings* out) { + if (in->Port != 0) { + out->SetPort(in->Port); + } + + if (in->HostnameSize != 0) { + out->SetHostname(TString(reinterpret_cast<char*>(in->Hostname), in->HostnameSize)); + } + + if (in->AuthTokenSize != 0) { + out->SetAuthToken(TString(reinterpret_cast<char*>(in->AuthToken), in->AuthTokenSize)); + } + + out->ShouldCheckSrc = in->DisableSrcCheck == 0; + out->ShouldCheckDefaultUid = in->DisableDefaultUIDCheck == 0; +} + +static void UnpackSettings( + TVM_UnittestSettings* in, + TMockedUpdater::TSettings* out) { + out->SelfTvmId = in->SelfId; + out->UserTicketEnv = static_cast<EBlackboxEnv>(in->BlackboxEnv); +} + +template <class TTicket> +static void PackScopes( + const TScopes& scopes, + TTicket* ticket, + TVM_MemPool* pool) { + if (scopes.empty()) { + return; + } + + pool->Scopes = ticket->Scopes = reinterpret_cast<TVM_String*>(malloc(scopes.size() * sizeof(TVM_String))); + + for (size_t i = 0; i < scopes.size(); i++) { + ticket->Scopes[i].Data = const_cast<char*>(scopes[i].data()); + ticket->Scopes[i].Size = scopes[i].size(); + } + ticket->ScopesSize = scopes.size(); +} + +static void PackUserTicket( + TCheckedUserTicket in, + TVM_UserTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedUserTicket(std::move(in)); + pool->CheckedUserTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->DefaultUid = copy->GetDefaultUid(); + + const auto& uids = copy->GetUids(); + if (!uids.empty()) { + out->Uids = const_cast<TUid*>(uids.data()); + out->UidsSize = uids.size(); + } + + out->Env = static_cast<int>(copy->GetEnv()); + + PackScopes(copy->GetScopes(), out, pool); +} + +static void PackServiceTicket( + TCheckedServiceTicket in, + TVM_ServiceTicket* out, + TVM_MemPool* pool, + TStringBuf originalStr) noexcept { + auto copy = new TCheckedServiceTicket(std::move(in)); + pool->CheckedServiceTicket = reinterpret_cast<void*>(copy); + + PackStr(copy->DebugInfo(), &out->DbgInfo, pool->DbgInfo); + PackStr(NUtils::RemoveTicketSignature(originalStr), &out->LogInfo, pool->LogInfo); + + out->Status = static_cast<int>(copy->GetStatus()); + if (out->Status != static_cast<int>(ETicketStatus::Ok)) { + return; + } + + out->SrcId = copy->GetSrc(); + + auto issuer = copy->GetIssuerUid(); + if (issuer) { + out->IssuerUid = *issuer; + } +} + +template <class F> +static void CatchError(TVM_Error* err, TVM_MemPool* pool, const F& f) { + try { + f(); + } catch (const TMalformedTvmSecretException& ex) { + err->Code = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMalformedTvmKeysException& ex) { + err->Code = 2; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TEmptyTvmKeysException& ex) { + err->Code = 3; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TNotAllowedException& ex) { + err->Code = 4; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TBrokenTvmClientSettings& ex) { + err->Code = 5; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TMissingServiceTicket& ex) { + err->Code = 6; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TPermissionDenied& ex) { + err->Code = 7; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const TRetriableException& ex) { + err->Code = 8; + err->Retriable = 1; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } catch (const std::exception& ex) { + err->Code = 8; + PackStr(ex.what(), &err->Message, pool->ErrorStr); + } +} + +namespace { + class TGoLogger: public ILogger { + public: + TGoLogger(int loggerHandle) + : LoggerHandle_(loggerHandle) + { + } + + void Log(int lvl, const TString& msg) override { + TVM_WriteToLog(LoggerHandle_, lvl, const_cast<char*>(msg.data()), msg.size()); + } + + private: + int LoggerHandle_; + }; + +} + +extern "C" void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + NTvmApi::TClientSettings realSettings; + UnpackSettings(&settings, &realSettings); + + realSettings.LibVersionPrefix = "go_"; + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TString alias(reinterpret_cast<char*>(settings.Alias), settings.AliasSize); + NTvmTool::TClientSettings realSettings(alias); + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(realSettings, MakeIntrusive<TGoLogger>(loggerHandle)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + TMockedUpdater::TSettings realSettings; + UnpackSettings(&settings, &realSettings); + + auto client = new TTvmClient(MakeIntrusiveConst<TMockedUpdater>(realSettings)); + *handle = static_cast<void*>(client); + }); +} + +extern "C" void TVM_DestroyClient(void* handle) { + delete static_cast<TTvmClient*>(handle); +} + +extern "C" void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + + TClientStatus s = client->GetStatus(); + status->Status = static_cast<int>(s.GetCode()); + + PackStr(s.GetLastError(), &status->LastError, pool->LastError.Data); + }); +} + +extern "C" void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + + TMaybe<EBlackboxEnv> optEnv; + if (env) { + optEnv = (EBlackboxEnv)*env; + } + + auto userTicket = client->CheckUserTicket(str, optEnv); + PackUserTicket(std::move(userTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + TStringBuf str(reinterpret_cast<char*>(ticketStr), ticketSize); + auto serviceTicket = client->CheckServiceTicket(str); + PackServiceTicket(std::move(serviceTicket), ticket, pool, str); + }); +} + +extern "C" void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(dstId)); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + auto ticketPtr = new TString(client->GetServiceTicketFor(TString((char*)alias, aliasSize))); + + pool->TicketStr = reinterpret_cast<void*>(ticketPtr); + *ticket = const_cast<char*>(ticketPtr->c_str()); + }); +} + +extern "C" void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool) { + CatchError(err, pool, [&] { + auto client = static_cast<TTvmClient*>(handle); + NTvmAuth::NRoles::TRolesPtr roles = client->GetRoles(); + + if (currentRevision && + roles->GetMeta().Revision == TStringBuf(reinterpret_cast<char*>(currentRevision), currentRevisionSize)) { + return; + } + + auto rawPtr = new TString(roles->GetRaw()); + + pool->RawRolesStr = reinterpret_cast<void*>(rawPtr); + *raw = const_cast<char*>(rawPtr->c_str()); + *rawSize = rawPtr->size(); + }); +} diff --git a/library/go/yandex/tvm/tvmauth/tvm.h b/library/go/yandex/tvm/tvmauth/tvm.h new file mode 100644 index 0000000000..f7c7a5b2bc --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/tvm.h @@ -0,0 +1,192 @@ +#pragma once + +#include <util/system/types.h> + +#include <stdint.h> +#include <time.h> + +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct _TVM_String { + char* Data; + int Size; + } TVM_String; + + // MemPool owns memory allocated by C. + typedef struct { + char* ErrorStr; + void* TicketStr; + void* RawRolesStr; + TVM_String* Scopes; + void* CheckedUserTicket; + void* CheckedServiceTicket; + char* DbgInfo; + char* LogInfo; + TVM_String LastError; + } TVM_MemPool; + + void TVM_DestroyMemPool(TVM_MemPool* pool); + + typedef struct { + int Code; + int Retriable; + + TVM_String Message; + } TVM_Error; + + typedef struct { + int Status; + + ui64 DefaultUid; + + ui64* Uids; + int UidsSize; + + int Env; + + TVM_String* Scopes; + int ScopesSize; + + TVM_String DbgInfo; + TVM_String LogInfo; + } TVM_UserTicket; + + typedef struct { + int Status; + + ui32 SrcId; + + ui64 IssuerUid; + + TVM_String DbgInfo; + TVM_String LogInfo; + } TVM_ServiceTicket; + + typedef struct { + ui32 SelfId; + + int EnableServiceTicketChecking; + + int EnableUserTicketChecking; + int BlackboxEnv; + + unsigned char* SelfSecret; + int SelfSecretSize; + unsigned char* DstAliases; + int DstAliasesSize; + + unsigned char* IdmSystemSlug; + int IdmSystemSlugSize; + int DisableSrcCheck; + int DisableDefaultUIDCheck; + + unsigned char* TVMHost; + int TVMHostSize; + int TVMPort; + unsigned char* TiroleHost; + int TiroleHostSize; + int TirolePort; + ui32 TiroleTvmId; + + unsigned char* DiskCacheDir; + int DiskCacheDirSize; + } TVM_ApiSettings; + + typedef struct { + unsigned char* Alias; + int AliasSize; + + int Port; + + unsigned char* Hostname; + int HostnameSize; + + unsigned char* AuthToken; + int AuthTokenSize; + + int DisableSrcCheck; + int DisableDefaultUIDCheck; + } TVM_ToolSettings; + + typedef struct { + ui32 SelfId; + int BlackboxEnv; + } TVM_UnittestSettings; + + typedef struct { + int Status; + TVM_String LastError; + } TVM_ClientStatus; + + // First argument must be passed by value. "Go code may pass a Go pointer to C + // provided the Go memory to which it points does not contain any Go pointers." + void TVM_NewApiClient( + TVM_ApiSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewToolClient( + TVM_ToolSettings settings, + int loggerHandle, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_NewUnittestClient( + TVM_UnittestSettings settings, + void** handle, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_DestroyClient(void* handle); + + void TVM_GetStatus( + void* handle, + TVM_ClientStatus* status, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckUserTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + int* env, + TVM_UserTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_CheckServiceTicket( + void* handle, + unsigned char* ticketStr, int ticketSize, + TVM_ServiceTicket* ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicket( + void* handle, + ui32 dstId, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetServiceTicketForAlias( + void* handle, + unsigned char* alias, int aliasSize, + char** ticket, + TVM_Error* err, + TVM_MemPool* pool); + + void TVM_GetRoles( + void* handle, + unsigned char* currentRevision, int currentRevisionSize, + char** raw, + int* rawSize, + TVM_Error* err, + TVM_MemPool* pool); + +#ifdef __cplusplus +} +#endif diff --git a/library/go/yandex/tvm/tvmauth/types.go b/library/go/yandex/tvm/tvmauth/types.go new file mode 100644 index 0000000000..e9df007ad1 --- /dev/null +++ b/library/go/yandex/tvm/tvmauth/types.go @@ -0,0 +1,139 @@ +package tvmauth + +import ( + "sync" + "unsafe" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// TvmAPISettings may be used to fetch data from tvm-api +type TvmAPISettings struct { + // SelfID is required for ServiceTicketOptions and EnableServiceTicketChecking + SelfID tvm.ClientID + + // ServiceTicketOptions provides info for fetching Service Tickets from tvm-api + // to allow you send them to your backends. + // + // WARNING: It is not way to provide authorization for incoming ServiceTickets! + // It is way only to send your ServiceTickets to your backend! + ServiceTicketOptions *TVMAPIOptions + + // EnableServiceTicketChecking enables fetching of public keys for signature checking + EnableServiceTicketChecking bool + + // BlackboxEnv with not nil value enables UserTicket checking + // and enables fetching of public keys for signature checking + BlackboxEnv *tvm.BlackboxEnv + + fetchRolesForIdmSystemSlug []byte + // Non-empty FetchRolesForIdmSystemSlug enables roles fetching from tirole + FetchRolesForIdmSystemSlug string + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + + tvmHost []byte + // TVMHost should be used only in tests + TVMHost string + // TVMPort should be used only in tests + TVMPort int + + tiroleHost []byte + // TiroleHost should be used only in tests or for tirole-api-test.yandex.net + TiroleHost string + // TirolePort should be used only in tests + TirolePort int + // TiroleTvmID should be used only in tests or for tirole-api-test.yandex.net + TiroleTvmID tvm.ClientID + + // Directory for disk cache. + // Requires read/write permissions. Permissions will be checked before start. + // WARNING: The same directory can be used only: + // - for TVM clients with the same settings + // OR + // - for new client replacing previous - with another config. + // System user must be the same for processes with these clients inside. + // Implementation doesn't provide other scenarios. + DiskCacheDir string + diskCacheDir []byte +} + +// TVMAPIOptions is part of TvmAPISettings: allows to enable fetching of ServiceTickets +type TVMAPIOptions struct { + selfSecret string + selfSecretB []byte + dstAliases []byte +} + +// TvmToolSettings may be used to fetch data from tvmtool +type TvmToolSettings struct { + // Alias is required: self alias of your tvm ClientID + Alias string + alias []byte + + // By default, client checks src from ServiceTicket or default uid from UserTicket - + // to prevent you from forgetting to check it yourself. + // It does binary checks only: + // ticket gets status NoRoles, if there is no role for src or default uid. + // You need to check roles on your own if you have a non-binary role system or + // you have switched DisableSrcCheck/DisableDefaultUIDCheck + // + // You may need to disable this check in the following cases: + // - You use GetRoles() to provide verbose message (with revision). + // Double check may be inconsistent: + // binary check inside client uses revision of roles X - i.e. src 100500 has no role, + // exact check in your code uses revision of roles Y - i.e. src 100500 has some roles. + DisableSrcCheck bool + // See comment for DisableSrcCheck + DisableDefaultUIDCheck bool + + // Port will be detected with env["DEPLOY_TVM_TOOL_URL"] (provided with Yandex.Deploy), + // otherwise port == 1 (it is ok for Qloud) + Port int + // Hostname == "localhost" by default + Hostname string + hostname []byte + + // AuthToken is protection from SSRF. + // By default it is fetched from env: + // * TVMTOOL_LOCAL_AUTHTOKEN (provided with Yandex.Deploy) + // * QLOUD_TVM_TOKEN (provided with Qloud) + AuthToken string + authToken []byte +} + +type TvmUnittestSettings struct { + // SelfID is required for service ticket checking + SelfID tvm.ClientID + + // Service ticket checking is enabled by default + + // User ticket checking is enabled by default: choose required environment + BlackboxEnv tvm.BlackboxEnv + + // Other features are not supported yet +} + +// Client contains raw pointer for C++ object +type Client struct { + handle unsafe.Pointer + logger *int + + roles *tvm.Roles + mutex *sync.RWMutex +} + +var _ tvm.Client = (*Client)(nil) diff --git a/library/go/yandex/tvm/tvmtool/any.go b/library/go/yandex/tvm/tvmtool/any.go new file mode 100644 index 0000000000..5c394af771 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any.go @@ -0,0 +1,37 @@ +package tvmtool + +import ( + "os" + + "a.yandex-team.ru/library/go/core/xerrors" +) + +const ( + LocalEndpointEnvKey = "TVMTOOL_URL" + LocalTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +var ErrUnknownTvmtoolEnvironment = xerrors.NewSentinel("unknown tvmtool environment") + +// NewAnyClient method creates a new tvmtool client with environment auto-detection. +// You must reuse it to prevent connection/goroutines leakage. +func NewAnyClient(opts ...Option) (*Client, error) { + switch { + case os.Getenv(QloudEndpointEnvKey) != "": + // it's Qloud + return NewQloudClient(opts...) + case os.Getenv(DeployEndpointEnvKey) != "": + // it's Y.Deploy + return NewDeployClient(opts...) + case os.Getenv(LocalEndpointEnvKey) != "": + passedOpts := append( + []Option{ + WithAuthToken(os.Getenv(LocalTokenEnvKey)), + }, + opts..., + ) + return NewClient(os.Getenv(LocalEndpointEnvKey), passedOpts...) + default: + return nil, ErrUnknownTvmtoolEnvironment.WithFrame() + } +} diff --git a/library/go/yandex/tvm/tvmtool/any_example_test.go b/library/go/yandex/tvm/tvmtool/any_example_test.go new file mode 100644 index 0000000000..d5959426bc --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/any_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewAnyClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewAnyClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewAnyClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/clients_test.go b/library/go/yandex/tvm/tvmtool/clients_test.go new file mode 100644 index 0000000000..5bf34b93fd --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/clients_test.go @@ -0,0 +1,154 @@ +//go:build linux || darwin +// +build linux darwin + +package tvmtool_test + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func TestNewClients(t *testing.T) { + type TestCase struct { + env map[string]string + willFail bool + expectedErr string + expectedBaseURI string + expectedAuthToken string + } + + cases := map[string]struct { + constructor func(opts ...tvmtool.Option) (*tvmtool.Client, error) + cases map[string]TestCase + }{ + "qloud": { + constructor: tvmtool.NewQloudClient, + cases: map[string]TestCase{ + "no-auth": { + willFail: true, + expectedErr: "empty auth token (looked at ENV[QLOUD_TVM_TOKEN])", + }, + "ok-default-origin": { + env: map[string]string{ + "QLOUD_TVM_TOKEN": "ok-default-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1/tvm", + expectedAuthToken: "ok-default-origin-token", + }, + "ok-custom-origin": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://localhost:9000", + "QLOUD_TVM_TOKEN": "ok-custom-origin-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:9000/tvm", + expectedAuthToken: "ok-custom-origin-token", + }, + }, + }, + "deploy": { + constructor: tvmtool.NewDeployClient, + cases: map[string]TestCase{ + "no-url": { + willFail: true, + expectedErr: "empty tvmtool url (looked at ENV[DEPLOY_TVM_TOOL_URL])", + }, + "no-auth": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:2", + }, + willFail: true, + expectedErr: "empty auth token (looked at ENV[TVMTOOL_LOCAL_AUTHTOKEN])", + }, + "ok": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://localhost:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-token", + }, + willFail: false, + expectedBaseURI: "http://localhost:1337/tvm", + expectedAuthToken: "ok-token", + }, + }, + }, + "any": { + constructor: tvmtool.NewAnyClient, + cases: map[string]TestCase{ + "empty": { + willFail: true, + expectedErr: "unknown tvmtool environment", + }, + "ok-qloud": { + env: map[string]string{ + "QLOUD_TVM_INTERFACE_ORIGIN": "http://qloud:9000", + "QLOUD_TVM_TOKEN": "ok-qloud", + }, + expectedBaseURI: "http://qloud:9000/tvm", + expectedAuthToken: "ok-qloud", + }, + "ok-deploy": { + env: map[string]string{ + "DEPLOY_TVM_TOOL_URL": "http://deploy:1337", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-deploy", + }, + expectedBaseURI: "http://deploy:1337/tvm", + expectedAuthToken: "ok-deploy", + }, + "ok-local": { + env: map[string]string{ + "TVMTOOL_URL": "http://local:1338", + "TVMTOOL_LOCAL_AUTHTOKEN": "ok-local", + }, + willFail: false, + expectedBaseURI: "http://local:1338/tvm", + expectedAuthToken: "ok-local", + }, + }, + }, + } + + // NB! this checks are not thread safe, never use t.Parallel() and so on + for clientName, client := range cases { + t.Run(clientName, func(t *testing.T) { + for name, tc := range client.cases { + t.Run(name, func(t *testing.T) { + savedEnv := os.Environ() + defer func() { + os.Clearenv() + for _, env := range savedEnv { + parts := strings.SplitN(env, "=", 2) + err := os.Setenv(parts[0], parts[1]) + require.NoError(t, err) + } + }() + + os.Clearenv() + for key, val := range tc.env { + _ = os.Setenv(key, val) + } + + tvmClient, err := client.constructor() + if tc.willFail { + require.Error(t, err) + if tc.expectedErr != "" { + require.EqualError(t, err, tc.expectedErr) + } + + require.Nil(t, tvmClient) + } else { + require.NoError(t, err) + require.NotNil(t, tvmClient) + require.Equal(t, tc.expectedBaseURI, tvmClient.BaseURI()) + require.Equal(t, tc.expectedAuthToken, tvmClient.AuthToken()) + } + }) + } + }) + } +} diff --git a/library/go/yandex/tvm/tvmtool/deploy.go b/library/go/yandex/tvm/tvmtool/deploy.go new file mode 100644 index 0000000000..d7a2eac62b --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy.go @@ -0,0 +1,31 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + DeployEndpointEnvKey = "DEPLOY_TVM_TOOL_URL" + DeployTokenEnvKey = "TVMTOOL_LOCAL_AUTHTOKEN" +) + +// NewDeployClient method creates a new tvmtool client for Deploy environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewDeployClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(DeployEndpointEnvKey) + if baseURI == "" { + return nil, fmt.Errorf("empty tvmtool url (looked at ENV[%s])", DeployEndpointEnvKey) + } + + authToken := os.Getenv(DeployTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", DeployTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/deploy_example_test.go b/library/go/yandex/tvm/tvmtool/deploy_example_test.go new file mode 100644 index 0000000000..d352336d58 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/deploy_example_test.go @@ -0,0 +1,70 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewDeployClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.TODO(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewDeployClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewDeployClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/doc.go b/library/go/yandex/tvm/tvmtool/doc.go new file mode 100644 index 0000000000..d46dca8132 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/doc.go @@ -0,0 +1,7 @@ +// Pure Go implementation of tvm-interface based on TVMTool client. +// +// https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/. +// Package allows you to get service/user TVM-tickets, as well as check them. +// This package can provide fast getting of service tickets (from cache), other cases lead to http request to localhost. +// Also this package provides TVM client for Qloud (NewQloudClient) and Yandex.Deploy (NewDeployClient) environments. +package tvmtool diff --git a/library/go/yandex/tvm/tvmtool/errors.go b/library/go/yandex/tvm/tvmtool/errors.go new file mode 100644 index 0000000000..f0b08a9878 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/errors.go @@ -0,0 +1,61 @@ +package tvmtool + +import ( + "fmt" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +// Generic TVM errors, before retry any request it check .Retriable field. +type Error = tvm.Error + +const ( + // ErrorAuthFail - auth failed, probably you provides invalid auth token + ErrorAuthFail = tvm.ErrorAuthFail + // ErrorBadRequest - tvmtool rejected our request, check .Msg for details + ErrorBadRequest = tvm.ErrorBadRequest + // ErrorOther - any other TVM-related errors, check .Msg for details + ErrorOther = tvm.ErrorOther +) + +// Ticket validation error +type TicketError = tvm.TicketError + +const ( + TicketErrorInvalidScopes = tvm.TicketInvalidScopes + TicketErrorOther = tvm.TicketStatusOther +) + +type PingCode uint32 + +const ( + PingCodeDie = iota + PingCodeWarning + PingCodeError + PingCodeOther +) + +func (e PingCode) String() string { + switch e { + case PingCodeDie: + return "HttpDie" + case PingCodeWarning: + return "Warning" + case PingCodeError: + return "Error" + case PingCodeOther: + return "Other" + default: + return fmt.Sprintf("Unknown%d", e) + } +} + +// Special ping error +type PingError struct { + Code PingCode + Err error +} + +func (e *PingError) Error() string { + return fmt.Sprintf("tvm: %s (code %s)", e.Err.Error(), e.Code) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go new file mode 100644 index 0000000000..95fcc0bd51 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/check_tickets/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + srvTicket string + userTicket string +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&srvTicket, "srv", "", "service ticket to check") + flag.StringVar(&userTicket, "usr", "", "user ticket to check") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + fmt.Printf("------ Check service ticket ------\n\n") + srvCheck, err := tvmClient.CheckServiceTicket(context.Background(), srvTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", srvCheck, err) + } else { + fmt.Printf("OK\nInfo: %s\n", srvCheck) + } + + if userTicket == "" { + return + } + + fmt.Printf("\n------ Check user ticket result ------\n\n") + + usrCheck, err := tvmClient.CheckUserTicket(context.Background(), userTicket) + if err != nil { + fmt.Printf("Failed\nTicket: %s\nError: %s\n", usrCheck, err) + return + } + fmt.Printf("OK\nInfo: %s\n", usrCheck) +} diff --git a/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go new file mode 100644 index 0000000000..2abfca8bfb --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/examples/get_service_ticket/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +var ( + baseURI = "http://localhost:3000" + dst = "dst" +) + +func main() { + flag.StringVar(&baseURI, "tool-uri", baseURI, "TVM tool uri") + flag.StringVar(&dst, "dst", dst, "Destination TVM app (must be configured in tvm-tool)") + flag.Parse() + + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + auth := os.Getenv("TVMTOOL_LOCAL_AUTHTOKEN") + if auth == "" { + zlog.Fatal("Please provide tvm-tool auth in env[TVMTOOL_LOCAL_AUTHTOKEN]") + return + } + + tvmClient, err := tvmtool.NewClient( + baseURI, + tvmtool.WithAuthToken(auth), + tvmtool.WithLogger(zlog), + ) + if err != nil { + zlog.Fatal("failed create tvm client", log.Error(err)) + return + } + defer tvmClient.Close() + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), dst) + if err != nil { + zlog.Fatal("failed to get tvm ticket", log.String("dst", dst), log.Error(err)) + return + } + + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json new file mode 100644 index 0000000000..db768f5d53 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/gotest/tvmtool.conf.json @@ -0,0 +1,32 @@ +{ + "BbEnvType": 3, + "clients": { + "main": { + "secret": "fake_secret", + "self_tvm_id": 42, + "dsts": { + "he": { + "dst_id": 100500 + }, + "he_clone": { + "dst_id": 100500 + }, + "slave": { + "dst_id": 43 + }, + "self": { + "dst_id": 42 + } + } + }, + "slave": { + "secret": "fake_secret", + "self_tvm_id": 43, + "dsts": { + "he": { + "dst_id": 100500 + } + } + } + } +}
\ No newline at end of file diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go new file mode 100644 index 0000000000..b625ca774f --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache.go @@ -0,0 +1,128 @@ +package cache + +import ( + "sync" + "time" + + "a.yandex-team.ru/library/go/yandex/tvm" +) + +const ( + Hit Status = iota + Miss + GonnaMissy +) + +type ( + Status int + + Cache struct { + ttl time.Duration + maxTTL time.Duration + tickets map[tvm.ClientID]entry + aliases map[string]tvm.ClientID + lock sync.RWMutex + } + + entry struct { + value *string + born time.Time + } +) + +func New(ttl, maxTTL time.Duration) *Cache { + return &Cache{ + ttl: ttl, + maxTTL: maxTTL, + tickets: make(map[tvm.ClientID]entry, 1), + aliases: make(map[string]tvm.ClientID, 1), + } +} + +func (c *Cache) Gc() { + now := time.Now() + + c.lock.Lock() + defer c.lock.Unlock() + for clientID, ticket := range c.tickets { + if ticket.born.Add(c.maxTTL).After(now) { + continue + } + + delete(c.tickets, clientID) + for alias, aClientID := range c.aliases { + if clientID == aClientID { + delete(c.aliases, alias) + } + } + } +} + +func (c *Cache) ClientIDs() []tvm.ClientID { + c.lock.RLock() + defer c.lock.RUnlock() + + clientIDs := make([]tvm.ClientID, 0, len(c.tickets)) + for clientID := range c.tickets { + clientIDs = append(clientIDs, clientID) + } + return clientIDs +} + +func (c *Cache) Aliases() []string { + c.lock.RLock() + defer c.lock.RUnlock() + + aliases := make([]string, 0, len(c.aliases)) + for alias := range c.aliases { + aliases = append(aliases, alias) + } + return aliases +} + +func (c *Cache) Load(clientID tvm.ClientID) (*string, Status) { + c.lock.RLock() + e, ok := c.tickets[clientID] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + now := time.Now() + exp := e.born.Add(c.ttl) + if exp.After(now) { + return e.value, Hit + } + + exp = e.born.Add(c.maxTTL) + if exp.After(now) { + return e.value, GonnaMissy + } + + c.lock.Lock() + delete(c.tickets, clientID) + c.lock.Unlock() + return nil, Miss +} + +func (c *Cache) LoadByAlias(alias string) (*string, Status) { + c.lock.RLock() + clientID, ok := c.aliases[alias] + c.lock.RUnlock() + if !ok { + return nil, Miss + } + + return c.Load(clientID) +} + +func (c *Cache) Store(clientID tvm.ClientID, alias string, value *string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.aliases[alias] = clientID + c.tickets[clientID] = entry{ + value: value, + born: time.Now(), + } +} diff --git a/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go new file mode 100644 index 0000000000..d9a1780108 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/internal/cache/cache_test.go @@ -0,0 +1,125 @@ +package cache_test + +import ( + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +var ( + testDst = "test_dst" + testDstAlias = "test_dst_alias" + testDstID = tvm.ClientID(1) + testValue = "test_val" +) + +func TestNewAtHour(t *testing.T) { + c := cache.New(time.Hour, 11*time.Hour) + assert.NotNil(t, c, "failed to create cache") +} + +func TestCache_Load(t *testing.T) { + + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + // checking before + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.Hit, hit, "failed to get '%d' from cache before deadline", testDstID) + assert.NotNil(t, r, "failed to get '%d' from cache before deadline", testDstID) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.Hit, hit, "failed to get '%s' from cache before deadline", testDst) + assert.NotNil(t, r, "failed to get %q from tickets before deadline", testDst) + assert.Equal(t, testValue, *r) + } + { + r, hit := c.Load(999833321) + assert.Equal(t, cache.Miss, hit, "got tickets for '999833321', but that key must be never existed") + assert.Nil(t, r, "got tickets for '999833321', but that key must be never existed") + + r, hit = c.LoadByAlias("kek") + assert.Equal(t, cache.Miss, hit, "got tickets for 'kek', but that key must be never existed") + assert.Nil(t, r, "got tickets for 'kek', but that key must be never existed") + } + + time.Sleep(3 * time.Second) + // checking after + { + r, hit := c.Load(testDstID) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + + r, hit = c.LoadByAlias(testDst) + assert.Equal(t, cache.GonnaMissy, hit) + assert.Equal(t, testValue, *r) + } +} + +func TestCache_Keys(t *testing.T) { + c := cache.New(time.Second, time.Hour) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + sort.Strings(aliases) + require.Equal(t, 2, len(aliases), "not correct length of aliases") + require.EqualValues(t, []string{testDst, testDstAlias}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.EqualValues(t, []tvm.ClientID{testDstID}, ids) + }) +} + +func TestCache_ExpiredKeys(t *testing.T) { + c := cache.New(time.Second, 10*time.Second) + c.Store(testDstID, testDst, &testValue) + c.Store(testDstID, testDstAlias, &testValue) + + time.Sleep(3 * time.Second) + c.Gc() + + var ( + newDst = "new_dst" + newDstID = tvm.ClientID(2) + ) + c.Store(newDstID, newDst, &testValue) + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 3, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{testDst, testDstAlias, newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 2, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{testDstID, newDstID}, ids) + }) + + time.Sleep(8 * time.Second) + c.Gc() + + t.Run("aliases", func(t *testing.T) { + aliases := c.Aliases() + require.Equal(t, 1, len(aliases), "not correct length of aliases") + require.ElementsMatch(t, []string{newDst}, aliases) + }) + + t.Run("client_ids", func(t *testing.T) { + ids := c.ClientIDs() + require.Equal(t, 1, len(ids), "not correct length of client ids") + require.ElementsMatch(t, []tvm.ClientID{newDstID}, ids) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/opts.go b/library/go/yandex/tvm/tvmtool/opts.go new file mode 100644 index 0000000000..91d29139d8 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/opts.go @@ -0,0 +1,103 @@ +package tvmtool + +import ( + "context" + "net/http" + "strings" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/xerrors" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +type ( + Option func(tool *Client) error +) + +// Source TVM client (id or alias) +// +// WARNING: id/alias must be configured in tvmtool. Documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#konfig +func WithSrc(src string) Option { + return func(tool *Client) error { + tool.src = src + return nil + } +} + +// Auth token +func WithAuthToken(token string) Option { + return func(tool *Client) error { + tool.authToken = token + return nil + } +} + +// Use custom HTTP client +func WithHTTPClient(client *http.Client) Option { + return func(tool *Client) error { + tool.ownHTTPClient = false + tool.httpClient = client + return nil + } +} + +// Enable or disable service tickets cache +// +// Enabled by default +func WithCacheEnabled(enabled bool) Option { + return func(tool *Client) error { + switch { + case enabled && tool.cache == nil: + tool.cache = cache.New(cacheTTL, cacheMaxTTL) + case !enabled: + tool.cache = nil + } + return nil + } +} + +// Overrides blackbox environment defined in config. +// +// Documentation about environment overriding: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func WithOverrideEnv(bbEnv tvm.BlackboxEnv) Option { + return func(tool *Client) error { + tool.bbEnv = strings.ToLower(bbEnv.String()) + return nil + } +} + +// WithLogger sets logger for tvm client. +func WithLogger(l log.Structured) Option { + return func(tool *Client) error { + tool.l = l + return nil + } +} + +// WithRefreshFrequency sets service tickets refresh frequency. +// Frequency must be lower chan cacheTTL (10 min) +// +// Default: 8 min +func WithRefreshFrequency(freq time.Duration) Option { + return func(tool *Client) error { + if freq > cacheTTL { + return xerrors.Errorf("refresh frequency must be lower than cacheTTL (%d > %d)", freq, cacheTTL) + } + + tool.refreshFreq = int64(freq.Seconds()) + return nil + } +} + +// WithBackgroundUpdate force Client to update all service ticket at background. +// You must manually cancel given ctx to stops refreshing. +// +// Default: disabled +func WithBackgroundUpdate(ctx context.Context) Option { + return func(tool *Client) error { + tool.bgCtx, tool.bgCancel = context.WithCancel(ctx) + return nil + } +} diff --git a/library/go/yandex/tvm/tvmtool/qloud.go b/library/go/yandex/tvm/tvmtool/qloud.go new file mode 100644 index 0000000000..4dcf0648db --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud.go @@ -0,0 +1,32 @@ +package tvmtool + +import ( + "fmt" + "os" +) + +const ( + QloudEndpointEnvKey = "QLOUD_TVM_INTERFACE_ORIGIN" + QloudTokenEnvKey = "QLOUD_TVM_TOKEN" + QloudDefaultEndpoint = "http://localhost:1" +) + +// NewQloudClient method creates a new tvmtool client for Qloud environment. +// You must reuse it to prevent connection/goroutines leakage. +func NewQloudClient(opts ...Option) (*Client, error) { + baseURI := os.Getenv(QloudEndpointEnvKey) + if baseURI == "" { + baseURI = QloudDefaultEndpoint + } + + authToken := os.Getenv(QloudTokenEnvKey) + if authToken == "" { + return nil, fmt.Errorf("empty auth token (looked at ENV[%s])", QloudTokenEnvKey) + } + + opts = append([]Option{WithAuthToken(authToken)}, opts...) + return NewClient( + baseURI, + opts..., + ) +} diff --git a/library/go/yandex/tvm/tvmtool/qloud_example_test.go b/library/go/yandex/tvm/tvmtool/qloud_example_test.go new file mode 100644 index 0000000000..a6bfcbede6 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/qloud_example_test.go @@ -0,0 +1,71 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewQloudClient_simple() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient(tvmtool.WithLogger(zlog)) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewQloudClient_custom() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewQloudClient( + tvmtool.WithSrc("second_app"), + tvmtool.WithOverrideEnv(tvm.BlackboxProd), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool.go b/library/go/yandex/tvm/tvmtool/tool.go new file mode 100644 index 0000000000..0273902b6f --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool.go @@ -0,0 +1,530 @@ +package tvmtool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/nop" + "a.yandex-team.ru/library/go/core/xerrors" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool/internal/cache" +) + +const ( + dialTimeout = 100 * time.Millisecond + requestTimeout = 500 * time.Millisecond + keepAlive = 60 * time.Second + cacheTTL = 10 * time.Minute + cacheMaxTTL = 11 * time.Hour +) + +var _ tvm.Client = (*Client)(nil) + +type ( + Client struct { + lastSync int64 + baseURI string + src string + authToken string + bbEnv string + refreshFreq int64 + bgCtx context.Context + bgCancel context.CancelFunc + inFlightRefresh uint32 + cache *cache.Cache + pingRequest *http.Request + ownHTTPClient bool + httpClient *http.Client + l log.Structured + } + + ticketsResponse map[string]struct { + Error string `json:"error"` + Ticket string `json:"ticket"` + TvmID tvm.ClientID `json:"tvm_id"` + } + + checkSrvResponse struct { + SrcID tvm.ClientID `json:"src"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } + + checkUserResponse struct { + DefaultUID tvm.UID `json:"default_uid"` + UIDs []tvm.UID `json:"uids"` + Scopes []string `json:"scopes"` + Error string `json:"error"` + DbgInfo string `json:"debug_string"` + LogInfo string `json:"logging_string"` + } +) + +// NewClient method creates a new tvmtool client. +// You must reuse it to prevent connection/goroutines leakage. +func NewClient(apiURI string, opts ...Option) (*Client, error) { + baseURI := strings.TrimRight(apiURI, "/") + "/tvm" + pingRequest, err := http.NewRequest("GET", baseURI+"/ping", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: keepAlive, + }).DialContext + + tool := &Client{ + baseURI: baseURI, + refreshFreq: 8 * 60, + cache: cache.New(cacheTTL, cacheMaxTTL), + pingRequest: pingRequest, + l: &nop.Logger{}, + ownHTTPClient: true, + httpClient: &http.Client{ + Transport: transport, + Timeout: requestTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + } + + for _, opt := range opts { + if err := opt(tool); err != nil { + return nil, xerrors.Errorf("tvmtool: failed to configure client: %w", err) + } + } + + if tool.bgCtx != nil { + go tool.serviceTicketsRefreshLoop() + } + + return tool, nil +} + +// GetServiceTicketForAlias returns TVM service ticket for alias +// +// WARNING: alias must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.LoadByAlias(alias); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + return "", xerrors.Errorf("tvmtool: alias %q was not found in response", alias) + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(entry.TvmID, alias, &ticket) + } + return ticket, nil +} + +// GetServiceTicketForID returns TVM service ticket for destination application id +// +// WARNING: id must be configured in tvmtool +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/tickets +func (c *Client) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + var ( + cachedTicket *string + cacheStatus = cache.Miss + ) + + if c.cache != nil { + c.refreshServiceTickets() + + if cachedTicket, cacheStatus = c.cache.Load(dstID); cacheStatus == cache.Hit { + return *cachedTicket, nil + } + } + + alias := strconv.FormatUint(uint64(dstID), 10) + tickets, err := c.getServiceTickets(ctx, alias) + if err != nil { + if cachedTicket != nil && cacheStatus == cache.GonnaMissy { + return *cachedTicket, nil + } + return "", err + } + + entry, ok := tickets[alias] + if !ok { + // ok, let's find him + for candidateAlias, candidate := range tickets { + if candidate.TvmID == dstID { + entry = candidate + alias = candidateAlias + ok = true + break + } + } + + if !ok { + return "", xerrors.Errorf("tvmtool: dst %q was not found in response", alias) + } + } + + if entry.Error != "" { + return "", &Error{Code: ErrorOther, Msg: entry.Error} + } + + ticket := entry.Ticket + if c.cache != nil { + c.cache.Store(dstID, alias, &ticket) + } + return ticket, nil +} + +// Close stops background ticket updates (if configured) and closes idle connections. +func (c *Client) Close() { + if c.bgCancel != nil { + c.bgCancel() + } + + if c.ownHTTPClient { + c.httpClient.CloseIdleConnections() + } +} + +func (c *Client) refreshServiceTickets() { + if c.bgCtx != nil { + // service tickets will be updated at background in the separated goroutine + return + } + + now := time.Now().Unix() + if now-atomic.LoadInt64(&c.lastSync) > c.refreshFreq { + atomic.StoreInt64(&c.lastSync, now) + if atomic.CompareAndSwapUint32(&c.inFlightRefresh, 0, 1) { + go c.doServiceTicketsRefresh(context.Background()) + } + } +} + +func (c *Client) serviceTicketsRefreshLoop() { + var ticker = time.NewTicker(time.Duration(c.refreshFreq) * time.Second) + defer ticker.Stop() + for { + select { + case <-c.bgCtx.Done(): + return + case <-ticker.C: + c.doServiceTicketsRefresh(c.bgCtx) + } + } +} + +func (c *Client) doServiceTicketsRefresh(ctx context.Context) { + defer atomic.CompareAndSwapUint32(&c.inFlightRefresh, 1, 0) + + c.cache.Gc() + aliases := c.cache.Aliases() + if len(aliases) == 0 { + return + } + + c.l.Debug("tvmtool: service ticket update started") + defer c.l.Debug("tvmtool: service ticket update finished") + + // fast path: batch update, must work most of time + err := c.refreshServiceTicket(ctx, aliases...) + if err == nil { + return + } + + if tvmErr, ok := err.(*Error); ok && tvmErr.Code != ErrorBadRequest { + c.l.Error( + "tvmtool: failed to refresh all service tickets at background", + log.Strings("dsts", aliases), + log.Error(err), + ) + + // if we have non "bad request" error - something really terrible happens, nothing to do with it :( + // TODO(buglloc): implement adaptive refreshFreq based on errors? + return + } + + // slow path: trying to update service tickets one by one + c.l.Error( + "tvmtool: failed to refresh all service tickets at background, switched to slow path", + log.Strings("dsts", aliases), + log.Error(err), + ) + + for _, dst := range aliases { + if err := c.refreshServiceTicket(ctx, dst); err != nil { + c.l.Error( + "tvmtool: failed to refresh service ticket at background", + log.String("dst", dst), + log.Error(err), + ) + } + } +} + +func (c *Client) refreshServiceTicket(ctx context.Context, dsts ...string) error { + tickets, err := c.getServiceTickets(ctx, strings.Join(dsts, ",")) + if err != nil { + return err + } + + for _, dst := range dsts { + entry, ok := tickets[dst] + if !ok { + c.l.Error( + "tvmtool: destination was not found in tvmtool response", + log.String("dst", dst), + ) + continue + } + + if entry.Error != "" { + c.l.Error( + "tvmtool: failed to get service ticket for destination", + log.String("dst", dst), + log.String("err", entry.Error), + ) + continue + } + + c.cache.Store(entry.TvmID, dst, &entry.Ticket) + } + return nil +} + +func (c *Client) getServiceTickets(ctx context.Context, dst string) (ticketsResponse, error) { + params := url.Values{ + "dsts": {dst}, + } + if c.src != "" { + params.Set("src", c.src) + } + + req, err := http.NewRequest("GET", c.baseURI+"/tickets?"+params.Encode(), nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + req.Header.Set("Authorization", c.authToken) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result ticketsResponse + err = readResponse(resp, &result) + return result, err +} + +// Check TVM service ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checksrv +func (c *Client) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + req, err := http.NewRequest("GET", c.baseURI+"/checksrv", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + if c.src != "" { + req.URL.RawQuery += "dst=" + url.QueryEscape(c.src) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-Service-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkSrvResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedServiceTicket{ + SrcID: result.SrcID, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Check TVM user ticket +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/checkusr +func (c *Client) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + for range opts { + panic("implement me") + } + + req, err := http.NewRequest("GET", c.baseURI+"/checkusr", nil) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + if c.bbEnv != "" { + req.URL.RawQuery += "override_env=" + url.QueryEscape(c.bbEnv) + } + req.Header.Set("Authorization", c.authToken) + req.Header.Set("X-Ya-User-Ticket", ticket) + + req = req.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, xerrors.Errorf("tvmtool: failed to call tvmtool: %w", err) + } + + var result checkUserResponse + if err = readResponse(resp, &result); err != nil { + return nil, err + } + + ticketInfo := &tvm.CheckedUserTicket{ + DefaultUID: result.DefaultUID, + UIDs: result.UIDs, + Scopes: result.Scopes, + DbgInfo: result.DbgInfo, + LogInfo: result.LogInfo, + } + + if resp.StatusCode == http.StatusForbidden { + err = &TicketError{Status: TicketErrorOther, Msg: result.Error} + } + + return ticketInfo, err +} + +// Checks TVMTool liveness +// +// TVMTool documentation: https://wiki.yandex-team.ru/passport/tvm2/tvm-daemon/#/tvm/ping +func (c *Client) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + defer func() { _ = resp.Body.Close() }() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return tvm.ClientStatusInfo{Status: tvm.ClientError}, + &PingError{Code: PingCodeDie, Err: err} + } + + var status tvm.ClientStatusInfo + switch resp.StatusCode { + case http.StatusOK: + // OK! + status = tvm.ClientStatusInfo{Status: tvm.ClientOK} + err = nil + case http.StatusPartialContent: + status = tvm.ClientStatusInfo{Status: tvm.ClientWarning} + err = &PingError{Code: PingCodeWarning, Err: xerrors.New(string(body))} + case http.StatusInternalServerError: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeError, Err: xerrors.New(string(body))} + default: + status = tvm.ClientStatusInfo{Status: tvm.ClientError} + err = &PingError{Code: PingCodeOther, Err: xerrors.Errorf("tvmtool: unexpected status: %d", resp.StatusCode)} + } + return status, err +} + +// Returns TVMTool version +func (c *Client) Version(ctx context.Context) (string, error) { + req := c.pingRequest.WithContext(ctx) + resp, err := c.httpClient.Do(req) + if err != nil { + return "", xerrors.Errorf("tvmtool: failed to call tmvtool: %w", err) + } + _, _ = ioutil.ReadAll(resp.Body) + _ = resp.Body.Close() + + return resp.Header.Get("Server"), nil +} + +func (c *Client) GetRoles(ctx context.Context) (*tvm.Roles, error) { + return nil, errors.New("not implemented") +} + +func readResponse(resp *http.Response, dst interface{}) error { + body, err := ioutil.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return xerrors.Errorf("tvmtool: failed to read response: %w", err) + } + + switch resp.StatusCode { + case http.StatusOK, http.StatusForbidden: + // ok + return json.Unmarshal(body, dst) + case http.StatusUnauthorized: + return &Error{ + Code: ErrorAuthFail, + Msg: string(body), + } + case http.StatusBadRequest: + return &Error{ + Code: ErrorBadRequest, + Msg: string(body), + } + case http.StatusInternalServerError: + return &Error{ + Code: ErrorOther, + Msg: string(body), + Retriable: true, + } + default: + return &Error{ + Code: ErrorOther, + Msg: fmt.Sprintf("tvmtool: unexpected status: %d, msg: %s", resp.StatusCode, string(body)), + } + } +} diff --git a/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go new file mode 100644 index 0000000000..e1b9f114c0 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_bg_update_test.go @@ -0,0 +1,354 @@ +package tvmtool_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func newMockClient(upstream string, options ...tvmtool.Option) (*tvmtool.Client, error) { + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + options = append(options, tvmtool.WithLogger(zlog), tvmtool.WithAuthToken("token")) + return tvmtool.NewClient(upstream, options...) +} + +// TestClientBackgroundUpdate_Updatable checks that TVMTool client updates tickets state +func TestClientBackgroundUpdate_Updatable(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(500*time.Millisecond)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(1*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = atomic.NewString("3:serv:original-test-ticket:signature") + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = atomic.NewString("3:serv:original-test-foo-ticket:signature") + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket.Load(), + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket.Load(), + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func(mustEquals bool) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + if mustEquals { + require.Equal(t, testFooTicket.Load(), ticket) + } + } + + // populate tickets cache + requestTickets(true) + + // now change tickets + newTicket := "3:serv:changed-test-ticket:signature" + testTicket.Store(newTicket) + testFooTicket.Store("3:serv:changed-test-foo-ticket:signature") + + // wait some time + time.Sleep(2 * time.Second) + + // request new tickets + requestTickets(false) + + // and wait updates some time + for idx := 0; idx < 250; idx++ { + time.Sleep(100 * time.Millisecond) + ticket, _ := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + if ticket == newTicket { + break + } + } + + // now out tvmclient MUST returns new tickets + requestTickets(true) + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +// TestClientBackgroundUpdate_NotTooOften checks that TVMTool client request tvmtool not too often +func TestClientBackgroundUpdate_NotTooOften(t *testing.T) { + type TestCase struct { + client func(ctx context.Context, t *testing.T, url string) *tvmtool.Client + } + cases := map[string]TestCase{ + "async": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient(url, tvmtool.WithRefreshFrequency(20*time.Second)) + require.NoError(t, err) + return tvmClient + }, + }, + "background": { + client: func(ctx context.Context, t *testing.T, url string) *tvmtool.Client { + tvmClient, err := newMockClient( + url, + tvmtool.WithRefreshFrequency(20*time.Second), + tvmtool.WithBackgroundUpdate(ctx), + ) + require.NoError(t, err) + return tvmClient + }, + }, + } + + tester := func(name string, tc TestCase) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + var ( + reqCount = atomic.NewUint32(0) + testDstAlias = "test" + testDstID = tvm.ClientID(2002456) + testTicket = "3:serv:original-test-ticket:signature" + testFooDstAlias = "test_foo" + testFooDstID = tvm.ClientID(2002457) + testFooTicket = "3:serv:original-test-foo-ticket:signature" + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + assert.Equal(t, "/tvm/tickets", r.URL.Path) + assert.Equal(t, "token", r.Header.Get("Authorization")) + switch r.URL.RawQuery { + case "dsts=test", "dsts=test_foo", "dsts=test%2Ctest_foo", "dsts=test_foo%2Ctest": + // ok + case "dsts=2002456", "dsts=2002457", "dsts=2002456%2C2002457", "dsts=2002457%2C2002456": + // ok + default: + t.Errorf("unknown tvm-request query: %q", r.URL.RawQuery) + } + + w.Header().Set("Content-Type", "application/json") + rsp := map[string]struct { + Ticket string `json:"ticket"` + TVMID tvm.ClientID `json:"tvm_id"` + }{ + testDstAlias: { + Ticket: testTicket, + TVMID: testDstID, + }, + testFooDstAlias: { + Ticket: testFooTicket, + TVMID: testFooDstID, + }, + } + + err := json.NewEncoder(w).Encode(rsp) + assert.NoError(t, err) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tvmClient := tc.client(ctx, t, srv.URL) + + requestTickets := func() { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), testDstAlias) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testDstID) + require.NoError(t, err) + require.Equal(t, testTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), testFooDstAlias) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), testFooDstID) + require.NoError(t, err) + require.Equal(t, testFooTicket, ticket) + } + + // populate cache + requestTickets() + + // requests tickets some time that lower than refresh frequency + for i := 0; i < 10; i++ { + requestTickets() + time.Sleep(200 * time.Millisecond) + } + + require.Equal(t, uint32(2), reqCount.Load(), "tvmtool client calls tvmtool too many times") + }) + } + + for name, tc := range cases { + tester(name, tc) + } +} + +func TestClient_RefreshFrequency(t *testing.T) { + cases := map[string]struct { + freq time.Duration + err bool + }{ + "too_high": { + freq: 20 * time.Minute, + err: true, + }, + "ok": { + freq: 2 * time.Minute, + err: false, + }, + } + + for name, cs := range cases { + t.Run(name, func(t *testing.T) { + _, err := tvmtool.NewClient("fake", tvmtool.WithRefreshFrequency(cs.freq)) + if cs.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestClient_MultipleAliases(t *testing.T) { + reqCount := atomic.NewUint32(0) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ +"test": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456}, +"test_alias": {"ticket": "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature","tvm_id": 2002456} +}`)) + })) + defer srv.Close() + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := newMockClient( + srv.URL, + tvmtool.WithRefreshFrequency(2*time.Second), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + require.NoError(t, err) + + requestTickets := func(t *testing.T) { + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "test") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForAlias(context.Background(), "test_alias") + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + + ticket, err = tvmClient.GetServiceTicketForID(context.Background(), tvm.ClientID(2002456)) + require.NoError(t, err) + require.Equal(t, "3:serv:CNVRELOq1O0FIggIwON6EJiceg:signature", ticket) + } + + t.Run("first", requestTickets) + + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 2 - one for each aliases + require.Equal(t, uint32(2), reqCount.Load()) + }) + + // now wait GC + reqCount.Store(0) + time.Sleep(3 * time.Second) + + t.Run("after_gc", requestTickets) + t.Run("check_requests", func(t *testing.T) { + // reqCount must be 1 + require.Equal(t, uint32(1), reqCount.Load()) + }) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_example_test.go b/library/go/yandex/tvm/tvmtool/tool_example_test.go new file mode 100644 index 0000000000..f3e482de91 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_example_test.go @@ -0,0 +1,81 @@ +package tvmtool_test + +import ( + "context" + "fmt" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +func ExampleNewClient() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} + +func ExampleNewClient_backgroundServiceTicketsUpdate() { + zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel)) + if err != nil { + panic(err) + } + + bgCtx, bgCancel := context.WithCancel(context.Background()) + defer bgCancel() + + tvmClient, err := tvmtool.NewClient( + "http://localhost:9000", + tvmtool.WithAuthToken("auth-token"), + tvmtool.WithSrc("my-cool-app"), + tvmtool.WithLogger(zlog), + tvmtool.WithBackgroundUpdate(bgCtx), + ) + if err != nil { + panic(err) + } + + ticket, err := tvmClient.GetServiceTicketForAlias(context.Background(), "black-box") + if err != nil { + retryable := false + if tvmErr, ok := err.(*tvm.Error); ok { + retryable = tvmErr.Retriable + } + + zlog.Fatal( + "failed to get service ticket", + log.String("alias", "black-box"), + log.Error(err), + log.Bool("retryable", retryable), + ) + } + fmt.Printf("ticket: %s\n", ticket) +} diff --git a/library/go/yandex/tvm/tvmtool/tool_export_test.go b/library/go/yandex/tvm/tvmtool/tool_export_test.go new file mode 100644 index 0000000000..7981a2db72 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_export_test.go @@ -0,0 +1,9 @@ +package tvmtool + +func (c *Client) BaseURI() string { + return c.baseURI +} + +func (c *Client) AuthToken() string { + return c.authToken +} diff --git a/library/go/yandex/tvm/tvmtool/tool_test.go b/library/go/yandex/tvm/tvmtool/tool_test.go new file mode 100644 index 0000000000..4329e1d101 --- /dev/null +++ b/library/go/yandex/tvm/tvmtool/tool_test.go @@ -0,0 +1,255 @@ +//go:build linux || darwin +// +build linux darwin + +// tvmtool recipe exists only for linux & darwin so we skip another OSes +package tvmtool_test + +import ( + "context" + "fmt" + "io/ioutil" + "regexp" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "a.yandex-team.ru/library/go/core/log" + "a.yandex-team.ru/library/go/core/log/zap" + "a.yandex-team.ru/library/go/yandex/tvm" + "a.yandex-team.ru/library/go/yandex/tvm/tvmtool" +) + +const ( + tvmToolPortFile = "tvmtool.port" + tvmToolAuthTokenFile = "tvmtool.authtoken" + userTicketFor1120000000038691 = "3:user" + + ":CA4Q__________9_GjUKCQijrpqRpdT-ARCjrpqRpdT-ARoMYmI6c2Vzc2lvbmlkGgl0ZXN0OnRlc3Qg0oXY" + + "zAQoAw:A-YI2yhoD7BbGU80_dKQ6vm7XADdvgD2QUFCeTI3XZ4MS4N8iENvsNDvYwsW89-vLQPv9pYqn8jxx" + + "awkvu_ZS2aAfpU8vXtnEHvzUQfes2kMjweRJE71cyX8B0VjENdXC5QAfGyK7Y0b4elTDJzw8b28Ro7IFFbNe" + + "qgcPInXndY" + serviceTicketFor41_42 = "3:serv:CBAQ__________9_IgQIKRAq" + + ":VVXL3wkhpBHB7OXSeG0IhqM5AP2CP-gJRD31ksAb-q7pmssBJKtPNbH34BSyLpBllmM1dgOfwL8ICUOGUA3l" + + "jOrwuxZ9H8ayfdrpM7q1-BVPE0sh0L9cd8lwZIW6yHejTe59s6wk1tG5MdSfncdaJpYiF3MwNHSRklNAkb6hx" + + "vg" + serviceTicketFor41_99 = "3:serv:CBAQ__________9_IgQIKRBj" + + ":PjJKDOsEk8VyxZFZwsVnKrW1bRyA82nGd0oIxnEFEf7DBTVZmNuxEejncDrMxnjkKwimrumV9POK4ptTo0ZPY" + + "6Du9zHR5QxekZYwDzFkECVrv9YT2QI03odwZJX8_WCpmlgI8hUog_9yZ5YCYxrQpWaOwDXx4T7VVMwH_Z9YTZk" +) + +var ( + srvTicketRe = regexp.MustCompile(`^3:serv:[A-Za-z0-9_\-]+:[A-Za-z0-9_\-]+$`) +) + +func newTvmToolClient(src string, authToken ...string) (*tvmtool.Client, error) { + raw, err := ioutil.ReadFile(tvmToolPortFile) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(string(raw)) + if err != nil { + return nil, err + } + + var auth string + if len(authToken) > 0 { + auth = authToken[0] + } else { + raw, err = ioutil.ReadFile(tvmToolAuthTokenFile) + if err != nil { + return nil, err + } + auth = string(raw) + } + + zlog, _ := zap.New(zap.ConsoleConfig(log.DebugLevel)) + + return tvmtool.NewClient( + fmt.Sprintf("http://localhost:%d", port), + tvmtool.WithAuthToken(auth), + tvmtool.WithCacheEnabled(false), + tvmtool.WithSrc(src), + tvmtool.WithLogger(zlog), + ) +} + +func TestNewClient(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestClient_GetStatus(t *testing.T) { + client, err := newTvmToolClient("main") + require.NoError(t, err) + status, err := client.GetStatus(context.Background()) + require.NoError(t, err, "ping must work") + require.Equal(t, tvm.ClientOK, status.Status) +} + +func TestClient_BadAuth(t *testing.T) { + badClient, err := newTvmToolClient("main", "fake-auth") + require.NoError(t, err) + + _, err = badClient.GetServiceTicketForAlias(context.Background(), "lala") + require.Error(t, err) + require.IsType(t, err, &tvmtool.Error{}) + srvTickerErr := err.(*tvmtool.Error) + require.Equal(t, tvmtool.ErrorAuthFail, srvTickerErr.Code) +} + +func TestClient_GetServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + + t.Run("invalid_alias", func(t *testing.T) { + // Ticket for invalid alias must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForAlias(ctx, "not_exists") + require.Error(t, err, "ticket for invalid alias must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = not_exists (strconv) (code ErrorBadRequest)") + }) + + t.Run("invalid_dst_id", func(t *testing.T) { + // Ticket for invalid client id must fails + t.Parallel() + _, err := tvmClient.GetServiceTicketForID(ctx, 123123123) + require.Error(t, err, "ticket for invalid ID must fails") + assert.IsType(t, err, &tvmtool.Error{}, "must return tvm err") + assert.EqualError(t, err, "tvm: can't find in config destination tvmid for src = 42, dstparam = 123123123 (by number) (code ErrorBadRequest)") + }) + + t.Run("by_alias", func(t *testing.T) { + // Try to get ticket by alias + t.Parallel() + heTicketByAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he") + if assert.NoError(t, err, "failed to get srv ticket to 'he'") { + assert.Regexp(t, srvTicketRe, heTicketByAlias, "invalid 'he' srv ticket") + } + + heCloneTicketAlias, err := tvmClient.GetServiceTicketForAlias(ctx, "he_clone") + if assert.NoError(t, err, "failed to get srv ticket to 'he_clone'") { + assert.Regexp(t, srvTicketRe, heCloneTicketAlias, "invalid 'he_clone' srv ticket") + } + }) + + t.Run("by_dst_id", func(t *testing.T) { + // Try to get ticket by id + t.Parallel() + heTicketByID, err := tvmClient.GetServiceTicketForID(ctx, 100500) + if assert.NoError(t, err, "failed to get srv ticket to '100500'") { + assert.Regexp(t, srvTicketRe, heTicketByID, "invalid '100500' srv ticket") + } + }) +} + +func TestClient_CheckServiceTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ctx := context.Background() + t.Run("self_to_self", func(t *testing.T) { + t.Parallel() + + // Check from self to self + selfTicket, err := tvmClient.GetServiceTicketForAlias(ctx, "self") + require.NoError(t, err, "failed to get service ticket to 'self'") + assert.Regexp(t, srvTicketRe, selfTicket, "invalid 'self' srv ticket") + + // Now we can check srv ticket + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, selfTicket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("to_another", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to self + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_42) + require.NoError(t, err, "failed to check srv ticket 41 -> 42") + + assert.Equal(t, tvm.ClientID(41), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + }) + + t.Run("invalid_dst", func(t *testing.T) { + t.Parallel() + + // Check from another client (41) to invalid dst (99) + ticketInfo, err := tvmClient.CheckServiceTicket(ctx, serviceTicketFor41_99) + require.Error(t, err, "srv ticket for 41 -> 99 must fails") + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) + + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "Wrong ticket dst, expected 42, got 99", ticketErr.Msg) + }) + + t.Run("broken", func(t *testing.T) { + t.Parallel() + + // Check with broken sign + _, err := tvmClient.CheckServiceTicket(ctx, "lalala") + require.Error(t, err, "srv ticket with broken sign must fails") + ticketErr := err.(*tvmtool.TicketError) + require.IsType(t, err, &tvmtool.TicketError{}) + assert.Equal(t, tvmtool.TicketErrorOther, ticketErr.Status) + assert.Equal(t, "invalid ticket format", ticketErr.Msg) + }) +} + +func TestClient_MultipleClients(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + slaveClient, err := newTvmToolClient("slave") + require.NoError(t, err) + + ctx := context.Background() + + ticket, err := tvmClient.GetServiceTicketForAlias(ctx, "slave") + require.NoError(t, err, "failed to get service ticket to 'slave'") + assert.Regexp(t, srvTicketRe, ticket, "invalid 'slave' srv ticket") + + ticketInfo, err := slaveClient.CheckServiceTicket(ctx, ticket) + require.NoError(t, err, "failed to check srv ticket main -> self") + + assert.Equal(t, tvm.ClientID(42), ticketInfo.SrcID) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_CheckUserTicket(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + ticketInfo, err := tvmClient.CheckUserTicket(context.Background(), userTicketFor1120000000038691) + require.NoError(t, err, "failed to check user ticket") + + assert.Equal(t, tvm.UID(1120000000038691), ticketInfo.DefaultUID) + assert.Subset(t, []tvm.UID{1120000000038691}, ticketInfo.UIDs) + assert.Subset(t, []string{"bb:sessionid", "test:test"}, ticketInfo.Scopes) + assert.NotEmpty(t, ticketInfo.LogInfo) + assert.NotEmpty(t, ticketInfo.DbgInfo) +} + +func TestClient_Version(t *testing.T) { + tvmClient, err := newTvmToolClient("main") + require.NoError(t, err) + + version, err := tvmClient.Version(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, version) +} diff --git a/library/go/yandex/tvm/user_ticket.go b/library/go/yandex/tvm/user_ticket.go new file mode 100644 index 0000000000..e68e5e5032 --- /dev/null +++ b/library/go/yandex/tvm/user_ticket.go @@ -0,0 +1,122 @@ +package tvm + +import ( + "fmt" +) + +// CheckedUserTicket is short-lived user credential. +// +// CheckedUserTicket contains only valid users. +// Details: https://wiki.yandex-team.ru/passport/tvm2/user-ticket/#chtoestvusertickete +type CheckedUserTicket struct { + // DefaultUID is default user - maybe 0 + DefaultUID UID + // UIDs is array of valid users - never empty + UIDs []UID + // Env is blackbox environment which created this UserTicket - provides only tvmauth now + Env BlackboxEnv + // Scopes is array of scopes inherited from credential - never empty + Scopes []string + // DbgInfo is human readable data for debug purposes + DbgInfo string + // LogInfo is safe for logging part of ticket - it can be parsed later with `tvmknife parse_ticket -t ...` + LogInfo string +} + +func (t CheckedUserTicket) String() string { + return fmt.Sprintf("%s (%s)", t.LogInfo, t.DbgInfo) +} + +// CheckScopes verify that ALL needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopes(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; !found { + // exit on first nonexistent scope + msg := fmt.Sprintf( + "user ticket doesn't contain one of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } + } + + return nil + } +} + +// CheckScopesAny verify that ANY of needed scopes presents in the user ticket +func (t *CheckedUserTicket) CheckScopesAny(scopes ...string) error { + switch { + case len(scopes) == 0: + // ok, no scopes. no checks. no rules + return nil + case len(t.Scopes) == 0: + msg := fmt.Sprintf("user ticket doesn't contain any of expected scopes: %s (actual: nil)", scopes) + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + default: + actualScopes := make(map[string]struct{}, len(t.Scopes)) + for _, s := range t.Scopes { + actualScopes[s] = struct{}{} + } + + for _, s := range scopes { + if _, found := actualScopes[s]; found { + // exit on first valid scope + return nil + } + } + + msg := fmt.Sprintf( + "user ticket doesn't contain any of expected scopes: %s (actual: %s)", + scopes, t.Scopes, + ) + + return &TicketError{Status: TicketInvalidScopes, Msg: msg} + } +} + +type CheckUserTicketOptions struct { + EnvOverride *BlackboxEnv +} + +type CheckUserTicketOption func(*CheckUserTicketOptions) + +func WithBlackboxOverride(env BlackboxEnv) CheckUserTicketOption { + return func(opts *CheckUserTicketOptions) { + opts.EnvOverride = &env + } +} + +type UserTicketACL func(ticket *CheckedUserTicket) error + +func AllowAllUserTickets() UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return nil + } +} + +func CheckAllUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopes(scopes...) + } +} + +func CheckAnyUserTicketScopesPresent(scopes []string) UserTicketACL { + return func(ticket *CheckedUserTicket) error { + return ticket.CheckScopesAny(scopes...) + } +} |