aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum
diff options
context:
space:
mode:
authorvitalyisaev <vitalyisaev@ydb.tech>2023-12-12 21:55:07 +0300
committervitalyisaev <vitalyisaev@ydb.tech>2023-12-12 22:25:10 +0300
commit4967f99474a4040ba150eb04995de06342252718 (patch)
treec9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum
parent2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff)
downloadydb-4967f99474a4040ba150eb04995de06342252718.tar.gz
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.) 2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum')
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go323
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go470
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go389
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go507
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go6
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make5
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go185
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go412
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go479
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go958
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go117
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go143
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go131
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go263
-rw-r--r--vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make28
15 files changed, 4416 insertions, 0 deletions
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go
new file mode 100644
index 0000000000..a17041c35d
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go
@@ -0,0 +1,323 @@
+package checksum
+
+import (
+ "crypto/md5"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "hash"
+ "hash/crc32"
+ "io"
+ "strings"
+ "sync"
+)
+
+// Algorithm represents the checksum algorithms supported
+type Algorithm string
+
+// Enumeration values for supported checksum Algorithms.
+const (
+ // AlgorithmCRC32C represents CRC32C hash algorithm
+ AlgorithmCRC32C Algorithm = "CRC32C"
+
+ // AlgorithmCRC32 represents CRC32 hash algorithm
+ AlgorithmCRC32 Algorithm = "CRC32"
+
+ // AlgorithmSHA1 represents SHA1 hash algorithm
+ AlgorithmSHA1 Algorithm = "SHA1"
+
+ // AlgorithmSHA256 represents SHA256 hash algorithm
+ AlgorithmSHA256 Algorithm = "SHA256"
+)
+
+var supportedAlgorithms = []Algorithm{
+ AlgorithmCRC32C,
+ AlgorithmCRC32,
+ AlgorithmSHA1,
+ AlgorithmSHA256,
+}
+
+func (a Algorithm) String() string { return string(a) }
+
+// ParseAlgorithm attempts to parse the provided value into a checksum
+// algorithm, matching without case. Returns the algorithm matched, or an error
+// if the algorithm wasn't matched.
+func ParseAlgorithm(v string) (Algorithm, error) {
+ for _, a := range supportedAlgorithms {
+ if strings.EqualFold(string(a), v) {
+ return a, nil
+ }
+ }
+ return "", fmt.Errorf("unknown checksum algorithm, %v", v)
+}
+
+// FilterSupportedAlgorithms filters the set of algorithms, returning a slice
+// of algorithms that are supported.
+func FilterSupportedAlgorithms(vs []string) []Algorithm {
+ found := map[Algorithm]struct{}{}
+
+ supported := make([]Algorithm, 0, len(supportedAlgorithms))
+ for _, v := range vs {
+ for _, a := range supportedAlgorithms {
+ // Only consider algorithms that are supported
+ if !strings.EqualFold(v, string(a)) {
+ continue
+ }
+ // Ignore duplicate algorithms in list.
+ if _, ok := found[a]; ok {
+ continue
+ }
+
+ supported = append(supported, a)
+ found[a] = struct{}{}
+ }
+ }
+ return supported
+}
+
+// NewAlgorithmHash returns a hash.Hash for the checksum algorithm. Error is
+// returned if the algorithm is unknown.
+func NewAlgorithmHash(v Algorithm) (hash.Hash, error) {
+ switch v {
+ case AlgorithmSHA1:
+ return sha1.New(), nil
+ case AlgorithmSHA256:
+ return sha256.New(), nil
+ case AlgorithmCRC32:
+ return crc32.NewIEEE(), nil
+ case AlgorithmCRC32C:
+ return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil
+ default:
+ return nil, fmt.Errorf("unknown checksum algorithm, %v", v)
+ }
+}
+
+// AlgorithmChecksumLength returns the length of the algorithm's checksum in
+// bytes. If the algorithm is not known, an error is returned.
+func AlgorithmChecksumLength(v Algorithm) (int, error) {
+ switch v {
+ case AlgorithmSHA1:
+ return sha1.Size, nil
+ case AlgorithmSHA256:
+ return sha256.Size, nil
+ case AlgorithmCRC32:
+ return crc32.Size, nil
+ case AlgorithmCRC32C:
+ return crc32.Size, nil
+ default:
+ return 0, fmt.Errorf("unknown checksum algorithm, %v", v)
+ }
+}
+
+const awsChecksumHeaderPrefix = "x-amz-checksum-"
+
+// AlgorithmHTTPHeader returns the HTTP header for the algorithm's hash.
+func AlgorithmHTTPHeader(v Algorithm) string {
+ return awsChecksumHeaderPrefix + strings.ToLower(string(v))
+}
+
+// base64EncodeHashSum computes base64 encoded checksum of a given running
+// hash. The running hash must already have content written to it. Returns the
+// byte slice of checksum and an error
+func base64EncodeHashSum(h hash.Hash) []byte {
+ sum := h.Sum(nil)
+ sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum)))
+ base64.StdEncoding.Encode(sum64, sum)
+ return sum64
+}
+
+// hexEncodeHashSum computes hex encoded checksum of a given running hash. The
+// running hash must already have content written to it. Returns the byte slice
+// of checksum and an error
+func hexEncodeHashSum(h hash.Hash) []byte {
+ sum := h.Sum(nil)
+ sumHex := make([]byte, hex.EncodedLen(len(sum)))
+ hex.Encode(sumHex, sum)
+ return sumHex
+}
+
+// computeMD5Checksum computes base64 MD5 checksum of an io.Reader's contents.
+// Returns the byte slice of MD5 checksum and an error.
+func computeMD5Checksum(r io.Reader) ([]byte, error) {
+ h := md5.New()
+
+ // Copy errors may be assumed to be from the body.
+ if _, err := io.Copy(h, r); err != nil {
+ return nil, fmt.Errorf("failed compute MD5 hash of reader, %w", err)
+ }
+
+ // Encode the MD5 checksum in base64.
+ return base64EncodeHashSum(h), nil
+}
+
+// computeChecksumReader provides a reader wrapping an underlying io.Reader to
+// compute the checksum of the stream's bytes.
+type computeChecksumReader struct {
+ stream io.Reader
+ algorithm Algorithm
+ hasher hash.Hash
+ base64ChecksumLen int
+
+ mux sync.RWMutex
+ lockedChecksum string
+ lockedErr error
+}
+
+// newComputeChecksumReader returns a computeChecksumReader for the stream and
+// algorithm specified. Returns error if unable to create the reader, or
+// algorithm is unknown.
+func newComputeChecksumReader(stream io.Reader, algorithm Algorithm) (*computeChecksumReader, error) {
+ hasher, err := NewAlgorithmHash(algorithm)
+ if err != nil {
+ return nil, err
+ }
+
+ checksumLength, err := AlgorithmChecksumLength(algorithm)
+ if err != nil {
+ return nil, err
+ }
+
+ return &computeChecksumReader{
+ stream: io.TeeReader(stream, hasher),
+ algorithm: algorithm,
+ hasher: hasher,
+ base64ChecksumLen: base64.StdEncoding.EncodedLen(checksumLength),
+ }, nil
+}
+
+// Read wraps the underlying reader. When the underlying reader returns EOF,
+// the checksum of the reader will be computed, and can be retrieved with
+// ChecksumBase64String.
+func (r *computeChecksumReader) Read(p []byte) (int, error) {
+ n, err := r.stream.Read(p)
+ if err == nil {
+ return n, nil
+ } else if err != io.EOF {
+ r.mux.Lock()
+ defer r.mux.Unlock()
+
+ r.lockedErr = err
+ return n, err
+ }
+
+ b := base64EncodeHashSum(r.hasher)
+
+ r.mux.Lock()
+ defer r.mux.Unlock()
+
+ r.lockedChecksum = string(b)
+
+ return n, err
+}
+
+func (r *computeChecksumReader) Algorithm() Algorithm {
+ return r.algorithm
+}
+
+// Base64ChecksumLength returns the base64 encoded length of the checksum for
+// algorithm.
+func (r *computeChecksumReader) Base64ChecksumLength() int {
+ return r.base64ChecksumLen
+}
+
+// Base64Checksum returns the base64 checksum for the algorithm, or error if
+// the underlying reader returned a non-EOF error.
+//
+// Safe to be called concurrently, but will return an error until after the
+// underlying reader is returns EOF.
+func (r *computeChecksumReader) Base64Checksum() (string, error) {
+ r.mux.RLock()
+ defer r.mux.RUnlock()
+
+ if r.lockedErr != nil {
+ return "", r.lockedErr
+ }
+
+ if r.lockedChecksum == "" {
+ return "", fmt.Errorf(
+ "checksum not available yet, called before reader returns EOF",
+ )
+ }
+
+ return r.lockedChecksum, nil
+}
+
+// validateChecksumReader implements io.ReadCloser interface. The wrapper
+// performs checksum validation when the underlying reader has been fully read.
+type validateChecksumReader struct {
+ originalBody io.ReadCloser
+ body io.Reader
+ hasher hash.Hash
+ algorithm Algorithm
+ expectChecksum string
+}
+
+// newValidateChecksumReader returns a configured io.ReadCloser that performs
+// checksum validation when the underlying reader has been fully read.
+func newValidateChecksumReader(
+ body io.ReadCloser,
+ algorithm Algorithm,
+ expectChecksum string,
+) (*validateChecksumReader, error) {
+ hasher, err := NewAlgorithmHash(algorithm)
+ if err != nil {
+ return nil, err
+ }
+
+ return &validateChecksumReader{
+ originalBody: body,
+ body: io.TeeReader(body, hasher),
+ hasher: hasher,
+ algorithm: algorithm,
+ expectChecksum: expectChecksum,
+ }, nil
+}
+
+// Read attempts to read from the underlying stream while also updating the
+// running hash. If the underlying stream returns with an EOF error, the
+// checksum of the stream will be collected, and compared against the expected
+// checksum. If the checksums do not match, an error will be returned.
+//
+// If a non-EOF error occurs when reading the underlying stream, that error
+// will be returned and the checksum for the stream will be discarded.
+func (c *validateChecksumReader) Read(p []byte) (n int, err error) {
+ n, err = c.body.Read(p)
+ if err == io.EOF {
+ if checksumErr := c.validateChecksum(); checksumErr != nil {
+ return n, checksumErr
+ }
+ }
+
+ return n, err
+}
+
+// Close closes the underlying reader, returning any error that occurred in the
+// underlying reader.
+func (c *validateChecksumReader) Close() (err error) {
+ return c.originalBody.Close()
+}
+
+func (c *validateChecksumReader) validateChecksum() error {
+ // Compute base64 encoded checksum hash of the payload's read bytes.
+ v := base64EncodeHashSum(c.hasher)
+ if e, a := c.expectChecksum, string(v); !strings.EqualFold(e, a) {
+ return validationError{
+ Algorithm: c.algorithm, Expect: e, Actual: a,
+ }
+ }
+
+ return nil
+}
+
+type validationError struct {
+ Algorithm Algorithm
+ Expect string
+ Actual string
+}
+
+func (v validationError) Error() string {
+ return fmt.Sprintf("checksum did not match: algorithm %v, expect %v, actual %v",
+ v.Algorithm, v.Expect, v.Actual)
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go
new file mode 100644
index 0000000000..3f8b27018a
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go
@@ -0,0 +1,470 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "hash/crc32"
+ "io"
+ "io/ioutil"
+ "strings"
+ "testing"
+ "testing/iotest"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestComputeChecksumReader(t *testing.T) {
+ cases := map[string]struct {
+ Input io.Reader
+ Algorithm Algorithm
+ ExpectErr string
+ ExpectChecksumLen int
+
+ ExpectRead string
+ ExpectReadErr string
+ ExpectComputeErr string
+ ExpectChecksum string
+ }{
+ "unknown algorithm": {
+ Input: bytes.NewBuffer(nil),
+ Algorithm: Algorithm("something"),
+ ExpectErr: "unknown checksum algorithm",
+ },
+ "read error": {
+ Input: iotest.ErrReader(fmt.Errorf("some error")),
+ Algorithm: AlgorithmSHA256,
+ ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size),
+ ExpectReadErr: "some error",
+ ExpectComputeErr: "some error",
+ },
+ "crc32c": {
+ Input: strings.NewReader("hello world"),
+ Algorithm: AlgorithmCRC32C,
+ ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size),
+ ExpectRead: "hello world",
+ ExpectChecksum: "yZRlqg==",
+ },
+ "crc32": {
+ Input: strings.NewReader("hello world"),
+ Algorithm: AlgorithmCRC32,
+ ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size),
+ ExpectRead: "hello world",
+ ExpectChecksum: "DUoRhQ==",
+ },
+ "sha1": {
+ Input: strings.NewReader("hello world"),
+ Algorithm: AlgorithmSHA1,
+ ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha1.Size),
+ ExpectRead: "hello world",
+ ExpectChecksum: "Kq5sNclPz7QV2+lfQIuc6R7oRu0=",
+ },
+ "sha256": {
+ Input: strings.NewReader("hello world"),
+ Algorithm: AlgorithmSHA256,
+ ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size),
+ ExpectRead: "hello world",
+ ExpectChecksum: "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ // Validate reader can be created as expected.
+ r, err := newComputeChecksumReader(c.Input, c.Algorithm)
+ if err == nil && len(c.ExpectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.ExpectErr)
+ }
+ if err != nil && len(c.ExpectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.ExpectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.ExpectErr, err)
+ }
+ if c.ExpectErr != "" {
+ return
+ }
+
+ if e, a := c.Algorithm, r.Algorithm(); e != a {
+ t.Errorf("expect %v algorithm, got %v", e, a)
+ }
+
+ // Validate expected checksum length.
+ if e, a := c.ExpectChecksumLen, r.Base64ChecksumLength(); e != a {
+ t.Errorf("expect %v checksum length, got %v", e, a)
+ }
+
+ // Validate read reads underlying stream's bytes as expected.
+ b, err := ioutil.ReadAll(r)
+ if err == nil && len(c.ExpectReadErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.ExpectReadErr)
+ }
+ if err != nil && len(c.ExpectReadErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.ExpectReadErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.ExpectReadErr, err)
+ }
+ if len(c.ExpectReadErr) != 0 {
+ return
+ }
+
+ if diff := cmp.Diff(string(c.ExpectRead), string(b)); diff != "" {
+ t.Errorf("expect read match, got\n%v", diff)
+ }
+
+ // validate computed base64
+ v, err := r.Base64Checksum()
+ if err == nil && len(c.ExpectComputeErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.ExpectComputeErr)
+ }
+ if err != nil && len(c.ExpectComputeErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.ExpectComputeErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.ExpectComputeErr, err)
+ }
+ if diff := cmp.Diff(c.ExpectChecksum, v); diff != "" {
+ t.Errorf("expect checksum match, got\n%v", diff)
+ }
+ if c.ExpectComputeErr != "" {
+ return
+ }
+
+ if e, a := c.ExpectChecksumLen, len(v); e != a {
+ t.Errorf("expect computed checksum length %v, got %v", e, a)
+ }
+ })
+ }
+}
+
+func TestComputeChecksumReader_earlyGetChecksum(t *testing.T) {
+ r, err := newComputeChecksumReader(strings.NewReader("hello world"), AlgorithmCRC32C)
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ v, err := r.Base64Checksum()
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if err != nil && !strings.Contains(err.Error(), "not available") {
+ t.Fatalf("expect error to match, got %v", err)
+ }
+ if v != "" {
+ t.Errorf("expect no checksum, got %v", err)
+ }
+}
+
+// TODO race condition case with many reads, and get checksum
+
+func TestValidateChecksumReader(t *testing.T) {
+ cases := map[string]struct {
+ payload io.ReadCloser
+ algorithm Algorithm
+ checksum string
+ expectErr string
+ expectChecksumErr string
+ expectedBody []byte
+ }{
+ "unknown algorithm": {
+ payload: ioutil.NopCloser(bytes.NewBuffer(nil)),
+ algorithm: Algorithm("something"),
+ expectErr: "unknown checksum algorithm",
+ },
+ "empty body": {
+ payload: ioutil.NopCloser(bytes.NewReader([]byte(""))),
+ algorithm: AlgorithmCRC32,
+ checksum: "AAAAAA==",
+ expectedBody: []byte(""),
+ },
+ "standard body": {
+ payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))),
+ algorithm: AlgorithmCRC32,
+ checksum: "i9aeUg==",
+ expectedBody: []byte("Hello world"),
+ },
+ "checksum mismatch": {
+ payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))),
+ algorithm: AlgorithmCRC32,
+ checksum: "AAAAAA==",
+ expectedBody: []byte("Hello world"),
+ expectChecksumErr: "checksum did not match",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ response, err := newValidateChecksumReader(c.payload, c.algorithm, c.checksum)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if response == nil {
+ if c.expectedBody == nil {
+ return
+ }
+ t.Fatalf("expected non nil response, got nil")
+ }
+
+ actualResponse, err := ioutil.ReadAll(response)
+ if err == nil && len(c.expectChecksumErr) != 0 {
+ t.Fatalf("expected error %v, got none", c.expectChecksumErr)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectChecksumErr) {
+ t.Fatalf("expected error %v to contain %v", err.Error(), c.expectChecksumErr)
+ }
+
+ if diff := cmp.Diff(c.expectedBody, actualResponse); len(diff) != 0 {
+ t.Fatalf("found diff comparing response body %v", diff)
+ }
+
+ err = response.Close()
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ })
+ }
+}
+
+func TestComputeMD5Checksum(t *testing.T) {
+ cases := map[string]struct {
+ payload io.Reader
+ expectErr string
+ expectChecksum string
+ }{
+ "empty payload": {
+ payload: strings.NewReader(""),
+ expectChecksum: "1B2M2Y8AsgTpgAmY7PhCfg==",
+ },
+ "payload": {
+ payload: strings.NewReader("hello world"),
+ expectChecksum: "XrY7u+Ae7tCTyyK7j1rNww==",
+ },
+ "error payload": {
+ payload: iotest.ErrReader(fmt.Errorf("some error")),
+ expectErr: "some error",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ actualChecksum, err := computeMD5Checksum(c.payload)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if e, a := c.expectChecksum, string(actualChecksum); !strings.EqualFold(e, a) {
+ t.Errorf("expect %v checksum, got %v", e, a)
+ }
+ })
+ }
+}
+
+func TestParseAlgorithm(t *testing.T) {
+ cases := map[string]struct {
+ Value string
+ expectAlgorithm Algorithm
+ expectErr string
+ }{
+ "crc32c": {
+ Value: "crc32c",
+ expectAlgorithm: AlgorithmCRC32C,
+ },
+ "CRC32C": {
+ Value: "CRC32C",
+ expectAlgorithm: AlgorithmCRC32C,
+ },
+ "crc32": {
+ Value: "crc32",
+ expectAlgorithm: AlgorithmCRC32,
+ },
+ "CRC32": {
+ Value: "CRC32",
+ expectAlgorithm: AlgorithmCRC32,
+ },
+ "sha1": {
+ Value: "sha1",
+ expectAlgorithm: AlgorithmSHA1,
+ },
+ "SHA1": {
+ Value: "SHA1",
+ expectAlgorithm: AlgorithmSHA1,
+ },
+ "sha256": {
+ Value: "sha256",
+ expectAlgorithm: AlgorithmSHA256,
+ },
+ "SHA256": {
+ Value: "SHA256",
+ expectAlgorithm: AlgorithmSHA256,
+ },
+ "empty": {
+ Value: "",
+ expectErr: "unknown checksum algorithm",
+ },
+ "unknown": {
+ Value: "unknown",
+ expectErr: "unknown checksum algorithm",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ // Asserts
+ algorithm, err := ParseAlgorithm(c.Value)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if e, a := c.expectAlgorithm, algorithm; e != a {
+ t.Errorf("expect %v algorithm, got %v", e, a)
+ }
+ })
+ }
+}
+
+func TestFilterSupportedAlgorithms(t *testing.T) {
+ cases := map[string]struct {
+ values []string
+ expectAlgorithms []Algorithm
+ }{
+ "no algorithms": {
+ expectAlgorithms: []Algorithm{},
+ },
+ "no supported algorithms": {
+ values: []string{"abc", "123"},
+ expectAlgorithms: []Algorithm{},
+ },
+ "duplicate algorithms": {
+ values: []string{"crc32", "crc32c", "crc32c"},
+ expectAlgorithms: []Algorithm{
+ AlgorithmCRC32,
+ AlgorithmCRC32C,
+ },
+ },
+ "preserve order": {
+ values: []string{"crc32", "crc32c", "sha1", "sha256"},
+ expectAlgorithms: []Algorithm{
+ AlgorithmCRC32,
+ AlgorithmCRC32C,
+ AlgorithmSHA1,
+ AlgorithmSHA256,
+ },
+ },
+ "preserve order 2": {
+ values: []string{"sha256", "crc32", "sha1", "crc32c"},
+ expectAlgorithms: []Algorithm{
+ AlgorithmSHA256,
+ AlgorithmCRC32,
+ AlgorithmSHA1,
+ AlgorithmCRC32C,
+ },
+ },
+ "mixed case": {
+ values: []string{"Crc32", "cRc32c", "shA1", "sHA256"},
+ expectAlgorithms: []Algorithm{
+ AlgorithmCRC32,
+ AlgorithmCRC32C,
+ AlgorithmSHA1,
+ AlgorithmSHA256,
+ },
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ algorithms := FilterSupportedAlgorithms(c.values)
+ if diff := cmp.Diff(c.expectAlgorithms, algorithms); diff != "" {
+ t.Errorf("expect algorithms match\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestAlgorithmChecksumLength(t *testing.T) {
+ cases := map[string]struct {
+ algorithm Algorithm
+ expectErr string
+ expectLength int
+ }{
+ "empty": {
+ algorithm: "",
+ expectErr: "unknown checksum algorithm",
+ },
+ "unknown": {
+ algorithm: "",
+ expectErr: "unknown checksum algorithm",
+ },
+ "crc32": {
+ algorithm: AlgorithmCRC32,
+ expectLength: crc32.Size,
+ },
+ "crc32c": {
+ algorithm: AlgorithmCRC32C,
+ expectLength: crc32.Size,
+ },
+ "sha1": {
+ algorithm: AlgorithmSHA1,
+ expectLength: sha1.Size,
+ },
+ "sha256": {
+ algorithm: AlgorithmSHA256,
+ expectLength: sha256.Size,
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ l, err := AlgorithmChecksumLength(c.algorithm)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if e, a := c.expectLength, l; e != a {
+ t.Errorf("expect %v checksum length, got %v", e, a)
+ }
+ })
+ }
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go
new file mode 100644
index 0000000000..3bd320c437
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go
@@ -0,0 +1,389 @@
+package checksum
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+)
+
+const (
+ crlf = "\r\n"
+
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html
+ defaultChunkLength = 1024 * 64
+
+ awsTrailerHeaderName = "x-amz-trailer"
+ decodedContentLengthHeaderName = "x-amz-decoded-content-length"
+
+ contentEncodingHeaderName = "content-encoding"
+ awsChunkedContentEncodingHeaderValue = "aws-chunked"
+
+ trailerKeyValueSeparator = ":"
+)
+
+var (
+ crlfBytes = []byte(crlf)
+ finalChunkBytes = []byte("0" + crlf)
+)
+
+type awsChunkedEncodingOptions struct {
+ // The total size of the stream. For unsigned encoding this implies that
+ // there will only be a single chunk containing the underlying payload,
+ // unless ChunkLength is also specified.
+ StreamLength int64
+
+ // Set of trailer key:value pairs that will be appended to the end of the
+ // payload after the end chunk has been written.
+ Trailers map[string]awsChunkedTrailerValue
+
+ // The maximum size of each chunk to be sent. Default value of -1, signals
+ // that optimal chunk length will be used automatically. ChunkSize must be
+ // at least 8KB.
+ //
+ // If ChunkLength and StreamLength are both specified, the stream will be
+ // broken up into ChunkLength chunks. The encoded length of the aws-chunked
+ // encoding can still be determined as long as all trailers, if any, have a
+ // fixed length.
+ ChunkLength int
+}
+
+type awsChunkedTrailerValue struct {
+ // Function to retrieve the value of the trailer. Will only be called after
+ // the underlying stream returns EOF error.
+ Get func() (string, error)
+
+ // If the length of the value can be pre-determined, and is constant
+ // specify the length. A value of -1 means the length is unknown, or
+ // cannot be pre-determined.
+ Length int
+}
+
+// awsChunkedEncoding provides a reader that wraps the payload such that
+// payload is read as a single aws-chunk payload. This reader can only be used
+// if the content length of payload is known. Content-Length is used as size of
+// the single payload chunk. The final chunk and trailing checksum is appended
+// at the end.
+//
+// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition
+//
+// Here is the aws-chunked payload stream as read from the awsChunkedEncoding
+// if original request stream is "Hello world", and checksum hash used is SHA256
+//
+// <b>\r\n
+// Hello world\r\n
+// 0\r\n
+// x-amz-checksum-sha256:ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=\r\n
+// \r\n
+type awsChunkedEncoding struct {
+ options awsChunkedEncodingOptions
+
+ encodedStream io.Reader
+ trailerEncodedLength int
+}
+
+// newUnsignedAWSChunkedEncoding returns a new awsChunkedEncoding configured
+// for unsigned aws-chunked content encoding. Any additional trailers that need
+// to be appended after the end chunk must be included as via Trailer
+// callbacks.
+func newUnsignedAWSChunkedEncoding(
+ stream io.Reader,
+ optFns ...func(*awsChunkedEncodingOptions),
+) *awsChunkedEncoding {
+ options := awsChunkedEncodingOptions{
+ Trailers: map[string]awsChunkedTrailerValue{},
+ StreamLength: -1,
+ ChunkLength: -1,
+ }
+ for _, fn := range optFns {
+ fn(&options)
+ }
+
+ var chunkReader io.Reader
+ if options.ChunkLength != -1 || options.StreamLength == -1 {
+ if options.ChunkLength == -1 {
+ options.ChunkLength = defaultChunkLength
+ }
+ chunkReader = newBufferedAWSChunkReader(stream, options.ChunkLength)
+ } else {
+ chunkReader = newUnsignedChunkReader(stream, options.StreamLength)
+ }
+
+ trailerReader := newAWSChunkedTrailerReader(options.Trailers)
+
+ return &awsChunkedEncoding{
+ options: options,
+ encodedStream: io.MultiReader(chunkReader,
+ trailerReader,
+ bytes.NewBuffer(crlfBytes),
+ ),
+ trailerEncodedLength: trailerReader.EncodedLength(),
+ }
+}
+
+// EncodedLength returns the final length of the aws-chunked content encoded
+// stream if it can be determined without reading the underlying stream or lazy
+// header values, otherwise -1 is returned.
+func (e *awsChunkedEncoding) EncodedLength() int64 {
+ var length int64
+ if e.options.StreamLength == -1 || e.trailerEncodedLength == -1 {
+ return -1
+ }
+
+ if e.options.StreamLength != 0 {
+ // If the stream length is known, and there is no chunk length specified,
+ // only a single chunk will be used. Otherwise the stream length needs to
+ // include the multiple chunk padding content.
+ if e.options.ChunkLength == -1 {
+ length += getUnsignedChunkBytesLength(e.options.StreamLength)
+
+ } else {
+ // Compute chunk header and payload length
+ numChunks := e.options.StreamLength / int64(e.options.ChunkLength)
+ length += numChunks * getUnsignedChunkBytesLength(int64(e.options.ChunkLength))
+ if remainder := e.options.StreamLength % int64(e.options.ChunkLength); remainder != 0 {
+ length += getUnsignedChunkBytesLength(remainder)
+ }
+ }
+ }
+
+ // End chunk
+ length += int64(len(finalChunkBytes))
+
+ // Trailers
+ length += int64(e.trailerEncodedLength)
+
+ // Encoding terminator
+ length += int64(len(crlf))
+
+ return length
+}
+
+func getUnsignedChunkBytesLength(payloadLength int64) int64 {
+ payloadLengthStr := strconv.FormatInt(payloadLength, 16)
+ return int64(len(payloadLengthStr)) + int64(len(crlf)) + payloadLength + int64(len(crlf))
+}
+
+// HTTPHeaders returns the set of headers that must be included the request for
+// aws-chunked to work. This includes the content-encoding: aws-chunked header.
+//
+// If there are multiple layered content encoding, the aws-chunked encoding
+// must be appended to the previous layers the stream's encoding. The best way
+// to do this is to append all header values returned to the HTTP request's set
+// of headers.
+func (e *awsChunkedEncoding) HTTPHeaders() map[string][]string {
+ headers := map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ }
+
+ if len(e.options.Trailers) != 0 {
+ trailers := make([]string, 0, len(e.options.Trailers))
+ for name := range e.options.Trailers {
+ trailers = append(trailers, strings.ToLower(name))
+ }
+ headers[awsTrailerHeaderName] = trailers
+ }
+
+ return headers
+}
+
+func (e *awsChunkedEncoding) Read(b []byte) (n int, err error) {
+ return e.encodedStream.Read(b)
+}
+
+// awsChunkedTrailerReader provides a lazy reader for reading of aws-chunked
+// content encoded trailers. The trailer values will not be retrieved until the
+// reader is read from.
+type awsChunkedTrailerReader struct {
+ reader *bytes.Buffer
+ trailers map[string]awsChunkedTrailerValue
+ trailerEncodedLength int
+}
+
+// newAWSChunkedTrailerReader returns an initialized awsChunkedTrailerReader to
+// lazy reading aws-chunk content encoded trailers.
+func newAWSChunkedTrailerReader(trailers map[string]awsChunkedTrailerValue) *awsChunkedTrailerReader {
+ return &awsChunkedTrailerReader{
+ trailers: trailers,
+ trailerEncodedLength: trailerEncodedLength(trailers),
+ }
+}
+
+func trailerEncodedLength(trailers map[string]awsChunkedTrailerValue) (length int) {
+ for name, trailer := range trailers {
+ length += len(name) + len(trailerKeyValueSeparator)
+ l := trailer.Length
+ if l == -1 {
+ return -1
+ }
+ length += l + len(crlf)
+ }
+
+ return length
+}
+
+// EncodedLength returns the length of the encoded trailers if the length could
+// be determined without retrieving the header values. Returns -1 if length is
+// unknown.
+func (r *awsChunkedTrailerReader) EncodedLength() (length int) {
+ return r.trailerEncodedLength
+}
+
+// Read populates the passed in byte slice with bytes from the encoded
+// trailers. Will lazy read header values first time Read is called.
+func (r *awsChunkedTrailerReader) Read(p []byte) (int, error) {
+ if r.trailerEncodedLength == 0 {
+ return 0, io.EOF
+ }
+
+ if r.reader == nil {
+ trailerLen := r.trailerEncodedLength
+ if r.trailerEncodedLength == -1 {
+ trailerLen = 0
+ }
+ r.reader = bytes.NewBuffer(make([]byte, 0, trailerLen))
+ for name, trailer := range r.trailers {
+ r.reader.WriteString(name)
+ r.reader.WriteString(trailerKeyValueSeparator)
+ v, err := trailer.Get()
+ if err != nil {
+ return 0, fmt.Errorf("failed to get trailer value, %w", err)
+ }
+ r.reader.WriteString(v)
+ r.reader.WriteString(crlf)
+ }
+ }
+
+ return r.reader.Read(p)
+}
+
+// newUnsignedChunkReader returns an io.Reader encoding the underlying reader
+// as unsigned aws-chunked chunks. The returned reader will also include the
+// end chunk, but not the aws-chunked final `crlf` segment so trailers can be
+// added.
+//
+// If the payload size is -1 for unknown length the content will be buffered in
+// defaultChunkLength chunks before wrapped in aws-chunked chunk encoding.
+func newUnsignedChunkReader(reader io.Reader, payloadSize int64) io.Reader {
+ if payloadSize == -1 {
+ return newBufferedAWSChunkReader(reader, defaultChunkLength)
+ }
+
+ var endChunk bytes.Buffer
+ if payloadSize == 0 {
+ endChunk.Write(finalChunkBytes)
+ return &endChunk
+ }
+
+ endChunk.WriteString(crlf)
+ endChunk.Write(finalChunkBytes)
+
+ var header bytes.Buffer
+ header.WriteString(strconv.FormatInt(payloadSize, 16))
+ header.WriteString(crlf)
+ return io.MultiReader(
+ &header,
+ reader,
+ &endChunk,
+ )
+}
+
+// Provides a buffered aws-chunked chunk encoder of an underlying io.Reader.
+// Will include end chunk, but not the aws-chunked final `crlf` segment so
+// trailers can be added.
+//
+// Note does not implement support for chunk extensions, e.g. chunk signing.
+type bufferedAWSChunkReader struct {
+ reader io.Reader
+ chunkSize int
+ chunkSizeStr string
+
+ headerBuffer *bytes.Buffer
+ chunkBuffer *bytes.Buffer
+
+ multiReader io.Reader
+ multiReaderLen int
+ endChunkDone bool
+}
+
+// newBufferedAWSChunkReader returns an bufferedAWSChunkReader for reading
+// aws-chunked encoded chunks.
+func newBufferedAWSChunkReader(reader io.Reader, chunkSize int) *bufferedAWSChunkReader {
+ return &bufferedAWSChunkReader{
+ reader: reader,
+ chunkSize: chunkSize,
+ chunkSizeStr: strconv.FormatInt(int64(chunkSize), 16),
+
+ headerBuffer: bytes.NewBuffer(make([]byte, 0, 64)),
+ chunkBuffer: bytes.NewBuffer(make([]byte, 0, chunkSize+len(crlf))),
+ }
+}
+
+// Read attempts to read from the underlying io.Reader writing aws-chunked
+// chunk encoded bytes to p. When the underlying io.Reader has been completed
+// read the end chunk will be available. Once the end chunk is read, the reader
+// will return EOF.
+func (r *bufferedAWSChunkReader) Read(p []byte) (n int, err error) {
+ if r.multiReaderLen == 0 && r.endChunkDone {
+ return 0, io.EOF
+ }
+ if r.multiReader == nil || r.multiReaderLen == 0 {
+ r.multiReader, r.multiReaderLen, err = r.newMultiReader()
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ n, err = r.multiReader.Read(p)
+ r.multiReaderLen -= n
+
+ if err == io.EOF && !r.endChunkDone {
+ // Edge case handling when the multi-reader has been completely read,
+ // and returned an EOF, make sure that EOF only gets returned if the
+ // end chunk was included in the multi-reader. Otherwise, the next call
+ // to read will initialize the next chunk's multi-reader.
+ err = nil
+ }
+ return n, err
+}
+
+// newMultiReader returns a new io.Reader for wrapping the next chunk. Will
+// return an error if the underlying reader can not be read from. Will never
+// return io.EOF.
+func (r *bufferedAWSChunkReader) newMultiReader() (io.Reader, int, error) {
+ // io.Copy eats the io.EOF returned by io.LimitReader. Any error that
+ // occurs here is due to an actual read error.
+ n, err := io.Copy(r.chunkBuffer, io.LimitReader(r.reader, int64(r.chunkSize)))
+ if err != nil {
+ return nil, 0, err
+ }
+ if n == 0 {
+ // Early exit writing out only the end chunk. This does not include
+ // aws-chunk's final `crlf` so that trailers can still be added by
+ // upstream reader.
+ r.headerBuffer.Reset()
+ r.headerBuffer.WriteString("0")
+ r.headerBuffer.WriteString(crlf)
+ r.endChunkDone = true
+
+ return r.headerBuffer, r.headerBuffer.Len(), nil
+ }
+ r.chunkBuffer.WriteString(crlf)
+
+ chunkSizeStr := r.chunkSizeStr
+ if int(n) != r.chunkSize {
+ chunkSizeStr = strconv.FormatInt(n, 16)
+ }
+
+ r.headerBuffer.Reset()
+ r.headerBuffer.WriteString(chunkSizeStr)
+ r.headerBuffer.WriteString(crlf)
+
+ return io.MultiReader(
+ r.headerBuffer,
+ r.chunkBuffer,
+ ), r.headerBuffer.Len() + r.chunkBuffer.Len(), nil
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go
new file mode 100644
index 0000000000..8e9ce3c8a9
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go
@@ -0,0 +1,507 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "strings"
+ "testing"
+ "testing/iotest"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestAWSChunkedEncoding(t *testing.T) {
+ cases := map[string]struct {
+ reader *awsChunkedEncoding
+ expectErr string
+ expectEncodedLength int64
+ expectHTTPHeaders map[string][]string
+ expectPayload []byte
+ }{
+ "empty payload fixed stream length": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader(""),
+ func(o *awsChunkedEncodingOptions) {
+ o.StreamLength = 0
+ }),
+ expectEncodedLength: 5,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("0\r\n\r\n"),
+ },
+ "empty payload unknown stream length": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("")),
+ expectEncodedLength: -1,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("0\r\n\r\n"),
+ },
+ "payload fixed stream length": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"),
+ func(o *awsChunkedEncodingOptions) {
+ o.StreamLength = 11
+ }),
+ expectEncodedLength: 21,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"),
+ },
+ "payload unknown stream length": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world")),
+ expectEncodedLength: -1,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"),
+ },
+ "payload unknown stream length with chunk size": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"),
+ func(o *awsChunkedEncodingOptions) {
+ o.ChunkLength = 8
+ }),
+ expectEncodedLength: -1,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"),
+ },
+ "payload fixed stream length with chunk size": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"),
+ func(o *awsChunkedEncodingOptions) {
+ o.StreamLength = 11
+ o.ChunkLength = 8
+ }),
+ expectEncodedLength: 26,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ },
+ expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"),
+ },
+ "payload fixed stream length with fixed length trailer": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"),
+ func(o *awsChunkedEncodingOptions) {
+ o.StreamLength = 11
+ o.Trailers = map[string]awsChunkedTrailerValue{
+ "foo": {
+ Get: func() (string, error) {
+ return "abc123", nil
+ },
+ Length: 6,
+ },
+ }
+ }),
+ expectEncodedLength: 33,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ awsTrailerHeaderName: {"foo"},
+ },
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"),
+ },
+ "payload fixed stream length with unknown length trailer": {
+ reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"),
+ func(o *awsChunkedEncodingOptions) {
+ o.StreamLength = 11
+ o.Trailers = map[string]awsChunkedTrailerValue{
+ "foo": {
+ Get: func() (string, error) {
+ return "abc123", nil
+ },
+ Length: -1,
+ },
+ }
+ }),
+ expectEncodedLength: -1,
+ expectHTTPHeaders: map[string][]string{
+ contentEncodingHeaderName: {
+ awsChunkedContentEncodingHeaderValue,
+ },
+ awsTrailerHeaderName: {"foo"},
+ },
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"),
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a {
+ t.Errorf("expect %v encoded length, got %v", e, a)
+ }
+ if diff := cmp.Diff(c.expectHTTPHeaders, c.reader.HTTPHeaders()); diff != "" {
+ t.Errorf("expect HTTP headers match\n%v", diff)
+ }
+
+ actualPayload, err := ioutil.ReadAll(c.reader)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match\n%v", diff)
+ }
+ })
+ }
+}
+
+func TestUnsignedAWSChunkReader(t *testing.T) {
+ cases := map[string]struct {
+ payload interface {
+ io.Reader
+ Len() int
+ }
+
+ expectPayload []byte
+ expectErr string
+ }{
+ "empty body": {
+ payload: bytes.NewReader([]byte{}),
+ expectPayload: []byte("0\r\n"),
+ },
+ "with body": {
+ payload: strings.NewReader("Hello world"),
+ expectPayload: []byte("b\r\nHello world\r\n0\r\n"),
+ },
+ "large body": {
+ payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world"),
+ expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"),
+ },
+ "reader error": {
+ payload: newLimitReadLener(iotest.ErrReader(fmt.Errorf("some read error")), 128),
+ expectErr: "some read error",
+ },
+ "unknown length reader": {
+ payload: newUnknownLenReader(io.LimitReader(byteReader('a'), defaultChunkLength*2)),
+ expectPayload: func() []byte {
+ reader := newBufferedAWSChunkReader(
+ io.LimitReader(byteReader('a'), defaultChunkLength*2),
+ defaultChunkLength,
+ )
+ actualPayload, err := ioutil.ReadAll(reader)
+ if err != nil {
+ t.Fatalf("failed to create unknown length reader test data, %v", err)
+ }
+ return actualPayload
+ }(),
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ reader := newUnsignedChunkReader(c.payload, int64(c.payload.Len()))
+
+ actualPayload, err := ioutil.ReadAll(reader)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match\n%v", diff)
+ }
+ })
+ }
+}
+
+func TestBufferedAWSChunkReader(t *testing.T) {
+ cases := map[string]struct {
+ payload io.Reader
+ readSize int
+ chunkSize int
+
+ expectPayload []byte
+ expectErr string
+ }{
+ "empty body": {
+ payload: bytes.NewReader([]byte{}),
+ chunkSize: 4,
+ expectPayload: []byte("0\r\n"),
+ },
+ "with one chunk body": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 20,
+ expectPayload: []byte("b\r\nHello world\r\n0\r\n"),
+ },
+ "single byte read": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 8,
+ readSize: 1,
+ expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"),
+ },
+ "single chunk and byte read": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 1,
+ readSize: 1,
+ expectPayload: []byte("1\r\nH\r\n1\r\ne\r\n1\r\nl\r\n1\r\nl\r\n1\r\no\r\n1\r\n \r\n1\r\nw\r\n1\r\no\r\n1\r\nr\r\n1\r\nl\r\n1\r\nd\r\n0\r\n"),
+ },
+ "with two chunk body": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 8,
+ expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"),
+ },
+ "chunk size equal to read size": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 512,
+ expectPayload: []byte("b\r\nHello world\r\n0\r\n"),
+ },
+ "chunk size greater than read size": {
+ payload: strings.NewReader("Hello world"),
+ chunkSize: 1024,
+ expectPayload: []byte("b\r\nHello world\r\n0\r\n"),
+ },
+ "payload size more than default read size, chunk size less than read size": {
+ payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world"),
+ chunkSize: 500,
+ expectPayload: []byte("1f4\r\nHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello\r\n11\r\n worldHello world\r\n0\r\n"),
+ },
+ "payload size more than default read size, chunk size equal to read size": {
+ payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world"),
+ chunkSize: 512,
+ expectPayload: []byte("200\r\nHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello \r\n5\r\nworld\r\n0\r\n"),
+ },
+ "payload size more than default read size, chunk size more than read size": {
+ payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world"),
+ chunkSize: 1024,
+ expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " +
+ "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"),
+ },
+ "reader error": {
+ payload: iotest.ErrReader(fmt.Errorf("some read error")),
+ chunkSize: 128,
+ expectErr: "some read error",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ reader := newBufferedAWSChunkReader(c.payload, c.chunkSize)
+
+ var actualPayload []byte
+ var err error
+
+ if c.readSize != 0 {
+ for err == nil {
+ var n int
+ p := make([]byte, c.readSize)
+ n, err = reader.Read(p)
+ if n != 0 {
+ actualPayload = append(actualPayload, p[:n]...)
+ }
+ }
+ if err == io.EOF {
+ err = nil
+ }
+ } else {
+ actualPayload, err = ioutil.ReadAll(reader)
+ }
+
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match\n%v", diff)
+ }
+ })
+ }
+}
+
+func TestAwsChunkedTrailerReader(t *testing.T) {
+ cases := map[string]struct {
+ reader *awsChunkedTrailerReader
+
+ expectErr string
+ expectEncodedLength int
+ expectPayload []byte
+ }{
+ "no trailers": {
+ reader: newAWSChunkedTrailerReader(nil),
+ expectPayload: []byte{},
+ },
+ "unknown length trailers": {
+ reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{
+ "foo": {
+ Get: func() (string, error) {
+ return "abc123", nil
+ },
+ Length: -1,
+ },
+ }),
+ expectEncodedLength: -1,
+ expectPayload: []byte("foo:abc123\r\n"),
+ },
+ "known length trailers": {
+ reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{
+ "foo": {
+ Get: func() (string, error) {
+ return "abc123", nil
+ },
+ Length: 6,
+ },
+ }),
+ expectEncodedLength: 12,
+ expectPayload: []byte("foo:abc123\r\n"),
+ },
+ "trailer error": {
+ reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{
+ "foo": {
+ Get: func() (string, error) {
+ return "", fmt.Errorf("some error")
+ },
+ Length: 6,
+ },
+ }),
+ expectEncodedLength: 12,
+ expectErr: "failed to get trailer",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a {
+ t.Errorf("expect %v encoded length, got %v", e, a)
+ }
+
+ actualPayload, err := ioutil.ReadAll(c.reader)
+
+ // Asserts
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match\n%v", diff)
+ }
+ })
+ }
+}
+
+type limitReadLener struct {
+ io.Reader
+ length int
+}
+
+func newLimitReadLener(r io.Reader, l int) *limitReadLener {
+ return &limitReadLener{
+ Reader: io.LimitReader(r, int64(l)),
+ length: l,
+ }
+}
+func (r *limitReadLener) Len() int {
+ return r.length
+}
+
+type unknownLenReader struct {
+ io.Reader
+}
+
+func newUnknownLenReader(r io.Reader) *unknownLenReader {
+ return &unknownLenReader{
+ Reader: r,
+ }
+}
+func (r *unknownLenReader) Len() int {
+ return -1
+}
+
+type byteReader byte
+
+func (r byteReader) Read(p []byte) (int, error) {
+ for i := 0; i < len(p); i++ {
+ p[i] = byte(r)
+ }
+ return len(p), nil
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go
new file mode 100644
index 0000000000..f591861505
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go
@@ -0,0 +1,6 @@
+// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT.
+
+package checksum
+
+// goModuleVersion is the tagged release for this module
+const goModuleVersion = "1.2.0"
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make
new file mode 100644
index 0000000000..2a3d936a06
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make
@@ -0,0 +1,5 @@
+GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum)
+
+LICENSE(Apache-2.0)
+
+END()
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go
new file mode 100644
index 0000000000..3e17d2216b
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go
@@ -0,0 +1,185 @@
+package checksum
+
+import (
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+)
+
+// InputMiddlewareOptions provides the options for the request
+// checksum middleware setup.
+type InputMiddlewareOptions struct {
+ // GetAlgorithm is a function to get the checksum algorithm of the
+ // input payload from the input parameters.
+ //
+ // Given the input parameter value, the function must return the algorithm
+ // and true, or false if no algorithm is specified.
+ GetAlgorithm func(interface{}) (string, bool)
+
+ // Forces the middleware to compute the input payload's checksum. The
+ // request will fail if the algorithm is not specified or unable to compute
+ // the checksum.
+ RequireChecksum bool
+
+ // Enables support for wrapping the serialized input payload with a
+ // content-encoding: aws-check wrapper, and including a trailer for the
+ // algorithm's checksum value.
+ //
+ // The checksum will not be computed, nor added as trailing checksum, if
+ // the Algorithm's header is already set on the request.
+ EnableTrailingChecksum bool
+
+ // Enables support for computing the SHA256 checksum of input payloads
+ // along with the algorithm specified checksum. Prevents downstream
+ // middleware handlers (computePayloadSHA256) re-reading the payload.
+ //
+ // The SHA256 payload checksum will only be used for computed for requests
+ // that are not TLS, or do not enable trailing checksums.
+ //
+ // The SHA256 payload hash will not be computed, if the Algorithm's header
+ // is already set on the request.
+ EnableComputeSHA256PayloadHash bool
+
+ // Enables support for setting the aws-chunked decoded content length
+ // header for the decoded length of the underlying stream. Will only be set
+ // when used with trailing checksums, and aws-chunked content-encoding.
+ EnableDecodedContentLengthHeader bool
+}
+
+// AddInputMiddleware adds the middleware for performing checksum computing
+// of request payloads, and checksum validation of response payloads.
+func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) (err error) {
+ // TODO ensure this works correctly with presigned URLs
+
+ // Middleware stack:
+ // * (OK)(Initialize) --none--
+ // * (OK)(Serialize) EndpointResolver
+ // * (OK)(Build) ComputeContentLength
+ // * (AD)(Build) Header ComputeInputPayloadChecksum
+ // * SIGNED Payload - If HTTP && not support trailing checksum
+ // * UNSIGNED Payload - If HTTPS && not support trailing checksum
+ // * (RM)(Build) ContentChecksum - OK to remove
+ // * (OK)(Build) ComputePayloadHash
+ // * v4.dynamicPayloadSigningMiddleware
+ // * v4.computePayloadSHA256
+ // * v4.unsignedPayload
+ // (OK)(Build) Set computedPayloadHash header
+ // * (OK)(Finalize) Retry
+ // * (AD)(Finalize) Trailer ComputeInputPayloadChecksum,
+ // * Requires HTTPS && support trailing checksum
+ // * UNSIGNED Payload
+ // * Finalize run if HTTPS && support trailing checksum
+ // * (OK)(Finalize) Signing
+ // * (OK)(Deserialize) --none--
+
+ // Initial checksum configuration look up middleware
+ err = stack.Initialize.Add(&setupInputContext{
+ GetAlgorithm: options.GetAlgorithm,
+ }, middleware.Before)
+ if err != nil {
+ return err
+ }
+
+ stack.Build.Remove("ContentChecksum")
+
+ // Create the compute checksum middleware that will be added as both a
+ // build and finalize handler.
+ inputChecksum := &computeInputPayloadChecksum{
+ RequireChecksum: options.RequireChecksum,
+ EnableTrailingChecksum: options.EnableTrailingChecksum,
+ EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash,
+ EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader,
+ }
+
+ // Insert header checksum after ComputeContentLength middleware, must also
+ // be before the computePayloadHash middleware handlers.
+ err = stack.Build.Insert(inputChecksum,
+ (*smithyhttp.ComputeContentLength)(nil).ID(),
+ middleware.After)
+ if err != nil {
+ return err
+ }
+
+ // If trailing checksum is not supported no need for finalize handler to be added.
+ if options.EnableTrailingChecksum {
+ err = stack.Finalize.Insert(inputChecksum, "Retry", middleware.After)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// RemoveInputMiddleware Removes the compute input payload checksum middleware
+// handlers from the stack.
+func RemoveInputMiddleware(stack *middleware.Stack) {
+ id := (*setupInputContext)(nil).ID()
+ stack.Initialize.Remove(id)
+
+ id = (*computeInputPayloadChecksum)(nil).ID()
+ stack.Build.Remove(id)
+ stack.Finalize.Remove(id)
+}
+
+// OutputMiddlewareOptions provides options for configuring output checksum
+// validation middleware.
+type OutputMiddlewareOptions struct {
+ // GetValidationMode is a function to get the checksum validation
+ // mode of the output payload from the input parameters.
+ //
+ // Given the input parameter value, the function must return the validation
+ // mode and true, or false if no mode is specified.
+ GetValidationMode func(interface{}) (string, bool)
+
+ // The set of checksum algorithms that should be used for response payload
+ // checksum validation. The algorithm(s) used will be a union of the
+ // output's returned algorithms and this set.
+ //
+ // Only the first algorithm in the union is currently used.
+ ValidationAlgorithms []string
+
+ // If set the middleware will ignore output multipart checksums. Otherwise
+ // an checksum format error will be returned by the middleware.
+ IgnoreMultipartValidation bool
+
+ // When set the middleware will log when output does not have checksum or
+ // algorithm to validate.
+ LogValidationSkipped bool
+
+ // When set the middleware will log when the output contains a multipart
+ // checksum that was, skipped and not validated.
+ LogMultipartValidationSkipped bool
+}
+
+// AddOutputMiddleware adds the middleware for validating response payload's
+// checksum.
+func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error {
+ err := stack.Initialize.Add(&setupOutputContext{
+ GetValidationMode: options.GetValidationMode,
+ }, middleware.Before)
+ if err != nil {
+ return err
+ }
+
+ // Resolve a supported priority order list of algorithms to validate.
+ algorithms := FilterSupportedAlgorithms(options.ValidationAlgorithms)
+
+ m := &validateOutputPayloadChecksum{
+ Algorithms: algorithms,
+ IgnoreMultipartValidation: options.IgnoreMultipartValidation,
+ LogMultipartValidationSkipped: options.LogMultipartValidationSkipped,
+ LogValidationSkipped: options.LogValidationSkipped,
+ }
+
+ return stack.Deserialize.Add(m, middleware.After)
+}
+
+// RemoveOutputMiddleware Removes the compute input payload checksum middleware
+// handlers from the stack.
+func RemoveOutputMiddleware(stack *middleware.Stack) {
+ id := (*setupOutputContext)(nil).ID()
+ stack.Initialize.Remove(id)
+
+ id = (*validateOutputPayloadChecksum)(nil).ID()
+ stack.Deserialize.Remove(id)
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go
new file mode 100644
index 0000000000..33f952ebe9
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go
@@ -0,0 +1,412 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "context"
+ "testing"
+
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestAddInputMiddleware(t *testing.T) {
+ cases := map[string]struct {
+ options InputMiddlewareOptions
+ expectErr string
+ expectMiddleware []string
+ expectInitialize *setupInputContext
+ expectBuild *computeInputPayloadChecksum
+ expectFinalize *computeInputPayloadChecksum
+ }{
+ "with trailing checksum": {
+ options: InputMiddlewareOptions{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ EnableTrailingChecksum: true,
+ EnableComputeSHA256PayloadHash: true,
+ EnableDecodedContentLengthHeader: true,
+ },
+ expectMiddleware: []string{
+ "test",
+ "Initialize stack step",
+ "AWSChecksum:SetupInputContext",
+ "Serialize stack step",
+ "Build stack step",
+ "ComputeContentLength",
+ "AWSChecksum:ComputeInputPayloadChecksum",
+ "ComputePayloadHash",
+ "Finalize stack step",
+ "Retry",
+ "AWSChecksum:ComputeInputPayloadChecksum",
+ "Signing",
+ "Deserialize stack step",
+ },
+ expectInitialize: &setupInputContext{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ },
+ expectBuild: &computeInputPayloadChecksum{
+ EnableTrailingChecksum: true,
+ EnableComputePayloadHash: true,
+ EnableDecodedContentLengthHeader: true,
+ },
+ },
+ "with checksum required": {
+ options: InputMiddlewareOptions{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ EnableTrailingChecksum: true,
+ RequireChecksum: true,
+ },
+ expectMiddleware: []string{
+ "test",
+ "Initialize stack step",
+ "AWSChecksum:SetupInputContext",
+ "Serialize stack step",
+ "Build stack step",
+ "ComputeContentLength",
+ "AWSChecksum:ComputeInputPayloadChecksum",
+ "ComputePayloadHash",
+ "Finalize stack step",
+ "Retry",
+ "AWSChecksum:ComputeInputPayloadChecksum",
+ "Signing",
+ "Deserialize stack step",
+ },
+ expectInitialize: &setupInputContext{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ },
+ expectBuild: &computeInputPayloadChecksum{
+ RequireChecksum: true,
+ EnableTrailingChecksum: true,
+ },
+ },
+ "no trailing checksum": {
+ options: InputMiddlewareOptions{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ },
+ expectMiddleware: []string{
+ "test",
+ "Initialize stack step",
+ "AWSChecksum:SetupInputContext",
+ "Serialize stack step",
+ "Build stack step",
+ "ComputeContentLength",
+ "AWSChecksum:ComputeInputPayloadChecksum",
+ "ComputePayloadHash",
+ "Finalize stack step",
+ "Retry",
+ "Signing",
+ "Deserialize stack step",
+ },
+ expectInitialize: &setupInputContext{
+ GetAlgorithm: func(interface{}) (string, bool) {
+ return string(AlgorithmCRC32), true
+ },
+ },
+ expectBuild: &computeInputPayloadChecksum{},
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ stack := middleware.NewStack("test", smithyhttp.NewStackRequest)
+
+ stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After)
+ stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After)
+ stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After)
+ stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After)
+ stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After)
+
+ err := AddInputMiddleware(stack, c.options)
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" {
+ t.Fatalf("expect stack list match:\n%s", diff)
+ }
+
+ initializeMiddleware, ok := stack.Initialize.Get((*setupInputContext)(nil).ID())
+ if e, a := (c.expectInitialize != nil), ok; e != a {
+ t.Errorf("expect initialize middleware %t, got %t", e, a)
+ }
+ if c.expectInitialize != nil && ok {
+ setupInput := initializeMiddleware.(*setupInputContext)
+ if e, a := c.options.GetAlgorithm != nil, setupInput.GetAlgorithm != nil; e != a {
+ t.Fatalf("expect GetAlgorithm %t, got %t", e, a)
+ }
+ expectAlgo, expectOK := c.options.GetAlgorithm(nil)
+ actualAlgo, actualOK := setupInput.GetAlgorithm(nil)
+ if e, a := expectAlgo, actualAlgo; e != a {
+ t.Errorf("expect %v algorithm, got %v", e, a)
+ }
+ if e, a := expectOK, actualOK; e != a {
+ t.Errorf("expect %v algorithm present, got %v", e, a)
+ }
+ }
+
+ buildMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID())
+ if e, a := (c.expectBuild != nil), ok; e != a {
+ t.Errorf("expect build middleware %t, got %t", e, a)
+ }
+ var computeInput *computeInputPayloadChecksum
+ if c.expectBuild != nil && ok {
+ computeInput = buildMiddleware.(*computeInputPayloadChecksum)
+ if e, a := c.expectBuild.RequireChecksum, computeInput.RequireChecksum; e != a {
+ t.Errorf("expect %v require checksum, got %v", e, a)
+ }
+ if e, a := c.expectBuild.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a {
+ t.Errorf("expect %v enable trailing checksum, got %v", e, a)
+ }
+ if e, a := c.expectBuild.EnableComputePayloadHash, computeInput.EnableComputePayloadHash; e != a {
+ t.Errorf("expect %v enable compute payload hash, got %v", e, a)
+ }
+ if e, a := c.expectBuild.EnableDecodedContentLengthHeader, computeInput.EnableDecodedContentLengthHeader; e != a {
+ t.Errorf("expect %v enable decoded length header, got %v", e, a)
+ }
+ }
+
+ if c.expectFinalize != nil && ok {
+ finalizeMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID())
+ if !ok {
+ t.Errorf("expect finalize middleware")
+ }
+ finalizeComputeInput := finalizeMiddleware.(*computeInputPayloadChecksum)
+
+ if e, a := computeInput, finalizeComputeInput; e != a {
+ t.Errorf("expect build and finalize to be same value")
+ }
+ }
+ })
+ }
+}
+
+func TestRemoveInputMiddleware(t *testing.T) {
+ stack := middleware.NewStack("test", smithyhttp.NewStackRequest)
+
+ stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After)
+ stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After)
+ stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After)
+ stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After)
+ stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After)
+
+ err := AddInputMiddleware(stack, InputMiddlewareOptions{
+ EnableTrailingChecksum: true,
+ })
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ RemoveInputMiddleware(stack)
+
+ expectStack := []string{
+ "test",
+ "Initialize stack step",
+ "Serialize stack step",
+ "Build stack step",
+ "ComputeContentLength",
+ "ComputePayloadHash",
+ "Finalize stack step",
+ "Retry",
+ "Signing",
+ "Deserialize stack step",
+ }
+
+ if diff := cmp.Diff(expectStack, stack.List()); diff != "" {
+ t.Fatalf("expect stack list match:\n%s", diff)
+ }
+}
+
+func TestAddOutputMiddleware(t *testing.T) {
+ cases := map[string]struct {
+ options OutputMiddlewareOptions
+ expectErr string
+ expectMiddleware []string
+ expectInitialize *setupOutputContext
+ expectDeserialize *validateOutputPayloadChecksum
+ }{
+ "validate output": {
+ options: OutputMiddlewareOptions{
+ GetValidationMode: func(interface{}) (string, bool) {
+ return "ENABLED", true
+ },
+ ValidationAlgorithms: []string{
+ "crc32", "sha1", "abc123", "crc32c",
+ },
+ IgnoreMultipartValidation: true,
+ LogMultipartValidationSkipped: true,
+ LogValidationSkipped: true,
+ },
+ expectMiddleware: []string{
+ "test",
+ "Initialize stack step",
+ "AWSChecksum:SetupOutputContext",
+ "Serialize stack step",
+ "Build stack step",
+ "Finalize stack step",
+ "Deserialize stack step",
+ "AWSChecksum:ValidateOutputPayloadChecksum",
+ },
+ expectInitialize: &setupOutputContext{
+ GetValidationMode: func(interface{}) (string, bool) {
+ return "ENABLED", true
+ },
+ },
+ expectDeserialize: &validateOutputPayloadChecksum{
+ Algorithms: []Algorithm{
+ AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C,
+ },
+ IgnoreMultipartValidation: true,
+ LogMultipartValidationSkipped: true,
+ LogValidationSkipped: true,
+ },
+ },
+ "validate options off": {
+ options: OutputMiddlewareOptions{
+ GetValidationMode: func(interface{}) (string, bool) {
+ return "ENABLED", true
+ },
+ ValidationAlgorithms: []string{
+ "crc32", "sha1", "abc123", "crc32c",
+ },
+ },
+ expectMiddleware: []string{
+ "test",
+ "Initialize stack step",
+ "AWSChecksum:SetupOutputContext",
+ "Serialize stack step",
+ "Build stack step",
+ "Finalize stack step",
+ "Deserialize stack step",
+ "AWSChecksum:ValidateOutputPayloadChecksum",
+ },
+ expectInitialize: &setupOutputContext{
+ GetValidationMode: func(interface{}) (string, bool) {
+ return "ENABLED", true
+ },
+ },
+ expectDeserialize: &validateOutputPayloadChecksum{
+ Algorithms: []Algorithm{
+ AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C,
+ },
+ },
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ stack := middleware.NewStack("test", smithyhttp.NewStackRequest)
+
+ err := AddOutputMiddleware(stack, c.options)
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" {
+ t.Fatalf("expect stack list match:\n%s", diff)
+ }
+
+ initializeMiddleware, ok := stack.Initialize.Get((*setupOutputContext)(nil).ID())
+ if e, a := (c.expectInitialize != nil), ok; e != a {
+ t.Errorf("expect initialize middleware %t, got %t", e, a)
+ }
+ if c.expectInitialize != nil && ok {
+ setupOutput := initializeMiddleware.(*setupOutputContext)
+ if e, a := c.options.GetValidationMode != nil, setupOutput.GetValidationMode != nil; e != a {
+ t.Fatalf("expect GetValidationMode %t, got %t", e, a)
+ }
+ expectMode, expectOK := c.options.GetValidationMode(nil)
+ actualMode, actualOK := setupOutput.GetValidationMode(nil)
+ if e, a := expectMode, actualMode; e != a {
+ t.Errorf("expect %v mode, got %v", e, a)
+ }
+ if e, a := expectOK, actualOK; e != a {
+ t.Errorf("expect %v mode present, got %v", e, a)
+ }
+ }
+
+ deserializeMiddleware, ok := stack.Deserialize.Get((*validateOutputPayloadChecksum)(nil).ID())
+ if e, a := (c.expectDeserialize != nil), ok; e != a {
+ t.Errorf("expect deserialize middleware %t, got %t", e, a)
+ }
+ if c.expectDeserialize != nil && ok {
+ validateOutput := deserializeMiddleware.(*validateOutputPayloadChecksum)
+ if diff := cmp.Diff(c.expectDeserialize.Algorithms, validateOutput.Algorithms); diff != "" {
+ t.Errorf("expect algorithms match:\n%s", diff)
+ }
+ if e, a := c.expectDeserialize.IgnoreMultipartValidation, validateOutput.IgnoreMultipartValidation; e != a {
+ t.Errorf("expect %v ignore multipart checksum, got %v", e, a)
+ }
+ if e, a := c.expectDeserialize.LogMultipartValidationSkipped, validateOutput.LogMultipartValidationSkipped; e != a {
+ t.Errorf("expect %v log multipart skipped, got %v", e, a)
+ }
+ if e, a := c.expectDeserialize.LogValidationSkipped, validateOutput.LogValidationSkipped; e != a {
+ t.Errorf("expect %v log validation skipped, got %v", e, a)
+ }
+ }
+ })
+ }
+}
+
+func TestRemoveOutputMiddleware(t *testing.T) {
+ stack := middleware.NewStack("test", smithyhttp.NewStackRequest)
+
+ err := AddOutputMiddleware(stack, OutputMiddlewareOptions{})
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ RemoveOutputMiddleware(stack)
+
+ expectStack := []string{
+ "test",
+ "Initialize stack step",
+ "Serialize stack step",
+ "Build stack step",
+ "Finalize stack step",
+ "Deserialize stack step",
+ }
+
+ if diff := cmp.Diff(expectStack, stack.List()); diff != "" {
+ t.Fatalf("expect stack list match:\n%s", diff)
+ }
+}
+
+func setSerializedRequest(req *smithyhttp.Request) middleware.SerializeMiddleware {
+ return middleware.SerializeMiddlewareFunc("OperationSerializer",
+ func(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler) (
+ middleware.SerializeOutput, middleware.Metadata, error,
+ ) {
+ input.Request = req
+ return next.HandleSerialize(ctx, input)
+ })
+}
+
+func nopBuildMiddleware(id string) middleware.BuildMiddleware {
+ return middleware.BuildMiddlewareFunc(id,
+ func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) (
+ middleware.BuildOutput, middleware.Metadata, error,
+ ) {
+ return next.HandleBuild(ctx, input)
+ })
+}
+
+func nopFinalizeMiddleware(id string) middleware.FinalizeMiddleware {
+ return middleware.FinalizeMiddlewareFunc(id,
+ func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) (
+ middleware.FinalizeOutput, middleware.Metadata, error,
+ ) {
+ return next.HandleFinalize(ctx, input)
+ })
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go
new file mode 100644
index 0000000000..0b3c48912b
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go
@@ -0,0 +1,479 @@
+package checksum
+
+import (
+ "context"
+ "crypto/sha256"
+ "fmt"
+ "hash"
+ "io"
+ "strconv"
+
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+)
+
+const (
+ contentMD5Header = "Content-Md5"
+ streamingUnsignedPayloadTrailerPayloadHash = "STREAMING-UNSIGNED-PAYLOAD-TRAILER"
+)
+
+// computedInputChecksumsKey is the metadata key for recording the algorithm the
+// checksum was computed for and the checksum value.
+type computedInputChecksumsKey struct{}
+
+// GetComputedInputChecksums returns the map of checksum algorithm to their
+// computed value stored in the middleware Metadata. Returns false if no values
+// were stored in the Metadata.
+func GetComputedInputChecksums(m middleware.Metadata) (map[string]string, bool) {
+ vs, ok := m.Get(computedInputChecksumsKey{}).(map[string]string)
+ return vs, ok
+}
+
+// SetComputedInputChecksums stores the map of checksum algorithm to their
+// computed value in the middleware Metadata. Overwrites any values that
+// currently exist in the metadata.
+func SetComputedInputChecksums(m *middleware.Metadata, vs map[string]string) {
+ m.Set(computedInputChecksumsKey{}, vs)
+}
+
+// computeInputPayloadChecksum middleware computes payload checksum
+type computeInputPayloadChecksum struct {
+ // Enables support for wrapping the serialized input payload with a
+ // content-encoding: aws-check wrapper, and including a trailer for the
+ // algorithm's checksum value.
+ //
+ // The checksum will not be computed, nor added as trailing checksum, if
+ // the Algorithm's header is already set on the request.
+ EnableTrailingChecksum bool
+
+ // States that a checksum is required to be included for the operation. If
+ // Input does not specify a checksum, fallback to built in MD5 checksum is
+ // used.
+ //
+ // Replaces smithy-go's ContentChecksum middleware.
+ RequireChecksum bool
+
+ // Enables support for computing the SHA256 checksum of input payloads
+ // along with the algorithm specified checksum. Prevents downstream
+ // middleware handlers (computePayloadSHA256) re-reading the payload.
+ //
+ // The SHA256 payload hash will only be used for computed for requests
+ // that are not TLS, or do not enable trailing checksums.
+ //
+ // The SHA256 payload hash will not be computed, if the Algorithm's header
+ // is already set on the request.
+ EnableComputePayloadHash bool
+
+ // Enables support for setting the aws-chunked decoded content length
+ // header for the decoded length of the underlying stream. Will only be set
+ // when used with trailing checksums, and aws-chunked content-encoding.
+ EnableDecodedContentLengthHeader bool
+
+ buildHandlerRun bool
+ deferToFinalizeHandler bool
+}
+
+// ID provides the middleware's identifier.
+func (m *computeInputPayloadChecksum) ID() string {
+ return "AWSChecksum:ComputeInputPayloadChecksum"
+}
+
+type computeInputHeaderChecksumError struct {
+ Msg string
+ Err error
+}
+
+func (e computeInputHeaderChecksumError) Error() string {
+ const intro = "compute input header checksum failed"
+
+ if e.Err != nil {
+ return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err)
+ }
+
+ return fmt.Sprintf("%s, %s", intro, e.Msg)
+}
+func (e computeInputHeaderChecksumError) Unwrap() error { return e.Err }
+
+// HandleBuild handles computing the payload's checksum, in the following cases:
+// - Is HTTP, not HTTPS
+// - RequireChecksum is true, and no checksums were specified via the Input
+// - Trailing checksums are not supported
+//
+// The build handler must be inserted in the stack before ContentPayloadHash
+// and after ComputeContentLength.
+func (m *computeInputPayloadChecksum) HandleBuild(
+ ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
+) (
+ out middleware.BuildOutput, metadata middleware.Metadata, err error,
+) {
+ m.buildHandlerRun = true
+
+ req, ok := in.Request.(*smithyhttp.Request)
+ if !ok {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: fmt.Sprintf("unknown request type %T", req),
+ }
+ }
+
+ var algorithm Algorithm
+ var checksum string
+ defer func() {
+ if algorithm == "" || checksum == "" || err != nil {
+ return
+ }
+
+ // Record the checksum and algorithm that was computed
+ SetComputedInputChecksums(&metadata, map[string]string{
+ string(algorithm): checksum,
+ })
+ }()
+
+ // If no algorithm was specified, and the operation requires a checksum,
+ // fallback to the legacy content MD5 checksum.
+ algorithm, ok, err = getInputAlgorithm(ctx)
+ if err != nil {
+ return out, metadata, err
+ } else if !ok {
+ if m.RequireChecksum {
+ checksum, err = setMD5Checksum(ctx, req)
+ if err != nil {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: "failed to compute stream's MD5 checksum",
+ Err: err,
+ }
+ }
+ algorithm = Algorithm("MD5")
+ }
+ return next.HandleBuild(ctx, in)
+ }
+
+ // If the checksum header is already set nothing to do.
+ checksumHeader := AlgorithmHTTPHeader(algorithm)
+ if checksum = req.Header.Get(checksumHeader); checksum != "" {
+ return next.HandleBuild(ctx, in)
+ }
+
+ computePayloadHash := m.EnableComputePayloadHash
+ if v := v4.GetPayloadHash(ctx); v != "" {
+ computePayloadHash = false
+ }
+
+ stream := req.GetStream()
+ streamLength, err := getRequestStreamLength(req)
+ if err != nil {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: "failed to determine stream length",
+ Err: err,
+ }
+ }
+
+ // If trailing checksums are supported, the request is HTTPS, and the
+ // stream is not nil or empty, there is nothing to do in the build stage.
+ // The checksum will be added to the request as a trailing checksum in the
+ // finalize handler.
+ //
+ // Nil and empty streams will always be handled as a request header,
+ // regardless if the operation supports trailing checksums or not.
+ if req.IsHTTPS() {
+ if stream != nil && streamLength != 0 && m.EnableTrailingChecksum {
+ if m.EnableComputePayloadHash {
+ // payload hash is set as header in Build middleware handler,
+ // ContentSHA256Header.
+ ctx = v4.SetPayloadHash(ctx, streamingUnsignedPayloadTrailerPayloadHash)
+ }
+
+ m.deferToFinalizeHandler = true
+ return next.HandleBuild(ctx, in)
+ }
+
+ // If trailing checksums are not enabled but protocol is still HTTPS
+ // disabling computing the payload hash. Downstream middleware handler
+ // (ComputetPayloadHash) will set the payload hash to unsigned payload,
+ // if signing was used.
+ computePayloadHash = false
+ }
+
+ // Only seekable streams are supported for non-trailing checksums, because
+ // the stream needs to be rewound before the handler can continue.
+ if stream != nil && !req.IsStreamSeekable() {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: "unseekable stream is not supported without TLS and trailing checksum",
+ }
+ }
+
+ var sha256Checksum string
+ checksum, sha256Checksum, err = computeStreamChecksum(
+ algorithm, stream, computePayloadHash)
+ if err != nil {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: "failed to compute stream checksum",
+ Err: err,
+ }
+ }
+
+ if err := req.RewindStream(); err != nil {
+ return out, metadata, computeInputHeaderChecksumError{
+ Msg: "failed to rewind stream",
+ Err: err,
+ }
+ }
+
+ req.Header.Set(checksumHeader, checksum)
+
+ if computePayloadHash {
+ ctx = v4.SetPayloadHash(ctx, sha256Checksum)
+ }
+
+ return next.HandleBuild(ctx, in)
+}
+
+type computeInputTrailingChecksumError struct {
+ Msg string
+ Err error
+}
+
+func (e computeInputTrailingChecksumError) Error() string {
+ const intro = "compute input trailing checksum failed"
+
+ if e.Err != nil {
+ return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err)
+ }
+
+ return fmt.Sprintf("%s, %s", intro, e.Msg)
+}
+func (e computeInputTrailingChecksumError) Unwrap() error { return e.Err }
+
+// HandleFinalize handles computing the payload's checksum, in the following cases:
+// - Is HTTPS, not HTTP
+// - A checksum was specified via the Input
+// - Trailing checksums are supported.
+//
+// The finalize handler must be inserted in the stack before Signing, and after Retry.
+func (m *computeInputPayloadChecksum) HandleFinalize(
+ ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
+) (
+ out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
+) {
+ if !m.deferToFinalizeHandler {
+ if !m.buildHandlerRun {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "build handler was removed without also removing finalize handler",
+ }
+ }
+ return next.HandleFinalize(ctx, in)
+ }
+
+ req, ok := in.Request.(*smithyhttp.Request)
+ if !ok {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: fmt.Sprintf("unknown request type %T", req),
+ }
+ }
+
+ // Trailing checksums are only supported when TLS is enabled.
+ if !req.IsHTTPS() {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "HTTPS required",
+ }
+ }
+
+ // If no algorithm was specified, there is nothing to do.
+ algorithm, ok, err := getInputAlgorithm(ctx)
+ if err != nil {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "failed to get algorithm",
+ Err: err,
+ }
+ } else if !ok {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "no algorithm specified",
+ }
+ }
+
+ // If the checksum header is already set before finalize could run, there
+ // is nothing to do.
+ checksumHeader := AlgorithmHTTPHeader(algorithm)
+ if req.Header.Get(checksumHeader) != "" {
+ return next.HandleFinalize(ctx, in)
+ }
+
+ stream := req.GetStream()
+ streamLength, err := getRequestStreamLength(req)
+ if err != nil {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "failed to determine stream length",
+ Err: err,
+ }
+ }
+
+ if stream == nil || streamLength == 0 {
+ // Nil and empty streams are handled by the Build handler. They are not
+ // supported by the trailing checksums finalize handler. There is no
+ // benefit to sending them as trailers compared to headers.
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "nil or empty streams are not supported",
+ }
+ }
+
+ checksumReader, err := newComputeChecksumReader(stream, algorithm)
+ if err != nil {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "failed to created checksum reader",
+ Err: err,
+ }
+ }
+
+ awsChunkedReader := newUnsignedAWSChunkedEncoding(checksumReader,
+ func(o *awsChunkedEncodingOptions) {
+ o.Trailers[AlgorithmHTTPHeader(checksumReader.Algorithm())] = awsChunkedTrailerValue{
+ Get: checksumReader.Base64Checksum,
+ Length: checksumReader.Base64ChecksumLength(),
+ }
+ o.StreamLength = streamLength
+ })
+
+ for key, values := range awsChunkedReader.HTTPHeaders() {
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
+
+ // Setting the stream on the request will create a copy. The content length
+ // is not updated until after the request is copied to prevent impacting
+ // upstream middleware.
+ req, err = req.SetStream(awsChunkedReader)
+ if err != nil {
+ return out, metadata, computeInputTrailingChecksumError{
+ Msg: "failed updating request to trailing checksum wrapped stream",
+ Err: err,
+ }
+ }
+ req.ContentLength = awsChunkedReader.EncodedLength()
+ in.Request = req
+
+ // Add decoded content length header if original stream's content length is known.
+ if streamLength != -1 && m.EnableDecodedContentLengthHeader {
+ req.Header.Set(decodedContentLengthHeaderName, strconv.FormatInt(streamLength, 10))
+ }
+
+ out, metadata, err = next.HandleFinalize(ctx, in)
+ if err == nil {
+ checksum, err := checksumReader.Base64Checksum()
+ if err != nil {
+ return out, metadata, fmt.Errorf("failed to get computed checksum, %w", err)
+ }
+
+ // Record the checksum and algorithm that was computed
+ SetComputedInputChecksums(&metadata, map[string]string{
+ string(algorithm): checksum,
+ })
+ }
+
+ return out, metadata, err
+}
+
+func getInputAlgorithm(ctx context.Context) (Algorithm, bool, error) {
+ ctxAlgorithm := getContextInputAlgorithm(ctx)
+ if ctxAlgorithm == "" {
+ return "", false, nil
+ }
+
+ algorithm, err := ParseAlgorithm(ctxAlgorithm)
+ if err != nil {
+ return "", false, fmt.Errorf(
+ "failed to parse algorithm, %w", err)
+ }
+
+ return algorithm, true, nil
+}
+
+func computeStreamChecksum(algorithm Algorithm, stream io.Reader, computePayloadHash bool) (
+ checksum string, sha256Checksum string, err error,
+) {
+ hasher, err := NewAlgorithmHash(algorithm)
+ if err != nil {
+ return "", "", fmt.Errorf(
+ "failed to get hasher for checksum algorithm, %w", err)
+ }
+
+ var sha256Hasher hash.Hash
+ var batchHasher io.Writer = hasher
+
+ // Compute payload hash for the protocol. To prevent another handler
+ // (computePayloadSHA256) re-reading body also compute the SHA256 for
+ // request signing. If configured checksum algorithm is SHA256, don't
+ // double wrap stream with another SHA256 hasher.
+ if computePayloadHash && algorithm != AlgorithmSHA256 {
+ sha256Hasher = sha256.New()
+ batchHasher = io.MultiWriter(hasher, sha256Hasher)
+ }
+
+ if stream != nil {
+ if _, err = io.Copy(batchHasher, stream); err != nil {
+ return "", "", fmt.Errorf(
+ "failed to read stream to compute hash, %w", err)
+ }
+ }
+
+ checksum = string(base64EncodeHashSum(hasher))
+ if computePayloadHash {
+ if algorithm != AlgorithmSHA256 {
+ sha256Checksum = string(hexEncodeHashSum(sha256Hasher))
+ } else {
+ sha256Checksum = string(hexEncodeHashSum(hasher))
+ }
+ }
+
+ return checksum, sha256Checksum, nil
+}
+
+func getRequestStreamLength(req *smithyhttp.Request) (int64, error) {
+ if v := req.ContentLength; v > 0 {
+ return v, nil
+ }
+
+ if length, ok, err := req.StreamLength(); err != nil {
+ return 0, fmt.Errorf("failed getting request stream's length, %w", err)
+ } else if ok {
+ return length, nil
+ }
+
+ return -1, nil
+}
+
+// setMD5Checksum computes the MD5 of the request payload and sets it to the
+// Content-MD5 header. Returning the MD5 base64 encoded string or error.
+//
+// If the MD5 is already set as the Content-MD5 header, that value will be
+// returned, and nothing else will be done.
+//
+// If the payload is empty, no MD5 will be computed. No error will be returned.
+// Empty payloads do not have an MD5 value.
+//
+// Replaces the smithy-go middleware for httpChecksum trait.
+func setMD5Checksum(ctx context.Context, req *smithyhttp.Request) (string, error) {
+ if v := req.Header.Get(contentMD5Header); len(v) != 0 {
+ return v, nil
+ }
+ stream := req.GetStream()
+ if stream == nil {
+ return "", nil
+ }
+
+ if !req.IsStreamSeekable() {
+ return "", fmt.Errorf(
+ "unseekable stream is not supported for computing md5 checksum")
+ }
+
+ v, err := computeMD5Checksum(stream)
+ if err != nil {
+ return "", err
+ }
+ if err := req.RewindStream(); err != nil {
+ return "", fmt.Errorf("failed to rewind stream after computing MD5 checksum, %w", err)
+ }
+ // set the 'Content-MD5' header
+ req.Header.Set(contentMD5Header, string(v))
+ return string(v), nil
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go
new file mode 100644
index 0000000000..7ad59e7315
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go
@@ -0,0 +1,958 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+ "testing/iotest"
+
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/smithy-go/logging"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "github.com/google/go-cmp/cmp"
+)
+
+// TODO test cases:
+// * Retry re-wrapping payload
+
+func TestComputeInputPayloadChecksum(t *testing.T) {
+ cases := map[string]map[string]struct {
+ optionsFn func(*computeInputPayloadChecksum)
+ initContext func(context.Context) context.Context
+ buildInput middleware.BuildInput
+
+ expectErr string
+ expectBuildErr bool
+ expectFinalizeErr bool
+ expectReadErr bool
+
+ expectHeader http.Header
+ expectContentLength int64
+ expectPayload []byte
+ expectPayloadHash string
+
+ expectChecksumMetadata map[string]string
+
+ expectDeferToFinalize bool
+ expectLogged string
+ }{
+ "no op": {
+ "checksum header set known length": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==")
+ r = requestMust(r.SetStream(strings.NewReader("hello world")))
+ r.ContentLength = 11
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "checksum header set unknown length": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==")
+ r = requestMust(r.SetStream(strings.NewReader("hello world")))
+ r.ContentLength = -1
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: -1,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "no algorithm": {
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r = requestMust(r.SetStream(strings.NewReader("hello world")))
+ r.ContentLength = 11
+ return r
+ }(),
+ },
+ expectHeader: http.Header{},
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ },
+ "nil stream no algorithm require checksum": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.RequireChecksum = true
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ return r
+ }(),
+ },
+ expectContentLength: -1,
+ expectHeader: http.Header{},
+ },
+ },
+
+ "build handled": {
+ "http nil stream": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: -1,
+ expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "http empty stream": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 0
+ r = requestMust(r.SetStream(strings.NewReader("")))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: 0,
+ expectPayload: []byte{},
+ expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "https empty stream unseekable": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 0
+ r = requestMust(r.SetStream(&bytes.Buffer{}))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: 0,
+ expectPayload: nil,
+ expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "http empty stream unseekable": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 0
+ r = requestMust(r.SetStream(&bytes.Buffer{}))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: 0,
+ expectPayload: nil,
+ expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "https nil stream": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: -1,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "https empty stream": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 0
+ r = requestMust(r.SetStream(strings.NewReader("")))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"AAAAAA=="},
+ },
+ expectContentLength: 0,
+ expectPayload: []byte{},
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "AAAAAA==",
+ },
+ },
+ "http no algorithm require checksum": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.RequireChecksum = true
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "MD5": "XrY7u+Ae7tCTyyK7j1rNww==",
+ },
+ },
+ "http no algorithm require checksum header preset": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.RequireChecksum = true
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r.Header.Set("Content-MD5", "XrY7u+Ae7tCTyyK7j1rNww==")
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "MD5": "XrY7u+Ae7tCTyyK7j1rNww==",
+ },
+ },
+ "https no algorithm require checksum": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.RequireChecksum = true
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "MD5": "XrY7u+Ae7tCTyyK7j1rNww==",
+ },
+ },
+ "http seekable": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "http payload hash already set": {
+ initContext: func(ctx context.Context) context.Context {
+ ctx = setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ ctx = v4.SetPayloadHash(ctx, "somehash")
+ return ctx
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectPayloadHash: "somehash",
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "http seekable checksum matches payload hash": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmSHA256))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Sha256": []string{"uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9",
+ expectChecksumMetadata: map[string]string{
+ "SHA256": "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=",
+ },
+ },
+ "http payload hash disabled": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableComputePayloadHash = false
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https no trailing checksum": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableTrailingChecksum = false
+ },
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "with content encoding set": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableTrailingChecksum = false
+ },
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r.Header.Set("Content-Encoding", "gzip")
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="},
+ "Content-Encoding": []string{"gzip"},
+ },
+ expectContentLength: 11,
+ expectPayload: []byte("hello world"),
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ },
+
+ "build error": {
+ "unknown algorithm": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string("unknown"))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectErr: "failed to parse algorithm",
+ expectBuildErr: true,
+ },
+ "no algorithm require checksum unseekable stream": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.RequireChecksum = true
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectErr: "unseekable stream is not supported",
+ expectBuildErr: true,
+ },
+ "http unseekable stream": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectErr: "unseekable stream is not supported",
+ expectBuildErr: true,
+ },
+ "http stream read error": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 128
+ r = requestMust(r.SetStream(&mockReadSeeker{
+ Reader: iotest.ErrReader(fmt.Errorf("read error")),
+ }))
+ return r
+ }(),
+ },
+ expectErr: "failed to read stream to compute hash",
+ expectBuildErr: true,
+ },
+ "http stream rewind error": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("http://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(&errSeekReader{
+ Reader: strings.NewReader("hello world"),
+ }))
+ return r
+ }(),
+ },
+ expectErr: "failed to rewind stream",
+ expectBuildErr: true,
+ },
+ "https no trailing unseekable stream": {
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableTrailingChecksum = false
+ },
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectErr: "unseekable stream is not supported",
+ expectBuildErr: true,
+ },
+ },
+
+ "finalize handled": {
+ "https unseekable": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Decoded-Content-Length": []string{"11"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: 52,
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https unseekable unknown length": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = -1
+ r = requestMust(r.SetStream(ioutil.NopCloser(bytes.NewBuffer([]byte("hello world")))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: -1,
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https seekable": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Decoded-Content-Length": []string{"11"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: 52,
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https seekable unknown length": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = -1
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Decoded-Content-Length": []string{"11"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: 52,
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https no compute payload hash": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableComputePayloadHash = false
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Decoded-Content-Length": []string{"11"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: 52,
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "https no decode content length": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ optionsFn: func(o *computeInputPayloadChecksum) {
+ o.EnableDecodedContentLengthHeader = false
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "Content-Encoding": []string{"aws-chunked"},
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ },
+ expectContentLength: 52,
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ "with content encoding set": {
+ initContext: func(ctx context.Context) context.Context {
+ return setContextInputAlgorithm(ctx, string(AlgorithmCRC32))
+ },
+ buildInput: middleware.BuildInput{
+ Request: func() *smithyhttp.Request {
+ r := smithyhttp.NewStackRequest().(*smithyhttp.Request)
+ r.URL, _ = url.Parse("https://example.aws")
+ r.ContentLength = 11
+ r.Header.Set("Content-Encoding", "gzip")
+ r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world"))))
+ return r
+ }(),
+ },
+ expectHeader: http.Header{
+ "X-Amz-Trailer": []string{"x-amz-checksum-crc32"},
+ "X-Amz-Decoded-Content-Length": []string{"11"},
+ "Content-Encoding": []string{"gzip", "aws-chunked"},
+ },
+ expectContentLength: 52,
+ expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
+ expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"),
+ expectDeferToFinalize: true,
+ expectChecksumMetadata: map[string]string{
+ "CRC32": "DUoRhQ==",
+ },
+ },
+ },
+ }
+
+ for name, cs := range cases {
+ t.Run(name, func(t *testing.T) {
+ for name, c := range cs {
+ t.Run(name, func(t *testing.T) {
+ m := &computeInputPayloadChecksum{
+ EnableTrailingChecksum: true,
+ EnableComputePayloadHash: true,
+ EnableDecodedContentLengthHeader: true,
+ }
+ if c.optionsFn != nil {
+ c.optionsFn(m)
+ }
+
+ ctx := context.Background()
+ var logged bytes.Buffer
+ logger := logging.LoggerFunc(
+ func(classification logging.Classification, format string, v ...interface{}) {
+ fmt.Fprintf(&logged, format, v...)
+ },
+ )
+
+ stack := middleware.NewStack("test", smithyhttp.NewStackRequest)
+ middleware.AddSetLoggerMiddleware(stack, logger)
+
+ //------------------------------
+ // Build handler
+ //------------------------------
+ // On return path validate any errors were expected.
+ stack.Build.Add(middleware.BuildMiddlewareFunc(
+ "build-assert",
+ func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) (
+ out middleware.BuildOutput, metadata middleware.Metadata, err error,
+ ) {
+ // ignore initial build input for the test case's build input.
+ out, metadata, err = next.HandleBuild(ctx, c.buildInput)
+ if err == nil && c.expectBuildErr {
+ t.Fatalf("expect build error, got none")
+ }
+
+ if !m.buildHandlerRun {
+ t.Fatalf("expect build handler run")
+ }
+ return out, metadata, err
+ },
+ ), middleware.After)
+
+ // Build middleware
+ stack.Build.Add(m, middleware.After)
+
+ // Validate defer to finalize was performed as expected
+ stack.Build.Add(middleware.BuildMiddlewareFunc(
+ "assert-defer-to-finalize",
+ func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) (
+ out middleware.BuildOutput, metadata middleware.Metadata, err error,
+ ) {
+ if e, a := c.expectDeferToFinalize, m.deferToFinalizeHandler; e != a {
+ t.Fatalf("expect %v defer to finalize, got %v", e, a)
+ }
+ return next.HandleBuild(ctx, input)
+ },
+ ), middleware.After)
+
+ //------------------------------
+ // Finalize handler
+ //------------------------------
+ if m.EnableTrailingChecksum {
+ // On return path assert any errors are expected.
+ stack.Finalize.Add(middleware.FinalizeMiddlewareFunc(
+ "build-assert",
+ func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) (
+ out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
+ ) {
+ out, metadata, err = next.HandleFinalize(ctx, input)
+ if err == nil && c.expectFinalizeErr {
+ t.Fatalf("expect finalize error, got none")
+ }
+
+ return out, metadata, err
+ },
+ ), middleware.After)
+
+ // Add finalize middleware
+ stack.Finalize.Add(m, middleware.After)
+ }
+
+ //------------------------------
+ // Request validation
+ //------------------------------
+ validateRequestHandler := middleware.HandlerFunc(
+ func(ctx context.Context, input interface{}) (
+ output interface{}, metadata middleware.Metadata, err error,
+ ) {
+ request := input.(*smithyhttp.Request)
+
+ if diff := cmp.Diff(c.expectHeader, request.Header); diff != "" {
+ t.Errorf("expect header to match:\n%s", diff)
+ }
+ if e, a := c.expectContentLength, request.ContentLength; e != a {
+ t.Errorf("expect %v content length, got %v", e, a)
+ }
+
+ stream := request.GetStream()
+ if e, a := stream != nil, c.expectPayload != nil; e != a {
+ t.Fatalf("expect nil payload %t, got %t", e, a)
+ }
+ if stream == nil {
+ return
+ }
+
+ actualPayload, err := ioutil.ReadAll(stream)
+ if err == nil && c.expectReadErr {
+ t.Fatalf("expected read error, got none")
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match:\n%s", diff)
+ }
+
+ payloadHash := v4.GetPayloadHash(ctx)
+ if e, a := c.expectPayloadHash, payloadHash; e != a {
+ t.Errorf("expect %v payload hash, got %v", e, a)
+ }
+
+ return &smithyhttp.Response{}, metadata, nil
+ },
+ )
+
+ if c.initContext != nil {
+ ctx = c.initContext(ctx)
+ }
+ _, metadata, err := stack.HandleMiddleware(ctx, struct{}{}, validateRequestHandler)
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expected error: %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expected error %v to contain %v", err, c.expectErr)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ if c.expectLogged != "" {
+ if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) {
+ t.Errorf("expected %q logged in:\n%s", e, a)
+ }
+ }
+
+ // assert computed input checksums metadata
+ computedMetadata, ok := GetComputedInputChecksums(metadata)
+ if e, a := ok, (c.expectChecksumMetadata != nil); e != a {
+ t.Fatalf("expect checksum metadata %t, got %t, %v", e, a, computedMetadata)
+ }
+ if c.expectChecksumMetadata != nil {
+ if diff := cmp.Diff(c.expectChecksumMetadata, computedMetadata); diff != "" {
+ t.Errorf("expect checksum metadata match\n%s", diff)
+ }
+ }
+ })
+ }
+ })
+ }
+}
+
+type mockReadSeeker struct {
+ io.Reader
+}
+
+func (r *mockReadSeeker) Seek(int64, int) (int64, error) {
+ return 0, nil
+}
+
+type errSeekReader struct {
+ io.Reader
+}
+
+func (r *errSeekReader) Seek(offset int64, whence int) (int64, error) {
+ if whence == io.SeekCurrent {
+ return 0, nil
+ }
+
+ return 0, fmt.Errorf("seek failed")
+}
+
+func requestMust(r *smithyhttp.Request, err error) *smithyhttp.Request {
+ if err != nil {
+ panic(err.Error())
+ }
+
+ return r
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go
new file mode 100644
index 0000000000..f729525497
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go
@@ -0,0 +1,117 @@
+package checksum
+
+import (
+ "context"
+
+ "github.com/aws/smithy-go/middleware"
+)
+
+// setupChecksumContext is the initial middleware that looks up the input
+// used to configure checksum behavior. This middleware must be executed before
+// input validation step or any other checksum middleware.
+type setupInputContext struct {
+ // GetAlgorithm is a function to get the checksum algorithm of the
+ // input payload from the input parameters.
+ //
+ // Given the input parameter value, the function must return the algorithm
+ // and true, or false if no algorithm is specified.
+ GetAlgorithm func(interface{}) (string, bool)
+}
+
+// ID for the middleware
+func (m *setupInputContext) ID() string {
+ return "AWSChecksum:SetupInputContext"
+}
+
+// HandleInitialize initialization middleware that setups up the checksum
+// context based on the input parameters provided in the stack.
+func (m *setupInputContext) HandleInitialize(
+ ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
+) (
+ out middleware.InitializeOutput, metadata middleware.Metadata, err error,
+) {
+ // Check if validation algorithm is specified.
+ if m.GetAlgorithm != nil {
+ // check is input resource has a checksum algorithm
+ algorithm, ok := m.GetAlgorithm(in.Parameters)
+ if ok && len(algorithm) != 0 {
+ ctx = setContextInputAlgorithm(ctx, algorithm)
+ }
+ }
+
+ return next.HandleInitialize(ctx, in)
+}
+
+// inputAlgorithmKey is the key set on context used to identify, retrieves the
+// request checksum algorithm if present on the context.
+type inputAlgorithmKey struct{}
+
+// setContextInputAlgorithm sets the request checksum algorithm on the context.
+//
+// Scoped to stack values.
+func setContextInputAlgorithm(ctx context.Context, value string) context.Context {
+ return middleware.WithStackValue(ctx, inputAlgorithmKey{}, value)
+}
+
+// getContextInputAlgorithm returns the checksum algorithm from the context if
+// one was specified. Empty string is returned if one is not specified.
+//
+// Scoped to stack values.
+func getContextInputAlgorithm(ctx context.Context) (v string) {
+ v, _ = middleware.GetStackValue(ctx, inputAlgorithmKey{}).(string)
+ return v
+}
+
+type setupOutputContext struct {
+ // GetValidationMode is a function to get the checksum validation
+ // mode of the output payload from the input parameters.
+ //
+ // Given the input parameter value, the function must return the validation
+ // mode and true, or false if no mode is specified.
+ GetValidationMode func(interface{}) (string, bool)
+}
+
+// ID for the middleware
+func (m *setupOutputContext) ID() string {
+ return "AWSChecksum:SetupOutputContext"
+}
+
+// HandleInitialize initialization middleware that setups up the checksum
+// context based on the input parameters provided in the stack.
+func (m *setupOutputContext) HandleInitialize(
+ ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
+) (
+ out middleware.InitializeOutput, metadata middleware.Metadata, err error,
+) {
+ // Check if validation mode is specified.
+ if m.GetValidationMode != nil {
+ // check is input resource has a checksum algorithm
+ mode, ok := m.GetValidationMode(in.Parameters)
+ if ok && len(mode) != 0 {
+ ctx = setContextOutputValidationMode(ctx, mode)
+ }
+ }
+
+ return next.HandleInitialize(ctx, in)
+}
+
+// outputValidationModeKey is the key set on context used to identify if
+// output checksum validation is enabled.
+type outputValidationModeKey struct{}
+
+// setContextOutputValidationMode sets the request checksum
+// algorithm on the context.
+//
+// Scoped to stack values.
+func setContextOutputValidationMode(ctx context.Context, value string) context.Context {
+ return middleware.WithStackValue(ctx, outputValidationModeKey{}, value)
+}
+
+// getContextOutputValidationMode returns response checksum validation state,
+// if one was specified. Empty string is returned if one is not specified.
+//
+// Scoped to stack values.
+func getContextOutputValidationMode(ctx context.Context) (v string) {
+ v, _ = middleware.GetStackValue(ctx, outputValidationModeKey{}).(string)
+ return v
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go
new file mode 100644
index 0000000000..3235983bad
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go
@@ -0,0 +1,143 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "context"
+ "testing"
+
+ "github.com/aws/smithy-go/middleware"
+)
+
+func TestSetupInput(t *testing.T) {
+ type Params struct {
+ Value string
+ }
+
+ cases := map[string]struct {
+ inputParams interface{}
+ getAlgorithm func(interface{}) (string, bool)
+ expectValue string
+ }{
+ "nil accessor": {
+ expectValue: "",
+ },
+ "found empty": {
+ inputParams: Params{Value: ""},
+ getAlgorithm: func(v interface{}) (string, bool) {
+ vv := v.(Params)
+ return vv.Value, true
+ },
+ expectValue: "",
+ },
+ "found not set": {
+ inputParams: Params{Value: ""},
+ getAlgorithm: func(v interface{}) (string, bool) {
+ return "", false
+ },
+ expectValue: "",
+ },
+ "found": {
+ inputParams: Params{Value: "abc123"},
+ getAlgorithm: func(v interface{}) (string, bool) {
+ vv := v.(Params)
+ return vv.Value, true
+ },
+ expectValue: "abc123",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ m := setupInputContext{
+ GetAlgorithm: c.getAlgorithm,
+ }
+
+ _, _, err := m.HandleInitialize(context.Background(),
+ middleware.InitializeInput{Parameters: c.inputParams},
+ middleware.InitializeHandlerFunc(
+ func(ctx context.Context, input middleware.InitializeInput) (
+ out middleware.InitializeOutput, metadata middleware.Metadata, err error,
+ ) {
+ v := getContextInputAlgorithm(ctx)
+ if e, a := c.expectValue, v; e != a {
+ t.Errorf("expect value %v, got %v", e, a)
+ }
+
+ return out, metadata, nil
+ },
+ ))
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ })
+ }
+}
+
+func TestSetupOutput(t *testing.T) {
+ type Params struct {
+ Value string
+ }
+
+ cases := map[string]struct {
+ inputParams interface{}
+ getValidationMode func(interface{}) (string, bool)
+ expectValue string
+ }{
+ "nil accessor": {
+ expectValue: "",
+ },
+ "found empty": {
+ inputParams: Params{Value: ""},
+ getValidationMode: func(v interface{}) (string, bool) {
+ vv := v.(Params)
+ return vv.Value, true
+ },
+ expectValue: "",
+ },
+ "found not set": {
+ inputParams: Params{Value: ""},
+ getValidationMode: func(v interface{}) (string, bool) {
+ return "", false
+ },
+ expectValue: "",
+ },
+ "found": {
+ inputParams: Params{Value: "abc123"},
+ getValidationMode: func(v interface{}) (string, bool) {
+ vv := v.(Params)
+ return vv.Value, true
+ },
+ expectValue: "abc123",
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ m := setupOutputContext{
+ GetValidationMode: c.getValidationMode,
+ }
+
+ _, _, err := m.HandleInitialize(context.Background(),
+ middleware.InitializeInput{Parameters: c.inputParams},
+ middleware.InitializeHandlerFunc(
+ func(ctx context.Context, input middleware.InitializeInput) (
+ out middleware.InitializeOutput, metadata middleware.Metadata, err error,
+ ) {
+ v := getContextOutputValidationMode(ctx)
+ if e, a := c.expectValue, v; e != a {
+ t.Errorf("expect value %v, got %v", e, a)
+ }
+
+ return out, metadata, nil
+ },
+ ))
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ })
+ }
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go
new file mode 100644
index 0000000000..9fde12d86d
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go
@@ -0,0 +1,131 @@
+package checksum
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/aws/smithy-go"
+ "github.com/aws/smithy-go/logging"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+)
+
+// outputValidationAlgorithmsUsedKey is the metadata key for indexing the algorithms
+// that were used, by the middleware's validation.
+type outputValidationAlgorithmsUsedKey struct{}
+
+// GetOutputValidationAlgorithmsUsed returns the checksum algorithms used
+// stored in the middleware Metadata. Returns false if no algorithms were
+// stored in the Metadata.
+func GetOutputValidationAlgorithmsUsed(m middleware.Metadata) ([]string, bool) {
+ vs, ok := m.Get(outputValidationAlgorithmsUsedKey{}).([]string)
+ return vs, ok
+}
+
+// SetOutputValidationAlgorithmsUsed stores the checksum algorithms used in the
+// middleware Metadata.
+func SetOutputValidationAlgorithmsUsed(m *middleware.Metadata, vs []string) {
+ m.Set(outputValidationAlgorithmsUsedKey{}, vs)
+}
+
+// validateOutputPayloadChecksum middleware computes payload checksum of the
+// received response and validates with checksum returned by the service.
+type validateOutputPayloadChecksum struct {
+ // Algorithms represents a priority-ordered list of valid checksum
+ // algorithm that should be validated when present in HTTP response
+ // headers.
+ Algorithms []Algorithm
+
+ // IgnoreMultipartValidation indicates multipart checksums ending with "-#"
+ // will be ignored.
+ IgnoreMultipartValidation bool
+
+ // When set the middleware will log when output does not have checksum or
+ // algorithm to validate.
+ LogValidationSkipped bool
+
+ // When set the middleware will log when the output contains a multipart
+ // checksum that was, skipped and not validated.
+ LogMultipartValidationSkipped bool
+}
+
+func (m *validateOutputPayloadChecksum) ID() string {
+ return "AWSChecksum:ValidateOutputPayloadChecksum"
+}
+
+// HandleDeserialize is a Deserialize middleware that wraps the HTTP response
+// body with an io.ReadCloser that will validate the its checksum.
+func (m *validateOutputPayloadChecksum) HandleDeserialize(
+ ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
+) (
+ out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
+) {
+ out, metadata, err = next.HandleDeserialize(ctx, in)
+ if err != nil {
+ return out, metadata, err
+ }
+
+ // If there is no validation mode specified nothing is supported.
+ if mode := getContextOutputValidationMode(ctx); mode != "ENABLED" {
+ return out, metadata, err
+ }
+
+ response, ok := out.RawResponse.(*smithyhttp.Response)
+ if !ok {
+ return out, metadata, &smithy.DeserializationError{
+ Err: fmt.Errorf("unknown transport type %T", out.RawResponse),
+ }
+ }
+
+ var expectedChecksum string
+ var algorithmToUse Algorithm
+ for _, algorithm := range m.Algorithms {
+ value := response.Header.Get(AlgorithmHTTPHeader(algorithm))
+ if len(value) == 0 {
+ continue
+ }
+
+ expectedChecksum = value
+ algorithmToUse = algorithm
+ }
+
+ // TODO this must validate the validation mode is set to enabled.
+
+ logger := middleware.GetLogger(ctx)
+
+ // Skip validation if no checksum algorithm or checksum is available.
+ if len(expectedChecksum) == 0 || len(algorithmToUse) == 0 {
+ if m.LogValidationSkipped {
+ // TODO this probably should have more information about the
+ // operation output that won't be validated.
+ logger.Logf(logging.Warn,
+ "Response has no supported checksum. Not validating response payload.")
+ }
+ return out, metadata, nil
+ }
+
+ // Ignore multipart validation
+ if m.IgnoreMultipartValidation && strings.Contains(expectedChecksum, "-") {
+ if m.LogMultipartValidationSkipped {
+ // TODO this probably should have more information about the
+ // operation output that won't be validated.
+ logger.Logf(logging.Warn, "Skipped validation of multipart checksum.")
+ }
+ return out, metadata, nil
+ }
+
+ body, err := newValidateChecksumReader(response.Body, algorithmToUse, expectedChecksum)
+ if err != nil {
+ return out, metadata, fmt.Errorf("failed to create checksum validation reader, %w", err)
+ }
+ response.Body = body
+
+ // Update the metadata to include the set of the checksum algorithms that
+ // will be validated.
+ SetOutputValidationAlgorithmsUsed(&metadata, []string{
+ string(algorithmToUse),
+ })
+
+ return out, metadata, nil
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go
new file mode 100644
index 0000000000..e6cf7054e8
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go
@@ -0,0 +1,263 @@
+//go:build go1.16
+// +build go1.16
+
+package checksum
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+ "testing/iotest"
+
+ "github.com/aws/smithy-go/logging"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestValidateOutputPayloadChecksum(t *testing.T) {
+ cases := map[string]struct {
+ response *smithyhttp.Response
+ validateOptions func(*validateOutputPayloadChecksum)
+ modifyContext func(context.Context) context.Context
+ expectHaveAlgorithmsUsed bool
+ expectAlgorithmsUsed []string
+ expectErr string
+ expectReadErr string
+ expectLogged string
+ expectPayload []byte
+ }{
+ "success": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==")
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ expectHaveAlgorithmsUsed: true,
+ expectAlgorithmsUsed: []string{"CRC32"},
+ expectPayload: []byte("hello world"),
+ },
+ "failure": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==")
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ expectReadErr: "checksum did not match",
+ },
+ "read error": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==")
+ return h
+ }(),
+ Body: ioutil.NopCloser(iotest.ErrReader(fmt.Errorf("some read error"))),
+ },
+ },
+ expectReadErr: "some read error",
+ },
+ "unsupported algorithm": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader("unsupported"), "AAAAAA==")
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ expectLogged: "no supported checksum",
+ expectPayload: []byte("hello world"),
+ },
+ "no output validation model": {
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ expectPayload: []byte("hello world"),
+ },
+ "unknown output validation model": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "something else")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ expectPayload: []byte("hello world"),
+ },
+ "success ignore multipart checksum": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==")
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ validateOptions: func(o *validateOutputPayloadChecksum) {
+ o.IgnoreMultipartValidation = true
+ },
+ expectHaveAlgorithmsUsed: true,
+ expectAlgorithmsUsed: []string{"CRC32"},
+ expectPayload: []byte("hello world"),
+ },
+ "success skip ignore multipart checksum": {
+ modifyContext: func(ctx context.Context) context.Context {
+ return setContextOutputValidationMode(ctx, "ENABLED")
+ },
+ response: &smithyhttp.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Header: func() http.Header {
+ h := http.Header{}
+ h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==-12")
+ return h
+ }(),
+ Body: ioutil.NopCloser(strings.NewReader("hello world")),
+ },
+ },
+ validateOptions: func(o *validateOutputPayloadChecksum) {
+ o.IgnoreMultipartValidation = true
+ },
+ expectLogged: "Skipped validation of multipart checksum",
+ expectPayload: []byte("hello world"),
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ var logged bytes.Buffer
+ ctx := middleware.SetLogger(context.Background(), logging.LoggerFunc(
+ func(classification logging.Classification, format string, v ...interface{}) {
+ fmt.Fprintf(&logged, format, v...)
+ }))
+
+ if c.modifyContext != nil {
+ ctx = c.modifyContext(ctx)
+ }
+
+ validateOutput := validateOutputPayloadChecksum{
+ Algorithms: []Algorithm{
+ AlgorithmSHA1, AlgorithmCRC32, AlgorithmCRC32C,
+ },
+ LogValidationSkipped: true,
+ LogMultipartValidationSkipped: true,
+ }
+ if c.validateOptions != nil {
+ c.validateOptions(&validateOutput)
+ }
+
+ out, meta, err := validateOutput.HandleDeserialize(ctx,
+ middleware.DeserializeInput{},
+ middleware.DeserializeHandlerFunc(
+ func(ctx context.Context, input middleware.DeserializeInput) (
+ out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
+ ) {
+ out.RawResponse = c.response
+ return out, metadata, nil
+ },
+ ),
+ )
+ if err == nil && len(c.expectErr) != 0 {
+ t.Fatalf("expect error %v, got none", c.expectErr)
+ }
+ if err != nil && len(c.expectErr) == 0 {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectErr) {
+ t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
+ }
+ if c.expectErr != "" {
+ return
+ }
+
+ response := out.RawResponse.(*smithyhttp.Response)
+
+ actualPayload, err := ioutil.ReadAll(response.Body)
+ if err == nil && len(c.expectReadErr) != 0 {
+ t.Fatalf("expected read error: %v, got none", c.expectReadErr)
+ }
+ if err != nil && len(c.expectReadErr) == 0 {
+ t.Fatalf("expect no read error, got %v", err)
+ }
+ if err != nil && !strings.Contains(err.Error(), c.expectReadErr) {
+ t.Fatalf("expected read error %v to contain %v", err, c.expectReadErr)
+ }
+ if c.expectReadErr != "" {
+ return
+ }
+
+ if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) || !((e == "") == (a == "")) {
+ t.Errorf("expected %q logged in:\n%s", e, a)
+ }
+
+ if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" {
+ t.Errorf("expect payload match:\n%s", diff)
+ }
+
+ if err = response.Body.Close(); err != nil {
+ t.Errorf("expect no close error, got %v", err)
+ }
+
+ values, ok := GetOutputValidationAlgorithmsUsed(meta)
+ if ok != c.expectHaveAlgorithmsUsed {
+ t.Errorf("expect metadata to contain algorithms used, %t", c.expectHaveAlgorithmsUsed)
+ }
+ if diff := cmp.Diff(c.expectAlgorithmsUsed, values); diff != "" {
+ t.Errorf("expect algorithms used to match\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make
new file mode 100644
index 0000000000..85e34f039f
--- /dev/null
+++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make
@@ -0,0 +1,28 @@
+GO_LIBRARY()
+
+LICENSE(Apache-2.0)
+
+SRCS(
+ algorithms.go
+ aws_chunked_encoding.go
+ go_module_metadata.go
+ middleware_add.go
+ middleware_compute_input_checksum.go
+ middleware_setup_context.go
+ middleware_validate_output.go
+)
+
+GO_TEST_SRCS(
+ algorithms_test.go
+ aws_chunked_encoding_test.go
+ middleware_add_test.go
+ middleware_compute_input_checksum_test.go
+ middleware_setup_context_test.go
+ middleware_validate_output_test.go
+)
+
+END()
+
+RECURSE(
+ gotest
+)