diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 21:55:07 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 22:25:10 +0300 |
commit | 4967f99474a4040ba150eb04995de06342252718 (patch) | |
tree | c9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/aws-sdk-go-v2/service/internal | |
parent | 2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff) | |
download | ydb-4967f99474a4040ba150eb04995de06342252718.tar.gz |
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.)
2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/aws-sdk-go-v2/service/internal')
56 files changed, 7164 insertions, 0 deletions
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip.go new file mode 100644 index 0000000000..3f451fc9b4 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip.go @@ -0,0 +1,176 @@ +package acceptencoding + +import ( + "compress/gzip" + "context" + "fmt" + "io" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const acceptEncodingHeaderKey = "Accept-Encoding" +const contentEncodingHeaderKey = "Content-Encoding" + +// AddAcceptEncodingGzipOptions provides the options for the +// AddAcceptEncodingGzip middleware setup. +type AddAcceptEncodingGzipOptions struct { + Enable bool +} + +// AddAcceptEncodingGzip explicitly adds handling for accept-encoding GZIP +// middleware to the operation stack. This allows checksums to be correctly +// computed without disabling GZIP support. +func AddAcceptEncodingGzip(stack *middleware.Stack, options AddAcceptEncodingGzipOptions) error { + if options.Enable { + if err := stack.Finalize.Add(&EnableGzip{}, middleware.Before); err != nil { + return err + } + if err := stack.Deserialize.Insert(&DecompressGzip{}, "OperationDeserializer", middleware.After); err != nil { + return err + } + return nil + } + + return stack.Finalize.Add(&DisableGzip{}, middleware.Before) +} + +// DisableGzip provides the middleware that will +// disable the underlying http client automatically enabling for gzip +// decompress content-encoding support. +type DisableGzip struct{} + +// ID returns the id for the middleware. +func (*DisableGzip) ID() string { + return "DisableAcceptEncodingGzip" +} + +// HandleFinalize implements the FinalizeMiddleware interface. +func (*DisableGzip) HandleFinalize( + ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + output middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + req, ok := input.Request.(*smithyhttp.Request) + if !ok { + return output, metadata, &smithy.SerializationError{ + Err: fmt.Errorf("unknown request type %T", input.Request), + } + } + + // Explicitly enable gzip support, this will prevent the http client from + // auto extracting the zipped content. + req.Header.Set(acceptEncodingHeaderKey, "identity") + + return next.HandleFinalize(ctx, input) +} + +// EnableGzip provides a middleware to enable support for +// gzip responses, with manual decompression. This prevents the underlying HTTP +// client from performing the gzip decompression automatically. +type EnableGzip struct{} + +// ID returns the id for the middleware. +func (*EnableGzip) ID() string { + return "AcceptEncodingGzip" +} + +// HandleFinalize implements the FinalizeMiddleware interface. +func (*EnableGzip) HandleFinalize( + ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + output middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + req, ok := input.Request.(*smithyhttp.Request) + if !ok { + return output, metadata, &smithy.SerializationError{ + Err: fmt.Errorf("unknown request type %T", input.Request), + } + } + + // Explicitly enable gzip support, this will prevent the http client from + // auto extracting the zipped content. + req.Header.Set(acceptEncodingHeaderKey, "gzip") + + return next.HandleFinalize(ctx, input) +} + +// DecompressGzip provides the middleware for decompressing a gzip +// response from the service. +type DecompressGzip struct{} + +// ID returns the id for the middleware. +func (*DecompressGzip) ID() string { + return "DecompressGzip" +} + +// HandleDeserialize implements the DeserializeMiddlware interface. +func (*DecompressGzip) HandleDeserialize( + ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler, +) ( + output middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + output, metadata, err = next.HandleDeserialize(ctx, input) + if err != nil { + return output, metadata, err + } + + resp, ok := output.RawResponse.(*smithyhttp.Response) + if !ok { + return output, metadata, &smithy.DeserializationError{ + Err: fmt.Errorf("unknown response type %T", output.RawResponse), + } + } + if v := resp.Header.Get(contentEncodingHeaderKey); v != "gzip" { + return output, metadata, err + } + + // Clear content length since it will no longer be valid once the response + // body is decompressed. + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + + resp.Body = wrapGzipReader(resp.Body) + + return output, metadata, err +} + +type gzipReader struct { + reader io.ReadCloser + gzip *gzip.Reader +} + +func wrapGzipReader(reader io.ReadCloser) *gzipReader { + return &gzipReader{ + reader: reader, + } +} + +// Read wraps the gzip reader around the underlying io.Reader to extract the +// response bytes on the fly. +func (g *gzipReader) Read(b []byte) (n int, err error) { + if g.gzip == nil { + g.gzip, err = gzip.NewReader(g.reader) + if err != nil { + g.gzip = nil // ensure uninitialized gzip value isn't used in close. + return 0, fmt.Errorf("failed to decompress gzip response, %w", err) + } + } + + return g.gzip.Read(b) +} + +func (g *gzipReader) Close() error { + if g.gzip == nil { + return nil + } + + if err := g.gzip.Close(); err != nil { + g.reader.Close() + return fmt.Errorf("failed to decompress gzip response, %w", err) + } + + return g.reader.Close() +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip_test.go new file mode 100644 index 0000000000..71fb7e26d5 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/accept_encoding_gzip_test.go @@ -0,0 +1,215 @@ +package acceptencoding + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/hex" + "io" + "io/ioutil" + "net/http" + "testing" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestAddAcceptEncodingGzip(t *testing.T) { + cases := map[string]struct { + Enable bool + }{ + "disabled": { + Enable: false, + }, + "enabled": { + Enable: true, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + stack.Deserialize.Add(&stubOpDeserializer{}, middleware.After) + + AddAcceptEncodingGzip(stack, AddAcceptEncodingGzipOptions{ + Enable: c.Enable, + }) + + id := "OperationDeserializer" + if m, ok := stack.Deserialize.Get(id); !ok || m == nil { + t.Fatalf("expect %s not to be removed", id) + } + + if c.Enable { + id = (*EnableGzip)(nil).ID() + if m, ok := stack.Finalize.Get(id); !ok || m == nil { + t.Fatalf("expect %s to be present.", id) + } + + id = (*DecompressGzip)(nil).ID() + if m, ok := stack.Deserialize.Get(id); !ok || m == nil { + t.Fatalf("expect %s to be present.", id) + } + return + } + id = (*EnableGzip)(nil).ID() + if m, ok := stack.Finalize.Get(id); ok || m != nil { + t.Fatalf("expect %s not to be present.", id) + } + + id = (*DecompressGzip)(nil).ID() + if m, ok := stack.Deserialize.Get(id); ok || m != nil { + t.Fatalf("expect %s not to be present.", id) + } + }) + } +} + +func TestAcceptEncodingGzipMiddleware(t *testing.T) { + m := &EnableGzip{} + + _, _, err := m.HandleFinalize(context.Background(), + middleware.FinalizeInput{ + Request: smithyhttp.NewStackRequest(), + }, + middleware.FinalizeHandlerFunc( + func(ctx context.Context, input middleware.FinalizeInput) ( + output middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + req, ok := input.Request.(*smithyhttp.Request) + if !ok || req == nil { + t.Fatalf("expect smithy request, got %T", input.Request) + } + + actual := req.Header.Get(acceptEncodingHeaderKey) + if e, a := "gzip", actual; e != a { + t.Errorf("expect %v accept-encoding, got %v", e, a) + } + + return output, metadata, err + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +func TestDecompressGzipMiddleware(t *testing.T) { + cases := map[string]struct { + Response *smithyhttp.Response + ExpectBody []byte + ExpectContentLength int64 + }{ + "not compressed": { + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: http.Header{}, + ContentLength: 2, + Body: &wasClosedReadCloser{ + Reader: bytes.NewBuffer([]byte(`{}`)), + }, + }, + }, + ExpectBody: []byte(`{}`), + ExpectContentLength: 2, + }, + "compressed": { + Response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: http.Header{ + contentEncodingHeaderKey: []string{"gzip"}, + }, + ContentLength: 10, + Body: func() io.ReadCloser { + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + w.Write([]byte(`{}`)) + w.Close() + + return &wasClosedReadCloser{Reader: &buf} + }(), + }, + }, + ExpectBody: []byte(`{}`), + ExpectContentLength: -1, // Length empty because was decompressed + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := &DecompressGzip{} + + var origRespBody io.Reader + output, _, err := m.HandleDeserialize(context.Background(), + middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc( + func(ctx context.Context, input middleware.DeserializeInput) ( + output middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + output.RawResponse = c.Response + origRespBody = c.Response.Body + return output, metadata, err + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + resp, ok := output.RawResponse.(*smithyhttp.Response) + if !ok || resp == nil { + t.Fatalf("expect smithy request, got %T", output.RawResponse) + } + + if e, a := c.ExpectContentLength, resp.ContentLength; e != a { + t.Errorf("expect %v content-length, got %v", e, a) + } + + actual, err := ioutil.ReadAll(resp.Body) + if e, a := c.ExpectBody, actual; !bytes.Equal(e, a) { + t.Errorf("expect body equal\nexpect:\n%s\nactual:\n%s", + hex.Dump(e), hex.Dump(a)) + } + + if err := resp.Body.Close(); err != nil { + t.Fatalf("expect no close error, got %v", err) + } + + if c, ok := origRespBody.(interface{ WasClosed() bool }); ok { + if !c.WasClosed() { + t.Errorf("expect original reader closed, but was not") + } + } + }) + } +} + +type stubOpDeserializer struct{} + +func (*stubOpDeserializer) ID() string { return "OperationDeserializer" } +func (*stubOpDeserializer) HandleDeserialize( + ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler, +) ( + output middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + return next.HandleDeserialize(ctx, input) +} + +type wasClosedReadCloser struct { + io.Reader + closed bool +} + +func (c *wasClosedReadCloser) WasClosed() bool { + return c.closed +} + +func (c *wasClosedReadCloser) Close() error { + c.closed = true + if v, ok := c.Reader.(io.Closer); ok { + return v.Close() + } + return nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/doc.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/doc.go new file mode 100644 index 0000000000..7056d9bf6f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/doc.go @@ -0,0 +1,22 @@ +/* +Package acceptencoding provides customizations associated with Accept Encoding Header. + +# Accept encoding gzip + +The Go HTTP client automatically supports accept-encoding and content-encoding +gzip by default. This default behavior is not desired by the SDK, and prevents +validating the response body's checksum. To prevent this the SDK must manually +control usage of content-encoding gzip. + +To control content-encoding, the SDK must always set the `Accept-Encoding` +header to a value. This prevents the HTTP client from using gzip automatically. +When gzip is enabled on the API client, the SDK's customization will control +decompressing the gzip data in order to not break the checksum validation. When +gzip is disabled, the API client will disable gzip, preventing the HTTP +client's default behavior. + +An `EnableAcceptEncodingGzip` option may or may not be present depending on the client using +the below middleware. The option if present can be used to enable auto decompressing +gzip by the SDK. +*/ +package acceptencoding diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/go_module_metadata.go new file mode 100644 index 0000000000..e57bbab9fe --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/go_module_metadata.go @@ -0,0 +1,6 @@ +// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT. + +package acceptencoding + +// goModuleVersion is the tagged release for this module +const goModuleVersion = "1.10.0" diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/gotest/ya.make new file mode 100644 index 0000000000..87dff7ba39 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/ya.make new file mode 100644 index 0000000000..9e18eb7a5a --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding/ya.make @@ -0,0 +1,17 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + accept_encoding_gzip.go + doc.go + go_module_metadata.go +) + +GO_TEST_SRCS(accept_encoding_gzip_test.go) + +END() + +RECURSE( + gotest +) diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go new file mode 100644 index 0000000000..a17041c35d --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go @@ -0,0 +1,323 @@ +package checksum + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "hash/crc32" + "io" + "strings" + "sync" +) + +// Algorithm represents the checksum algorithms supported +type Algorithm string + +// Enumeration values for supported checksum Algorithms. +const ( + // AlgorithmCRC32C represents CRC32C hash algorithm + AlgorithmCRC32C Algorithm = "CRC32C" + + // AlgorithmCRC32 represents CRC32 hash algorithm + AlgorithmCRC32 Algorithm = "CRC32" + + // AlgorithmSHA1 represents SHA1 hash algorithm + AlgorithmSHA1 Algorithm = "SHA1" + + // AlgorithmSHA256 represents SHA256 hash algorithm + AlgorithmSHA256 Algorithm = "SHA256" +) + +var supportedAlgorithms = []Algorithm{ + AlgorithmCRC32C, + AlgorithmCRC32, + AlgorithmSHA1, + AlgorithmSHA256, +} + +func (a Algorithm) String() string { return string(a) } + +// ParseAlgorithm attempts to parse the provided value into a checksum +// algorithm, matching without case. Returns the algorithm matched, or an error +// if the algorithm wasn't matched. +func ParseAlgorithm(v string) (Algorithm, error) { + for _, a := range supportedAlgorithms { + if strings.EqualFold(string(a), v) { + return a, nil + } + } + return "", fmt.Errorf("unknown checksum algorithm, %v", v) +} + +// FilterSupportedAlgorithms filters the set of algorithms, returning a slice +// of algorithms that are supported. +func FilterSupportedAlgorithms(vs []string) []Algorithm { + found := map[Algorithm]struct{}{} + + supported := make([]Algorithm, 0, len(supportedAlgorithms)) + for _, v := range vs { + for _, a := range supportedAlgorithms { + // Only consider algorithms that are supported + if !strings.EqualFold(v, string(a)) { + continue + } + // Ignore duplicate algorithms in list. + if _, ok := found[a]; ok { + continue + } + + supported = append(supported, a) + found[a] = struct{}{} + } + } + return supported +} + +// NewAlgorithmHash returns a hash.Hash for the checksum algorithm. Error is +// returned if the algorithm is unknown. +func NewAlgorithmHash(v Algorithm) (hash.Hash, error) { + switch v { + case AlgorithmSHA1: + return sha1.New(), nil + case AlgorithmSHA256: + return sha256.New(), nil + case AlgorithmCRC32: + return crc32.NewIEEE(), nil + case AlgorithmCRC32C: + return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil + default: + return nil, fmt.Errorf("unknown checksum algorithm, %v", v) + } +} + +// AlgorithmChecksumLength returns the length of the algorithm's checksum in +// bytes. If the algorithm is not known, an error is returned. +func AlgorithmChecksumLength(v Algorithm) (int, error) { + switch v { + case AlgorithmSHA1: + return sha1.Size, nil + case AlgorithmSHA256: + return sha256.Size, nil + case AlgorithmCRC32: + return crc32.Size, nil + case AlgorithmCRC32C: + return crc32.Size, nil + default: + return 0, fmt.Errorf("unknown checksum algorithm, %v", v) + } +} + +const awsChecksumHeaderPrefix = "x-amz-checksum-" + +// AlgorithmHTTPHeader returns the HTTP header for the algorithm's hash. +func AlgorithmHTTPHeader(v Algorithm) string { + return awsChecksumHeaderPrefix + strings.ToLower(string(v)) +} + +// base64EncodeHashSum computes base64 encoded checksum of a given running +// hash. The running hash must already have content written to it. Returns the +// byte slice of checksum and an error +func base64EncodeHashSum(h hash.Hash) []byte { + sum := h.Sum(nil) + sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum))) + base64.StdEncoding.Encode(sum64, sum) + return sum64 +} + +// hexEncodeHashSum computes hex encoded checksum of a given running hash. The +// running hash must already have content written to it. Returns the byte slice +// of checksum and an error +func hexEncodeHashSum(h hash.Hash) []byte { + sum := h.Sum(nil) + sumHex := make([]byte, hex.EncodedLen(len(sum))) + hex.Encode(sumHex, sum) + return sumHex +} + +// computeMD5Checksum computes base64 MD5 checksum of an io.Reader's contents. +// Returns the byte slice of MD5 checksum and an error. +func computeMD5Checksum(r io.Reader) ([]byte, error) { + h := md5.New() + + // Copy errors may be assumed to be from the body. + if _, err := io.Copy(h, r); err != nil { + return nil, fmt.Errorf("failed compute MD5 hash of reader, %w", err) + } + + // Encode the MD5 checksum in base64. + return base64EncodeHashSum(h), nil +} + +// computeChecksumReader provides a reader wrapping an underlying io.Reader to +// compute the checksum of the stream's bytes. +type computeChecksumReader struct { + stream io.Reader + algorithm Algorithm + hasher hash.Hash + base64ChecksumLen int + + mux sync.RWMutex + lockedChecksum string + lockedErr error +} + +// newComputeChecksumReader returns a computeChecksumReader for the stream and +// algorithm specified. Returns error if unable to create the reader, or +// algorithm is unknown. +func newComputeChecksumReader(stream io.Reader, algorithm Algorithm) (*computeChecksumReader, error) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return nil, err + } + + checksumLength, err := AlgorithmChecksumLength(algorithm) + if err != nil { + return nil, err + } + + return &computeChecksumReader{ + stream: io.TeeReader(stream, hasher), + algorithm: algorithm, + hasher: hasher, + base64ChecksumLen: base64.StdEncoding.EncodedLen(checksumLength), + }, nil +} + +// Read wraps the underlying reader. When the underlying reader returns EOF, +// the checksum of the reader will be computed, and can be retrieved with +// ChecksumBase64String. +func (r *computeChecksumReader) Read(p []byte) (int, error) { + n, err := r.stream.Read(p) + if err == nil { + return n, nil + } else if err != io.EOF { + r.mux.Lock() + defer r.mux.Unlock() + + r.lockedErr = err + return n, err + } + + b := base64EncodeHashSum(r.hasher) + + r.mux.Lock() + defer r.mux.Unlock() + + r.lockedChecksum = string(b) + + return n, err +} + +func (r *computeChecksumReader) Algorithm() Algorithm { + return r.algorithm +} + +// Base64ChecksumLength returns the base64 encoded length of the checksum for +// algorithm. +func (r *computeChecksumReader) Base64ChecksumLength() int { + return r.base64ChecksumLen +} + +// Base64Checksum returns the base64 checksum for the algorithm, or error if +// the underlying reader returned a non-EOF error. +// +// Safe to be called concurrently, but will return an error until after the +// underlying reader is returns EOF. +func (r *computeChecksumReader) Base64Checksum() (string, error) { + r.mux.RLock() + defer r.mux.RUnlock() + + if r.lockedErr != nil { + return "", r.lockedErr + } + + if r.lockedChecksum == "" { + return "", fmt.Errorf( + "checksum not available yet, called before reader returns EOF", + ) + } + + return r.lockedChecksum, nil +} + +// validateChecksumReader implements io.ReadCloser interface. The wrapper +// performs checksum validation when the underlying reader has been fully read. +type validateChecksumReader struct { + originalBody io.ReadCloser + body io.Reader + hasher hash.Hash + algorithm Algorithm + expectChecksum string +} + +// newValidateChecksumReader returns a configured io.ReadCloser that performs +// checksum validation when the underlying reader has been fully read. +func newValidateChecksumReader( + body io.ReadCloser, + algorithm Algorithm, + expectChecksum string, +) (*validateChecksumReader, error) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return nil, err + } + + return &validateChecksumReader{ + originalBody: body, + body: io.TeeReader(body, hasher), + hasher: hasher, + algorithm: algorithm, + expectChecksum: expectChecksum, + }, nil +} + +// Read attempts to read from the underlying stream while also updating the +// running hash. If the underlying stream returns with an EOF error, the +// checksum of the stream will be collected, and compared against the expected +// checksum. If the checksums do not match, an error will be returned. +// +// If a non-EOF error occurs when reading the underlying stream, that error +// will be returned and the checksum for the stream will be discarded. +func (c *validateChecksumReader) Read(p []byte) (n int, err error) { + n, err = c.body.Read(p) + if err == io.EOF { + if checksumErr := c.validateChecksum(); checksumErr != nil { + return n, checksumErr + } + } + + return n, err +} + +// Close closes the underlying reader, returning any error that occurred in the +// underlying reader. +func (c *validateChecksumReader) Close() (err error) { + return c.originalBody.Close() +} + +func (c *validateChecksumReader) validateChecksum() error { + // Compute base64 encoded checksum hash of the payload's read bytes. + v := base64EncodeHashSum(c.hasher) + if e, a := c.expectChecksum, string(v); !strings.EqualFold(e, a) { + return validationError{ + Algorithm: c.algorithm, Expect: e, Actual: a, + } + } + + return nil +} + +type validationError struct { + Algorithm Algorithm + Expect string + Actual string +} + +func (v validationError) Error() string { + return fmt.Sprintf("checksum did not match: algorithm %v, expect %v, actual %v", + v.Algorithm, v.Expect, v.Actual) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go new file mode 100644 index 0000000000..3f8b27018a --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go @@ -0,0 +1,470 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "fmt" + "hash/crc32" + "io" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/google/go-cmp/cmp" +) + +func TestComputeChecksumReader(t *testing.T) { + cases := map[string]struct { + Input io.Reader + Algorithm Algorithm + ExpectErr string + ExpectChecksumLen int + + ExpectRead string + ExpectReadErr string + ExpectComputeErr string + ExpectChecksum string + }{ + "unknown algorithm": { + Input: bytes.NewBuffer(nil), + Algorithm: Algorithm("something"), + ExpectErr: "unknown checksum algorithm", + }, + "read error": { + Input: iotest.ErrReader(fmt.Errorf("some error")), + Algorithm: AlgorithmSHA256, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size), + ExpectReadErr: "some error", + ExpectComputeErr: "some error", + }, + "crc32c": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmCRC32C, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size), + ExpectRead: "hello world", + ExpectChecksum: "yZRlqg==", + }, + "crc32": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmCRC32, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size), + ExpectRead: "hello world", + ExpectChecksum: "DUoRhQ==", + }, + "sha1": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmSHA1, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha1.Size), + ExpectRead: "hello world", + ExpectChecksum: "Kq5sNclPz7QV2+lfQIuc6R7oRu0=", + }, + "sha256": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmSHA256, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size), + ExpectRead: "hello world", + ExpectChecksum: "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + // Validate reader can be created as expected. + r, err := newComputeChecksumReader(c.Input, c.Algorithm) + if err == nil && len(c.ExpectErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectErr) + } + if err != nil && len(c.ExpectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectErr, err) + } + if c.ExpectErr != "" { + return + } + + if e, a := c.Algorithm, r.Algorithm(); e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + + // Validate expected checksum length. + if e, a := c.ExpectChecksumLen, r.Base64ChecksumLength(); e != a { + t.Errorf("expect %v checksum length, got %v", e, a) + } + + // Validate read reads underlying stream's bytes as expected. + b, err := ioutil.ReadAll(r) + if err == nil && len(c.ExpectReadErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectReadErr) + } + if err != nil && len(c.ExpectReadErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectReadErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectReadErr, err) + } + if len(c.ExpectReadErr) != 0 { + return + } + + if diff := cmp.Diff(string(c.ExpectRead), string(b)); diff != "" { + t.Errorf("expect read match, got\n%v", diff) + } + + // validate computed base64 + v, err := r.Base64Checksum() + if err == nil && len(c.ExpectComputeErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectComputeErr) + } + if err != nil && len(c.ExpectComputeErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectComputeErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectComputeErr, err) + } + if diff := cmp.Diff(c.ExpectChecksum, v); diff != "" { + t.Errorf("expect checksum match, got\n%v", diff) + } + if c.ExpectComputeErr != "" { + return + } + + if e, a := c.ExpectChecksumLen, len(v); e != a { + t.Errorf("expect computed checksum length %v, got %v", e, a) + } + }) + } +} + +func TestComputeChecksumReader_earlyGetChecksum(t *testing.T) { + r, err := newComputeChecksumReader(strings.NewReader("hello world"), AlgorithmCRC32C) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + v, err := r.Base64Checksum() + if err == nil { + t.Fatalf("expect error, got none") + } + if err != nil && !strings.Contains(err.Error(), "not available") { + t.Fatalf("expect error to match, got %v", err) + } + if v != "" { + t.Errorf("expect no checksum, got %v", err) + } +} + +// TODO race condition case with many reads, and get checksum + +func TestValidateChecksumReader(t *testing.T) { + cases := map[string]struct { + payload io.ReadCloser + algorithm Algorithm + checksum string + expectErr string + expectChecksumErr string + expectedBody []byte + }{ + "unknown algorithm": { + payload: ioutil.NopCloser(bytes.NewBuffer(nil)), + algorithm: Algorithm("something"), + expectErr: "unknown checksum algorithm", + }, + "empty body": { + payload: ioutil.NopCloser(bytes.NewReader([]byte(""))), + algorithm: AlgorithmCRC32, + checksum: "AAAAAA==", + expectedBody: []byte(""), + }, + "standard body": { + payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))), + algorithm: AlgorithmCRC32, + checksum: "i9aeUg==", + expectedBody: []byte("Hello world"), + }, + "checksum mismatch": { + payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))), + algorithm: AlgorithmCRC32, + checksum: "AAAAAA==", + expectedBody: []byte("Hello world"), + expectChecksumErr: "checksum did not match", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + response, err := newValidateChecksumReader(c.payload, c.algorithm, c.checksum) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if response == nil { + if c.expectedBody == nil { + return + } + t.Fatalf("expected non nil response, got nil") + } + + actualResponse, err := ioutil.ReadAll(response) + if err == nil && len(c.expectChecksumErr) != 0 { + t.Fatalf("expected error %v, got none", c.expectChecksumErr) + } + if err != nil && !strings.Contains(err.Error(), c.expectChecksumErr) { + t.Fatalf("expected error %v to contain %v", err.Error(), c.expectChecksumErr) + } + + if diff := cmp.Diff(c.expectedBody, actualResponse); len(diff) != 0 { + t.Fatalf("found diff comparing response body %v", diff) + } + + err = response.Close() + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestComputeMD5Checksum(t *testing.T) { + cases := map[string]struct { + payload io.Reader + expectErr string + expectChecksum string + }{ + "empty payload": { + payload: strings.NewReader(""), + expectChecksum: "1B2M2Y8AsgTpgAmY7PhCfg==", + }, + "payload": { + payload: strings.NewReader("hello world"), + expectChecksum: "XrY7u+Ae7tCTyyK7j1rNww==", + }, + "error payload": { + payload: iotest.ErrReader(fmt.Errorf("some error")), + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualChecksum, err := computeMD5Checksum(c.payload) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectChecksum, string(actualChecksum); !strings.EqualFold(e, a) { + t.Errorf("expect %v checksum, got %v", e, a) + } + }) + } +} + +func TestParseAlgorithm(t *testing.T) { + cases := map[string]struct { + Value string + expectAlgorithm Algorithm + expectErr string + }{ + "crc32c": { + Value: "crc32c", + expectAlgorithm: AlgorithmCRC32C, + }, + "CRC32C": { + Value: "CRC32C", + expectAlgorithm: AlgorithmCRC32C, + }, + "crc32": { + Value: "crc32", + expectAlgorithm: AlgorithmCRC32, + }, + "CRC32": { + Value: "CRC32", + expectAlgorithm: AlgorithmCRC32, + }, + "sha1": { + Value: "sha1", + expectAlgorithm: AlgorithmSHA1, + }, + "SHA1": { + Value: "SHA1", + expectAlgorithm: AlgorithmSHA1, + }, + "sha256": { + Value: "sha256", + expectAlgorithm: AlgorithmSHA256, + }, + "SHA256": { + Value: "SHA256", + expectAlgorithm: AlgorithmSHA256, + }, + "empty": { + Value: "", + expectErr: "unknown checksum algorithm", + }, + "unknown": { + Value: "unknown", + expectErr: "unknown checksum algorithm", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + // Asserts + algorithm, err := ParseAlgorithm(c.Value) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectAlgorithm, algorithm; e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + }) + } +} + +func TestFilterSupportedAlgorithms(t *testing.T) { + cases := map[string]struct { + values []string + expectAlgorithms []Algorithm + }{ + "no algorithms": { + expectAlgorithms: []Algorithm{}, + }, + "no supported algorithms": { + values: []string{"abc", "123"}, + expectAlgorithms: []Algorithm{}, + }, + "duplicate algorithms": { + values: []string{"crc32", "crc32c", "crc32c"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + }, + }, + "preserve order": { + values: []string{"crc32", "crc32c", "sha1", "sha256"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + AlgorithmSHA1, + AlgorithmSHA256, + }, + }, + "preserve order 2": { + values: []string{"sha256", "crc32", "sha1", "crc32c"}, + expectAlgorithms: []Algorithm{ + AlgorithmSHA256, + AlgorithmCRC32, + AlgorithmSHA1, + AlgorithmCRC32C, + }, + }, + "mixed case": { + values: []string{"Crc32", "cRc32c", "shA1", "sHA256"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + AlgorithmSHA1, + AlgorithmSHA256, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + algorithms := FilterSupportedAlgorithms(c.values) + if diff := cmp.Diff(c.expectAlgorithms, algorithms); diff != "" { + t.Errorf("expect algorithms match\n%s", diff) + } + }) + } +} + +func TestAlgorithmChecksumLength(t *testing.T) { + cases := map[string]struct { + algorithm Algorithm + expectErr string + expectLength int + }{ + "empty": { + algorithm: "", + expectErr: "unknown checksum algorithm", + }, + "unknown": { + algorithm: "", + expectErr: "unknown checksum algorithm", + }, + "crc32": { + algorithm: AlgorithmCRC32, + expectLength: crc32.Size, + }, + "crc32c": { + algorithm: AlgorithmCRC32C, + expectLength: crc32.Size, + }, + "sha1": { + algorithm: AlgorithmSHA1, + expectLength: sha1.Size, + }, + "sha256": { + algorithm: AlgorithmSHA256, + expectLength: sha256.Size, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + l, err := AlgorithmChecksumLength(c.algorithm) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectLength, l; e != a { + t.Errorf("expect %v checksum length, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go new file mode 100644 index 0000000000..3bd320c437 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go @@ -0,0 +1,389 @@ +package checksum + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" +) + +const ( + crlf = "\r\n" + + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html + defaultChunkLength = 1024 * 64 + + awsTrailerHeaderName = "x-amz-trailer" + decodedContentLengthHeaderName = "x-amz-decoded-content-length" + + contentEncodingHeaderName = "content-encoding" + awsChunkedContentEncodingHeaderValue = "aws-chunked" + + trailerKeyValueSeparator = ":" +) + +var ( + crlfBytes = []byte(crlf) + finalChunkBytes = []byte("0" + crlf) +) + +type awsChunkedEncodingOptions struct { + // The total size of the stream. For unsigned encoding this implies that + // there will only be a single chunk containing the underlying payload, + // unless ChunkLength is also specified. + StreamLength int64 + + // Set of trailer key:value pairs that will be appended to the end of the + // payload after the end chunk has been written. + Trailers map[string]awsChunkedTrailerValue + + // The maximum size of each chunk to be sent. Default value of -1, signals + // that optimal chunk length will be used automatically. ChunkSize must be + // at least 8KB. + // + // If ChunkLength and StreamLength are both specified, the stream will be + // broken up into ChunkLength chunks. The encoded length of the aws-chunked + // encoding can still be determined as long as all trailers, if any, have a + // fixed length. + ChunkLength int +} + +type awsChunkedTrailerValue struct { + // Function to retrieve the value of the trailer. Will only be called after + // the underlying stream returns EOF error. + Get func() (string, error) + + // If the length of the value can be pre-determined, and is constant + // specify the length. A value of -1 means the length is unknown, or + // cannot be pre-determined. + Length int +} + +// awsChunkedEncoding provides a reader that wraps the payload such that +// payload is read as a single aws-chunk payload. This reader can only be used +// if the content length of payload is known. Content-Length is used as size of +// the single payload chunk. The final chunk and trailing checksum is appended +// at the end. +// +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition +// +// Here is the aws-chunked payload stream as read from the awsChunkedEncoding +// if original request stream is "Hello world", and checksum hash used is SHA256 +// +// <b>\r\n +// Hello world\r\n +// 0\r\n +// x-amz-checksum-sha256:ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=\r\n +// \r\n +type awsChunkedEncoding struct { + options awsChunkedEncodingOptions + + encodedStream io.Reader + trailerEncodedLength int +} + +// newUnsignedAWSChunkedEncoding returns a new awsChunkedEncoding configured +// for unsigned aws-chunked content encoding. Any additional trailers that need +// to be appended after the end chunk must be included as via Trailer +// callbacks. +func newUnsignedAWSChunkedEncoding( + stream io.Reader, + optFns ...func(*awsChunkedEncodingOptions), +) *awsChunkedEncoding { + options := awsChunkedEncodingOptions{ + Trailers: map[string]awsChunkedTrailerValue{}, + StreamLength: -1, + ChunkLength: -1, + } + for _, fn := range optFns { + fn(&options) + } + + var chunkReader io.Reader + if options.ChunkLength != -1 || options.StreamLength == -1 { + if options.ChunkLength == -1 { + options.ChunkLength = defaultChunkLength + } + chunkReader = newBufferedAWSChunkReader(stream, options.ChunkLength) + } else { + chunkReader = newUnsignedChunkReader(stream, options.StreamLength) + } + + trailerReader := newAWSChunkedTrailerReader(options.Trailers) + + return &awsChunkedEncoding{ + options: options, + encodedStream: io.MultiReader(chunkReader, + trailerReader, + bytes.NewBuffer(crlfBytes), + ), + trailerEncodedLength: trailerReader.EncodedLength(), + } +} + +// EncodedLength returns the final length of the aws-chunked content encoded +// stream if it can be determined without reading the underlying stream or lazy +// header values, otherwise -1 is returned. +func (e *awsChunkedEncoding) EncodedLength() int64 { + var length int64 + if e.options.StreamLength == -1 || e.trailerEncodedLength == -1 { + return -1 + } + + if e.options.StreamLength != 0 { + // If the stream length is known, and there is no chunk length specified, + // only a single chunk will be used. Otherwise the stream length needs to + // include the multiple chunk padding content. + if e.options.ChunkLength == -1 { + length += getUnsignedChunkBytesLength(e.options.StreamLength) + + } else { + // Compute chunk header and payload length + numChunks := e.options.StreamLength / int64(e.options.ChunkLength) + length += numChunks * getUnsignedChunkBytesLength(int64(e.options.ChunkLength)) + if remainder := e.options.StreamLength % int64(e.options.ChunkLength); remainder != 0 { + length += getUnsignedChunkBytesLength(remainder) + } + } + } + + // End chunk + length += int64(len(finalChunkBytes)) + + // Trailers + length += int64(e.trailerEncodedLength) + + // Encoding terminator + length += int64(len(crlf)) + + return length +} + +func getUnsignedChunkBytesLength(payloadLength int64) int64 { + payloadLengthStr := strconv.FormatInt(payloadLength, 16) + return int64(len(payloadLengthStr)) + int64(len(crlf)) + payloadLength + int64(len(crlf)) +} + +// HTTPHeaders returns the set of headers that must be included the request for +// aws-chunked to work. This includes the content-encoding: aws-chunked header. +// +// If there are multiple layered content encoding, the aws-chunked encoding +// must be appended to the previous layers the stream's encoding. The best way +// to do this is to append all header values returned to the HTTP request's set +// of headers. +func (e *awsChunkedEncoding) HTTPHeaders() map[string][]string { + headers := map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + } + + if len(e.options.Trailers) != 0 { + trailers := make([]string, 0, len(e.options.Trailers)) + for name := range e.options.Trailers { + trailers = append(trailers, strings.ToLower(name)) + } + headers[awsTrailerHeaderName] = trailers + } + + return headers +} + +func (e *awsChunkedEncoding) Read(b []byte) (n int, err error) { + return e.encodedStream.Read(b) +} + +// awsChunkedTrailerReader provides a lazy reader for reading of aws-chunked +// content encoded trailers. The trailer values will not be retrieved until the +// reader is read from. +type awsChunkedTrailerReader struct { + reader *bytes.Buffer + trailers map[string]awsChunkedTrailerValue + trailerEncodedLength int +} + +// newAWSChunkedTrailerReader returns an initialized awsChunkedTrailerReader to +// lazy reading aws-chunk content encoded trailers. +func newAWSChunkedTrailerReader(trailers map[string]awsChunkedTrailerValue) *awsChunkedTrailerReader { + return &awsChunkedTrailerReader{ + trailers: trailers, + trailerEncodedLength: trailerEncodedLength(trailers), + } +} + +func trailerEncodedLength(trailers map[string]awsChunkedTrailerValue) (length int) { + for name, trailer := range trailers { + length += len(name) + len(trailerKeyValueSeparator) + l := trailer.Length + if l == -1 { + return -1 + } + length += l + len(crlf) + } + + return length +} + +// EncodedLength returns the length of the encoded trailers if the length could +// be determined without retrieving the header values. Returns -1 if length is +// unknown. +func (r *awsChunkedTrailerReader) EncodedLength() (length int) { + return r.trailerEncodedLength +} + +// Read populates the passed in byte slice with bytes from the encoded +// trailers. Will lazy read header values first time Read is called. +func (r *awsChunkedTrailerReader) Read(p []byte) (int, error) { + if r.trailerEncodedLength == 0 { + return 0, io.EOF + } + + if r.reader == nil { + trailerLen := r.trailerEncodedLength + if r.trailerEncodedLength == -1 { + trailerLen = 0 + } + r.reader = bytes.NewBuffer(make([]byte, 0, trailerLen)) + for name, trailer := range r.trailers { + r.reader.WriteString(name) + r.reader.WriteString(trailerKeyValueSeparator) + v, err := trailer.Get() + if err != nil { + return 0, fmt.Errorf("failed to get trailer value, %w", err) + } + r.reader.WriteString(v) + r.reader.WriteString(crlf) + } + } + + return r.reader.Read(p) +} + +// newUnsignedChunkReader returns an io.Reader encoding the underlying reader +// as unsigned aws-chunked chunks. The returned reader will also include the +// end chunk, but not the aws-chunked final `crlf` segment so trailers can be +// added. +// +// If the payload size is -1 for unknown length the content will be buffered in +// defaultChunkLength chunks before wrapped in aws-chunked chunk encoding. +func newUnsignedChunkReader(reader io.Reader, payloadSize int64) io.Reader { + if payloadSize == -1 { + return newBufferedAWSChunkReader(reader, defaultChunkLength) + } + + var endChunk bytes.Buffer + if payloadSize == 0 { + endChunk.Write(finalChunkBytes) + return &endChunk + } + + endChunk.WriteString(crlf) + endChunk.Write(finalChunkBytes) + + var header bytes.Buffer + header.WriteString(strconv.FormatInt(payloadSize, 16)) + header.WriteString(crlf) + return io.MultiReader( + &header, + reader, + &endChunk, + ) +} + +// Provides a buffered aws-chunked chunk encoder of an underlying io.Reader. +// Will include end chunk, but not the aws-chunked final `crlf` segment so +// trailers can be added. +// +// Note does not implement support for chunk extensions, e.g. chunk signing. +type bufferedAWSChunkReader struct { + reader io.Reader + chunkSize int + chunkSizeStr string + + headerBuffer *bytes.Buffer + chunkBuffer *bytes.Buffer + + multiReader io.Reader + multiReaderLen int + endChunkDone bool +} + +// newBufferedAWSChunkReader returns an bufferedAWSChunkReader for reading +// aws-chunked encoded chunks. +func newBufferedAWSChunkReader(reader io.Reader, chunkSize int) *bufferedAWSChunkReader { + return &bufferedAWSChunkReader{ + reader: reader, + chunkSize: chunkSize, + chunkSizeStr: strconv.FormatInt(int64(chunkSize), 16), + + headerBuffer: bytes.NewBuffer(make([]byte, 0, 64)), + chunkBuffer: bytes.NewBuffer(make([]byte, 0, chunkSize+len(crlf))), + } +} + +// Read attempts to read from the underlying io.Reader writing aws-chunked +// chunk encoded bytes to p. When the underlying io.Reader has been completed +// read the end chunk will be available. Once the end chunk is read, the reader +// will return EOF. +func (r *bufferedAWSChunkReader) Read(p []byte) (n int, err error) { + if r.multiReaderLen == 0 && r.endChunkDone { + return 0, io.EOF + } + if r.multiReader == nil || r.multiReaderLen == 0 { + r.multiReader, r.multiReaderLen, err = r.newMultiReader() + if err != nil { + return 0, err + } + } + + n, err = r.multiReader.Read(p) + r.multiReaderLen -= n + + if err == io.EOF && !r.endChunkDone { + // Edge case handling when the multi-reader has been completely read, + // and returned an EOF, make sure that EOF only gets returned if the + // end chunk was included in the multi-reader. Otherwise, the next call + // to read will initialize the next chunk's multi-reader. + err = nil + } + return n, err +} + +// newMultiReader returns a new io.Reader for wrapping the next chunk. Will +// return an error if the underlying reader can not be read from. Will never +// return io.EOF. +func (r *bufferedAWSChunkReader) newMultiReader() (io.Reader, int, error) { + // io.Copy eats the io.EOF returned by io.LimitReader. Any error that + // occurs here is due to an actual read error. + n, err := io.Copy(r.chunkBuffer, io.LimitReader(r.reader, int64(r.chunkSize))) + if err != nil { + return nil, 0, err + } + if n == 0 { + // Early exit writing out only the end chunk. This does not include + // aws-chunk's final `crlf` so that trailers can still be added by + // upstream reader. + r.headerBuffer.Reset() + r.headerBuffer.WriteString("0") + r.headerBuffer.WriteString(crlf) + r.endChunkDone = true + + return r.headerBuffer, r.headerBuffer.Len(), nil + } + r.chunkBuffer.WriteString(crlf) + + chunkSizeStr := r.chunkSizeStr + if int(n) != r.chunkSize { + chunkSizeStr = strconv.FormatInt(n, 16) + } + + r.headerBuffer.Reset() + r.headerBuffer.WriteString(chunkSizeStr) + r.headerBuffer.WriteString(crlf) + + return io.MultiReader( + r.headerBuffer, + r.chunkBuffer, + ), r.headerBuffer.Len() + r.chunkBuffer.Len(), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go new file mode 100644 index 0000000000..8e9ce3c8a9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go @@ -0,0 +1,507 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/google/go-cmp/cmp" +) + +func TestAWSChunkedEncoding(t *testing.T) { + cases := map[string]struct { + reader *awsChunkedEncoding + expectErr string + expectEncodedLength int64 + expectHTTPHeaders map[string][]string + expectPayload []byte + }{ + "empty payload fixed stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader(""), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 0 + }), + expectEncodedLength: 5, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("0\r\n\r\n"), + }, + "empty payload unknown stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("")), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("0\r\n\r\n"), + }, + "payload fixed stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + }), + expectEncodedLength: 21, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"), + }, + "payload unknown stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world")), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"), + }, + "payload unknown stream length with chunk size": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.ChunkLength = 8 + }), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"), + }, + "payload fixed stream length with chunk size": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.ChunkLength = 8 + }), + expectEncodedLength: 26, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"), + }, + "payload fixed stream length with fixed length trailer": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.Trailers = map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: 6, + }, + } + }), + expectEncodedLength: 33, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + awsTrailerHeaderName: {"foo"}, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"), + }, + "payload fixed stream length with unknown length trailer": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.Trailers = map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: -1, + }, + } + }), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + awsTrailerHeaderName: {"foo"}, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a { + t.Errorf("expect %v encoded length, got %v", e, a) + } + if diff := cmp.Diff(c.expectHTTPHeaders, c.reader.HTTPHeaders()); diff != "" { + t.Errorf("expect HTTP headers match\n%v", diff) + } + + actualPayload, err := ioutil.ReadAll(c.reader) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestUnsignedAWSChunkReader(t *testing.T) { + cases := map[string]struct { + payload interface { + io.Reader + Len() int + } + + expectPayload []byte + expectErr string + }{ + "empty body": { + payload: bytes.NewReader([]byte{}), + expectPayload: []byte("0\r\n"), + }, + "with body": { + payload: strings.NewReader("Hello world"), + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "large body": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"), + }, + "reader error": { + payload: newLimitReadLener(iotest.ErrReader(fmt.Errorf("some read error")), 128), + expectErr: "some read error", + }, + "unknown length reader": { + payload: newUnknownLenReader(io.LimitReader(byteReader('a'), defaultChunkLength*2)), + expectPayload: func() []byte { + reader := newBufferedAWSChunkReader( + io.LimitReader(byteReader('a'), defaultChunkLength*2), + defaultChunkLength, + ) + actualPayload, err := ioutil.ReadAll(reader) + if err != nil { + t.Fatalf("failed to create unknown length reader test data, %v", err) + } + return actualPayload + }(), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + reader := newUnsignedChunkReader(c.payload, int64(c.payload.Len())) + + actualPayload, err := ioutil.ReadAll(reader) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestBufferedAWSChunkReader(t *testing.T) { + cases := map[string]struct { + payload io.Reader + readSize int + chunkSize int + + expectPayload []byte + expectErr string + }{ + "empty body": { + payload: bytes.NewReader([]byte{}), + chunkSize: 4, + expectPayload: []byte("0\r\n"), + }, + "with one chunk body": { + payload: strings.NewReader("Hello world"), + chunkSize: 20, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "single byte read": { + payload: strings.NewReader("Hello world"), + chunkSize: 8, + readSize: 1, + expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"), + }, + "single chunk and byte read": { + payload: strings.NewReader("Hello world"), + chunkSize: 1, + readSize: 1, + expectPayload: []byte("1\r\nH\r\n1\r\ne\r\n1\r\nl\r\n1\r\nl\r\n1\r\no\r\n1\r\n \r\n1\r\nw\r\n1\r\no\r\n1\r\nr\r\n1\r\nl\r\n1\r\nd\r\n0\r\n"), + }, + "with two chunk body": { + payload: strings.NewReader("Hello world"), + chunkSize: 8, + expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"), + }, + "chunk size equal to read size": { + payload: strings.NewReader("Hello world"), + chunkSize: 512, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "chunk size greater than read size": { + payload: strings.NewReader("Hello world"), + chunkSize: 1024, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "payload size more than default read size, chunk size less than read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 500, + expectPayload: []byte("1f4\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello\r\n11\r\n worldHello world\r\n0\r\n"), + }, + "payload size more than default read size, chunk size equal to read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 512, + expectPayload: []byte("200\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello \r\n5\r\nworld\r\n0\r\n"), + }, + "payload size more than default read size, chunk size more than read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 1024, + expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"), + }, + "reader error": { + payload: iotest.ErrReader(fmt.Errorf("some read error")), + chunkSize: 128, + expectErr: "some read error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + reader := newBufferedAWSChunkReader(c.payload, c.chunkSize) + + var actualPayload []byte + var err error + + if c.readSize != 0 { + for err == nil { + var n int + p := make([]byte, c.readSize) + n, err = reader.Read(p) + if n != 0 { + actualPayload = append(actualPayload, p[:n]...) + } + } + if err == io.EOF { + err = nil + } + } else { + actualPayload, err = ioutil.ReadAll(reader) + } + + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestAwsChunkedTrailerReader(t *testing.T) { + cases := map[string]struct { + reader *awsChunkedTrailerReader + + expectErr string + expectEncodedLength int + expectPayload []byte + }{ + "no trailers": { + reader: newAWSChunkedTrailerReader(nil), + expectPayload: []byte{}, + }, + "unknown length trailers": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: -1, + }, + }), + expectEncodedLength: -1, + expectPayload: []byte("foo:abc123\r\n"), + }, + "known length trailers": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: 6, + }, + }), + expectEncodedLength: 12, + expectPayload: []byte("foo:abc123\r\n"), + }, + "trailer error": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "", fmt.Errorf("some error") + }, + Length: 6, + }, + }), + expectEncodedLength: 12, + expectErr: "failed to get trailer", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a { + t.Errorf("expect %v encoded length, got %v", e, a) + } + + actualPayload, err := ioutil.ReadAll(c.reader) + + // Asserts + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +type limitReadLener struct { + io.Reader + length int +} + +func newLimitReadLener(r io.Reader, l int) *limitReadLener { + return &limitReadLener{ + Reader: io.LimitReader(r, int64(l)), + length: l, + } +} +func (r *limitReadLener) Len() int { + return r.length +} + +type unknownLenReader struct { + io.Reader +} + +func newUnknownLenReader(r io.Reader) *unknownLenReader { + return &unknownLenReader{ + Reader: r, + } +} +func (r *unknownLenReader) Len() int { + return -1 +} + +type byteReader byte + +func (r byteReader) Read(p []byte) (int, error) { + for i := 0; i < len(p); i++ { + p[i] = byte(r) + } + return len(p), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go new file mode 100644 index 0000000000..f591861505 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go @@ -0,0 +1,6 @@ +// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT. + +package checksum + +// goModuleVersion is the tagged release for this module +const goModuleVersion = "1.2.0" diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make new file mode 100644 index 0000000000..2a3d936a06 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go new file mode 100644 index 0000000000..3e17d2216b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go @@ -0,0 +1,185 @@ +package checksum + +import ( + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// InputMiddlewareOptions provides the options for the request +// checksum middleware setup. +type InputMiddlewareOptions struct { + // GetAlgorithm is a function to get the checksum algorithm of the + // input payload from the input parameters. + // + // Given the input parameter value, the function must return the algorithm + // and true, or false if no algorithm is specified. + GetAlgorithm func(interface{}) (string, bool) + + // Forces the middleware to compute the input payload's checksum. The + // request will fail if the algorithm is not specified or unable to compute + // the checksum. + RequireChecksum bool + + // Enables support for wrapping the serialized input payload with a + // content-encoding: aws-check wrapper, and including a trailer for the + // algorithm's checksum value. + // + // The checksum will not be computed, nor added as trailing checksum, if + // the Algorithm's header is already set on the request. + EnableTrailingChecksum bool + + // Enables support for computing the SHA256 checksum of input payloads + // along with the algorithm specified checksum. Prevents downstream + // middleware handlers (computePayloadSHA256) re-reading the payload. + // + // The SHA256 payload checksum will only be used for computed for requests + // that are not TLS, or do not enable trailing checksums. + // + // The SHA256 payload hash will not be computed, if the Algorithm's header + // is already set on the request. + EnableComputeSHA256PayloadHash bool + + // Enables support for setting the aws-chunked decoded content length + // header for the decoded length of the underlying stream. Will only be set + // when used with trailing checksums, and aws-chunked content-encoding. + EnableDecodedContentLengthHeader bool +} + +// AddInputMiddleware adds the middleware for performing checksum computing +// of request payloads, and checksum validation of response payloads. +func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) (err error) { + // TODO ensure this works correctly with presigned URLs + + // Middleware stack: + // * (OK)(Initialize) --none-- + // * (OK)(Serialize) EndpointResolver + // * (OK)(Build) ComputeContentLength + // * (AD)(Build) Header ComputeInputPayloadChecksum + // * SIGNED Payload - If HTTP && not support trailing checksum + // * UNSIGNED Payload - If HTTPS && not support trailing checksum + // * (RM)(Build) ContentChecksum - OK to remove + // * (OK)(Build) ComputePayloadHash + // * v4.dynamicPayloadSigningMiddleware + // * v4.computePayloadSHA256 + // * v4.unsignedPayload + // (OK)(Build) Set computedPayloadHash header + // * (OK)(Finalize) Retry + // * (AD)(Finalize) Trailer ComputeInputPayloadChecksum, + // * Requires HTTPS && support trailing checksum + // * UNSIGNED Payload + // * Finalize run if HTTPS && support trailing checksum + // * (OK)(Finalize) Signing + // * (OK)(Deserialize) --none-- + + // Initial checksum configuration look up middleware + err = stack.Initialize.Add(&setupInputContext{ + GetAlgorithm: options.GetAlgorithm, + }, middleware.Before) + if err != nil { + return err + } + + stack.Build.Remove("ContentChecksum") + + // Create the compute checksum middleware that will be added as both a + // build and finalize handler. + inputChecksum := &computeInputPayloadChecksum{ + RequireChecksum: options.RequireChecksum, + EnableTrailingChecksum: options.EnableTrailingChecksum, + EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash, + EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader, + } + + // Insert header checksum after ComputeContentLength middleware, must also + // be before the computePayloadHash middleware handlers. + err = stack.Build.Insert(inputChecksum, + (*smithyhttp.ComputeContentLength)(nil).ID(), + middleware.After) + if err != nil { + return err + } + + // If trailing checksum is not supported no need for finalize handler to be added. + if options.EnableTrailingChecksum { + err = stack.Finalize.Insert(inputChecksum, "Retry", middleware.After) + if err != nil { + return err + } + } + + return nil +} + +// RemoveInputMiddleware Removes the compute input payload checksum middleware +// handlers from the stack. +func RemoveInputMiddleware(stack *middleware.Stack) { + id := (*setupInputContext)(nil).ID() + stack.Initialize.Remove(id) + + id = (*computeInputPayloadChecksum)(nil).ID() + stack.Build.Remove(id) + stack.Finalize.Remove(id) +} + +// OutputMiddlewareOptions provides options for configuring output checksum +// validation middleware. +type OutputMiddlewareOptions struct { + // GetValidationMode is a function to get the checksum validation + // mode of the output payload from the input parameters. + // + // Given the input parameter value, the function must return the validation + // mode and true, or false if no mode is specified. + GetValidationMode func(interface{}) (string, bool) + + // The set of checksum algorithms that should be used for response payload + // checksum validation. The algorithm(s) used will be a union of the + // output's returned algorithms and this set. + // + // Only the first algorithm in the union is currently used. + ValidationAlgorithms []string + + // If set the middleware will ignore output multipart checksums. Otherwise + // an checksum format error will be returned by the middleware. + IgnoreMultipartValidation bool + + // When set the middleware will log when output does not have checksum or + // algorithm to validate. + LogValidationSkipped bool + + // When set the middleware will log when the output contains a multipart + // checksum that was, skipped and not validated. + LogMultipartValidationSkipped bool +} + +// AddOutputMiddleware adds the middleware for validating response payload's +// checksum. +func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error { + err := stack.Initialize.Add(&setupOutputContext{ + GetValidationMode: options.GetValidationMode, + }, middleware.Before) + if err != nil { + return err + } + + // Resolve a supported priority order list of algorithms to validate. + algorithms := FilterSupportedAlgorithms(options.ValidationAlgorithms) + + m := &validateOutputPayloadChecksum{ + Algorithms: algorithms, + IgnoreMultipartValidation: options.IgnoreMultipartValidation, + LogMultipartValidationSkipped: options.LogMultipartValidationSkipped, + LogValidationSkipped: options.LogValidationSkipped, + } + + return stack.Deserialize.Add(m, middleware.After) +} + +// RemoveOutputMiddleware Removes the compute input payload checksum middleware +// handlers from the stack. +func RemoveOutputMiddleware(stack *middleware.Stack) { + id := (*setupOutputContext)(nil).ID() + stack.Initialize.Remove(id) + + id = (*validateOutputPayloadChecksum)(nil).ID() + stack.Deserialize.Remove(id) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go new file mode 100644 index 0000000000..33f952ebe9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go @@ -0,0 +1,412 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "context" + "testing" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +func TestAddInputMiddleware(t *testing.T) { + cases := map[string]struct { + options InputMiddlewareOptions + expectErr string + expectMiddleware []string + expectInitialize *setupInputContext + expectBuild *computeInputPayloadChecksum + expectFinalize *computeInputPayloadChecksum + }{ + "with trailing checksum": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + EnableTrailingChecksum: true, + EnableComputeSHA256PayloadHash: true, + EnableDecodedContentLengthHeader: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "AWSChecksum:ComputeInputPayloadChecksum", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{ + EnableTrailingChecksum: true, + EnableComputePayloadHash: true, + EnableDecodedContentLengthHeader: true, + }, + }, + "with checksum required": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + EnableTrailingChecksum: true, + RequireChecksum: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "AWSChecksum:ComputeInputPayloadChecksum", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{ + RequireChecksum: true, + EnableTrailingChecksum: true, + }, + }, + "no trailing checksum": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After) + + err := AddInputMiddleware(stack, c.options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } + + initializeMiddleware, ok := stack.Initialize.Get((*setupInputContext)(nil).ID()) + if e, a := (c.expectInitialize != nil), ok; e != a { + t.Errorf("expect initialize middleware %t, got %t", e, a) + } + if c.expectInitialize != nil && ok { + setupInput := initializeMiddleware.(*setupInputContext) + if e, a := c.options.GetAlgorithm != nil, setupInput.GetAlgorithm != nil; e != a { + t.Fatalf("expect GetAlgorithm %t, got %t", e, a) + } + expectAlgo, expectOK := c.options.GetAlgorithm(nil) + actualAlgo, actualOK := setupInput.GetAlgorithm(nil) + if e, a := expectAlgo, actualAlgo; e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + if e, a := expectOK, actualOK; e != a { + t.Errorf("expect %v algorithm present, got %v", e, a) + } + } + + buildMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID()) + if e, a := (c.expectBuild != nil), ok; e != a { + t.Errorf("expect build middleware %t, got %t", e, a) + } + var computeInput *computeInputPayloadChecksum + if c.expectBuild != nil && ok { + computeInput = buildMiddleware.(*computeInputPayloadChecksum) + if e, a := c.expectBuild.RequireChecksum, computeInput.RequireChecksum; e != a { + t.Errorf("expect %v require checksum, got %v", e, a) + } + if e, a := c.expectBuild.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a { + t.Errorf("expect %v enable trailing checksum, got %v", e, a) + } + if e, a := c.expectBuild.EnableComputePayloadHash, computeInput.EnableComputePayloadHash; e != a { + t.Errorf("expect %v enable compute payload hash, got %v", e, a) + } + if e, a := c.expectBuild.EnableDecodedContentLengthHeader, computeInput.EnableDecodedContentLengthHeader; e != a { + t.Errorf("expect %v enable decoded length header, got %v", e, a) + } + } + + if c.expectFinalize != nil && ok { + finalizeMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID()) + if !ok { + t.Errorf("expect finalize middleware") + } + finalizeComputeInput := finalizeMiddleware.(*computeInputPayloadChecksum) + + if e, a := computeInput, finalizeComputeInput; e != a { + t.Errorf("expect build and finalize to be same value") + } + } + }) + } +} + +func TestRemoveInputMiddleware(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After) + + err := AddInputMiddleware(stack, InputMiddlewareOptions{ + EnableTrailingChecksum: true, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + RemoveInputMiddleware(stack) + + expectStack := []string{ + "test", + "Initialize stack step", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "Signing", + "Deserialize stack step", + } + + if diff := cmp.Diff(expectStack, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } +} + +func TestAddOutputMiddleware(t *testing.T) { + cases := map[string]struct { + options OutputMiddlewareOptions + expectErr string + expectMiddleware []string + expectInitialize *setupOutputContext + expectDeserialize *validateOutputPayloadChecksum + }{ + "validate output": { + options: OutputMiddlewareOptions{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + ValidationAlgorithms: []string{ + "crc32", "sha1", "abc123", "crc32c", + }, + IgnoreMultipartValidation: true, + LogMultipartValidationSkipped: true, + LogValidationSkipped: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupOutputContext", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + "AWSChecksum:ValidateOutputPayloadChecksum", + }, + expectInitialize: &setupOutputContext{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + }, + expectDeserialize: &validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C, + }, + IgnoreMultipartValidation: true, + LogMultipartValidationSkipped: true, + LogValidationSkipped: true, + }, + }, + "validate options off": { + options: OutputMiddlewareOptions{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + ValidationAlgorithms: []string{ + "crc32", "sha1", "abc123", "crc32c", + }, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupOutputContext", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + "AWSChecksum:ValidateOutputPayloadChecksum", + }, + expectInitialize: &setupOutputContext{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + }, + expectDeserialize: &validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C, + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + err := AddOutputMiddleware(stack, c.options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } + + initializeMiddleware, ok := stack.Initialize.Get((*setupOutputContext)(nil).ID()) + if e, a := (c.expectInitialize != nil), ok; e != a { + t.Errorf("expect initialize middleware %t, got %t", e, a) + } + if c.expectInitialize != nil && ok { + setupOutput := initializeMiddleware.(*setupOutputContext) + if e, a := c.options.GetValidationMode != nil, setupOutput.GetValidationMode != nil; e != a { + t.Fatalf("expect GetValidationMode %t, got %t", e, a) + } + expectMode, expectOK := c.options.GetValidationMode(nil) + actualMode, actualOK := setupOutput.GetValidationMode(nil) + if e, a := expectMode, actualMode; e != a { + t.Errorf("expect %v mode, got %v", e, a) + } + if e, a := expectOK, actualOK; e != a { + t.Errorf("expect %v mode present, got %v", e, a) + } + } + + deserializeMiddleware, ok := stack.Deserialize.Get((*validateOutputPayloadChecksum)(nil).ID()) + if e, a := (c.expectDeserialize != nil), ok; e != a { + t.Errorf("expect deserialize middleware %t, got %t", e, a) + } + if c.expectDeserialize != nil && ok { + validateOutput := deserializeMiddleware.(*validateOutputPayloadChecksum) + if diff := cmp.Diff(c.expectDeserialize.Algorithms, validateOutput.Algorithms); diff != "" { + t.Errorf("expect algorithms match:\n%s", diff) + } + if e, a := c.expectDeserialize.IgnoreMultipartValidation, validateOutput.IgnoreMultipartValidation; e != a { + t.Errorf("expect %v ignore multipart checksum, got %v", e, a) + } + if e, a := c.expectDeserialize.LogMultipartValidationSkipped, validateOutput.LogMultipartValidationSkipped; e != a { + t.Errorf("expect %v log multipart skipped, got %v", e, a) + } + if e, a := c.expectDeserialize.LogValidationSkipped, validateOutput.LogValidationSkipped; e != a { + t.Errorf("expect %v log validation skipped, got %v", e, a) + } + } + }) + } +} + +func TestRemoveOutputMiddleware(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + err := AddOutputMiddleware(stack, OutputMiddlewareOptions{}) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + RemoveOutputMiddleware(stack) + + expectStack := []string{ + "test", + "Initialize stack step", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + } + + if diff := cmp.Diff(expectStack, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } +} + +func setSerializedRequest(req *smithyhttp.Request) middleware.SerializeMiddleware { + return middleware.SerializeMiddlewareFunc("OperationSerializer", + func(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler) ( + middleware.SerializeOutput, middleware.Metadata, error, + ) { + input.Request = req + return next.HandleSerialize(ctx, input) + }) +} + +func nopBuildMiddleware(id string) middleware.BuildMiddleware { + return middleware.BuildMiddlewareFunc(id, + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + middleware.BuildOutput, middleware.Metadata, error, + ) { + return next.HandleBuild(ctx, input) + }) +} + +func nopFinalizeMiddleware(id string) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc(id, + func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) ( + middleware.FinalizeOutput, middleware.Metadata, error, + ) { + return next.HandleFinalize(ctx, input) + }) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go new file mode 100644 index 0000000000..0b3c48912b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go @@ -0,0 +1,479 @@ +package checksum + +import ( + "context" + "crypto/sha256" + "fmt" + "hash" + "io" + "strconv" + + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const ( + contentMD5Header = "Content-Md5" + streamingUnsignedPayloadTrailerPayloadHash = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" +) + +// computedInputChecksumsKey is the metadata key for recording the algorithm the +// checksum was computed for and the checksum value. +type computedInputChecksumsKey struct{} + +// GetComputedInputChecksums returns the map of checksum algorithm to their +// computed value stored in the middleware Metadata. Returns false if no values +// were stored in the Metadata. +func GetComputedInputChecksums(m middleware.Metadata) (map[string]string, bool) { + vs, ok := m.Get(computedInputChecksumsKey{}).(map[string]string) + return vs, ok +} + +// SetComputedInputChecksums stores the map of checksum algorithm to their +// computed value in the middleware Metadata. Overwrites any values that +// currently exist in the metadata. +func SetComputedInputChecksums(m *middleware.Metadata, vs map[string]string) { + m.Set(computedInputChecksumsKey{}, vs) +} + +// computeInputPayloadChecksum middleware computes payload checksum +type computeInputPayloadChecksum struct { + // Enables support for wrapping the serialized input payload with a + // content-encoding: aws-check wrapper, and including a trailer for the + // algorithm's checksum value. + // + // The checksum will not be computed, nor added as trailing checksum, if + // the Algorithm's header is already set on the request. + EnableTrailingChecksum bool + + // States that a checksum is required to be included for the operation. If + // Input does not specify a checksum, fallback to built in MD5 checksum is + // used. + // + // Replaces smithy-go's ContentChecksum middleware. + RequireChecksum bool + + // Enables support for computing the SHA256 checksum of input payloads + // along with the algorithm specified checksum. Prevents downstream + // middleware handlers (computePayloadSHA256) re-reading the payload. + // + // The SHA256 payload hash will only be used for computed for requests + // that are not TLS, or do not enable trailing checksums. + // + // The SHA256 payload hash will not be computed, if the Algorithm's header + // is already set on the request. + EnableComputePayloadHash bool + + // Enables support for setting the aws-chunked decoded content length + // header for the decoded length of the underlying stream. Will only be set + // when used with trailing checksums, and aws-chunked content-encoding. + EnableDecodedContentLengthHeader bool + + buildHandlerRun bool + deferToFinalizeHandler bool +} + +// ID provides the middleware's identifier. +func (m *computeInputPayloadChecksum) ID() string { + return "AWSChecksum:ComputeInputPayloadChecksum" +} + +type computeInputHeaderChecksumError struct { + Msg string + Err error +} + +func (e computeInputHeaderChecksumError) Error() string { + const intro = "compute input header checksum failed" + + if e.Err != nil { + return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err) + } + + return fmt.Sprintf("%s, %s", intro, e.Msg) +} +func (e computeInputHeaderChecksumError) Unwrap() error { return e.Err } + +// HandleBuild handles computing the payload's checksum, in the following cases: +// - Is HTTP, not HTTPS +// - RequireChecksum is true, and no checksums were specified via the Input +// - Trailing checksums are not supported +// +// The build handler must be inserted in the stack before ContentPayloadHash +// and after ComputeContentLength. +func (m *computeInputPayloadChecksum) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + m.buildHandlerRun = true + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, computeInputHeaderChecksumError{ + Msg: fmt.Sprintf("unknown request type %T", req), + } + } + + var algorithm Algorithm + var checksum string + defer func() { + if algorithm == "" || checksum == "" || err != nil { + return + } + + // Record the checksum and algorithm that was computed + SetComputedInputChecksums(&metadata, map[string]string{ + string(algorithm): checksum, + }) + }() + + // If no algorithm was specified, and the operation requires a checksum, + // fallback to the legacy content MD5 checksum. + algorithm, ok, err = getInputAlgorithm(ctx) + if err != nil { + return out, metadata, err + } else if !ok { + if m.RequireChecksum { + checksum, err = setMD5Checksum(ctx, req) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to compute stream's MD5 checksum", + Err: err, + } + } + algorithm = Algorithm("MD5") + } + return next.HandleBuild(ctx, in) + } + + // If the checksum header is already set nothing to do. + checksumHeader := AlgorithmHTTPHeader(algorithm) + if checksum = req.Header.Get(checksumHeader); checksum != "" { + return next.HandleBuild(ctx, in) + } + + computePayloadHash := m.EnableComputePayloadHash + if v := v4.GetPayloadHash(ctx); v != "" { + computePayloadHash = false + } + + stream := req.GetStream() + streamLength, err := getRequestStreamLength(req) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to determine stream length", + Err: err, + } + } + + // If trailing checksums are supported, the request is HTTPS, and the + // stream is not nil or empty, there is nothing to do in the build stage. + // The checksum will be added to the request as a trailing checksum in the + // finalize handler. + // + // Nil and empty streams will always be handled as a request header, + // regardless if the operation supports trailing checksums or not. + if req.IsHTTPS() { + if stream != nil && streamLength != 0 && m.EnableTrailingChecksum { + if m.EnableComputePayloadHash { + // payload hash is set as header in Build middleware handler, + // ContentSHA256Header. + ctx = v4.SetPayloadHash(ctx, streamingUnsignedPayloadTrailerPayloadHash) + } + + m.deferToFinalizeHandler = true + return next.HandleBuild(ctx, in) + } + + // If trailing checksums are not enabled but protocol is still HTTPS + // disabling computing the payload hash. Downstream middleware handler + // (ComputetPayloadHash) will set the payload hash to unsigned payload, + // if signing was used. + computePayloadHash = false + } + + // Only seekable streams are supported for non-trailing checksums, because + // the stream needs to be rewound before the handler can continue. + if stream != nil && !req.IsStreamSeekable() { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "unseekable stream is not supported without TLS and trailing checksum", + } + } + + var sha256Checksum string + checksum, sha256Checksum, err = computeStreamChecksum( + algorithm, stream, computePayloadHash) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to compute stream checksum", + Err: err, + } + } + + if err := req.RewindStream(); err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to rewind stream", + Err: err, + } + } + + req.Header.Set(checksumHeader, checksum) + + if computePayloadHash { + ctx = v4.SetPayloadHash(ctx, sha256Checksum) + } + + return next.HandleBuild(ctx, in) +} + +type computeInputTrailingChecksumError struct { + Msg string + Err error +} + +func (e computeInputTrailingChecksumError) Error() string { + const intro = "compute input trailing checksum failed" + + if e.Err != nil { + return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err) + } + + return fmt.Sprintf("%s, %s", intro, e.Msg) +} +func (e computeInputTrailingChecksumError) Unwrap() error { return e.Err } + +// HandleFinalize handles computing the payload's checksum, in the following cases: +// - Is HTTPS, not HTTP +// - A checksum was specified via the Input +// - Trailing checksums are supported. +// +// The finalize handler must be inserted in the stack before Signing, and after Retry. +func (m *computeInputPayloadChecksum) HandleFinalize( + ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + if !m.deferToFinalizeHandler { + if !m.buildHandlerRun { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "build handler was removed without also removing finalize handler", + } + } + return next.HandleFinalize(ctx, in) + } + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, computeInputTrailingChecksumError{ + Msg: fmt.Sprintf("unknown request type %T", req), + } + } + + // Trailing checksums are only supported when TLS is enabled. + if !req.IsHTTPS() { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "HTTPS required", + } + } + + // If no algorithm was specified, there is nothing to do. + algorithm, ok, err := getInputAlgorithm(ctx) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to get algorithm", + Err: err, + } + } else if !ok { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "no algorithm specified", + } + } + + // If the checksum header is already set before finalize could run, there + // is nothing to do. + checksumHeader := AlgorithmHTTPHeader(algorithm) + if req.Header.Get(checksumHeader) != "" { + return next.HandleFinalize(ctx, in) + } + + stream := req.GetStream() + streamLength, err := getRequestStreamLength(req) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to determine stream length", + Err: err, + } + } + + if stream == nil || streamLength == 0 { + // Nil and empty streams are handled by the Build handler. They are not + // supported by the trailing checksums finalize handler. There is no + // benefit to sending them as trailers compared to headers. + return out, metadata, computeInputTrailingChecksumError{ + Msg: "nil or empty streams are not supported", + } + } + + checksumReader, err := newComputeChecksumReader(stream, algorithm) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to created checksum reader", + Err: err, + } + } + + awsChunkedReader := newUnsignedAWSChunkedEncoding(checksumReader, + func(o *awsChunkedEncodingOptions) { + o.Trailers[AlgorithmHTTPHeader(checksumReader.Algorithm())] = awsChunkedTrailerValue{ + Get: checksumReader.Base64Checksum, + Length: checksumReader.Base64ChecksumLength(), + } + o.StreamLength = streamLength + }) + + for key, values := range awsChunkedReader.HTTPHeaders() { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // Setting the stream on the request will create a copy. The content length + // is not updated until after the request is copied to prevent impacting + // upstream middleware. + req, err = req.SetStream(awsChunkedReader) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed updating request to trailing checksum wrapped stream", + Err: err, + } + } + req.ContentLength = awsChunkedReader.EncodedLength() + in.Request = req + + // Add decoded content length header if original stream's content length is known. + if streamLength != -1 && m.EnableDecodedContentLengthHeader { + req.Header.Set(decodedContentLengthHeaderName, strconv.FormatInt(streamLength, 10)) + } + + out, metadata, err = next.HandleFinalize(ctx, in) + if err == nil { + checksum, err := checksumReader.Base64Checksum() + if err != nil { + return out, metadata, fmt.Errorf("failed to get computed checksum, %w", err) + } + + // Record the checksum and algorithm that was computed + SetComputedInputChecksums(&metadata, map[string]string{ + string(algorithm): checksum, + }) + } + + return out, metadata, err +} + +func getInputAlgorithm(ctx context.Context) (Algorithm, bool, error) { + ctxAlgorithm := getContextInputAlgorithm(ctx) + if ctxAlgorithm == "" { + return "", false, nil + } + + algorithm, err := ParseAlgorithm(ctxAlgorithm) + if err != nil { + return "", false, fmt.Errorf( + "failed to parse algorithm, %w", err) + } + + return algorithm, true, nil +} + +func computeStreamChecksum(algorithm Algorithm, stream io.Reader, computePayloadHash bool) ( + checksum string, sha256Checksum string, err error, +) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return "", "", fmt.Errorf( + "failed to get hasher for checksum algorithm, %w", err) + } + + var sha256Hasher hash.Hash + var batchHasher io.Writer = hasher + + // Compute payload hash for the protocol. To prevent another handler + // (computePayloadSHA256) re-reading body also compute the SHA256 for + // request signing. If configured checksum algorithm is SHA256, don't + // double wrap stream with another SHA256 hasher. + if computePayloadHash && algorithm != AlgorithmSHA256 { + sha256Hasher = sha256.New() + batchHasher = io.MultiWriter(hasher, sha256Hasher) + } + + if stream != nil { + if _, err = io.Copy(batchHasher, stream); err != nil { + return "", "", fmt.Errorf( + "failed to read stream to compute hash, %w", err) + } + } + + checksum = string(base64EncodeHashSum(hasher)) + if computePayloadHash { + if algorithm != AlgorithmSHA256 { + sha256Checksum = string(hexEncodeHashSum(sha256Hasher)) + } else { + sha256Checksum = string(hexEncodeHashSum(hasher)) + } + } + + return checksum, sha256Checksum, nil +} + +func getRequestStreamLength(req *smithyhttp.Request) (int64, error) { + if v := req.ContentLength; v > 0 { + return v, nil + } + + if length, ok, err := req.StreamLength(); err != nil { + return 0, fmt.Errorf("failed getting request stream's length, %w", err) + } else if ok { + return length, nil + } + + return -1, nil +} + +// setMD5Checksum computes the MD5 of the request payload and sets it to the +// Content-MD5 header. Returning the MD5 base64 encoded string or error. +// +// If the MD5 is already set as the Content-MD5 header, that value will be +// returned, and nothing else will be done. +// +// If the payload is empty, no MD5 will be computed. No error will be returned. +// Empty payloads do not have an MD5 value. +// +// Replaces the smithy-go middleware for httpChecksum trait. +func setMD5Checksum(ctx context.Context, req *smithyhttp.Request) (string, error) { + if v := req.Header.Get(contentMD5Header); len(v) != 0 { + return v, nil + } + stream := req.GetStream() + if stream == nil { + return "", nil + } + + if !req.IsStreamSeekable() { + return "", fmt.Errorf( + "unseekable stream is not supported for computing md5 checksum") + } + + v, err := computeMD5Checksum(stream) + if err != nil { + return "", err + } + if err := req.RewindStream(); err != nil { + return "", fmt.Errorf("failed to rewind stream after computing MD5 checksum, %w", err) + } + // set the 'Content-MD5' header + req.Header.Set(contentMD5Header, string(v)) + return string(v), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go new file mode 100644 index 0000000000..7ad59e7315 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go @@ -0,0 +1,958 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + "testing/iotest" + + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +// TODO test cases: +// * Retry re-wrapping payload + +func TestComputeInputPayloadChecksum(t *testing.T) { + cases := map[string]map[string]struct { + optionsFn func(*computeInputPayloadChecksum) + initContext func(context.Context) context.Context + buildInput middleware.BuildInput + + expectErr string + expectBuildErr bool + expectFinalizeErr bool + expectReadErr bool + + expectHeader http.Header + expectContentLength int64 + expectPayload []byte + expectPayloadHash string + + expectChecksumMetadata map[string]string + + expectDeferToFinalize bool + expectLogged string + }{ + "no op": { + "checksum header set known length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = 11 + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "checksum header set unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = -1 + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "no algorithm": { + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = 11 + return r + }(), + }, + expectHeader: http.Header{}, + expectContentLength: 11, + expectPayload: []byte("hello world"), + }, + "nil stream no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + }, + expectContentLength: -1, + expectHeader: http.Header{}, + }, + }, + + "build handled": { + "http nil stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http empty stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(strings.NewReader(""))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: []byte{}, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https empty stream unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(&bytes.Buffer{})) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: nil, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http empty stream unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(&bytes.Buffer{})) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: nil, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https nil stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https empty stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(strings.NewReader(""))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: []byte{}, + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "http no algorithm require checksum header preset": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-MD5", "XrY7u+Ae7tCTyyK7j1rNww==") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "https no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "http seekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "http payload hash already set": { + initContext: func(ctx context.Context) context.Context { + ctx = setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + ctx = v4.SetPayloadHash(ctx, "somehash") + return ctx + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "somehash", + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "http seekable checksum matches payload hash": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmSHA256)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Sha256": []string{"uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + expectChecksumMetadata: map[string]string{ + "SHA256": "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=", + }, + }, + "http payload hash disabled": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableComputePayloadHash = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no trailing checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "with content encoding set": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-Encoding", "gzip") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + "Content-Encoding": []string{"gzip"}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + }, + + "build error": { + "unknown algorithm": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string("unknown")) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "failed to parse algorithm", + expectBuildErr: true, + }, + "no algorithm require checksum unseekable stream": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + "http unseekable stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + "http stream read error": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 128 + r = requestMust(r.SetStream(&mockReadSeeker{ + Reader: iotest.ErrReader(fmt.Errorf("read error")), + })) + return r + }(), + }, + expectErr: "failed to read stream to compute hash", + expectBuildErr: true, + }, + "http stream rewind error": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(&errSeekReader{ + Reader: strings.NewReader("hello world"), + })) + return r + }(), + }, + expectErr: "failed to rewind stream", + expectBuildErr: true, + }, + "https no trailing unseekable stream": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + }, + + "finalize handled": { + "https unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https unseekable unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = -1 + r = requestMust(r.SetStream(ioutil.NopCloser(bytes.NewBuffer([]byte("hello world"))))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: -1, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https seekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https seekable unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = -1 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no compute payload hash": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableComputePayloadHash = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no decode content length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableDecodedContentLengthHeader = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "with content encoding set": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-Encoding", "gzip") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "Content-Encoding": []string{"gzip", "aws-chunked"}, + }, + expectContentLength: 52, + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + }, + } + + for name, cs := range cases { + t.Run(name, func(t *testing.T) { + for name, c := range cs { + t.Run(name, func(t *testing.T) { + m := &computeInputPayloadChecksum{ + EnableTrailingChecksum: true, + EnableComputePayloadHash: true, + EnableDecodedContentLengthHeader: true, + } + if c.optionsFn != nil { + c.optionsFn(m) + } + + ctx := context.Background() + var logged bytes.Buffer + logger := logging.LoggerFunc( + func(classification logging.Classification, format string, v ...interface{}) { + fmt.Fprintf(&logged, format, v...) + }, + ) + + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + middleware.AddSetLoggerMiddleware(stack, logger) + + //------------------------------ + // Build handler + //------------------------------ + // On return path validate any errors were expected. + stack.Build.Add(middleware.BuildMiddlewareFunc( + "build-assert", + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, + ) { + // ignore initial build input for the test case's build input. + out, metadata, err = next.HandleBuild(ctx, c.buildInput) + if err == nil && c.expectBuildErr { + t.Fatalf("expect build error, got none") + } + + if !m.buildHandlerRun { + t.Fatalf("expect build handler run") + } + return out, metadata, err + }, + ), middleware.After) + + // Build middleware + stack.Build.Add(m, middleware.After) + + // Validate defer to finalize was performed as expected + stack.Build.Add(middleware.BuildMiddlewareFunc( + "assert-defer-to-finalize", + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, + ) { + if e, a := c.expectDeferToFinalize, m.deferToFinalizeHandler; e != a { + t.Fatalf("expect %v defer to finalize, got %v", e, a) + } + return next.HandleBuild(ctx, input) + }, + ), middleware.After) + + //------------------------------ + // Finalize handler + //------------------------------ + if m.EnableTrailingChecksum { + // On return path assert any errors are expected. + stack.Finalize.Add(middleware.FinalizeMiddlewareFunc( + "build-assert", + func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + out, metadata, err = next.HandleFinalize(ctx, input) + if err == nil && c.expectFinalizeErr { + t.Fatalf("expect finalize error, got none") + } + + return out, metadata, err + }, + ), middleware.After) + + // Add finalize middleware + stack.Finalize.Add(m, middleware.After) + } + + //------------------------------ + // Request validation + //------------------------------ + validateRequestHandler := middleware.HandlerFunc( + func(ctx context.Context, input interface{}) ( + output interface{}, metadata middleware.Metadata, err error, + ) { + request := input.(*smithyhttp.Request) + + if diff := cmp.Diff(c.expectHeader, request.Header); diff != "" { + t.Errorf("expect header to match:\n%s", diff) + } + if e, a := c.expectContentLength, request.ContentLength; e != a { + t.Errorf("expect %v content length, got %v", e, a) + } + + stream := request.GetStream() + if e, a := stream != nil, c.expectPayload != nil; e != a { + t.Fatalf("expect nil payload %t, got %t", e, a) + } + if stream == nil { + return + } + + actualPayload, err := ioutil.ReadAll(stream) + if err == nil && c.expectReadErr { + t.Fatalf("expected read error, got none") + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match:\n%s", diff) + } + + payloadHash := v4.GetPayloadHash(ctx) + if e, a := c.expectPayloadHash, payloadHash; e != a { + t.Errorf("expect %v payload hash, got %v", e, a) + } + + return &smithyhttp.Response{}, metadata, nil + }, + ) + + if c.initContext != nil { + ctx = c.initContext(ctx) + } + _, metadata, err := stack.HandleMiddleware(ctx, struct{}{}, validateRequestHandler) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expected error: %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expected error %v to contain %v", err, c.expectErr) + } + if c.expectErr != "" { + return + } + + if c.expectLogged != "" { + if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) { + t.Errorf("expected %q logged in:\n%s", e, a) + } + } + + // assert computed input checksums metadata + computedMetadata, ok := GetComputedInputChecksums(metadata) + if e, a := ok, (c.expectChecksumMetadata != nil); e != a { + t.Fatalf("expect checksum metadata %t, got %t, %v", e, a, computedMetadata) + } + if c.expectChecksumMetadata != nil { + if diff := cmp.Diff(c.expectChecksumMetadata, computedMetadata); diff != "" { + t.Errorf("expect checksum metadata match\n%s", diff) + } + } + }) + } + }) + } +} + +type mockReadSeeker struct { + io.Reader +} + +func (r *mockReadSeeker) Seek(int64, int) (int64, error) { + return 0, nil +} + +type errSeekReader struct { + io.Reader +} + +func (r *errSeekReader) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekCurrent { + return 0, nil + } + + return 0, fmt.Errorf("seek failed") +} + +func requestMust(r *smithyhttp.Request, err error) *smithyhttp.Request { + if err != nil { + panic(err.Error()) + } + + return r +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go new file mode 100644 index 0000000000..f729525497 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go @@ -0,0 +1,117 @@ +package checksum + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +// setupChecksumContext is the initial middleware that looks up the input +// used to configure checksum behavior. This middleware must be executed before +// input validation step or any other checksum middleware. +type setupInputContext struct { + // GetAlgorithm is a function to get the checksum algorithm of the + // input payload from the input parameters. + // + // Given the input parameter value, the function must return the algorithm + // and true, or false if no algorithm is specified. + GetAlgorithm func(interface{}) (string, bool) +} + +// ID for the middleware +func (m *setupInputContext) ID() string { + return "AWSChecksum:SetupInputContext" +} + +// HandleInitialize initialization middleware that setups up the checksum +// context based on the input parameters provided in the stack. +func (m *setupInputContext) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // Check if validation algorithm is specified. + if m.GetAlgorithm != nil { + // check is input resource has a checksum algorithm + algorithm, ok := m.GetAlgorithm(in.Parameters) + if ok && len(algorithm) != 0 { + ctx = setContextInputAlgorithm(ctx, algorithm) + } + } + + return next.HandleInitialize(ctx, in) +} + +// inputAlgorithmKey is the key set on context used to identify, retrieves the +// request checksum algorithm if present on the context. +type inputAlgorithmKey struct{} + +// setContextInputAlgorithm sets the request checksum algorithm on the context. +// +// Scoped to stack values. +func setContextInputAlgorithm(ctx context.Context, value string) context.Context { + return middleware.WithStackValue(ctx, inputAlgorithmKey{}, value) +} + +// getContextInputAlgorithm returns the checksum algorithm from the context if +// one was specified. Empty string is returned if one is not specified. +// +// Scoped to stack values. +func getContextInputAlgorithm(ctx context.Context) (v string) { + v, _ = middleware.GetStackValue(ctx, inputAlgorithmKey{}).(string) + return v +} + +type setupOutputContext struct { + // GetValidationMode is a function to get the checksum validation + // mode of the output payload from the input parameters. + // + // Given the input parameter value, the function must return the validation + // mode and true, or false if no mode is specified. + GetValidationMode func(interface{}) (string, bool) +} + +// ID for the middleware +func (m *setupOutputContext) ID() string { + return "AWSChecksum:SetupOutputContext" +} + +// HandleInitialize initialization middleware that setups up the checksum +// context based on the input parameters provided in the stack. +func (m *setupOutputContext) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // Check if validation mode is specified. + if m.GetValidationMode != nil { + // check is input resource has a checksum algorithm + mode, ok := m.GetValidationMode(in.Parameters) + if ok && len(mode) != 0 { + ctx = setContextOutputValidationMode(ctx, mode) + } + } + + return next.HandleInitialize(ctx, in) +} + +// outputValidationModeKey is the key set on context used to identify if +// output checksum validation is enabled. +type outputValidationModeKey struct{} + +// setContextOutputValidationMode sets the request checksum +// algorithm on the context. +// +// Scoped to stack values. +func setContextOutputValidationMode(ctx context.Context, value string) context.Context { + return middleware.WithStackValue(ctx, outputValidationModeKey{}, value) +} + +// getContextOutputValidationMode returns response checksum validation state, +// if one was specified. Empty string is returned if one is not specified. +// +// Scoped to stack values. +func getContextOutputValidationMode(ctx context.Context) (v string) { + v, _ = middleware.GetStackValue(ctx, outputValidationModeKey{}).(string) + return v +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go new file mode 100644 index 0000000000..3235983bad --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go @@ -0,0 +1,143 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "context" + "testing" + + "github.com/aws/smithy-go/middleware" +) + +func TestSetupInput(t *testing.T) { + type Params struct { + Value string + } + + cases := map[string]struct { + inputParams interface{} + getAlgorithm func(interface{}) (string, bool) + expectValue string + }{ + "nil accessor": { + expectValue: "", + }, + "found empty": { + inputParams: Params{Value: ""}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "", + }, + "found not set": { + inputParams: Params{Value: ""}, + getAlgorithm: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "", + }, + "found": { + inputParams: Params{Value: "abc123"}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := setupInputContext{ + GetAlgorithm: c.getAlgorithm, + } + + _, _, err := m.HandleInitialize(context.Background(), + middleware.InitializeInput{Parameters: c.inputParams}, + middleware.InitializeHandlerFunc( + func(ctx context.Context, input middleware.InitializeInput) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + v := getContextInputAlgorithm(ctx) + if e, a := c.expectValue, v; e != a { + t.Errorf("expect value %v, got %v", e, a) + } + + return out, metadata, nil + }, + )) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + }) + } +} + +func TestSetupOutput(t *testing.T) { + type Params struct { + Value string + } + + cases := map[string]struct { + inputParams interface{} + getValidationMode func(interface{}) (string, bool) + expectValue string + }{ + "nil accessor": { + expectValue: "", + }, + "found empty": { + inputParams: Params{Value: ""}, + getValidationMode: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "", + }, + "found not set": { + inputParams: Params{Value: ""}, + getValidationMode: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "", + }, + "found": { + inputParams: Params{Value: "abc123"}, + getValidationMode: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := setupOutputContext{ + GetValidationMode: c.getValidationMode, + } + + _, _, err := m.HandleInitialize(context.Background(), + middleware.InitializeInput{Parameters: c.inputParams}, + middleware.InitializeHandlerFunc( + func(ctx context.Context, input middleware.InitializeInput) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + v := getContextOutputValidationMode(ctx) + if e, a := c.expectValue, v; e != a { + t.Errorf("expect value %v, got %v", e, a) + } + + return out, metadata, nil + }, + )) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go new file mode 100644 index 0000000000..9fde12d86d --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go @@ -0,0 +1,131 @@ +package checksum + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// outputValidationAlgorithmsUsedKey is the metadata key for indexing the algorithms +// that were used, by the middleware's validation. +type outputValidationAlgorithmsUsedKey struct{} + +// GetOutputValidationAlgorithmsUsed returns the checksum algorithms used +// stored in the middleware Metadata. Returns false if no algorithms were +// stored in the Metadata. +func GetOutputValidationAlgorithmsUsed(m middleware.Metadata) ([]string, bool) { + vs, ok := m.Get(outputValidationAlgorithmsUsedKey{}).([]string) + return vs, ok +} + +// SetOutputValidationAlgorithmsUsed stores the checksum algorithms used in the +// middleware Metadata. +func SetOutputValidationAlgorithmsUsed(m *middleware.Metadata, vs []string) { + m.Set(outputValidationAlgorithmsUsedKey{}, vs) +} + +// validateOutputPayloadChecksum middleware computes payload checksum of the +// received response and validates with checksum returned by the service. +type validateOutputPayloadChecksum struct { + // Algorithms represents a priority-ordered list of valid checksum + // algorithm that should be validated when present in HTTP response + // headers. + Algorithms []Algorithm + + // IgnoreMultipartValidation indicates multipart checksums ending with "-#" + // will be ignored. + IgnoreMultipartValidation bool + + // When set the middleware will log when output does not have checksum or + // algorithm to validate. + LogValidationSkipped bool + + // When set the middleware will log when the output contains a multipart + // checksum that was, skipped and not validated. + LogMultipartValidationSkipped bool +} + +func (m *validateOutputPayloadChecksum) ID() string { + return "AWSChecksum:ValidateOutputPayloadChecksum" +} + +// HandleDeserialize is a Deserialize middleware that wraps the HTTP response +// body with an io.ReadCloser that will validate the its checksum. +func (m *validateOutputPayloadChecksum) HandleDeserialize( + ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, +) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + + // If there is no validation mode specified nothing is supported. + if mode := getContextOutputValidationMode(ctx); mode != "ENABLED" { + return out, metadata, err + } + + response, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + return out, metadata, &smithy.DeserializationError{ + Err: fmt.Errorf("unknown transport type %T", out.RawResponse), + } + } + + var expectedChecksum string + var algorithmToUse Algorithm + for _, algorithm := range m.Algorithms { + value := response.Header.Get(AlgorithmHTTPHeader(algorithm)) + if len(value) == 0 { + continue + } + + expectedChecksum = value + algorithmToUse = algorithm + } + + // TODO this must validate the validation mode is set to enabled. + + logger := middleware.GetLogger(ctx) + + // Skip validation if no checksum algorithm or checksum is available. + if len(expectedChecksum) == 0 || len(algorithmToUse) == 0 { + if m.LogValidationSkipped { + // TODO this probably should have more information about the + // operation output that won't be validated. + logger.Logf(logging.Warn, + "Response has no supported checksum. Not validating response payload.") + } + return out, metadata, nil + } + + // Ignore multipart validation + if m.IgnoreMultipartValidation && strings.Contains(expectedChecksum, "-") { + if m.LogMultipartValidationSkipped { + // TODO this probably should have more information about the + // operation output that won't be validated. + logger.Logf(logging.Warn, "Skipped validation of multipart checksum.") + } + return out, metadata, nil + } + + body, err := newValidateChecksumReader(response.Body, algorithmToUse, expectedChecksum) + if err != nil { + return out, metadata, fmt.Errorf("failed to create checksum validation reader, %w", err) + } + response.Body = body + + // Update the metadata to include the set of the checksum algorithms that + // will be validated. + SetOutputValidationAlgorithmsUsed(&metadata, []string{ + string(algorithmToUse), + }) + + return out, metadata, nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go new file mode 100644 index 0000000000..e6cf7054e8 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go @@ -0,0 +1,263 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + "testing/iotest" + + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +func TestValidateOutputPayloadChecksum(t *testing.T) { + cases := map[string]struct { + response *smithyhttp.Response + validateOptions func(*validateOutputPayloadChecksum) + modifyContext func(context.Context) context.Context + expectHaveAlgorithmsUsed bool + expectAlgorithmsUsed []string + expectErr string + expectReadErr string + expectLogged string + expectPayload []byte + }{ + "success": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectHaveAlgorithmsUsed: true, + expectAlgorithmsUsed: []string{"CRC32"}, + expectPayload: []byte("hello world"), + }, + "failure": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectReadErr: "checksum did not match", + }, + "read error": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(iotest.ErrReader(fmt.Errorf("some read error"))), + }, + }, + expectReadErr: "some read error", + }, + "unsupported algorithm": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader("unsupported"), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectLogged: "no supported checksum", + expectPayload: []byte("hello world"), + }, + "no output validation model": { + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectPayload: []byte("hello world"), + }, + "unknown output validation model": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "something else") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectPayload: []byte("hello world"), + }, + "success ignore multipart checksum": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + validateOptions: func(o *validateOutputPayloadChecksum) { + o.IgnoreMultipartValidation = true + }, + expectHaveAlgorithmsUsed: true, + expectAlgorithmsUsed: []string{"CRC32"}, + expectPayload: []byte("hello world"), + }, + "success skip ignore multipart checksum": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==-12") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + validateOptions: func(o *validateOutputPayloadChecksum) { + o.IgnoreMultipartValidation = true + }, + expectLogged: "Skipped validation of multipart checksum", + expectPayload: []byte("hello world"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var logged bytes.Buffer + ctx := middleware.SetLogger(context.Background(), logging.LoggerFunc( + func(classification logging.Classification, format string, v ...interface{}) { + fmt.Fprintf(&logged, format, v...) + })) + + if c.modifyContext != nil { + ctx = c.modifyContext(ctx) + } + + validateOutput := validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmSHA1, AlgorithmCRC32, AlgorithmCRC32C, + }, + LogValidationSkipped: true, + LogMultipartValidationSkipped: true, + } + if c.validateOptions != nil { + c.validateOptions(&validateOutput) + } + + out, meta, err := validateOutput.HandleDeserialize(ctx, + middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc( + func(ctx context.Context, input middleware.DeserializeInput) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out.RawResponse = c.response + return out, metadata, nil + }, + ), + ) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + response := out.RawResponse.(*smithyhttp.Response) + + actualPayload, err := ioutil.ReadAll(response.Body) + if err == nil && len(c.expectReadErr) != 0 { + t.Fatalf("expected read error: %v, got none", c.expectReadErr) + } + if err != nil && len(c.expectReadErr) == 0 { + t.Fatalf("expect no read error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectReadErr) { + t.Fatalf("expected read error %v to contain %v", err, c.expectReadErr) + } + if c.expectReadErr != "" { + return + } + + if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) || !((e == "") == (a == "")) { + t.Errorf("expected %q logged in:\n%s", e, a) + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match:\n%s", diff) + } + + if err = response.Body.Close(); err != nil { + t.Errorf("expect no close error, got %v", err) + } + + values, ok := GetOutputValidationAlgorithmsUsed(meta) + if ok != c.expectHaveAlgorithmsUsed { + t.Errorf("expect metadata to contain algorithms used, %t", c.expectHaveAlgorithmsUsed) + } + if diff := cmp.Diff(c.expectAlgorithmsUsed, values); diff != "" { + t.Errorf("expect algorithms used to match\n%s", diff) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make new file mode 100644 index 0000000000..85e34f039f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make @@ -0,0 +1,28 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + algorithms.go + aws_chunked_encoding.go + go_module_metadata.go + middleware_add.go + middleware_compute_input_checksum.go + middleware_setup_context.go + middleware_validate_output.go +) + +GO_TEST_SRCS( + algorithms_test.go + aws_chunked_encoding_test.go + middleware_add_test.go + middleware_compute_input_checksum_test.go + middleware_setup_context_test.go + middleware_validate_output_test.go +) + +END() + +RECURSE( + gotest +) diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/context.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/context.go new file mode 100644 index 0000000000..cc919701a0 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/context.go @@ -0,0 +1,48 @@ +package presignedurl + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +// WithIsPresigning adds the isPresigning sentinel value to a context to signal +// that the middleware stack is using the presign flow. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func WithIsPresigning(ctx context.Context) context.Context { + return middleware.WithStackValue(ctx, isPresigningKey{}, true) +} + +// GetIsPresigning returns if the context contains the isPresigning sentinel +// value for presigning flows. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func GetIsPresigning(ctx context.Context) bool { + v, _ := middleware.GetStackValue(ctx, isPresigningKey{}).(bool) + return v +} + +type isPresigningKey struct{} + +// AddAsIsPresigingMiddleware adds a middleware to the head of the stack that +// will update the stack's context to be flagged as being invoked for the +// purpose of presigning. +func AddAsIsPresigingMiddleware(stack *middleware.Stack) error { + return stack.Initialize.Add(asIsPresigningMiddleware{}, middleware.Before) +} + +type asIsPresigningMiddleware struct{} + +func (asIsPresigningMiddleware) ID() string { return "AsIsPresigningMiddleware" } + +func (asIsPresigningMiddleware) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + ctx = WithIsPresigning(ctx) + return next.HandleInitialize(ctx, in) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/doc.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/doc.go new file mode 100644 index 0000000000..1b85375cf8 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/doc.go @@ -0,0 +1,3 @@ +// Package presignedurl provides the customizations for API clients to fill in +// presigned URLs into input parameters. +package presignedurl diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/go_module_metadata.go new file mode 100644 index 0000000000..e7168c626e --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/go_module_metadata.go @@ -0,0 +1,6 @@ +// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT. + +package presignedurl + +// goModuleVersion is the tagged release for this module +const goModuleVersion = "1.10.0" diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/gotest/ya.make new file mode 100644 index 0000000000..715df102eb --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware.go new file mode 100644 index 0000000000..1e2f5c8122 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware.go @@ -0,0 +1,110 @@ +package presignedurl + +import ( + "context" + "fmt" + + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + + "github.com/aws/smithy-go/middleware" +) + +// URLPresigner provides the interface to presign the input parameters in to a +// presigned URL. +type URLPresigner interface { + // PresignURL presigns a URL. + PresignURL(ctx context.Context, srcRegion string, params interface{}) (*v4.PresignedHTTPRequest, error) +} + +// ParameterAccessor provides an collection of accessor to for retrieving and +// setting the values needed to PresignedURL generation +type ParameterAccessor struct { + // GetPresignedURL accessor points to a function that retrieves a presigned url if present + GetPresignedURL func(interface{}) (string, bool, error) + + // GetSourceRegion accessor points to a function that retrieves source region for presigned url + GetSourceRegion func(interface{}) (string, bool, error) + + // CopyInput accessor points to a function that takes in an input, and returns a copy. + CopyInput func(interface{}) (interface{}, error) + + // SetDestinationRegion accessor points to a function that sets destination region on api input struct + SetDestinationRegion func(interface{}, string) error + + // SetPresignedURL accessor points to a function that sets presigned url on api input struct + SetPresignedURL func(interface{}, string) error +} + +// Options provides the set of options needed by the presigned URL middleware. +type Options struct { + // Accessor are the parameter accessors used by this middleware + Accessor ParameterAccessor + + // Presigner is the URLPresigner used by the middleware + Presigner URLPresigner +} + +// AddMiddleware adds the Presign URL middleware to the middleware stack. +func AddMiddleware(stack *middleware.Stack, opts Options) error { + return stack.Initialize.Add(&presign{options: opts}, middleware.Before) +} + +// RemoveMiddleware removes the Presign URL middleware from the stack. +func RemoveMiddleware(stack *middleware.Stack) error { + _, err := stack.Initialize.Remove((*presign)(nil).ID()) + return err +} + +type presign struct { + options Options +} + +func (m *presign) ID() string { return "Presign" } + +func (m *presign) HandleInitialize( + ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // If PresignedURL is already set ignore middleware. + if _, ok, err := m.options.Accessor.GetPresignedURL(input.Parameters); err != nil { + return out, metadata, fmt.Errorf("presign middleware failed, %w", err) + } else if ok { + return next.HandleInitialize(ctx, input) + } + + // If have source region is not set ignore middleware. + srcRegion, ok, err := m.options.Accessor.GetSourceRegion(input.Parameters) + if err != nil { + return out, metadata, fmt.Errorf("presign middleware failed, %w", err) + } else if !ok || len(srcRegion) == 0 { + return next.HandleInitialize(ctx, input) + } + + // Create a copy of the original input so the destination region value can + // be added. This ensures that value does not leak into the original + // request parameters. + paramCpy, err := m.options.Accessor.CopyInput(input.Parameters) + if err != nil { + return out, metadata, fmt.Errorf("unable to create presigned URL, %w", err) + } + + // Destination region is the API client's configured region. + dstRegion := awsmiddleware.GetRegion(ctx) + if err = m.options.Accessor.SetDestinationRegion(paramCpy, dstRegion); err != nil { + return out, metadata, fmt.Errorf("presign middleware failed, %w", err) + } + + presignedReq, err := m.options.Presigner.PresignURL(ctx, srcRegion, paramCpy) + if err != nil { + return out, metadata, fmt.Errorf("unable to create presigned URL, %w", err) + } + + // Update the original input with the presigned URL value. + if err = m.options.Accessor.SetPresignedURL(input.Parameters, presignedReq.URL); err != nil { + return out, metadata, fmt.Errorf("presign middleware failed, %w", err) + } + + return next.HandleInitialize(ctx, input) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware_test.go new file mode 100644 index 0000000000..56ba1d9205 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/middleware_test.go @@ -0,0 +1,151 @@ +package presignedurl + +import ( + "context" + "net/http" + "strings" + "testing" + + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +func TestPresignMiddleware(t *testing.T) { + cases := map[string]struct { + Input *mockURLPresignInput + + ExpectInput *mockURLPresignInput + ExpectErr string + }{ + "no source": { + Input: &mockURLPresignInput{}, + ExpectInput: &mockURLPresignInput{}, + }, + "with presigned URL": { + Input: &mockURLPresignInput{ + SourceRegion: "source-region", + PresignedURL: "https://example.amazonaws.com/someURL", + }, + ExpectInput: &mockURLPresignInput{ + SourceRegion: "source-region", + PresignedURL: "https://example.amazonaws.com/someURL", + }, + }, + "with source": { + Input: &mockURLPresignInput{ + SourceRegion: "source-region", + }, + ExpectInput: &mockURLPresignInput{ + SourceRegion: "source-region", + PresignedURL: "https://example.source-region.amazonaws.com/?DestinationRegion=mock-region", + }, + }, + "matching source destination region": { + Input: &mockURLPresignInput{ + SourceRegion: "mock-region", + }, + ExpectInput: &mockURLPresignInput{ + SourceRegion: "mock-region", + PresignedURL: "https://example.mock-region.amazonaws.com/?DestinationRegion=mock-region", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack(name, smithyhttp.NewStackRequest) + + stack.Initialize.Add(&awsmiddleware.RegisterServiceMetadata{ + Region: "mock-region", + }, middleware.After) + + stack.Initialize.Add(&presign{options: getURLPresignMiddlewareOptions()}, middleware.After) + + stack.Initialize.Add(middleware.InitializeMiddlewareFunc(name+"_verifyParams", + func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + input := in.Parameters.(*mockURLPresignInput) + if diff := cmp.Diff(c.ExpectInput, input); len(diff) != 0 { + t.Errorf("expect input to be updated\n%s", diff) + } + + return next.HandleInitialize(ctx, in) + }, + ), middleware.After) + + handler := middleware.DecorateHandler(smithyhttp.NewClientHandler(smithyhttp.NopClient{}), stack) + _, _, err := handler.Handle(context.Background(), c.Input) + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func getURLPresignMiddlewareOptions() Options { + return Options{ + Accessor: ParameterAccessor{ + GetPresignedURL: func(c interface{}) (string, bool, error) { + presignURL := c.(*mockURLPresignInput).PresignedURL + if len(presignURL) != 0 { + return presignURL, true, nil + } + return "", false, nil + }, + GetSourceRegion: func(c interface{}) (string, bool, error) { + srcRegion := c.(*mockURLPresignInput).SourceRegion + if len(srcRegion) != 0 { + return srcRegion, true, nil + } + return "", false, nil + }, + CopyInput: func(c interface{}) (interface{}, error) { + input := *(c.(*mockURLPresignInput)) + return &input, nil + }, + SetDestinationRegion: func(c interface{}, v string) error { + c.(*mockURLPresignInput).DestinationRegion = v + return nil + }, + SetPresignedURL: func(c interface{}, v string) error { + c.(*mockURLPresignInput).PresignedURL = v + return nil + }, + }, + Presigner: &mockURLPresigner{}, + } +} + +type mockURLPresignInput struct { + SourceRegion string + DestinationRegion string + PresignedURL string +} + +type mockURLPresigner struct{} + +func (*mockURLPresigner) PresignURL(ctx context.Context, srcRegion string, params interface{}) ( + req *v4.PresignedHTTPRequest, err error, +) { + in := params.(*mockURLPresignInput) + + return &v4.PresignedHTTPRequest{ + URL: "https://example." + srcRegion + ".amazonaws.com/?DestinationRegion=" + in.DestinationRegion, + Method: "GET", + SignedHeader: http.Header{}, + }, nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/ya.make new file mode 100644 index 0000000000..266a0d246b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url/ya.make @@ -0,0 +1,18 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + context.go + doc.go + go_module_metadata.go + middleware.go +) + +GO_TEST_SRCS(middleware_test.go) + +END() + +RECURSE( + gotest +) diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn.go new file mode 100644 index 0000000000..ec290b2135 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn.go @@ -0,0 +1,53 @@ +package arn + +import ( + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +// AccessPointARN provides representation +type AccessPointARN struct { + arn.ARN + AccessPointName string +} + +// GetARN returns the base ARN for the Access Point resource +func (a AccessPointARN) GetARN() arn.ARN { + return a.ARN +} + +// ParseAccessPointResource attempts to parse the ARN's resource as an +// AccessPoint resource. +// +// Supported Access point resource format: +// - Access point format: arn:{partition}:s3:{region}:{accountId}:accesspoint/{accesspointName} +// - example: arn:aws:s3:us-west-2:012345678901:accesspoint/myaccesspoint +func ParseAccessPointResource(a arn.ARN, resParts []string) (AccessPointARN, error) { + if isFIPS(a.Region) { + return AccessPointARN{}, InvalidARNError{ARN: a, Reason: "FIPS region not allowed in ARN"} + } + if len(a.AccountID) == 0 { + return AccessPointARN{}, InvalidARNError{ARN: a, Reason: "account-id not set"} + } + if len(resParts) == 0 { + return AccessPointARN{}, InvalidARNError{ARN: a, Reason: "resource-id not set"} + } + if len(resParts) > 1 { + return AccessPointARN{}, InvalidARNError{ARN: a, Reason: "sub resource not supported"} + } + + resID := resParts[0] + if len(strings.TrimSpace(resID)) == 0 { + return AccessPointARN{}, InvalidARNError{ARN: a, Reason: "resource-id not set"} + } + + return AccessPointARN{ + ARN: a, + AccessPointName: resID, + }, nil +} + +func isFIPS(region string) bool { + return strings.HasPrefix(region, "fips-") || strings.HasSuffix(region, "-fips") +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn_test.go new file mode 100644 index 0000000000..51221b20f3 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/accesspoint_arn_test.go @@ -0,0 +1,118 @@ +package arn + +import ( + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +func TestParseAccessPointResource(t *testing.T) { + cases := map[string]struct { + ARN arn.ARN + ExpectErr string + ExpectARN AccessPointARN + }{ + "account-id not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + Resource: "accesspoint/myendpoint", + }, + ExpectErr: "account-id not set", + }, + "resource-id not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint", + }, + ExpectErr: "resource-id not set", + }, + "resource-id empty": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint:", + }, + ExpectErr: "resource-id not set", + }, + "resource not supported": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint/endpoint/object/key", + }, + ExpectErr: "sub resource not supported", + }, + "valid resource-id": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint/endpoint", + }, + ExpectARN: AccessPointARN{ + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint/endpoint", + }, + AccessPointName: "endpoint", + }, + }, + "invalid FIPS pseudo region in ARN (prefix)": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "fips-us-west-2", + AccountID: "012345678901", + Resource: "accesspoint/endpoint", + }, + ExpectErr: "FIPS region not allowed in ARN", + }, + "invalid FIPS pseudo region in ARN (suffix)": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2-fips", + AccountID: "012345678901", + Resource: "accesspoint/endpoint", + }, + ExpectErr: "FIPS region not allowed in ARN", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + resParts := SplitResource(c.ARN.Resource) + a, err := ParseAccessPointResource(c.ARN, resParts[1:]) + + if len(c.ExpectErr) == 0 && err != nil { + t.Fatalf("expect no error but got %v", err) + } else if len(c.ExpectErr) != 0 && err == nil { + t.Fatalf("expect error %q, but got nil", c.ExpectErr) + } else if len(c.ExpectErr) != 0 && err != nil { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error %q, got %q", e, a) + } + return + } + + if e, a := c.ExpectARN, a; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn.go new file mode 100644 index 0000000000..06e1a3addd --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn.go @@ -0,0 +1,85 @@ +package arn + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +var supportedServiceARN = []string{ + "s3", + "s3-outposts", + "s3-object-lambda", +} + +func isSupportedServiceARN(service string) bool { + for _, name := range supportedServiceARN { + if name == service { + return true + } + } + return false +} + +// Resource provides the interfaces abstracting ARNs of specific resource +// types. +type Resource interface { + GetARN() arn.ARN + String() string +} + +// ResourceParser provides the function for parsing an ARN's resource +// component into a typed resource. +type ResourceParser func(arn.ARN) (Resource, error) + +// ParseResource parses an AWS ARN into a typed resource for the S3 API. +func ParseResource(a arn.ARN, resParser ResourceParser) (resARN Resource, err error) { + if len(a.Partition) == 0 { + return nil, InvalidARNError{ARN: a, Reason: "partition not set"} + } + + if !isSupportedServiceARN(a.Service) { + return nil, InvalidARNError{ARN: a, Reason: "service is not supported"} + } + + if len(a.Resource) == 0 { + return nil, InvalidARNError{ARN: a, Reason: "resource not set"} + } + + return resParser(a) +} + +// SplitResource splits the resource components by the ARN resource delimiters. +func SplitResource(v string) []string { + var parts []string + var offset int + + for offset <= len(v) { + idx := strings.IndexAny(v[offset:], "/:") + if idx < 0 { + parts = append(parts, v[offset:]) + break + } + parts = append(parts, v[offset:idx+offset]) + offset += idx + 1 + } + + return parts +} + +// IsARN returns whether the given string is an ARN +func IsARN(s string) bool { + return arn.IsARN(s) +} + +// InvalidARNError provides the error for an invalid ARN error. +type InvalidARNError struct { + ARN arn.ARN + Reason string +} + +// Error returns a string denoting the occurred InvalidARNError +func (e InvalidARNError) Error() string { + return fmt.Sprintf("invalid Amazon %s ARN, %s, %s", e.ARN.Service, e.Reason, e.ARN.String()) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_member.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_member.go new file mode 100644 index 0000000000..9a3258e15a --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_member.go @@ -0,0 +1,32 @@ +package arn + +import "fmt" + +// arnable is implemented by the relevant S3/S3Control +// operations which have members that may need ARN +// processing. +type arnable interface { + SetARNMember(string) error + GetARNMember() (*string, bool) +} + +// GetARNField would be called during middleware execution +// to retrieve a member value that is an ARN in need of +// processing. +func GetARNField(input interface{}) (*string, bool) { + v, ok := input.(arnable) + if !ok { + return nil, false + } + return v.GetARNMember() +} + +// SetARNField would called during middleware exeuction +// to set a member value that required ARN processing. +func SetARNField(input interface{}, v string) error { + params, ok := input.(arnable) + if !ok { + return fmt.Errorf("Params does not contain arn field member") + } + return params.SetARNMember(v) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_test.go new file mode 100644 index 0000000000..db7beaaf8f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/arn_test.go @@ -0,0 +1,170 @@ +package arn + +import ( + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +func TestParseResource(t *testing.T) { + cases := map[string]struct { + Input string + MappedResources map[string]func(arn.ARN, []string) (Resource, error) + Expect Resource + ExpectErr string + }{ + "Empty ARN": { + Input: "", + ExpectErr: "arn: invalid prefix", + }, + "No Partition": { + Input: "arn::sqs:us-west-2:012345678901:accesspoint", + ExpectErr: "partition not set", + }, + "Not S3 ARN": { + Input: "arn:aws:sqs:us-west-2:012345678901:accesspoint", + ExpectErr: "service is not supported", + }, + "No Resource": { + Input: "arn:aws:s3:us-west-2:012345678901:", + ExpectErr: "resource not set", + }, + "Unknown Resource Type": { + Input: "arn:aws:s3:us-west-2:012345678901:myresource", + ExpectErr: "unknown resource type", + }, + "Unknown BucketARN Resource Type": { + Input: "arn:aws:s3:us-west-2:012345678901:bucket_name:mybucket", + ExpectErr: "unknown resource type", + }, + "Unknown Resource Type with Resource and Sub-Resource": { + Input: "arn:aws:s3:us-west-2:012345678901:somethingnew:myresource/subresource", + ExpectErr: "unknown resource type", + }, + "Access Point with sub resource": { + Input: "arn:aws:s3:us-west-2:012345678901:accesspoint:myresource/subresource", + MappedResources: map[string]func(arn.ARN, []string) (Resource, error){ + "accesspoint": func(a arn.ARN, parts []string) (Resource, error) { + return ParseAccessPointResource(a, parts) + }, + }, + ExpectErr: "resource not supported", + }, + "AccessPoint Resource Type": { + Input: "arn:aws:s3:us-west-2:012345678901:accesspoint:myendpoint", + MappedResources: map[string]func(arn.ARN, []string) (Resource, error){ + "accesspoint": func(a arn.ARN, parts []string) (Resource, error) { + return ParseAccessPointResource(a, parts) + }, + }, + Expect: AccessPointARN{ + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint:myendpoint", + }, + AccessPointName: "myendpoint", + }, + }, + "AccessPoint Resource Type With Path Syntax": { + Input: "arn:aws:s3:us-west-2:012345678901:accesspoint/myendpoint", + MappedResources: map[string]func(arn.ARN, []string) (Resource, error){ + "accesspoint": func(a arn.ARN, parts []string) (Resource, error) { + return ParseAccessPointResource(a, parts) + }, + }, + Expect: AccessPointARN{ + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "accesspoint/myendpoint", + }, + AccessPointName: "myendpoint", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var parsed Resource + arn, err := arn.Parse(c.Input) + if err == nil { + parsed, err = ParseResource(arn, mappedResourceParser(c.MappedResources)) + } + + if len(c.ExpectErr) == 0 && err != nil { + t.Fatalf("expect no error but got %v", err) + } else if len(c.ExpectErr) != 0 && err == nil { + t.Fatalf("expect error but got nil") + } else if len(c.ExpectErr) != 0 && err != nil { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error %q, got %q", e, a) + } + return + } + + if e, a := c.Expect, parsed; !reflect.DeepEqual(e, a) { + t.Errorf("Expect %v, got %v", e, a) + } + }) + } +} + +func mappedResourceParser(kinds map[string]func(arn.ARN, []string) (Resource, error)) ResourceParser { + return func(a arn.ARN) (Resource, error) { + parts := SplitResource(a.Resource) + + fn, ok := kinds[parts[0]] + if !ok { + return nil, InvalidARNError{ARN: a, Reason: "unknown resource type"} + } + return fn(a, parts[1:]) + } +} + +func TestSplitResource(t *testing.T) { + cases := []struct { + Input string + Expect []string + }{ + { + Input: "accesspoint:myendpoint", + Expect: []string{"accesspoint", "myendpoint"}, + }, + { + Input: "accesspoint/myendpoint", + Expect: []string{"accesspoint", "myendpoint"}, + }, + { + Input: "accesspoint", + Expect: []string{"accesspoint"}, + }, + { + Input: "accesspoint:", + Expect: []string{"accesspoint", ""}, + }, + { + Input: "accesspoint: ", + Expect: []string{"accesspoint", " "}, + }, + { + Input: "accesspoint:endpoint/object/key", + Expect: []string{"accesspoint", "endpoint", "object", "key"}, + }, + } + + for _, c := range cases { + t.Run(c.Input, func(t *testing.T) { + parts := SplitResource(c.Input) + if e, a := c.Expect, parts; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/gotest/ya.make new file mode 100644 index 0000000000..0db9a0c5c8 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn.go new file mode 100644 index 0000000000..e06a302857 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn.go @@ -0,0 +1,128 @@ +package arn + +import ( + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +// OutpostARN interface that should be satisfied by outpost ARNs +type OutpostARN interface { + Resource + GetOutpostID() string +} + +// ParseOutpostARNResource will parse a provided ARNs resource using the appropriate ARN format +// and return a specific OutpostARN type +// +// Currently supported outpost ARN formats: +// * Outpost AccessPoint ARN format: +// - ARN format: arn:{partition}:s3-outposts:{region}:{accountId}:outpost/{outpostId}/accesspoint/{accesspointName} +// - example: arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint +// +// * Outpost Bucket ARN format: +// - ARN format: arn:{partition}:s3-outposts:{region}:{accountId}:outpost/{outpostId}/bucket/{bucketName} +// - example: arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/bucket/mybucket +// +// Other outpost ARN formats may be supported and added in the future. +func ParseOutpostARNResource(a arn.ARN, resParts []string) (OutpostARN, error) { + if len(a.Region) == 0 { + return nil, InvalidARNError{ARN: a, Reason: "region not set"} + } + + if isFIPS(a.Region) { + return nil, InvalidARNError{ARN: a, Reason: "FIPS region not allowed in ARN"} + } + + if len(a.AccountID) == 0 { + return nil, InvalidARNError{ARN: a, Reason: "account-id not set"} + } + + // verify if outpost id is present and valid + if len(resParts) == 0 || len(strings.TrimSpace(resParts[0])) == 0 { + return nil, InvalidARNError{ARN: a, Reason: "outpost resource-id not set"} + } + + // verify possible resource type exists + if len(resParts) < 3 { + return nil, InvalidARNError{ + ARN: a, Reason: "incomplete outpost resource type. Expected bucket or access-point resource to be present", + } + } + + // Since we know this is a OutpostARN fetch outpostID + outpostID := strings.TrimSpace(resParts[0]) + + switch resParts[1] { + case "accesspoint": + accesspointARN, err := ParseAccessPointResource(a, resParts[2:]) + if err != nil { + return OutpostAccessPointARN{}, err + } + return OutpostAccessPointARN{ + AccessPointARN: accesspointARN, + OutpostID: outpostID, + }, nil + + case "bucket": + bucketName, err := parseBucketResource(a, resParts[2:]) + if err != nil { + return nil, err + } + return OutpostBucketARN{ + ARN: a, + BucketName: bucketName, + OutpostID: outpostID, + }, nil + + default: + return nil, InvalidARNError{ARN: a, Reason: "unknown resource set for outpost ARN"} + } +} + +// OutpostAccessPointARN represents outpost access point ARN. +type OutpostAccessPointARN struct { + AccessPointARN + OutpostID string +} + +// GetOutpostID returns the outpost id of outpost access point arn +func (o OutpostAccessPointARN) GetOutpostID() string { + return o.OutpostID +} + +// OutpostBucketARN represents the outpost bucket ARN. +type OutpostBucketARN struct { + arn.ARN + BucketName string + OutpostID string +} + +// GetOutpostID returns the outpost id of outpost bucket arn +func (o OutpostBucketARN) GetOutpostID() string { + return o.OutpostID +} + +// GetARN retrives the base ARN from outpost bucket ARN resource +func (o OutpostBucketARN) GetARN() arn.ARN { + return o.ARN +} + +// parseBucketResource attempts to parse the ARN's bucket resource and retrieve the +// bucket resource id. +// +// parseBucketResource only parses the bucket resource id. +func parseBucketResource(a arn.ARN, resParts []string) (bucketName string, err error) { + if len(resParts) == 0 { + return bucketName, InvalidARNError{ARN: a, Reason: "bucket resource-id not set"} + } + if len(resParts) > 1 { + return bucketName, InvalidARNError{ARN: a, Reason: "sub resource not supported"} + } + + bucketName = strings.TrimSpace(resParts[0]) + if len(bucketName) == 0 { + return bucketName, InvalidARNError{ARN: a, Reason: "bucket resource-id not set"} + } + return bucketName, err +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn_test.go new file mode 100644 index 0000000000..b21d4bd5d7 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/outpost_arn_test.go @@ -0,0 +1,291 @@ +package arn + +import ( + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +func TestParseOutpostAccessPointARNResource(t *testing.T) { + cases := map[string]struct { + ARN arn.ARN + ExpectErr string + ExpectARN OutpostAccessPointARN + }{ + "region not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/myendpoint", + }, + ExpectErr: "region not set", + }, + "account-id not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + Resource: "outpost/myoutpost/accesspoint/myendpoint", + }, + ExpectErr: "account-id not set", + }, + "resource-id not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "myoutpost", + }, + ExpectErr: "resource-id not set", + }, + "resource-id empty": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost:", + }, + ExpectErr: "resource-id not set", + }, + "resource not supported": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/endpoint/object/key", + }, + ExpectErr: "sub resource not supported", + }, + "access-point not defined": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/endpoint/object/key", + }, + ExpectErr: "unknown resource set for outpost ARN", + }, + "valid resource-id": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/myaccesspoint", + }, + ExpectARN: OutpostAccessPointARN{ + AccessPointARN: AccessPointARN{ + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/myaccesspoint", + }, + AccessPointName: "myaccesspoint", + }, + OutpostID: "myoutpost", + }, + }, + "invalid FIPS pseudo region in ARN (prefix)": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "fips-us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/myendpoint", + }, + ExpectErr: "FIPS region not allowed in ARN", + }, + "invalid FIPS pseudo region in ARN (suffix)": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2-fips", + AccountID: "012345678901", + Resource: "outpost/myoutpost/accesspoint/myendpoint", + }, + ExpectErr: "FIPS region not allowed in ARN", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + resParts := SplitResource(c.ARN.Resource) + a, err := ParseOutpostARNResource(c.ARN, resParts[1:]) + + if len(c.ExpectErr) == 0 && err != nil { + t.Fatalf("expect no error but got %v", err) + } else if len(c.ExpectErr) != 0 && err == nil { + t.Fatalf("expect error %q, but got nil", c.ExpectErr) + } else if len(c.ExpectErr) != 0 && err != nil { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error %q, got %q", e, a) + } + return + } + + if e, a := c.ExpectARN, a; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} + +func TestParseOutpostBucketARNResource(t *testing.T) { + cases := map[string]struct { + ARN arn.ARN + ExpectErr string + ExpectARN OutpostBucketARN + }{ + "region not set": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + AccountID: "012345678901", + Resource: "outpost/myoutpost/bucket/mybucket", + }, + ExpectErr: "region not set", + }, + "resource-id empty": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost:", + }, + ExpectErr: "resource-id not set", + }, + "resource not supported": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/bucket/mybucket/object/key", + }, + ExpectErr: "sub resource not supported", + }, + "bucket not defined": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/endpoint/object/key", + }, + ExpectErr: "unknown resource set for outpost ARN", + }, + "valid resource-id": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/bucket/mybucket", + }, + ExpectARN: OutpostBucketARN{ + ARN: arn.ARN{ + Partition: "aws", + Service: "s3-outposts", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "outpost/myoutpost/bucket/mybucket", + }, + BucketName: "mybucket", + OutpostID: "myoutpost", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + resParts := SplitResource(c.ARN.Resource) + a, err := ParseOutpostARNResource(c.ARN, resParts[1:]) + + if len(c.ExpectErr) == 0 && err != nil { + t.Fatalf("expect no error but got %v", err) + } else if len(c.ExpectErr) != 0 && err == nil { + t.Fatalf("expect error %q, but got nil", c.ExpectErr) + } else if len(c.ExpectErr) != 0 && err != nil { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error %q, got %q", e, a) + } + return + } + + if e, a := c.ExpectARN, a; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} + +func TestParseBucketResource(t *testing.T) { + cases := map[string]struct { + ARN arn.ARN + ExpectErr string + ExpectBucketName string + }{ + "resource-id empty": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "bucket:", + }, + ExpectErr: "bucket resource-id not set", + }, + "resource not supported": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "bucket/mybucket/object/key", + }, + ExpectErr: "sub resource not supported", + }, + "valid resource-id": { + ARN: arn.ARN{ + Partition: "aws", + Service: "s3", + Region: "us-west-2", + AccountID: "012345678901", + Resource: "bucket/mybucket", + }, + ExpectBucketName: "mybucket", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + resParts := SplitResource(c.ARN.Resource) + a, err := parseBucketResource(c.ARN, resParts[1:]) + + if len(c.ExpectErr) == 0 && err != nil { + t.Fatalf("expect no error but got %v", err) + } else if len(c.ExpectErr) != 0 && err == nil { + t.Fatalf("expect error %q, but got nil", c.ExpectErr) + } else if len(c.ExpectErr) != 0 && err != nil { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error %q, got %q", e, a) + } + return + } + + if e, a := c.ExpectBucketName, a; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/s3_object_lambda_arn.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/s3_object_lambda_arn.go new file mode 100644 index 0000000000..513154cc0e --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/s3_object_lambda_arn.go @@ -0,0 +1,15 @@ +package arn + +// S3ObjectLambdaARN represents an ARN for the s3-object-lambda service +type S3ObjectLambdaARN interface { + Resource + + isS3ObjectLambdasARN() +} + +// S3ObjectLambdaAccessPointARN is an S3ObjectLambdaARN for the Access Point resource type +type S3ObjectLambdaAccessPointARN struct { + AccessPointARN +} + +func (s S3ObjectLambdaAccessPointARN) isS3ObjectLambdasARN() {} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/ya.make new file mode 100644 index 0000000000..5e1a74cbc2 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn/ya.make @@ -0,0 +1,23 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + accesspoint_arn.go + arn.go + arn_member.go + outpost_arn.go + s3_object_lambda_arn.go +) + +GO_TEST_SRCS( + accesspoint_arn_test.go + arn_test.go + outpost_arn_test.go +) + +END() + +RECURSE( + gotest +) diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn_lookup.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn_lookup.go new file mode 100644 index 0000000000..b51532085f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn_lookup.go @@ -0,0 +1,73 @@ +package s3shared + +import ( + "context" + "fmt" + + "github.com/aws/smithy-go/middleware" + + "github.com/aws/aws-sdk-go-v2/aws/arn" +) + +// ARNLookup is the initial middleware that looks up if an arn is provided. +// This middleware is responsible for fetching ARN from a arnable field, and registering the ARN on +// middleware context. This middleware must be executed before input validation step or any other +// arn processing middleware. +type ARNLookup struct { + + // GetARNValue takes in a input interface and returns a ptr to string and a bool + GetARNValue func(interface{}) (*string, bool) +} + +// ID for the middleware +func (m *ARNLookup) ID() string { + return "S3Shared:ARNLookup" +} + +// HandleInitialize handles the behavior of this initialize step +func (m *ARNLookup) HandleInitialize(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // check if GetARNValue is supported + if m.GetARNValue == nil { + return next.HandleInitialize(ctx, in) + } + + // check is input resource is an ARN; if not go to next + v, ok := m.GetARNValue(in.Parameters) + if !ok || v == nil || !arn.IsARN(*v) { + return next.HandleInitialize(ctx, in) + } + + // if ARN process ResourceRequest and put it on ctx + av, err := arn.Parse(*v) + if err != nil { + return out, metadata, fmt.Errorf("error parsing arn: %w", err) + } + // set parsed arn on context + ctx = setARNResourceOnContext(ctx, av) + + return next.HandleInitialize(ctx, in) +} + +// arnResourceKey is the key set on context used to identify, retrive an ARN resource +// if present on the context. +type arnResourceKey struct{} + +// SetARNResourceOnContext sets the S3 ARN on the context. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func setARNResourceOnContext(ctx context.Context, value arn.ARN) context.Context { + return middleware.WithStackValue(ctx, arnResourceKey{}, value) +} + +// GetARNResourceFromContext returns an ARN from context and a bool indicating +// presence of ARN on ctx. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func GetARNResourceFromContext(ctx context.Context) (arn.ARN, bool) { + v, ok := middleware.GetStackValue(ctx, arnResourceKey{}).(arn.ARN) + return v, ok +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/config.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/config.go new file mode 100644 index 0000000000..b5d31f5c57 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/config.go @@ -0,0 +1,41 @@ +package config + +import "context" + +// UseARNRegionProvider is an interface for retrieving external configuration value for UseARNRegion +type UseARNRegionProvider interface { + GetS3UseARNRegion(ctx context.Context) (value bool, found bool, err error) +} + +// DisableMultiRegionAccessPointsProvider is an interface for retrieving external configuration value for DisableMultiRegionAccessPoints +type DisableMultiRegionAccessPointsProvider interface { + GetS3DisableMultiRegionAccessPoints(ctx context.Context) (value bool, found bool, err error) +} + +// ResolveUseARNRegion extracts the first instance of a UseARNRegion from the config slice. +// Additionally returns a boolean to indicate if the value was found in provided configs, and error if one is encountered. +func ResolveUseARNRegion(ctx context.Context, configs []interface{}) (value bool, found bool, err error) { + for _, cfg := range configs { + if p, ok := cfg.(UseARNRegionProvider); ok { + value, found, err = p.GetS3UseARNRegion(ctx) + if err != nil || found { + break + } + } + } + return +} + +// ResolveDisableMultiRegionAccessPoints extracts the first instance of a DisableMultiRegionAccessPoints from the config slice. +// Additionally returns a boolean to indicate if the value was found in provided configs, and error if one is encountered. +func ResolveDisableMultiRegionAccessPoints(ctx context.Context, configs []interface{}) (value bool, found bool, err error) { + for _, cfg := range configs { + if p, ok := cfg.(DisableMultiRegionAccessPointsProvider); ok { + value, found, err = p.GetS3DisableMultiRegionAccessPoints(ctx) + if err != nil || found { + break + } + } + } + return +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/ya.make new file mode 100644 index 0000000000..165270885d --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/config/ya.make @@ -0,0 +1,9 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + config.go +) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/endpoint_error.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/endpoint_error.go new file mode 100644 index 0000000000..aa0c3714e2 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/endpoint_error.go @@ -0,0 +1,183 @@ +package s3shared + +import ( + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn" +) + +// TODO: fix these error statements to be relevant to v2 sdk + +const ( + invalidARNErrorErrCode = "InvalidARNError" + configurationErrorErrCode = "ConfigurationError" +) + +// InvalidARNError denotes the error for Invalid ARN +type InvalidARNError struct { + message string + resource arn.Resource + origErr error +} + +// Error returns the InvalidARN error string +func (e InvalidARNError) Error() string { + var extra string + if e.resource != nil { + extra = "ARN: " + e.resource.String() + } + msg := invalidARNErrorErrCode + " : " + e.message + if extra != "" { + msg = msg + "\n\t" + extra + } + + return msg +} + +// OrigErr is the original error wrapped by Invalid ARN Error +func (e InvalidARNError) Unwrap() error { + return e.origErr +} + +// NewInvalidARNError denotes invalid arn error +func NewInvalidARNError(resource arn.Resource, err error) InvalidARNError { + return InvalidARNError{ + message: "invalid ARN", + origErr: err, + resource: resource, + } +} + +// NewInvalidARNWithUnsupportedPartitionError ARN not supported for the target partition +func NewInvalidARNWithUnsupportedPartitionError(resource arn.Resource, err error) InvalidARNError { + return InvalidARNError{ + message: "resource ARN not supported for the target ARN partition", + origErr: err, + resource: resource, + } +} + +// NewInvalidARNWithFIPSError ARN not supported for FIPS region +// +// Deprecated: FIPS will not appear in the ARN region component. +func NewInvalidARNWithFIPSError(resource arn.Resource, err error) InvalidARNError { + return InvalidARNError{ + message: "resource ARN not supported for FIPS region", + resource: resource, + origErr: err, + } +} + +// ConfigurationError is used to denote a client configuration error +type ConfigurationError struct { + message string + resource arn.Resource + clientPartitionID string + clientRegion string + origErr error +} + +// Error returns the Configuration error string +func (e ConfigurationError) Error() string { + extra := fmt.Sprintf("ARN: %s, client partition: %s, client region: %s", + e.resource, e.clientPartitionID, e.clientRegion) + + msg := configurationErrorErrCode + " : " + e.message + if extra != "" { + msg = msg + "\n\t" + extra + } + return msg +} + +// OrigErr is the original error wrapped by Configuration Error +func (e ConfigurationError) Unwrap() error { + return e.origErr +} + +// NewClientPartitionMismatchError stub +func NewClientPartitionMismatchError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client partition does not match provided ARN partition", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewClientRegionMismatchError denotes cross region access error +func NewClientRegionMismatchError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client region does not match provided ARN region", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewFailedToResolveEndpointError denotes endpoint resolving error +func NewFailedToResolveEndpointError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "endpoint resolver failed to find an endpoint for the provided ARN region", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewClientConfiguredForFIPSError denotes client config error for unsupported cross region FIPS access +func NewClientConfiguredForFIPSError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client configured for fips but cross-region resource ARN provided", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewFIPSConfigurationError denotes a configuration error when a client or request is configured for FIPS +func NewFIPSConfigurationError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "use of ARN is not supported when client or request is configured for FIPS", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewClientConfiguredForAccelerateError denotes client config error for unsupported S3 accelerate +func NewClientConfiguredForAccelerateError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client configured for S3 Accelerate but is not supported with resource ARN", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewClientConfiguredForCrossRegionFIPSError denotes client config error for unsupported cross region FIPS request +func NewClientConfiguredForCrossRegionFIPSError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client configured for FIPS with cross-region enabled but is supported with cross-region resource ARN", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} + +// NewClientConfiguredForDualStackError denotes client config error for unsupported S3 Dual-stack +func NewClientConfiguredForDualStackError(resource arn.Resource, clientPartitionID, clientRegion string, err error) ConfigurationError { + return ConfigurationError{ + message: "client configured for S3 Dual-stack but is not supported with resource ARN", + origErr: err, + resource: resource, + clientPartitionID: clientPartitionID, + clientRegion: clientRegion, + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/go_module_metadata.go new file mode 100644 index 0000000000..fc7e409600 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/go_module_metadata.go @@ -0,0 +1,6 @@ +// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT. + +package s3shared + +// goModuleVersion is the tagged release for this module +const goModuleVersion = "1.16.0" diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/gotest/ya.make new file mode 100644 index 0000000000..1ceac56026 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/host_id.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/host_id.go new file mode 100644 index 0000000000..85b60d2a1b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/host_id.go @@ -0,0 +1,29 @@ +package s3shared + +import ( + "github.com/aws/smithy-go/middleware" +) + +// hostID is used to retrieve host id from response metadata +type hostID struct { +} + +// SetHostIDMetadata sets the provided host id over middleware metadata +func SetHostIDMetadata(metadata *middleware.Metadata, id string) { + metadata.Set(hostID{}, id) +} + +// GetHostIDMetadata retrieves the host id from middleware metadata +// returns host id as string along with a boolean indicating presence of +// hostId on middleware metadata. +func GetHostIDMetadata(metadata middleware.Metadata) (string, bool) { + if !metadata.Has(hostID{}) { + return "", false + } + + v, ok := metadata.Get(hostID{}).(string) + if !ok { + return "", true + } + return v, true +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata.go new file mode 100644 index 0000000000..f02604cb62 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata.go @@ -0,0 +1,28 @@ +package s3shared + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +// clonedInputKey used to denote if request input was cloned. +type clonedInputKey struct{} + +// SetClonedInputKey sets a key on context to denote input was cloned previously. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func SetClonedInputKey(ctx context.Context, value bool) context.Context { + return middleware.WithStackValue(ctx, clonedInputKey{}, value) +} + +// IsClonedInput retrieves if context key for cloned input was set. +// If set, we can infer that the reuqest input was cloned previously. +// +// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues +// to clear all stack values. +func IsClonedInput(ctx context.Context) bool { + v, _ := middleware.GetStackValue(ctx, clonedInputKey{}).(bool) + return v +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata_retriever.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata_retriever.go new file mode 100644 index 0000000000..f52f2f11e9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/metadata_retriever.go @@ -0,0 +1,52 @@ +package s3shared + +import ( + "context" + + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const metadataRetrieverID = "S3MetadataRetriever" + +// AddMetadataRetrieverMiddleware adds request id, host id retriever middleware +func AddMetadataRetrieverMiddleware(stack *middleware.Stack) error { + // add metadata retriever middleware before operation deserializers so that it can retrieve metadata such as + // host id, request id from response header returned by operation deserializers + return stack.Deserialize.Insert(&metadataRetriever{}, "OperationDeserializer", middleware.Before) +} + +type metadataRetriever struct { +} + +// ID returns the middleware identifier +func (m *metadataRetriever) ID() string { + return metadataRetrieverID +} + +func (m *metadataRetriever) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + + resp, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + // No raw response to wrap with. + return out, metadata, err + } + + // check for header for Request id + if v := resp.Header.Get("X-Amz-Request-Id"); len(v) != 0 { + // set reqID on metadata for successful responses. + awsmiddleware.SetRequestIDMetadata(&metadata, v) + } + + // look up host-id + if v := resp.Header.Get("X-Amz-Id-2"); len(v) != 0 { + // set reqID on metadata for successful responses. + SetHostIDMetadata(&metadata, v) + } + + return out, metadata, err +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/resource_request.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/resource_request.go new file mode 100644 index 0000000000..bee8da3fe3 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/resource_request.go @@ -0,0 +1,77 @@ +package s3shared + +import ( + "fmt" + "strings" + + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn" +) + +// ResourceRequest represents an ARN resource and api request metadata +type ResourceRequest struct { + Resource arn.Resource + // RequestRegion is the region configured on the request config + RequestRegion string + + // SigningRegion is the signing region resolved for the request + SigningRegion string + + // PartitionID is the resolved partition id for the provided request region + PartitionID string + + // UseARNRegion indicates if client should use the region provided in an ARN resource + UseARNRegion bool + + // UseFIPS indicates if te client is configured for FIPS + UseFIPS bool +} + +// ARN returns the resource ARN +func (r ResourceRequest) ARN() awsarn.ARN { + return r.Resource.GetARN() +} + +// ResourceConfiguredForFIPS returns true if resource ARNs region is FIPS +// +// Deprecated: FIPS will not be present in the ARN region +func (r ResourceRequest) ResourceConfiguredForFIPS() bool { + return IsFIPS(r.ARN().Region) +} + +// AllowCrossRegion returns a bool value to denote if S3UseARNRegion flag is set +func (r ResourceRequest) AllowCrossRegion() bool { + return r.UseARNRegion +} + +// IsCrossPartition returns true if request is configured for region of another partition, than +// the partition that resource ARN region resolves to. IsCrossPartition will not return an error, +// if request is not configured with a specific partition id. This might happen if customer provides +// custom endpoint url, but does not associate a partition id with it. +func (r ResourceRequest) IsCrossPartition() (bool, error) { + rv := r.PartitionID + if len(rv) == 0 { + return false, nil + } + + av := r.Resource.GetARN().Partition + if len(av) == 0 { + return false, fmt.Errorf("no partition id for provided ARN") + } + + return !strings.EqualFold(rv, av), nil +} + +// IsCrossRegion returns true if request signing region is not same as arn region +func (r ResourceRequest) IsCrossRegion() bool { + v := r.SigningRegion + return !strings.EqualFold(v, r.Resource.GetARN().Region) +} + +// IsFIPS returns true if region is a fips pseudo-region +// +// Deprecated: FIPS should be specified via EndpointOptions. +func IsFIPS(region string) bool { + return strings.HasPrefix(region, "fips-") || + strings.HasSuffix(region, "-fips") +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error.go new file mode 100644 index 0000000000..8573362430 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error.go @@ -0,0 +1,33 @@ +package s3shared + +import ( + "errors" + "fmt" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" +) + +// ResponseError provides the HTTP centric error type wrapping the underlying error +// with the HTTP response value and the deserialized RequestID. +type ResponseError struct { + *awshttp.ResponseError + + // HostID associated with response error + HostID string +} + +// ServiceHostID returns the host id associated with Response Error +func (e *ResponseError) ServiceHostID() string { return e.HostID } + +// Error returns the formatted error +func (e *ResponseError) Error() string { + return fmt.Sprintf( + "https response error StatusCode: %d, RequestID: %s, HostID: %s, %v", + e.Response.StatusCode, e.RequestID, e.HostID, e.Err) +} + +// As populates target and returns true if the type of target is a error type that +// the ResponseError embeds, (e.g.S3 HTTP ResponseError) +func (e *ResponseError) As(target interface{}) bool { + return errors.As(e.ResponseError, target) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error_middleware.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error_middleware.go new file mode 100644 index 0000000000..5435762450 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/response_error_middleware.go @@ -0,0 +1,60 @@ +package s3shared + +import ( + "context" + + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// AddResponseErrorMiddleware adds response error wrapper middleware +func AddResponseErrorMiddleware(stack *middleware.Stack) error { + // add error wrapper middleware before request id retriever middleware so that it can wrap the error response + // returned by operation deserializers + return stack.Deserialize.Insert(&errorWrapper{}, metadataRetrieverID, middleware.Before) +} + +type errorWrapper struct { +} + +// ID returns the middleware identifier +func (m *errorWrapper) ID() string { + return "ResponseErrorWrapper" +} + +func (m *errorWrapper) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err == nil { + // Nothing to do when there is no error. + return out, metadata, err + } + + resp, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + // No raw response to wrap with. + return out, metadata, err + } + + // look for request id in metadata + reqID, _ := awsmiddleware.GetRequestIDMetadata(metadata) + // look for host id in metadata + hostID, _ := GetHostIDMetadata(metadata) + + // Wrap the returned smithy error with the request id retrieved from the metadata + err = &ResponseError{ + ResponseError: &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: resp, + Err: err, + }, + RequestID: reqID, + }, + HostID: hostID, + } + + return out, metadata, err +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue.go new file mode 100644 index 0000000000..0f43ec0d4f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue.go @@ -0,0 +1,54 @@ +package s3shared + +import ( + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const s3100ContinueID = "S3100Continue" +const default100ContinueThresholdBytes int64 = 1024 * 1024 * 2 + +// Add100Continue add middleware, which adds {Expect: 100-continue} header for s3 client HTTP PUT request larger than 2MB +// or with unknown size streaming bodies, during operation builder step +func Add100Continue(stack *middleware.Stack, continueHeaderThresholdBytes int64) error { + return stack.Build.Add(&s3100Continue{ + continueHeaderThresholdBytes: continueHeaderThresholdBytes, + }, middleware.After) +} + +type s3100Continue struct { + continueHeaderThresholdBytes int64 +} + +// ID returns the middleware identifier +func (m *s3100Continue) ID() string { + return s3100ContinueID +} + +func (m *s3100Continue) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + sizeLimit := default100ContinueThresholdBytes + switch { + case m.continueHeaderThresholdBytes == -1: + return next.HandleBuild(ctx, in) + case m.continueHeaderThresholdBytes > 0: + sizeLimit = m.continueHeaderThresholdBytes + default: + } + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + if req.ContentLength == -1 || (req.ContentLength == 0 && req.Body != nil) || req.ContentLength >= sizeLimit { + req.Header.Set("Expect", "100-continue") + } + + return next.HandleBuild(ctx, in) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue_test.go new file mode 100644 index 0000000000..db815c30bd --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/s3100continue_test.go @@ -0,0 +1,96 @@ +package s3shared + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "testing" +) + +// unit test for service/internal/s3shared/s3100continue.go +func TestAdd100ContinueHttpHeader(t *testing.T) { + const HeaderKey = "Expect" + HeaderValue := "100-continue" + + cases := map[string]struct { + ContentLength int64 + Body *awstesting.ReadCloser + ExpectValueFound string + ContinueHeaderThresholdBytes int64 + }{ + "http request smaller than default 2MB": { + ContentLength: 1, + Body: &awstesting.ReadCloser{Size: 1}, + ExpectValueFound: "", + }, + "http request smaller than configured threshold": { + ContentLength: 1024 * 1024 * 2, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 2}, + ExpectValueFound: "", + ContinueHeaderThresholdBytes: 1024 * 1024 * 3, + }, + "http request larger than default 2MB": { + ContentLength: 1024 * 1024 * 3, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3}, + ExpectValueFound: HeaderValue, + }, + "http request larger than configured threshold": { + ContentLength: 1024 * 1024 * 4, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 4}, + ExpectValueFound: HeaderValue, + ContinueHeaderThresholdBytes: 1024 * 1024 * 3, + }, + "http put request with unknown -1 ContentLength": { + ContentLength: -1, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 10}, + ExpectValueFound: HeaderValue, + }, + "http put request with 0 ContentLength but unknown non-nil body": { + ContentLength: 0, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3}, + ExpectValueFound: HeaderValue, + }, + "http put request with unknown -1 ContentLength and configured threshold": { + ContentLength: -1, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3}, + ExpectValueFound: HeaderValue, + ContinueHeaderThresholdBytes: 1024 * 1024 * 10, + }, + "http put request with continue header disabled": { + ContentLength: 1024 * 1024 * 3, + Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3}, + ExpectValueFound: "", + ContinueHeaderThresholdBytes: -1, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := smithyhttp.NewStackRequest().(*smithyhttp.Request) + + req.ContentLength = c.ContentLength + req.Body = c.Body + var updatedRequest *smithyhttp.Request + m := s3100Continue{ + continueHeaderThresholdBytes: c.ContinueHeaderThresholdBytes, + } + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error) { + updatedRequest = input.Request.(*smithyhttp.Request) + return out, metadata, nil + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := c.ExpectValueFound, updatedRequest.Header.Get(HeaderKey); e != a { + t.Errorf("expect header value %v found, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/update_endpoint.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/update_endpoint.go new file mode 100644 index 0000000000..22773199f6 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/update_endpoint.go @@ -0,0 +1,78 @@ +package s3shared + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + + awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware" +) + +// EnableDualstack represents middleware struct for enabling dualstack support +// +// Deprecated: See EndpointResolverOptions' UseDualStackEndpoint support +type EnableDualstack struct { + // UseDualstack indicates if dualstack endpoint resolving is to be enabled + UseDualstack bool + + // DefaultServiceID is the service id prefix used in endpoint resolving + // by default service-id is 's3' and 's3-control' for service s3, s3control. + DefaultServiceID string +} + +// ID returns the middleware ID. +func (*EnableDualstack) ID() string { + return "EnableDualstack" +} + +// HandleSerialize handles serializer middleware behavior when middleware is executed +func (u *EnableDualstack) HandleSerialize( + ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, +) ( + out middleware.SerializeOutput, metadata middleware.Metadata, err error, +) { + + // check for host name immutable property + if smithyhttp.GetHostnameImmutable(ctx) { + return next.HandleSerialize(ctx, in) + } + + serviceID := awsmiddle.GetServiceID(ctx) + + // s3-control may be represented as `S3 Control` as in model + if serviceID == "S3 Control" { + serviceID = "s3-control" + } + + if len(serviceID) == 0 { + // default service id + serviceID = u.DefaultServiceID + } + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unknown request type %T", req) + } + + if u.UseDualstack { + parts := strings.Split(req.URL.Host, ".") + if len(parts) < 3 { + return out, metadata, fmt.Errorf("unable to update endpoint host for dualstack, hostname invalid, %s", req.URL.Host) + } + + for i := 0; i+1 < len(parts); i++ { + if strings.EqualFold(parts[i], serviceID) { + parts[i] = parts[i] + ".dualstack" + break + } + } + + // construct the url host + req.URL.Host = strings.Join(parts, ".") + } + + return next.HandleSerialize(ctx, in) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils.go new file mode 100644 index 0000000000..65fd07e000 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils.go @@ -0,0 +1,89 @@ +package s3shared + +import ( + "encoding/xml" + "fmt" + "io" + "net/http" + "strings" +) + +// ErrorComponents represents the error response fields +// that will be deserialized from an xml error response body +type ErrorComponents struct { + Code string `xml:"Code"` + Message string `xml:"Message"` + RequestID string `xml:"RequestId"` + HostID string `xml:"HostId"` +} + +// GetUnwrappedErrorResponseComponents returns the error fields from an xml error response body +func GetUnwrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) { + var errComponents ErrorComponents + if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF { + return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err) + } + return errComponents, nil +} + +// GetWrappedErrorResponseComponents returns the error fields from an xml error response body +// in which error code, and message are wrapped by a <Error> tag +func GetWrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) { + var errComponents struct { + Code string `xml:"Error>Code"` + Message string `xml:"Error>Message"` + RequestID string `xml:"RequestId"` + HostID string `xml:"HostId"` + } + + if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF { + return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err) + } + + return ErrorComponents{ + Code: errComponents.Code, + Message: errComponents.Message, + RequestID: errComponents.RequestID, + HostID: errComponents.HostID, + }, nil +} + +// GetErrorResponseComponents retrieves error components according to passed in options +func GetErrorResponseComponents(r io.Reader, options ErrorResponseDeserializerOptions) (ErrorComponents, error) { + var errComponents ErrorComponents + var err error + + if options.IsWrappedWithErrorTag { + errComponents, err = GetWrappedErrorResponseComponents(r) + } else { + errComponents, err = GetUnwrappedErrorResponseComponents(r) + } + + if err != nil { + return ErrorComponents{}, err + } + + // If an error code or message is not retrieved, it is derived from the http status code + // eg, for S3 service, we derive err code and message, if none is found + if options.UseStatusCode && len(errComponents.Code) == 0 && + len(errComponents.Message) == 0 { + // derive code and message from status code + statusText := http.StatusText(options.StatusCode) + errComponents.Code = strings.Replace(statusText, " ", "", -1) + errComponents.Message = statusText + } + return errComponents, nil +} + +// ErrorResponseDeserializerOptions represents error response deserializer options for s3 and s3-control service +type ErrorResponseDeserializerOptions struct { + // UseStatusCode denotes if status code should be used to retrieve error code, msg + UseStatusCode bool + + // StatusCode is status code of error response + StatusCode int + + //IsWrappedWithErrorTag represents if error response's code, msg is wrapped within an + // additional <Error> tag + IsWrappedWithErrorTag bool +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils_test.go new file mode 100644 index 0000000000..6d7efaabf9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/xml_utils_test.go @@ -0,0 +1,102 @@ +package s3shared + +import ( + "strings" + "testing" +) + +func TestGetResponseErrorCode(t *testing.T) { + const xmlErrorResponse = `<Error> + <Type>Sender</Type> + <Code>InvalidGreeting</Code> + <Message>Hi</Message> + <HostId>bar-id</HostId> + <RequestId>foo-id</RequestId> +</Error>` + + const wrappedXMLErrorResponse = `<ErrorResponse><Error> + <Type>Sender</Type> + <Code>InvalidGreeting</Code> + <Message>Hi</Message> +</Error> + <HostId>bar-id</HostId> + <RequestId>foo-id</RequestId> +</ErrorResponse>` + + cases := map[string]struct { + getErr func() (ErrorComponents, error) + expectedErrorCode string + expectedErrorMessage string + expectedErrorRequestID string + expectedErrorHostID string + }{ + "standard xml error": { + getErr: func() (ErrorComponents, error) { + errResp := strings.NewReader(xmlErrorResponse) + return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{ + UseStatusCode: false, + StatusCode: 0, + IsWrappedWithErrorTag: false, + }) + }, + expectedErrorCode: "InvalidGreeting", + expectedErrorMessage: "Hi", + expectedErrorRequestID: "foo-id", + expectedErrorHostID: "bar-id", + }, + + "s3 no response body": { + getErr: func() (ErrorComponents, error) { + errResp := strings.NewReader("") + return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{ + UseStatusCode: true, + StatusCode: 400, + }) + }, + expectedErrorCode: "BadRequest", + expectedErrorMessage: "Bad Request", + }, + "s3control no response body": { + getErr: func() (ErrorComponents, error) { + errResp := strings.NewReader("") + return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{ + IsWrappedWithErrorTag: true, + }) + }, + }, + "s3control standard response body": { + getErr: func() (ErrorComponents, error) { + errResp := strings.NewReader(wrappedXMLErrorResponse) + return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{ + IsWrappedWithErrorTag: true, + }) + }, + expectedErrorCode: "InvalidGreeting", + expectedErrorMessage: "Hi", + expectedErrorRequestID: "foo-id", + expectedErrorHostID: "bar-id", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + ec, err := c.getErr() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if e, a := c.expectedErrorCode, ec.Code; !strings.EqualFold(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := c.expectedErrorMessage, ec.Message; !strings.EqualFold(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := c.expectedErrorRequestID, ec.RequestID; !strings.EqualFold(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := c.expectedErrorHostID, ec.HostID; !strings.EqualFold(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/ya.make new file mode 100644 index 0000000000..512ad83e48 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/s3shared/ya.make @@ -0,0 +1,31 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + arn_lookup.go + endpoint_error.go + go_module_metadata.go + host_id.go + metadata.go + metadata_retriever.go + resource_request.go + response_error.go + response_error_middleware.go + s3100continue.go + update_endpoint.go + xml_utils.go +) + +GO_TEST_SRCS( + s3100continue_test.go + xml_utils_test.go +) + +END() + +RECURSE( + arn + config + gotest +) |