diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 21:55:07 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-12-12 22:25:10 +0300 |
commit | 4967f99474a4040ba150eb04995de06342252718 (patch) | |
tree | c9c118836513a8fab6e9fcfb25be5d404338bca7 /vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum | |
parent | 2ce9cccb9b0bdd4cd7a3491dc5cbf8687cda51de (diff) | |
download | ydb-4967f99474a4040ba150eb04995de06342252718.tar.gz |
YQ Connector: prepare code base for S3 integration
1. Кодовая база Коннектора переписана с помощью Go дженериков так, чтобы добавление нового источника данных (в частности S3 + csv) максимально переиспользовало имеющийся код (чтобы сохранялась логика нарезания на блоки данных, учёт трафика и пр.)
2. API Connector расширено для работы с S3, но ещё пока не протестировано.
Diffstat (limited to 'vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum')
15 files changed, 4416 insertions, 0 deletions
diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go new file mode 100644 index 0000000000..a17041c35d --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go @@ -0,0 +1,323 @@ +package checksum + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "hash/crc32" + "io" + "strings" + "sync" +) + +// Algorithm represents the checksum algorithms supported +type Algorithm string + +// Enumeration values for supported checksum Algorithms. +const ( + // AlgorithmCRC32C represents CRC32C hash algorithm + AlgorithmCRC32C Algorithm = "CRC32C" + + // AlgorithmCRC32 represents CRC32 hash algorithm + AlgorithmCRC32 Algorithm = "CRC32" + + // AlgorithmSHA1 represents SHA1 hash algorithm + AlgorithmSHA1 Algorithm = "SHA1" + + // AlgorithmSHA256 represents SHA256 hash algorithm + AlgorithmSHA256 Algorithm = "SHA256" +) + +var supportedAlgorithms = []Algorithm{ + AlgorithmCRC32C, + AlgorithmCRC32, + AlgorithmSHA1, + AlgorithmSHA256, +} + +func (a Algorithm) String() string { return string(a) } + +// ParseAlgorithm attempts to parse the provided value into a checksum +// algorithm, matching without case. Returns the algorithm matched, or an error +// if the algorithm wasn't matched. +func ParseAlgorithm(v string) (Algorithm, error) { + for _, a := range supportedAlgorithms { + if strings.EqualFold(string(a), v) { + return a, nil + } + } + return "", fmt.Errorf("unknown checksum algorithm, %v", v) +} + +// FilterSupportedAlgorithms filters the set of algorithms, returning a slice +// of algorithms that are supported. +func FilterSupportedAlgorithms(vs []string) []Algorithm { + found := map[Algorithm]struct{}{} + + supported := make([]Algorithm, 0, len(supportedAlgorithms)) + for _, v := range vs { + for _, a := range supportedAlgorithms { + // Only consider algorithms that are supported + if !strings.EqualFold(v, string(a)) { + continue + } + // Ignore duplicate algorithms in list. + if _, ok := found[a]; ok { + continue + } + + supported = append(supported, a) + found[a] = struct{}{} + } + } + return supported +} + +// NewAlgorithmHash returns a hash.Hash for the checksum algorithm. Error is +// returned if the algorithm is unknown. +func NewAlgorithmHash(v Algorithm) (hash.Hash, error) { + switch v { + case AlgorithmSHA1: + return sha1.New(), nil + case AlgorithmSHA256: + return sha256.New(), nil + case AlgorithmCRC32: + return crc32.NewIEEE(), nil + case AlgorithmCRC32C: + return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil + default: + return nil, fmt.Errorf("unknown checksum algorithm, %v", v) + } +} + +// AlgorithmChecksumLength returns the length of the algorithm's checksum in +// bytes. If the algorithm is not known, an error is returned. +func AlgorithmChecksumLength(v Algorithm) (int, error) { + switch v { + case AlgorithmSHA1: + return sha1.Size, nil + case AlgorithmSHA256: + return sha256.Size, nil + case AlgorithmCRC32: + return crc32.Size, nil + case AlgorithmCRC32C: + return crc32.Size, nil + default: + return 0, fmt.Errorf("unknown checksum algorithm, %v", v) + } +} + +const awsChecksumHeaderPrefix = "x-amz-checksum-" + +// AlgorithmHTTPHeader returns the HTTP header for the algorithm's hash. +func AlgorithmHTTPHeader(v Algorithm) string { + return awsChecksumHeaderPrefix + strings.ToLower(string(v)) +} + +// base64EncodeHashSum computes base64 encoded checksum of a given running +// hash. The running hash must already have content written to it. Returns the +// byte slice of checksum and an error +func base64EncodeHashSum(h hash.Hash) []byte { + sum := h.Sum(nil) + sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum))) + base64.StdEncoding.Encode(sum64, sum) + return sum64 +} + +// hexEncodeHashSum computes hex encoded checksum of a given running hash. The +// running hash must already have content written to it. Returns the byte slice +// of checksum and an error +func hexEncodeHashSum(h hash.Hash) []byte { + sum := h.Sum(nil) + sumHex := make([]byte, hex.EncodedLen(len(sum))) + hex.Encode(sumHex, sum) + return sumHex +} + +// computeMD5Checksum computes base64 MD5 checksum of an io.Reader's contents. +// Returns the byte slice of MD5 checksum and an error. +func computeMD5Checksum(r io.Reader) ([]byte, error) { + h := md5.New() + + // Copy errors may be assumed to be from the body. + if _, err := io.Copy(h, r); err != nil { + return nil, fmt.Errorf("failed compute MD5 hash of reader, %w", err) + } + + // Encode the MD5 checksum in base64. + return base64EncodeHashSum(h), nil +} + +// computeChecksumReader provides a reader wrapping an underlying io.Reader to +// compute the checksum of the stream's bytes. +type computeChecksumReader struct { + stream io.Reader + algorithm Algorithm + hasher hash.Hash + base64ChecksumLen int + + mux sync.RWMutex + lockedChecksum string + lockedErr error +} + +// newComputeChecksumReader returns a computeChecksumReader for the stream and +// algorithm specified. Returns error if unable to create the reader, or +// algorithm is unknown. +func newComputeChecksumReader(stream io.Reader, algorithm Algorithm) (*computeChecksumReader, error) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return nil, err + } + + checksumLength, err := AlgorithmChecksumLength(algorithm) + if err != nil { + return nil, err + } + + return &computeChecksumReader{ + stream: io.TeeReader(stream, hasher), + algorithm: algorithm, + hasher: hasher, + base64ChecksumLen: base64.StdEncoding.EncodedLen(checksumLength), + }, nil +} + +// Read wraps the underlying reader. When the underlying reader returns EOF, +// the checksum of the reader will be computed, and can be retrieved with +// ChecksumBase64String. +func (r *computeChecksumReader) Read(p []byte) (int, error) { + n, err := r.stream.Read(p) + if err == nil { + return n, nil + } else if err != io.EOF { + r.mux.Lock() + defer r.mux.Unlock() + + r.lockedErr = err + return n, err + } + + b := base64EncodeHashSum(r.hasher) + + r.mux.Lock() + defer r.mux.Unlock() + + r.lockedChecksum = string(b) + + return n, err +} + +func (r *computeChecksumReader) Algorithm() Algorithm { + return r.algorithm +} + +// Base64ChecksumLength returns the base64 encoded length of the checksum for +// algorithm. +func (r *computeChecksumReader) Base64ChecksumLength() int { + return r.base64ChecksumLen +} + +// Base64Checksum returns the base64 checksum for the algorithm, or error if +// the underlying reader returned a non-EOF error. +// +// Safe to be called concurrently, but will return an error until after the +// underlying reader is returns EOF. +func (r *computeChecksumReader) Base64Checksum() (string, error) { + r.mux.RLock() + defer r.mux.RUnlock() + + if r.lockedErr != nil { + return "", r.lockedErr + } + + if r.lockedChecksum == "" { + return "", fmt.Errorf( + "checksum not available yet, called before reader returns EOF", + ) + } + + return r.lockedChecksum, nil +} + +// validateChecksumReader implements io.ReadCloser interface. The wrapper +// performs checksum validation when the underlying reader has been fully read. +type validateChecksumReader struct { + originalBody io.ReadCloser + body io.Reader + hasher hash.Hash + algorithm Algorithm + expectChecksum string +} + +// newValidateChecksumReader returns a configured io.ReadCloser that performs +// checksum validation when the underlying reader has been fully read. +func newValidateChecksumReader( + body io.ReadCloser, + algorithm Algorithm, + expectChecksum string, +) (*validateChecksumReader, error) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return nil, err + } + + return &validateChecksumReader{ + originalBody: body, + body: io.TeeReader(body, hasher), + hasher: hasher, + algorithm: algorithm, + expectChecksum: expectChecksum, + }, nil +} + +// Read attempts to read from the underlying stream while also updating the +// running hash. If the underlying stream returns with an EOF error, the +// checksum of the stream will be collected, and compared against the expected +// checksum. If the checksums do not match, an error will be returned. +// +// If a non-EOF error occurs when reading the underlying stream, that error +// will be returned and the checksum for the stream will be discarded. +func (c *validateChecksumReader) Read(p []byte) (n int, err error) { + n, err = c.body.Read(p) + if err == io.EOF { + if checksumErr := c.validateChecksum(); checksumErr != nil { + return n, checksumErr + } + } + + return n, err +} + +// Close closes the underlying reader, returning any error that occurred in the +// underlying reader. +func (c *validateChecksumReader) Close() (err error) { + return c.originalBody.Close() +} + +func (c *validateChecksumReader) validateChecksum() error { + // Compute base64 encoded checksum hash of the payload's read bytes. + v := base64EncodeHashSum(c.hasher) + if e, a := c.expectChecksum, string(v); !strings.EqualFold(e, a) { + return validationError{ + Algorithm: c.algorithm, Expect: e, Actual: a, + } + } + + return nil +} + +type validationError struct { + Algorithm Algorithm + Expect string + Actual string +} + +func (v validationError) Error() string { + return fmt.Sprintf("checksum did not match: algorithm %v, expect %v, actual %v", + v.Algorithm, v.Expect, v.Actual) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go new file mode 100644 index 0000000000..3f8b27018a --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms_test.go @@ -0,0 +1,470 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "fmt" + "hash/crc32" + "io" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/google/go-cmp/cmp" +) + +func TestComputeChecksumReader(t *testing.T) { + cases := map[string]struct { + Input io.Reader + Algorithm Algorithm + ExpectErr string + ExpectChecksumLen int + + ExpectRead string + ExpectReadErr string + ExpectComputeErr string + ExpectChecksum string + }{ + "unknown algorithm": { + Input: bytes.NewBuffer(nil), + Algorithm: Algorithm("something"), + ExpectErr: "unknown checksum algorithm", + }, + "read error": { + Input: iotest.ErrReader(fmt.Errorf("some error")), + Algorithm: AlgorithmSHA256, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size), + ExpectReadErr: "some error", + ExpectComputeErr: "some error", + }, + "crc32c": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmCRC32C, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size), + ExpectRead: "hello world", + ExpectChecksum: "yZRlqg==", + }, + "crc32": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmCRC32, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(crc32.Size), + ExpectRead: "hello world", + ExpectChecksum: "DUoRhQ==", + }, + "sha1": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmSHA1, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha1.Size), + ExpectRead: "hello world", + ExpectChecksum: "Kq5sNclPz7QV2+lfQIuc6R7oRu0=", + }, + "sha256": { + Input: strings.NewReader("hello world"), + Algorithm: AlgorithmSHA256, + ExpectChecksumLen: base64.StdEncoding.EncodedLen(sha256.Size), + ExpectRead: "hello world", + ExpectChecksum: "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + // Validate reader can be created as expected. + r, err := newComputeChecksumReader(c.Input, c.Algorithm) + if err == nil && len(c.ExpectErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectErr) + } + if err != nil && len(c.ExpectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectErr, err) + } + if c.ExpectErr != "" { + return + } + + if e, a := c.Algorithm, r.Algorithm(); e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + + // Validate expected checksum length. + if e, a := c.ExpectChecksumLen, r.Base64ChecksumLength(); e != a { + t.Errorf("expect %v checksum length, got %v", e, a) + } + + // Validate read reads underlying stream's bytes as expected. + b, err := ioutil.ReadAll(r) + if err == nil && len(c.ExpectReadErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectReadErr) + } + if err != nil && len(c.ExpectReadErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectReadErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectReadErr, err) + } + if len(c.ExpectReadErr) != 0 { + return + } + + if diff := cmp.Diff(string(c.ExpectRead), string(b)); diff != "" { + t.Errorf("expect read match, got\n%v", diff) + } + + // validate computed base64 + v, err := r.Base64Checksum() + if err == nil && len(c.ExpectComputeErr) != 0 { + t.Fatalf("expect error %v, got none", c.ExpectComputeErr) + } + if err != nil && len(c.ExpectComputeErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.ExpectComputeErr) { + t.Fatalf("expect error to contain %v, got %v", c.ExpectComputeErr, err) + } + if diff := cmp.Diff(c.ExpectChecksum, v); diff != "" { + t.Errorf("expect checksum match, got\n%v", diff) + } + if c.ExpectComputeErr != "" { + return + } + + if e, a := c.ExpectChecksumLen, len(v); e != a { + t.Errorf("expect computed checksum length %v, got %v", e, a) + } + }) + } +} + +func TestComputeChecksumReader_earlyGetChecksum(t *testing.T) { + r, err := newComputeChecksumReader(strings.NewReader("hello world"), AlgorithmCRC32C) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + v, err := r.Base64Checksum() + if err == nil { + t.Fatalf("expect error, got none") + } + if err != nil && !strings.Contains(err.Error(), "not available") { + t.Fatalf("expect error to match, got %v", err) + } + if v != "" { + t.Errorf("expect no checksum, got %v", err) + } +} + +// TODO race condition case with many reads, and get checksum + +func TestValidateChecksumReader(t *testing.T) { + cases := map[string]struct { + payload io.ReadCloser + algorithm Algorithm + checksum string + expectErr string + expectChecksumErr string + expectedBody []byte + }{ + "unknown algorithm": { + payload: ioutil.NopCloser(bytes.NewBuffer(nil)), + algorithm: Algorithm("something"), + expectErr: "unknown checksum algorithm", + }, + "empty body": { + payload: ioutil.NopCloser(bytes.NewReader([]byte(""))), + algorithm: AlgorithmCRC32, + checksum: "AAAAAA==", + expectedBody: []byte(""), + }, + "standard body": { + payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))), + algorithm: AlgorithmCRC32, + checksum: "i9aeUg==", + expectedBody: []byte("Hello world"), + }, + "checksum mismatch": { + payload: ioutil.NopCloser(bytes.NewReader([]byte("Hello world"))), + algorithm: AlgorithmCRC32, + checksum: "AAAAAA==", + expectedBody: []byte("Hello world"), + expectChecksumErr: "checksum did not match", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + response, err := newValidateChecksumReader(c.payload, c.algorithm, c.checksum) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if response == nil { + if c.expectedBody == nil { + return + } + t.Fatalf("expected non nil response, got nil") + } + + actualResponse, err := ioutil.ReadAll(response) + if err == nil && len(c.expectChecksumErr) != 0 { + t.Fatalf("expected error %v, got none", c.expectChecksumErr) + } + if err != nil && !strings.Contains(err.Error(), c.expectChecksumErr) { + t.Fatalf("expected error %v to contain %v", err.Error(), c.expectChecksumErr) + } + + if diff := cmp.Diff(c.expectedBody, actualResponse); len(diff) != 0 { + t.Fatalf("found diff comparing response body %v", diff) + } + + err = response.Close() + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestComputeMD5Checksum(t *testing.T) { + cases := map[string]struct { + payload io.Reader + expectErr string + expectChecksum string + }{ + "empty payload": { + payload: strings.NewReader(""), + expectChecksum: "1B2M2Y8AsgTpgAmY7PhCfg==", + }, + "payload": { + payload: strings.NewReader("hello world"), + expectChecksum: "XrY7u+Ae7tCTyyK7j1rNww==", + }, + "error payload": { + payload: iotest.ErrReader(fmt.Errorf("some error")), + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualChecksum, err := computeMD5Checksum(c.payload) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectChecksum, string(actualChecksum); !strings.EqualFold(e, a) { + t.Errorf("expect %v checksum, got %v", e, a) + } + }) + } +} + +func TestParseAlgorithm(t *testing.T) { + cases := map[string]struct { + Value string + expectAlgorithm Algorithm + expectErr string + }{ + "crc32c": { + Value: "crc32c", + expectAlgorithm: AlgorithmCRC32C, + }, + "CRC32C": { + Value: "CRC32C", + expectAlgorithm: AlgorithmCRC32C, + }, + "crc32": { + Value: "crc32", + expectAlgorithm: AlgorithmCRC32, + }, + "CRC32": { + Value: "CRC32", + expectAlgorithm: AlgorithmCRC32, + }, + "sha1": { + Value: "sha1", + expectAlgorithm: AlgorithmSHA1, + }, + "SHA1": { + Value: "SHA1", + expectAlgorithm: AlgorithmSHA1, + }, + "sha256": { + Value: "sha256", + expectAlgorithm: AlgorithmSHA256, + }, + "SHA256": { + Value: "SHA256", + expectAlgorithm: AlgorithmSHA256, + }, + "empty": { + Value: "", + expectErr: "unknown checksum algorithm", + }, + "unknown": { + Value: "unknown", + expectErr: "unknown checksum algorithm", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + // Asserts + algorithm, err := ParseAlgorithm(c.Value) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectAlgorithm, algorithm; e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + }) + } +} + +func TestFilterSupportedAlgorithms(t *testing.T) { + cases := map[string]struct { + values []string + expectAlgorithms []Algorithm + }{ + "no algorithms": { + expectAlgorithms: []Algorithm{}, + }, + "no supported algorithms": { + values: []string{"abc", "123"}, + expectAlgorithms: []Algorithm{}, + }, + "duplicate algorithms": { + values: []string{"crc32", "crc32c", "crc32c"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + }, + }, + "preserve order": { + values: []string{"crc32", "crc32c", "sha1", "sha256"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + AlgorithmSHA1, + AlgorithmSHA256, + }, + }, + "preserve order 2": { + values: []string{"sha256", "crc32", "sha1", "crc32c"}, + expectAlgorithms: []Algorithm{ + AlgorithmSHA256, + AlgorithmCRC32, + AlgorithmSHA1, + AlgorithmCRC32C, + }, + }, + "mixed case": { + values: []string{"Crc32", "cRc32c", "shA1", "sHA256"}, + expectAlgorithms: []Algorithm{ + AlgorithmCRC32, + AlgorithmCRC32C, + AlgorithmSHA1, + AlgorithmSHA256, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + algorithms := FilterSupportedAlgorithms(c.values) + if diff := cmp.Diff(c.expectAlgorithms, algorithms); diff != "" { + t.Errorf("expect algorithms match\n%s", diff) + } + }) + } +} + +func TestAlgorithmChecksumLength(t *testing.T) { + cases := map[string]struct { + algorithm Algorithm + expectErr string + expectLength int + }{ + "empty": { + algorithm: "", + expectErr: "unknown checksum algorithm", + }, + "unknown": { + algorithm: "", + expectErr: "unknown checksum algorithm", + }, + "crc32": { + algorithm: AlgorithmCRC32, + expectLength: crc32.Size, + }, + "crc32c": { + algorithm: AlgorithmCRC32C, + expectLength: crc32.Size, + }, + "sha1": { + algorithm: AlgorithmSHA1, + expectLength: sha1.Size, + }, + "sha256": { + algorithm: AlgorithmSHA256, + expectLength: sha256.Size, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + l, err := AlgorithmChecksumLength(c.algorithm) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectLength, l; e != a { + t.Errorf("expect %v checksum length, got %v", e, a) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go new file mode 100644 index 0000000000..3bd320c437 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding.go @@ -0,0 +1,389 @@ +package checksum + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" +) + +const ( + crlf = "\r\n" + + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html + defaultChunkLength = 1024 * 64 + + awsTrailerHeaderName = "x-amz-trailer" + decodedContentLengthHeaderName = "x-amz-decoded-content-length" + + contentEncodingHeaderName = "content-encoding" + awsChunkedContentEncodingHeaderValue = "aws-chunked" + + trailerKeyValueSeparator = ":" +) + +var ( + crlfBytes = []byte(crlf) + finalChunkBytes = []byte("0" + crlf) +) + +type awsChunkedEncodingOptions struct { + // The total size of the stream. For unsigned encoding this implies that + // there will only be a single chunk containing the underlying payload, + // unless ChunkLength is also specified. + StreamLength int64 + + // Set of trailer key:value pairs that will be appended to the end of the + // payload after the end chunk has been written. + Trailers map[string]awsChunkedTrailerValue + + // The maximum size of each chunk to be sent. Default value of -1, signals + // that optimal chunk length will be used automatically. ChunkSize must be + // at least 8KB. + // + // If ChunkLength and StreamLength are both specified, the stream will be + // broken up into ChunkLength chunks. The encoded length of the aws-chunked + // encoding can still be determined as long as all trailers, if any, have a + // fixed length. + ChunkLength int +} + +type awsChunkedTrailerValue struct { + // Function to retrieve the value of the trailer. Will only be called after + // the underlying stream returns EOF error. + Get func() (string, error) + + // If the length of the value can be pre-determined, and is constant + // specify the length. A value of -1 means the length is unknown, or + // cannot be pre-determined. + Length int +} + +// awsChunkedEncoding provides a reader that wraps the payload such that +// payload is read as a single aws-chunk payload. This reader can only be used +// if the content length of payload is known. Content-Length is used as size of +// the single payload chunk. The final chunk and trailing checksum is appended +// at the end. +// +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition +// +// Here is the aws-chunked payload stream as read from the awsChunkedEncoding +// if original request stream is "Hello world", and checksum hash used is SHA256 +// +// <b>\r\n +// Hello world\r\n +// 0\r\n +// x-amz-checksum-sha256:ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=\r\n +// \r\n +type awsChunkedEncoding struct { + options awsChunkedEncodingOptions + + encodedStream io.Reader + trailerEncodedLength int +} + +// newUnsignedAWSChunkedEncoding returns a new awsChunkedEncoding configured +// for unsigned aws-chunked content encoding. Any additional trailers that need +// to be appended after the end chunk must be included as via Trailer +// callbacks. +func newUnsignedAWSChunkedEncoding( + stream io.Reader, + optFns ...func(*awsChunkedEncodingOptions), +) *awsChunkedEncoding { + options := awsChunkedEncodingOptions{ + Trailers: map[string]awsChunkedTrailerValue{}, + StreamLength: -1, + ChunkLength: -1, + } + for _, fn := range optFns { + fn(&options) + } + + var chunkReader io.Reader + if options.ChunkLength != -1 || options.StreamLength == -1 { + if options.ChunkLength == -1 { + options.ChunkLength = defaultChunkLength + } + chunkReader = newBufferedAWSChunkReader(stream, options.ChunkLength) + } else { + chunkReader = newUnsignedChunkReader(stream, options.StreamLength) + } + + trailerReader := newAWSChunkedTrailerReader(options.Trailers) + + return &awsChunkedEncoding{ + options: options, + encodedStream: io.MultiReader(chunkReader, + trailerReader, + bytes.NewBuffer(crlfBytes), + ), + trailerEncodedLength: trailerReader.EncodedLength(), + } +} + +// EncodedLength returns the final length of the aws-chunked content encoded +// stream if it can be determined without reading the underlying stream or lazy +// header values, otherwise -1 is returned. +func (e *awsChunkedEncoding) EncodedLength() int64 { + var length int64 + if e.options.StreamLength == -1 || e.trailerEncodedLength == -1 { + return -1 + } + + if e.options.StreamLength != 0 { + // If the stream length is known, and there is no chunk length specified, + // only a single chunk will be used. Otherwise the stream length needs to + // include the multiple chunk padding content. + if e.options.ChunkLength == -1 { + length += getUnsignedChunkBytesLength(e.options.StreamLength) + + } else { + // Compute chunk header and payload length + numChunks := e.options.StreamLength / int64(e.options.ChunkLength) + length += numChunks * getUnsignedChunkBytesLength(int64(e.options.ChunkLength)) + if remainder := e.options.StreamLength % int64(e.options.ChunkLength); remainder != 0 { + length += getUnsignedChunkBytesLength(remainder) + } + } + } + + // End chunk + length += int64(len(finalChunkBytes)) + + // Trailers + length += int64(e.trailerEncodedLength) + + // Encoding terminator + length += int64(len(crlf)) + + return length +} + +func getUnsignedChunkBytesLength(payloadLength int64) int64 { + payloadLengthStr := strconv.FormatInt(payloadLength, 16) + return int64(len(payloadLengthStr)) + int64(len(crlf)) + payloadLength + int64(len(crlf)) +} + +// HTTPHeaders returns the set of headers that must be included the request for +// aws-chunked to work. This includes the content-encoding: aws-chunked header. +// +// If there are multiple layered content encoding, the aws-chunked encoding +// must be appended to the previous layers the stream's encoding. The best way +// to do this is to append all header values returned to the HTTP request's set +// of headers. +func (e *awsChunkedEncoding) HTTPHeaders() map[string][]string { + headers := map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + } + + if len(e.options.Trailers) != 0 { + trailers := make([]string, 0, len(e.options.Trailers)) + for name := range e.options.Trailers { + trailers = append(trailers, strings.ToLower(name)) + } + headers[awsTrailerHeaderName] = trailers + } + + return headers +} + +func (e *awsChunkedEncoding) Read(b []byte) (n int, err error) { + return e.encodedStream.Read(b) +} + +// awsChunkedTrailerReader provides a lazy reader for reading of aws-chunked +// content encoded trailers. The trailer values will not be retrieved until the +// reader is read from. +type awsChunkedTrailerReader struct { + reader *bytes.Buffer + trailers map[string]awsChunkedTrailerValue + trailerEncodedLength int +} + +// newAWSChunkedTrailerReader returns an initialized awsChunkedTrailerReader to +// lazy reading aws-chunk content encoded trailers. +func newAWSChunkedTrailerReader(trailers map[string]awsChunkedTrailerValue) *awsChunkedTrailerReader { + return &awsChunkedTrailerReader{ + trailers: trailers, + trailerEncodedLength: trailerEncodedLength(trailers), + } +} + +func trailerEncodedLength(trailers map[string]awsChunkedTrailerValue) (length int) { + for name, trailer := range trailers { + length += len(name) + len(trailerKeyValueSeparator) + l := trailer.Length + if l == -1 { + return -1 + } + length += l + len(crlf) + } + + return length +} + +// EncodedLength returns the length of the encoded trailers if the length could +// be determined without retrieving the header values. Returns -1 if length is +// unknown. +func (r *awsChunkedTrailerReader) EncodedLength() (length int) { + return r.trailerEncodedLength +} + +// Read populates the passed in byte slice with bytes from the encoded +// trailers. Will lazy read header values first time Read is called. +func (r *awsChunkedTrailerReader) Read(p []byte) (int, error) { + if r.trailerEncodedLength == 0 { + return 0, io.EOF + } + + if r.reader == nil { + trailerLen := r.trailerEncodedLength + if r.trailerEncodedLength == -1 { + trailerLen = 0 + } + r.reader = bytes.NewBuffer(make([]byte, 0, trailerLen)) + for name, trailer := range r.trailers { + r.reader.WriteString(name) + r.reader.WriteString(trailerKeyValueSeparator) + v, err := trailer.Get() + if err != nil { + return 0, fmt.Errorf("failed to get trailer value, %w", err) + } + r.reader.WriteString(v) + r.reader.WriteString(crlf) + } + } + + return r.reader.Read(p) +} + +// newUnsignedChunkReader returns an io.Reader encoding the underlying reader +// as unsigned aws-chunked chunks. The returned reader will also include the +// end chunk, but not the aws-chunked final `crlf` segment so trailers can be +// added. +// +// If the payload size is -1 for unknown length the content will be buffered in +// defaultChunkLength chunks before wrapped in aws-chunked chunk encoding. +func newUnsignedChunkReader(reader io.Reader, payloadSize int64) io.Reader { + if payloadSize == -1 { + return newBufferedAWSChunkReader(reader, defaultChunkLength) + } + + var endChunk bytes.Buffer + if payloadSize == 0 { + endChunk.Write(finalChunkBytes) + return &endChunk + } + + endChunk.WriteString(crlf) + endChunk.Write(finalChunkBytes) + + var header bytes.Buffer + header.WriteString(strconv.FormatInt(payloadSize, 16)) + header.WriteString(crlf) + return io.MultiReader( + &header, + reader, + &endChunk, + ) +} + +// Provides a buffered aws-chunked chunk encoder of an underlying io.Reader. +// Will include end chunk, but not the aws-chunked final `crlf` segment so +// trailers can be added. +// +// Note does not implement support for chunk extensions, e.g. chunk signing. +type bufferedAWSChunkReader struct { + reader io.Reader + chunkSize int + chunkSizeStr string + + headerBuffer *bytes.Buffer + chunkBuffer *bytes.Buffer + + multiReader io.Reader + multiReaderLen int + endChunkDone bool +} + +// newBufferedAWSChunkReader returns an bufferedAWSChunkReader for reading +// aws-chunked encoded chunks. +func newBufferedAWSChunkReader(reader io.Reader, chunkSize int) *bufferedAWSChunkReader { + return &bufferedAWSChunkReader{ + reader: reader, + chunkSize: chunkSize, + chunkSizeStr: strconv.FormatInt(int64(chunkSize), 16), + + headerBuffer: bytes.NewBuffer(make([]byte, 0, 64)), + chunkBuffer: bytes.NewBuffer(make([]byte, 0, chunkSize+len(crlf))), + } +} + +// Read attempts to read from the underlying io.Reader writing aws-chunked +// chunk encoded bytes to p. When the underlying io.Reader has been completed +// read the end chunk will be available. Once the end chunk is read, the reader +// will return EOF. +func (r *bufferedAWSChunkReader) Read(p []byte) (n int, err error) { + if r.multiReaderLen == 0 && r.endChunkDone { + return 0, io.EOF + } + if r.multiReader == nil || r.multiReaderLen == 0 { + r.multiReader, r.multiReaderLen, err = r.newMultiReader() + if err != nil { + return 0, err + } + } + + n, err = r.multiReader.Read(p) + r.multiReaderLen -= n + + if err == io.EOF && !r.endChunkDone { + // Edge case handling when the multi-reader has been completely read, + // and returned an EOF, make sure that EOF only gets returned if the + // end chunk was included in the multi-reader. Otherwise, the next call + // to read will initialize the next chunk's multi-reader. + err = nil + } + return n, err +} + +// newMultiReader returns a new io.Reader for wrapping the next chunk. Will +// return an error if the underlying reader can not be read from. Will never +// return io.EOF. +func (r *bufferedAWSChunkReader) newMultiReader() (io.Reader, int, error) { + // io.Copy eats the io.EOF returned by io.LimitReader. Any error that + // occurs here is due to an actual read error. + n, err := io.Copy(r.chunkBuffer, io.LimitReader(r.reader, int64(r.chunkSize))) + if err != nil { + return nil, 0, err + } + if n == 0 { + // Early exit writing out only the end chunk. This does not include + // aws-chunk's final `crlf` so that trailers can still be added by + // upstream reader. + r.headerBuffer.Reset() + r.headerBuffer.WriteString("0") + r.headerBuffer.WriteString(crlf) + r.endChunkDone = true + + return r.headerBuffer, r.headerBuffer.Len(), nil + } + r.chunkBuffer.WriteString(crlf) + + chunkSizeStr := r.chunkSizeStr + if int(n) != r.chunkSize { + chunkSizeStr = strconv.FormatInt(n, 16) + } + + r.headerBuffer.Reset() + r.headerBuffer.WriteString(chunkSizeStr) + r.headerBuffer.WriteString(crlf) + + return io.MultiReader( + r.headerBuffer, + r.chunkBuffer, + ), r.headerBuffer.Len() + r.chunkBuffer.Len(), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go new file mode 100644 index 0000000000..8e9ce3c8a9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/aws_chunked_encoding_test.go @@ -0,0 +1,507 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/google/go-cmp/cmp" +) + +func TestAWSChunkedEncoding(t *testing.T) { + cases := map[string]struct { + reader *awsChunkedEncoding + expectErr string + expectEncodedLength int64 + expectHTTPHeaders map[string][]string + expectPayload []byte + }{ + "empty payload fixed stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader(""), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 0 + }), + expectEncodedLength: 5, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("0\r\n\r\n"), + }, + "empty payload unknown stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("")), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("0\r\n\r\n"), + }, + "payload fixed stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + }), + expectEncodedLength: 21, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"), + }, + "payload unknown stream length": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world")), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\n\r\n"), + }, + "payload unknown stream length with chunk size": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.ChunkLength = 8 + }), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"), + }, + "payload fixed stream length with chunk size": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.ChunkLength = 8 + }), + expectEncodedLength: 26, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + }, + expectPayload: []byte("8\r\nhello wo\r\n3\r\nrld\r\n0\r\n\r\n"), + }, + "payload fixed stream length with fixed length trailer": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.Trailers = map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: 6, + }, + } + }), + expectEncodedLength: 33, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + awsTrailerHeaderName: {"foo"}, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"), + }, + "payload fixed stream length with unknown length trailer": { + reader: newUnsignedAWSChunkedEncoding(strings.NewReader("hello world"), + func(o *awsChunkedEncodingOptions) { + o.StreamLength = 11 + o.Trailers = map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: -1, + }, + } + }), + expectEncodedLength: -1, + expectHTTPHeaders: map[string][]string{ + contentEncodingHeaderName: { + awsChunkedContentEncodingHeaderValue, + }, + awsTrailerHeaderName: {"foo"}, + }, + expectPayload: []byte("b\r\nhello world\r\n0\r\nfoo:abc123\r\n\r\n"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a { + t.Errorf("expect %v encoded length, got %v", e, a) + } + if diff := cmp.Diff(c.expectHTTPHeaders, c.reader.HTTPHeaders()); diff != "" { + t.Errorf("expect HTTP headers match\n%v", diff) + } + + actualPayload, err := ioutil.ReadAll(c.reader) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestUnsignedAWSChunkReader(t *testing.T) { + cases := map[string]struct { + payload interface { + io.Reader + Len() int + } + + expectPayload []byte + expectErr string + }{ + "empty body": { + payload: bytes.NewReader([]byte{}), + expectPayload: []byte("0\r\n"), + }, + "with body": { + payload: strings.NewReader("Hello world"), + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "large body": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"), + }, + "reader error": { + payload: newLimitReadLener(iotest.ErrReader(fmt.Errorf("some read error")), 128), + expectErr: "some read error", + }, + "unknown length reader": { + payload: newUnknownLenReader(io.LimitReader(byteReader('a'), defaultChunkLength*2)), + expectPayload: func() []byte { + reader := newBufferedAWSChunkReader( + io.LimitReader(byteReader('a'), defaultChunkLength*2), + defaultChunkLength, + ) + actualPayload, err := ioutil.ReadAll(reader) + if err != nil { + t.Fatalf("failed to create unknown length reader test data, %v", err) + } + return actualPayload + }(), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + reader := newUnsignedChunkReader(c.payload, int64(c.payload.Len())) + + actualPayload, err := ioutil.ReadAll(reader) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestBufferedAWSChunkReader(t *testing.T) { + cases := map[string]struct { + payload io.Reader + readSize int + chunkSize int + + expectPayload []byte + expectErr string + }{ + "empty body": { + payload: bytes.NewReader([]byte{}), + chunkSize: 4, + expectPayload: []byte("0\r\n"), + }, + "with one chunk body": { + payload: strings.NewReader("Hello world"), + chunkSize: 20, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "single byte read": { + payload: strings.NewReader("Hello world"), + chunkSize: 8, + readSize: 1, + expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"), + }, + "single chunk and byte read": { + payload: strings.NewReader("Hello world"), + chunkSize: 1, + readSize: 1, + expectPayload: []byte("1\r\nH\r\n1\r\ne\r\n1\r\nl\r\n1\r\nl\r\n1\r\no\r\n1\r\n \r\n1\r\nw\r\n1\r\no\r\n1\r\nr\r\n1\r\nl\r\n1\r\nd\r\n0\r\n"), + }, + "with two chunk body": { + payload: strings.NewReader("Hello world"), + chunkSize: 8, + expectPayload: []byte("8\r\nHello wo\r\n3\r\nrld\r\n0\r\n"), + }, + "chunk size equal to read size": { + payload: strings.NewReader("Hello world"), + chunkSize: 512, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "chunk size greater than read size": { + payload: strings.NewReader("Hello world"), + chunkSize: 1024, + expectPayload: []byte("b\r\nHello world\r\n0\r\n"), + }, + "payload size more than default read size, chunk size less than read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 500, + expectPayload: []byte("1f4\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello\r\n11\r\n worldHello world\r\n0\r\n"), + }, + "payload size more than default read size, chunk size equal to read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 512, + expectPayload: []byte("200\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello \r\n5\r\nworld\r\n0\r\n"), + }, + "payload size more than default read size, chunk size more than read size": { + payload: bytes.NewBufferString("Hello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world"), + chunkSize: 1024, + expectPayload: []byte("205\r\nHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello worldHello " + + "worldHello worldHello worldHello worldHello worldHello world\r\n0\r\n"), + }, + "reader error": { + payload: iotest.ErrReader(fmt.Errorf("some read error")), + chunkSize: 128, + expectErr: "some read error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + reader := newBufferedAWSChunkReader(c.payload, c.chunkSize) + + var actualPayload []byte + var err error + + if c.readSize != 0 { + for err == nil { + var n int + p := make([]byte, c.readSize) + n, err = reader.Read(p) + if n != 0 { + actualPayload = append(actualPayload, p[:n]...) + } + } + if err == io.EOF { + err = nil + } + } else { + actualPayload, err = ioutil.ReadAll(reader) + } + + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +func TestAwsChunkedTrailerReader(t *testing.T) { + cases := map[string]struct { + reader *awsChunkedTrailerReader + + expectErr string + expectEncodedLength int + expectPayload []byte + }{ + "no trailers": { + reader: newAWSChunkedTrailerReader(nil), + expectPayload: []byte{}, + }, + "unknown length trailers": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: -1, + }, + }), + expectEncodedLength: -1, + expectPayload: []byte("foo:abc123\r\n"), + }, + "known length trailers": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "abc123", nil + }, + Length: 6, + }, + }), + expectEncodedLength: 12, + expectPayload: []byte("foo:abc123\r\n"), + }, + "trailer error": { + reader: newAWSChunkedTrailerReader(map[string]awsChunkedTrailerValue{ + "foo": { + Get: func() (string, error) { + return "", fmt.Errorf("some error") + }, + Length: 6, + }, + }), + expectEncodedLength: 12, + expectErr: "failed to get trailer", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + if e, a := c.expectEncodedLength, c.reader.EncodedLength(); e != a { + t.Errorf("expect %v encoded length, got %v", e, a) + } + + actualPayload, err := ioutil.ReadAll(c.reader) + + // Asserts + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match\n%v", diff) + } + }) + } +} + +type limitReadLener struct { + io.Reader + length int +} + +func newLimitReadLener(r io.Reader, l int) *limitReadLener { + return &limitReadLener{ + Reader: io.LimitReader(r, int64(l)), + length: l, + } +} +func (r *limitReadLener) Len() int { + return r.length +} + +type unknownLenReader struct { + io.Reader +} + +func newUnknownLenReader(r io.Reader) *unknownLenReader { + return &unknownLenReader{ + Reader: r, + } +} +func (r *unknownLenReader) Len() int { + return -1 +} + +type byteReader byte + +func (r byteReader) Read(p []byte) (int, error) { + for i := 0; i < len(p); i++ { + p[i] = byte(r) + } + return len(p), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go new file mode 100644 index 0000000000..f591861505 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/go_module_metadata.go @@ -0,0 +1,6 @@ +// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT. + +package checksum + +// goModuleVersion is the tagged release for this module +const goModuleVersion = "1.2.0" diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make new file mode 100644 index 0000000000..2a3d936a06 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/gotest/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum) + +LICENSE(Apache-2.0) + +END() diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go new file mode 100644 index 0000000000..3e17d2216b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add.go @@ -0,0 +1,185 @@ +package checksum + +import ( + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// InputMiddlewareOptions provides the options for the request +// checksum middleware setup. +type InputMiddlewareOptions struct { + // GetAlgorithm is a function to get the checksum algorithm of the + // input payload from the input parameters. + // + // Given the input parameter value, the function must return the algorithm + // and true, or false if no algorithm is specified. + GetAlgorithm func(interface{}) (string, bool) + + // Forces the middleware to compute the input payload's checksum. The + // request will fail if the algorithm is not specified or unable to compute + // the checksum. + RequireChecksum bool + + // Enables support for wrapping the serialized input payload with a + // content-encoding: aws-check wrapper, and including a trailer for the + // algorithm's checksum value. + // + // The checksum will not be computed, nor added as trailing checksum, if + // the Algorithm's header is already set on the request. + EnableTrailingChecksum bool + + // Enables support for computing the SHA256 checksum of input payloads + // along with the algorithm specified checksum. Prevents downstream + // middleware handlers (computePayloadSHA256) re-reading the payload. + // + // The SHA256 payload checksum will only be used for computed for requests + // that are not TLS, or do not enable trailing checksums. + // + // The SHA256 payload hash will not be computed, if the Algorithm's header + // is already set on the request. + EnableComputeSHA256PayloadHash bool + + // Enables support for setting the aws-chunked decoded content length + // header for the decoded length of the underlying stream. Will only be set + // when used with trailing checksums, and aws-chunked content-encoding. + EnableDecodedContentLengthHeader bool +} + +// AddInputMiddleware adds the middleware for performing checksum computing +// of request payloads, and checksum validation of response payloads. +func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) (err error) { + // TODO ensure this works correctly with presigned URLs + + // Middleware stack: + // * (OK)(Initialize) --none-- + // * (OK)(Serialize) EndpointResolver + // * (OK)(Build) ComputeContentLength + // * (AD)(Build) Header ComputeInputPayloadChecksum + // * SIGNED Payload - If HTTP && not support trailing checksum + // * UNSIGNED Payload - If HTTPS && not support trailing checksum + // * (RM)(Build) ContentChecksum - OK to remove + // * (OK)(Build) ComputePayloadHash + // * v4.dynamicPayloadSigningMiddleware + // * v4.computePayloadSHA256 + // * v4.unsignedPayload + // (OK)(Build) Set computedPayloadHash header + // * (OK)(Finalize) Retry + // * (AD)(Finalize) Trailer ComputeInputPayloadChecksum, + // * Requires HTTPS && support trailing checksum + // * UNSIGNED Payload + // * Finalize run if HTTPS && support trailing checksum + // * (OK)(Finalize) Signing + // * (OK)(Deserialize) --none-- + + // Initial checksum configuration look up middleware + err = stack.Initialize.Add(&setupInputContext{ + GetAlgorithm: options.GetAlgorithm, + }, middleware.Before) + if err != nil { + return err + } + + stack.Build.Remove("ContentChecksum") + + // Create the compute checksum middleware that will be added as both a + // build and finalize handler. + inputChecksum := &computeInputPayloadChecksum{ + RequireChecksum: options.RequireChecksum, + EnableTrailingChecksum: options.EnableTrailingChecksum, + EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash, + EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader, + } + + // Insert header checksum after ComputeContentLength middleware, must also + // be before the computePayloadHash middleware handlers. + err = stack.Build.Insert(inputChecksum, + (*smithyhttp.ComputeContentLength)(nil).ID(), + middleware.After) + if err != nil { + return err + } + + // If trailing checksum is not supported no need for finalize handler to be added. + if options.EnableTrailingChecksum { + err = stack.Finalize.Insert(inputChecksum, "Retry", middleware.After) + if err != nil { + return err + } + } + + return nil +} + +// RemoveInputMiddleware Removes the compute input payload checksum middleware +// handlers from the stack. +func RemoveInputMiddleware(stack *middleware.Stack) { + id := (*setupInputContext)(nil).ID() + stack.Initialize.Remove(id) + + id = (*computeInputPayloadChecksum)(nil).ID() + stack.Build.Remove(id) + stack.Finalize.Remove(id) +} + +// OutputMiddlewareOptions provides options for configuring output checksum +// validation middleware. +type OutputMiddlewareOptions struct { + // GetValidationMode is a function to get the checksum validation + // mode of the output payload from the input parameters. + // + // Given the input parameter value, the function must return the validation + // mode and true, or false if no mode is specified. + GetValidationMode func(interface{}) (string, bool) + + // The set of checksum algorithms that should be used for response payload + // checksum validation. The algorithm(s) used will be a union of the + // output's returned algorithms and this set. + // + // Only the first algorithm in the union is currently used. + ValidationAlgorithms []string + + // If set the middleware will ignore output multipart checksums. Otherwise + // an checksum format error will be returned by the middleware. + IgnoreMultipartValidation bool + + // When set the middleware will log when output does not have checksum or + // algorithm to validate. + LogValidationSkipped bool + + // When set the middleware will log when the output contains a multipart + // checksum that was, skipped and not validated. + LogMultipartValidationSkipped bool +} + +// AddOutputMiddleware adds the middleware for validating response payload's +// checksum. +func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error { + err := stack.Initialize.Add(&setupOutputContext{ + GetValidationMode: options.GetValidationMode, + }, middleware.Before) + if err != nil { + return err + } + + // Resolve a supported priority order list of algorithms to validate. + algorithms := FilterSupportedAlgorithms(options.ValidationAlgorithms) + + m := &validateOutputPayloadChecksum{ + Algorithms: algorithms, + IgnoreMultipartValidation: options.IgnoreMultipartValidation, + LogMultipartValidationSkipped: options.LogMultipartValidationSkipped, + LogValidationSkipped: options.LogValidationSkipped, + } + + return stack.Deserialize.Add(m, middleware.After) +} + +// RemoveOutputMiddleware Removes the compute input payload checksum middleware +// handlers from the stack. +func RemoveOutputMiddleware(stack *middleware.Stack) { + id := (*setupOutputContext)(nil).ID() + stack.Initialize.Remove(id) + + id = (*validateOutputPayloadChecksum)(nil).ID() + stack.Deserialize.Remove(id) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go new file mode 100644 index 0000000000..33f952ebe9 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_add_test.go @@ -0,0 +1,412 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "context" + "testing" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +func TestAddInputMiddleware(t *testing.T) { + cases := map[string]struct { + options InputMiddlewareOptions + expectErr string + expectMiddleware []string + expectInitialize *setupInputContext + expectBuild *computeInputPayloadChecksum + expectFinalize *computeInputPayloadChecksum + }{ + "with trailing checksum": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + EnableTrailingChecksum: true, + EnableComputeSHA256PayloadHash: true, + EnableDecodedContentLengthHeader: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "AWSChecksum:ComputeInputPayloadChecksum", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{ + EnableTrailingChecksum: true, + EnableComputePayloadHash: true, + EnableDecodedContentLengthHeader: true, + }, + }, + "with checksum required": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + EnableTrailingChecksum: true, + RequireChecksum: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "AWSChecksum:ComputeInputPayloadChecksum", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{ + RequireChecksum: true, + EnableTrailingChecksum: true, + }, + }, + "no trailing checksum": { + options: InputMiddlewareOptions{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupInputContext", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "AWSChecksum:ComputeInputPayloadChecksum", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "Signing", + "Deserialize stack step", + }, + expectInitialize: &setupInputContext{ + GetAlgorithm: func(interface{}) (string, bool) { + return string(AlgorithmCRC32), true + }, + }, + expectBuild: &computeInputPayloadChecksum{}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After) + + err := AddInputMiddleware(stack, c.options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } + + initializeMiddleware, ok := stack.Initialize.Get((*setupInputContext)(nil).ID()) + if e, a := (c.expectInitialize != nil), ok; e != a { + t.Errorf("expect initialize middleware %t, got %t", e, a) + } + if c.expectInitialize != nil && ok { + setupInput := initializeMiddleware.(*setupInputContext) + if e, a := c.options.GetAlgorithm != nil, setupInput.GetAlgorithm != nil; e != a { + t.Fatalf("expect GetAlgorithm %t, got %t", e, a) + } + expectAlgo, expectOK := c.options.GetAlgorithm(nil) + actualAlgo, actualOK := setupInput.GetAlgorithm(nil) + if e, a := expectAlgo, actualAlgo; e != a { + t.Errorf("expect %v algorithm, got %v", e, a) + } + if e, a := expectOK, actualOK; e != a { + t.Errorf("expect %v algorithm present, got %v", e, a) + } + } + + buildMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID()) + if e, a := (c.expectBuild != nil), ok; e != a { + t.Errorf("expect build middleware %t, got %t", e, a) + } + var computeInput *computeInputPayloadChecksum + if c.expectBuild != nil && ok { + computeInput = buildMiddleware.(*computeInputPayloadChecksum) + if e, a := c.expectBuild.RequireChecksum, computeInput.RequireChecksum; e != a { + t.Errorf("expect %v require checksum, got %v", e, a) + } + if e, a := c.expectBuild.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a { + t.Errorf("expect %v enable trailing checksum, got %v", e, a) + } + if e, a := c.expectBuild.EnableComputePayloadHash, computeInput.EnableComputePayloadHash; e != a { + t.Errorf("expect %v enable compute payload hash, got %v", e, a) + } + if e, a := c.expectBuild.EnableDecodedContentLengthHeader, computeInput.EnableDecodedContentLengthHeader; e != a { + t.Errorf("expect %v enable decoded length header, got %v", e, a) + } + } + + if c.expectFinalize != nil && ok { + finalizeMiddleware, ok := stack.Build.Get((*computeInputPayloadChecksum)(nil).ID()) + if !ok { + t.Errorf("expect finalize middleware") + } + finalizeComputeInput := finalizeMiddleware.(*computeInputPayloadChecksum) + + if e, a := computeInput, finalizeComputeInput; e != a { + t.Errorf("expect build and finalize to be same value") + } + } + }) + } +} + +func TestRemoveInputMiddleware(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + stack.Build.Add(nopBuildMiddleware("ComputeContentLength"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ContentChecksum"), middleware.After) + stack.Build.Add(nopBuildMiddleware("ComputePayloadHash"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Retry"), middleware.After) + stack.Finalize.Add(nopFinalizeMiddleware("Signing"), middleware.After) + + err := AddInputMiddleware(stack, InputMiddlewareOptions{ + EnableTrailingChecksum: true, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + RemoveInputMiddleware(stack) + + expectStack := []string{ + "test", + "Initialize stack step", + "Serialize stack step", + "Build stack step", + "ComputeContentLength", + "ComputePayloadHash", + "Finalize stack step", + "Retry", + "Signing", + "Deserialize stack step", + } + + if diff := cmp.Diff(expectStack, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } +} + +func TestAddOutputMiddleware(t *testing.T) { + cases := map[string]struct { + options OutputMiddlewareOptions + expectErr string + expectMiddleware []string + expectInitialize *setupOutputContext + expectDeserialize *validateOutputPayloadChecksum + }{ + "validate output": { + options: OutputMiddlewareOptions{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + ValidationAlgorithms: []string{ + "crc32", "sha1", "abc123", "crc32c", + }, + IgnoreMultipartValidation: true, + LogMultipartValidationSkipped: true, + LogValidationSkipped: true, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupOutputContext", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + "AWSChecksum:ValidateOutputPayloadChecksum", + }, + expectInitialize: &setupOutputContext{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + }, + expectDeserialize: &validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C, + }, + IgnoreMultipartValidation: true, + LogMultipartValidationSkipped: true, + LogValidationSkipped: true, + }, + }, + "validate options off": { + options: OutputMiddlewareOptions{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + ValidationAlgorithms: []string{ + "crc32", "sha1", "abc123", "crc32c", + }, + }, + expectMiddleware: []string{ + "test", + "Initialize stack step", + "AWSChecksum:SetupOutputContext", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + "AWSChecksum:ValidateOutputPayloadChecksum", + }, + expectInitialize: &setupOutputContext{ + GetValidationMode: func(interface{}) (string, bool) { + return "ENABLED", true + }, + }, + expectDeserialize: &validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmCRC32, AlgorithmSHA1, AlgorithmCRC32C, + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + err := AddOutputMiddleware(stack, c.options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.expectMiddleware, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } + + initializeMiddleware, ok := stack.Initialize.Get((*setupOutputContext)(nil).ID()) + if e, a := (c.expectInitialize != nil), ok; e != a { + t.Errorf("expect initialize middleware %t, got %t", e, a) + } + if c.expectInitialize != nil && ok { + setupOutput := initializeMiddleware.(*setupOutputContext) + if e, a := c.options.GetValidationMode != nil, setupOutput.GetValidationMode != nil; e != a { + t.Fatalf("expect GetValidationMode %t, got %t", e, a) + } + expectMode, expectOK := c.options.GetValidationMode(nil) + actualMode, actualOK := setupOutput.GetValidationMode(nil) + if e, a := expectMode, actualMode; e != a { + t.Errorf("expect %v mode, got %v", e, a) + } + if e, a := expectOK, actualOK; e != a { + t.Errorf("expect %v mode present, got %v", e, a) + } + } + + deserializeMiddleware, ok := stack.Deserialize.Get((*validateOutputPayloadChecksum)(nil).ID()) + if e, a := (c.expectDeserialize != nil), ok; e != a { + t.Errorf("expect deserialize middleware %t, got %t", e, a) + } + if c.expectDeserialize != nil && ok { + validateOutput := deserializeMiddleware.(*validateOutputPayloadChecksum) + if diff := cmp.Diff(c.expectDeserialize.Algorithms, validateOutput.Algorithms); diff != "" { + t.Errorf("expect algorithms match:\n%s", diff) + } + if e, a := c.expectDeserialize.IgnoreMultipartValidation, validateOutput.IgnoreMultipartValidation; e != a { + t.Errorf("expect %v ignore multipart checksum, got %v", e, a) + } + if e, a := c.expectDeserialize.LogMultipartValidationSkipped, validateOutput.LogMultipartValidationSkipped; e != a { + t.Errorf("expect %v log multipart skipped, got %v", e, a) + } + if e, a := c.expectDeserialize.LogValidationSkipped, validateOutput.LogValidationSkipped; e != a { + t.Errorf("expect %v log validation skipped, got %v", e, a) + } + } + }) + } +} + +func TestRemoveOutputMiddleware(t *testing.T) { + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + err := AddOutputMiddleware(stack, OutputMiddlewareOptions{}) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + RemoveOutputMiddleware(stack) + + expectStack := []string{ + "test", + "Initialize stack step", + "Serialize stack step", + "Build stack step", + "Finalize stack step", + "Deserialize stack step", + } + + if diff := cmp.Diff(expectStack, stack.List()); diff != "" { + t.Fatalf("expect stack list match:\n%s", diff) + } +} + +func setSerializedRequest(req *smithyhttp.Request) middleware.SerializeMiddleware { + return middleware.SerializeMiddlewareFunc("OperationSerializer", + func(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler) ( + middleware.SerializeOutput, middleware.Metadata, error, + ) { + input.Request = req + return next.HandleSerialize(ctx, input) + }) +} + +func nopBuildMiddleware(id string) middleware.BuildMiddleware { + return middleware.BuildMiddlewareFunc(id, + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + middleware.BuildOutput, middleware.Metadata, error, + ) { + return next.HandleBuild(ctx, input) + }) +} + +func nopFinalizeMiddleware(id string) middleware.FinalizeMiddleware { + return middleware.FinalizeMiddlewareFunc(id, + func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) ( + middleware.FinalizeOutput, middleware.Metadata, error, + ) { + return next.HandleFinalize(ctx, input) + }) +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go new file mode 100644 index 0000000000..0b3c48912b --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum.go @@ -0,0 +1,479 @@ +package checksum + +import ( + "context" + "crypto/sha256" + "fmt" + "hash" + "io" + "strconv" + + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const ( + contentMD5Header = "Content-Md5" + streamingUnsignedPayloadTrailerPayloadHash = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" +) + +// computedInputChecksumsKey is the metadata key for recording the algorithm the +// checksum was computed for and the checksum value. +type computedInputChecksumsKey struct{} + +// GetComputedInputChecksums returns the map of checksum algorithm to their +// computed value stored in the middleware Metadata. Returns false if no values +// were stored in the Metadata. +func GetComputedInputChecksums(m middleware.Metadata) (map[string]string, bool) { + vs, ok := m.Get(computedInputChecksumsKey{}).(map[string]string) + return vs, ok +} + +// SetComputedInputChecksums stores the map of checksum algorithm to their +// computed value in the middleware Metadata. Overwrites any values that +// currently exist in the metadata. +func SetComputedInputChecksums(m *middleware.Metadata, vs map[string]string) { + m.Set(computedInputChecksumsKey{}, vs) +} + +// computeInputPayloadChecksum middleware computes payload checksum +type computeInputPayloadChecksum struct { + // Enables support for wrapping the serialized input payload with a + // content-encoding: aws-check wrapper, and including a trailer for the + // algorithm's checksum value. + // + // The checksum will not be computed, nor added as trailing checksum, if + // the Algorithm's header is already set on the request. + EnableTrailingChecksum bool + + // States that a checksum is required to be included for the operation. If + // Input does not specify a checksum, fallback to built in MD5 checksum is + // used. + // + // Replaces smithy-go's ContentChecksum middleware. + RequireChecksum bool + + // Enables support for computing the SHA256 checksum of input payloads + // along with the algorithm specified checksum. Prevents downstream + // middleware handlers (computePayloadSHA256) re-reading the payload. + // + // The SHA256 payload hash will only be used for computed for requests + // that are not TLS, or do not enable trailing checksums. + // + // The SHA256 payload hash will not be computed, if the Algorithm's header + // is already set on the request. + EnableComputePayloadHash bool + + // Enables support for setting the aws-chunked decoded content length + // header for the decoded length of the underlying stream. Will only be set + // when used with trailing checksums, and aws-chunked content-encoding. + EnableDecodedContentLengthHeader bool + + buildHandlerRun bool + deferToFinalizeHandler bool +} + +// ID provides the middleware's identifier. +func (m *computeInputPayloadChecksum) ID() string { + return "AWSChecksum:ComputeInputPayloadChecksum" +} + +type computeInputHeaderChecksumError struct { + Msg string + Err error +} + +func (e computeInputHeaderChecksumError) Error() string { + const intro = "compute input header checksum failed" + + if e.Err != nil { + return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err) + } + + return fmt.Sprintf("%s, %s", intro, e.Msg) +} +func (e computeInputHeaderChecksumError) Unwrap() error { return e.Err } + +// HandleBuild handles computing the payload's checksum, in the following cases: +// - Is HTTP, not HTTPS +// - RequireChecksum is true, and no checksums were specified via the Input +// - Trailing checksums are not supported +// +// The build handler must be inserted in the stack before ContentPayloadHash +// and after ComputeContentLength. +func (m *computeInputPayloadChecksum) HandleBuild( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, +) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + m.buildHandlerRun = true + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, computeInputHeaderChecksumError{ + Msg: fmt.Sprintf("unknown request type %T", req), + } + } + + var algorithm Algorithm + var checksum string + defer func() { + if algorithm == "" || checksum == "" || err != nil { + return + } + + // Record the checksum and algorithm that was computed + SetComputedInputChecksums(&metadata, map[string]string{ + string(algorithm): checksum, + }) + }() + + // If no algorithm was specified, and the operation requires a checksum, + // fallback to the legacy content MD5 checksum. + algorithm, ok, err = getInputAlgorithm(ctx) + if err != nil { + return out, metadata, err + } else if !ok { + if m.RequireChecksum { + checksum, err = setMD5Checksum(ctx, req) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to compute stream's MD5 checksum", + Err: err, + } + } + algorithm = Algorithm("MD5") + } + return next.HandleBuild(ctx, in) + } + + // If the checksum header is already set nothing to do. + checksumHeader := AlgorithmHTTPHeader(algorithm) + if checksum = req.Header.Get(checksumHeader); checksum != "" { + return next.HandleBuild(ctx, in) + } + + computePayloadHash := m.EnableComputePayloadHash + if v := v4.GetPayloadHash(ctx); v != "" { + computePayloadHash = false + } + + stream := req.GetStream() + streamLength, err := getRequestStreamLength(req) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to determine stream length", + Err: err, + } + } + + // If trailing checksums are supported, the request is HTTPS, and the + // stream is not nil or empty, there is nothing to do in the build stage. + // The checksum will be added to the request as a trailing checksum in the + // finalize handler. + // + // Nil and empty streams will always be handled as a request header, + // regardless if the operation supports trailing checksums or not. + if req.IsHTTPS() { + if stream != nil && streamLength != 0 && m.EnableTrailingChecksum { + if m.EnableComputePayloadHash { + // payload hash is set as header in Build middleware handler, + // ContentSHA256Header. + ctx = v4.SetPayloadHash(ctx, streamingUnsignedPayloadTrailerPayloadHash) + } + + m.deferToFinalizeHandler = true + return next.HandleBuild(ctx, in) + } + + // If trailing checksums are not enabled but protocol is still HTTPS + // disabling computing the payload hash. Downstream middleware handler + // (ComputetPayloadHash) will set the payload hash to unsigned payload, + // if signing was used. + computePayloadHash = false + } + + // Only seekable streams are supported for non-trailing checksums, because + // the stream needs to be rewound before the handler can continue. + if stream != nil && !req.IsStreamSeekable() { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "unseekable stream is not supported without TLS and trailing checksum", + } + } + + var sha256Checksum string + checksum, sha256Checksum, err = computeStreamChecksum( + algorithm, stream, computePayloadHash) + if err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to compute stream checksum", + Err: err, + } + } + + if err := req.RewindStream(); err != nil { + return out, metadata, computeInputHeaderChecksumError{ + Msg: "failed to rewind stream", + Err: err, + } + } + + req.Header.Set(checksumHeader, checksum) + + if computePayloadHash { + ctx = v4.SetPayloadHash(ctx, sha256Checksum) + } + + return next.HandleBuild(ctx, in) +} + +type computeInputTrailingChecksumError struct { + Msg string + Err error +} + +func (e computeInputTrailingChecksumError) Error() string { + const intro = "compute input trailing checksum failed" + + if e.Err != nil { + return fmt.Sprintf("%s, %s, %v", intro, e.Msg, e.Err) + } + + return fmt.Sprintf("%s, %s", intro, e.Msg) +} +func (e computeInputTrailingChecksumError) Unwrap() error { return e.Err } + +// HandleFinalize handles computing the payload's checksum, in the following cases: +// - Is HTTPS, not HTTP +// - A checksum was specified via the Input +// - Trailing checksums are supported. +// +// The finalize handler must be inserted in the stack before Signing, and after Retry. +func (m *computeInputPayloadChecksum) HandleFinalize( + ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler, +) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, +) { + if !m.deferToFinalizeHandler { + if !m.buildHandlerRun { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "build handler was removed without also removing finalize handler", + } + } + return next.HandleFinalize(ctx, in) + } + + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, computeInputTrailingChecksumError{ + Msg: fmt.Sprintf("unknown request type %T", req), + } + } + + // Trailing checksums are only supported when TLS is enabled. + if !req.IsHTTPS() { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "HTTPS required", + } + } + + // If no algorithm was specified, there is nothing to do. + algorithm, ok, err := getInputAlgorithm(ctx) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to get algorithm", + Err: err, + } + } else if !ok { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "no algorithm specified", + } + } + + // If the checksum header is already set before finalize could run, there + // is nothing to do. + checksumHeader := AlgorithmHTTPHeader(algorithm) + if req.Header.Get(checksumHeader) != "" { + return next.HandleFinalize(ctx, in) + } + + stream := req.GetStream() + streamLength, err := getRequestStreamLength(req) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to determine stream length", + Err: err, + } + } + + if stream == nil || streamLength == 0 { + // Nil and empty streams are handled by the Build handler. They are not + // supported by the trailing checksums finalize handler. There is no + // benefit to sending them as trailers compared to headers. + return out, metadata, computeInputTrailingChecksumError{ + Msg: "nil or empty streams are not supported", + } + } + + checksumReader, err := newComputeChecksumReader(stream, algorithm) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to created checksum reader", + Err: err, + } + } + + awsChunkedReader := newUnsignedAWSChunkedEncoding(checksumReader, + func(o *awsChunkedEncodingOptions) { + o.Trailers[AlgorithmHTTPHeader(checksumReader.Algorithm())] = awsChunkedTrailerValue{ + Get: checksumReader.Base64Checksum, + Length: checksumReader.Base64ChecksumLength(), + } + o.StreamLength = streamLength + }) + + for key, values := range awsChunkedReader.HTTPHeaders() { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // Setting the stream on the request will create a copy. The content length + // is not updated until after the request is copied to prevent impacting + // upstream middleware. + req, err = req.SetStream(awsChunkedReader) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed updating request to trailing checksum wrapped stream", + Err: err, + } + } + req.ContentLength = awsChunkedReader.EncodedLength() + in.Request = req + + // Add decoded content length header if original stream's content length is known. + if streamLength != -1 && m.EnableDecodedContentLengthHeader { + req.Header.Set(decodedContentLengthHeaderName, strconv.FormatInt(streamLength, 10)) + } + + out, metadata, err = next.HandleFinalize(ctx, in) + if err == nil { + checksum, err := checksumReader.Base64Checksum() + if err != nil { + return out, metadata, fmt.Errorf("failed to get computed checksum, %w", err) + } + + // Record the checksum and algorithm that was computed + SetComputedInputChecksums(&metadata, map[string]string{ + string(algorithm): checksum, + }) + } + + return out, metadata, err +} + +func getInputAlgorithm(ctx context.Context) (Algorithm, bool, error) { + ctxAlgorithm := getContextInputAlgorithm(ctx) + if ctxAlgorithm == "" { + return "", false, nil + } + + algorithm, err := ParseAlgorithm(ctxAlgorithm) + if err != nil { + return "", false, fmt.Errorf( + "failed to parse algorithm, %w", err) + } + + return algorithm, true, nil +} + +func computeStreamChecksum(algorithm Algorithm, stream io.Reader, computePayloadHash bool) ( + checksum string, sha256Checksum string, err error, +) { + hasher, err := NewAlgorithmHash(algorithm) + if err != nil { + return "", "", fmt.Errorf( + "failed to get hasher for checksum algorithm, %w", err) + } + + var sha256Hasher hash.Hash + var batchHasher io.Writer = hasher + + // Compute payload hash for the protocol. To prevent another handler + // (computePayloadSHA256) re-reading body also compute the SHA256 for + // request signing. If configured checksum algorithm is SHA256, don't + // double wrap stream with another SHA256 hasher. + if computePayloadHash && algorithm != AlgorithmSHA256 { + sha256Hasher = sha256.New() + batchHasher = io.MultiWriter(hasher, sha256Hasher) + } + + if stream != nil { + if _, err = io.Copy(batchHasher, stream); err != nil { + return "", "", fmt.Errorf( + "failed to read stream to compute hash, %w", err) + } + } + + checksum = string(base64EncodeHashSum(hasher)) + if computePayloadHash { + if algorithm != AlgorithmSHA256 { + sha256Checksum = string(hexEncodeHashSum(sha256Hasher)) + } else { + sha256Checksum = string(hexEncodeHashSum(hasher)) + } + } + + return checksum, sha256Checksum, nil +} + +func getRequestStreamLength(req *smithyhttp.Request) (int64, error) { + if v := req.ContentLength; v > 0 { + return v, nil + } + + if length, ok, err := req.StreamLength(); err != nil { + return 0, fmt.Errorf("failed getting request stream's length, %w", err) + } else if ok { + return length, nil + } + + return -1, nil +} + +// setMD5Checksum computes the MD5 of the request payload and sets it to the +// Content-MD5 header. Returning the MD5 base64 encoded string or error. +// +// If the MD5 is already set as the Content-MD5 header, that value will be +// returned, and nothing else will be done. +// +// If the payload is empty, no MD5 will be computed. No error will be returned. +// Empty payloads do not have an MD5 value. +// +// Replaces the smithy-go middleware for httpChecksum trait. +func setMD5Checksum(ctx context.Context, req *smithyhttp.Request) (string, error) { + if v := req.Header.Get(contentMD5Header); len(v) != 0 { + return v, nil + } + stream := req.GetStream() + if stream == nil { + return "", nil + } + + if !req.IsStreamSeekable() { + return "", fmt.Errorf( + "unseekable stream is not supported for computing md5 checksum") + } + + v, err := computeMD5Checksum(stream) + if err != nil { + return "", err + } + if err := req.RewindStream(); err != nil { + return "", fmt.Errorf("failed to rewind stream after computing MD5 checksum, %w", err) + } + // set the 'Content-MD5' header + req.Header.Set(contentMD5Header, string(v)) + return string(v), nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go new file mode 100644 index 0000000000..7ad59e7315 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_compute_input_checksum_test.go @@ -0,0 +1,958 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + "testing/iotest" + + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +// TODO test cases: +// * Retry re-wrapping payload + +func TestComputeInputPayloadChecksum(t *testing.T) { + cases := map[string]map[string]struct { + optionsFn func(*computeInputPayloadChecksum) + initContext func(context.Context) context.Context + buildInput middleware.BuildInput + + expectErr string + expectBuildErr bool + expectFinalizeErr bool + expectReadErr bool + + expectHeader http.Header + expectContentLength int64 + expectPayload []byte + expectPayloadHash string + + expectChecksumMetadata map[string]string + + expectDeferToFinalize bool + expectLogged string + }{ + "no op": { + "checksum header set known length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = 11 + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "checksum header set unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = -1 + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "no algorithm": { + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = 11 + return r + }(), + }, + expectHeader: http.Header{}, + expectContentLength: 11, + expectPayload: []byte("hello world"), + }, + "nil stream no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + }, + expectContentLength: -1, + expectHeader: http.Header{}, + }, + }, + + "build handled": { + "http nil stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http empty stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(strings.NewReader(""))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: []byte{}, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https empty stream unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(&bytes.Buffer{})) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: nil, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http empty stream unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(&bytes.Buffer{})) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: nil, + expectPayloadHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https nil stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: -1, + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "https empty stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 0 + r = requestMust(r.SetStream(strings.NewReader(""))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 0, + expectPayload: []byte{}, + expectChecksumMetadata: map[string]string{ + "CRC32": "AAAAAA==", + }, + }, + "http no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "http no algorithm require checksum header preset": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-MD5", "XrY7u+Ae7tCTyyK7j1rNww==") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "https no algorithm require checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", + }, + }, + "http seekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "http payload hash already set": { + initContext: func(ctx context.Context) context.Context { + ctx = setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + ctx = v4.SetPayloadHash(ctx, "somehash") + return ctx + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "somehash", + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "http seekable checksum matches payload hash": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmSHA256)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Sha256": []string{"uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + expectChecksumMetadata: map[string]string{ + "SHA256": "uU0nuZNNPgilLlLX2n2r+sSE7+N6U4DukIj3rOLvzek=", + }, + }, + "http payload hash disabled": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableComputePayloadHash = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no trailing checksum": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "with content encoding set": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-Encoding", "gzip") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + "Content-Encoding": []string{"gzip"}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + }, + + "build error": { + "unknown algorithm": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string("unknown")) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "failed to parse algorithm", + expectBuildErr: true, + }, + "no algorithm require checksum unseekable stream": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.RequireChecksum = true + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + "http unseekable stream": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + "http stream read error": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 128 + r = requestMust(r.SetStream(&mockReadSeeker{ + Reader: iotest.ErrReader(fmt.Errorf("read error")), + })) + return r + }(), + }, + expectErr: "failed to read stream to compute hash", + expectBuildErr: true, + }, + "http stream rewind error": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(&errSeekReader{ + Reader: strings.NewReader("hello world"), + })) + return r + }(), + }, + expectErr: "failed to rewind stream", + expectBuildErr: true, + }, + "https no trailing unseekable stream": { + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableTrailingChecksum = false + }, + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectErr: "unseekable stream is not supported", + expectBuildErr: true, + }, + }, + + "finalize handled": { + "https unseekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https unseekable unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = -1 + r = requestMust(r.SetStream(ioutil.NopCloser(bytes.NewBuffer([]byte("hello world"))))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: -1, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https seekable": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https seekable unknown length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = -1 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no compute payload hash": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableComputePayloadHash = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "https no decode content length": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + optionsFn: func(o *computeInputPayloadChecksum) { + o.EnableDecodedContentLengthHeader = false + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "Content-Encoding": []string{"aws-chunked"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + }, + expectContentLength: 52, + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + "with content encoding set": { + initContext: func(ctx context.Context) context.Context { + return setContextInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r.Header.Set("Content-Encoding", "gzip") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + "X-Amz-Decoded-Content-Length": []string{"11"}, + "Content-Encoding": []string{"gzip", "aws-chunked"}, + }, + expectContentLength: 52, + expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectDeferToFinalize: true, + expectChecksumMetadata: map[string]string{ + "CRC32": "DUoRhQ==", + }, + }, + }, + } + + for name, cs := range cases { + t.Run(name, func(t *testing.T) { + for name, c := range cs { + t.Run(name, func(t *testing.T) { + m := &computeInputPayloadChecksum{ + EnableTrailingChecksum: true, + EnableComputePayloadHash: true, + EnableDecodedContentLengthHeader: true, + } + if c.optionsFn != nil { + c.optionsFn(m) + } + + ctx := context.Background() + var logged bytes.Buffer + logger := logging.LoggerFunc( + func(classification logging.Classification, format string, v ...interface{}) { + fmt.Fprintf(&logged, format, v...) + }, + ) + + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + middleware.AddSetLoggerMiddleware(stack, logger) + + //------------------------------ + // Build handler + //------------------------------ + // On return path validate any errors were expected. + stack.Build.Add(middleware.BuildMiddlewareFunc( + "build-assert", + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, + ) { + // ignore initial build input for the test case's build input. + out, metadata, err = next.HandleBuild(ctx, c.buildInput) + if err == nil && c.expectBuildErr { + t.Fatalf("expect build error, got none") + } + + if !m.buildHandlerRun { + t.Fatalf("expect build handler run") + } + return out, metadata, err + }, + ), middleware.After) + + // Build middleware + stack.Build.Add(m, middleware.After) + + // Validate defer to finalize was performed as expected + stack.Build.Add(middleware.BuildMiddlewareFunc( + "assert-defer-to-finalize", + func(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, + ) { + if e, a := c.expectDeferToFinalize, m.deferToFinalizeHandler; e != a { + t.Fatalf("expect %v defer to finalize, got %v", e, a) + } + return next.HandleBuild(ctx, input) + }, + ), middleware.After) + + //------------------------------ + // Finalize handler + //------------------------------ + if m.EnableTrailingChecksum { + // On return path assert any errors are expected. + stack.Finalize.Add(middleware.FinalizeMiddlewareFunc( + "build-assert", + func(ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + out, metadata, err = next.HandleFinalize(ctx, input) + if err == nil && c.expectFinalizeErr { + t.Fatalf("expect finalize error, got none") + } + + return out, metadata, err + }, + ), middleware.After) + + // Add finalize middleware + stack.Finalize.Add(m, middleware.After) + } + + //------------------------------ + // Request validation + //------------------------------ + validateRequestHandler := middleware.HandlerFunc( + func(ctx context.Context, input interface{}) ( + output interface{}, metadata middleware.Metadata, err error, + ) { + request := input.(*smithyhttp.Request) + + if diff := cmp.Diff(c.expectHeader, request.Header); diff != "" { + t.Errorf("expect header to match:\n%s", diff) + } + if e, a := c.expectContentLength, request.ContentLength; e != a { + t.Errorf("expect %v content length, got %v", e, a) + } + + stream := request.GetStream() + if e, a := stream != nil, c.expectPayload != nil; e != a { + t.Fatalf("expect nil payload %t, got %t", e, a) + } + if stream == nil { + return + } + + actualPayload, err := ioutil.ReadAll(stream) + if err == nil && c.expectReadErr { + t.Fatalf("expected read error, got none") + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match:\n%s", diff) + } + + payloadHash := v4.GetPayloadHash(ctx) + if e, a := c.expectPayloadHash, payloadHash; e != a { + t.Errorf("expect %v payload hash, got %v", e, a) + } + + return &smithyhttp.Response{}, metadata, nil + }, + ) + + if c.initContext != nil { + ctx = c.initContext(ctx) + } + _, metadata, err := stack.HandleMiddleware(ctx, struct{}{}, validateRequestHandler) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expected error: %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expected error %v to contain %v", err, c.expectErr) + } + if c.expectErr != "" { + return + } + + if c.expectLogged != "" { + if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) { + t.Errorf("expected %q logged in:\n%s", e, a) + } + } + + // assert computed input checksums metadata + computedMetadata, ok := GetComputedInputChecksums(metadata) + if e, a := ok, (c.expectChecksumMetadata != nil); e != a { + t.Fatalf("expect checksum metadata %t, got %t, %v", e, a, computedMetadata) + } + if c.expectChecksumMetadata != nil { + if diff := cmp.Diff(c.expectChecksumMetadata, computedMetadata); diff != "" { + t.Errorf("expect checksum metadata match\n%s", diff) + } + } + }) + } + }) + } +} + +type mockReadSeeker struct { + io.Reader +} + +func (r *mockReadSeeker) Seek(int64, int) (int64, error) { + return 0, nil +} + +type errSeekReader struct { + io.Reader +} + +func (r *errSeekReader) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekCurrent { + return 0, nil + } + + return 0, fmt.Errorf("seek failed") +} + +func requestMust(r *smithyhttp.Request, err error) *smithyhttp.Request { + if err != nil { + panic(err.Error()) + } + + return r +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go new file mode 100644 index 0000000000..f729525497 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context.go @@ -0,0 +1,117 @@ +package checksum + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +// setupChecksumContext is the initial middleware that looks up the input +// used to configure checksum behavior. This middleware must be executed before +// input validation step or any other checksum middleware. +type setupInputContext struct { + // GetAlgorithm is a function to get the checksum algorithm of the + // input payload from the input parameters. + // + // Given the input parameter value, the function must return the algorithm + // and true, or false if no algorithm is specified. + GetAlgorithm func(interface{}) (string, bool) +} + +// ID for the middleware +func (m *setupInputContext) ID() string { + return "AWSChecksum:SetupInputContext" +} + +// HandleInitialize initialization middleware that setups up the checksum +// context based on the input parameters provided in the stack. +func (m *setupInputContext) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // Check if validation algorithm is specified. + if m.GetAlgorithm != nil { + // check is input resource has a checksum algorithm + algorithm, ok := m.GetAlgorithm(in.Parameters) + if ok && len(algorithm) != 0 { + ctx = setContextInputAlgorithm(ctx, algorithm) + } + } + + return next.HandleInitialize(ctx, in) +} + +// inputAlgorithmKey is the key set on context used to identify, retrieves the +// request checksum algorithm if present on the context. +type inputAlgorithmKey struct{} + +// setContextInputAlgorithm sets the request checksum algorithm on the context. +// +// Scoped to stack values. +func setContextInputAlgorithm(ctx context.Context, value string) context.Context { + return middleware.WithStackValue(ctx, inputAlgorithmKey{}, value) +} + +// getContextInputAlgorithm returns the checksum algorithm from the context if +// one was specified. Empty string is returned if one is not specified. +// +// Scoped to stack values. +func getContextInputAlgorithm(ctx context.Context) (v string) { + v, _ = middleware.GetStackValue(ctx, inputAlgorithmKey{}).(string) + return v +} + +type setupOutputContext struct { + // GetValidationMode is a function to get the checksum validation + // mode of the output payload from the input parameters. + // + // Given the input parameter value, the function must return the validation + // mode and true, or false if no mode is specified. + GetValidationMode func(interface{}) (string, bool) +} + +// ID for the middleware +func (m *setupOutputContext) ID() string { + return "AWSChecksum:SetupOutputContext" +} + +// HandleInitialize initialization middleware that setups up the checksum +// context based on the input parameters provided in the stack. +func (m *setupOutputContext) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + // Check if validation mode is specified. + if m.GetValidationMode != nil { + // check is input resource has a checksum algorithm + mode, ok := m.GetValidationMode(in.Parameters) + if ok && len(mode) != 0 { + ctx = setContextOutputValidationMode(ctx, mode) + } + } + + return next.HandleInitialize(ctx, in) +} + +// outputValidationModeKey is the key set on context used to identify if +// output checksum validation is enabled. +type outputValidationModeKey struct{} + +// setContextOutputValidationMode sets the request checksum +// algorithm on the context. +// +// Scoped to stack values. +func setContextOutputValidationMode(ctx context.Context, value string) context.Context { + return middleware.WithStackValue(ctx, outputValidationModeKey{}, value) +} + +// getContextOutputValidationMode returns response checksum validation state, +// if one was specified. Empty string is returned if one is not specified. +// +// Scoped to stack values. +func getContextOutputValidationMode(ctx context.Context) (v string) { + v, _ = middleware.GetStackValue(ctx, outputValidationModeKey{}).(string) + return v +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go new file mode 100644 index 0000000000..3235983bad --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_setup_context_test.go @@ -0,0 +1,143 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "context" + "testing" + + "github.com/aws/smithy-go/middleware" +) + +func TestSetupInput(t *testing.T) { + type Params struct { + Value string + } + + cases := map[string]struct { + inputParams interface{} + getAlgorithm func(interface{}) (string, bool) + expectValue string + }{ + "nil accessor": { + expectValue: "", + }, + "found empty": { + inputParams: Params{Value: ""}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "", + }, + "found not set": { + inputParams: Params{Value: ""}, + getAlgorithm: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "", + }, + "found": { + inputParams: Params{Value: "abc123"}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := setupInputContext{ + GetAlgorithm: c.getAlgorithm, + } + + _, _, err := m.HandleInitialize(context.Background(), + middleware.InitializeInput{Parameters: c.inputParams}, + middleware.InitializeHandlerFunc( + func(ctx context.Context, input middleware.InitializeInput) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + v := getContextInputAlgorithm(ctx) + if e, a := c.expectValue, v; e != a { + t.Errorf("expect value %v, got %v", e, a) + } + + return out, metadata, nil + }, + )) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + }) + } +} + +func TestSetupOutput(t *testing.T) { + type Params struct { + Value string + } + + cases := map[string]struct { + inputParams interface{} + getValidationMode func(interface{}) (string, bool) + expectValue string + }{ + "nil accessor": { + expectValue: "", + }, + "found empty": { + inputParams: Params{Value: ""}, + getValidationMode: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "", + }, + "found not set": { + inputParams: Params{Value: ""}, + getValidationMode: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "", + }, + "found": { + inputParams: Params{Value: "abc123"}, + getValidationMode: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := setupOutputContext{ + GetValidationMode: c.getValidationMode, + } + + _, _, err := m.HandleInitialize(context.Background(), + middleware.InitializeInput{Parameters: c.inputParams}, + middleware.InitializeHandlerFunc( + func(ctx context.Context, input middleware.InitializeInput) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + v := getContextOutputValidationMode(ctx) + if e, a := c.expectValue, v; e != a { + t.Errorf("expect value %v, got %v", e, a) + } + + return out, metadata, nil + }, + )) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go new file mode 100644 index 0000000000..9fde12d86d --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output.go @@ -0,0 +1,131 @@ +package checksum + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// outputValidationAlgorithmsUsedKey is the metadata key for indexing the algorithms +// that were used, by the middleware's validation. +type outputValidationAlgorithmsUsedKey struct{} + +// GetOutputValidationAlgorithmsUsed returns the checksum algorithms used +// stored in the middleware Metadata. Returns false if no algorithms were +// stored in the Metadata. +func GetOutputValidationAlgorithmsUsed(m middleware.Metadata) ([]string, bool) { + vs, ok := m.Get(outputValidationAlgorithmsUsedKey{}).([]string) + return vs, ok +} + +// SetOutputValidationAlgorithmsUsed stores the checksum algorithms used in the +// middleware Metadata. +func SetOutputValidationAlgorithmsUsed(m *middleware.Metadata, vs []string) { + m.Set(outputValidationAlgorithmsUsedKey{}, vs) +} + +// validateOutputPayloadChecksum middleware computes payload checksum of the +// received response and validates with checksum returned by the service. +type validateOutputPayloadChecksum struct { + // Algorithms represents a priority-ordered list of valid checksum + // algorithm that should be validated when present in HTTP response + // headers. + Algorithms []Algorithm + + // IgnoreMultipartValidation indicates multipart checksums ending with "-#" + // will be ignored. + IgnoreMultipartValidation bool + + // When set the middleware will log when output does not have checksum or + // algorithm to validate. + LogValidationSkipped bool + + // When set the middleware will log when the output contains a multipart + // checksum that was, skipped and not validated. + LogMultipartValidationSkipped bool +} + +func (m *validateOutputPayloadChecksum) ID() string { + return "AWSChecksum:ValidateOutputPayloadChecksum" +} + +// HandleDeserialize is a Deserialize middleware that wraps the HTTP response +// body with an io.ReadCloser that will validate the its checksum. +func (m *validateOutputPayloadChecksum) HandleDeserialize( + ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, +) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + + // If there is no validation mode specified nothing is supported. + if mode := getContextOutputValidationMode(ctx); mode != "ENABLED" { + return out, metadata, err + } + + response, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + return out, metadata, &smithy.DeserializationError{ + Err: fmt.Errorf("unknown transport type %T", out.RawResponse), + } + } + + var expectedChecksum string + var algorithmToUse Algorithm + for _, algorithm := range m.Algorithms { + value := response.Header.Get(AlgorithmHTTPHeader(algorithm)) + if len(value) == 0 { + continue + } + + expectedChecksum = value + algorithmToUse = algorithm + } + + // TODO this must validate the validation mode is set to enabled. + + logger := middleware.GetLogger(ctx) + + // Skip validation if no checksum algorithm or checksum is available. + if len(expectedChecksum) == 0 || len(algorithmToUse) == 0 { + if m.LogValidationSkipped { + // TODO this probably should have more information about the + // operation output that won't be validated. + logger.Logf(logging.Warn, + "Response has no supported checksum. Not validating response payload.") + } + return out, metadata, nil + } + + // Ignore multipart validation + if m.IgnoreMultipartValidation && strings.Contains(expectedChecksum, "-") { + if m.LogMultipartValidationSkipped { + // TODO this probably should have more information about the + // operation output that won't be validated. + logger.Logf(logging.Warn, "Skipped validation of multipart checksum.") + } + return out, metadata, nil + } + + body, err := newValidateChecksumReader(response.Body, algorithmToUse, expectedChecksum) + if err != nil { + return out, metadata, fmt.Errorf("failed to create checksum validation reader, %w", err) + } + response.Body = body + + // Update the metadata to include the set of the checksum algorithms that + // will be validated. + SetOutputValidationAlgorithmsUsed(&metadata, []string{ + string(algorithmToUse), + }) + + return out, metadata, nil +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go new file mode 100644 index 0000000000..e6cf7054e8 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/middleware_validate_output_test.go @@ -0,0 +1,263 @@ +//go:build go1.16 +// +build go1.16 + +package checksum + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + "testing/iotest" + + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" +) + +func TestValidateOutputPayloadChecksum(t *testing.T) { + cases := map[string]struct { + response *smithyhttp.Response + validateOptions func(*validateOutputPayloadChecksum) + modifyContext func(context.Context) context.Context + expectHaveAlgorithmsUsed bool + expectAlgorithmsUsed []string + expectErr string + expectReadErr string + expectLogged string + expectPayload []byte + }{ + "success": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectHaveAlgorithmsUsed: true, + expectAlgorithmsUsed: []string{"CRC32"}, + expectPayload: []byte("hello world"), + }, + "failure": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectReadErr: "checksum did not match", + }, + "read error": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(iotest.ErrReader(fmt.Errorf("some read error"))), + }, + }, + expectReadErr: "some read error", + }, + "unsupported algorithm": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader("unsupported"), "AAAAAA==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectLogged: "no supported checksum", + expectPayload: []byte("hello world"), + }, + "no output validation model": { + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectPayload: []byte("hello world"), + }, + "unknown output validation model": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "something else") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + expectPayload: []byte("hello world"), + }, + "success ignore multipart checksum": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + validateOptions: func(o *validateOutputPayloadChecksum) { + o.IgnoreMultipartValidation = true + }, + expectHaveAlgorithmsUsed: true, + expectAlgorithmsUsed: []string{"CRC32"}, + expectPayload: []byte("hello world"), + }, + "success skip ignore multipart checksum": { + modifyContext: func(ctx context.Context) context.Context { + return setContextOutputValidationMode(ctx, "ENABLED") + }, + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "DUoRhQ==-12") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("hello world")), + }, + }, + validateOptions: func(o *validateOutputPayloadChecksum) { + o.IgnoreMultipartValidation = true + }, + expectLogged: "Skipped validation of multipart checksum", + expectPayload: []byte("hello world"), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var logged bytes.Buffer + ctx := middleware.SetLogger(context.Background(), logging.LoggerFunc( + func(classification logging.Classification, format string, v ...interface{}) { + fmt.Fprintf(&logged, format, v...) + })) + + if c.modifyContext != nil { + ctx = c.modifyContext(ctx) + } + + validateOutput := validateOutputPayloadChecksum{ + Algorithms: []Algorithm{ + AlgorithmSHA1, AlgorithmCRC32, AlgorithmCRC32C, + }, + LogValidationSkipped: true, + LogMultipartValidationSkipped: true, + } + if c.validateOptions != nil { + c.validateOptions(&validateOutput) + } + + out, meta, err := validateOutput.HandleDeserialize(ctx, + middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc( + func(ctx context.Context, input middleware.DeserializeInput) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out.RawResponse = c.response + return out, metadata, nil + }, + ), + ) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + response := out.RawResponse.(*smithyhttp.Response) + + actualPayload, err := ioutil.ReadAll(response.Body) + if err == nil && len(c.expectReadErr) != 0 { + t.Fatalf("expected read error: %v, got none", c.expectReadErr) + } + if err != nil && len(c.expectReadErr) == 0 { + t.Fatalf("expect no read error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectReadErr) { + t.Fatalf("expected read error %v to contain %v", err, c.expectReadErr) + } + if c.expectReadErr != "" { + return + } + + if e, a := c.expectLogged, logged.String(); !strings.Contains(a, e) || !((e == "") == (a == "")) { + t.Errorf("expected %q logged in:\n%s", e, a) + } + + if diff := cmp.Diff(string(c.expectPayload), string(actualPayload)); diff != "" { + t.Errorf("expect payload match:\n%s", diff) + } + + if err = response.Body.Close(); err != nil { + t.Errorf("expect no close error, got %v", err) + } + + values, ok := GetOutputValidationAlgorithmsUsed(meta) + if ok != c.expectHaveAlgorithmsUsed { + t.Errorf("expect metadata to contain algorithms used, %t", c.expectHaveAlgorithmsUsed) + } + if diff := cmp.Diff(c.expectAlgorithmsUsed, values); diff != "" { + t.Errorf("expect algorithms used to match\n%s", diff) + } + }) + } +} diff --git a/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make new file mode 100644 index 0000000000..85e34f039f --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/ya.make @@ -0,0 +1,28 @@ +GO_LIBRARY() + +LICENSE(Apache-2.0) + +SRCS( + algorithms.go + aws_chunked_encoding.go + go_module_metadata.go + middleware_add.go + middleware_compute_input_checksum.go + middleware_setup_context.go + middleware_validate_output.go +) + +GO_TEST_SRCS( + algorithms_test.go + aws_chunked_encoding_test.go + middleware_add_test.go + middleware_compute_input_checksum_test.go + middleware_setup_context_test.go + middleware_validate_output_test.go +) + +END() + +RECURSE( + gotest +) |