aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
diff options
context:
space:
mode:
authorvitalyisaev <vitalyisaev@ydb.tech>2023-12-12 21:55:07 +0300
committervitalyisaev <vitalyisaev@ydb.tech>2023-12-12 22:25:10 +0300
commit4967f99474a4040ba150eb04995de06342252718 (patch)
treec9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
parent2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff)
downloadydb-4967f99474a4040ba150eb04995de06342252718.tar.gz
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.) 2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go')
-rw-r--r--vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go512
1 files changed, 512 insertions, 0 deletions
diff --git a/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
new file mode 100644
index 0000000000..3d56f7ee63
--- /dev/null
+++ b/vendor/github.com/aws/smithy-go/auth/bearer/token_cache_test.go
@@ -0,0 +1,512 @@
+package bearer
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+var _ TokenProvider = (*TokenCache)(nil)
+
+func TestTokenCache_cache(t *testing.T) {
+ expectToken := Token{
+ Value: "abc123",
+ }
+
+ var retrieveCalled bool
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if retrieveCalled {
+ t.Fatalf("expect wrapped provider to be called once")
+ }
+ retrieveCalled = true
+ return expectToken, nil
+ }))
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ for i := 0; i < 100; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+}
+
+func TestTokenCache_cacheConcurrent(t *testing.T) {
+ expectToken := Token{
+ Value: "abc123",
+ }
+
+ var retrieveCalled bool
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if retrieveCalled {
+ t.Fatalf("expect wrapped provider to be called once")
+ }
+ retrieveCalled = true
+ return expectToken, nil
+ }))
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ for i := 0; i < 100; i++ {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ t.Parallel()
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestTokenCache_expired(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ if atomic.AddInt32(retrievedCount, 1) > 1 {
+ return refreshedToken, nil
+ }
+ return expectToken, nil
+ }))
+
+ for i := 0; i < 10; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+ if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Errorf("expect %v provider calls, got %v", e, a)
+ }
+
+ // Offset time for refresh
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(10 * time.Minute)
+ }
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+ if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Errorf("expect %v provider calls, got %v", e, a)
+ }
+}
+
+func TestTokenCache_cancelled(t *testing.T) {
+ providerRunning := make(chan struct{})
+ providerDone := make(chan struct{})
+ var onceClose sync.Once
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ onceClose.Do(func() { close(providerRunning) })
+
+ // Provider running never receives context cancel so that if the first
+ // retrieve call is canceled all subsequent retrieve callers won't get
+ // canceled as well.
+ select {
+ case <-providerDone:
+ return Token{Value: "abc123"}, nil
+ case <-ctx.Done():
+ return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err())
+ }
+ }))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ // Retrieve that will have its context canceled, should return error, but
+ // underlying provider retrieve will continue to block in the background.
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ _, err := provider.RetrieveBearerToken(ctx)
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) {
+ t.Errorf("unexpected context canceled received, %v", err)
+
+ } else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+ }()
+
+ <-providerRunning
+
+ // Retrieve that will be added to existing single flight group, (or create
+ // a new group). Returning valid token.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ } else {
+ if diff := cmp.Diff(Token{Value: "abc123"}, token); diff != "" {
+ t.Errorf("expect token retrieve match\n%s", diff)
+ }
+ }
+ }()
+ close(providerDone)
+
+ wg.Wait()
+}
+
+func TestTokenCache_cancelledWithTimeout(t *testing.T) {
+ providerReady := make(chan struct{})
+ var providerReadCloseOnce sync.Once
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ providerReadCloseOnce.Do(func() { close(providerReady) })
+
+ <-ctx.Done()
+ return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err())
+ }), func(o *TokenCacheOptions) {
+ o.RetrieveBearerTokenTimeout = time.Millisecond
+ })
+
+ var wg sync.WaitGroup
+
+ // Spin up additional retrieves that will be deduplicated and block on the
+ // original retrieve call.
+ for i := 0; i < 5; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-providerReady
+
+ _, err := provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+ }()
+ }
+
+ _, err := provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Errorf("expect error, got none")
+
+ } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v error in, %v", e, a)
+ }
+
+ wg.Wait()
+}
+
+func TestTokenCache_asyncRefresh(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
+ // refreshes.
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 4; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ }
+ // Wait for all async refreshes to complete
+ testWaitAsyncRefreshDone(provider)
+
+ if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 {
+ t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c)
+ }
+
+ // Ensure enough retrieves have been done to trigger refresh.
+ if c := atomic.LoadInt32(retrievedCount); c != 5 {
+ atomic.StoreInt32(retrievedCount, 4)
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ testWaitAsyncRefreshDone(provider)
+ }
+
+ // Last async refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ o.AsyncRefreshMinimumDelay = 30 * time.Second
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
+ // refreshes.
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 4; i++ {
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ // Wait for all async refreshes to complete ensure not deduped
+ testWaitAsyncRefreshDone(provider)
+ }
+
+ // Only a single refresh attempt is expected.
+ if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v min async refresh, got %v", e, a)
+ }
+
+ // Move time forward to ensure another async refresh is triggered.
+ timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) }
+ // Make sure the next attempt refreshes the token
+ atomic.StoreInt32(retrievedCount, 4)
+
+ // Do async retrieve that will succeed refreshing in background.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+ // Wait for all async refreshes to complete ensure not deduped
+ testWaitAsyncRefreshDone(provider)
+
+ // Last async refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func TestTokenCache_disableAsyncRefresh(t *testing.T) {
+ origTimeNow := timeNow
+ defer func() { timeNow = origTimeNow }()
+
+ timeNow = func() time.Time { return time.Time{} }
+
+ expectToken := Token{
+ Value: "abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(10 * time.Minute),
+ }
+ refreshedToken := Token{
+ Value: "refreshed-abc123",
+ CanExpire: true,
+ Expires: timeNow().Add(30 * time.Minute),
+ }
+
+ retrievedCount := new(int32)
+ provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
+ c := atomic.AddInt32(retrievedCount, 1)
+ switch {
+ case c == 1:
+ return expectToken, nil
+ case c > 1 && c < 5:
+ return Token{}, fmt.Errorf("some error")
+ case c == 5:
+ return refreshedToken, nil
+ default:
+ return Token{}, fmt.Errorf("unexpected error")
+ }
+ }), func(o *TokenCacheOptions) {
+ o.RefreshBeforeExpires = 5 * time.Minute
+ o.DisableAsyncRefresh = true
+ })
+
+ // 1: Initial retrieve to cache token
+ token, err := provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(expectToken, token); diff != "" {
+ t.Errorf("expect token match\n%s", diff)
+ }
+
+ // Update time into refresh window before token expires
+ timeNow = func() time.Time {
+ return (time.Time{}).Add(6 * time.Minute)
+ }
+
+ for i := 0; i < 3; i++ {
+ _, err = provider.RetrieveBearerToken(context.Background())
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if e, a := "some error", err.Error(); !strings.Contains(a, e) {
+ t.Fatalf("expect %v error in %v", e, a)
+ }
+ if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v retrieveCount, got %v", e, a)
+ }
+ }
+ if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a {
+ t.Fatalf("expect %v retrieveCount, got %v", e, a)
+ }
+
+ // Last refresh will succeed and update cached token, expect the next
+ // call to get refreshed token.
+ token, err = provider.RetrieveBearerToken(context.Background())
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if diff := cmp.Diff(refreshedToken, token); diff != "" {
+ t.Errorf("expect refreshed token match\n%s", diff)
+ }
+}
+
+func testWaitAsyncRefreshDone(provider *TokenCache) {
+ asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
+ return nil, nil
+ })
+ <-asyncResCh
+}