aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/smithy-go/transport/http/client_test.go
blob: 274b3f8db9d87532515bbda563f286dd9565d41d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package http

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"testing"

	smithy "github.com/aws/smithy-go"
)

func TestClientHandler_Handle(t *testing.T) {
	cases := map[string]struct {
		Context   context.Context
		Client    ClientDo
		ExpectErr func(error) error
	}{
		"no error": {
			Context: context.Background(),
			Client: ClientDoFunc(func(*http.Request) (*http.Response, error) {
				return &http.Response{}, nil
			}),
		},
		"send error": {
			Context: context.Background(),
			Client: ClientDoFunc(func(*http.Request) (*http.Response, error) {
				return nil, fmt.Errorf("some error")
			}),
			ExpectErr: func(err error) error {
				var sendError *RequestSendError
				if !errors.As(err, &sendError) {
					return fmt.Errorf("expect error to be %T, %v", sendError, err)
				}

				var cancelError *smithy.CanceledError
				if errors.As(err, &cancelError) {
					return fmt.Errorf("expect error to not be %T, %v", cancelError, err)
				}

				return nil
			},
		},
		"canceled error": {
			Context: func() context.Context {
				ctx, fn := context.WithCancel(context.Background())
				fn()
				return ctx
			}(),
			Client: ClientDoFunc(func(*http.Request) (*http.Response, error) {
				return nil, fmt.Errorf("some error")
			}),
			ExpectErr: func(err error) error {
				var sendError *RequestSendError
				if errors.As(err, &sendError) {
					return fmt.Errorf("expect error to not be %T, %v", sendError, err)
				}

				var cancelError *smithy.CanceledError
				if !errors.As(err, &cancelError) {
					return fmt.Errorf("expect error to be %T, %v", cancelError, err)
				}

				return nil
			},
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {
			handler := NewClientHandler(c.Client)
			resp, _, err := handler.Handle(c.Context, NewStackRequest())

			if c.ExpectErr != nil {
				if err == nil {
					t.Fatalf("expect error, got none")
				}
				if err = c.ExpectErr(err); err != nil {
					t.Fatalf("expect error match failed, %v", err)
				}
				return
			}
			if err != nil {
				t.Fatalf("expect no error, got %v", err)
			}

			if _, ok := resp.(*Response); !ok {
				t.Fatalf("expect Response type, got %T", resp)
			}
		})
	}

}