diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/s3transfer/py3/tests/unit | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/s3transfer/py3/tests/unit')
15 files changed, 7169 insertions, 0 deletions
diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py new file mode 100644 index 0000000000..79ef91c6a2 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py new file mode 100644 index 0000000000..b796f8f24c --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py @@ -0,0 +1,452 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile + +from s3transfer.bandwidth import ( + BandwidthLimitedStream, + BandwidthLimiter, + BandwidthRateTracker, + ConsumptionScheduler, + LeakyBucket, + RequestExceededException, + RequestToken, + TimeUtils, +) +from s3transfer.futures import TransferCoordinator +from __tests__ import mock, unittest + + +class FixedIncrementalTickTimeUtils(TimeUtils): + def __init__(self, seconds_per_tick=1.0): + self._count = 0 + self._seconds_per_tick = seconds_per_tick + + def time(self): + current_count = self._count + self._count += self._seconds_per_tick + return current_count + + +class TestTimeUtils(unittest.TestCase): + @mock.patch('time.time') + def test_time(self, mock_time): + mock_return_val = 1 + mock_time.return_value = mock_return_val + time_utils = TimeUtils() + self.assertEqual(time_utils.time(), mock_return_val) + + @mock.patch('time.sleep') + def test_sleep(self, mock_sleep): + time_utils = TimeUtils() + time_utils.sleep(1) + self.assertEqual(mock_sleep.call_args_list, [mock.call(1)]) + + +class BaseBandwidthLimitTest(unittest.TestCase): + def setUp(self): + self.leaky_bucket = mock.Mock(LeakyBucket) + self.time_utils = mock.Mock(TimeUtils) + self.tempdir = tempfile.mkdtemp() + self.content = b'a' * 1024 * 1024 + self.filename = os.path.join(self.tempdir, 'myfile') + with open(self.filename, 'wb') as f: + f.write(self.content) + self.coordinator = TransferCoordinator() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def assert_consume_calls(self, amts): + expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts] + self.assertEqual( + self.leaky_bucket.consume.call_args_list, expected_consume_args + ) + + +class TestBandwidthLimiter(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) + + def test_get_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.assert_consume_calls(amts=[len(self.content)]) + + def test_get_disabled_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator, enabled=False + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.leaky_bucket.consume.assert_not_called() + + +class TestBandwidthLimitedStream(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bytes_threshold = 1 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def get_bandwidth_limited_stream(self, f): + return BandwidthLimitedStream( + f, + self.leaky_bucket, + self.coordinator, + self.time_utils, + self.bytes_threshold, + ) + + def assert_sleep_calls(self, amts): + expected_sleep_args_list = [mock.call(amt) for amt in amts] + self.assertEqual( + self.time_utils.sleep.call_args_list, expected_sleep_args_list + ) + + def get_unique_consume_request_tokens(self): + return { + call_args[0][1] + for call_args in self.leaky_bucket.consume.call_args_list + } + + def test_read(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[len(self.content)]) + self.assert_sleep_calls(amts=[]) + + def test_retries_on_request_exceeded(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + retry_time = 1 + amt_requested = len(self.content) + self.leaky_bucket.consume.side_effect = [ + RequestExceededException(amt_requested, retry_time), + len(self.content), + ] + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[amt_requested, amt_requested]) + self.assert_sleep_calls(amts=[retry_time]) + + def test_with_transfer_coordinator_exception(self): + self.coordinator.set_exception(ValueError()) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + with self.assertRaises(ValueError): + stream.read(len(self.content)) + + def test_read_when_bandwidth_limiting_disabled(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assertFalse(self.leaky_bucket.consume.called) + + def test_read_toggle_disable_enable_bandwidth_limiting(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.enable_bandwidth_limiting() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + def test_seek(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.seek(1) + self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)]) + + def test_tell(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.tell() + self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()]) + + def test_close(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.close() + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_context_manager(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + with stream as stream_handle: + self.assertIs(stream_handle, stream) + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_reuses_request_token(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) + + def test_request_tokens_unique_per_stream(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) + + def test_call_consume_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + self.assertEqual(stream.read(1), self.content[1:2]) + self.assert_consume_calls(amts=[2]) + + def test_resets_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + self.assertEqual(stream.read(1), self.content[2:3]) + self.assert_consume_calls(amts=[2]) + + def test_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.close() + self.assert_consume_calls(amts=[1]) + + def test_no_bytes_remaining_on(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + stream.close() + # There should have been no more consume() calls made + # as all bytes have been accounted for in the previous + # consume() call. + self.assert_consume_calls(amts=[2]) + + def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.disable_bandwidth_limiting() + stream.close() + self.assert_consume_calls(amts=[]) + + def test_signal_transferring(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.signal_not_transferring() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.signal_transferring() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + +class TestLeakyBucket(unittest.TestCase): + def setUp(self): + self.max_rate = 1 + self.time_now = 1.0 + self.time_utils = mock.Mock(TimeUtils) + self.time_utils.time.return_value = self.time_now + self.scheduler = mock.Mock(ConsumptionScheduler) + self.scheduler.is_scheduled.return_value = False + self.rate_tracker = mock.Mock(BandwidthRateTracker) + self.leaky_bucket = LeakyBucket( + self.max_rate, self.time_utils, self.rate_tracker, self.scheduler + ) + + def set_projected_rate(self, rate): + self.rate_tracker.get_projected_rate.return_value = rate + + def set_retry_time(self, retry_time): + self.scheduler.schedule_consumption.return_value = retry_time + + def assert_recorded_consumed_amt(self, expected_amt): + self.assertEqual( + self.rate_tracker.record_consumption_rate.call_args, + mock.call(expected_amt, self.time_utils.time.return_value), + ) + + def assert_was_scheduled(self, amt, token): + self.assertEqual( + self.scheduler.schedule_consumption.call_args, + mock.call(amt, token, amt / (self.max_rate)), + ) + + def assert_nothing_scheduled(self): + self.assertFalse(self.scheduler.schedule_consumption.called) + + def assert_processed_request_token(self, request_token): + self.assertEqual( + self.scheduler.process_scheduled_consumption.call_args, + mock.call(request_token), + ) + + def test_consume_under_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate / 2) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_at_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_over_max_rate(self): + amt = 1 + retry_time = 2.0 + self.set_projected_rate(self.max_rate + 1) + self.set_retry_time(retry_time) + request_token = RequestToken() + try: + self.leaky_bucket.consume(amt, request_token) + self.fail('A RequestExceededException should have been thrown') + except RequestExceededException as e: + self.assertEqual(e.requested_amt, amt) + self.assertEqual(e.retry_time, retry_time) + self.assert_was_scheduled(amt, request_token) + + def test_consume_with_scheduled_retry(self): + amt = 1 + self.set_projected_rate(self.max_rate + 1) + self.scheduler.is_scheduled.return_value = True + request_token = RequestToken() + self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) + # Nothing new should have been scheduled but the request token + # should have been processed. + self.assert_nothing_scheduled() + self.assert_processed_request_token(request_token) + + +class TestConsumptionScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = ConsumptionScheduler() + + def test_schedule_consumption(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + def test_schedule_consumption_for_multiple_requests(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + other_consume_time = 3 + other_token = RequestToken() + next_wait_time = self.scheduler.schedule_consumption( + 1, other_token, other_consume_time + ) + + # This wait time should be the previous time plus its desired + # wait time + self.assertEqual(next_wait_time, consume_time + other_consume_time) + + def test_is_scheduled(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.assertTrue(self.scheduler.is_scheduled(token)) + + def test_is_not_scheduled(self): + self.assertFalse(self.scheduler.is_scheduled(RequestToken())) + + def test_process_scheduled_consumption(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.scheduler.process_scheduled_consumption(token) + self.assertFalse(self.scheduler.is_scheduled(token)) + different_time = 7 + # The previous consume time should have no affect on the next wait tim + # as it has been completed. + self.assertEqual( + self.scheduler.schedule_consumption(1, token, different_time), + different_time, + ) + + +class TestBandwidthRateTracker(unittest.TestCase): + def setUp(self): + self.alpha = 0.8 + self.rate_tracker = BandwidthRateTracker(self.alpha) + + def test_current_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate_after_one_recorded_point(self): + self.rate_tracker.record_consumption_rate(1, 1) + # There is no last time point to do a diff against so return a + # current rate of 0.0 + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, 0.96) + + def test_get_projected_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) + + def test_get_projected_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + projected_rate = self.rate_tracker.get_projected_rate(1, 3) + self.assertEqual(projected_rate, 0.96) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, projected_rate) + + def test_get_projected_rate_for_same_timestamp(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.assertEqual( + self.rate_tracker.get_projected_rate(1, 1), float('inf') + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py new file mode 100644 index 0000000000..78fdc25845 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py @@ -0,0 +1,105 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import signal +import tempfile +from io import BytesIO + +from s3transfer.compat import BaseManager, readable, seekable +from __tests__ import skip_if_windows, unittest + + +class ErrorRaisingSeekWrapper: + """An object wrapper that throws an error when seeked on + + :param fileobj: The fileobj that it wraps + :param exception: The exception to raise when seeked on. + """ + + def __init__(self, fileobj, exception): + self._fileobj = fileobj + self._exception = exception + + def seek(self, offset, whence=0): + raise self._exception + + def tell(self): + return self._fileobj.tell() + + +class TestSeekable(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_seekable_fileobj(self): + with open(self.filename, 'w') as f: + self.assertTrue(seekable(f)) + + def test_non_file_like_obj(self): + # Fails because there is no seekable(), seek(), nor tell() + self.assertFalse(seekable(object())) + + def test_non_seekable_ioerror(self): + # Should return False if IOError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError()))) + + def test_non_seekable_oserror(self): + # Should return False if OSError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError()))) + + +class TestReadable(unittest.TestCase): + def test_readable_fileobj(self): + with tempfile.TemporaryFile() as f: + self.assertTrue(readable(f)) + + def test_readable_file_like_obj(self): + self.assertTrue(readable(BytesIO())) + + def test_non_file_like_obj(self): + self.assertFalse(readable(object())) + + +class TestBaseManager(unittest.TestCase): + def create_pid_manager(self): + class PIDManager(BaseManager): + pass + + PIDManager.register('getpid', os.getpid) + return PIDManager() + + def get_pid(self, pid_manager): + pid = pid_manager.getpid() + # A proxy object is returned back. The needed value can be acquired + # from the repr and converting that to an integer + return int(str(pid)) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_can_provide_signal_handler_initializers_to_start(self): + manager = self.create_pid_manager() + manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN)) + pid = self.get_pid(manager) + try: + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + # Try using the manager after the os.kill on the parent process. The + # manager should not have died and should still be usable. + self.assertEqual(pid, self.get_pid(manager)) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py new file mode 100644 index 0000000000..4e6992a28b --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py @@ -0,0 +1,207 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.copies import CopyObjectTask, CopyPartTask +from __tests__ import BaseTaskTest, RecordingSubscriber + + +class BaseCopyTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} + self.extra_args = {} + self.callbacks = [] + self.size = 5 + + +class TestCopyObjectTask(BaseCopyTaskTest): + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['ACL'] = 'private' + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'ACL': 'private', + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) + + +class TestCopyPartTask(BaseCopyTaskTest): + def setUp(self): + super().setUp() + self.copy_source_range = 'bytes=5-9' + self.extra_args['CopySourceRange'] = self.copy_source_range + self.upload_id = 'myuploadid' + self.part_number = 1 + self.result_etag = 'my-etag' + self.checksum_sha1 = 'my-checksum_sha1' + + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': self.upload_id, + 'part_number': self.part_number, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyPartTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_main_with_checksum(self): + self.stubber.add_response( + 'upload_part_copy', + service_response={ + 'CopyPartResult': { + 'ETag': self.result_etag, + 'ChecksumSHA1': self.checksum_sha1, + } + }, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task(checksum_algorithm="sha1") + self.assertEqual( + task(), + { + 'PartNumber': self.part_number, + 'ETag': self.result_etag, + 'ChecksumSHA1': self.checksum_sha1, + }, + ) + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['RequestPayer'] = 'requester' + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + 'RequestPayer': 'requester', + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py new file mode 100644 index 0000000000..8c32668eab --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py @@ -0,0 +1,173 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.credentials import CredentialResolver, ReadOnlyCredentials +from botocore.session import Session + +from s3transfer.exceptions import TransferNotDoneError +from s3transfer.utils import CallArgs +from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + +if HAS_CRT: + import awscrt.s3 + + import s3transfer.crt + + +class CustomFutureException(Exception): + pass + + +@requires_crt +class TestBotocoreCRTRequestSerializer(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.session = Session() + self.session.set_config_variable('region', self.region) + self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( + self.session + ) + self.bucket = "test_bucket" + self.key = "test_key" + self.files = FileCreator() + self.filename = self.files.create_file('myfile', 'my content') + self.expected_path = "/" + self.bucket + "/" + self.key + self.expected_host = "s3.%s.amazonaws.com" % (self.region) + + def tearDown(self): + self.files.remove_all() + + def test_upload_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "put_object", future + ) + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_download_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "get_object", future + ) + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_delete_request(self): + callargs = CallArgs( + bucket=self.bucket, key=self.key, extra_args={}, subscribers=[] + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "delete_object", future + ) + self.assertEqual("DELETE", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + +@requires_crt +class TestCRTCredentialProviderAdapter(unittest.TestCase): + def setUp(self): + self.botocore_credential_provider = mock.Mock(CredentialResolver) + self.access_key = "access_key" + self.secret_key = "secret_key" + self.token = "token" + self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( + self.access_key, self.secret_key, self.token + ) + + def _call_adapter_and_check(self, credentails_provider_adapter): + credentials = credentails_provider_adapter() + self.assertEqual(credentials.access_key_id, self.access_key) + self.assertEqual(credentials.secret_access_key, self.secret_key) + self.assertEqual(credentials.session_token, self.token) + + def test_fetch_crt_credentials_successfully(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + self._call_adapter_and_check(credentails_provider_adapter) + + def test_load_credentials_once(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + called_times = 5 + for i in range(called_times): + self._call_adapter_and_check(credentails_provider_adapter) + # Assert that the load_credentails of botocore credential provider + # will only be called once + self.assertEqual( + self.botocore_credential_provider.load_credentials.call_count, 1 + ) + + +@requires_crt +class TestCRTTransferFuture(unittest.TestCase): + def setUp(self): + self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType) + self.mock_crt_future = mock.Mock(awscrt.s3.Future) + self.mock_s3_request.finished_future = self.mock_crt_future + self.coordinator = s3transfer.crt.CRTTransferCoordinator() + self.coordinator.set_s3_request(self.mock_s3_request) + self.future = s3transfer.crt.CRTTransferFuture( + coordinator=self.coordinator + ) + + def test_set_exception(self): + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() + + def test_set_exception_raises_error_when_not_done(self): + self.mock_crt_future.done.return_value = False + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(CustomFutureException()) + + def test_set_exception_can_override_previous_exception(self): + self.future.set_exception(Exception()) + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py new file mode 100644 index 0000000000..23b77112f2 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py @@ -0,0 +1,67 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.delete import DeleteObjectTask +from __tests__ import BaseTaskTest + + +class TestDeleteObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + + def get_delete_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + } + default_kwargs.update(kwargs) + return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['MFA'] = 'mfa-code' + self.extra_args['VersionId'] = '12345' + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + # These extra_args should be injected into the + # expected params for the delete_object call. + 'MFA': 'mfa-code', + 'VersionId': '12345', + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py new file mode 100644 index 0000000000..e8b5fe1f86 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py @@ -0,0 +1,999 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import os +import shutil +import tempfile +from io import BytesIO + +from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.compat import SOCKET_ERROR +from s3transfer.download import ( + CompleteDownloadNOOPTask, + DeferQueue, + DownloadChunkIterator, + DownloadFilenameOutputManager, + DownloadNonSeekableOutputManager, + DownloadSeekableOutputManager, + DownloadSpecialFilenameOutputManager, + DownloadSubmissionTask, + GetObjectTask, + ImmediatelyWriteIOGetObjectTask, + IOCloseTask, + IORenameFileTask, + IOStreamingWriteTask, + IOWriteTask, +) +from s3transfer.exceptions import RetriesExceededError +from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileCreator, + NonSeekableWriter, + RecordingExecutor, + StreamWithError, + mock, + unittest, +) + + +class DownloadException(Exception): + pass + + +class WriteCollector: + """A utility to collect information about writes and seeks""" + + def __init__(self): + self._pos = 0 + self.writes = [] + + def seek(self, pos, whence=0): + self._pos = pos + + def write(self, data): + self.writes.append((self._pos, data)) + self._pos += len(data) + + +class AlwaysIndicatesSpecialFileOSUtils(OSUtils): + """OSUtil that always returns True for is_special_file""" + + def is_special_file(self, filename): + return True + + +class CancelledStreamWrapper: + """A wrapper to trigger a cancellation while stream reading + + Forces the transfer coordinator to cancel after a certain amount of reads + :param stream: The underlying stream to read from + :param transfer_coordinator: The coordinator for the transfer + :param num_reads: On which read to signal a cancellation. 0 is the first + read. + """ + + def __init__(self, stream, transfer_coordinator, num_reads=0): + self._stream = stream + self._transfer_coordinator = transfer_coordinator + self._num_reads = num_reads + self._count = 0 + + def read(self, *args, **kwargs): + if self._num_reads == self._count: + self._transfer_coordinator.cancel() + self._stream.read(*args, **kwargs) + self._count += 1 + + +class BaseDownloadOutputManagerTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + + # Create a file to write to + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.call_args = CallArgs(fileobj=self.filename) + self.future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + +class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + self.assertTrue(hasattr(f, 'seek')) + # Make sure the name of the file returned is not the same as the + # final filename as we should be writing to a temporary file. + self.assertNotEqual(f.name, self.filename) + + def test_get_final_io_task(self): + ref_contents = b'my_contents' + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + temp_filename = f.name + # Write some data to test that the data gets moved over to the + # final location. + f.write(ref_contents) + final_task = self.download_output_manager.get_final_io_task() + # Make sure it is the appropriate task. + self.assertIsInstance(final_task, IORenameFileTask) + final_task() + # Make sure the temp_file gets removed + self.assertFalse(os.path.exists(temp_filename)) + # Make sure what ever was written to the temp file got moved to + # the final filename + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), ref_contents) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.osutil = AlwaysIndicatesSpecialFileOSUtils() + self.download_output_manager = DownloadSpecialFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible_for_special_file(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, AlwaysIndicatesSpecialFileOSUtils() + ) + ) + + def test_is_not_compatible_for_non_special_file(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, OSUtils() + ) + ) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + # Make sure the name of the file returned is the same as the + # final filename as we should not be writing to a temporary file. + self.assertEqual(f.name, self.filename) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), IOCloseTask + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + +class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + # Create a fileobj to write to + self.fileobj = open(self.filename, 'wb') + + self.call_args = CallArgs(fileobj=self.fileobj) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.fileobj, self.osutil + ) + ) + + def test_is_compatible_bytes_io(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse( + self.download_output_manager.is_compatible(object(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + self.assertIs( + self.download_output_manager.get_fileobj_for_io_writes( + self.future + ), + self.fileobj, + ) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), + CompleteDownloadNOOPTask, + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadNonSeekableOutputManager( + self.osutil, self.transfer_coordinator, io_executor=None + ) + + def test_is_compatible_with_seekable_stream(self): + with open(self.filename, 'wb') as f: + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_not_compatible_with_filename(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_compatible_with_non_seekable_stream(self): + class NonSeekable: + def write(self, data): + pass + + f = NonSeekable() + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_is_compatible_with_bytesio(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIs( + self.download_output_manager.get_download_task_tag(), + IN_MEMORY_DOWNLOAD_TAG, + ) + + def test_submit_writes_from_internal_queue(self): + class FakeQueue: + def request_writes(self, offset, data): + return [ + {'offset': 0, 'data': 'foo'}, + {'offset': 3, 'data': 'bar'}, + ] + + q = FakeQueue() + io_executor = BoundedExecutor(1000, 1) + manager = DownloadNonSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=io_executor, + defer_queue=q, + ) + fileobj = WriteCollector() + manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1) + io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=1 + ) + self.assertIsInstance(io_write_task, IOStreamingWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(0, 'foo')]) + + +class TestDownloadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # Create a stream to read from + self.content = b'my content' + self.stream = BytesIO(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'io_executor': self.io_executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_download_submission_task() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_for_get_object(self, tag_value): + submissions_to_compare = self.executor.submissions + if len(submissions_to_compare) > 1: + # If it was ranged get, make sure we do not include the join task. + submissions_to_compare = submissions_to_compare[:-1] + for submission in submissions_to_compare: + self.assertEqual(submission['tag'], tag_value) + + def add_head_object_response(self): + self.stubber.add_response( + 'head_object', {'ContentLength': len(self.content)} + ) + + def add_get_responses(self): + chunksize = self.config.multipart_chunksize + for i in range(0, len(self.content), chunksize): + if i + chunksize > len(self.content): + stream = BytesIO(self.content[i:]) + self.stubber.add_response('get_object', {'Body': stream}) + else: + stream = BytesIO(self.content[i : i + chunksize]) + self.stubber.add_response('get_object', {'Body': stream}) + + def configure_for_ranged_get(self): + self.config.multipart_threshold = 1 + self.config.multipart_chunksize = 4 + + def get_download_submission_task(self): + return self.get_task( + DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def wait_and_assert_completed_successfully(self, submission_task): + submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + def test_submits_no_tag_for_get_object_filename(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_filename(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def tests_submits_tag_for_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + +class TestGetObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + self.max_attempts = 5 + self.io_executor = BoundedExecutor(1000, 1) + self.content = b'my content' + self.stream = BytesIO(self.content) + self.fileobj = WriteCollector() + self.osutil = OSUtils() + self.io_chunksize = 64 * (1024**2) + self.task_cls = GetObjectTask + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def get_download_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'fileobj': self.fileobj, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'max_attempts': self.max_attempts, + 'download_output_manager': self.download_output_manager, + 'io_chunksize': self.io_chunksize, + } + default_kwargs.update(kwargs) + self.transfer_coordinator.set_status_to_queued() + return self.get_task(self.task_cls, main_kwargs=default_kwargs) + + def assert_io_writes(self, expected_writes): + # Let the io executor process all of the writes before checking + # what writes were sent to it. + self.io_executor.shutdown() + self.assertEqual(self.fileobj.writes, expected_writes) + + def test_main(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_extra_args(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Range': 'bytes=0-', + }, + ) + self.extra_args['Range'] = 'bytes=0-' + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_control_chunk_size(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + + self.assert_io_writes(expected_contents) + + def test_start_index(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(start_index=5) + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(5, self.content)]) + + def test_uses_bandwidth_limiter(self): + bandwidth_limiter = mock.Mock(BandwidthLimiter) + + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual( + bandwidth_limiter.get_bandwith_limited_stream.call_args_list, + [mock.call(mock.ANY, self.transfer_coordinator)], + ) + + def test_retries_succeeds(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + # Retryable error should have not affected the bytes placed into + # the io queue. + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_retries_failure(self): + for _ in range(self.max_attempts): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + + # Should have failed out on a RetriesExceededError + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + self.stubber.assert_no_pending_responses() + + def test_retries_in_middle_of_streaming(self): + # After the first read a retryable error will be thrown + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError( + copy.deepcopy(self.stream), SOCKET_ERROR, 1 + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + # This is the content initially read in before the retry hit on the + # second read() + expected_contents.append((0, bytes(self.content[0:1]))) + + # The rest of the content should be the entire set of data partitioned + # out based on the one byte stream chunk size. Note the second + # element in the list should be a copy of the first element since + # a retryable exception happened in between. + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + self.assert_io_writes(expected_contents) + + def test_cancels_out_of_queueing(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': CancelledStreamWrapper( + self.stream, self.transfer_coordinator + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + # Make sure that no contents were added to the queue because the task + # should have been canceled before trying to add the contents to the + # io queue. + self.assert_io_writes([]) + + def test_handles_callback_on_initial_error(self): + # We can't use the stubber for this because we need to raise + # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows + # you to raise a ClientError. + self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR()) + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + # Should have failed out on a RetriesExceededError because + # get_object keeps raising a socket error. + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + + +class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask): + def setUp(self): + super().setUp() + self.task_cls = ImmediatelyWriteIOGetObjectTask + # When data is written out, it should not use the io executor at all + # if it does use the io executor that is a deviation from expected + # behavior as the data should be written immediately to the file + # object once downloaded. + self.io_executor = None + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def assert_io_writes(self, expected_writes): + self.assertEqual(self.fileobj.writes, expected_writes) + + +class BaseIOTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.osutil = OSUtils() + self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile') + self.final_filename = os.path.join(self.files.rootdir, 'myfile') + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + +class TestIOStreamingWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'foobar'}, + ) + task() + task2 = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'baz'}, + ) + task2() + with open(self.temp_filename, 'rb') as f: + # We should just have written to the file in the order + # the tasks were executed. + self.assertEqual(f.read(), b'foobarbaz') + + +class TestIOWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + # Write once to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0}, + ) + task() + + # Write again to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3}, + ) + task() + + with open(self.temp_filename, 'rb') as f: + self.assertEqual(f.read(), b'foobar') + + +class TestIORenameFileTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IORenameFileTask, + main_kwargs={ + 'fileobj': f, + 'final_filename': self.final_filename, + 'osutil': self.osutil, + }, + ) + task() + self.assertTrue(os.path.exists(self.final_filename)) + self.assertFalse(os.path.exists(self.temp_filename)) + + +class TestIOCloseTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'w') as f: + task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f}) + task() + self.assertTrue(f.closed) + + +class TestDownloadChunkIterator(unittest.TestCase): + def test_iter(self): + content = b'my content' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, len(content)): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'my content']) + + def test_iter_chunksize(self): + content = b'1234' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'123', b'4']) + + def test_empty_content(self): + body = BytesIO(b'') + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'']) + + +class TestDeferQueue(unittest.TestCase): + def setUp(self): + self.q = DeferQueue() + + def test_no_writes_when_not_lowest_block(self): + writes = self.q.request_writes(offset=1, data='bar') + self.assertEqual(writes, []) + + def test_writes_returned_in_order(self): + self.assertEqual(self.q.request_writes(offset=3, data='d'), []) + self.assertEqual(self.q.request_writes(offset=2, data='c'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # Everything at this point has been deferred, but as soon as we + # send offset=0, that will unlock offsets 0-3. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + {'offset': 3, 'data': 'd'}, + ], + ) + + def test_unlocks_partial_range(self): + self.assertEqual(self.q.request_writes(offset=5, data='f'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + ], + ) + + def test_data_can_be_any_size(self): + self.q.request_writes(offset=5, data='hello world') + writes = self.q.request_writes(offset=0, data='abcde') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'abcde'}, + {'offset': 5, 'data': 'hello world'}, + ], + ) + + def test_data_queued_in_order(self): + # This immediately gets returned because offset=0 is the + # next range we're waiting on. + writes = self.q.request_writes(offset=0, data='hello world') + self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}]) + # Same thing here but with offset + writes = self.q.request_writes(offset=11, data='hello again') + self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}]) + + def test_writes_below_min_offset_are_ignored(self): + self.q.request_writes(offset=0, data='a') + self.q.request_writes(offset=1, data='b') + self.q.request_writes(offset=2, data='c') + + # At this point we're expecting offset=3, so if a write + # comes in below 3, we ignore it. + self.assertEqual(self.q.request_writes(offset=0, data='a'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + self.assertEqual( + self.q.request_writes(offset=3, data='d'), + [{'offset': 3, 'data': 'd'}], + ) + + def test_duplicate_writes_are_ignored(self): + self.q.request_writes(offset=2, data='c') + self.q.request_writes(offset=1, data='b') + + # We're still waiting for offset=0, but if + # a duplicate write comes in for offset=2/offset=1 + # it's ignored. This gives "first one wins" behavior. + self.assertEqual(self.q.request_writes(offset=2, data='X'), []) + self.assertEqual(self.q.request_writes(offset=1, data='Y'), []) + + self.assertEqual( + self.q.request_writes(offset=0, data='a'), + [ + {'offset': 0, 'data': 'a'}, + # Note we're seeing 'b' 'c', and not 'X', 'Y'. + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + ], + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py new file mode 100644 index 0000000000..ca2888a654 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py @@ -0,0 +1,696 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import ( + CancelledError, + FatalError, + TransferNotDoneError, +) +from s3transfer.futures import ( + BaseExecutor, + BoundedExecutor, + ExecutorFuture, + NonThreadedExecutor, + NonThreadedExecutorFuture, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.tasks import Task +from s3transfer.utils import ( + FunctionContainer, + NoResourcesAvailable, + TaskSemaphore, +) +from __tests__ import ( + RecordingExecutor, + TransferCoordinatorWithInterrupt, + mock, + unittest, +) + + +def return_call_args(*args, **kwargs): + return args, kwargs + + +def raise_exception(exception): + raise exception + + +def get_exc_info(exception): + try: + raise_exception(exception) + except Exception: + return sys.exc_info() + + +class RecordingTransferCoordinator(TransferCoordinator): + def __init__(self): + self.all_transfer_futures_ever_associated = set() + super().__init__() + + def add_associated_future(self, future): + self.all_transfer_futures_ever_associated.add(future) + super().add_associated_future(future) + + +class ReturnFooTask(Task): + def _main(self, **kwargs): + return 'foo' + + +class SleepTask(Task): + def _main(self, sleep_time, **kwargs): + time.sleep(sleep_time) + + +class TestTransferFuture(unittest.TestCase): + def setUp(self): + self.meta = TransferMeta() + self.coordinator = TransferCoordinator() + self.future = self._get_transfer_future() + + def _get_transfer_future(self, **kwargs): + components = { + 'meta': self.meta, + 'coordinator': self.coordinator, + } + for component_name, component in kwargs.items(): + components[component_name] = component + return TransferFuture(**components) + + def test_meta(self): + self.assertIs(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.coordinator.set_result(None) + self.assertTrue(self.future.done()) + + def test_result(self): + result = 'foo' + self.coordinator.set_result(result) + self.coordinator.announce_done() + self.assertEqual(self.future.result(), result) + + def test_keyboard_interrupt_on_result_does_not_block(self): + # This should raise a KeyboardInterrupt when result is called on it. + self.coordinator = TransferCoordinatorWithInterrupt() + self.future = self._get_transfer_future() + + # result() should not block and immediately raise the keyboard + # interrupt exception. + with self.assertRaises(KeyboardInterrupt): + self.future.result() + + def test_cancel(self): + self.future.cancel() + self.assertTrue(self.future.done()) + self.assertEqual(self.coordinator.status, 'cancelled') + + def test_set_exception(self): + # Set the result such that there is no exception + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.assertEqual(self.future.result(), 'result') + + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + def test_set_exception_only_after_done(self): + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(ValueError()) + + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + +class TestTransferMeta(unittest.TestCase): + def setUp(self): + self.transfer_meta = TransferMeta() + + def test_size(self): + self.assertEqual(self.transfer_meta.size, None) + self.transfer_meta.provide_transfer_size(5) + self.assertEqual(self.transfer_meta.size, 5) + + def test_call_args(self): + call_args = object() + transfer_meta = TransferMeta(call_args) + # Assert the that call args provided is the same as is returned + self.assertIs(transfer_meta.call_args, call_args) + + def test_transfer_id(self): + transfer_meta = TransferMeta(transfer_id=1) + self.assertEqual(transfer_meta.transfer_id, 1) + + def test_user_context(self): + self.transfer_meta.user_context['foo'] = 'bar' + self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) + + +class TestTransferCoordinator(unittest.TestCase): + def setUp(self): + self.transfer_coordinator = TransferCoordinator() + + def test_transfer_id(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual(transfer_coordinator.transfer_id, 1) + + def test_repr(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual( + repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)' + ) + + def test_initial_status(self): + # A TransferCoordinator with no progress should have the status + # of not-started + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + def test_set_status_to_queued(self): + self.transfer_coordinator.set_status_to_queued() + self.assertEqual(self.transfer_coordinator.status, 'queued') + + def test_cannot_set_status_to_queued_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_queued() + + def test_status_running(self): + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + def test_cannot_set_status_to_running_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_running() + + def test_set_result(self): + success_result = 'foo' + self.transfer_coordinator.set_result(success_result) + self.transfer_coordinator.announce_done() + # Setting result should result in a success state and the return value + # that was set. + self.assertEqual(self.transfer_coordinator.status, 'success') + self.assertEqual(self.transfer_coordinator.result(), success_result) + + def test_set_exception(self): + exception_result = RuntimeError + self.transfer_coordinator.set_exception(exception_result) + self.transfer_coordinator.announce_done() + # Setting an exception should result in a failed state and the return + # value should be the raised exception + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(self.transfer_coordinator.exception, exception_result) + with self.assertRaises(exception_result): + self.transfer_coordinator.result() + + def test_exception_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError) + # It status should be success even after the exception is set because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_exception_can_override_done_state_with_override_flag(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError, override=True) + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_cancel(self): + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.cancel() + # This should set the state to cancelled and raise the CancelledError + # exception and should have also set the done event so that result() + # is no longer set. + self.assertEqual(self.transfer_coordinator.status, 'cancelled') + with self.assertRaises(CancelledError): + self.transfer_coordinator.result() + + def test_cancel_can_run_done_callbacks_that_uses_result(self): + exceptions = [] + + def capture_exception(transfer_coordinator, captured_exceptions): + try: + transfer_coordinator.result() + except Exception as e: + captured_exceptions.append(e) + + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.add_done_callback( + capture_exception, self.transfer_coordinator, exceptions + ) + self.transfer_coordinator.cancel() + + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], CancelledError) + + def test_cancel_with_message(self): + message = 'my message' + self.transfer_coordinator.cancel(message) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + self.transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my message' + self.transfer_coordinator.cancel(message, exc_type=FatalError) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + self.transfer_coordinator.result() + + def test_cancel_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.cancel() + # It status should be success even after cancel is called because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_set_result_can_override_cancel(self): + self.transfer_coordinator.cancel() + # Result setting should override any cancel or set exception as this + # is always invoked by the final task. + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.announce_done() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_submit(self): + # Submit a callable to the transfer coordinator. It should submit it + # to the executor. + executor = RecordingExecutor( + BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)}) + ) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task, tag='my-tag') + executor.shutdown() + # Make sure the future got submit and executed as well by checking its + # result value which should include the provided future tag. + self.assertEqual( + executor.submissions, + [{'block': True, 'tag': 'my-tag', 'task': task}], + ) + self.assertEqual(future.result(), 'foo') + + def test_association_and_disassociation_on_submit(self): + self.transfer_coordinator = RecordingTransferCoordinator() + + # Submit a callable to the transfer coordinator. + executor = BoundedExecutor(1, 1) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task) + executor.shutdown() + + # Make sure the future that got submitted was associated to the + # transfer future at some point. + self.assertEqual( + self.transfer_coordinator.all_transfer_futures_ever_associated, + {future}, + ) + + # Make sure the future got disassociated once the future is now done + # by looking at the currently associated futures. + self.assertEqual(self.transfer_coordinator.associated_futures, set()) + + def test_done(self): + # These should result in not done state: + # queued + self.assertFalse(self.transfer_coordinator.done()) + # running + self.transfer_coordinator.set_status_to_running() + self.assertFalse(self.transfer_coordinator.done()) + + # These should result in done state: + # failed + self.transfer_coordinator.set_exception(Exception) + self.assertTrue(self.transfer_coordinator.done()) + + # success + self.transfer_coordinator.set_result('foo') + self.assertTrue(self.transfer_coordinator.done()) + + # cancelled + self.transfer_coordinator.cancel() + self.assertTrue(self.transfer_coordinator.done()) + + def test_result_waits_until_done(self): + execution_order = [] + + def sleep_then_set_result(transfer_coordinator, execution_order): + time.sleep(0.05) + execution_order.append('setting_result') + transfer_coordinator.set_result(None) + self.transfer_coordinator.announce_done() + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit( + sleep_then_set_result, + self.transfer_coordinator, + execution_order, + ) + self.transfer_coordinator.result() + execution_order.append('after_result') + + # The result() call should have waited until the other thread set + # the result after sleeping for 0.05 seconds. + self.assertTrue(execution_order, ['setting_result', 'after_result']) + + def test_failure_cleanups(self): + args = (1, 2) + kwargs = {'foo': 'bar'} + + second_args = (2, 4) + second_kwargs = {'biz': 'baz'} + + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *args, **kwargs + ) + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *second_args, **second_kwargs + ) + + # Ensure the callbacks got added. + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2) + + result_list = [] + # Ensure they will get called in the correct order. + for cleanup in self.transfer_coordinator.failure_cleanups: + result_list.append(cleanup()) + self.assertEqual( + result_list, [(args, kwargs), (second_args, second_kwargs)] + ) + + def test_associated_futures(self): + first_future = object() + # Associate one future to the transfer + self.transfer_coordinator.add_associated_future(first_future) + associated_futures = self.transfer_coordinator.associated_futures + # The first future should be in the returned list of futures. + self.assertEqual(associated_futures, {first_future}) + + second_future = object() + # Associate another future to the transfer. + self.transfer_coordinator.add_associated_future(second_future) + # The association should not have mutated the returned list from + # before. + self.assertEqual(associated_futures, {first_future}) + + # Both futures should be in the returned list. + self.assertEqual( + self.transfer_coordinator.associated_futures, + {first_future, second_future}, + ) + + def test_done_callbacks_on_done(self): + done_callback_invocations = [] + callback = FunctionContainer( + done_callback_invocations.append, 'done callback called' + ) + + # Add the done callback to the transfer. + self.transfer_coordinator.add_done_callback(callback) + + # Announce that the transfer is done. This should invoke the done + # callback. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + # If done is announced again, we should not invoke the callback again + # because done has already been announced and thus the callback has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + def test_failure_cleanups_on_done(self): + cleanup_invocations = [] + callback = FunctionContainer( + cleanup_invocations.append, 'cleanup called' + ) + + # Add the failure cleanup to the transfer. + self.transfer_coordinator.add_failure_cleanup(callback) + + # Announce that the transfer is done. This should invoke the failure + # cleanup. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + # If done is announced again, we should not invoke the cleanup again + # because done has already been announced and thus the cleanup has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + +class TestBoundedExecutor(unittest.TestCase): + def setUp(self): + self.coordinator = TransferCoordinator() + self.tag_semaphores = {} + self.executor = self.get_executor() + + def get_executor(self, max_size=1, max_num_threads=1): + return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores) + + def get_task(self, task_cls, main_kwargs=None): + return task_cls(self.coordinator, main_kwargs=main_kwargs) + + def get_sleep_task(self, sleep_time=0.01): + return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time}) + + def add_semaphore(self, task_tag, count): + self.tag_semaphores[task_tag] = TaskSemaphore(count) + + def assert_submit_would_block(self, task, tag=None): + with self.assertRaises(NoResourcesAvailable): + self.executor.submit(task, tag=tag, block=False) + + def assert_submit_would_not_block(self, task, tag=None, **kwargs): + try: + self.executor.submit(task, tag=tag, block=False) + except NoResourcesAvailable: + self.fail( + 'Task {} should not have been blocked. Caused by:\n{}'.format( + task, traceback.format_exc() + ) + ) + + def add_done_callback_to_future(self, future, fn, *args, **kwargs): + callback_for_future = FunctionContainer(fn, *args, **kwargs) + future.add_done_callback(callback_for_future) + + def test_submit_single_task(self): + # Ensure we can submit a task to the executor + task = self.get_task(ReturnFooTask) + future = self.executor.submit(task) + + # Ensure what we get back is a Future + self.assertIsInstance(future, ExecutorFuture) + # Ensure the callable got executed. + self.assertEqual(future.result(), 'foo') + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_executor_blocks_on_full_capacity(self): + first_task = self.get_sleep_task() + second_task = self.get_sleep_task() + self.executor.submit(first_task) + # The first task should be sleeping for a substantial period of + # time such that on the submission of the second task, it will + # raise an error saying that it cannot be submitted as the max + # capacity of the semaphore is one. + self.assert_submit_would_block(second_task) + + def test_executor_clears_capacity_on_done_tasks(self): + first_task = self.get_sleep_task() + second_task = self.get_task(ReturnFooTask) + + # Submit a task. + future = self.executor.submit(first_task) + + # Submit a new task when the first task finishes. This should not get + # blocked because the first task should have finished clearing up + # capacity. + self.add_done_callback_to_future( + future, self.assert_submit_would_not_block, second_task + ) + + # Wait for it to complete. + self.executor.shutdown() + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_would_not_block_when_full_capacity_in_other_semaphore(self): + first_task = self.get_sleep_task() + + # Now let's create a new task with a tag and so it uses different + # semaphore. + task_tag = 'other' + other_task = self.get_sleep_task() + self.add_semaphore(task_tag, 1) + + # Submit the normal first task + self.executor.submit(first_task) + + # Even though The first task should be sleeping for a substantial + # period of time, the submission of the second task should not + # raise an error because it should use a different semaphore + self.assert_submit_would_not_block(other_task, task_tag) + + # Another submission of the other task though should raise + # an exception as the capacity is equal to one for that tag. + self.assert_submit_would_block(other_task, task_tag) + + def test_shutdown(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown() + # Ensure that the shutdown waits until the task is done + self.assertTrue(future.done()) + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_shutdown_no_wait(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown(False) + # Ensure that the shutdown returns immediately even if the task is + # not done, which it should not be because it it slow. + self.assertFalse(future.done()) + + def test_replace_underlying_executor(self): + mocked_executor_cls = mock.Mock(BaseExecutor) + executor = BoundedExecutor(10, 1, {}, mocked_executor_cls) + executor.submit(self.get_task(ReturnFooTask)) + self.assertTrue(mocked_executor_cls.return_value.submit.called) + + +class TestExecutorFuture(unittest.TestCase): + def test_result(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'})) + + def test_done(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertTrue(wrapped_future.done()) + + def test_add_done_callback(self): + done_callbacks = [] + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + wrapped_future.add_done_callback( + FunctionContainer(done_callbacks.append, 'called') + ) + self.assertEqual(done_callbacks, ['called']) + + +class TestNonThreadedExecutor(unittest.TestCase): + def test_submit(self): + executor = NonThreadedExecutor() + future = executor.submit(return_call_args, 1, 2, foo='bar') + self.assertIsInstance(future, NonThreadedExecutorFuture) + self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'})) + + def test_submit_with_exception(self): + executor = NonThreadedExecutor() + future = executor.submit(raise_exception, RuntimeError()) + self.assertIsInstance(future, NonThreadedExecutorFuture) + with self.assertRaises(RuntimeError): + future.result() + + def test_submit_with_exception_and_captures_info(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + future = NonThreadedExecutor().submit(raise_exception, exception) + try: + future.result() + # An exception should have been raised + self.fail('Future should have raised a ValueError') + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + +class TestNonThreadedExecutorFuture(unittest.TestCase): + def setUp(self): + self.future = NonThreadedExecutorFuture() + + def test_done_starts_false(self): + self.assertFalse(self.future.done()) + + def test_done_after_setting_result(self): + self.future.set_result('result') + self.assertTrue(self.future.done()) + + def test_done_after_setting_exception(self): + self.future.set_exception_info(Exception(), None) + self.assertTrue(self.future.done()) + + def test_result(self): + self.future.set_result('result') + self.assertEqual(self.future.result(), 'result') + + def test_exception_result(self): + exception = ValueError('message') + self.future.set_exception_info(exception, None) + with self.assertRaisesRegex(ValueError, 'message'): + self.future.result() + + def test_exception_result_doesnt_modify_last_frame(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + self.future.set_exception_info(exception, tb) + try: + self.future.result() + # An exception should have been raised + self.fail() + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + def test_done_callback(self): + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, []) + self.future.set_result('result') + self.assertEqual(done_futures, [self.future]) + + def test_done_callback_after_done(self): + self.future.set_result('result') + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, [self.future]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py new file mode 100644 index 0000000000..fc3caa843f --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py @@ -0,0 +1,143 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import time +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import TransferCoordinator +from s3transfer.manager import TransferConfig, TransferCoordinatorController +from __tests__ import TransferCoordinatorWithInterrupt, unittest + + +class FutureResultException(Exception): + pass + + +class TestTransferConfig(unittest.TestCase): + def test_exception_on_zero_attr_value(self): + with self.assertRaises(ValueError): + TransferConfig(max_request_queue_size=0) + + +class TestTransferCoordinatorController(unittest.TestCase): + def setUp(self): + self.coordinator_controller = TransferCoordinatorController() + + def sleep_then_announce_done(self, transfer_coordinator, sleep_time): + time.sleep(sleep_time) + transfer_coordinator.set_result('done') + transfer_coordinator.announce_done() + + def assert_coordinator_is_cancelled(self, transfer_coordinator): + self.assertEqual(transfer_coordinator.status, 'cancelled') + + def test_add_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Ensure that is tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, + {transfer_coordinator}, + ) + + def test_remove_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Now remove the coordinator + self.coordinator_controller.remove_transfer_coordinator( + transfer_coordinator + ) + # Make sure that it is no longer getting tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, set() + ) + + def test_cancel(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Cancel with the canceler + self.coordinator_controller.cancel() + # Check that coordinator got canceled + self.assert_coordinator_is_cancelled(transfer_coordinator) + + def test_cancel_with_message(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message, exc_type=FatalError) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + transfer_coordinator.result() + + def test_wait_for_done_transfer_coordinators(self): + # Create a coordinator and add it to the canceler + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + + sleep_time = 0.02 + with ThreadPoolExecutor(max_workers=1) as executor: + # In a separate thread sleep and then set the transfer coordinator + # to done after sleeping. + start_time = time.time() + executor.submit( + self.sleep_then_announce_done, transfer_coordinator, sleep_time + ) + # Now call wait to wait for the transfer coordinator to be done. + self.coordinator_controller.wait() + end_time = time.time() + wait_time = end_time - start_time + # The time waited should not be less than the time it took to sleep in + # the separate thread because the wait ending should be dependent on + # the sleeping thread announcing that the transfer coordinator is done. + self.assertTrue(sleep_time <= wait_time) + + def test_wait_does_not_propogate_exceptions_from_result(self): + transfer_coordinator = TransferCoordinator() + transfer_coordinator.set_exception(FutureResultException()) + transfer_coordinator.announce_done() + try: + self.coordinator_controller.wait() + except FutureResultException as e: + self.fail('%s should not have been raised.' % e) + + def test_wait_can_be_interrupted(self): + inject_interrupt_coordinator = TransferCoordinatorWithInterrupt() + self.coordinator_controller.add_transfer_coordinator( + inject_interrupt_coordinator + ) + with self.assertRaises(KeyboardInterrupt): + self.coordinator_controller.wait() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py new file mode 100644 index 0000000000..d77b5e0240 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py @@ -0,0 +1,728 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import queue +import signal +import threading +import time +from io import BytesIO + +from botocore.client import BaseClient +from botocore.config import Config +from botocore.exceptions import ClientError, ReadTimeoutError + +from s3transfer.constants import PROCESS_USER_AGENT +from s3transfer.exceptions import CancelledError, RetriesExceededError +from s3transfer.processpool import ( + SHUTDOWN_SIGNAL, + ClientFactory, + DownloadFileRequest, + GetObjectJob, + GetObjectSubmitter, + GetObjectWorker, + ProcessPoolDownloader, + ProcessPoolTransferFuture, + ProcessPoolTransferMeta, + ProcessTransferConfig, + TransferMonitor, + TransferState, + ignore_ctrl_c, +) +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + FileCreator, + StreamWithError, + StubbedClientTest, + mock, + skip_if_windows, + unittest, +) + + +class RenameFailingOSUtils(OSUtils): + def __init__(self, exception): + self.exception = exception + + def rename_file(self, current_filename, new_filename): + raise self.exception + + +class TestIgnoreCtrlC(unittest.TestCase): + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_ignore_ctrl_c(self): + with ignore_ctrl_c(): + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + self.fail( + 'The ignore_ctrl_c context manager should have ' + 'ignored the KeyboardInterrupt exception' + ) + + +class TestProcessPoolDownloader(unittest.TestCase): + def test_uses_client_kwargs(self): + with mock.patch('s3transfer.processpool.ClientFactory') as factory: + ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'}) + self.assertEqual( + factory.call_args[0][0], {'region_name': 'myregion'} + ) + + +class TestProcessPoolTransferFuture(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + self.meta = ProcessPoolTransferMeta( + transfer_id=self.transfer_id, call_args=CallArgs() + ) + self.future = ProcessPoolTransferFuture( + monitor=self.monitor, meta=self.meta + ) + + def test_meta(self): + self.assertEqual(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.future.done()) + + def test_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.future.result()) + + def test_result_with_exception(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.future.result() + + def test_result_with_keyboard_interrupt(self): + mock_monitor = mock.Mock(TransferMonitor) + mock_monitor._connect = mock.Mock() + mock_monitor.poll_for_result.side_effect = KeyboardInterrupt() + future = ProcessPoolTransferFuture( + monitor=mock_monitor, meta=self.meta + ) + with self.assertRaises(KeyboardInterrupt): + future.result() + self.assertTrue(mock_monitor._connect.called) + self.assertTrue(mock_monitor.notify_exception.called) + call_args = mock_monitor.notify_exception.call_args[0] + self.assertEqual(call_args[0], self.transfer_id) + self.assertIsInstance(call_args[1], CancelledError) + + def test_cancel(self): + self.future.cancel() + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(CancelledError): + self.future.result() + + +class TestProcessPoolTransferMeta(unittest.TestCase): + def test_transfer_id(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.transfer_id, 1) + + def test_call_args(self): + call_args = CallArgs() + meta = ProcessPoolTransferMeta(1, call_args) + self.assertEqual(meta.call_args, call_args) + + def test_user_context(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.user_context, {}) + meta.user_context['mykey'] = 'myvalue' + self.assertEqual(meta.user_context, {'mykey': 'myvalue'}) + + +class TestClientFactory(unittest.TestCase): + def test_create_client(self): + client = ClientFactory().create_client() + self.assertIsInstance(client, BaseClient) + self.assertEqual(client.meta.service_model.service_name, 's3') + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_create_client_with_client_kwargs(self): + client = ClientFactory({'region_name': 'myregion'}).create_client() + self.assertEqual(client.meta.region_name, 'myregion') + + def test_user_agent_with_config(self): + client = ClientFactory({'config': Config()}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent_extra(self): + config = Config(user_agent_extra='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent(self): + config = Config(user_agent='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + +class TestTransferMonitor(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + + def test_notify_new_transfer_creates_new_state(self): + monitor = TransferMonitor() + transfer_id = monitor.notify_new_transfer() + self.assertFalse(monitor.is_done(transfer_id)) + self.assertIsNone(monitor.get_exception(transfer_id)) + + def test_notify_new_transfer_increments_transfer_id(self): + monitor = TransferMonitor() + self.assertEqual(monitor.notify_new_transfer(), 0) + self.assertEqual(monitor.notify_new_transfer(), 1) + + def test_notify_get_exception(self): + exception = Exception() + self.monitor.notify_exception(self.transfer_id, exception) + self.assertEqual( + self.monitor.get_exception(self.transfer_id), exception + ) + + def test_get_no_exception(self): + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + def test_notify_jobs(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0) + + def test_notify_jobs_for_multiple_transfers(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + other_transfer_id = self.monitor.notify_new_transfer() + self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual( + self.monitor.notify_job_complete(other_transfer_id), 1 + ) + + def test_done(self): + self.assertFalse(self.monitor.is_done(self.transfer_id)) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.monitor.is_done(self.transfer_id)) + + def test_poll_for_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.monitor.poll_for_result(self.transfer_id)) + + def test_poll_for_result_raises_error(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.monitor.poll_for_result(self.transfer_id) + + def test_poll_for_result_waits_till_done(self): + event_order = [] + + def sleep_then_notify_done(): + time.sleep(0.05) + event_order.append('notify_done') + self.monitor.notify_done(self.transfer_id) + + t = threading.Thread(target=sleep_then_notify_done) + t.start() + + self.monitor.poll_for_result(self.transfer_id) + event_order.append('done_polling') + self.assertEqual(event_order, ['notify_done', 'done_polling']) + + def test_notify_cancel_all_in_progress(self): + monitor = TransferMonitor() + transfer_ids = [] + for _ in range(10): + transfer_ids.append(monitor.notify_new_transfer()) + monitor.notify_cancel_all_in_progress() + for transfer_id in transfer_ids: + self.assertIsInstance( + monitor.get_exception(transfer_id), CancelledError + ) + # Cancelling a transfer does not mean it is done as there may + # be cleanup work left to do. + self.assertFalse(monitor.is_done(transfer_id)) + + def test_notify_cancel_does_not_affect_done_transfers(self): + self.monitor.notify_done(self.transfer_id) + self.monitor.notify_cancel_all_in_progress() + self.assertTrue(self.monitor.is_done(self.transfer_id)) + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + +class TestTransferState(unittest.TestCase): + def setUp(self): + self.state = TransferState() + + def test_done(self): + self.assertFalse(self.state.done) + self.state.set_done() + self.assertTrue(self.state.done) + + def test_waits_till_done_is_set(self): + event_order = [] + + def sleep_then_set_done(): + time.sleep(0.05) + event_order.append('set_done') + self.state.set_done() + + t = threading.Thread(target=sleep_then_set_done) + t.start() + + self.state.wait_till_done() + event_order.append('done_waiting') + self.assertEqual(event_order, ['set_done', 'done_waiting']) + + def test_exception(self): + exception = RuntimeError() + self.state.exception = exception + self.assertEqual(self.state.exception, exception) + + def test_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.jobs_to_complete, 5) + + def test_decrement_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.decrement_jobs_to_complete(), 4) + + +class TestGetObjectSubmitter(StubbedClientTest): + def setUp(self): + super().setUp() + self.transfer_config = ProcessTransferConfig() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = mock.Mock(OSUtils) + self.download_request_queue = queue.Queue() + self.worker_queue = queue.Queue() + self.submitter = GetObjectSubmitter( + transfer_config=self.transfer_config, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + download_request_queue=self.download_request_queue, + worker_queue=self.worker_queue, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.filename = 'myfile' + self.temp_filename = 'myfile.temp' + self.osutil.get_temp_filename.return_value = self.temp_filename + self.extra_args = {} + self.expected_size = None + + def add_download_file_request(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'filename': self.filename, + 'extra_args': self.extra_args, + 'expected_size': self.expected_size, + } + kwargs.update(override_kwargs) + self.download_request_queue.put(DownloadFileRequest(**kwargs)) + + def add_shutdown(self): + self.download_request_queue.put(SHUTDOWN_SIGNAL) + + def assert_submitted_get_object_jobs(self, expected_jobs): + actual_jobs = [] + while not self.worker_queue.empty(): + actual_jobs.append(self.worker_queue.get()) + self.assertEqual(actual_jobs, expected_jobs) + + def test_run_for_non_ranged_download(self): + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_for_ranged_download(self): + self.transfer_config.multipart_chunksize = 2 + self.transfer_config.multipart_threshold = 4 + self.add_download_file_request(expected_size=4) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 4) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'Range': 'bytes=0-1'}, + filename=self.filename, + ), + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=2, + extra_args={'Range': 'bytes=2-'}, + filename=self.filename, + ), + ] + ) + + def test_run_when_expected_size_not_provided(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_with_extra_args(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + self.add_download_file_request( + extra_args={'VersionId': 'versionid'}, expected_size=None + ) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'VersionId': 'versionid'}, + filename=self.filename, + ) + ] + ) + + def test_run_with_exception(self): + self.stubber.add_client_error('head_object', 'NoSuchKey', 404) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_with_error_in_allocating_temp_file(self): + self.osutil.allocate.side_effect = OSError() + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), OSError + ) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_submitter_cannot_be_killed(self): + self.add_download_file_request(expected_size=None) + self.add_shutdown() + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.head_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.submitter.run() + except KeyboardInterrupt: + self.fail( + 'The submitter should have not been killed by the ' + 'KeyboardInterrupt' + ) + + +class TestGetObjectWorker(StubbedClientTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.queue = queue.Queue() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = OSUtils() + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.remote_contents = b'my content' + self.temp_filename = self.files.create_file('tempfile', '') + self.extra_args = {} + self.offset = 0 + self.final_filename = self.files.full_path('final_filename') + self.stream = BytesIO(self.remote_contents) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + def add_get_object_job(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'temp_filename': self.temp_filename, + 'extra_args': self.extra_args, + 'offset': self.offset, + 'filename': self.final_filename, + } + kwargs.update(override_kwargs) + self.queue.put(GetObjectJob(**kwargs)) + + def add_shutdown(self): + self.queue.put(SHUTDOWN_SIGNAL) + + def add_stubbed_get_object_response(self, body=None, expected_params=None): + if body is None: + body = self.stream + get_object_response = {'Body': body} + + if expected_params is None: + expected_params = {'Bucket': self.bucket, 'Key': self.key} + + self.stubber.add_response( + 'get_object', get_object_response, expected_params + ) + + def assert_contents(self, filename, contents): + self.assertTrue(os.path.exists(filename)) + with open(filename, 'rb') as f: + self.assertEqual(f.read(), contents) + + def assert_does_not_exist(self, filename): + self.assertFalse(os.path.exists(filename)) + + def test_run_is_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_contents(self.final_filename, self.remote_contents) + + def test_run_jobs_is_not_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + self.assert_does_not_exist(self.final_filename) + + def test_run_with_extra_args(self): + self.add_get_object_job(extra_args={'VersionId': 'versionid'}) + self.add_shutdown() + self.add_stubbed_get_object_response( + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + } + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + + def test_run_with_offset(self): + offset = 1 + self.add_get_object_job(offset=offset) + self.add_shutdown() + self.add_stubbed_get_object_response() + + self.worker.run() + with open(self.temp_filename, 'rb') as f: + f.seek(offset) + self.assertEqual(f.read(), self.remote_contents) + + def test_run_error_in_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.stubber.add_client_error('get_object', 'NoSuchKey', 404) + self.add_stubbed_get_object_response() + + self.worker.run() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_does_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + self.add_stubbed_get_object_response() + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + + def test_run_can_exhaust_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + # 5 is the current setting for max number of GetObject attempts + for _ in range(5): + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), + RetriesExceededError, + ) + + def test_run_skips_get_object_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + + self.worker.run() + # Note we did not add a stubbed response for get_object + self.stubber.assert_no_pending_responses() + + def test_run_final_job_removes_file_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + def test_run_fails_to_rename_file(self): + exception = OSError() + osutil = RenameFailingOSUtils(exception) + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=osutil, + ) + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.assertEqual( + self.transfer_monitor.get_exception(self.transfer_id), exception + ) + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_worker_cannot_be_killed(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.get_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.worker.run() + except KeyboardInterrupt: + self.fail( + 'The worker should have not been killed by the ' + 'KeyboardInterrupt' + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py new file mode 100644 index 0000000000..35cf4a22dd --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py @@ -0,0 +1,780 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import socket +import tempfile +from concurrent import futures +from contextlib import closing +from io import BytesIO, StringIO + +from s3transfer import ( + MultipartDownloader, + MultipartUploader, + OSUtils, + QueueShutdownError, + ReadFileChunk, + S3Transfer, + ShutdownQueue, + StreamReaderProgress, + TransferConfig, + disable_upload_callbacks, + enable_upload_callbacks, + random_file_extension, +) +from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError +from __tests__ import mock, unittest + + +class InMemoryOSLayer(OSUtils): + def __init__(self, filemap): + self.filemap = filemap + + def get_file_size(self, filename): + return len(self.filemap[filename]) + + def open_file_chunk_reader(self, filename, start_byte, size, callback): + return closing(BytesIO(self.filemap[filename])) + + def open(self, filename, mode): + if 'wb' in mode: + fileobj = BytesIO() + self.filemap[filename] = fileobj + return closing(fileobj) + else: + return closing(self.filemap[filename]) + + def remove_file(self, filename): + if filename in self.filemap: + del self.filemap[filename] + + def rename_file(self, current_filename, new_filename): + if current_filename in self.filemap: + self.filemap[new_filename] = self.filemap.pop(current_filename) + + +class SequentialExecutor: + def __init__(self, max_workers): + pass + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + # The real map() interface actually takes *args, but we specifically do + # _not_ use this interface. + def map(self, function, args): + results = [] + for arg in args: + results.append(function(arg)) + return results + + def submit(self, function): + future = futures.Future() + future.set_result(function()) + return future + + +class TestOSUtils(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_get_file_size(self): + with mock.patch('os.path.getsize') as m: + OSUtils().get_file_size('myfile') + m.assert_called_with('myfile') + + def test_open_file_chunk_reader(self): + with mock.patch('s3transfer.ReadFileChunk') as m: + OSUtils().open_file_chunk_reader('myfile', 0, 100, None) + m.from_filename.assert_called_with( + 'myfile', 0, 100, None, enable_callback=False + ) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + with mock.patch('os.remove') as remove: + remove.side_effect = OSError('fake error') + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_remove_file_proxies_remove_file(self): + with mock.patch('os.remove') as remove: + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_rename_file(self): + with mock.patch('s3transfer.compat.rename_file') as rename_file: + OSUtils().rename_file('foo', 'newfoo') + rename_file.assert_called_with('foo', 'newfoo') + + +class TestReadFileChunk(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + +class TestReadFileChunkWithCallback(TestReadFileChunk): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + with open(self.filename, 'wb') as f: + f.write(b'abc') + self.amounts_seen = [] + + def callback(self, amount): + self.amounts_seen.append(amount) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + +class TestStreamReaderProgress(unittest.TestCase): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + amounts_seen = [] + + def callback(amount): + amounts_seen.append(amount) + + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream, callback) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(amounts_seen, [9]) + + +class TestMultipartUploader(unittest.TestCase): + def test_multipart_upload_uses_correct_client_calls(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + # We need to check both the sequence of calls (create/upload/complete) + # as well as the params passed between the calls, including + # 1. The upload_id was plumbed through + # 2. The collected etags were added to the complete call. + client.create_multipart_upload.assert_called_with( + Bucket='bucket', Key='key' + ) + # Should be two parts. + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_injects_proper_kwargs(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + extra_args = { + 'SSECustomerKey': 'fakekey', + 'SSECustomerAlgorithm': 'AES256', + 'StorageClass': 'REDUCED_REDUNDANCY', + } + uploader.upload_file('filename', 'bucket', 'key', None, extra_args) + + client.create_multipart_upload.assert_called_with( + Bucket='bucket', + Key='key', + # The initial call should inject all the storage class params. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + StorageClass='REDUCED_REDUNDANCY', + ) + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + # We only have to forward certain **extra_args in subsequent + # UploadPart calls. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_is_aborted_on_error(self): + # If the create_multipart_upload succeeds and any upload_part + # fails, then abort_multipart_upload will be called. + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.side_effect = Exception( + "Some kind of error occurred." + ) + + with self.assertRaises(S3UploadFailedError): + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + client.abort_multipart_upload.assert_called_with( + Bucket='bucket', Key='key', UploadId='upload_id' + ) + + +class TestMultipartDownloader(unittest.TestCase): + + maxDiff = None + + def test_multipart_download_uses_correct_client_calls(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + client.get_object.assert_called_with( + Range='bytes=0-', Bucket='bucket', Key='key' + ) + + def test_multipart_download_with_multiple_parts(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + # For testing purposes, we're testing with a multipart threshold + # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz', + # this should result in 3 calls. In python slices this would be: + # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly + # different because they use inclusive ranges. + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # Note these are inclusive ranges. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_retry_on_failures_from_stream_reads(self): + # If we get an exception during a call to the response body's .read() + # method, we should retry the request. + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = [ + socket.error("fake error"), + response_body, + ] + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # The first call to range=0-3 fails because of the + # side_effect above where we make the .read() raise a + # socket.error. + # The second call to range=0-3 then succeeds. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_exception_raised_on_exceeded_retries(self): + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = socket.error("fake error") + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + with self.assertRaises(RetriesExceededError): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_io_thread_failure_triggers_shutdown(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + os_layer = mock.Mock() + mock_fileobj = mock.MagicMock() + mock_fileobj.__enter__.return_value = mock_fileobj + mock_fileobj.write.side_effect = Exception("fake IO error") + os_layer.open.return_value = mock_fileobj + + downloader = MultipartDownloader( + client, TransferConfig(), os_layer, SequentialExecutor + ) + # We're verifying that the exception raised from the IO future + # propagates back up via download_file(). + with self.assertRaisesRegex(Exception, "fake IO error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_download_futures_fail_triggers_shutdown(self): + class FailedDownloadParts(SequentialExecutor): + def __init__(self, max_workers): + self.is_first = True + + def submit(self, function): + future = futures.Future() + if self.is_first: + # This is the download_parts_thread. + future.set_exception( + Exception("fake download parts error") + ) + self.is_first = False + return future + + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts + ) + with self.assertRaisesRegex(Exception, "fake download parts error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + +class TestS3Transfer(unittest.TestCase): + def setUp(self): + self.client = mock.Mock() + self.random_file_patch = mock.patch('s3transfer.random_file_extension') + self.random_file = self.random_file_patch.start() + self.random_file.return_value = 'RANDOM' + + def tearDown(self): + self.random_file_patch.stop() + + def test_callback_handlers_register_on_put_item(self): + osutil = InMemoryOSLayer({'smallfile': b'foobar'}) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + events = self.client.meta.events + events.register_first.assert_called_with( + 'request-created.s3', + disable_upload_callbacks, + unique_id='s3upload-callback-disable', + ) + events.register_last.assert_called_with( + 'request-created.s3', + enable_upload_callbacks, + unique_id='s3upload-callback-enable', + ) + + def test_upload_below_multipart_threshold_uses_put_object(self): + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY + ) + + def test_extra_args_on_uploaded_passed_to_api_call(self): + extra_args = {'ACL': 'public-read'} + fake_files = {'smallfile': b'hello world'} + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file( + 'smallfile', 'bucket', 'key', extra_args=extra_args + ) + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read' + ) + + def test_uses_multipart_upload_when_over_threshold(self): + with mock.patch('s3transfer.MultipartUploader') as uploader: + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + config = TransferConfig( + multipart_threshold=2, multipart_chunksize=2 + ) + transfer = S3Transfer(self.client, osutil=osutil, config=config) + transfer.upload_file('smallfile', 'bucket', 'key') + + uploader.return_value.upload_file.assert_called_with( + 'smallfile', 'bucket', 'key', None, {} + ) + + def test_uses_multipart_download_when_over_threshold(self): + with mock.patch('s3transfer.MultipartDownloader') as downloader: + osutil = InMemoryOSLayer({}) + over_multipart_threshold = 100 * 1024 * 1024 + transfer = S3Transfer(self.client, osutil=osutil) + callback = mock.sentinel.CALLBACK + self.client.head_object.return_value = { + 'ContentLength': over_multipart_threshold, + } + transfer.download_file( + 'bucket', 'key', 'filename', callback=callback + ) + + downloader.return_value.download_file.assert_called_with( + # Note how we're downloading to a temporary random file. + 'bucket', + 'key', + 'filename.RANDOM', + over_multipart_threshold, + {}, + callback, + ) + + def test_download_file_with_invalid_extra_args(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + with self.assertRaises(ValueError): + transfer.download_file( + 'bucket', + 'key', + '/tmp/smallfile', + extra_args={'BadValue': 'foo'}, + ) + + def test_upload_file_with_invalid_extra_args(self): + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + bad_args = {"WebsiteRedirectLocation": "/foo"} + with self.assertRaises(ValueError): + transfer.upload_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args + ) + + def test_download_file_fowards_extra_args(self): + extra_args = { + 'SSECustomerKey': 'foo', + 'SSECustomerAlgorithm': 'AES256', + } + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args + ) + + # Note that we need to invoke the HeadObject call + # and the PutObject call with the extra_args. + # This is necessary. Trying to HeadObject an SSE object + # will return a 400 if you don't provide the required + # params. + self.client.get_object.assert_called_with( + Bucket='bucket', + Key='key', + SSECustomerAlgorithm='AES256', + SSECustomerKey='foo', + ) + + def test_get_object_stream_is_retried_and_succeeds(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.side_effect = [ + # First request fails. + socket.error("fake error"), + # Second succeeds. + {'Body': BytesIO(b'foobar')}, + ] + transfer.download_file('bucket', 'key', '/tmp/smallfile') + + self.assertEqual(self.client.get_object.call_count, 2) + + def test_get_object_stream_uses_all_retries_and_errors_out(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + # Here we're raising an exception every single time, which + # will exhaust our retry count and propagate a + # RetriesExceededError. + self.client.get_object.side_effect = socket.error("fake error") + with self.assertRaises(RetriesExceededError): + transfer.download_file('bucket', 'key', 'smallfile') + + self.assertEqual(self.client.get_object.call_count, 5) + # We should have also cleaned up the in progress file + # we were downloading to. + self.assertEqual(osutil.filemap, {}) + + def test_download_below_multipart_threshold(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file('bucket', 'key', 'smallfile') + + self.client.get_object.assert_called_with(Bucket='bucket', Key='key') + + def test_can_create_with_just_client(self): + transfer = S3Transfer(client=mock.Mock()) + self.assertIsInstance(transfer, S3Transfer) + + +class TestShutdownQueue(unittest.TestCase): + def test_handles_normal_put_get_requests(self): + q = ShutdownQueue() + q.put('foo') + self.assertEqual(q.get(), 'foo') + + def test_put_raises_error_on_shutdown(self): + q = ShutdownQueue() + q.trigger_shutdown() + with self.assertRaises(QueueShutdownError): + q.put('foo') + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestCallbackHandlers(unittest.TestCase): + def setUp(self): + self.request = mock.Mock() + + def test_disable_request_on_put_object(self): + disable_upload_callbacks(self.request, 'PutObject') + self.request.body.disable_callback.assert_called_with() + + def test_disable_request_on_upload_part(self): + disable_upload_callbacks(self.request, 'UploadPart') + self.request.body.disable_callback.assert_called_with() + + def test_enable_object_on_put_object(self): + enable_upload_callbacks(self.request, 'PutObject') + self.request.body.enable_callback.assert_called_with() + + def test_enable_object_on_upload_part(self): + enable_upload_callbacks(self.request, 'UploadPart') + self.request.body.enable_callback.assert_called_with() + + def test_dont_disable_if_missing_interface(self): + del self.request.body.disable_callback + disable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_enable_if_missing_interface(self): + del self.request.body.enable_callback + enable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_disable_if_wrong_operation(self): + disable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.disable_callback.called) + + def test_dont_enable_if_wrong_operation(self): + enable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.enable_callback.called) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py new file mode 100644 index 0000000000..a26d3a548c --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py @@ -0,0 +1,91 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.exceptions import InvalidSubscriberMethodError +from s3transfer.subscribers import BaseSubscriber +from __tests__ import unittest + + +class ExtraMethodsSubscriber(BaseSubscriber): + def extra_method(self): + return 'called extra method' + + +class NotCallableSubscriber(BaseSubscriber): + on_done = 'foo' + + +class NoKwargsSubscriber(BaseSubscriber): + def on_done(self): + pass + + +class OverrideMethodSubscriber(BaseSubscriber): + def on_queued(self, **kwargs): + return kwargs + + +class OverrideConstructorSubscriber(BaseSubscriber): + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + +class TestSubscribers(unittest.TestCase): + def test_can_instantiate_base_subscriber(self): + try: + BaseSubscriber() + except InvalidSubscriberMethodError: + self.fail('BaseSubscriber should be instantiable') + + def test_can_call_base_subscriber_method(self): + subscriber = BaseSubscriber() + try: + subscriber.on_done(future=None) + except Exception as e: + self.fail( + 'Should be able to call base class subscriber method. ' + 'instead got: %s' % e + ) + + def test_subclass_can_have_and_call_additional_methods(self): + subscriber = ExtraMethodsSubscriber() + self.assertEqual(subscriber.extra_method(), 'called extra method') + + def test_can_subclass_and_override_method_from_base_subscriber(self): + subscriber = OverrideMethodSubscriber() + # Make sure that the overridden method is called + self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'}) + + def test_can_subclass_and_override_constructor_from_base_class(self): + subscriber = OverrideConstructorSubscriber('foo', arg2='bar') + # Make sure you can create a custom constructor. + self.assertEqual(subscriber.arg1, 'foo') + self.assertEqual(subscriber.arg2, 'bar') + + def test_invalid_arguments_in_constructor_of_subclass_subscriber(self): + # The override constructor should still have validation of + # constructor args. + with self.assertRaises(TypeError): + OverrideConstructorSubscriber() + + def test_not_callable_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must be callable' + ): + NotCallableSubscriber() + + def test_no_kwargs_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must accept keyword' + ): + NoKwargsSubscriber() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py new file mode 100644 index 0000000000..4f0bc4d1cc --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py @@ -0,0 +1,833 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from concurrent import futures +from functools import partial +from threading import Event + +from s3transfer.futures import BoundedExecutor, TransferCoordinator +from s3transfer.subscribers import BaseSubscriber +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + RecordingSubscriber, + unittest, +) + + +class TaskFailureException(Exception): + pass + + +class SuccessTask(Task): + def _main( + self, return_value='success', callbacks=None, failure_cleanups=None + ): + if callbacks: + for callback in callbacks: + callback() + if failure_cleanups: + for failure_cleanup in failure_cleanups: + self._transfer_coordinator.add_failure_cleanup(failure_cleanup) + return return_value + + +class FailureTask(Task): + def _main(self, exception=TaskFailureException): + raise exception() + + +class ReturnKwargsTask(Task): + def _main(self, **kwargs): + return kwargs + + +class SubmitMoreTasksTask(Task): + def _main(self, executor, tasks_to_submit): + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + + +class NOOPSubmissionTask(SubmissionTask): + def _submit(self, transfer_future, **kwargs): + pass + + +class ExceptionSubmissionTask(SubmissionTask): + def _submit( + self, + transfer_future, + executor=None, + tasks_to_submit=None, + additional_callbacks=None, + exception=TaskFailureException, + ): + if executor and tasks_to_submit: + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + if additional_callbacks: + for callback in additional_callbacks: + callback() + raise exception() + + +class StatusRecordingTransferCoordinator(TransferCoordinator): + def __init__(self, transfer_id=None): + super().__init__(transfer_id) + self.status_changes = [self._status] + + def set_status_to_queued(self): + super().set_status_to_queued() + self._record_status_change() + + def set_status_to_running(self): + super().set_status_to_running() + self._record_status_change() + + def _record_status_change(self): + self.status_changes.append(self._status) + + +class RecordingStateSubscriber(BaseSubscriber): + def __init__(self, transfer_coordinator): + self._transfer_coordinator = transfer_coordinator + self.status_during_on_queued = None + + def on_queued(self, **kwargs): + self.status_during_on_queued = self._transfer_coordinator.status + + +class TestSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.executor = BoundedExecutor(1000, 5) + self.call_args = CallArgs(subscribers=[]) + self.transfer_future = self.get_transfer_future(self.call_args) + self.main_kwargs = {'transfer_future': self.transfer_future} + + def test_transitions_from_not_started_to_queued_to_running(self): + self.transfer_coordinator = StatusRecordingTransferCoordinator() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + # Status should be queued until submission task has been ran. + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + submission_task() + # Once submission task has been ran, the status should now be running. + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Ensure the transitions were as expected as well. + self.assertEqual( + self.transfer_coordinator.status_changes, + ['not-started', 'queued', 'running'], + ) + + def test_on_queued_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the on_queued callback of the subscriber is called. + self.assertEqual( + subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_on_queued_status_in_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingStateSubscriber(self.transfer_coordinator) + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the status was queued during on_queued callback. + self.assertEqual(subscriber.status_during_on_queued, 'queued') + + def test_sets_exception_from_submit(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Make sure the status of the future is failed + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the future propagates the exception encountered in the + # submission task. + with self.assertRaises(TaskFailureException): + self.transfer_future.result() + + def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self): + self.main_kwargs['exception'] = KeyboardInterrupt + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + self.assertEqual(self.transfer_coordinator.status, 'failed') + with self.assertRaises(KeyboardInterrupt): + self.transfer_future.result() + + def test_calls_done_callbacks_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_calls_failure_cleanups_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_cleanups_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the done callbacks to be invoke more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called only once. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_done_callbacks_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the failure cleanups to be invoked more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called only once. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_handles_cleanups_submitted_in_other_tasks(self): + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + # We want the cleanup to be added in the execution of the task and + # still be executed by the submission task when it fails. + task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [task] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called even though the callback got + # added in a completely different task. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self): + # In this test, we want to make sure that any tasks that may be + # submitted in another task complete before we start performing + # cleanups. + # + # This is tested by doing the following: + # + # ExecutionSubmissionTask + # | + # +--submits-->SubmitMoreTasksTask + # | + # +--submits-->SuccessTask + # | + # +-->sleeps-->adds failure cleanup + # + # In the end, the failure cleanup of the SuccessTask should be ran + # when the ExecutionSubmissionTask fails. If the + # ExeceptionSubmissionTask did not run the failure cleanup it is most + # likely that it did not wait for the SuccessTask to complete, which + # it needs to because the ExeceptionSubmissionTask does not know + # what failure cleanups it needs to run until all spawned tasks have + # completed. + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + + cleanup_task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + task_for_submitting_cleanup_task = self.get_task( + SubmitMoreTasksTask, + main_kwargs={ + 'executor': self.executor, + 'tasks_to_submit': [cleanup_task], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [ + task_for_submitting_cleanup_task + ] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_submission_task_announces_done_if_cancelled_before_main(self): + invocations_of_done = [] + done_callback = FunctionContainer( + invocations_of_done.append, 'done announced' + ) + self.transfer_coordinator.add_done_callback(done_callback) + + self.transfer_coordinator.cancel() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Because the submission task was cancelled before being run + # it did not submit any extra tasks so a result it is responsible + # for making sure it announces done as nothing else will. + self.assertEqual(invocations_of_done, ['done announced']) + + +class TestTask(unittest.TestCase): + def setUp(self): + self.transfer_id = 1 + self.transfer_coordinator = TransferCoordinator( + transfer_id=self.transfer_id + ) + + def test_repr(self): + main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'} + task = ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs + ) + # The repr should not include the other parameter because it is not + # a desired parameter to include. + self.assertEqual( + repr(task), + 'ReturnKwargsTask(transfer_id={}, {})'.format( + self.transfer_id, {'bucket': 'mybucket'} + ), + ) + + def test_transfer_id(self): + task = SuccessTask(self.transfer_coordinator) + # Make sure that the id is the one provided to the id associated + # to the transfer coordinator. + self.assertEqual(task.transfer_id, self.transfer_id) + + def test_context_status_transitioning_success(self): + # The status should be set to running. + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # If a task is called, the status still should be running. + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Once the final task is called, the status should be set to success. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_context_status_transitioning_failed(self): + self.transfer_coordinator.set_status_to_running() + + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # A failure task should result in the failed status + FailureTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Even if the final task comes in and succeeds, it should stay failed. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_result_setting_for_success(self): + override_return = 'foo' + SuccessTask(self.transfer_coordinator)() + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': override_return}, + is_final=True, + )() + + # The return value for the transfer future should be of the final + # task. + self.assertEqual(self.transfer_coordinator.result(), override_return) + + def test_result_setting_for_error(self): + FailureTask(self.transfer_coordinator)() + + # If another failure comes in, the result should still throw the + # original exception when result() is eventually called. + FailureTask( + self.transfer_coordinator, main_kwargs={'exception': Exception} + )() + + # Even if a success task comes along, the result of the future + # should be the original exception + SuccessTask(self.transfer_coordinator, is_final=True)() + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_done_callbacks_success(self): + callback_results = [] + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For successful tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + def test_done_callbacks_failure(self): + callback_results = [] + FailureTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For even failed tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + # Callbacks should continue to be called even after a related failure + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'third'), + partial(callback_results.append, 'fourth'), + ], + )() + self.assertEqual( + callback_results, ['first', 'second', 'third', 'fourth'] + ) + + def test_failure_cleanups_on_failure(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + FailureTask(self.transfer_coordinator)() + # The failure callbacks should have not been called yet because it + # is not the last task + self.assertEqual(callback_results, []) + + # Now the failure callbacks should get called. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(callback_results, ['first', 'second']) + + def test_no_failure_cleanups_on_success(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + SuccessTask(self.transfer_coordinator, is_final=True)() + # The failure cleanups should not have been called because no task + # failed for the transfer context. + self.assertEqual(callback_results, []) + + def test_passing_main_kwargs(self): + main_kwargs = {'foo': 'bar', 'baz': 'biz'} + ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True + )() + # The kwargs should have been passed to the main() + self.assertEqual(self.transfer_coordinator.result(), main_kwargs) + + def test_passing_pending_kwargs_single_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo']}, + ) + ) + pending_kwargs['baz'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['baz']}, + ) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_kwargs_list_of_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': ['first', 'second']} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][0]}, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][1]}, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['foo'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out in the expected order. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_and_non_pending_kwargs(self): + main_kwargs = {'nonpending_value': 'foo'} + pending_kwargs = {} + ref_main_kwargs = { + 'nonpending_value': 'foo', + 'pending_value': 'bar', + 'pending_list': ['first', 'second'], + } + + # Create the pending tasks + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['pending_value'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_value'] + }, + ) + ) + + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][0] + }, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][1] + }, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + # and just regular nonpending kwargs. + ReturnKwargsTask( + self.transfer_coordinator, + main_kwargs=main_kwargs, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have all of the kwargs (both pending and + # nonpending) + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_single_failed_pending_future(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + pending_kwargs['baz'] = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_single_failed_pending_future_in_list(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + second_future = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value in the list + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + +class BaseMultipartTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + + +class TestCreateMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'foo' + extra_args = {'Metadata': {'foo': 'bar'}} + response = {'UploadId': upload_id} + task = self.get_task( + CreateMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='create_multipart_upload', + service_response=response, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + result_id = task() + self.stubber.assert_no_pending_responses() + # Ensure the upload id returned is correct + self.assertEqual(upload_id, result_id) + + # Make sure that the abort was added as a cleanup failure + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1) + + # Make sure if it is called, it will abort correctly + self.stubber.add_response( + method='abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + }, + ) + self.transfer_coordinator.failure_cleanups[0]() + self.stubber.assert_no_pending_responses() + + +class TestCompleteMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + + def test_includes_extra_args(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {'RequestPayer': 'requester'}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + 'RequestPayer': 'requester', + }, + ) + task() + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py new file mode 100644 index 0000000000..1ac38b3616 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py @@ -0,0 +1,694 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import math +import os +import shutil +import tempfile +from io import BytesIO + +from botocore.stub import ANY + +from s3transfer.futures import IN_MEMORY_UPLOAD_TAG +from s3transfer.manager import TransferConfig +from s3transfer.upload import ( + AggregatedProgressCallback, + InterruptReader, + PutObjectTask, + UploadFilenameInputManager, + UploadNonSeekableInputManager, + UploadPartTask, + UploadSeekableInputManager, + UploadSubmissionTask, +) +from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileSizeProvider, + NonSeekableReader, + RecordingExecutor, + RecordingSubscriber, + unittest, +) + + +class InterruptionError(Exception): + pass + + +class OSUtilsExceptionOnFileSize(OSUtils): + def get_file_size(self, filename): + raise AssertionError( + "The file %s should not have been stated" % filename + ) + + +class BaseUploadTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + self.osutil = OSUtils() + + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'my content' + self.subscribers = [] + + with open(self.filename, 'wb') as f: + f.write(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + +class TestAggregatedProgressCallback(unittest.TestCase): + def setUp(self): + self.aggregated_amounts = [] + self.threshold = 3 + self.aggregated_progress_callback = AggregatedProgressCallback( + [self.callback], self.threshold + ) + + def callback(self, bytes_transferred): + self.aggregated_amounts.append(bytes_transferred) + + def test_under_threshold(self): + one_under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(one_under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback(1) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_at_threshold(self): + self.aggregated_progress_callback(self.threshold) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_over_threshold(self): + over_threshold_amount = self.threshold + 1 + self.aggregated_progress_callback(over_threshold_amount) + self.assertEqual(self.aggregated_amounts, [over_threshold_amount]) + + def test_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + def test_flush_with_nothing_to_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + # Flushing again should do nothing as it was just flushed + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + +class TestInterruptReader(BaseUploadTest): + def test_read_raises_exception(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Read some bytes to show it can be read. + self.assertEqual(reader.read(1), self.content[0:1]) + # Then set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError()) + # The next read should have the exception propograte + with self.assertRaises(InterruptionError): + reader.read() + + def test_seek(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can seek correctly + reader.seek(1) + self.assertEqual(reader.read(1), self.content[1:2]) + + def test_tell(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can tell correctly + reader.seek(1) + self.assertEqual(reader.tell(), 1) + + +class BaseUploadInputManagerTest(BaseUploadTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + self.config = TransferConfig() + self.recording_subscriber = RecordingSubscriber() + self.subscribers.append(self.recording_subscriber) + + def _get_expected_body_for_part(self, part_number): + # A helper method for retrieving the expected body for a specific + # part number of the data + total_size = len(self.content) + chunk_size = self.config.multipart_chunksize + start_index = (part_number - 1) * chunk_size + end_index = part_number * chunk_size + if end_index >= total_size: + return self.content[start_index:] + return self.content[start_index:end_index] + + +class TestUploadFilenameInputManager(BaseUploadInputManagerTest): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadFilenameInputManager( + self.osutil, self.transfer_coordinator + ) + self.call_args = CallArgs( + fileobj=self.filename, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def test_is_compatible(self): + self.assertTrue( + self.upload_input_manager.is_compatible( + self.future.meta.call_args.fileobj + ) + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # The provided file size should be equal to size of the contents of + # the file. + self.assertEqual(self.future.meta.size, len(self.content)) + + def test_requires_multipart_upload(self): + self.future.meta.provide_transfer_size(len(self.content)) + # With the default multipart threshold, the length of the content + # should be smaller than the threshold thus not requiring a multipart + # transfer. + self.assertFalse( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + # Decreasing the threshold to that of the length of the content of + # the file should trigger the need for a multipart upload. + self.config.multipart_threshold = len(self.content) + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + def test_get_put_object_body(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + read_file_chunk.enable_callback() + # The file-like object provided back should be the same as the content + # of the file. + with read_file_chunk: + self.assertEqual(read_file_chunk.read(), self.content) + # The file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_get_put_object_body_is_interruptable(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + # Ensure the returned read file chunk can be interrupted with that + # error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + def test_yield_upload_part_bodies(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + expected_part_number = 1 + for part_number, read_file_chunk in part_iterator: + # Ensure that the part number is as expected + self.assertEqual(part_number, expected_part_number) + read_file_chunk.enable_callback() + # Ensure that the body is correct for that part. + with read_file_chunk: + self.assertEqual( + read_file_chunk.read(), + self._get_expected_body_for_part(part_number), + ) + expected_part_number += 1 + + # All of the file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_yield_upload_part_bodies_are_interruptable(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + for _, read_file_chunk in part_iterator: + # Ensure that each read file chunk yielded can be interrupted + # with that error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + +class TestUploadSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = open(self.filename, 'rb') + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible_bytes_io(self): + self.assertTrue(self.upload_input_manager.is_compatible(BytesIO())) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse(self.upload_input_manager.is_compatible(object())) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_get_put_object_body(self): + start_pos = 3 + self.fileobj.seek(start_pos) + adjusted_size = len(self.content) - start_pos + self.future.meta.provide_transfer_size(adjusted_size) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + read_file_chunk.enable_callback() + # The fact that the file was seeked to start should be taken into + # account in length and content for the read file chunk. + with read_file_chunk: + self.assertEqual(len(read_file_chunk), adjusted_size) + self.assertEqual(read_file_chunk.read(), self.content[start_pos:]) + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), adjusted_size + ) + + +class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadNonSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = NonSeekableReader(self.content) + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def assert_multipart_parts(self): + """ + Asserts that the input manager will generate a multipart upload + and that each part is in order and the correct size. + """ + # Assert that a multipart upload is required. + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + # Get a list of all the parts that would be sent. + parts = list( + self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + ) + + # Assert that the actual number of parts is what we would expect + # based on the configuration. + size = self.config.multipart_chunksize + num_parts = math.ceil(len(self.content) / size) + self.assertEqual(len(parts), num_parts) + + # Run for every part but the last part. + for i, part in enumerate(parts[:-1]): + # Assert the part number is correct. + self.assertEqual(part[0], i + 1) + # Assert the part contains the right amount of data. + data = part[1].read() + self.assertEqual(len(data), size) + + # Assert that the last part is the correct size. + expected_final_size = len(self.content) - ((num_parts - 1) * size) + final_part = parts[-1] + self.assertEqual(len(final_part[1].read()), expected_final_size) + + # Assert that the last part has the correct part number. + self.assertEqual(final_part[0], len(parts)) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # There is no way to get the size without reading the entire body. + self.assertEqual(self.future.meta.size, None) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_initial_data_parts_threshold_lesser(self): + # threshold < size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 2 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_equal(self): + # threshold == size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 4 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_greater(self): + # threshold > size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 8 + self.assert_multipart_parts() + + +class TestUploadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3) + self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE + self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5 + + with open(self.filename, 'wb') as f: + f.write(self.content) + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def add_multipart_upload_stubbed_responses(self): + self.stubber.add_response( + method='create_multipart_upload', + service_response={'UploadId': 'my-id'}, + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-1'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-2'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-3'} + ) + self.stubber.add_response( + method='complete_multipart_upload', service_response={} + ) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_value_for_put_object(self, tag_value): + self.assertEqual(self.executor.submissions[0]['tag'], tag_value) + + def assert_tag_value_for_upload_parts(self, tag_value): + for submission in self.executor.submissions[1:-1]: + self.assertEqual(submission['tag'], tag_value) + + def test_provide_file_size_on_put(self): + self.call_args.subscribers.append(FileSizeProvider(len(self.content))) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + + # With this submitter, it will fail to stat the file if a transfer + # size is not provided. + self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize() + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + def test_submits_no_tag_for_put_object_filename(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_no_tag_for_multipart_filename(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit any of the upload part tasks were + # were associated when submitted to the executor + self.assert_tag_value_for_upload_parts(None) + + def test_submits_no_tag_for_put_object_fileobj(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_tag_for_multipart_fileobj(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure tags to limit all of the upload part tasks were + # were associated when submitted to the executor as these tasks will + # have chunks of data stored with them in memory. + self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG) + + +class TestPutObjectTask(BaseUploadTest): + def test_main(self): + extra_args = {'Metadata': {'foo': 'bar'}} + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + PutObjectTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + +class TestUploadPartTask(BaseUploadTest): + def test_main(self): + extra_args = {'RequestPayer': 'requester'} + upload_id = 'my-id' + part_number = 1 + etag = 'foo' + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + UploadPartTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'part_number': part_number, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='upload_part', + service_response={'ETag': etag}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'PartNumber': part_number, + 'RequestPayer': 'requester', + }, + ) + rval = task() + self.stubber.assert_no_pending_responses() + self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number}) + self.assertEqual(self.sent_bodies, [self.content]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py new file mode 100644 index 0000000000..217779943b --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py @@ -0,0 +1,1189 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import io +import os.path +import random +import re +import shutil +import tempfile +import threading +import time +from io import BytesIO, StringIO + +from s3transfer.futures import TransferFuture, TransferMeta +from s3transfer.utils import ( + MAX_PARTS, + MAX_SINGLE_UPLOAD_SIZE, + MIN_UPLOAD_CHUNKSIZE, + CallArgs, + ChunksizeAdjuster, + CountCallbackInvoker, + DeferredOpenFile, + FunctionContainer, + NoResourcesAvailable, + OSUtils, + ReadFileChunk, + SlidingWindowSemaphore, + StreamReaderProgress, + TaskSemaphore, + calculate_num_parts, + calculate_range_parameter, + get_callbacks, + get_filtered_dict, + invoke_progress_callbacks, + random_file_extension, +) +from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest + + +class TestGetCallbacks(unittest.TestCase): + def setUp(self): + self.subscriber = RecordingSubscriber() + self.second_subscriber = RecordingSubscriber() + self.call_args = CallArgs( + subscribers=[self.subscriber, self.second_subscriber] + ) + self.transfer_meta = TransferMeta(self.call_args) + self.transfer_future = TransferFuture(self.transfer_meta) + + def test_get_callbacks(self): + callbacks = get_callbacks(self.transfer_future, 'queued') + # Make sure two callbacks were added as both subscribers had + # an on_queued method. + self.assertEqual(len(callbacks), 2) + + # Ensure that the callback was injected with the future by calling + # one of them and checking that the future was used in the call. + callbacks[0]() + self.assertEqual( + self.subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_get_callbacks_for_missing_type(self): + callbacks = get_callbacks(self.transfer_future, 'fake_state') + # There should be no callbacks as the subscribers will not have the + # on_fake_state method + self.assertEqual(len(callbacks), 0) + + +class TestGetFilteredDict(unittest.TestCase): + def test_get_filtered_dict(self): + original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'} + whitelist = ['Include'] + self.assertEqual( + get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'} + ) + + +class TestCallArgs(unittest.TestCase): + def test_call_args(self): + call_args = CallArgs(foo='bar', biz='baz') + self.assertEqual(call_args.foo, 'bar') + self.assertEqual(call_args.biz, 'baz') + + +class TestFunctionContainer(unittest.TestCase): + def get_args_kwargs(self, *args, **kwargs): + return args, kwargs + + def test_call(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual(func_container(), (('foo',), {'bar': 'baz'})) + + def test_repr(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual( + str(func_container), + 'Function: {} with args {} and kwargs {}'.format( + self.get_args_kwargs, ('foo',), {'bar': 'baz'} + ), + ) + + +class TestCountCallbackInvoker(unittest.TestCase): + def invoke_callback(self): + self.ref_results.append('callback invoked') + + def assert_callback_invoked(self): + self.assertEqual(self.ref_results, ['callback invoked']) + + def assert_callback_not_invoked(self): + self.assertEqual(self.ref_results, []) + + def setUp(self): + self.ref_results = [] + self.invoker = CountCallbackInvoker(self.invoke_callback) + + def test_increment(self): + self.invoker.increment() + self.assertEqual(self.invoker.current_count, 1) + + def test_decrement(self): + self.invoker.increment() + self.invoker.increment() + self.invoker.decrement() + self.assertEqual(self.invoker.current_count, 1) + + def test_count_cannot_go_below_zero(self): + with self.assertRaises(RuntimeError): + self.invoker.decrement() + + def test_callback_invoked_only_once_finalized(self): + self.invoker.increment() + self.invoker.decrement() + self.assert_callback_not_invoked() + self.invoker.finalize() + # Callback should only be invoked once finalized + self.assert_callback_invoked() + + def test_callback_invoked_after_finalizing_and_count_reaching_zero(self): + self.invoker.increment() + self.invoker.finalize() + # Make sure that it does not get invoked immediately after + # finalizing as the count is currently one + self.assert_callback_not_invoked() + self.invoker.decrement() + self.assert_callback_invoked() + + def test_cannot_increment_after_finalization(self): + self.invoker.finalize() + with self.assertRaises(RuntimeError): + self.invoker.increment() + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestInvokeProgressCallbacks(unittest.TestCase): + def test_invoke_progress_callbacks(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 2) + self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2) + + def test_invoke_progress_callbacks_with_no_progress(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 0) + self.assertEqual(len(recording_subscriber.on_progress_calls), 0) + + +class TestCalculateNumParts(unittest.TestCase): + def test_calculate_num_parts_divisible(self): + self.assertEqual(calculate_num_parts(size=4, part_size=2), 2) + + def test_calculate_num_parts_not_divisible(self): + self.assertEqual(calculate_num_parts(size=3, part_size=2), 2) + + +class TestCalculateRangeParameter(unittest.TestCase): + def setUp(self): + self.part_size = 5 + self.part_index = 1 + self.num_parts = 3 + + def test_calculate_range_paramter(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, self.num_parts + ) + self.assertEqual(range_val, 'bytes=5-9') + + def test_last_part_with_no_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2 + ) + self.assertEqual(range_val, 'bytes=5-') + + def test_last_part_with_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2, total_size=8 + ) + self.assertEqual(range_val, 'bytes=5-7') + + +class BaseUtilsTest(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + self.content = b'abc' + with open(self.filename, 'wb') as f: + f.write(self.content) + self.amounts_seen = [] + self.num_close_callback_calls = 0 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def callback(self, bytes_transferred): + self.amounts_seen.append(bytes_transferred) + + def close_callback(self): + self.num_close_callback_calls += 1 + + +class TestOSUtils(BaseUtilsTest): + def test_get_file_size(self): + self.assertEqual( + OSUtils().get_file_size(self.filename), len(self.content) + ) + + def test_open_file_chunk_reader(self): + reader = OSUtils().open_file_chunk_reader( + self.filename, 0, 3, [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + + def test_open_file_chunk_reader_from_fileobj(self): + with open(self.filename, 'rb') as f: + reader = OSUtils().open_file_chunk_reader_from_fileobj( + f, len(self.content), len(self.content), [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + reader.close() + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + self.assertEqual(self.num_close_callback_calls, 0) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + non_existent_file = os.path.join(self.tempdir, 'no-exist') + # This should not exist to start. + self.assertFalse(os.path.exists(non_existent_file)) + try: + OSUtils().remove_file(non_existent_file) + except OSError as e: + self.fail('OSError should have been caught: %s' % e) + + def test_remove_file_proxies_remove_file(self): + OSUtils().remove_file(self.filename) + self.assertFalse(os.path.exists(self.filename)) + + def test_rename_file(self): + new_filename = os.path.join(self.tempdir, 'newfoo') + OSUtils().rename_file(self.filename, new_filename) + self.assertFalse(os.path.exists(self.filename)) + self.assertTrue(os.path.exists(new_filename)) + + def test_is_special_file_for_normal_file(self): + self.assertFalse(OSUtils().is_special_file(self.filename)) + + def test_is_special_file_for_non_existant_file(self): + non_existant_filename = os.path.join(self.tempdir, 'no-exist') + self.assertFalse(os.path.exists(non_existant_filename)) + self.assertFalse(OSUtils().is_special_file(non_existant_filename)) + + def test_get_temp_filename(self): + filename = 'myfile' + self.assertIsNotNone( + re.match( + r'%s\.[0-9A-Fa-f]{8}$' % filename, + OSUtils().get_temp_filename(filename), + ) + ) + + def test_get_temp_filename_len_255(self): + filename = 'a' * 255 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_get_temp_filename_len_gt_255(self): + filename = 'a' * 280 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_allocate(self): + truncate_size = 1 + OSUtils().allocate(self.filename, truncate_size) + with open(self.filename, 'rb') as f: + self.assertEqual(len(f.read()), truncate_size) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_io_error(self, mock_fallocate): + mock_fallocate.side_effect = IOError() + with self.assertRaises(IOError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_os_error(self, mock_fallocate): + mock_fallocate.side_effect = OSError() + with self.assertRaises(OSError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + +class TestDeferredOpenFile(BaseUtilsTest): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + self.contents = b'my contents' + with open(self.filename, 'wb') as f: + f.write(self.contents) + self.deferred_open_file = DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.open_call_args = [] + + def tearDown(self): + self.deferred_open_file.close() + super().tearDown() + + def recording_open_function(self, filename, mode): + self.open_call_args.append((filename, mode)) + return open(filename, mode) + + def open_nonseekable(self, filename, mode): + self.open_call_args.append((filename, mode)) + return NonSeekableWriter(BytesIO(self.content)) + + def test_instantiation_does_not_open_file(self): + DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.assertEqual(len(self.open_call_args), 0) + + def test_name(self): + self.assertEqual(self.deferred_open_file.name, self.filename) + + def test_read(self): + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[0:2]) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_write(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + open_function=self.recording_open_function, + ) + + write_content = b'foo' + self.deferred_open_file.write(write_content) + self.deferred_open_file.write(write_content) + self.deferred_open_file.close() + # Both of the writes should now be in the file. + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), write_content * 2) + # Open should have only been called once. + self.assertEqual(len(self.open_call_args), 1) + + def test_seek(self): + self.deferred_open_file.seek(2) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_does_not_seek_with_zero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=0, + open_function=self.open_nonseekable, + ) + + try: + # If this seeks, an UnsupportedOperation error will be raised. + self.deferred_open_file.write(b'data') + except io.UnsupportedOperation: + self.fail('DeferredOpenFile seeked upon opening') + + def test_open_seeks_with_nonzero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=5, + open_function=self.open_nonseekable, + ) + + # Since a non-seekable file is being opened, calling Seek will raise + # an UnsupportedOperation error. + with self.assertRaises(io.UnsupportedOperation): + self.deferred_open_file.write(b'data') + + def test_tell(self): + self.deferred_open_file.tell() + # tell() should not have opened the file if it has not been seeked + # or read because we know the start bytes upfront. + self.assertEqual(len(self.open_call_args), 0) + + self.deferred_open_file.seek(2) + self.assertEqual(self.deferred_open_file.tell(), 2) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_args(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='ab+', + open_function=self.recording_open_function, + ) + # Force an open + self.deferred_open_file.write(b'data') + self.assertEqual(len(self.open_call_args), 1) + self.assertEqual(self.open_call_args[0], (self.filename, 'ab+')) + + def test_context_handler(self): + with self.deferred_open_file: + self.assertEqual(len(self.open_call_args), 1) + + +class TestReadFileChunk(BaseUtilsTest): + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + chunk.seek(1, whence=1) + self.assertEqual(chunk.tell(), 1) + chunk.seek(-1, whence=1) + self.assertEqual(chunk.tell(), 0) + chunk.seek(-1, whence=2) + self.assertEqual(chunk.tell(), 2) + + def test_tell_and_seek_boundaries(self): + # Test to ensure ReadFileChunk behaves the same as the + # Python standard library around seeking and reading out + # of bounds in a file object. + data = b'abcdefghij12345678klmnopqrst' + start_pos = 10 + chunk_size = 8 + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + + # ReadFileChunk should be a substring of only numbers + file_objects = [ + ReadFileChunk.from_filename( + filename, start_byte=start_pos, chunk_size=chunk_size + ) + ] + + # Uncomment next line to validate we match Python's io.BytesIO + # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size])) + + for obj in file_objects: + self._assert_whence_start_behavior(obj) + self._assert_whence_end_behavior(obj) + self._assert_whence_relative_behavior(obj) + self._assert_boundary_behavior(obj) + + def _assert_whence_start_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(1, 0) + self.assertEqual(file_obj.tell(), 1) + + file_obj.seek(1) + self.assertEqual(file_obj.tell(), 1) + self.assertEqual(file_obj.read(), b'2345678') + + file_obj.seek(3, 0) + self.assertEqual(file_obj.tell(), 3) + + file_obj.seek(0, 0) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_relative_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(2, 1) + self.assertEqual(file_obj.tell(), 2) + + file_obj.seek(1, 1) + self.assertEqual(file_obj.tell(), 3) + self.assertEqual(file_obj.read(), b'45678') + + file_obj.seek(20, 1) + self.assertEqual(file_obj.tell(), 28) + + file_obj.seek(-30, 1) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 1) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_end_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(-1, 2) + self.assertEqual(file_obj.tell(), 7) + + file_obj.seek(1, 2) + self.assertEqual(file_obj.tell(), 9) + + file_obj.seek(3, 2) + self.assertEqual(file_obj.tell(), 11) + self.assertEqual(file_obj.read(), b'') + + file_obj.seek(-15, 2) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 2) + self.assertEqual(file_obj.tell(), 0) + + def _assert_boundary_behavior(self, file_obj): + # Verify we're at the start + self.assertEqual(file_obj.tell(), 0) + + # Verify we can't move backwards beyond start of file + file_obj.seek(-10, 1) + self.assertEqual(file_obj.tell(), 0) + + # Verify we *can* move after end of file, but return nothing + file_obj.seek(10, 2) + self.assertEqual(file_obj.tell(), 18) + self.assertEqual(file_obj.read(), b'') + self.assertEqual(file_obj.read(10), b'') + + # Verify we can partially rewind + file_obj.seek(-12, 1) + self.assertEqual(file_obj.tell(), 6) + self.assertEqual(file_obj.read(), b'78') + self.assertEqual(file_obj.tell(), 8) + + # Verify we can rewind to start + file_obj.seek(0) + self.assertEqual(file_obj.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_all_callbacks_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback, self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + # The list should be twice as long because there are two callbacks + # recording the amount read. + self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + def test_callback_triggered_by_out_of_bound_seeks(self): + data = b'abcdefghij1234567890klmnopqr' + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + chunk = ReadFileChunk.from_filename( + filename, start_byte=10, chunk_size=10, callbacks=[self.callback] + ) + + # Seek calls that generate "0" progress are skipped by + # invoke_progress_callbacks and won't appear in the list. + expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10] + + self._assert_out_of_bound_start_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_end_seek(chunk, expected_callback_prog) + + def _assert_out_of_bound_start_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20) # (20, 10) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(11) # (11, 0) + chunk.seek(20) # (20, 0) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(0) # (0, -10) + chunk.seek(0) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_relative_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20, 1) # (20, 10) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-9, 1) # (11, 0) + chunk.seek(9, 1) # (20, 0) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-20, 1) # (0, -10) + chunk.seek(-1000, 1) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_end_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(10, 2) # (20, 10) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(1, 2) # (11, 0) + chunk.seek(10, 2) # (20, 0) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-10, 2) # (0, -10) + chunk.seek(-1000, 2) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def test_close_callbacks(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 1) + + def test_close_callbacks_when_not_enabled(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + enable_callbacks=False, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 0) + + def test_close_callbacks_when_context_handler_is_used(self): + with open(self.filename) as f: + with ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) as chunk: + chunk.read(1) + self.assertEqual(self.num_close_callback_calls, 1) + + def test_signal_transferring(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.signal_not_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, []) + chunk.signal_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, [1]) + + def test_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_transferring() + self.assertTrue(underlying_stream.signal_transferring.called) + + def test_no_call_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call signal_transferring ' + 'to the underlying stream.' + ) + + def test_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_not_transferring() + self.assertTrue(underlying_stream.signal_not_transferring.called) + + def test_no_call_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_not_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call ' + 'signal_not_transferring to the underlying stream.' + ) + + +class TestStreamReaderProgress(BaseUtilsTest): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress( + original_stream, [self.callback, self.callback] + ) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(self.amounts_seen, [9, 9]) + + +class TestTaskSemaphore(unittest.TestCase): + def setUp(self): + self.semaphore = TaskSemaphore(1) + + def test_should_block_at_max_capacity(self): + self.semaphore.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + self.semaphore.acquire('a', blocking=False) + + def test_release_capacity(self): + acquire_token = self.semaphore.acquire('a', blocking=False) + self.semaphore.release('a', acquire_token) + try: + self.semaphore.acquire('a', blocking=False) + except NoResourcesAvailable: + self.fail( + 'The release of the semaphore should have allowed for ' + 'the second acquire to not be blocked' + ) + + +class TestSlidingWindowSemaphore(unittest.TestCase): + # These tests use block=False to tests will fail + # instead of hang the test runner in the case of x + # incorrect behavior. + def test_acquire_release_basic_case(self): + sem = SlidingWindowSemaphore(1) + # Count is 1 + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', 0) + # Count now back to 1. + + def test_can_acquire_release_multiple_times(self): + sem = SlidingWindowSemaphore(1) + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', num) + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 1) + sem.release('a', num) + + def test_can_acquire_a_range(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + self.assertEqual(sem.acquire('a', blocking=False), 2) + sem.release('a', 0) + sem.release('a', 1) + sem.release('a', 2) + # Now we're reset so we should be able to acquire the same + # sequence again. + self.assertEqual(sem.acquire('a', blocking=False), 3) + self.assertEqual(sem.acquire('a', blocking=False), 4) + self.assertEqual(sem.acquire('a', blocking=False), 5) + self.assertEqual(sem.current_count(), 0) + + def test_counter_release_only_on_min_element(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # The count only increases when we free the min + # element. This means if we're currently failing to + # acquire now: + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + # Then freeing a non-min element: + sem.release('a', 1) + + # doesn't change anything. We still fail to acquire. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + self.assertEqual(sem.current_count(), 0) + + def test_raises_error_when_count_is_zero(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # Count is now 0 so trying to acquire should fail. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + def test_release_counters_can_increment_counter_repeatedly(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # These two releases don't increment the counter + # because we're waiting on 0. + sem.release('a', 1) + sem.release('a', 2) + self.assertEqual(sem.current_count(), 0) + # But as soon as we release 0, we free up 0, 1, and 2. + sem.release('a', 0) + self.assertEqual(sem.current_count(), 3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def test_error_to_release_unknown_tag(self): + sem = SlidingWindowSemaphore(3) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_track_multiple_tags(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('b', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + + # We're at our max of 3 even though 2 are for A and 1 is for B. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + sem.acquire('b', blocking=False) + + def test_can_handle_multiple_tags_released(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('b', blocking=False) + sem.acquire('b', blocking=False) + + sem.release('b', 1) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 0) + + sem.release('b', 0) + self.assertEqual(sem.acquire('a', blocking=False), 2) + + sem.release('a', 0) + self.assertEqual(sem.acquire('b', blocking=False), 2) + + def test_is_error_to_release_unknown_sequence_number(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + with self.assertRaises(ValueError): + sem.release('a', 1) + + def test_is_error_to_double_release(self): + # This is different than other error tests because + # we're verifying we can reset the state after an + # acquire/release cycle. + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.release('a', 0) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 2) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_check_in_partial_range(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + sem.release('a', 1) + sem.release('a', 3) + sem.release('a', 0) + self.assertEqual(sem.current_count(), 2) + + +class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase): + # These tests focus on mutithreaded properties of the range + # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore. + def setUp(self): + self.threads = [] + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + self.threads = [] + + def start_threads(self): + for thread in self.threads: + thread.start() + + def test_acquire_blocks_until_release_is_called(self): + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def acquire(): + # This next call to acquire will block. + self.assertEqual(sem.acquire('a', blocking=True), 2) + + t = threading.Thread(target=acquire) + self.threads.append(t) + # Starting the thread will block the sem.acquire() + # in the acquire function above. + t.start() + # This still will keep the thread blocked. + sem.release('a', 1) + # Releasing the min element will unblock the thread. + sem.release('a', 0) + t.join() + sem.release('a', 2) + + def test_stress_invariants_random_order(self): + sem = SlidingWindowSemaphore(100) + for _ in range(10): + recorded = [] + for _ in range(100): + recorded.append(sem.acquire('a', blocking=False)) + # Release them in randomized order. As long as we + # eventually free all 100, we should have all the + # resources released. + random.shuffle(recorded) + for i in recorded: + sem.release('a', i) + + # Everything's freed so should be back at count == 100 + self.assertEqual(sem.current_count(), 100) + + def test_blocking_stress(self): + sem = SlidingWindowSemaphore(5) + num_threads = 10 + num_iterations = 50 + + def acquire(): + for _ in range(num_iterations): + num = sem.acquire('a', blocking=True) + time.sleep(0.001) + sem.release('a', num) + + for i in range(num_threads): + t = threading.Thread(target=acquire) + self.threads.append(t) + self.start_threads() + self.join_threads() + # Should have all the available resources freed. + self.assertEqual(sem.current_count(), 5) + # Should have acquired num_threads * num_iterations + self.assertEqual( + sem.acquire('a', blocking=False), num_threads * num_iterations + ) + + +class TestAdjustChunksize(unittest.TestCase): + def setUp(self): + self.adjuster = ChunksizeAdjuster() + + def test_valid_chunksize(self): + chunksize = 7 * (1024**2) + file_size = 8 * (1024**2) + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, chunksize) + + def test_chunksize_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + file_size = 3 * MIN_UPLOAD_CHUNKSIZE + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_chunksize_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + file_size = MAX_SINGLE_UPLOAD_SIZE * 2 + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) + + def test_chunksize_too_small(self): + chunksize = 7 * (1024**2) + file_size = 5 * (1024**4) + # If we try to upload a 5TB file, we'll need to use 896MB part + # sizes. + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, 896 * (1024**2)) + num_parts = file_size / new_size + self.assertLessEqual(num_parts, MAX_PARTS) + + def test_unknown_file_size_with_valid_chunksize(self): + chunksize = 7 * (1024**2) + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, chunksize) + + def test_unknown_file_size_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_unknown_file_size_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) |