aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/aws-sdk-go-v2/service/internal/checksum/algorithms.go
blob: a17041c35d0710508e09f8bf18f799b87d9a0fbf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
package checksum

import (
	"crypto/md5"
	"crypto/sha1"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"hash"
	"hash/crc32"
	"io"
	"strings"
	"sync"
)

// Algorithm represents the checksum algorithms supported
type Algorithm string

// Enumeration values for supported checksum Algorithms.
const (
	// AlgorithmCRC32C represents CRC32C hash algorithm
	AlgorithmCRC32C Algorithm = "CRC32C"

	// AlgorithmCRC32 represents CRC32 hash algorithm
	AlgorithmCRC32 Algorithm = "CRC32"

	// AlgorithmSHA1 represents SHA1 hash algorithm
	AlgorithmSHA1 Algorithm = "SHA1"

	// AlgorithmSHA256 represents SHA256 hash algorithm
	AlgorithmSHA256 Algorithm = "SHA256"
)

var supportedAlgorithms = []Algorithm{
	AlgorithmCRC32C,
	AlgorithmCRC32,
	AlgorithmSHA1,
	AlgorithmSHA256,
}

func (a Algorithm) String() string { return string(a) }

// ParseAlgorithm attempts to parse the provided value into a checksum
// algorithm, matching without case. Returns the algorithm matched, or an error
// if the algorithm wasn't matched.
func ParseAlgorithm(v string) (Algorithm, error) {
	for _, a := range supportedAlgorithms {
		if strings.EqualFold(string(a), v) {
			return a, nil
		}
	}
	return "", fmt.Errorf("unknown checksum algorithm, %v", v)
}

// FilterSupportedAlgorithms filters the set of algorithms, returning a slice
// of algorithms that are supported.
func FilterSupportedAlgorithms(vs []string) []Algorithm {
	found := map[Algorithm]struct{}{}

	supported := make([]Algorithm, 0, len(supportedAlgorithms))
	for _, v := range vs {
		for _, a := range supportedAlgorithms {
			// Only consider algorithms that are supported
			if !strings.EqualFold(v, string(a)) {
				continue
			}
			// Ignore duplicate algorithms in list.
			if _, ok := found[a]; ok {
				continue
			}

			supported = append(supported, a)
			found[a] = struct{}{}
		}
	}
	return supported
}

// NewAlgorithmHash returns a hash.Hash for the checksum algorithm. Error is
// returned if the algorithm is unknown.
func NewAlgorithmHash(v Algorithm) (hash.Hash, error) {
	switch v {
	case AlgorithmSHA1:
		return sha1.New(), nil
	case AlgorithmSHA256:
		return sha256.New(), nil
	case AlgorithmCRC32:
		return crc32.NewIEEE(), nil
	case AlgorithmCRC32C:
		return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil
	default:
		return nil, fmt.Errorf("unknown checksum algorithm, %v", v)
	}
}

// AlgorithmChecksumLength returns the length of the algorithm's checksum in
// bytes. If the algorithm is not known, an error is returned.
func AlgorithmChecksumLength(v Algorithm) (int, error) {
	switch v {
	case AlgorithmSHA1:
		return sha1.Size, nil
	case AlgorithmSHA256:
		return sha256.Size, nil
	case AlgorithmCRC32:
		return crc32.Size, nil
	case AlgorithmCRC32C:
		return crc32.Size, nil
	default:
		return 0, fmt.Errorf("unknown checksum algorithm, %v", v)
	}
}

const awsChecksumHeaderPrefix = "x-amz-checksum-"

// AlgorithmHTTPHeader returns the HTTP header for the algorithm's hash.
func AlgorithmHTTPHeader(v Algorithm) string {
	return awsChecksumHeaderPrefix + strings.ToLower(string(v))
}

// base64EncodeHashSum computes base64 encoded checksum of a given running
// hash. The running hash must already have content written to it. Returns the
// byte slice of checksum and an error
func base64EncodeHashSum(h hash.Hash) []byte {
	sum := h.Sum(nil)
	sum64 := make([]byte, base64.StdEncoding.EncodedLen(len(sum)))
	base64.StdEncoding.Encode(sum64, sum)
	return sum64
}

// hexEncodeHashSum computes hex encoded checksum of a given running hash. The
// running hash must already have content written to it. Returns the byte slice
// of checksum and an error
func hexEncodeHashSum(h hash.Hash) []byte {
	sum := h.Sum(nil)
	sumHex := make([]byte, hex.EncodedLen(len(sum)))
	hex.Encode(sumHex, sum)
	return sumHex
}

// computeMD5Checksum computes base64 MD5 checksum of an io.Reader's contents.
// Returns the byte slice of MD5 checksum and an error.
func computeMD5Checksum(r io.Reader) ([]byte, error) {
	h := md5.New()

	// Copy errors may be assumed to be from the body.
	if _, err := io.Copy(h, r); err != nil {
		return nil, fmt.Errorf("failed compute MD5 hash of reader, %w", err)
	}

	// Encode the MD5 checksum in base64.
	return base64EncodeHashSum(h), nil
}

