aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/smithy-go/auth/bearer
diff options
context:
space:
mode:
authorvitalyisaev <vitalyisaev@ydb.tech>2023-12-12 21:55:07 +0300
committervitalyisaev <vitalyisaev@ydb.tech>2023-12-12 22:25:10 +0300
commit4967f99474a4040ba150eb04995de06342252718 (patch)
treec9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/smithy-go/auth/bearer
parent2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff)
downloadydb-4967f99474a4040ba150eb04995de06342252718.tar.gz
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.) 2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/smithy-go/auth/bearer')
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/docs.go3
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make5
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/middleware.go104
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go78
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/token.go50
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go208
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go512
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/ya.make21
8 files changed, 981 insertions, 0 deletions
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/docs.go b/vendor/github.com/aws/smithy-go/auth/bearer/docs.go
new file mode 100644
index 0000000000..1c9b9715cb
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/docs.go
@@ -0,0 +1,3 @@
+// Package bearer provides middleware and utilities for authenticating API
+// operation calls with a Bearer Token.
+package bearer
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make b/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make
new file mode 100644
index 0000000000..6cb218feab
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/gotest/ya.make
@@ -0,0 +1,5 @@
+GO_TEST_FOR(vendor/github.com/aws/smithy-go/auth/bearer)
+
+LICENSE(Apache-2.0)
+
+END()
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go b/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go
new file mode 100644
index 0000000000..8c7d720995
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/middleware.go
@@ -0,0 +1,104 @@
+package bearer
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+)
+
+// Message is the middleware stack's request transport message value.
+type Message interface{}
+
+// Signer provides an interface for implementations to decorate a request
+// message with a bearer token. The signer is responsible for validating the
+// message type is compatible with the signer.
+type Signer interface {
+ SignWithBearerToken(context.Context, Token, Message) (Message, error)
+}
+
+// AuthenticationMiddleware provides the Finalize middleware step for signing
+// an request message with a bearer token.
+type AuthenticationMiddleware struct {
+ signer Signer
+ tokenProvider TokenProvider
+}
+
+// AddAuthenticationMiddleware helper adds the AuthenticationMiddleware to the
+// middleware Stack in the Finalize step with the options provided.
+func AddAuthenticationMiddleware(s *middleware.Stack, signer Signer, tokenProvider TokenProvider) error {
+ return s.Finalize.Add(
+ NewAuthenticationMiddleware(signer, tokenProvider),
+ middleware.After,
+ )
+}
+
+// NewAuthenticationMiddleware returns an initialized AuthenticationMiddleware.
+func NewAuthenticationMiddleware(signer Signer, tokenProvider TokenProvider) *AuthenticationMiddleware {
+ return &AuthenticationMiddleware{
+ signer: signer,
+ tokenProvider: tokenProvider,
+ }
+}
+
+const authenticationMiddlewareID = "BearerTokenAuthentication"
+
+// ID returns the resolver identifier
+func (m *AuthenticationMiddleware) ID() string {
+ return authenticationMiddlewareID
+}
+
+// HandleFinalize implements the FinalizeMiddleware interface in order to
+// update the request with bearer token authentication.
+func (m *AuthenticationMiddleware) HandleFinalize(
+ ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
+) (
+ out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
+) {
+ token, err := m.tokenProvider.RetrieveBearerToken(ctx)
+ if err != nil {
+ return out, metadata, fmt.Errorf("failed AuthenticationMiddleware wrap message, %w", err)
+ }
+
+ signedMessage, err := m.signer.SignWithBearerToken(ctx, token, in.Request)
+ if err != nil {
+ return out, metadata, fmt.Errorf("failed AuthenticationMiddleware sign message, %w", err)
+ }
+
+ in.Request = signedMessage
+ return next.HandleFinalize(ctx, in)
+}
+
+// SignHTTPSMessage provides a bearer token authentication implementation that
+// will sign the message with the provided bearer token.
+//
+// Will fail if the message is not a smithy-go HTTP request or the request is
+// not HTTPS.
+type SignHTTPSMessage struct{}
+
+// NewSignHTTPSMessage returns an initialized signer for HTTP messages.
+func NewSignHTTPSMessage() *SignHTTPSMessage {
+ return &SignHTTPSMessage{}
+}
+
+// SignWithBearerToken returns a copy of the HTTP request with the bearer token
+// added via the "Authorization" header, per RFC 6750, https://datatracker.ietf.org/doc/html/rfc6750.
+//
+// Returns an error if the request's URL scheme is not HTTPS, or the request
+// message is not an smithy-go HTTP Request pointer type.
+func (SignHTTPSMessage) SignWithBearerToken(ctx context.Context, token Token, message Message) (Message, error) {
+ req, ok := message.(*smithyhttp.Request)
+ if !ok {
+ return nil, fmt.Errorf("expect smithy-go HTTP Request, got %T", message)
+ }
+
+ if !req.IsHTTPS() {
+ return nil, fmt.Errorf("bearer token with HTTP request requires HTTPS")
+ }
+
+ reqClone := req.Clone()
+ reqClone.Header.Set("Authorization", "Bearer "+token.Value)
+
+ return reqClone, nil
+}
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go b/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go
new file mode 100644
index 0000000000..e9604f089b
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/middleware_test.go
@@ -0,0 +1,78 @@
+package bearer
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+)
+
+func TestSignHTTPSMessage(t *testing.T) {
+ cases := map[string]struct {
+ message Message
+ token Token
+ expectMessage Message
+ expectErr string
+ }{
+ // Cases
+ "not smithyhttp.Request": {
+ message: struct{}{},
+ expectErr: "expect smithy-go HTTP Request",
+ },
+ "not https": {
+ message: func() Message {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ return r
+ }(),
+ expectErr: "requires HTTPS",
+ },
+ "success": {
+ message: func() Message {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ return r
+ }(),
+ token: Token{Value: "abc123"},
+ expectMessage: func() Message {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.Header.Set("Authorization", "Bearer abc123")
+ return r
+ }(),
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ ctx := context.Background()
+ signer := SignHTTPSMessage{}
+ message, err := signer.SignWithBearerToken(ctx, c.token, c.message)
+ if c.expectErr != "" {
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
+ t.Fatalf("expect %v in error %v", e, a)
+ }
+ return
+ } else if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ options := []cmp.Option{
+ cmpopts.IgnoreUnexported(smithyhttp.Request{}),
+ cmpopts.IgnoreUnexported(http.Request{}),
+ }
+
+ if diff := cmp.Diff(c.expectMessage, message, options...); diff != "" {
+ t.Errorf("expect match\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token.go b/vendor/github.com/aws/smithy-go/auth/bearer/token.go
new file mode 100644
index 0000000000..be260d4c76
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/token.go
@@ -0,0 +1,50 @@
+package bearer
+
+import (
+ "context"
+ "time"
+)
+
+// Token provides a type wrapping a bearer token and expiration metadata.
+type Token struct {
+ Value string
+
+ CanExpire bool
+ Expires time.Time
+}
+
+// Expired returns if the token's Expires time is before or equal to the time
+// provided. If CanExpires is false, Expired will always return false.
+func (t Token) Expired(now time.Time) bool {
+ if !t.CanExpire {
+ return false
+ }
+ now = now.Round(0)
+ return now.Equal(t.Expires) || now.After(t.Expires)
+}
+
+// TokenProvider provides interface for retrieving bearer tokens.
+type TokenProvider interface {
+ RetrieveBearerToken(context.Context) (Token, error)
+}
+
+// TokenProviderFunc provides a helper utility to wrap a function as a type
+// that implements the TokenProvider interface.
+type TokenProviderFunc func(context.Context) (Token, error)
+
+// RetrieveBearerToken calls the wrapped function, returning the Token or
+// error.
+func (fn TokenProviderFunc) RetrieveBearerToken(ctx context.Context) (Token, error) {
+ return fn(ctx)
+}
+
+// StaticTokenProvider provides a utility for wrapping a static bearer token
+// value within an implementation of a token provider.
+type StaticTokenProvider struct {
+ Token Token
+}
+
+// RetrieveBearerToken returns the static token specified.
+func (s StaticTokenProvider) RetrieveBearerToken(context.Context) (Token, error) {
+ return s.Token, nil
+}
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go
new file mode 100644
index 0000000000..223ddf52bb
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache.go
@@ -0,0 +1,208 @@
+package bearer
+
+import (
+ "context"
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ smithycontext "github.com/aws/smithy-go/context"
+ "github.com/aws/smithy-go/internal/sync/singleflight"
+)
+
+// package variable that can be override in unit tests.
+var timeNow = time.Now
+
+// TokenCacheOptions provides a set of optional configuration options for the
+// TokenCache TokenProvider.
+type TokenCacheOptions struct {
+ // The duration before the token will expire when the credentials will be
+ // refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls
+ // will be blocking.
+ //
+ // Asynchronous refreshes are deduplicated, and only one will be in-flight
+ // at a time. If the token expires while an asynchronous refresh is in
+ // flight, the next call to RetrieveBearerToken will block on that refresh
+ // to return.
+ RefreshBeforeExpires time.Duration
+
+ // The timeout the underlying TokenProvider's RetrieveBearerToken call must
+ // return within, or will be canceled. Defaults to 0, no timeout.
+ //
+ // If 0 timeout, its possible for the underlying tokenProvider's
+ // RetrieveBearerToken call to block forever. Preventing subsequent
+ // TokenCache attempts to refresh the token.
+ //
+ // If this timeout is reached all pending deduplicated calls to
+ // TokenCache RetrieveBearerToken will fail with an error.
+ RetrieveBearerTokenTimeout time.Duration
+
+ // The minimum duration between asynchronous refresh attempts. If the next
+ // asynchronous recent refresh attempt was within the minimum delay
+ // duration, the call to retrieve will return the current cached token, if
+ // not expired.
+ //
+ // The asynchronous retrieve is deduplicated across multiple calls when
+ // RetrieveBearerToken is called. The asynchronous retrieve is not a
+ // periodic task. It is only performed when the token has not yet expired,
+ // and the current item is within the RefreshBeforeExpires window, and the
+ // TokenCache's RetrieveBearerToken method is called.
+ //
+ // If 0, (default) there will be no minimum delay between asynchronous
+ // refresh attempts.
+ //
+ // If DisableAsyncRefresh is true, this option is ignored.
+ AsyncRefreshMinimumDelay time.Duration
+
+ // Sets if the TokenCache will attempt to refresh the token in the
+ // background asynchronously instead of blocking for credentials to be
+ // refreshed. If disabled token refresh will be blocking.
+ //
+ // The first call to RetrieveBearerToken will always be blocking, because
+ // there is no cached token.
+ DisableAsyncRefresh bool
+}
+
+// TokenCache provides an utility to cache Bearer Authentication tokens from a
+// wrapped TokenProvider. The TokenCache can be has options to configure the
+// cache's early and asynchronous refresh of the token.
+type TokenCache struct {
+ options TokenCacheOptions
+ provider TokenProvider
+
+ cachedToken atomic.Value
+ lastRefreshAttemptTime atomic.Value
+ sfGroup singleflight.Group
+}
+
+// NewTokenCache returns a initialized TokenCache that implements the
+// TokenProvider interface. Wrapping the provider passed in. Also taking a set
+// of optional functional option parameters to configure the token cache.
+func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache {
+ var options TokenCacheOptions
+ for _, fn := range optFns {
+ fn(&options)
+ }
+
+ return &TokenCache{
+ options: options,
+ provider: provider,
+ }
+}
+
+// RetrieveBearerToken returns the token if it could be obtained, or error if a
+// valid token could not be retrieved.
+//
+// The passed in Context's cancel/deadline/timeout will impacting only this
+// individual retrieve call and not any other already queued up calls. This
+// means underlying provider's RetrieveBearerToken calls could block for ever,
+// and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
+// provide a timeout, preventing the underlying TokenProvider blocking forever.
+//
+// By default, if the passed in Context is canceled, all of its values will be
+// considered expired. The wrapped TokenProvider will not be able to lookup the
+// values from the Context once it is expired. This is done to protect against
+// expired values no longer being valid. To disable this behavior, use
+// smithy-go's context.WithPreserveExpiredValues to add a value to the Context
+// before calling RetrieveBearerToken to enable support for expired values.
+//
+// Without RetrieveBearerTokenTimeout there is the potential for a underlying
+// Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
+// attempts at refreshing the token.
+func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) {
+ cachedToken, ok := p.getCachedToken()
+ if !ok || cachedToken.Expired(timeNow()) {
+ return p.refreshBearerToken(ctx)
+ }
+
+ // Check if the token should be refreshed before it expires.
+ refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires))
+ if !refreshToken {
+ return cachedToken, nil
+ }
+
+ if p.options.DisableAsyncRefresh {
+ return p.refreshBearerToken(ctx)
+ }
+
+ p.tryAsyncRefresh(ctx)
+
+ return cachedToken, nil
+}
+
+// tryAsyncRefresh attempts to asynchronously refresh the token returning the
+// already cached token. If it AsyncRefreshMinimumDelay option is not zero, and
+// the duration since the last refresh is less than that value, nothing will be
+// done.
+func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {
+ if p.options.AsyncRefreshMinimumDelay != 0 {
+ var lastRefreshAttempt time.Time
+ if v := p.lastRefreshAttemptTime.Load(); v != nil {
+ lastRefreshAttempt = v.(time.Time)
+ }
+
+ if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) {
+ return
+ }
+ }
+
+ // Ignore the returned channel so this won't be blocking, and limit the
+ // number of additional goroutines created.
+ p.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
+ res, err := p.refreshBearerToken(ctx)
+ if p.options.AsyncRefreshMinimumDelay != 0 {
+ var refreshAttempt time.Time
+ if err != nil {
+ refreshAttempt = timeNow()
+ }
+ p.lastRefreshAttemptTime.Store(refreshAttempt)
+ }
+
+ return res, err
+ })
+}
+
+func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
+ resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
+ ctx := smithycontext.WithSuppressCancel(ctx)
+ if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
+ var cancel func()
+ ctx, cancel = context.WithTimeout(ctx, v)
+ defer cancel()
+ }
+ return p.singleRetrieve(ctx)
+ })
+
+ select {
+ case res := <-resCh:
+ return res.Val.(Token), res.Err
+ case <-ctx.Done():
+ return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err())
+ }
+}
+
+func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) {
+ token, err := p.provider.RetrieveBearerToken(ctx)
+ if err != nil {
+ return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err)
+ }
+
+ p.cachedToken.Store(&token)
+ return token, nil
+}
+
+// getCachedToken returns the currently cached token and true if found. Returns
+// false if no token is cached.
+func (p *TokenCache) getCachedToken() (Token, bool) {
+ v := p.cachedToken.Load()
+ if v == nil {
+ return Token{}, false
+ }
+
+ t := v.(*Token)
+ if t == nil || t.Value == "" {
+ return Token{}, false
+ }
+
+ return *t, true
+}
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
new file mode 100644
index 0000000000..3d56f7ee63
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
@@ -0,0 +1,512 @@
+package bearer
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+var _ TokenProvider = (*TokenCache)(nil)
+
+func TestTokenCache_cache(t *testing.T) {
+ expectToken := Token{
+ Value: "abc123",
+ }
+
+ var retrieveCalled bool
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if retrieveCalled {
+ t.Fatalf("expect wrapped provider to be called once")
+ }
+ retrieveCalled = true
+ return expectToken, nil
+ }))
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ for i := 0; i < 100; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+}
+
+func TestTokenCache_cacheConcurrent(t *testing.T) {
+ expectToken := Token{
+ Value: "abc123",
+ }
+
+ var retrieveCalled bool
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if retrieveCalled {
+ t.Fatalf("expect wrapped provider to be called once")
+ }
+ retrieveCalled = true
+ return expectToken, nil
+ }))
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ for i := 0; i < 100; i++ {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ t.Parallel()
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestTokenCache_expired(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if atomic.AddInt32(retrievedCount, 1) > 1 {
+ return refreshedToken, nil
+ }
+ return expectToken, nil
+ }))
+
+ for i := 0; i < 10; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+ if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Errorf("expect %v provider calls, got %v", e, a)
+ }
+
+ // Offset time for refresh
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(10 * time.Minute)
+ }
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+ if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Errorf("expect %v provider calls, got %v", e, a)
+ }
+}
+
+func TestTokenCache_cancelled(t *testing.T) {
+ providerRunning := make(chan struct{})
+ providerDone := make(chan struct{})
+ var onceClose sync.Once
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ onceClose.Do(func() { close(providerRunning) })
+
+ // Provider running never receives context cancel so that if the first
+ // retrieve call is canceled all subsequent retrieve callers won't get
+ // canceled as well.
+ select {
+ case <-providerDone:
+ return Token{Value: "abc123"}, nil
+ case <-ctx.Done():
+ return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err())
+ }
+ }))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ // Retrieve that will have its context canceled, should return error, but
+ // underlying provider retrieve will continue to block in the background.
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ _, err := provider.RetrieveBearerToken(ctx)
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) {
+ t.Errorf("unexpected context canceled received, %v", err)
+
+ } else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+ }()
+
+ <-providerRunning
+
+ // Retrieve that will be added to existing single flight group, (or create
+ // a new group). Returning valid token.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ } else {
+ if diff := cmp.Diff(Token{Value: "abc123"}, token); diff != "" {
+ t.Errorf("expect token retrieve match\n%s", diff)
+ }
+ }
+ }()
+ close(providerDone)
+
+ wg.Wait()
+}
+
+func TestTokenCache_cancelledWithTimeout(t *testing.T) {
+ providerReady := make(chan struct{})
+ var providerReadCloseOnce sync.Once
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ providerReadCloseOnce.Do(func() { close(providerReady) })
+
+ <-ctx.Done()
+ return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err())
+ }), func(o *TokenCacheOptions) {
+ o.RetrieveBearerTokenTimeout = time.Millisecond
+ })
+
+ var wg sync.WaitGroup
+
+ // Spin up additional retrieves that will be deduplicated and block on the
+ // original retrieve call.
+ for i := 0; i < 5; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-providerReady
+
+ _, err := provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+ }()
+ }
+
+ _, err := provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+
+ wg.Wait()
+}
+
+func TestTokenCache_asyncRefresh(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
+ // refreshes.
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 4; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+ // Wait for all async refreshes to complete
+ testWaitAsyncRefreshDone(provider)
+
+ if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 {
+ t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c)
+ }
+
+ // Ensure enough retrieves have been done to trigger refresh.
+ if c := atomic.LoadInt32(retrievedCount); c != 5 {
+ atomic.StoreInt32(retrievedCount, 4)
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ testWaitAsyncRefreshDone(provider)
+ }
+
+ // Last async refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ o.AsyncRefreshMinimumDelay = 30 * time.Second
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
+ // refreshes.
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 4; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ // Wait for all async refreshes to complete ensure not deduped
+ testWaitAsyncRefreshDone(provider)
+ }
+
+ // Only a single refresh attempt is expected.
+ if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v min async refresh, got %v", e, a)
+ }
+
+ // Move time forward to ensure another async refresh is triggered.
+ timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) }
+ // Make sure the next attempt refreshes the token
+ atomic.StoreInt32(retrievedCount, 4)
+
+ // Do async retrieve that will succeed refreshing in background.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ // Wait for all async refreshes to complete ensure not deduped
+ testWaitAsyncRefreshDone(provider)
+
+ // Last async refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func TestTokenCache_disableAsyncRefresh(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ o.DisableAsyncRefresh = true
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // Update time into refresh window before token expires
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 3; i++ {
+ _, err = provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if e, a := "some error", err.Error(); !strings.Contains(a, e) {
+ t.Fatalf("expect %v error in %v", e, a)
+ }
+ if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v retrieveCount, got %v", e, a)
+ }
+ }
+ if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v retrieveCount, got %v", e, a)
+ }
+
+ // Last refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func testWaitAsyncRefreshDone(provider *TokenCache) {
+ asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
+ return nil, nil
+ })
+ <-asyncResCh
+}
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/ya.make b/vendor/github.com/aws/smithy-go/auth/bearer/ya.make
new file mode 100644
index 0000000000..3dfbfa7c07
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/ya.make
@@ -0,0 +1,21 @@
+GO_LIBRARY()
+
+LICENSE(Apache-2.0)
+
+SRCS(
+ docs.go
+ middleware.go
+ token.go
+ token_cache.go
+)
+
+GO_TEST_SRCS(
+ middleware_test.go
+ token_cache_test.go
+)
+
+END()
+
+RECURSE(
+ gotest
+)