aboutsummaryrefslogtreecommitdiffstats
path: root/library/go/yandex/tvm/cachedtvm
diff options
context:
space:
mode:
authorqrort <qrort@yandex-team.com>2022-11-30 23:47:12 +0300
committerqrort <qrort@yandex-team.com>2022-11-30 23:47:12 +0300
commit22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch)
treebffa27765faf54126ad44bcafa89fadecb7a73d7 /library/go/yandex/tvm/cachedtvm
parent332b99e2173f0425444abb759eebcb2fafaa9209 (diff)
downloadydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz
validate canons without yatest_common
Diffstat (limited to 'library/go/yandex/tvm/cachedtvm')
-rw-r--r--library/go/yandex/tvm/cachedtvm/cache.go22
-rw-r--r--library/go/yandex/tvm/cachedtvm/client.go117
-rw-r--r--library/go/yandex/tvm/cachedtvm/client_example_test.go40
-rw-r--r--library/go/yandex/tvm/cachedtvm/client_test.go195
-rw-r--r--library/go/yandex/tvm/cachedtvm/opts.go40
5 files changed, 414 insertions, 0 deletions
diff --git a/library/go/yandex/tvm/cachedtvm/cache.go b/library/go/yandex/tvm/cachedtvm/cache.go
new file mode 100644
index 0000000000..a04e2baf8a
--- /dev/null
+++ b/library/go/yandex/tvm/cachedtvm/cache.go
@@ -0,0 +1,22 @@
+package cachedtvm
+
+import (
+ "time"
+
+ "github.com/karlseguin/ccache/v2"
+)
+
+type cache struct {
+ *ccache.Cache
+ ttl time.Duration
+}
+
+func (c *cache) Fetch(key string, fn func() (interface{}, error)) (*ccache.Item, error) {
+ return c.Cache.Fetch(key, c.ttl, fn)
+}
+
+func (c *cache) Stop() {
+ if c.Cache != nil {
+ c.Cache.Stop()
+ }
+}
diff --git a/library/go/yandex/tvm/cachedtvm/client.go b/library/go/yandex/tvm/cachedtvm/client.go
new file mode 100644
index 0000000000..503c973e8c
--- /dev/null
+++ b/library/go/yandex/tvm/cachedtvm/client.go
@@ -0,0 +1,117 @@
+package cachedtvm
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/karlseguin/ccache/v2"
+
+ "a.yandex-team.ru/library/go/yandex/tvm"
+)
+
+const (
+ DefaultTTL = 1 * time.Minute
+ DefaultMaxItems = 100
+ MaxServiceTicketTTL = 5 * time.Minute
+ MaxUserTicketTTL = 1 * time.Minute
+)
+
+type CachedClient struct {
+ tvm.Client
+ serviceTicketCache cache
+ userTicketCache cache
+ userTicketFn func(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error)
+}
+
+func NewClient(tvmClient tvm.Client, opts ...Option) (*CachedClient, error) {
+ newCache := func(o cacheOptions) cache {
+ return cache{
+ Cache: ccache.New(
+ ccache.Configure().MaxSize(o.maxItems),
+ ),
+ ttl: o.ttl,
+ }
+ }
+
+ out := &CachedClient{
+ Client: tvmClient,
+ serviceTicketCache: newCache(cacheOptions{
+ ttl: DefaultTTL,
+ maxItems: DefaultMaxItems,
+ }),
+ userTicketFn: tvmClient.CheckUserTicket,
+ }
+
+ for _, opt := range opts {
+ switch o := opt.(type) {
+ case OptionServiceTicket:
+ if o.ttl > MaxServiceTicketTTL {
+ return nil, fmt.Errorf("maximum TTL for check service ticket exceed: %s > %s", o.ttl, MaxServiceTicketTTL)
+ }
+
+ out.serviceTicketCache = newCache(o.cacheOptions)
+ case OptionUserTicket:
+ if o.ttl > MaxUserTicketTTL {
+ return nil, fmt.Errorf("maximum TTL for check user ticket exceed: %s > %s", o.ttl, MaxUserTicketTTL)
+ }
+
+ out.userTicketFn = out.cacheCheckUserTicket
+ out.userTicketCache = newCache(o.cacheOptions)
+ default:
+ panic(fmt.Sprintf("unexpected cache option: %T", o))
+ }
+ }
+
+ return out, nil
+}
+
+func (c *CachedClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) {
+ out, err := c.serviceTicketCache.Fetch(ticket, func() (interface{}, error) {
+ return c.Client.CheckServiceTicket(ctx, ticket)
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return out.Value().(*tvm.CheckedServiceTicket), nil
+}
+
+func (c *CachedClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
+ return c.userTicketFn(ctx, ticket, opts...)
+}
+
+func (c *CachedClient) cacheCheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
+ cacheKey := func(ticket string, opts ...tvm.CheckUserTicketOption) string {
+ if len(opts) == 0 {
+ return ticket
+ }
+
+ var options tvm.CheckUserTicketOptions
+ for _, opt := range opts {
+ opt(&options)
+ }
+
+ if options.EnvOverride == nil {
+ return ticket
+ }
+
+ return fmt.Sprintf("%d:%s", *options.EnvOverride, ticket)
+ }
+
+ out, err := c.userTicketCache.Fetch(cacheKey(ticket, opts...), func() (interface{}, error) {
+ return c.Client.CheckUserTicket(ctx, ticket, opts...)
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return out.Value().(*tvm.CheckedUserTicket), nil
+}
+
+func (c *CachedClient) Close() {
+ c.serviceTicketCache.Stop()
+ c.userTicketCache.Stop()
+}
diff --git a/library/go/yandex/tvm/cachedtvm/client_example_test.go b/library/go/yandex/tvm/cachedtvm/client_example_test.go
new file mode 100644
index 0000000000..a95b1674a3
--- /dev/null
+++ b/library/go/yandex/tvm/cachedtvm/client_example_test.go
@@ -0,0 +1,40 @@
+package cachedtvm_test
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "a.yandex-team.ru/library/go/core/log"
+ "a.yandex-team.ru/library/go/core/log/zap"
+ "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm"
+ "a.yandex-team.ru/library/go/yandex/tvm/tvmtool"
+)
+
+func ExampleNewClient_checkServiceTicket() {
+ zlog, err := zap.New(zap.ConsoleConfig(log.DebugLevel))
+ if err != nil {
+ panic(err)
+ }
+
+ tvmClient, err := tvmtool.NewAnyClient(tvmtool.WithLogger(zlog))
+ if err != nil {
+ panic(err)
+ }
+
+ cachedTvmClient, err := cachedtvm.NewClient(
+ tvmClient,
+ cachedtvm.WithCheckServiceTicket(1*time.Minute, 1000),
+ )
+ if err != nil {
+ panic(err)
+ }
+ defer cachedTvmClient.Close()
+
+ ticketInfo, err := cachedTvmClient.CheckServiceTicket(context.TODO(), "3:serv:....")
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Println("ticket info: ", ticketInfo.LogInfo)
+}
diff --git a/library/go/yandex/tvm/cachedtvm/client_test.go b/library/go/yandex/tvm/cachedtvm/client_test.go
new file mode 100644
index 0000000000..a3c3081e30
--- /dev/null
+++ b/library/go/yandex/tvm/cachedtvm/client_test.go
@@ -0,0 +1,195 @@
+package cachedtvm_test
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "a.yandex-team.ru/library/go/yandex/tvm"
+ "a.yandex-team.ru/library/go/yandex/tvm/cachedtvm"
+)
+
+const (
+ checkPasses = 5
+)
+
+type mockTvmClient struct {
+ tvm.Client
+ checkServiceTicketCalls int
+ checkUserTicketCalls int
+}
+
+func (c *mockTvmClient) CheckServiceTicket(_ context.Context, ticket string) (*tvm.CheckedServiceTicket, error) {
+ defer func() { c.checkServiceTicketCalls++ }()
+
+ return &tvm.CheckedServiceTicket{
+ LogInfo: ticket,
+ IssuerUID: tvm.UID(c.checkServiceTicketCalls),
+ }, nil
+}
+
+func (c *mockTvmClient) CheckUserTicket(_ context.Context, ticket string, _ ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
+ defer func() { c.checkUserTicketCalls++ }()
+
+ return &tvm.CheckedUserTicket{
+ LogInfo: ticket,
+ DefaultUID: tvm.UID(c.checkUserTicketCalls),
+ }, nil
+}
+
+func (c *mockTvmClient) GetServiceTicketForAlias(_ context.Context, alias string) (string, error) {
+ return alias, nil
+}
+
+func checkServiceTickets(t *testing.T, client tvm.Client, equal bool) {
+ var prev *tvm.CheckedServiceTicket
+ for i := 0; i < checkPasses; i++ {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ cur, err := client.CheckServiceTicket(context.Background(), "3:serv:tst")
+ require.NoError(t, err)
+
+ if prev == nil {
+ return
+ }
+
+ if equal {
+ require.Equal(t, *prev, *cur)
+ } else {
+ require.NotEqual(t, *prev, *cur)
+ }
+ })
+ }
+}
+
+func runEqualServiceTickets(client tvm.Client) func(t *testing.T) {
+ return func(t *testing.T) {
+ checkServiceTickets(t, client, true)
+ }
+}
+
+func runNotEqualServiceTickets(client tvm.Client) func(t *testing.T) {
+ return func(t *testing.T) {
+ checkServiceTickets(t, client, false)
+ }
+}
+
+func checkUserTickets(t *testing.T, client tvm.Client, equal bool) {
+ var prev *tvm.CheckedServiceTicket
+ for i := 0; i < checkPasses; i++ {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ cur, err := client.CheckUserTicket(context.Background(), "3:user:tst")
+ require.NoError(t, err)
+
+ if prev == nil {
+ return
+ }
+
+ if equal {
+ require.Equal(t, *prev, *cur)
+ } else {
+ require.NotEqual(t, *prev, *cur)
+ }
+ })
+ }
+}
+
+func runEqualUserTickets(client tvm.Client) func(t *testing.T) {
+ return func(t *testing.T) {
+ checkUserTickets(t, client, true)
+ }
+}
+
+func runNotEqualUserTickets(client tvm.Client) func(t *testing.T) {
+ return func(t *testing.T) {
+ checkUserTickets(t, client, false)
+ }
+}
+func TestDefaultBehavior(t *testing.T) {
+ nestedClient := &mockTvmClient{}
+ client, err := cachedtvm.NewClient(nestedClient)
+ require.NoError(t, err)
+
+ t.Run("first_pass_srv", runEqualServiceTickets(client))
+ t.Run("first_pass_usr", runNotEqualUserTickets(client))
+
+ require.Equal(t, 1, nestedClient.checkServiceTicketCalls)
+ require.Equal(t, checkPasses, nestedClient.checkUserTicketCalls)
+
+ ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst")
+ require.NoError(t, err)
+ require.Equal(t, "tst", ticket)
+}
+
+func TestCheckServiceTicket(t *testing.T) {
+ nestedClient := &mockTvmClient{}
+ client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckServiceTicket(10*time.Second, 10))
+ require.NoError(t, err)
+
+ t.Run("first_pass_srv", runEqualServiceTickets(client))
+ t.Run("first_pass_usr", runNotEqualUserTickets(client))
+ time.Sleep(20 * time.Second)
+ t.Run("second_pass_srv", runEqualServiceTickets(client))
+ t.Run("second_pass_usr", runNotEqualUserTickets(client))
+
+ require.Equal(t, 2, nestedClient.checkServiceTicketCalls)
+ require.Equal(t, 2*checkPasses, nestedClient.checkUserTicketCalls)
+
+ ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst")
+ require.NoError(t, err)
+ require.Equal(t, "tst", ticket)
+}
+
+func TestCheckUserTicket(t *testing.T) {
+ nestedClient := &mockTvmClient{}
+ client, err := cachedtvm.NewClient(nestedClient, cachedtvm.WithCheckUserTicket(10*time.Second, 10))
+ require.NoError(t, err)
+
+ t.Run("first_pass_usr", runEqualUserTickets(client))
+ time.Sleep(20 * time.Second)
+ t.Run("second_pass_usr", runEqualUserTickets(client))
+ require.Equal(t, 2, nestedClient.checkUserTicketCalls)
+
+ ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst")
+ require.NoError(t, err)
+ require.Equal(t, "tst", ticket)
+}
+
+func TestCheckServiceAndUserTicket(t *testing.T) {
+ nestedClient := &mockTvmClient{}
+ client, err := cachedtvm.NewClient(nestedClient,
+ cachedtvm.WithCheckServiceTicket(10*time.Second, 10),
+ cachedtvm.WithCheckUserTicket(10*time.Second, 10),
+ )
+ require.NoError(t, err)
+
+ t.Run("first_pass_srv", runEqualServiceTickets(client))
+ t.Run("first_pass_usr", runEqualUserTickets(client))
+ time.Sleep(20 * time.Second)
+ t.Run("second_pass_srv", runEqualServiceTickets(client))
+ t.Run("second_pass_usr", runEqualUserTickets(client))
+
+ require.Equal(t, 2, nestedClient.checkUserTicketCalls)
+ require.Equal(t, 2, nestedClient.checkServiceTicketCalls)
+
+ ticket, err := client.GetServiceTicketForAlias(context.Background(), "tst")
+ require.NoError(t, err)
+ require.Equal(t, "tst", ticket)
+}
+
+func TestErrors(t *testing.T) {
+ cases := []cachedtvm.Option{
+ cachedtvm.WithCheckServiceTicket(12*time.Hour, 1),
+ cachedtvm.WithCheckUserTicket(30*time.Minute, 1),
+ }
+
+ nestedClient := &mockTvmClient{}
+ for i, tc := range cases {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ _, err := cachedtvm.NewClient(nestedClient, tc)
+ require.Error(t, err)
+ })
+ }
+}
diff --git a/library/go/yandex/tvm/cachedtvm/opts.go b/library/go/yandex/tvm/cachedtvm/opts.go
new file mode 100644
index 0000000000..0df9dfa89e
--- /dev/null
+++ b/library/go/yandex/tvm/cachedtvm/opts.go
@@ -0,0 +1,40 @@
+package cachedtvm
+
+import "time"
+
+type (
+ Option interface{ isCachedOption() }
+
+ cacheOptions struct {
+ ttl time.Duration
+ maxItems int64
+ }
+
+ OptionServiceTicket struct {
+ Option
+ cacheOptions
+ }
+
+ OptionUserTicket struct {
+ Option
+ cacheOptions
+ }
+)
+
+func WithCheckServiceTicket(ttl time.Duration, maxSize int) Option {
+ return OptionServiceTicket{
+ cacheOptions: cacheOptions{
+ ttl: ttl,
+ maxItems: int64(maxSize),
+ },
+ }
+}
+
+func WithCheckUserTicket(ttl time.Duration, maxSize int) Option {
+ return OptionUserTicket{
+ cacheOptions: cacheOptions{
+ ttl: ttl,
+ maxItems: int64(maxSize),
+ },
+ }
+}