// computeChecksumReader provides a reader wrapping an underlying io.Reader to
// compute the checksum of the stream's bytes.
type computeChecksumReader struct {
	stream            io.Reader
	algorithm         Algorithm
	hasher            hash.Hash
	base64ChecksumLen int

	mux            sync.RWMutex
	lockedChecksum string
	lockedErr      error
}

// newComputeChecksumReader returns a computeChecksumReader for the stream and
// algorithm specified. Returns error if unable to create the reader, or
// algorithm is unknown.
func newComputeChecksumReader(stream io.Reader, algorithm Algorithm) (*computeChecksumReader, error) {
	hasher, err := NewAlgorithmHash(algorithm)
	if err != nil {
		return nil, err
	}

	checksumLength, err := AlgorithmChecksumLength(algorithm)
	if err != nil {
		return nil, err
	}

	return &computeChecksumReader{
		stream:            io.TeeReader(stream, hasher),
		algorithm:         algorithm,
		hasher:            hasher,
		base64ChecksumLen: base64.StdEncoding.EncodedLen(checksumLength),
	}, nil
}

// Read wraps the underlying reader. When the underlying reader returns EOF,
// the checksum of the reader will be computed, and can be retrieved with
// ChecksumBase64String.
func (r *computeChecksumReader) Read(p []byte) (int, error) {
	n, err := r.stream.Read(p)
	if err == nil {
		return n, nil
	} else if err != io.EOF {
		r.mux.Lock()
		defer r.mux.Unlock()

		r.lockedErr = err
		return n, err
	}

	b := base64EncodeHashSum(r.hasher)

	r.mux.Lock()
	defer r.mux.Unlock()

	r.lockedChecksum = string(b)

	return n, err
}

func (r *computeChecksumReader) Algorithm() Algorithm {
	return r.algorithm
}

// Base64ChecksumLength returns the base64 encoded length of the checksum for
// algorithm.
func (r *computeChecksumReader) Base64ChecksumLength() int {
	return r.base64ChecksumLen
}

// Base64Checksum returns the base64 checksum for the algorithm, or error if
// the underlying reader returned a non-EOF error.
//
// Safe to be called concurrently, but will return an error until after the
// underlying reader is returns EOF.
func (r *computeChecksumReader) Base64Checksum() (string, error) {
	r.mux.RLock()
	defer r.mux.RUnlock()

	if r.lockedErr != nil {
		return "", r.lockedErr
	}

	if r.lockedChecksum == "" {
		return "", fmt.Errorf(
			"checksum not available yet, called before reader returns EOF",
		)
	}

	return r.lockedChecksum, nil
}

// validateChecksumReader implements io.ReadCloser interface. The wrapper
// performs checksum validation when the underlying reader has been fully read.
type validateChecksumReader struct {
	originalBody   io.ReadCloser
	body           io.Reader
	hasher         hash.Hash
	algorithm      Algorithm
	expectChecksum string
}

// newValidateChecksumReader returns a configured io.ReadCloser that performs
// checksum validation when the underlying reader has been fully read.
func newValidateChecksumReader(
	body io.ReadCloser,
	algorithm Algorithm,
	expectChecksum string,
) (*validateChecksumReader, error) {
	hasher, err := NewAlgorithmHash(algorithm)
	if err != nil {
		return nil, err
	}

	return &validateChecksumReader{
		originalBody:   body,
		body:           io.TeeReader(body, hasher),
		hasher:         hasher,
		algorithm:      algorithm,
		expectChecksum: expectChecksum,
	}, nil
}

// Read attempts to read from the underlying stream while also updating the
// running hash. If the underlying stream returns with an EOF error, the
// checksum of the stream will be collected, and compared against the expected
// checksum. If the checksums do not match, an error will be returned.
//
// If a non-EOF error occurs when reading the underlying stream, that error
// will be returned and the checksum for the stream will be discarded.
func (c *validateChecksumReader) Read(p []byte) (n int, err error) {
	n, err = c.body.Read(p)
	if err == io.EOF {
		if checksumErr := c.validateChecksum(); checksumErr != nil {
			return n, checksumErr
		}
	}

	return n, err
}

// Close closes the underlying reader, returning any error that occurred in the
// underlying reader.
func (c *validateChecksumReader) Close() (err error) {
	return c.originalBody.Close()
}

func (c *validateChecksumReader) validateChecksum() error {
	// Compute base64 encoded checksum hash of the payload's read bytes.
	v := base64EncodeHashSum(c.hasher)
	if e, a := c.expectChecksum, string(v); !strings.EqualFold(e, a) {
		return validationError{
			Algorithm: c.algorithm, Expect: e, Actual: a,
		}
	}

	return nil
}

type validationError struct {
	Algorithm Algorithm
	Expect    string
	Actual    string
}

func (v validationError) Error() string {
	return fmt.Sprintf("checksum did not match: algorithm %v, expect %v, actual %v",
		v.Algorithm, v.Expect, v.Actual)
}