diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 21:55:07 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 22:25:10 +0300 |
commit | 4967f99474a4040ba150eb04995de06342252718 (patch) | |
tree | c9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/smithy-go/auth/bearer | |
parent | 2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff) | |
download | ydb-4967f99474a4040ba150eb04995de06342252718.tar.gz |
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.)
2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/smithy-go/auth/bearer')
8 files changed, 981 insertions, 0 deletions
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/docs.go b/vendor/github.com/aws/smithy-go/auth/bearer/docs.go new file mode 100644 index 0000000000..1c9b9715cb --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/docs.go @@ -0,0 +1,3 @@ +// Package bearer provides middleware and utilities for authenticating API +// operation calls with a Bearer Token. +package bearer diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make b/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make new file mode 100644 index 0000000000..6cb218feab --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/smithy-go/auth/bearer) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go b/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go new file mode 100644 index 0000000000..8c7d720995 --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go @@ -0,0 +1,104 @@ +package bearer + +import ( + "context" + "fmt" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// Message is the middleware stack's request transport message value. +type Message interface{} + +// Signer provides an interface for implementations to decorate a request +// message with a bearer token. The signer is responsible for validating the +// message type is compatible with the signer. +type Signer interface { + SignWithBearerToken(context.Context, Token, Message) (Message, error) +} + +// AuthenticationMiddleware provides the Finalize middleware step for signing +// an request message with a bearer token. +type AuthenticationMiddleware struct { + signer Signer + tokenProvider TokenProvider +} + +// AddAuthenticationMiddleware helper adds the AuthenticationMiddleware to the +// middleware Stack in the Finalize step with the options provided. +func AddAuthenticationMiddleware(s *middleware.Stack, signer Signer, tokenProvider TokenProvider) error { + return s.Finalize.Add( + NewAuthenticationMiddleware(signer, tokenProvider), + middleware.After, + ) +} + +// NewAuthenticationMiddleware returns an initialized AuthenticationMiddleware. +func NewAuthenticationMiddleware(signer Signer, tokenProvider TokenProvider) *AuthenticationMiddleware { + return &AuthenticationMiddleware{ + signer: signer, + tokenProvider: tokenProvider, + } +} + +const authenticationMiddlewareID = "BearerTokenAuthentication" + +// ID returns the resolver identifier +func (m *AuthenticationMiddleware) ID() string { + return authenticationMiddlewareID +} + +// HandleFinalize implements the FinalizeMiddleware interface in order to +// update the request with bearer token authentication. +func (m *AuthenticationMiddleware) HandleFinalize( + ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + token, err := m.tokenProvider.RetrieveBearerToken(ctx) + if err != nil { + return out, metadata, fmt.Errorf("failed AuthenticationMiddleware wrap message, %w", err) + } + + signedMessage, err := m.signer.SignWithBearerToken(ctx, token, in.Request) + if err != nil { + return out, metadata, fmt.Errorf("failed AuthenticationMiddleware sign message, %w", err) + } + + in.Request = signedMessage + return next.HandleFinalize(ctx, in) +} + +// SignHTTPSMessage provides a bearer token authentication implementation that +// will sign the message with the provided bearer token. +// +// Will fail if the message is not a smithy-go HTTP request or the request is +// not HTTPS. +type SignHTTPSMessage struct{} + +// NewSignHTTPSMessage returns an initialized signer for HTTP messages. +func NewSignHTTPSMessage() *SignHTTPSMessage { + return &SignHTTPSMessage{} +} + +// SignWithBearerToken returns a copy of the HTTP request with the bearer token +// added via the "Authorization" header, per RFC 6750, https://datatracker.ietf.org/doc/html/rfc6750. +// +// Returns an error if the request's URL scheme is not HTTPS, or the request +// message is not an smithy-go HTTP Request pointer type. +func (SignHTTPSMessage) SignWithBearerToken(ctx context.Context, token Token, message Message) (Message, error) { + req, ok := message.(*smithyhttp.Request) + if !ok { + return nil, fmt.Errorf("expect smithy-go HTTP Request, got %T", message) + } + + if !req.IsHTTPS() { + return nil, fmt.Errorf("bearer token with HTTP request requires HTTPS") + } + + reqClone := req.Clone() + reqClone.Header.Set("Authorization", "Bearer "+token.Value) + + return reqClone, nil +} diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go b/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go new file mode 100644 index 0000000000..e9604f089b --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go @@ -0,0 +1,78 @@ +package bearer + +import ( + "context" + "net/http" + "net/url" + "strings" + "testing" + + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestSignHTTPSMessage(t *testing.T) { + cases := map[string]struct { + message Message + token Token + expectMessage Message + expectErr string + }{ + // Cases + "not smithyhttp.Request": { + message: struct{}{}, + expectErr: "expect smithy-go HTTP Request", + }, + "not https": { + message: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + expectErr: "requires HTTPS", + }, + "success": { + message: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + return r + }(), + token: Token{Value: "abc123"}, + expectMessage: func() Message { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set("Authorization", "Bearer abc123") + return r + }(), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + signer := SignHTTPSMessage{} + message, err := signer.SignWithBearerToken(ctx, c.token, c.message) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v in error %v", e, a) + } + return + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + options := []cmp.Option{ + cmpopts.IgnoreUnexported(smithyhttp.Request{}), + cmpopts.IgnoreUnexported(http.Request{}), + } + + if diff := cmp.Diff(c.expectMessage, message, options...); diff != "" { + t.Errorf("expect match\n%s", diff) + } + }) + } +} diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token.go b/vendor/github.com/aws/smithy-go/auth/bearer/token.go new file mode 100644 index 0000000000..be260d4c76 --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/token.go @@ -0,0 +1,50 @@ +package bearer + +import ( + "context" + "time" +) + +// Token provides a type wrapping a bearer token and expiration metadata. +type Token struct { + Value string + + CanExpire bool + Expires time.Time +} + +// Expired returns if the token's Expires time is before or equal to the time +// provided. If CanExpires is false, Expired will always return false. +func (t Token) Expired(now time.Time) bool { + if !t.CanExpire { + return false + } + now = now.Round(0) + return now.Equal(t.Expires) || now.After(t.Expires) +} + +// TokenProvider provides interface for retrieving bearer tokens. +type TokenProvider interface { + RetrieveBearerToken(context.Context) (Token, error) +} + +// TokenProviderFunc provides a helper utility to wrap a function as a type +// that implements the TokenProvider interface. +type TokenProviderFunc func(context.Context) (Token, error) + +// RetrieveBearerToken calls the wrapped function, returning the Token or +// error. +func (fn TokenProviderFunc) RetrieveBearerToken(ctx context.Context) (Token, error) { + return fn(ctx) +} + +// StaticTokenProvider provides a utility for wrapping a static bearer token +// value within an implementation of a token provider. +type StaticTokenProvider struct { + Token Token +} + +// RetrieveBearerToken returns the static token specified. +func (s StaticTokenProvider) RetrieveBearerToken(context.Context) (Token, error) { + return s.Token, nil +} diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go new file mode 100644 index 0000000000..223ddf52bb --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go @@ -0,0 +1,208 @@ +package bearer + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + smithycontext "github.com/aws/smithy-go/context" + "github.com/aws/smithy-go/internal/sync/singleflight" +) + +// package variable that can be override in unit tests. +var timeNow = time.Now + +// TokenCacheOptions provides a set of optional configuration options for the +// TokenCache TokenProvider. +type TokenCacheOptions struct { + // The duration before the token will expire when the credentials will be + // refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls + // will be blocking. + // + // Asynchronous refreshes are deduplicated, and only one will be in-flight + // at a time. If the token expires while an asynchronous refresh is in + // flight, the next call to RetrieveBearerToken will block on that refresh + // to return. + RefreshBeforeExpires time.Duration + + // The timeout the underlying TokenProvider's RetrieveBearerToken call must + // return within, or will be canceled. Defaults to 0, no timeout. + // + // If 0 timeout, its possible for the underlying tokenProvider's + // RetrieveBearerToken call to block forever. Preventing subsequent + // TokenCache attempts to refresh the token. + // + // If this timeout is reached all pending deduplicated calls to + // TokenCache RetrieveBearerToken will fail with an error. + RetrieveBearerTokenTimeout time.Duration + + // The minimum duration between asynchronous refresh attempts. If the next + // asynchronous recent refresh attempt was within the minimum delay + // duration, the call to retrieve will return the current cached token, if + // not expired. + // + // The asynchronous retrieve is deduplicated across multiple calls when + // RetrieveBearerToken is called. The asynchronous retrieve is not a + // periodic task. It is only performed when the token has not yet expired, + // and the current item is within the RefreshBeforeExpires window, and the + // TokenCache's RetrieveBearerToken method is called. + // + // If 0, (default) there will be no minimum delay between asynchronous + // refresh attempts. + // + // If DisableAsyncRefresh is true, this option is ignored. + AsyncRefreshMinimumDelay time.Duration + + // Sets if the TokenCache will attempt to refresh the token in the + // background asynchronously instead of blocking for credentials to be + // refreshed. If disabled token refresh will be blocking. + // + // The first call to RetrieveBearerToken will always be blocking, because + // there is no cached token. + DisableAsyncRefresh bool +} + +// TokenCache provides an utility to cache Bearer Authentication tokens from a +// wrapped TokenProvider. The TokenCache can be has options to configure the +// cache's early and asynchronous refresh of the token. +type TokenCache struct { + options TokenCacheOptions + provider TokenProvider + + cachedToken atomic.Value + lastRefreshAttemptTime atomic.Value + sfGroup singleflight.Group +} + +// NewTokenCache returns a initialized TokenCache that implements the +// TokenProvider interface. Wrapping the provider passed in. Also taking a set +// of optional functional option parameters to configure the token cache. +func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache { + var options TokenCacheOptions + for _, fn := range optFns { + fn(&options) + } + + return &TokenCache{ + options: options, + provider: provider, + } +} + +// RetrieveBearerToken returns the token if it could be obtained, or error if a +// valid token could not be retrieved. +// +// The passed in Context's cancel/deadline/timeout will impacting only this +// individual retrieve call and not any other already queued up calls. This +// means underlying provider's RetrieveBearerToken calls could block for ever, +// and not be canceled with the Context. Set RetrieveBearerTokenTimeout to +// provide a timeout, preventing the underlying TokenProvider blocking forever. +// +// By default, if the passed in Context is canceled, all of its values will be +// considered expired. The wrapped TokenProvider will not be able to lookup the +// values from the Context once it is expired. This is done to protect against +// expired values no longer being valid. To disable this behavior, use +// smithy-go's context.WithPreserveExpiredValues to add a value to the Context +// before calling RetrieveBearerToken to enable support for expired values. +// +// Without RetrieveBearerTokenTimeout there is the potential for a underlying +// Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent +// attempts at refreshing the token. +func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) { + cachedToken, ok := p.getCachedToken() + if !ok || cachedToken.Expired(timeNow()) { + return p.refreshBearerToken(ctx) + } + + // Check if the token should be refreshed before it expires. + refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires)) + if !refreshToken { + return cachedToken, nil + } + + if p.options.DisableAsyncRefresh { + return p.refreshBearerToken(ctx) + } + + p.tryAsyncRefresh(ctx) + + return cachedToken, nil +} + +// tryAsyncRefresh attempts to asynchronously refresh the token returning the +// already cached token. If it AsyncRefreshMinimumDelay option is not zero, and +// the duration since the last refresh is less than that value, nothing will be +// done. +func (p *TokenCache) tryAsyncRefresh(ctx context.Context) { + if p.options.AsyncRefreshMinimumDelay != 0 { + var lastRefreshAttempt time.Time + if v := p.lastRefreshAttemptTime.Load(); v != nil { + lastRefreshAttempt = v.(time.Time) + } + + if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) { + return + } + } + + // Ignore the returned channel so this won't be blocking, and limit the + // number of additional goroutines created. + p.sfGroup.DoChan("async-refresh", func() (interface{}, error) { + res, err := p.refreshBearerToken(ctx) + if p.options.AsyncRefreshMinimumDelay != 0 { + var refreshAttempt time.Time + if err != nil { + refreshAttempt = timeNow() + } + p.lastRefreshAttemptTime.Store(refreshAttempt) + } + + return res, err + }) +} + +func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) { + resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) { + ctx := smithycontext.WithSuppressCancel(ctx) + if v := p.options.RetrieveBearerTokenTimeout; v != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, v) + defer cancel() + } + return p.singleRetrieve(ctx) + }) + + select { + case res := <-resCh: + return res.Val.(Token), res.Err + case <-ctx.Done(): + return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err()) + } +} + +func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) { + token, err := p.provider.RetrieveBearerToken(ctx) + if err != nil { + return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err) + } + + p.cachedToken.Store(&token) + return token, nil +} + +// getCachedToken returns the currently cached token and true if found. Returns +// false if no token is cached. +func (p *TokenCache) getCachedToken() (Token, bool) { + v := p.cachedToken.Load() + if v == nil { + return Token{}, false + } + + t := v.(*Token) + if t == nil || t.Value == "" { + return Token{}, false + } + + return *t, true +} diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go new file mode 100644 index 0000000000..3d56f7ee63 --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go @@ -0,0 +1,512 @@ +package bearer + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +var _ TokenProvider = (*TokenCache)(nil) + +func TestTokenCache_cache(t *testing.T) { + expectToken := Token{ + Value: "abc123", + } + + var retrieveCalled bool + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if retrieveCalled { + t.Fatalf("expect wrapped provider to be called once") + } + retrieveCalled = true + return expectToken, nil + })) + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + for i := 0; i < 100; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } +} + +func TestTokenCache_cacheConcurrent(t *testing.T) { + expectToken := Token{ + Value: "abc123", + } + + var retrieveCalled bool + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if retrieveCalled { + t.Fatalf("expect wrapped provider to be called once") + } + retrieveCalled = true + return expectToken, nil + })) + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + for i := 0; i < 100; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + }) + } +} + +func TestTokenCache_expired(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + if atomic.AddInt32(retrievedCount, 1) > 1 { + return refreshedToken, nil + } + return expectToken, nil + })) + + for i := 0; i < 10; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } + if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Errorf("expect %v provider calls, got %v", e, a) + } + + // Offset time for refresh + timeNow = func() time.Time { + return (time.Time{}).Add(10 * time.Minute) + } + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } + if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Errorf("expect %v provider calls, got %v", e, a) + } +} + +func TestTokenCache_cancelled(t *testing.T) { + providerRunning := make(chan struct{}) + providerDone := make(chan struct{}) + var onceClose sync.Once + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + onceClose.Do(func() { close(providerRunning) }) + + // Provider running never receives context cancel so that if the first + // retrieve call is canceled all subsequent retrieve callers won't get + // canceled as well. + select { + case <-providerDone: + return Token{Value: "abc123"}, nil + case <-ctx.Done(): + return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err()) + } + })) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Retrieve that will have its context canceled, should return error, but + // underlying provider retrieve will continue to block in the background. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + _, err := provider.RetrieveBearerToken(ctx) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) { + t.Errorf("unexpected context canceled received, %v", err) + + } else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + }() + + <-providerRunning + + // Retrieve that will be added to existing single flight group, (or create + // a new group). Returning valid token. + wg.Add(1) + go func() { + defer wg.Done() + + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Errorf("expect no error, got %v", err) + } else { + if diff := cmp.Diff(Token{Value: "abc123"}, token); diff != "" { + t.Errorf("expect token retrieve match\n%s", diff) + } + } + }() + close(providerDone) + + wg.Wait() +} + +func TestTokenCache_cancelledWithTimeout(t *testing.T) { + providerReady := make(chan struct{}) + var providerReadCloseOnce sync.Once + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + providerReadCloseOnce.Do(func() { close(providerReady) }) + + <-ctx.Done() + return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err()) + }), func(o *TokenCacheOptions) { + o.RetrieveBearerTokenTimeout = time.Millisecond + }) + + var wg sync.WaitGroup + + // Spin up additional retrieves that will be deduplicated and block on the + // original retrieve call. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-providerReady + + _, err := provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + }() + } + + _, err := provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Errorf("expect error, got none") + + } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v error in, %v", e, a) + } + + wg.Wait() +} + +func TestTokenCache_asyncRefresh(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous + // refreshes. + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 4; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + } + // Wait for all async refreshes to complete + testWaitAsyncRefreshDone(provider) + + if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 { + t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c) + } + + // Ensure enough retrieves have been done to trigger refresh. + if c := atomic.LoadInt32(retrievedCount); c != 5 { + atomic.StoreInt32(retrievedCount, 4) + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + testWaitAsyncRefreshDone(provider) + } + + // Last async refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + o.AsyncRefreshMinimumDelay = 30 * time.Second + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous + // refreshes. + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 4; i++ { + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + // Wait for all async refreshes to complete ensure not deduped + testWaitAsyncRefreshDone(provider) + } + + // Only a single refresh attempt is expected. + if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v min async refresh, got %v", e, a) + } + + // Move time forward to ensure another async refresh is triggered. + timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) } + // Make sure the next attempt refreshes the token + atomic.StoreInt32(retrievedCount, 4) + + // Do async retrieve that will succeed refreshing in background. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + // Wait for all async refreshes to complete ensure not deduped + testWaitAsyncRefreshDone(provider) + + // Last async refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func TestTokenCache_disableAsyncRefresh(t *testing.T) { + origTimeNow := timeNow + defer func() { timeNow = origTimeNow }() + + timeNow = func() time.Time { return time.Time{} } + + expectToken := Token{ + Value: "abc123", + CanExpire: true, + Expires: timeNow().Add(10 * time.Minute), + } + refreshedToken := Token{ + Value: "refreshed-abc123", + CanExpire: true, + Expires: timeNow().Add(30 * time.Minute), + } + + retrievedCount := new(int32) + provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) { + c := atomic.AddInt32(retrievedCount, 1) + switch { + case c == 1: + return expectToken, nil + case c > 1 && c < 5: + return Token{}, fmt.Errorf("some error") + case c == 5: + return refreshedToken, nil + default: + return Token{}, fmt.Errorf("unexpected error") + } + }), func(o *TokenCacheOptions) { + o.RefreshBeforeExpires = 5 * time.Minute + o.DisableAsyncRefresh = true + }) + + // 1: Initial retrieve to cache token + token, err := provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(expectToken, token); diff != "" { + t.Errorf("expect token match\n%s", diff) + } + + // Update time into refresh window before token expires + timeNow = func() time.Time { + return (time.Time{}).Add(6 * time.Minute) + } + + for i := 0; i < 3; i++ { + _, err = provider.RetrieveBearerToken(context.Background()) + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := "some error", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v error in %v", e, a) + } + if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v retrieveCount, got %v", e, a) + } + } + if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a { + t.Fatalf("expect %v retrieveCount, got %v", e, a) + } + + // Last refresh will succeed and update cached token, expect the next + // call to get refreshed token. + token, err = provider.RetrieveBearerToken(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmp.Diff(refreshedToken, token); diff != "" { + t.Errorf("expect refreshed token match\n%s", diff) + } +} + +func testWaitAsyncRefreshDone(provider *TokenCache) { + asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) { + return nil, nil + }) + <-asyncResCh +} diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/ya.make b/vendor/github.com/aws/smithy-go/auth/bearer/ya.make new file mode 100644 index 0000000000..3dfbfa7c07 --- /dev/null +++ b/vendor/github.com/aws/smithy-go/auth/bearer/ya.make @@ -0,0 +1,21 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + docs.go + middleware.go + token.go + token_cache.go +) + +GO_TEST_SRCS( + middleware_test.go + token_cache_test.go +) + +END() + +RECURSE( + gotest +) |