aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/aws-sdk-go-v2/feature/ec2/imds/request_middleware.go
blob: c8abd64916c016c26adcab841babc372844903b6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
package imds

import (
	"bytes"
	"context"
	"fmt"
	"io/ioutil"
	"net/url"
	"path"
	"time"

	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
	"github.com/aws/aws-sdk-go-v2/aws/retry"
	"github.com/aws/smithy-go/middleware"
	smithyhttp "github.com/aws/smithy-go/transport/http"
)

func addAPIRequestMiddleware(stack *middleware.Stack,
	options Options,
	getPath func(interface{}) (string, error),
	getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
	err = addRequestMiddleware(stack, options, "GET", getPath, getOutput)
	if err != nil {
		return err
	}

	// Token Serializer build and state management.
	if !options.disableAPIToken {
		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
		if err != nil {
			return err
		}

		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
		if err != nil {
			return err
		}
	}

	return nil
}

func addRequestMiddleware(stack *middleware.Stack,
	options Options,
	method string,
	getPath func(interface{}) (string, error),
	getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
	if err != nil {
		return err
	}

	// Operation timeout
	err = stack.Initialize.Add(&operationTimeout{
		DefaultTimeout: defaultOperationTimeout,
	}, middleware.Before)
	if err != nil {
		return err
	}

	// Operation Serializer
	err = stack.Serialize.Add(&serializeRequest{
		GetPath: getPath,
		Method:  method,
	}, middleware.After)
	if err != nil {
		return err
	}

	// Operation endpoint resolver
	err = stack.Serialize.Insert(&resolveEndpoint{
		Endpoint:     options.Endpoint,
		EndpointMode: options.EndpointMode,
	}, "OperationSerializer", middleware.Before)
	if err != nil {
		return err
	}

	// Operation Deserializer
	err = stack.Deserialize.Add(&deserializeResponse{
		GetOutput: getOutput,
	}, middleware.After)
	if err != nil {
		return err
	}

	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
		LogRequest:          options.ClientLogMode.IsRequest(),
		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
		LogResponse:         options.ClientLogMode.IsResponse(),
		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
	}, middleware.After)
	if err != nil {
		return err
	}

	err = addSetLoggerMiddleware(stack, options)
	if err != nil {
		return err
	}

	// Retry support
	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
		Retryer:          options.Retryer,
		LogRetryAttempts: options.ClientLogMode.IsRetries(),
	})
}

func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
}

type serializeRequest struct {
	GetPath func(interface{}) (string, error)
	Method  string
}

func (*serializeRequest) ID() string {
	return "OperationSerializer"
}

func (m *serializeRequest) HandleSerialize(
	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
	request, ok := in.Request.(*smithyhttp.Request)
	if !ok {
		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
	}

	reqPath, err := m.GetPath(in.Parameters)
	if err != nil {
		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
	}

	request.Request.URL.Path = reqPath
	request.Request.Method = m.Method

	return next.HandleSerialize(ctx, in)
}

type deserializeResponse struct {
	GetOutput func(*smithyhttp.Response) (interface{}, error)
}

func (*deserializeResponse) ID() string {
	return "OperationDeserializer"
}

func (m *deserializeResponse) HandleDeserialize(
	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
) (
	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
	out, metadata, err = next.HandleDeserialize(ctx, in)
	if err != nil {
		return out, metadata, err
	}

	resp, ok := out.RawResponse.(*smithyhttp.Response)
	if !ok {
		return out, metadata, fmt.Errorf(
			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
	}
	defer resp.Body.Close()

	// read the full body so that any operation timeouts cleanup will not race
	// the body being read.
	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return out, metadata, fmt.Errorf("read response body failed, %w", err)
	}
	resp.Body = ioutil.NopCloser(bytes.NewReader(body))

	// Anything that's not 200 |< 300 is error
	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
		return out, metadata, &smithyhttp.ResponseError{
			Response: resp,
			Err:      fmt.Errorf("request to EC2 IMDS failed"),
		}
	}

	result, err := m.GetOutput(resp)
	if err != nil {
		return out, metadata, fmt.Errorf(
			"unable to get deserialized result for response, %w", err,
		)
	}
	out.Result = result

	return out, metadata, err
}

type resolveEndpoint struct {
	Endpoint     string
	EndpointMode EndpointModeState
}

func (*resolveEndpoint) ID() string {
	return "ResolveEndpoint"
}

func (m *resolveEndpoint) HandleSerialize(
	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {

	req, ok := in.Request.(*smithyhttp.Request)
	if !ok {
		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
	}

	var endpoint string
	if len(m.Endpoint) > 0 {
		endpoint = m.Endpoint
	} else {
		switch m.EndpointMode {
		case EndpointModeStateIPv6:
			endpoint = defaultIPv6Endpoint
		case EndpointModeStateIPv4:
			fallthrough
		case EndpointModeStateUnset:
			endpoint = defaultIPv4Endpoint
		default:
			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
		}
	}

	req.URL, err = url.Parse(endpoint)
	if err != nil {
		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
	}

	return next.HandleSerialize(ctx, in)
}

const (
	defaultOperationTimeout = 5 * time.Second
)

// operationTimeout adds a timeout on the middleware stack if the Context the
// stack was called with does not have a deadline. The next middleware must
// complete before the timeout, or the context will be canceled.
//
// If DefaultTimeout is zero, no default timeout will be used if the Context
// does not have a timeout.
//
// The next middleware must also ensure that any resources that are also
// canceled by the stack's context are completely consumed before returning.
// Otherwise the timeout cleanup will race the resource being consumed
// upstream.
type operationTimeout struct {
	DefaultTimeout time.Duration
}

func (*operationTimeout) ID() string { return "OperationTimeout" }

func (m *operationTimeout) HandleInitialize(
	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
) (
	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
		var cancelFn func()
		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
		defer cancelFn()
	}

	return next.HandleInitialize(ctx, input)
}

// appendURIPath joins a URI path component to the existing path with `/`
// separators between the path components. If the path being added ends with a
// trailing `/` that slash will be maintained.
func appendURIPath(base, add string) string {
	reqPath := path.Join(base, add)
	if len(add) != 0 && add[len(add)-1] == '/' {
		reqPath += "/"
	}
	return reqPath
}