aboutsummaryrefslogtreecommitdiffstats
path: root/library/go/yandex/tvm/cachedtvm/client.go
blob: ed7d51d8d155a68919a333ecc3e25b4b31ca4349 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package cachedtvm

import (
	"context"
	"fmt"
	"time"

	"github.com/karlseguin/ccache/v2"
	"github.com/ydb-platform/ydb/library/go/yandex/tvm"
)

const (
	DefaultTTL          = 1 * time.Minute
	DefaultMaxItems     = 100
	MaxServiceTicketTTL = 5 * time.Minute
	MaxUserTicketTTL    = 1 * time.Minute
)

type CachedClient struct {
	tvm.Client
	serviceTicketCache cache
	userTicketCache    cache
	userTicketFn       func(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error)
}

func NewClient(tvmClient tvm.Client, opts ...Option) (*CachedClient, error) {
	newCache := func(o cacheOptions) cache {
		return cache{
			Cache: ccache.New(
				ccache.Configure().MaxSize(o.maxItems),
			),
			ttl: o.ttl,
		}
	}

	out := &CachedClient{
		Client: tvmClient,
		serviceTicketCache: newCache(cacheOptions{
			ttl:      DefaultTTL,
			maxItems: DefaultMaxItems,
		}),
		userTicketFn: tvmClient.CheckUserTicket,
	}

	for _, opt := range opts {
		switch o := opt.(type) {
		case OptionServiceTicket:
			if o.ttl > MaxServiceTicketTTL {
				return nil, fmt.Errorf("maximum TTL for check service ticket exceed: %s > %s", o.ttl, MaxServiceTicketTTL)
			}

			out.serviceTicketCache = newCache(o.cacheOptions)
		case OptionUserTicket:
			if o.ttl > MaxUserTicketTTL {
				return nil, fmt.Errorf("maximum TTL for check user ticket exceed: %s > %s", o.ttl, MaxUserTicketTTL)
			}

			out.userTicketFn = out.cacheCheckUserTicket
			out.userTicketCache = newCache(o.cacheOptions)
		default:
			panic(fmt.Sprintf("unexpected cache option: %T", o))
		}
	}

	return out, nil
}

func (c *CachedClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) {
	out, err := c.serviceTicketCache.Fetch(ticket, func() (interface{}, error) {
		return c.Client.CheckServiceTicket(ctx, ticket)
	})

	if err != nil {
		return nil, err
	}

	return out.Value().(*tvm.CheckedServiceTicket), nil
}

func (c *CachedClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
	return c.userTicketFn(ctx, ticket, opts...)
}

func (c *CachedClient) cacheCheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
	cacheKey := func(ticket string, opts ...tvm.CheckUserTicketOption) string {
		if len(opts) == 0 {
			return ticket
		}

		var options tvm.CheckUserTicketOptions
		for _, opt := range opts {
			opt(&options)
		}

		if options.EnvOverride == nil {
			return ticket
		}

		return fmt.Sprintf("%d:%s", *options.EnvOverride, ticket)
	}

	out, err := c.userTicketCache.Fetch(cacheKey(ticket, opts...), func() (interface{}, error) {
		return c.Client.CheckUserTicket(ctx, ticket, opts...)
	})

	if err != nil {
		return nil, err
	}

	return out.Value().(*tvm.CheckedUserTicket), nil
}

func (c *CachedClient) Close() {
	c.serviceTicketCache.Stop()
	c.userTicketCache.Stop()
}