aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/s3transfer/py3/tests/unit
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/s3transfer/py3/tests/unit
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/s3transfer/py3/tests/unit')
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/__init__.py12
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py452
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_compat.py105
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_copies.py207
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_crt.py173
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_delete.py67
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_download.py999
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_futures.py696
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_manager.py143
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_processpool.py728
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py780
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_subscribers.py91
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_tasks.py833
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_upload.py694
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_utils.py1189
15 files changed, 7169 insertions, 0 deletions
diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py
new file mode 100644
index 0000000000..79ef91c6a2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
new file mode 100644
index 0000000000..b796f8f24c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
@@ -0,0 +1,452 @@
+# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import tempfile
+
+from s3transfer.bandwidth import (
+ BandwidthLimitedStream,
+ BandwidthLimiter,
+ BandwidthRateTracker,
+ ConsumptionScheduler,
+ LeakyBucket,
+ RequestExceededException,
+ RequestToken,
+ TimeUtils,
+)
+from s3transfer.futures import TransferCoordinator
+from __tests__ import mock, unittest
+
+
+class FixedIncrementalTickTimeUtils(TimeUtils):
+ def __init__(self, seconds_per_tick=1.0):
+ self._count = 0
+ self._seconds_per_tick = seconds_per_tick
+
+ def time(self):
+ current_count = self._count
+ self._count += self._seconds_per_tick
+ return current_count
+
+
+class TestTimeUtils(unittest.TestCase):
+ @mock.patch('time.time')
+ def test_time(self, mock_time):
+ mock_return_val = 1
+ mock_time.return_value = mock_return_val
+ time_utils = TimeUtils()
+ self.assertEqual(time_utils.time(), mock_return_val)
+
+ @mock.patch('time.sleep')
+ def test_sleep(self, mock_sleep):
+ time_utils = TimeUtils()
+ time_utils.sleep(1)
+ self.assertEqual(mock_sleep.call_args_list, [mock.call(1)])
+
+
+class BaseBandwidthLimitTest(unittest.TestCase):
+ def setUp(self):
+ self.leaky_bucket = mock.Mock(LeakyBucket)
+ self.time_utils = mock.Mock(TimeUtils)
+ self.tempdir = tempfile.mkdtemp()
+ self.content = b'a' * 1024 * 1024
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.coordinator = TransferCoordinator()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def assert_consume_calls(self, amts):
+ expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts]
+ self.assertEqual(
+ self.leaky_bucket.consume.call_args_list, expected_consume_args
+ )
+
+
+class TestBandwidthLimiter(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
+
+ def test_get_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.assert_consume_calls(amts=[len(self.content)])
+
+ def test_get_disabled_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator, enabled=False
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.leaky_bucket.consume.assert_not_called()
+
+
+class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bytes_threshold = 1
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def get_bandwidth_limited_stream(self, f):
+ return BandwidthLimitedStream(
+ f,
+ self.leaky_bucket,
+ self.coordinator,
+ self.time_utils,
+ self.bytes_threshold,
+ )
+
+ def assert_sleep_calls(self, amts):
+ expected_sleep_args_list = [mock.call(amt) for amt in amts]
+ self.assertEqual(
+ self.time_utils.sleep.call_args_list, expected_sleep_args_list
+ )
+
+ def get_unique_consume_request_tokens(self):
+ return {
+ call_args[0][1]
+ for call_args in self.leaky_bucket.consume.call_args_list
+ }
+
+ def test_read(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[len(self.content)])
+ self.assert_sleep_calls(amts=[])
+
+ def test_retries_on_request_exceeded(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ retry_time = 1
+ amt_requested = len(self.content)
+ self.leaky_bucket.consume.side_effect = [
+ RequestExceededException(amt_requested, retry_time),
+ len(self.content),
+ ]
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[amt_requested, amt_requested])
+ self.assert_sleep_calls(amts=[retry_time])
+
+ def test_with_transfer_coordinator_exception(self):
+ self.coordinator.set_exception(ValueError())
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ with self.assertRaises(ValueError):
+ stream.read(len(self.content))
+
+ def test_read_when_bandwidth_limiting_disabled(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assertFalse(self.leaky_bucket.consume.called)
+
+ def test_read_toggle_disable_enable_bandwidth_limiting(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.enable_bandwidth_limiting()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+ def test_seek(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.seek(1)
+ self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)])
+
+ def test_tell(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.tell()
+ self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()])
+
+ def test_close(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.close()
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_context_manager(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ with stream as stream_handle:
+ self.assertIs(stream_handle, stream)
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_reuses_request_token(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
+
+ def test_request_tokens_unique_per_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
+
+ def test_call_consume_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ self.assertEqual(stream.read(1), self.content[1:2])
+ self.assert_consume_calls(amts=[2])
+
+ def test_resets_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ self.assertEqual(stream.read(1), self.content[2:3])
+ self.assert_consume_calls(amts=[2])
+
+ def test_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.close()
+ self.assert_consume_calls(amts=[1])
+
+ def test_no_bytes_remaining_on(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ stream.close()
+ # There should have been no more consume() calls made
+ # as all bytes have been accounted for in the previous
+ # consume() call.
+ self.assert_consume_calls(amts=[2])
+
+ def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.disable_bandwidth_limiting()
+ stream.close()
+ self.assert_consume_calls(amts=[])
+
+ def test_signal_transferring(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.signal_not_transferring()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.signal_transferring()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+
+class TestLeakyBucket(unittest.TestCase):
+ def setUp(self):
+ self.max_rate = 1
+ self.time_now = 1.0
+ self.time_utils = mock.Mock(TimeUtils)
+ self.time_utils.time.return_value = self.time_now
+ self.scheduler = mock.Mock(ConsumptionScheduler)
+ self.scheduler.is_scheduled.return_value = False
+ self.rate_tracker = mock.Mock(BandwidthRateTracker)
+ self.leaky_bucket = LeakyBucket(
+ self.max_rate, self.time_utils, self.rate_tracker, self.scheduler
+ )
+
+ def set_projected_rate(self, rate):
+ self.rate_tracker.get_projected_rate.return_value = rate
+
+ def set_retry_time(self, retry_time):
+ self.scheduler.schedule_consumption.return_value = retry_time
+
+ def assert_recorded_consumed_amt(self, expected_amt):
+ self.assertEqual(
+ self.rate_tracker.record_consumption_rate.call_args,
+ mock.call(expected_amt, self.time_utils.time.return_value),
+ )
+
+ def assert_was_scheduled(self, amt, token):
+ self.assertEqual(
+ self.scheduler.schedule_consumption.call_args,
+ mock.call(amt, token, amt / (self.max_rate)),
+ )
+
+ def assert_nothing_scheduled(self):
+ self.assertFalse(self.scheduler.schedule_consumption.called)
+
+ def assert_processed_request_token(self, request_token):
+ self.assertEqual(
+ self.scheduler.process_scheduled_consumption.call_args,
+ mock.call(request_token),
+ )
+
+ def test_consume_under_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate / 2)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_at_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_over_max_rate(self):
+ amt = 1
+ retry_time = 2.0
+ self.set_projected_rate(self.max_rate + 1)
+ self.set_retry_time(retry_time)
+ request_token = RequestToken()
+ try:
+ self.leaky_bucket.consume(amt, request_token)
+ self.fail('A RequestExceededException should have been thrown')
+ except RequestExceededException as e:
+ self.assertEqual(e.requested_amt, amt)
+ self.assertEqual(e.retry_time, retry_time)
+ self.assert_was_scheduled(amt, request_token)
+
+ def test_consume_with_scheduled_retry(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate + 1)
+ self.scheduler.is_scheduled.return_value = True
+ request_token = RequestToken()
+ self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
+ # Nothing new should have been scheduled but the request token
+ # should have been processed.
+ self.assert_nothing_scheduled()
+ self.assert_processed_request_token(request_token)
+
+
+class TestConsumptionScheduler(unittest.TestCase):
+ def setUp(self):
+ self.scheduler = ConsumptionScheduler()
+
+ def test_schedule_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ def test_schedule_consumption_for_multiple_requests(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ other_consume_time = 3
+ other_token = RequestToken()
+ next_wait_time = self.scheduler.schedule_consumption(
+ 1, other_token, other_consume_time
+ )
+
+ # This wait time should be the previous time plus its desired
+ # wait time
+ self.assertEqual(next_wait_time, consume_time + other_consume_time)
+
+ def test_is_scheduled(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.assertTrue(self.scheduler.is_scheduled(token))
+
+ def test_is_not_scheduled(self):
+ self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
+
+ def test_process_scheduled_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.scheduler.process_scheduled_consumption(token)
+ self.assertFalse(self.scheduler.is_scheduled(token))
+ different_time = 7
+ # The previous consume time should have no affect on the next wait tim
+ # as it has been completed.
+ self.assertEqual(
+ self.scheduler.schedule_consumption(1, token, different_time),
+ different_time,
+ )
+
+
+class TestBandwidthRateTracker(unittest.TestCase):
+ def setUp(self):
+ self.alpha = 0.8
+ self.rate_tracker = BandwidthRateTracker(self.alpha)
+
+ def test_current_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate_after_one_recorded_point(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ # There is no last time point to do a diff against so return a
+ # current rate of 0.0
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, 0.96)
+
+ def test_get_projected_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
+
+ def test_get_projected_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ projected_rate = self.rate_tracker.get_projected_rate(1, 3)
+ self.assertEqual(projected_rate, 0.96)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, projected_rate)
+
+ def test_get_projected_rate_for_same_timestamp(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.assertEqual(
+ self.rate_tracker.get_projected_rate(1, 1), float('inf')
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
new file mode 100644
index 0000000000..78fdc25845
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
@@ -0,0 +1,105 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import signal
+import tempfile
+from io import BytesIO
+
+from s3transfer.compat import BaseManager, readable, seekable
+from __tests__ import skip_if_windows, unittest
+
+
+class ErrorRaisingSeekWrapper:
+ """An object wrapper that throws an error when seeked on
+
+ :param fileobj: The fileobj that it wraps
+ :param exception: The exception to raise when seeked on.
+ """
+
+ def __init__(self, fileobj, exception):
+ self._fileobj = fileobj
+ self._exception = exception
+
+ def seek(self, offset, whence=0):
+ raise self._exception
+
+ def tell(self):
+ return self._fileobj.tell()
+
+
+class TestSeekable(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_seekable_fileobj(self):
+ with open(self.filename, 'w') as f:
+ self.assertTrue(seekable(f))
+
+ def test_non_file_like_obj(self):
+ # Fails because there is no seekable(), seek(), nor tell()
+ self.assertFalse(seekable(object()))
+
+ def test_non_seekable_ioerror(self):
+ # Should return False if IOError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError())))
+
+ def test_non_seekable_oserror(self):
+ # Should return False if OSError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError())))
+
+
+class TestReadable(unittest.TestCase):
+ def test_readable_fileobj(self):
+ with tempfile.TemporaryFile() as f:
+ self.assertTrue(readable(f))
+
+ def test_readable_file_like_obj(self):
+ self.assertTrue(readable(BytesIO()))
+
+ def test_non_file_like_obj(self):
+ self.assertFalse(readable(object()))
+
+
+class TestBaseManager(unittest.TestCase):
+ def create_pid_manager(self):
+ class PIDManager(BaseManager):
+ pass
+
+ PIDManager.register('getpid', os.getpid)
+ return PIDManager()
+
+ def get_pid(self, pid_manager):
+ pid = pid_manager.getpid()
+ # A proxy object is returned back. The needed value can be acquired
+ # from the repr and converting that to an integer
+ return int(str(pid))
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_can_provide_signal_handler_initializers_to_start(self):
+ manager = self.create_pid_manager()
+ manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN))
+ pid = self.get_pid(manager)
+ try:
+ os.kill(pid, signal.SIGINT)
+ except KeyboardInterrupt:
+ pass
+ # Try using the manager after the os.kill on the parent process. The
+ # manager should not have died and should still be usable.
+ self.assertEqual(pid, self.get_pid(manager))
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
new file mode 100644
index 0000000000..4e6992a28b
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
@@ -0,0 +1,207 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.copies import CopyObjectTask, CopyPartTask
+from __tests__ import BaseTaskTest, RecordingSubscriber
+
+
+class BaseCopyTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
+ self.extra_args = {}
+ self.callbacks = []
+ self.size = 5
+
+
+class TestCopyObjectTask(BaseCopyTaskTest):
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['ACL'] = 'private'
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'ACL': 'private',
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
+
+
+class TestCopyPartTask(BaseCopyTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.copy_source_range = 'bytes=5-9'
+ self.extra_args['CopySourceRange'] = self.copy_source_range
+ self.upload_id = 'myuploadid'
+ self.part_number = 1
+ self.result_etag = 'my-etag'
+ self.checksum_sha1 = 'my-checksum_sha1'
+
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': self.upload_id,
+ 'part_number': self.part_number,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyPartTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_main_with_checksum(self):
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={
+ 'CopyPartResult': {
+ 'ETag': self.result_etag,
+ 'ChecksumSHA1': self.checksum_sha1,
+ }
+ },
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task(checksum_algorithm="sha1")
+ self.assertEqual(
+ task(),
+ {
+ 'PartNumber': self.part_number,
+ 'ETag': self.result_etag,
+ 'ChecksumSHA1': self.checksum_sha1,
+ },
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ 'RequestPayer': 'requester',
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
new file mode 100644
index 0000000000..8c32668eab
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
@@ -0,0 +1,173 @@
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from botocore.credentials import CredentialResolver, ReadOnlyCredentials
+from botocore.session import Session
+
+from s3transfer.exceptions import TransferNotDoneError
+from s3transfer.utils import CallArgs
+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+if HAS_CRT:
+ import awscrt.s3
+
+ import s3transfer.crt
+
+
+class CustomFutureException(Exception):
+ pass
+
+
+@requires_crt
+class TestBotocoreCRTRequestSerializer(unittest.TestCase):
+ def setUp(self):
+ self.region = 'us-west-2'
+ self.session = Session()
+ self.session.set_config_variable('region', self.region)
+ self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
+ self.session
+ )
+ self.bucket = "test_bucket"
+ self.key = "test_key"
+ self.files = FileCreator()
+ self.filename = self.files.create_file('myfile', 'my content')
+ self.expected_path = "/" + self.bucket + "/" + self.key
+ self.expected_host = "s3.%s.amazonaws.com" % (self.region)
+
+ def tearDown(self):
+ self.files.remove_all()
+
+ def test_upload_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "put_object", future
+ )
+ self.assertEqual("PUT", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_download_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "get_object", future
+ )
+ self.assertEqual("GET", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_delete_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket, key=self.key, extra_args={}, subscribers=[]
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "delete_object", future
+ )
+ self.assertEqual("DELETE", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+
+@requires_crt
+class TestCRTCredentialProviderAdapter(unittest.TestCase):
+ def setUp(self):
+ self.botocore_credential_provider = mock.Mock(CredentialResolver)
+ self.access_key = "access_key"
+ self.secret_key = "secret_key"
+ self.token = "token"
+ self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials(
+ self.access_key, self.secret_key, self.token
+ )
+
+ def _call_adapter_and_check(self, credentails_provider_adapter):
+ credentials = credentails_provider_adapter()
+ self.assertEqual(credentials.access_key_id, self.access_key)
+ self.assertEqual(credentials.secret_access_key, self.secret_key)
+ self.assertEqual(credentials.session_token, self.token)
+
+ def test_fetch_crt_credentials_successfully(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ self._call_adapter_and_check(credentails_provider_adapter)
+
+ def test_load_credentials_once(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ called_times = 5
+ for i in range(called_times):
+ self._call_adapter_and_check(credentails_provider_adapter)
+ # Assert that the load_credentails of botocore credential provider
+ # will only be called once
+ self.assertEqual(
+ self.botocore_credential_provider.load_credentials.call_count, 1
+ )
+
+
+@requires_crt
+class TestCRTTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType)
+ self.mock_crt_future = mock.Mock(awscrt.s3.Future)
+ self.mock_s3_request.finished_future = self.mock_crt_future
+ self.coordinator = s3transfer.crt.CRTTransferCoordinator()
+ self.coordinator.set_s3_request(self.mock_s3_request)
+ self.future = s3transfer.crt.CRTTransferFuture(
+ coordinator=self.coordinator
+ )
+
+ def test_set_exception(self):
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
+
+ def test_set_exception_raises_error_when_not_done(self):
+ self.mock_crt_future.done.return_value = False
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(CustomFutureException())
+
+ def test_set_exception_can_override_previous_exception(self):
+ self.future.set_exception(Exception())
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
new file mode 100644
index 0000000000..23b77112f2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
@@ -0,0 +1,67 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.delete import DeleteObjectTask
+from __tests__ import BaseTaskTest
+
+
+class TestDeleteObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+
+ def get_delete_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['MFA'] = 'mfa-code'
+ self.extra_args['VersionId'] = '12345'
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ # These extra_args should be injected into the
+ # expected params for the delete_object call.
+ 'MFA': 'mfa-code',
+ 'VersionId': '12345',
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py
new file mode 100644
index 0000000000..e8b5fe1f86
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py
@@ -0,0 +1,999 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from s3transfer.bandwidth import BandwidthLimiter
+from s3transfer.compat import SOCKET_ERROR
+from s3transfer.download import (
+ CompleteDownloadNOOPTask,
+ DeferQueue,
+ DownloadChunkIterator,
+ DownloadFilenameOutputManager,
+ DownloadNonSeekableOutputManager,
+ DownloadSeekableOutputManager,
+ DownloadSpecialFilenameOutputManager,
+ DownloadSubmissionTask,
+ GetObjectTask,
+ ImmediatelyWriteIOGetObjectTask,
+ IOCloseTask,
+ IORenameFileTask,
+ IOStreamingWriteTask,
+ IOWriteTask,
+)
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileCreator,
+ NonSeekableWriter,
+ RecordingExecutor,
+ StreamWithError,
+ mock,
+ unittest,
+)
+
+
+class DownloadException(Exception):
+ pass
+
+
+class WriteCollector:
+ """A utility to collect information about writes and seeks"""
+
+ def __init__(self):
+ self._pos = 0
+ self.writes = []
+
+ def seek(self, pos, whence=0):
+ self._pos = pos
+
+ def write(self, data):
+ self.writes.append((self._pos, data))
+ self._pos += len(data)
+
+
+class AlwaysIndicatesSpecialFileOSUtils(OSUtils):
+ """OSUtil that always returns True for is_special_file"""
+
+ def is_special_file(self, filename):
+ return True
+
+
+class CancelledStreamWrapper:
+ """A wrapper to trigger a cancellation while stream reading
+
+ Forces the transfer coordinator to cancel after a certain amount of reads
+ :param stream: The underlying stream to read from
+ :param transfer_coordinator: The coordinator for the transfer
+ :param num_reads: On which read to signal a cancellation. 0 is the first
+ read.
+ """
+
+ def __init__(self, stream, transfer_coordinator, num_reads=0):
+ self._stream = stream
+ self._transfer_coordinator = transfer_coordinator
+ self._num_reads = num_reads
+ self._count = 0
+
+ def read(self, *args, **kwargs):
+ if self._num_reads == self._count:
+ self._transfer_coordinator.cancel()
+ self._stream.read(*args, **kwargs)
+ self._count += 1
+
+
+class BaseDownloadOutputManagerTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+
+ # Create a file to write to
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.call_args = CallArgs(fileobj=self.filename)
+ self.future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+
+class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ self.assertTrue(hasattr(f, 'seek'))
+ # Make sure the name of the file returned is not the same as the
+ # final filename as we should be writing to a temporary file.
+ self.assertNotEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ ref_contents = b'my_contents'
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ temp_filename = f.name
+ # Write some data to test that the data gets moved over to the
+ # final location.
+ f.write(ref_contents)
+ final_task = self.download_output_manager.get_final_io_task()
+ # Make sure it is the appropriate task.
+ self.assertIsInstance(final_task, IORenameFileTask)
+ final_task()
+ # Make sure the temp_file gets removed
+ self.assertFalse(os.path.exists(temp_filename))
+ # Make sure what ever was written to the temp file got moved to
+ # the final filename
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), ref_contents)
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = AlwaysIndicatesSpecialFileOSUtils()
+ self.download_output_manager = DownloadSpecialFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible_for_special_file(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, AlwaysIndicatesSpecialFileOSUtils()
+ )
+ )
+
+ def test_is_not_compatible_for_non_special_file(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, OSUtils()
+ )
+ )
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ # Make sure the name of the file returned is the same as the
+ # final filename as we should not be writing to a temporary file.
+ self.assertEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(), IOCloseTask
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+
+class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ # Create a fileobj to write to
+ self.fileobj = open(self.filename, 'wb')
+
+ self.call_args = CallArgs(fileobj=self.fileobj)
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.fileobj, self.osutil
+ )
+ )
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(object(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ self.assertIs(
+ self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ),
+ self.fileobj,
+ )
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(),
+ CompleteDownloadNOOPTask,
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadNonSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, io_executor=None
+ )
+
+ def test_is_compatible_with_seekable_stream(self):
+ with open(self.filename, 'wb') as f:
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_not_compatible_with_filename(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_compatible_with_non_seekable_stream(self):
+ class NonSeekable:
+ def write(self, data):
+ pass
+
+ f = NonSeekable()
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_is_compatible_with_bytesio(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIs(
+ self.download_output_manager.get_download_task_tag(),
+ IN_MEMORY_DOWNLOAD_TAG,
+ )
+
+ def test_submit_writes_from_internal_queue(self):
+ class FakeQueue:
+ def request_writes(self, offset, data):
+ return [
+ {'offset': 0, 'data': 'foo'},
+ {'offset': 3, 'data': 'bar'},
+ ]
+
+ q = FakeQueue()
+ io_executor = BoundedExecutor(1000, 1)
+ manager = DownloadNonSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=io_executor,
+ defer_queue=q,
+ )
+ fileobj = WriteCollector()
+ manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1)
+ io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=1
+ )
+ self.assertIsInstance(io_write_task, IOStreamingWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(0, 'foo')])
+
+
+class TestDownloadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # Create a stream to read from
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'io_executor': self.io_executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_download_submission_task()
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_for_get_object(self, tag_value):
+ submissions_to_compare = self.executor.submissions
+ if len(submissions_to_compare) > 1:
+ # If it was ranged get, make sure we do not include the join task.
+ submissions_to_compare = submissions_to_compare[:-1]
+ for submission in submissions_to_compare:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def add_head_object_response(self):
+ self.stubber.add_response(
+ 'head_object', {'ContentLength': len(self.content)}
+ )
+
+ def add_get_responses(self):
+ chunksize = self.config.multipart_chunksize
+ for i in range(0, len(self.content), chunksize):
+ if i + chunksize > len(self.content):
+ stream = BytesIO(self.content[i:])
+ self.stubber.add_response('get_object', {'Body': stream})
+ else:
+ stream = BytesIO(self.content[i : i + chunksize])
+ self.stubber.add_response('get_object', {'Body': stream})
+
+ def configure_for_ranged_get(self):
+ self.config.multipart_threshold = 1
+ self.config.multipart_chunksize = 4
+
+ def get_download_submission_task(self):
+ return self.get_task(
+ DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def wait_and_assert_completed_successfully(self, submission_task):
+ submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_submits_no_tag_for_get_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_filename(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def tests_submits_tag_for_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+
+class TestGetObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+ self.max_attempts = 5
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+ self.fileobj = WriteCollector()
+ self.osutil = OSUtils()
+ self.io_chunksize = 64 * (1024**2)
+ self.task_cls = GetObjectTask
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def get_download_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'fileobj': self.fileobj,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'max_attempts': self.max_attempts,
+ 'download_output_manager': self.download_output_manager,
+ 'io_chunksize': self.io_chunksize,
+ }
+ default_kwargs.update(kwargs)
+ self.transfer_coordinator.set_status_to_queued()
+ return self.get_task(self.task_cls, main_kwargs=default_kwargs)
+
+ def assert_io_writes(self, expected_writes):
+ # Let the io executor process all of the writes before checking
+ # what writes were sent to it.
+ self.io_executor.shutdown()
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_extra_args(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Range': 'bytes=0-',
+ },
+ )
+ self.extra_args['Range'] = 'bytes=0-'
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_control_chunk_size(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+
+ self.assert_io_writes(expected_contents)
+
+ def test_start_index(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(start_index=5)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(5, self.content)])
+
+ def test_uses_bandwidth_limiter(self):
+ bandwidth_limiter = mock.Mock(BandwidthLimiter)
+
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(bandwidth_limiter=bandwidth_limiter)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(
+ bandwidth_limiter.get_bandwith_limited_stream.call_args_list,
+ [mock.call(mock.ANY, self.transfer_coordinator)],
+ )
+
+ def test_retries_succeeds(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ # Retryable error should have not affected the bytes placed into
+ # the io queue.
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_retries_failure(self):
+ for _ in range(self.max_attempts):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+
+ # Should have failed out on a RetriesExceededError
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_retries_in_middle_of_streaming(self):
+ # After the first read a retryable error will be thrown
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(
+ copy.deepcopy(self.stream), SOCKET_ERROR, 1
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ # This is the content initially read in before the retry hit on the
+ # second read()
+ expected_contents.append((0, bytes(self.content[0:1])))
+
+ # The rest of the content should be the entire set of data partitioned
+ # out based on the one byte stream chunk size. Note the second
+ # element in the list should be a copy of the first element since
+ # a retryable exception happened in between.
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+ self.assert_io_writes(expected_contents)
+
+ def test_cancels_out_of_queueing(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': CancelledStreamWrapper(
+ self.stream, self.transfer_coordinator
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ # Make sure that no contents were added to the queue because the task
+ # should have been canceled before trying to add the contents to the
+ # io queue.
+ self.assert_io_writes([])
+
+ def test_handles_callback_on_initial_error(self):
+ # We can't use the stubber for this because we need to raise
+ # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows
+ # you to raise a ClientError.
+ self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR())
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+ # Should have failed out on a RetriesExceededError because
+ # get_object keeps raising a socket error.
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+
+
+class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask):
+ def setUp(self):
+ super().setUp()
+ self.task_cls = ImmediatelyWriteIOGetObjectTask
+ # When data is written out, it should not use the io executor at all
+ # if it does use the io executor that is a deviation from expected
+ # behavior as the data should be written immediately to the file
+ # object once downloaded.
+ self.io_executor = None
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def assert_io_writes(self, expected_writes):
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+
+class BaseIOTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.osutil = OSUtils()
+ self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile')
+ self.final_filename = os.path.join(self.files.rootdir, 'myfile')
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+
+class TestIOStreamingWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foobar'},
+ )
+ task()
+ task2 = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'baz'},
+ )
+ task2()
+ with open(self.temp_filename, 'rb') as f:
+ # We should just have written to the file in the order
+ # the tasks were executed.
+ self.assertEqual(f.read(), b'foobarbaz')
+
+
+class TestIOWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ # Write once to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0},
+ )
+ task()
+
+ # Write again to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3},
+ )
+ task()
+
+ with open(self.temp_filename, 'rb') as f:
+ self.assertEqual(f.read(), b'foobar')
+
+
+class TestIORenameFileTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IORenameFileTask,
+ main_kwargs={
+ 'fileobj': f,
+ 'final_filename': self.final_filename,
+ 'osutil': self.osutil,
+ },
+ )
+ task()
+ self.assertTrue(os.path.exists(self.final_filename))
+ self.assertFalse(os.path.exists(self.temp_filename))
+
+
+class TestIOCloseTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'w') as f:
+ task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f})
+ task()
+ self.assertTrue(f.closed)
+
+
+class TestDownloadChunkIterator(unittest.TestCase):
+ def test_iter(self):
+ content = b'my content'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, len(content)):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'my content'])
+
+ def test_iter_chunksize(self):
+ content = b'1234'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'123', b'4'])
+
+ def test_empty_content(self):
+ body = BytesIO(b'')
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b''])
+
+
+class TestDeferQueue(unittest.TestCase):
+ def setUp(self):
+ self.q = DeferQueue()
+
+ def test_no_writes_when_not_lowest_block(self):
+ writes = self.q.request_writes(offset=1, data='bar')
+ self.assertEqual(writes, [])
+
+ def test_writes_returned_in_order(self):
+ self.assertEqual(self.q.request_writes(offset=3, data='d'), [])
+ self.assertEqual(self.q.request_writes(offset=2, data='c'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # Everything at this point has been deferred, but as soon as we
+ # send offset=0, that will unlock offsets 0-3.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ {'offset': 3, 'data': 'd'},
+ ],
+ )
+
+ def test_unlocks_partial_range(self):
+ self.assertEqual(self.q.request_writes(offset=5, data='f'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ ],
+ )
+
+ def test_data_can_be_any_size(self):
+ self.q.request_writes(offset=5, data='hello world')
+ writes = self.q.request_writes(offset=0, data='abcde')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'abcde'},
+ {'offset': 5, 'data': 'hello world'},
+ ],
+ )
+
+ def test_data_queued_in_order(self):
+ # This immediately gets returned because offset=0 is the
+ # next range we're waiting on.
+ writes = self.q.request_writes(offset=0, data='hello world')
+ self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}])
+ # Same thing here but with offset
+ writes = self.q.request_writes(offset=11, data='hello again')
+ self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}])
+
+ def test_writes_below_min_offset_are_ignored(self):
+ self.q.request_writes(offset=0, data='a')
+ self.q.request_writes(offset=1, data='b')
+ self.q.request_writes(offset=2, data='c')
+
+ # At this point we're expecting offset=3, so if a write
+ # comes in below 3, we ignore it.
+ self.assertEqual(self.q.request_writes(offset=0, data='a'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=3, data='d'),
+ [{'offset': 3, 'data': 'd'}],
+ )
+
+ def test_duplicate_writes_are_ignored(self):
+ self.q.request_writes(offset=2, data='c')
+ self.q.request_writes(offset=1, data='b')
+
+ # We're still waiting for offset=0, but if
+ # a duplicate write comes in for offset=2/offset=1
+ # it's ignored. This gives "first one wins" behavior.
+ self.assertEqual(self.q.request_writes(offset=2, data='X'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='Y'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=0, data='a'),
+ [
+ {'offset': 0, 'data': 'a'},
+ # Note we're seeing 'b' 'c', and not 'X', 'Y'.
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ ],
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
new file mode 100644
index 0000000000..ca2888a654
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
@@ -0,0 +1,696 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import sys
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import (
+ CancelledError,
+ FatalError,
+ TransferNotDoneError,
+)
+from s3transfer.futures import (
+ BaseExecutor,
+ BoundedExecutor,
+ ExecutorFuture,
+ NonThreadedExecutor,
+ NonThreadedExecutorFuture,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.tasks import Task
+from s3transfer.utils import (
+ FunctionContainer,
+ NoResourcesAvailable,
+ TaskSemaphore,
+)
+from __tests__ import (
+ RecordingExecutor,
+ TransferCoordinatorWithInterrupt,
+ mock,
+ unittest,
+)
+
+
+def return_call_args(*args, **kwargs):
+ return args, kwargs
+
+
+def raise_exception(exception):
+ raise exception
+
+
+def get_exc_info(exception):
+ try:
+ raise_exception(exception)
+ except Exception:
+ return sys.exc_info()
+
+
+class RecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self):
+ self.all_transfer_futures_ever_associated = set()
+ super().__init__()
+
+ def add_associated_future(self, future):
+ self.all_transfer_futures_ever_associated.add(future)
+ super().add_associated_future(future)
+
+
+class ReturnFooTask(Task):
+ def _main(self, **kwargs):
+ return 'foo'
+
+
+class SleepTask(Task):
+ def _main(self, sleep_time, **kwargs):
+ time.sleep(sleep_time)
+
+
+class TestTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.meta = TransferMeta()
+ self.coordinator = TransferCoordinator()
+ self.future = self._get_transfer_future()
+
+ def _get_transfer_future(self, **kwargs):
+ components = {
+ 'meta': self.meta,
+ 'coordinator': self.coordinator,
+ }
+ for component_name, component in kwargs.items():
+ components[component_name] = component
+ return TransferFuture(**components)
+
+ def test_meta(self):
+ self.assertIs(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.coordinator.set_result(None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ result = 'foo'
+ self.coordinator.set_result(result)
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), result)
+
+ def test_keyboard_interrupt_on_result_does_not_block(self):
+ # This should raise a KeyboardInterrupt when result is called on it.
+ self.coordinator = TransferCoordinatorWithInterrupt()
+ self.future = self._get_transfer_future()
+
+ # result() should not block and immediately raise the keyboard
+ # interrupt exception.
+ with self.assertRaises(KeyboardInterrupt):
+ self.future.result()
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.assertTrue(self.future.done())
+ self.assertEqual(self.coordinator.status, 'cancelled')
+
+ def test_set_exception(self):
+ # Set the result such that there is no exception
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), 'result')
+
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+ def test_set_exception_only_after_done(self):
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(ValueError())
+
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+
+class TestTransferMeta(unittest.TestCase):
+ def setUp(self):
+ self.transfer_meta = TransferMeta()
+
+ def test_size(self):
+ self.assertEqual(self.transfer_meta.size, None)
+ self.transfer_meta.provide_transfer_size(5)
+ self.assertEqual(self.transfer_meta.size, 5)
+
+ def test_call_args(self):
+ call_args = object()
+ transfer_meta = TransferMeta(call_args)
+ # Assert the that call args provided is the same as is returned
+ self.assertIs(transfer_meta.call_args, call_args)
+
+ def test_transfer_id(self):
+ transfer_meta = TransferMeta(transfer_id=1)
+ self.assertEqual(transfer_meta.transfer_id, 1)
+
+ def test_user_context(self):
+ self.transfer_meta.user_context['foo'] = 'bar'
+ self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'})
+
+
+class TestTransferCoordinator(unittest.TestCase):
+ def setUp(self):
+ self.transfer_coordinator = TransferCoordinator()
+
+ def test_transfer_id(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(transfer_coordinator.transfer_id, 1)
+
+ def test_repr(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(
+ repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)'
+ )
+
+ def test_initial_status(self):
+ # A TransferCoordinator with no progress should have the status
+ # of not-started
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ def test_set_status_to_queued(self):
+ self.transfer_coordinator.set_status_to_queued()
+ self.assertEqual(self.transfer_coordinator.status, 'queued')
+
+ def test_cannot_set_status_to_queued_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_queued()
+
+ def test_status_running(self):
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ def test_cannot_set_status_to_running_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_running()
+
+ def test_set_result(self):
+ success_result = 'foo'
+ self.transfer_coordinator.set_result(success_result)
+ self.transfer_coordinator.announce_done()
+ # Setting result should result in a success state and the return value
+ # that was set.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+ self.assertEqual(self.transfer_coordinator.result(), success_result)
+
+ def test_set_exception(self):
+ exception_result = RuntimeError
+ self.transfer_coordinator.set_exception(exception_result)
+ self.transfer_coordinator.announce_done()
+ # Setting an exception should result in a failed state and the return
+ # value should be the raised exception
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(self.transfer_coordinator.exception, exception_result)
+ with self.assertRaises(exception_result):
+ self.transfer_coordinator.result()
+
+ def test_exception_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError)
+ # It status should be success even after the exception is set because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_exception_can_override_done_state_with_override_flag(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError, override=True)
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_cancel(self):
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.cancel()
+ # This should set the state to cancelled and raise the CancelledError
+ # exception and should have also set the done event so that result()
+ # is no longer set.
+ self.assertEqual(self.transfer_coordinator.status, 'cancelled')
+ with self.assertRaises(CancelledError):
+ self.transfer_coordinator.result()
+
+ def test_cancel_can_run_done_callbacks_that_uses_result(self):
+ exceptions = []
+
+ def capture_exception(transfer_coordinator, captured_exceptions):
+ try:
+ transfer_coordinator.result()
+ except Exception as e:
+ captured_exceptions.append(e)
+
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.add_done_callback(
+ capture_exception, self.transfer_coordinator, exceptions
+ )
+ self.transfer_coordinator.cancel()
+
+ self.assertEqual(len(exceptions), 1)
+ self.assertIsInstance(exceptions[0], CancelledError)
+
+ def test_cancel_with_message(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message, exc_type=FatalError)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.cancel()
+ # It status should be success even after cancel is called because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_set_result_can_override_cancel(self):
+ self.transfer_coordinator.cancel()
+ # Result setting should override any cancel or set exception as this
+ # is always invoked by the final task.
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_submit(self):
+ # Submit a callable to the transfer coordinator. It should submit it
+ # to the executor.
+ executor = RecordingExecutor(
+ BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)})
+ )
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task, tag='my-tag')
+ executor.shutdown()
+ # Make sure the future got submit and executed as well by checking its
+ # result value which should include the provided future tag.
+ self.assertEqual(
+ executor.submissions,
+ [{'block': True, 'tag': 'my-tag', 'task': task}],
+ )
+ self.assertEqual(future.result(), 'foo')
+
+ def test_association_and_disassociation_on_submit(self):
+ self.transfer_coordinator = RecordingTransferCoordinator()
+
+ # Submit a callable to the transfer coordinator.
+ executor = BoundedExecutor(1, 1)
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task)
+ executor.shutdown()
+
+ # Make sure the future that got submitted was associated to the
+ # transfer future at some point.
+ self.assertEqual(
+ self.transfer_coordinator.all_transfer_futures_ever_associated,
+ {future},
+ )
+
+ # Make sure the future got disassociated once the future is now done
+ # by looking at the currently associated futures.
+ self.assertEqual(self.transfer_coordinator.associated_futures, set())
+
+ def test_done(self):
+ # These should result in not done state:
+ # queued
+ self.assertFalse(self.transfer_coordinator.done())
+ # running
+ self.transfer_coordinator.set_status_to_running()
+ self.assertFalse(self.transfer_coordinator.done())
+
+ # These should result in done state:
+ # failed
+ self.transfer_coordinator.set_exception(Exception)
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # success
+ self.transfer_coordinator.set_result('foo')
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # cancelled
+ self.transfer_coordinator.cancel()
+ self.assertTrue(self.transfer_coordinator.done())
+
+ def test_result_waits_until_done(self):
+ execution_order = []
+
+ def sleep_then_set_result(transfer_coordinator, execution_order):
+ time.sleep(0.05)
+ execution_order.append('setting_result')
+ transfer_coordinator.set_result(None)
+ self.transfer_coordinator.announce_done()
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ executor.submit(
+ sleep_then_set_result,
+ self.transfer_coordinator,
+ execution_order,
+ )
+ self.transfer_coordinator.result()
+ execution_order.append('after_result')
+
+ # The result() call should have waited until the other thread set
+ # the result after sleeping for 0.05 seconds.
+ self.assertTrue(execution_order, ['setting_result', 'after_result'])
+
+ def test_failure_cleanups(self):
+ args = (1, 2)
+ kwargs = {'foo': 'bar'}
+
+ second_args = (2, 4)
+ second_kwargs = {'biz': 'baz'}
+
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *args, **kwargs
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *second_args, **second_kwargs
+ )
+
+ # Ensure the callbacks got added.
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2)
+
+ result_list = []
+ # Ensure they will get called in the correct order.
+ for cleanup in self.transfer_coordinator.failure_cleanups:
+ result_list.append(cleanup())
+ self.assertEqual(
+ result_list, [(args, kwargs), (second_args, second_kwargs)]
+ )
+
+ def test_associated_futures(self):
+ first_future = object()
+ # Associate one future to the transfer
+ self.transfer_coordinator.add_associated_future(first_future)
+ associated_futures = self.transfer_coordinator.associated_futures
+ # The first future should be in the returned list of futures.
+ self.assertEqual(associated_futures, {first_future})
+
+ second_future = object()
+ # Associate another future to the transfer.
+ self.transfer_coordinator.add_associated_future(second_future)
+ # The association should not have mutated the returned list from
+ # before.
+ self.assertEqual(associated_futures, {first_future})
+
+ # Both futures should be in the returned list.
+ self.assertEqual(
+ self.transfer_coordinator.associated_futures,
+ {first_future, second_future},
+ )
+
+ def test_done_callbacks_on_done(self):
+ done_callback_invocations = []
+ callback = FunctionContainer(
+ done_callback_invocations.append, 'done callback called'
+ )
+
+ # Add the done callback to the transfer.
+ self.transfer_coordinator.add_done_callback(callback)
+
+ # Announce that the transfer is done. This should invoke the done
+ # callback.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ # If done is announced again, we should not invoke the callback again
+ # because done has already been announced and thus the callback has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ def test_failure_cleanups_on_done(self):
+ cleanup_invocations = []
+ callback = FunctionContainer(
+ cleanup_invocations.append, 'cleanup called'
+ )
+
+ # Add the failure cleanup to the transfer.
+ self.transfer_coordinator.add_failure_cleanup(callback)
+
+ # Announce that the transfer is done. This should invoke the failure
+ # cleanup.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+ # If done is announced again, we should not invoke the cleanup again
+ # because done has already been announced and thus the cleanup has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+
+class TestBoundedExecutor(unittest.TestCase):
+ def setUp(self):
+ self.coordinator = TransferCoordinator()
+ self.tag_semaphores = {}
+ self.executor = self.get_executor()
+
+ def get_executor(self, max_size=1, max_num_threads=1):
+ return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores)
+
+ def get_task(self, task_cls, main_kwargs=None):
+ return task_cls(self.coordinator, main_kwargs=main_kwargs)
+
+ def get_sleep_task(self, sleep_time=0.01):
+ return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time})
+
+ def add_semaphore(self, task_tag, count):
+ self.tag_semaphores[task_tag] = TaskSemaphore(count)
+
+ def assert_submit_would_block(self, task, tag=None):
+ with self.assertRaises(NoResourcesAvailable):
+ self.executor.submit(task, tag=tag, block=False)
+
+ def assert_submit_would_not_block(self, task, tag=None, **kwargs):
+ try:
+ self.executor.submit(task, tag=tag, block=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'Task {} should not have been blocked. Caused by:\n{}'.format(
+ task, traceback.format_exc()
+ )
+ )
+
+ def add_done_callback_to_future(self, future, fn, *args, **kwargs):
+ callback_for_future = FunctionContainer(fn, *args, **kwargs)
+ future.add_done_callback(callback_for_future)
+
+ def test_submit_single_task(self):
+ # Ensure we can submit a task to the executor
+ task = self.get_task(ReturnFooTask)
+ future = self.executor.submit(task)
+
+ # Ensure what we get back is a Future
+ self.assertIsInstance(future, ExecutorFuture)
+ # Ensure the callable got executed.
+ self.assertEqual(future.result(), 'foo')
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_executor_blocks_on_full_capacity(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_sleep_task()
+ self.executor.submit(first_task)
+ # The first task should be sleeping for a substantial period of
+ # time such that on the submission of the second task, it will
+ # raise an error saying that it cannot be submitted as the max
+ # capacity of the semaphore is one.
+ self.assert_submit_would_block(second_task)
+
+ def test_executor_clears_capacity_on_done_tasks(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_task(ReturnFooTask)
+
+ # Submit a task.
+ future = self.executor.submit(first_task)
+
+ # Submit a new task when the first task finishes. This should not get
+ # blocked because the first task should have finished clearing up
+ # capacity.
+ self.add_done_callback_to_future(
+ future, self.assert_submit_would_not_block, second_task
+ )
+
+ # Wait for it to complete.
+ self.executor.shutdown()
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_would_not_block_when_full_capacity_in_other_semaphore(self):
+ first_task = self.get_sleep_task()
+
+ # Now let's create a new task with a tag and so it uses different
+ # semaphore.
+ task_tag = 'other'
+ other_task = self.get_sleep_task()
+ self.add_semaphore(task_tag, 1)
+
+ # Submit the normal first task
+ self.executor.submit(first_task)
+
+ # Even though The first task should be sleeping for a substantial
+ # period of time, the submission of the second task should not
+ # raise an error because it should use a different semaphore
+ self.assert_submit_would_not_block(other_task, task_tag)
+
+ # Another submission of the other task though should raise
+ # an exception as the capacity is equal to one for that tag.
+ self.assert_submit_would_block(other_task, task_tag)
+
+ def test_shutdown(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown()
+ # Ensure that the shutdown waits until the task is done
+ self.assertTrue(future.done())
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_shutdown_no_wait(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown(False)
+ # Ensure that the shutdown returns immediately even if the task is
+ # not done, which it should not be because it it slow.
+ self.assertFalse(future.done())
+
+ def test_replace_underlying_executor(self):
+ mocked_executor_cls = mock.Mock(BaseExecutor)
+ executor = BoundedExecutor(10, 1, {}, mocked_executor_cls)
+ executor.submit(self.get_task(ReturnFooTask))
+ self.assertTrue(mocked_executor_cls.return_value.submit.called)
+
+
+class TestExecutorFuture(unittest.TestCase):
+ def test_result(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'}))
+
+ def test_done(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertTrue(wrapped_future.done())
+
+ def test_add_done_callback(self):
+ done_callbacks = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ wrapped_future.add_done_callback(
+ FunctionContainer(done_callbacks.append, 'called')
+ )
+ self.assertEqual(done_callbacks, ['called'])
+
+
+class TestNonThreadedExecutor(unittest.TestCase):
+ def test_submit(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(return_call_args, 1, 2, foo='bar')
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'}))
+
+ def test_submit_with_exception(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(raise_exception, RuntimeError())
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ with self.assertRaises(RuntimeError):
+ future.result()
+
+ def test_submit_with_exception_and_captures_info(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ future = NonThreadedExecutor().submit(raise_exception, exception)
+ try:
+ future.result()
+ # An exception should have been raised
+ self.fail('Future should have raised a ValueError')
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+
+class TestNonThreadedExecutorFuture(unittest.TestCase):
+ def setUp(self):
+ self.future = NonThreadedExecutorFuture()
+
+ def test_done_starts_false(self):
+ self.assertFalse(self.future.done())
+
+ def test_done_after_setting_result(self):
+ self.future.set_result('result')
+ self.assertTrue(self.future.done())
+
+ def test_done_after_setting_exception(self):
+ self.future.set_exception_info(Exception(), None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.future.set_result('result')
+ self.assertEqual(self.future.result(), 'result')
+
+ def test_exception_result(self):
+ exception = ValueError('message')
+ self.future.set_exception_info(exception, None)
+ with self.assertRaisesRegex(ValueError, 'message'):
+ self.future.result()
+
+ def test_exception_result_doesnt_modify_last_frame(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ self.future.set_exception_info(exception, tb)
+ try:
+ self.future.result()
+ # An exception should have been raised
+ self.fail()
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+ def test_done_callback(self):
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [])
+ self.future.set_result('result')
+ self.assertEqual(done_futures, [self.future])
+
+ def test_done_callback_after_done(self):
+ self.future.set_result('result')
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [self.future])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
new file mode 100644
index 0000000000..fc3caa843f
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
@@ -0,0 +1,143 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import time
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import TransferCoordinator
+from s3transfer.manager import TransferConfig, TransferCoordinatorController
+from __tests__ import TransferCoordinatorWithInterrupt, unittest
+
+
+class FutureResultException(Exception):
+ pass
+
+
+class TestTransferConfig(unittest.TestCase):
+ def test_exception_on_zero_attr_value(self):
+ with self.assertRaises(ValueError):
+ TransferConfig(max_request_queue_size=0)
+
+
+class TestTransferCoordinatorController(unittest.TestCase):
+ def setUp(self):
+ self.coordinator_controller = TransferCoordinatorController()
+
+ def sleep_then_announce_done(self, transfer_coordinator, sleep_time):
+ time.sleep(sleep_time)
+ transfer_coordinator.set_result('done')
+ transfer_coordinator.announce_done()
+
+ def assert_coordinator_is_cancelled(self, transfer_coordinator):
+ self.assertEqual(transfer_coordinator.status, 'cancelled')
+
+ def test_add_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Ensure that is tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators,
+ {transfer_coordinator},
+ )
+
+ def test_remove_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Now remove the coordinator
+ self.coordinator_controller.remove_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Make sure that it is no longer getting tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators, set()
+ )
+
+ def test_cancel(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Cancel with the canceler
+ self.coordinator_controller.cancel()
+ # Check that coordinator got canceled
+ self.assert_coordinator_is_cancelled(transfer_coordinator)
+
+ def test_cancel_with_message(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message, exc_type=FatalError)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ transfer_coordinator.result()
+
+ def test_wait_for_done_transfer_coordinators(self):
+ # Create a coordinator and add it to the canceler
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+
+ sleep_time = 0.02
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ # In a separate thread sleep and then set the transfer coordinator
+ # to done after sleeping.
+ start_time = time.time()
+ executor.submit(
+ self.sleep_then_announce_done, transfer_coordinator, sleep_time
+ )
+ # Now call wait to wait for the transfer coordinator to be done.
+ self.coordinator_controller.wait()
+ end_time = time.time()
+ wait_time = end_time - start_time
+ # The time waited should not be less than the time it took to sleep in
+ # the separate thread because the wait ending should be dependent on
+ # the sleeping thread announcing that the transfer coordinator is done.
+ self.assertTrue(sleep_time <= wait_time)
+
+ def test_wait_does_not_propogate_exceptions_from_result(self):
+ transfer_coordinator = TransferCoordinator()
+ transfer_coordinator.set_exception(FutureResultException())
+ transfer_coordinator.announce_done()
+ try:
+ self.coordinator_controller.wait()
+ except FutureResultException as e:
+ self.fail('%s should not have been raised.' % e)
+
+ def test_wait_can_be_interrupted(self):
+ inject_interrupt_coordinator = TransferCoordinatorWithInterrupt()
+ self.coordinator_controller.add_transfer_coordinator(
+ inject_interrupt_coordinator
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ self.coordinator_controller.wait()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
new file mode 100644
index 0000000000..d77b5e0240
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
@@ -0,0 +1,728 @@
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import queue
+import signal
+import threading
+import time
+from io import BytesIO
+
+from botocore.client import BaseClient
+from botocore.config import Config
+from botocore.exceptions import ClientError, ReadTimeoutError
+
+from s3transfer.constants import PROCESS_USER_AGENT
+from s3transfer.exceptions import CancelledError, RetriesExceededError
+from s3transfer.processpool import (
+ SHUTDOWN_SIGNAL,
+ ClientFactory,
+ DownloadFileRequest,
+ GetObjectJob,
+ GetObjectSubmitter,
+ GetObjectWorker,
+ ProcessPoolDownloader,
+ ProcessPoolTransferFuture,
+ ProcessPoolTransferMeta,
+ ProcessTransferConfig,
+ TransferMonitor,
+ TransferState,
+ ignore_ctrl_c,
+)
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ FileCreator,
+ StreamWithError,
+ StubbedClientTest,
+ mock,
+ skip_if_windows,
+ unittest,
+)
+
+
+class RenameFailingOSUtils(OSUtils):
+ def __init__(self, exception):
+ self.exception = exception
+
+ def rename_file(self, current_filename, new_filename):
+ raise self.exception
+
+
+class TestIgnoreCtrlC(unittest.TestCase):
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_ignore_ctrl_c(self):
+ with ignore_ctrl_c():
+ try:
+ os.kill(os.getpid(), signal.SIGINT)
+ except KeyboardInterrupt:
+ self.fail(
+ 'The ignore_ctrl_c context manager should have '
+ 'ignored the KeyboardInterrupt exception'
+ )
+
+
+class TestProcessPoolDownloader(unittest.TestCase):
+ def test_uses_client_kwargs(self):
+ with mock.patch('s3transfer.processpool.ClientFactory') as factory:
+ ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'})
+ self.assertEqual(
+ factory.call_args[0][0], {'region_name': 'myregion'}
+ )
+
+
+class TestProcessPoolTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+ self.meta = ProcessPoolTransferMeta(
+ transfer_id=self.transfer_id, call_args=CallArgs()
+ )
+ self.future = ProcessPoolTransferFuture(
+ monitor=self.monitor, meta=self.meta
+ )
+
+ def test_meta(self):
+ self.assertEqual(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.future.result())
+
+ def test_result_with_exception(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.future.result()
+
+ def test_result_with_keyboard_interrupt(self):
+ mock_monitor = mock.Mock(TransferMonitor)
+ mock_monitor._connect = mock.Mock()
+ mock_monitor.poll_for_result.side_effect = KeyboardInterrupt()
+ future = ProcessPoolTransferFuture(
+ monitor=mock_monitor, meta=self.meta
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ future.result()
+ self.assertTrue(mock_monitor._connect.called)
+ self.assertTrue(mock_monitor.notify_exception.called)
+ call_args = mock_monitor.notify_exception.call_args[0]
+ self.assertEqual(call_args[0], self.transfer_id)
+ self.assertIsInstance(call_args[1], CancelledError)
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(CancelledError):
+ self.future.result()
+
+
+class TestProcessPoolTransferMeta(unittest.TestCase):
+ def test_transfer_id(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.transfer_id, 1)
+
+ def test_call_args(self):
+ call_args = CallArgs()
+ meta = ProcessPoolTransferMeta(1, call_args)
+ self.assertEqual(meta.call_args, call_args)
+
+ def test_user_context(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.user_context, {})
+ meta.user_context['mykey'] = 'myvalue'
+ self.assertEqual(meta.user_context, {'mykey': 'myvalue'})
+
+
+class TestClientFactory(unittest.TestCase):
+ def test_create_client(self):
+ client = ClientFactory().create_client()
+ self.assertIsInstance(client, BaseClient)
+ self.assertEqual(client.meta.service_model.service_name, 's3')
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_create_client_with_client_kwargs(self):
+ client = ClientFactory({'region_name': 'myregion'}).create_client()
+ self.assertEqual(client.meta.region_name, 'myregion')
+
+ def test_user_agent_with_config(self):
+ client = ClientFactory({'config': Config()}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent_extra(self):
+ config = Config(user_agent_extra='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent(self):
+ config = Config(user_agent='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+
+class TestTransferMonitor(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+
+ def test_notify_new_transfer_creates_new_state(self):
+ monitor = TransferMonitor()
+ transfer_id = monitor.notify_new_transfer()
+ self.assertFalse(monitor.is_done(transfer_id))
+ self.assertIsNone(monitor.get_exception(transfer_id))
+
+ def test_notify_new_transfer_increments_transfer_id(self):
+ monitor = TransferMonitor()
+ self.assertEqual(monitor.notify_new_transfer(), 0)
+ self.assertEqual(monitor.notify_new_transfer(), 1)
+
+ def test_notify_get_exception(self):
+ exception = Exception()
+ self.monitor.notify_exception(self.transfer_id, exception)
+ self.assertEqual(
+ self.monitor.get_exception(self.transfer_id), exception
+ )
+
+ def test_get_no_exception(self):
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+ def test_notify_jobs(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0)
+
+ def test_notify_jobs_for_multiple_transfers(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ other_transfer_id = self.monitor.notify_new_transfer()
+ self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(
+ self.monitor.notify_job_complete(other_transfer_id), 1
+ )
+
+ def test_done(self):
+ self.assertFalse(self.monitor.is_done(self.transfer_id))
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+
+ def test_poll_for_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.monitor.poll_for_result(self.transfer_id))
+
+ def test_poll_for_result_raises_error(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.monitor.poll_for_result(self.transfer_id)
+
+ def test_poll_for_result_waits_till_done(self):
+ event_order = []
+
+ def sleep_then_notify_done():
+ time.sleep(0.05)
+ event_order.append('notify_done')
+ self.monitor.notify_done(self.transfer_id)
+
+ t = threading.Thread(target=sleep_then_notify_done)
+ t.start()
+
+ self.monitor.poll_for_result(self.transfer_id)
+ event_order.append('done_polling')
+ self.assertEqual(event_order, ['notify_done', 'done_polling'])
+
+ def test_notify_cancel_all_in_progress(self):
+ monitor = TransferMonitor()
+ transfer_ids = []
+ for _ in range(10):
+ transfer_ids.append(monitor.notify_new_transfer())
+ monitor.notify_cancel_all_in_progress()
+ for transfer_id in transfer_ids:
+ self.assertIsInstance(
+ monitor.get_exception(transfer_id), CancelledError
+ )
+ # Cancelling a transfer does not mean it is done as there may
+ # be cleanup work left to do.
+ self.assertFalse(monitor.is_done(transfer_id))
+
+ def test_notify_cancel_does_not_affect_done_transfers(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.monitor.notify_cancel_all_in_progress()
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+
+class TestTransferState(unittest.TestCase):
+ def setUp(self):
+ self.state = TransferState()
+
+ def test_done(self):
+ self.assertFalse(self.state.done)
+ self.state.set_done()
+ self.assertTrue(self.state.done)
+
+ def test_waits_till_done_is_set(self):
+ event_order = []
+
+ def sleep_then_set_done():
+ time.sleep(0.05)
+ event_order.append('set_done')
+ self.state.set_done()
+
+ t = threading.Thread(target=sleep_then_set_done)
+ t.start()
+
+ self.state.wait_till_done()
+ event_order.append('done_waiting')
+ self.assertEqual(event_order, ['set_done', 'done_waiting'])
+
+ def test_exception(self):
+ exception = RuntimeError()
+ self.state.exception = exception
+ self.assertEqual(self.state.exception, exception)
+
+ def test_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.jobs_to_complete, 5)
+
+ def test_decrement_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.decrement_jobs_to_complete(), 4)
+
+
+class TestGetObjectSubmitter(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.transfer_config = ProcessTransferConfig()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = mock.Mock(OSUtils)
+ self.download_request_queue = queue.Queue()
+ self.worker_queue = queue.Queue()
+ self.submitter = GetObjectSubmitter(
+ transfer_config=self.transfer_config,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ download_request_queue=self.download_request_queue,
+ worker_queue=self.worker_queue,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.filename = 'myfile'
+ self.temp_filename = 'myfile.temp'
+ self.osutil.get_temp_filename.return_value = self.temp_filename
+ self.extra_args = {}
+ self.expected_size = None
+
+ def add_download_file_request(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'filename': self.filename,
+ 'extra_args': self.extra_args,
+ 'expected_size': self.expected_size,
+ }
+ kwargs.update(override_kwargs)
+ self.download_request_queue.put(DownloadFileRequest(**kwargs))
+
+ def add_shutdown(self):
+ self.download_request_queue.put(SHUTDOWN_SIGNAL)
+
+ def assert_submitted_get_object_jobs(self, expected_jobs):
+ actual_jobs = []
+ while not self.worker_queue.empty():
+ actual_jobs.append(self.worker_queue.get())
+ self.assertEqual(actual_jobs, expected_jobs)
+
+ def test_run_for_non_ranged_download(self):
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_for_ranged_download(self):
+ self.transfer_config.multipart_chunksize = 2
+ self.transfer_config.multipart_threshold = 4
+ self.add_download_file_request(expected_size=4)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 4)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'Range': 'bytes=0-1'},
+ filename=self.filename,
+ ),
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=2,
+ extra_args={'Range': 'bytes=2-'},
+ filename=self.filename,
+ ),
+ ]
+ )
+
+ def test_run_when_expected_size_not_provided(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_extra_args(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ self.add_download_file_request(
+ extra_args={'VersionId': 'versionid'}, expected_size=None
+ )
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'VersionId': 'versionid'},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_exception(self):
+ self.stubber.add_client_error('head_object', 'NoSuchKey', 404)
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_with_error_in_allocating_temp_file(self):
+ self.osutil.allocate.side_effect = OSError()
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), OSError
+ )
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_submitter_cannot_be_killed(self):
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.head_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.submitter.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The submitter should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
+
+
+class TestGetObjectWorker(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.queue = queue.Queue()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = OSUtils()
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.remote_contents = b'my content'
+ self.temp_filename = self.files.create_file('tempfile', '')
+ self.extra_args = {}
+ self.offset = 0
+ self.final_filename = self.files.full_path('final_filename')
+ self.stream = BytesIO(self.remote_contents)
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+ def add_get_object_job(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'temp_filename': self.temp_filename,
+ 'extra_args': self.extra_args,
+ 'offset': self.offset,
+ 'filename': self.final_filename,
+ }
+ kwargs.update(override_kwargs)
+ self.queue.put(GetObjectJob(**kwargs))
+
+ def add_shutdown(self):
+ self.queue.put(SHUTDOWN_SIGNAL)
+
+ def add_stubbed_get_object_response(self, body=None, expected_params=None):
+ if body is None:
+ body = self.stream
+ get_object_response = {'Body': body}
+
+ if expected_params is None:
+ expected_params = {'Bucket': self.bucket, 'Key': self.key}
+
+ self.stubber.add_response(
+ 'get_object', get_object_response, expected_params
+ )
+
+ def assert_contents(self, filename, contents):
+ self.assertTrue(os.path.exists(filename))
+ with open(filename, 'rb') as f:
+ self.assertEqual(f.read(), contents)
+
+ def assert_does_not_exist(self, filename):
+ self.assertFalse(os.path.exists(filename))
+
+ def test_run_is_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_contents(self.final_filename, self.remote_contents)
+
+ def test_run_jobs_is_not_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_with_extra_args(self):
+ self.add_get_object_job(extra_args={'VersionId': 'versionid'})
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ }
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_with_offset(self):
+ offset = 1
+ self.add_get_object_job(offset=offset)
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ with open(self.temp_filename, 'rb') as f:
+ f.seek(offset)
+ self.assertEqual(f.read(), self.remote_contents)
+
+ def test_run_error_in_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.stubber.add_client_error('get_object', 'NoSuchKey', 404)
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_does_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+
+ def test_run_can_exhaust_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ # 5 is the current setting for max number of GetObject attempts
+ for _ in range(5):
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id),
+ RetriesExceededError,
+ )
+
+ def test_run_skips_get_object_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+
+ self.worker.run()
+ # Note we did not add a stubbed response for get_object
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_final_job_removes_file_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_fails_to_rename_file(self):
+ exception = OSError()
+ osutil = RenameFailingOSUtils(exception)
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=osutil,
+ )
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.assertEqual(
+ self.transfer_monitor.get_exception(self.transfer_id), exception
+ )
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_worker_cannot_be_killed(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.get_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.worker.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The worker should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
new file mode 100644
index 0000000000..35cf4a22dd
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
@@ -0,0 +1,780 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import socket
+import tempfile
+from concurrent import futures
+from contextlib import closing
+from io import BytesIO, StringIO
+
+from s3transfer import (
+ MultipartDownloader,
+ MultipartUploader,
+ OSUtils,
+ QueueShutdownError,
+ ReadFileChunk,
+ S3Transfer,
+ ShutdownQueue,
+ StreamReaderProgress,
+ TransferConfig,
+ disable_upload_callbacks,
+ enable_upload_callbacks,
+ random_file_extension,
+)
+from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+from __tests__ import mock, unittest
+
+
+class InMemoryOSLayer(OSUtils):
+ def __init__(self, filemap):
+ self.filemap = filemap
+
+ def get_file_size(self, filename):
+ return len(self.filemap[filename])
+
+ def open_file_chunk_reader(self, filename, start_byte, size, callback):
+ return closing(BytesIO(self.filemap[filename]))
+
+ def open(self, filename, mode):
+ if 'wb' in mode:
+ fileobj = BytesIO()
+ self.filemap[filename] = fileobj
+ return closing(fileobj)
+ else:
+ return closing(self.filemap[filename])
+
+ def remove_file(self, filename):
+ if filename in self.filemap:
+ del self.filemap[filename]
+
+ def rename_file(self, current_filename, new_filename):
+ if current_filename in self.filemap:
+ self.filemap[new_filename] = self.filemap.pop(current_filename)
+
+
+class SequentialExecutor:
+ def __init__(self, max_workers):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ pass
+
+ # The real map() interface actually takes *args, but we specifically do
+ # _not_ use this interface.
+ def map(self, function, args):
+ results = []
+ for arg in args:
+ results.append(function(arg))
+ return results
+
+ def submit(self, function):
+ future = futures.Future()
+ future.set_result(function())
+ return future
+
+
+class TestOSUtils(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_get_file_size(self):
+ with mock.patch('os.path.getsize') as m:
+ OSUtils().get_file_size('myfile')
+ m.assert_called_with('myfile')
+
+ def test_open_file_chunk_reader(self):
+ with mock.patch('s3transfer.ReadFileChunk') as m:
+ OSUtils().open_file_chunk_reader('myfile', 0, 100, None)
+ m.from_filename.assert_called_with(
+ 'myfile', 0, 100, None, enable_callback=False
+ )
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ with mock.patch('os.remove') as remove:
+ remove.side_effect = OSError('fake error')
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_remove_file_proxies_remove_file(self):
+ with mock.patch('os.remove') as remove:
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_rename_file(self):
+ with mock.patch('s3transfer.compat.rename_file') as rename_file:
+ OSUtils().rename_file('foo', 'newfoo')
+ rename_file.assert_called_with('foo', 'newfoo')
+
+
+class TestReadFileChunk(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+
+class TestReadFileChunkWithCallback(TestReadFileChunk):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ with open(self.filename, 'wb') as f:
+ f.write(b'abc')
+ self.amounts_seen = []
+
+ def callback(self, amount):
+ self.amounts_seen.append(amount)
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+
+class TestStreamReaderProgress(unittest.TestCase):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ amounts_seen = []
+
+ def callback(amount):
+ amounts_seen.append(amount)
+
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream, callback)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(amounts_seen, [9])
+
+
+class TestMultipartUploader(unittest.TestCase):
+ def test_multipart_upload_uses_correct_client_calls(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ # We need to check both the sequence of calls (create/upload/complete)
+ # as well as the params passed between the calls, including
+ # 1. The upload_id was plumbed through
+ # 2. The collected etags were added to the complete call.
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key'
+ )
+ # Should be two parts.
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_injects_proper_kwargs(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ extra_args = {
+ 'SSECustomerKey': 'fakekey',
+ 'SSECustomerAlgorithm': 'AES256',
+ 'StorageClass': 'REDUCED_REDUNDANCY',
+ }
+ uploader.upload_file('filename', 'bucket', 'key', None, extra_args)
+
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ # The initial call should inject all the storage class params.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ StorageClass='REDUCED_REDUNDANCY',
+ )
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ # We only have to forward certain **extra_args in subsequent
+ # UploadPart calls.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_is_aborted_on_error(self):
+ # If the create_multipart_upload succeeds and any upload_part
+ # fails, then abort_multipart_upload will be called.
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.side_effect = Exception(
+ "Some kind of error occurred."
+ )
+
+ with self.assertRaises(S3UploadFailedError):
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ client.abort_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key', UploadId='upload_id'
+ )
+
+
+class TestMultipartDownloader(unittest.TestCase):
+
+ maxDiff = None
+
+ def test_multipart_download_uses_correct_client_calls(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ client.get_object.assert_called_with(
+ Range='bytes=0-', Bucket='bucket', Key='key'
+ )
+
+ def test_multipart_download_with_multiple_parts(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ # For testing purposes, we're testing with a multipart threshold
+ # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz',
+ # this should result in 3 calls. In python slices this would be:
+ # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly
+ # different because they use inclusive ranges.
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # Note these are inclusive ranges.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_retry_on_failures_from_stream_reads(self):
+ # If we get an exception during a call to the response body's .read()
+ # method, we should retry the request.
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = [
+ socket.error("fake error"),
+ response_body,
+ ]
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # The first call to range=0-3 fails because of the
+ # side_effect above where we make the .read() raise a
+ # socket.error.
+ # The second call to range=0-3 then succeeds.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_exception_raised_on_exceeded_retries(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = socket.error("fake error")
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ with self.assertRaises(RetriesExceededError):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_io_thread_failure_triggers_shutdown(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ os_layer = mock.Mock()
+ mock_fileobj = mock.MagicMock()
+ mock_fileobj.__enter__.return_value = mock_fileobj
+ mock_fileobj.write.side_effect = Exception("fake IO error")
+ os_layer.open.return_value = mock_fileobj
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), os_layer, SequentialExecutor
+ )
+ # We're verifying that the exception raised from the IO future
+ # propagates back up via download_file().
+ with self.assertRaisesRegex(Exception, "fake IO error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_download_futures_fail_triggers_shutdown(self):
+ class FailedDownloadParts(SequentialExecutor):
+ def __init__(self, max_workers):
+ self.is_first = True
+
+ def submit(self, function):
+ future = futures.Future()
+ if self.is_first:
+ # This is the download_parts_thread.
+ future.set_exception(
+ Exception("fake download parts error")
+ )
+ self.is_first = False
+ return future
+
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts
+ )
+ with self.assertRaisesRegex(Exception, "fake download parts error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+
+class TestS3Transfer(unittest.TestCase):
+ def setUp(self):
+ self.client = mock.Mock()
+ self.random_file_patch = mock.patch('s3transfer.random_file_extension')
+ self.random_file = self.random_file_patch.start()
+ self.random_file.return_value = 'RANDOM'
+
+ def tearDown(self):
+ self.random_file_patch.stop()
+
+ def test_callback_handlers_register_on_put_item(self):
+ osutil = InMemoryOSLayer({'smallfile': b'foobar'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ events = self.client.meta.events
+ events.register_first.assert_called_with(
+ 'request-created.s3',
+ disable_upload_callbacks,
+ unique_id='s3upload-callback-disable',
+ )
+ events.register_last.assert_called_with(
+ 'request-created.s3',
+ enable_upload_callbacks,
+ unique_id='s3upload-callback-enable',
+ )
+
+ def test_upload_below_multipart_threshold_uses_put_object(self):
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY
+ )
+
+ def test_extra_args_on_uploaded_passed_to_api_call(self):
+ extra_args = {'ACL': 'public-read'}
+ fake_files = {'smallfile': b'hello world'}
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file(
+ 'smallfile', 'bucket', 'key', extra_args=extra_args
+ )
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read'
+ )
+
+ def test_uses_multipart_upload_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartUploader') as uploader:
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ config = TransferConfig(
+ multipart_threshold=2, multipart_chunksize=2
+ )
+ transfer = S3Transfer(self.client, osutil=osutil, config=config)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+
+ uploader.return_value.upload_file.assert_called_with(
+ 'smallfile', 'bucket', 'key', None, {}
+ )
+
+ def test_uses_multipart_download_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartDownloader') as downloader:
+ osutil = InMemoryOSLayer({})
+ over_multipart_threshold = 100 * 1024 * 1024
+ transfer = S3Transfer(self.client, osutil=osutil)
+ callback = mock.sentinel.CALLBACK
+ self.client.head_object.return_value = {
+ 'ContentLength': over_multipart_threshold,
+ }
+ transfer.download_file(
+ 'bucket', 'key', 'filename', callback=callback
+ )
+
+ downloader.return_value.download_file.assert_called_with(
+ # Note how we're downloading to a temporary random file.
+ 'bucket',
+ 'key',
+ 'filename.RANDOM',
+ over_multipart_threshold,
+ {},
+ callback,
+ )
+
+ def test_download_file_with_invalid_extra_args(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ with self.assertRaises(ValueError):
+ transfer.download_file(
+ 'bucket',
+ 'key',
+ '/tmp/smallfile',
+ extra_args={'BadValue': 'foo'},
+ )
+
+ def test_upload_file_with_invalid_extra_args(self):
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ bad_args = {"WebsiteRedirectLocation": "/foo"}
+ with self.assertRaises(ValueError):
+ transfer.upload_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args
+ )
+
+ def test_download_file_fowards_extra_args(self):
+ extra_args = {
+ 'SSECustomerKey': 'foo',
+ 'SSECustomerAlgorithm': 'AES256',
+ }
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
+ )
+
+ # Note that we need to invoke the HeadObject call
+ # and the PutObject call with the extra_args.
+ # This is necessary. Trying to HeadObject an SSE object
+ # will return a 400 if you don't provide the required
+ # params.
+ self.client.get_object.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ SSECustomerAlgorithm='AES256',
+ SSECustomerKey='foo',
+ )
+
+ def test_get_object_stream_is_retried_and_succeeds(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.side_effect = [
+ # First request fails.
+ socket.error("fake error"),
+ # Second succeeds.
+ {'Body': BytesIO(b'foobar')},
+ ]
+ transfer.download_file('bucket', 'key', '/tmp/smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 2)
+
+ def test_get_object_stream_uses_all_retries_and_errors_out(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ # Here we're raising an exception every single time, which
+ # will exhaust our retry count and propagate a
+ # RetriesExceededError.
+ self.client.get_object.side_effect = socket.error("fake error")
+ with self.assertRaises(RetriesExceededError):
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 5)
+ # We should have also cleaned up the in progress file
+ # we were downloading to.
+ self.assertEqual(osutil.filemap, {})
+
+ def test_download_below_multipart_threshold(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.client.get_object.assert_called_with(Bucket='bucket', Key='key')
+
+ def test_can_create_with_just_client(self):
+ transfer = S3Transfer(client=mock.Mock())
+ self.assertIsInstance(transfer, S3Transfer)
+
+
+class TestShutdownQueue(unittest.TestCase):
+ def test_handles_normal_put_get_requests(self):
+ q = ShutdownQueue()
+ q.put('foo')
+ self.assertEqual(q.get(), 'foo')
+
+ def test_put_raises_error_on_shutdown(self):
+ q = ShutdownQueue()
+ q.trigger_shutdown()
+ with self.assertRaises(QueueShutdownError):
+ q.put('foo')
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestCallbackHandlers(unittest.TestCase):
+ def setUp(self):
+ self.request = mock.Mock()
+
+ def test_disable_request_on_put_object(self):
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_disable_request_on_upload_part(self):
+ disable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_enable_object_on_put_object(self):
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_enable_object_on_upload_part(self):
+ enable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_dont_disable_if_missing_interface(self):
+ del self.request.body.disable_callback
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_enable_if_missing_interface(self):
+ del self.request.body.enable_callback
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_disable_if_wrong_operation(self):
+ disable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.disable_callback.called)
+
+ def test_dont_enable_if_wrong_operation(self):
+ enable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.enable_callback.called)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
new file mode 100644
index 0000000000..a26d3a548c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
@@ -0,0 +1,91 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.exceptions import InvalidSubscriberMethodError
+from s3transfer.subscribers import BaseSubscriber
+from __tests__ import unittest
+
+
+class ExtraMethodsSubscriber(BaseSubscriber):
+ def extra_method(self):
+ return 'called extra method'
+
+
+class NotCallableSubscriber(BaseSubscriber):
+ on_done = 'foo'
+
+
+class NoKwargsSubscriber(BaseSubscriber):
+ def on_done(self):
+ pass
+
+
+class OverrideMethodSubscriber(BaseSubscriber):
+ def on_queued(self, **kwargs):
+ return kwargs
+
+
+class OverrideConstructorSubscriber(BaseSubscriber):
+ def __init__(self, arg1, arg2):
+ self.arg1 = arg1
+ self.arg2 = arg2
+
+
+class TestSubscribers(unittest.TestCase):
+ def test_can_instantiate_base_subscriber(self):
+ try:
+ BaseSubscriber()
+ except InvalidSubscriberMethodError:
+ self.fail('BaseSubscriber should be instantiable')
+
+ def test_can_call_base_subscriber_method(self):
+ subscriber = BaseSubscriber()
+ try:
+ subscriber.on_done(future=None)
+ except Exception as e:
+ self.fail(
+ 'Should be able to call base class subscriber method. '
+ 'instead got: %s' % e
+ )
+
+ def test_subclass_can_have_and_call_additional_methods(self):
+ subscriber = ExtraMethodsSubscriber()
+ self.assertEqual(subscriber.extra_method(), 'called extra method')
+
+ def test_can_subclass_and_override_method_from_base_subscriber(self):
+ subscriber = OverrideMethodSubscriber()
+ # Make sure that the overridden method is called
+ self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'})
+
+ def test_can_subclass_and_override_constructor_from_base_class(self):
+ subscriber = OverrideConstructorSubscriber('foo', arg2='bar')
+ # Make sure you can create a custom constructor.
+ self.assertEqual(subscriber.arg1, 'foo')
+ self.assertEqual(subscriber.arg2, 'bar')
+
+ def test_invalid_arguments_in_constructor_of_subclass_subscriber(self):
+ # The override constructor should still have validation of
+ # constructor args.
+ with self.assertRaises(TypeError):
+ OverrideConstructorSubscriber()
+
+ def test_not_callable_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must be callable'
+ ):
+ NotCallableSubscriber()
+
+ def test_no_kwargs_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must accept keyword'
+ ):
+ NoKwargsSubscriber()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
new file mode 100644
index 0000000000..4f0bc4d1cc
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
@@ -0,0 +1,833 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from concurrent import futures
+from functools import partial
+from threading import Event
+
+from s3transfer.futures import BoundedExecutor, TransferCoordinator
+from s3transfer.subscribers import BaseSubscriber
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class TaskFailureException(Exception):
+ pass
+
+
+class SuccessTask(Task):
+ def _main(
+ self, return_value='success', callbacks=None, failure_cleanups=None
+ ):
+ if callbacks:
+ for callback in callbacks:
+ callback()
+ if failure_cleanups:
+ for failure_cleanup in failure_cleanups:
+ self._transfer_coordinator.add_failure_cleanup(failure_cleanup)
+ return return_value
+
+
+class FailureTask(Task):
+ def _main(self, exception=TaskFailureException):
+ raise exception()
+
+
+class ReturnKwargsTask(Task):
+ def _main(self, **kwargs):
+ return kwargs
+
+
+class SubmitMoreTasksTask(Task):
+ def _main(self, executor, tasks_to_submit):
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+
+
+class NOOPSubmissionTask(SubmissionTask):
+ def _submit(self, transfer_future, **kwargs):
+ pass
+
+
+class ExceptionSubmissionTask(SubmissionTask):
+ def _submit(
+ self,
+ transfer_future,
+ executor=None,
+ tasks_to_submit=None,
+ additional_callbacks=None,
+ exception=TaskFailureException,
+ ):
+ if executor and tasks_to_submit:
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+ if additional_callbacks:
+ for callback in additional_callbacks:
+ callback()
+ raise exception()
+
+
+class StatusRecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self, transfer_id=None):
+ super().__init__(transfer_id)
+ self.status_changes = [self._status]
+
+ def set_status_to_queued(self):
+ super().set_status_to_queued()
+ self._record_status_change()
+
+ def set_status_to_running(self):
+ super().set_status_to_running()
+ self._record_status_change()
+
+ def _record_status_change(self):
+ self.status_changes.append(self._status)
+
+
+class RecordingStateSubscriber(BaseSubscriber):
+ def __init__(self, transfer_coordinator):
+ self._transfer_coordinator = transfer_coordinator
+ self.status_during_on_queued = None
+
+ def on_queued(self, **kwargs):
+ self.status_during_on_queued = self._transfer_coordinator.status
+
+
+class TestSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.executor = BoundedExecutor(1000, 5)
+ self.call_args = CallArgs(subscribers=[])
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.main_kwargs = {'transfer_future': self.transfer_future}
+
+ def test_transitions_from_not_started_to_queued_to_running(self):
+ self.transfer_coordinator = StatusRecordingTransferCoordinator()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ # Status should be queued until submission task has been ran.
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ submission_task()
+ # Once submission task has been ran, the status should now be running.
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Ensure the transitions were as expected as well.
+ self.assertEqual(
+ self.transfer_coordinator.status_changes,
+ ['not-started', 'queued', 'running'],
+ )
+
+ def test_on_queued_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the on_queued callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_on_queued_status_in_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingStateSubscriber(self.transfer_coordinator)
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the status was queued during on_queued callback.
+ self.assertEqual(subscriber.status_during_on_queued, 'queued')
+
+ def test_sets_exception_from_submit(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Make sure the status of the future is failed
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the future propagates the exception encountered in the
+ # submission task.
+ with self.assertRaises(TaskFailureException):
+ self.transfer_future.result()
+
+ def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self):
+ self.main_kwargs['exception'] = KeyboardInterrupt
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ with self.assertRaises(KeyboardInterrupt):
+ self.transfer_future.result()
+
+ def test_calls_done_callbacks_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_calls_failure_cleanups_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_cleanups_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the done callbacks to be invoke more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called only once.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_done_callbacks_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the failure cleanups to be invoked more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called only once.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_handles_cleanups_submitted_in_other_tasks(self):
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ # We want the cleanup to be added in the execution of the task and
+ # still be executed by the submission task when it fails.
+ task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [task]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called even though the callback got
+ # added in a completely different task.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self):
+ # In this test, we want to make sure that any tasks that may be
+ # submitted in another task complete before we start performing
+ # cleanups.
+ #
+ # This is tested by doing the following:
+ #
+ # ExecutionSubmissionTask
+ # |
+ # +--submits-->SubmitMoreTasksTask
+ # |
+ # +--submits-->SuccessTask
+ # |
+ # +-->sleeps-->adds failure cleanup
+ #
+ # In the end, the failure cleanup of the SuccessTask should be ran
+ # when the ExecutionSubmissionTask fails. If the
+ # ExeceptionSubmissionTask did not run the failure cleanup it is most
+ # likely that it did not wait for the SuccessTask to complete, which
+ # it needs to because the ExeceptionSubmissionTask does not know
+ # what failure cleanups it needs to run until all spawned tasks have
+ # completed.
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+
+ cleanup_task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+ task_for_submitting_cleanup_task = self.get_task(
+ SubmitMoreTasksTask,
+ main_kwargs={
+ 'executor': self.executor,
+ 'tasks_to_submit': [cleanup_task],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [
+ task_for_submitting_cleanup_task
+ ]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_submission_task_announces_done_if_cancelled_before_main(self):
+ invocations_of_done = []
+ done_callback = FunctionContainer(
+ invocations_of_done.append, 'done announced'
+ )
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ self.transfer_coordinator.cancel()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Because the submission task was cancelled before being run
+ # it did not submit any extra tasks so a result it is responsible
+ # for making sure it announces done as nothing else will.
+ self.assertEqual(invocations_of_done, ['done announced'])
+
+
+class TestTask(unittest.TestCase):
+ def setUp(self):
+ self.transfer_id = 1
+ self.transfer_coordinator = TransferCoordinator(
+ transfer_id=self.transfer_id
+ )
+
+ def test_repr(self):
+ main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'}
+ task = ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs
+ )
+ # The repr should not include the other parameter because it is not
+ # a desired parameter to include.
+ self.assertEqual(
+ repr(task),
+ 'ReturnKwargsTask(transfer_id={}, {})'.format(
+ self.transfer_id, {'bucket': 'mybucket'}
+ ),
+ )
+
+ def test_transfer_id(self):
+ task = SuccessTask(self.transfer_coordinator)
+ # Make sure that the id is the one provided to the id associated
+ # to the transfer coordinator.
+ self.assertEqual(task.transfer_id, self.transfer_id)
+
+ def test_context_status_transitioning_success(self):
+ # The status should be set to running.
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # If a task is called, the status still should be running.
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Once the final task is called, the status should be set to success.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_context_status_transitioning_failed(self):
+ self.transfer_coordinator.set_status_to_running()
+
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # A failure task should result in the failed status
+ FailureTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Even if the final task comes in and succeeds, it should stay failed.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_result_setting_for_success(self):
+ override_return = 'foo'
+ SuccessTask(self.transfer_coordinator)()
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': override_return},
+ is_final=True,
+ )()
+
+ # The return value for the transfer future should be of the final
+ # task.
+ self.assertEqual(self.transfer_coordinator.result(), override_return)
+
+ def test_result_setting_for_error(self):
+ FailureTask(self.transfer_coordinator)()
+
+ # If another failure comes in, the result should still throw the
+ # original exception when result() is eventually called.
+ FailureTask(
+ self.transfer_coordinator, main_kwargs={'exception': Exception}
+ )()
+
+ # Even if a success task comes along, the result of the future
+ # should be the original exception
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_done_callbacks_success(self):
+ callback_results = []
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For successful tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_done_callbacks_failure(self):
+ callback_results = []
+ FailureTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For even failed tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ # Callbacks should continue to be called even after a related failure
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'third'),
+ partial(callback_results.append, 'fourth'),
+ ],
+ )()
+ self.assertEqual(
+ callback_results, ['first', 'second', 'third', 'fourth']
+ )
+
+ def test_failure_cleanups_on_failure(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ FailureTask(self.transfer_coordinator)()
+ # The failure callbacks should have not been called yet because it
+ # is not the last task
+ self.assertEqual(callback_results, [])
+
+ # Now the failure callbacks should get called.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_no_failure_cleanups_on_success(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ # The failure cleanups should not have been called because no task
+ # failed for the transfer context.
+ self.assertEqual(callback_results, [])
+
+ def test_passing_main_kwargs(self):
+ main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+ ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True
+ )()
+ # The kwargs should have been passed to the main()
+ self.assertEqual(self.transfer_coordinator.result(), main_kwargs)
+
+ def test_passing_pending_kwargs_single_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo']},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['baz']},
+ )
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_kwargs_list_of_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': ['first', 'second']}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][0]},
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][1]},
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['foo'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out in the expected order.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_and_non_pending_kwargs(self):
+ main_kwargs = {'nonpending_value': 'foo'}
+ pending_kwargs = {}
+ ref_main_kwargs = {
+ 'nonpending_value': 'foo',
+ 'pending_value': 'bar',
+ 'pending_list': ['first', 'second'],
+ }
+
+ # Create the pending tasks
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['pending_value'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_value']
+ },
+ )
+ )
+
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][0]
+ },
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][1]
+ },
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ # and just regular nonpending kwargs.
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ main_kwargs=main_kwargs,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have all of the kwargs (both pending and
+ # nonpending)
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_single_failed_pending_future(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_single_failed_pending_future_in_list(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ second_future = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value in the list
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+
+class BaseMultipartTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+
+
+class TestCreateMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'foo'
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ response = {'UploadId': upload_id}
+ task = self.get_task(
+ CreateMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response=response,
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ result_id = task()
+ self.stubber.assert_no_pending_responses()
+ # Ensure the upload id returned is correct
+ self.assertEqual(upload_id, result_id)
+
+ # Make sure that the abort was added as a cleanup failure
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1)
+
+ # Make sure if it is called, it will abort correctly
+ self.stubber.add_response(
+ method='abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ },
+ )
+ self.transfer_coordinator.failure_cleanups[0]()
+ self.stubber.assert_no_pending_responses()
+
+
+class TestCompleteMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+
+ def test_includes_extra_args(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {'RequestPayer': 'requester'},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ 'RequestPayer': 'requester',
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
new file mode 100644
index 0000000000..1ac38b3616
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
@@ -0,0 +1,694 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+
+import math
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from botocore.stub import ANY
+
+from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
+from s3transfer.manager import TransferConfig
+from s3transfer.upload import (
+ AggregatedProgressCallback,
+ InterruptReader,
+ PutObjectTask,
+ UploadFilenameInputManager,
+ UploadNonSeekableInputManager,
+ UploadPartTask,
+ UploadSeekableInputManager,
+ UploadSubmissionTask,
+)
+from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileSizeProvider,
+ NonSeekableReader,
+ RecordingExecutor,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class InterruptionError(Exception):
+ pass
+
+
+class OSUtilsExceptionOnFileSize(OSUtils):
+ def get_file_size(self, filename):
+ raise AssertionError(
+ "The file %s should not have been stated" % filename
+ )
+
+
+class BaseUploadTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+ self.osutil = OSUtils()
+
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'my content'
+ self.subscribers = []
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+
+class TestAggregatedProgressCallback(unittest.TestCase):
+ def setUp(self):
+ self.aggregated_amounts = []
+ self.threshold = 3
+ self.aggregated_progress_callback = AggregatedProgressCallback(
+ [self.callback], self.threshold
+ )
+
+ def callback(self, bytes_transferred):
+ self.aggregated_amounts.append(bytes_transferred)
+
+ def test_under_threshold(self):
+ one_under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(one_under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback(1)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_at_threshold(self):
+ self.aggregated_progress_callback(self.threshold)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_over_threshold(self):
+ over_threshold_amount = self.threshold + 1
+ self.aggregated_progress_callback(over_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [over_threshold_amount])
+
+ def test_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+ def test_flush_with_nothing_to_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+ # Flushing again should do nothing as it was just flushed
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+
+class TestInterruptReader(BaseUploadTest):
+ def test_read_raises_exception(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Read some bytes to show it can be read.
+ self.assertEqual(reader.read(1), self.content[0:1])
+ # Then set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError())
+ # The next read should have the exception propograte
+ with self.assertRaises(InterruptionError):
+ reader.read()
+
+ def test_seek(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can seek correctly
+ reader.seek(1)
+ self.assertEqual(reader.read(1), self.content[1:2])
+
+ def test_tell(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can tell correctly
+ reader.seek(1)
+ self.assertEqual(reader.tell(), 1)
+
+
+class BaseUploadInputManagerTest(BaseUploadTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+ self.config = TransferConfig()
+ self.recording_subscriber = RecordingSubscriber()
+ self.subscribers.append(self.recording_subscriber)
+
+ def _get_expected_body_for_part(self, part_number):
+ # A helper method for retrieving the expected body for a specific
+ # part number of the data
+ total_size = len(self.content)
+ chunk_size = self.config.multipart_chunksize
+ start_index = (part_number - 1) * chunk_size
+ end_index = part_number * chunk_size
+ if end_index >= total_size:
+ return self.content[start_index:]
+ return self.content[start_index:end_index]
+
+
+class TestUploadFilenameInputManager(BaseUploadInputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadFilenameInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.call_args = CallArgs(
+ fileobj=self.filename, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.upload_input_manager.is_compatible(
+ self.future.meta.call_args.fileobj
+ )
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # The provided file size should be equal to size of the contents of
+ # the file.
+ self.assertEqual(self.future.meta.size, len(self.content))
+
+ def test_requires_multipart_upload(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ # With the default multipart threshold, the length of the content
+ # should be smaller than the threshold thus not requiring a multipart
+ # transfer.
+ self.assertFalse(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+ # Decreasing the threshold to that of the length of the content of
+ # the file should trigger the need for a multipart upload.
+ self.config.multipart_threshold = len(self.content)
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ def test_get_put_object_body(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+ read_file_chunk.enable_callback()
+ # The file-like object provided back should be the same as the content
+ # of the file.
+ with read_file_chunk:
+ self.assertEqual(read_file_chunk.read(), self.content)
+ # The file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_get_put_object_body_is_interruptable(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ # Ensure the returned read file chunk can be interrupted with that
+ # error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+ def test_yield_upload_part_bodies(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ expected_part_number = 1
+ for part_number, read_file_chunk in part_iterator:
+ # Ensure that the part number is as expected
+ self.assertEqual(part_number, expected_part_number)
+ read_file_chunk.enable_callback()
+ # Ensure that the body is correct for that part.
+ with read_file_chunk:
+ self.assertEqual(
+ read_file_chunk.read(),
+ self._get_expected_body_for_part(part_number),
+ )
+ expected_part_number += 1
+
+ # All of the file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_yield_upload_part_bodies_are_interruptable(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ for _, read_file_chunk in part_iterator:
+ # Ensure that each read file chunk yielded can be interrupted
+ # with that error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+
+class TestUploadSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = open(self.filename, 'rb')
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(self.upload_input_manager.is_compatible(BytesIO()))
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(self.upload_input_manager.is_compatible(object()))
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_get_put_object_body(self):
+ start_pos = 3
+ self.fileobj.seek(start_pos)
+ adjusted_size = len(self.content) - start_pos
+ self.future.meta.provide_transfer_size(adjusted_size)
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ read_file_chunk.enable_callback()
+ # The fact that the file was seeked to start should be taken into
+ # account in length and content for the read file chunk.
+ with read_file_chunk:
+ self.assertEqual(len(read_file_chunk), adjusted_size)
+ self.assertEqual(read_file_chunk.read(), self.content[start_pos:])
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), adjusted_size
+ )
+
+
+class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadNonSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = NonSeekableReader(self.content)
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def assert_multipart_parts(self):
+ """
+ Asserts that the input manager will generate a multipart upload
+ and that each part is in order and the correct size.
+ """
+ # Assert that a multipart upload is required.
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ # Get a list of all the parts that would be sent.
+ parts = list(
+ self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ )
+
+ # Assert that the actual number of parts is what we would expect
+ # based on the configuration.
+ size = self.config.multipart_chunksize
+ num_parts = math.ceil(len(self.content) / size)
+ self.assertEqual(len(parts), num_parts)
+
+ # Run for every part but the last part.
+ for i, part in enumerate(parts[:-1]):
+ # Assert the part number is correct.
+ self.assertEqual(part[0], i + 1)
+ # Assert the part contains the right amount of data.
+ data = part[1].read()
+ self.assertEqual(len(data), size)
+
+ # Assert that the last part is the correct size.
+ expected_final_size = len(self.content) - ((num_parts - 1) * size)
+ final_part = parts[-1]
+ self.assertEqual(len(final_part[1].read()), expected_final_size)
+
+ # Assert that the last part has the correct part number.
+ self.assertEqual(final_part[0], len(parts))
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # There is no way to get the size without reading the entire body.
+ self.assertEqual(self.future.meta.size, None)
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_initial_data_parts_threshold_lesser(self):
+ # threshold < size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 2
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_equal(self):
+ # threshold == size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 4
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_greater(self):
+ # threshold > size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 8
+ self.assert_multipart_parts()
+
+
+class TestUploadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3)
+ self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE
+ self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def add_multipart_upload_stubbed_responses(self):
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response={'UploadId': 'my-id'},
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-1'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-2'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-3'}
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload', service_response={}
+ )
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_value_for_put_object(self, tag_value):
+ self.assertEqual(self.executor.submissions[0]['tag'], tag_value)
+
+ def assert_tag_value_for_upload_parts(self, tag_value):
+ for submission in self.executor.submissions[1:-1]:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def test_provide_file_size_on_put(self):
+ self.call_args.subscribers.append(FileSizeProvider(len(self.content)))
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+
+ # With this submitter, it will fail to stat the file if a transfer
+ # size is not provided.
+ self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize()
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+ def test_submits_no_tag_for_put_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_no_tag_for_multipart_filename(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit any of the upload part tasks were
+ # were associated when submitted to the executor
+ self.assert_tag_value_for_upload_parts(None)
+
+ def test_submits_no_tag_for_put_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_tag_for_multipart_fileobj(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure tags to limit all of the upload part tasks were
+ # were associated when submitted to the executor as these tasks will
+ # have chunks of data stored with them in memory.
+ self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG)
+
+
+class TestPutObjectTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ PutObjectTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+
+class TestUploadPartTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'RequestPayer': 'requester'}
+ upload_id = 'my-id'
+ part_number = 1
+ etag = 'foo'
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ UploadPartTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'part_number': part_number,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='upload_part',
+ service_response={'ETag': etag},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'PartNumber': part_number,
+ 'RequestPayer': 'requester',
+ },
+ )
+ rval = task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number})
+ self.assertEqual(self.sent_bodies, [self.content])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
new file mode 100644
index 0000000000..217779943b
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
@@ -0,0 +1,1189 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import io
+import os.path
+import random
+import re
+import shutil
+import tempfile
+import threading
+import time
+from io import BytesIO, StringIO
+
+from s3transfer.futures import TransferFuture, TransferMeta
+from s3transfer.utils import (
+ MAX_PARTS,
+ MAX_SINGLE_UPLOAD_SIZE,
+ MIN_UPLOAD_CHUNKSIZE,
+ CallArgs,
+ ChunksizeAdjuster,
+ CountCallbackInvoker,
+ DeferredOpenFile,
+ FunctionContainer,
+ NoResourcesAvailable,
+ OSUtils,
+ ReadFileChunk,
+ SlidingWindowSemaphore,
+ StreamReaderProgress,
+ TaskSemaphore,
+ calculate_num_parts,
+ calculate_range_parameter,
+ get_callbacks,
+ get_filtered_dict,
+ invoke_progress_callbacks,
+ random_file_extension,
+)
+from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
+
+
+class TestGetCallbacks(unittest.TestCase):
+ def setUp(self):
+ self.subscriber = RecordingSubscriber()
+ self.second_subscriber = RecordingSubscriber()
+ self.call_args = CallArgs(
+ subscribers=[self.subscriber, self.second_subscriber]
+ )
+ self.transfer_meta = TransferMeta(self.call_args)
+ self.transfer_future = TransferFuture(self.transfer_meta)
+
+ def test_get_callbacks(self):
+ callbacks = get_callbacks(self.transfer_future, 'queued')
+ # Make sure two callbacks were added as both subscribers had
+ # an on_queued method.
+ self.assertEqual(len(callbacks), 2)
+
+ # Ensure that the callback was injected with the future by calling
+ # one of them and checking that the future was used in the call.
+ callbacks[0]()
+ self.assertEqual(
+ self.subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_get_callbacks_for_missing_type(self):
+ callbacks = get_callbacks(self.transfer_future, 'fake_state')
+ # There should be no callbacks as the subscribers will not have the
+ # on_fake_state method
+ self.assertEqual(len(callbacks), 0)
+
+
+class TestGetFilteredDict(unittest.TestCase):
+ def test_get_filtered_dict(self):
+ original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'}
+ whitelist = ['Include']
+ self.assertEqual(
+ get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'}
+ )
+
+
+class TestCallArgs(unittest.TestCase):
+ def test_call_args(self):
+ call_args = CallArgs(foo='bar', biz='baz')
+ self.assertEqual(call_args.foo, 'bar')
+ self.assertEqual(call_args.biz, 'baz')
+
+
+class TestFunctionContainer(unittest.TestCase):
+ def get_args_kwargs(self, *args, **kwargs):
+ return args, kwargs
+
+ def test_call(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(func_container(), (('foo',), {'bar': 'baz'}))
+
+ def test_repr(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(
+ str(func_container),
+ 'Function: {} with args {} and kwargs {}'.format(
+ self.get_args_kwargs, ('foo',), {'bar': 'baz'}
+ ),
+ )
+
+
+class TestCountCallbackInvoker(unittest.TestCase):
+ def invoke_callback(self):
+ self.ref_results.append('callback invoked')
+
+ def assert_callback_invoked(self):
+ self.assertEqual(self.ref_results, ['callback invoked'])
+
+ def assert_callback_not_invoked(self):
+ self.assertEqual(self.ref_results, [])
+
+ def setUp(self):
+ self.ref_results = []
+ self.invoker = CountCallbackInvoker(self.invoke_callback)
+
+ def test_increment(self):
+ self.invoker.increment()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_decrement(self):
+ self.invoker.increment()
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_count_cannot_go_below_zero(self):
+ with self.assertRaises(RuntimeError):
+ self.invoker.decrement()
+
+ def test_callback_invoked_only_once_finalized(self):
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assert_callback_not_invoked()
+ self.invoker.finalize()
+ # Callback should only be invoked once finalized
+ self.assert_callback_invoked()
+
+ def test_callback_invoked_after_finalizing_and_count_reaching_zero(self):
+ self.invoker.increment()
+ self.invoker.finalize()
+ # Make sure that it does not get invoked immediately after
+ # finalizing as the count is currently one
+ self.assert_callback_not_invoked()
+ self.invoker.decrement()
+ self.assert_callback_invoked()
+
+ def test_cannot_increment_after_finalization(self):
+ self.invoker.finalize()
+ with self.assertRaises(RuntimeError):
+ self.invoker.increment()
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestInvokeProgressCallbacks(unittest.TestCase):
+ def test_invoke_progress_callbacks(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 2)
+ self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2)
+
+ def test_invoke_progress_callbacks_with_no_progress(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 0)
+ self.assertEqual(len(recording_subscriber.on_progress_calls), 0)
+
+
+class TestCalculateNumParts(unittest.TestCase):
+ def test_calculate_num_parts_divisible(self):
+ self.assertEqual(calculate_num_parts(size=4, part_size=2), 2)
+
+ def test_calculate_num_parts_not_divisible(self):
+ self.assertEqual(calculate_num_parts(size=3, part_size=2), 2)
+
+
+class TestCalculateRangeParameter(unittest.TestCase):
+ def setUp(self):
+ self.part_size = 5
+ self.part_index = 1
+ self.num_parts = 3
+
+ def test_calculate_range_paramter(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, self.num_parts
+ )
+ self.assertEqual(range_val, 'bytes=5-9')
+
+ def test_last_part_with_no_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2
+ )
+ self.assertEqual(range_val, 'bytes=5-')
+
+ def test_last_part_with_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2, total_size=8
+ )
+ self.assertEqual(range_val, 'bytes=5-7')
+
+
+class BaseUtilsTest(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.content = b'abc'
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.amounts_seen = []
+ self.num_close_callback_calls = 0
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def callback(self, bytes_transferred):
+ self.amounts_seen.append(bytes_transferred)
+
+ def close_callback(self):
+ self.num_close_callback_calls += 1
+
+
+class TestOSUtils(BaseUtilsTest):
+ def test_get_file_size(self):
+ self.assertEqual(
+ OSUtils().get_file_size(self.filename), len(self.content)
+ )
+
+ def test_open_file_chunk_reader(self):
+ reader = OSUtils().open_file_chunk_reader(
+ self.filename, 0, 3, [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_open_file_chunk_reader_from_fileobj(self):
+ with open(self.filename, 'rb') as f:
+ reader = OSUtils().open_file_chunk_reader_from_fileobj(
+ f, len(self.content), len(self.content), [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ reader.close()
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ non_existent_file = os.path.join(self.tempdir, 'no-exist')
+ # This should not exist to start.
+ self.assertFalse(os.path.exists(non_existent_file))
+ try:
+ OSUtils().remove_file(non_existent_file)
+ except OSError as e:
+ self.fail('OSError should have been caught: %s' % e)
+
+ def test_remove_file_proxies_remove_file(self):
+ OSUtils().remove_file(self.filename)
+ self.assertFalse(os.path.exists(self.filename))
+
+ def test_rename_file(self):
+ new_filename = os.path.join(self.tempdir, 'newfoo')
+ OSUtils().rename_file(self.filename, new_filename)
+ self.assertFalse(os.path.exists(self.filename))
+ self.assertTrue(os.path.exists(new_filename))
+
+ def test_is_special_file_for_normal_file(self):
+ self.assertFalse(OSUtils().is_special_file(self.filename))
+
+ def test_is_special_file_for_non_existant_file(self):
+ non_existant_filename = os.path.join(self.tempdir, 'no-exist')
+ self.assertFalse(os.path.exists(non_existant_filename))
+ self.assertFalse(OSUtils().is_special_file(non_existant_filename))
+
+ def test_get_temp_filename(self):
+ filename = 'myfile'
+ self.assertIsNotNone(
+ re.match(
+ r'%s\.[0-9A-Fa-f]{8}$' % filename,
+ OSUtils().get_temp_filename(filename),
+ )
+ )
+
+ def test_get_temp_filename_len_255(self):
+ filename = 'a' * 255
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_get_temp_filename_len_gt_255(self):
+ filename = 'a' * 280
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_allocate(self):
+ truncate_size = 1
+ OSUtils().allocate(self.filename, truncate_size)
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(len(f.read()), truncate_size)
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_io_error(self, mock_fallocate):
+ mock_fallocate.side_effect = IOError()
+ with self.assertRaises(IOError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_os_error(self, mock_fallocate):
+ mock_fallocate.side_effect = OSError()
+ with self.assertRaises(OSError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+
+class TestDeferredOpenFile(BaseUtilsTest):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.contents = b'my contents'
+ with open(self.filename, 'wb') as f:
+ f.write(self.contents)
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.open_call_args = []
+
+ def tearDown(self):
+ self.deferred_open_file.close()
+ super().tearDown()
+
+ def recording_open_function(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return open(filename, mode)
+
+ def open_nonseekable(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return NonSeekableWriter(BytesIO(self.content))
+
+ def test_instantiation_does_not_open_file(self):
+ DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.assertEqual(len(self.open_call_args), 0)
+
+ def test_name(self):
+ self.assertEqual(self.deferred_open_file.name, self.filename)
+
+ def test_read(self):
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[0:2])
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_write(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ open_function=self.recording_open_function,
+ )
+
+ write_content = b'foo'
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.close()
+ # Both of the writes should now be in the file.
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), write_content * 2)
+ # Open should have only been called once.
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_seek(self):
+ self.deferred_open_file.seek(2)
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_does_not_seek_with_zero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=0,
+ open_function=self.open_nonseekable,
+ )
+
+ try:
+ # If this seeks, an UnsupportedOperation error will be raised.
+ self.deferred_open_file.write(b'data')
+ except io.UnsupportedOperation:
+ self.fail('DeferredOpenFile seeked upon opening')
+
+ def test_open_seeks_with_nonzero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=5,
+ open_function=self.open_nonseekable,
+ )
+
+ # Since a non-seekable file is being opened, calling Seek will raise
+ # an UnsupportedOperation error.
+ with self.assertRaises(io.UnsupportedOperation):
+ self.deferred_open_file.write(b'data')
+
+ def test_tell(self):
+ self.deferred_open_file.tell()
+ # tell() should not have opened the file if it has not been seeked
+ # or read because we know the start bytes upfront.
+ self.assertEqual(len(self.open_call_args), 0)
+
+ self.deferred_open_file.seek(2)
+ self.assertEqual(self.deferred_open_file.tell(), 2)
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_args(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='ab+',
+ open_function=self.recording_open_function,
+ )
+ # Force an open
+ self.deferred_open_file.write(b'data')
+ self.assertEqual(len(self.open_call_args), 1)
+ self.assertEqual(self.open_call_args[0], (self.filename, 'ab+'))
+
+ def test_context_handler(self):
+ with self.deferred_open_file:
+ self.assertEqual(len(self.open_call_args), 1)
+
+
+class TestReadFileChunk(BaseUtilsTest):
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(1, whence=1)
+ self.assertEqual(chunk.tell(), 1)
+ chunk.seek(-1, whence=1)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(-1, whence=2)
+ self.assertEqual(chunk.tell(), 2)
+
+ def test_tell_and_seek_boundaries(self):
+ # Test to ensure ReadFileChunk behaves the same as the
+ # Python standard library around seeking and reading out
+ # of bounds in a file object.
+ data = b'abcdefghij12345678klmnopqrst'
+ start_pos = 10
+ chunk_size = 8
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+
+ # ReadFileChunk should be a substring of only numbers
+ file_objects = [
+ ReadFileChunk.from_filename(
+ filename, start_byte=start_pos, chunk_size=chunk_size
+ )
+ ]
+
+ # Uncomment next line to validate we match Python's io.BytesIO
+ # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size]))
+
+ for obj in file_objects:
+ self._assert_whence_start_behavior(obj)
+ self._assert_whence_end_behavior(obj)
+ self._assert_whence_relative_behavior(obj)
+ self._assert_boundary_behavior(obj)
+
+ def _assert_whence_start_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(1, 0)
+ self.assertEqual(file_obj.tell(), 1)
+
+ file_obj.seek(1)
+ self.assertEqual(file_obj.tell(), 1)
+ self.assertEqual(file_obj.read(), b'2345678')
+
+ file_obj.seek(3, 0)
+ self.assertEqual(file_obj.tell(), 3)
+
+ file_obj.seek(0, 0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_relative_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(2, 1)
+ self.assertEqual(file_obj.tell(), 2)
+
+ file_obj.seek(1, 1)
+ self.assertEqual(file_obj.tell(), 3)
+ self.assertEqual(file_obj.read(), b'45678')
+
+ file_obj.seek(20, 1)
+ self.assertEqual(file_obj.tell(), 28)
+
+ file_obj.seek(-30, 1)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_end_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(-1, 2)
+ self.assertEqual(file_obj.tell(), 7)
+
+ file_obj.seek(1, 2)
+ self.assertEqual(file_obj.tell(), 9)
+
+ file_obj.seek(3, 2)
+ self.assertEqual(file_obj.tell(), 11)
+ self.assertEqual(file_obj.read(), b'')
+
+ file_obj.seek(-15, 2)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 2)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_boundary_behavior(self, file_obj):
+ # Verify we're at the start
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we can't move backwards beyond start of file
+ file_obj.seek(-10, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we *can* move after end of file, but return nothing
+ file_obj.seek(10, 2)
+ self.assertEqual(file_obj.tell(), 18)
+ self.assertEqual(file_obj.read(), b'')
+ self.assertEqual(file_obj.read(10), b'')
+
+ # Verify we can partially rewind
+ file_obj.seek(-12, 1)
+ self.assertEqual(file_obj.tell(), 6)
+ self.assertEqual(file_obj.read(), b'78')
+ self.assertEqual(file_obj.tell(), 8)
+
+ # Verify we can rewind to start
+ file_obj.seek(0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_all_callbacks_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback, self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ # The list should be twice as long because there are two callbacks
+ # recording the amount read.
+ self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+ def test_callback_triggered_by_out_of_bound_seeks(self):
+ data = b'abcdefghij1234567890klmnopqr'
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=10, chunk_size=10, callbacks=[self.callback]
+ )
+
+ # Seek calls that generate "0" progress are skipped by
+ # invoke_progress_callbacks and won't appear in the list.
+ expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10]
+
+ self._assert_out_of_bound_start_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_end_seek(chunk, expected_callback_prog)
+
+ def _assert_out_of_bound_start_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20) # (20, 10)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(11) # (11, 0)
+ chunk.seek(20) # (20, 0)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(0) # (0, -10)
+ chunk.seek(0) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_relative_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20, 1) # (20, 10)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-9, 1) # (11, 0)
+ chunk.seek(9, 1) # (20, 0)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-20, 1) # (0, -10)
+ chunk.seek(-1000, 1) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_end_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(10, 2) # (20, 10)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(1, 2) # (11, 0)
+ chunk.seek(10, 2) # (20, 0)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-10, 2) # (0, -10)
+ chunk.seek(-1000, 2) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def test_close_callbacks(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_close_callbacks_when_not_enabled(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ enable_callbacks=False,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_close_callbacks_when_context_handler_is_used(self):
+ with open(self.filename) as f:
+ with ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ ) as chunk:
+ chunk.read(1)
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_signal_transferring(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.signal_not_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [])
+ chunk.signal_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1])
+
+ def test_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_transferring()
+ self.assertTrue(underlying_stream.signal_transferring.called)
+
+ def test_no_call_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call signal_transferring '
+ 'to the underlying stream.'
+ )
+
+ def test_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_not_transferring()
+ self.assertTrue(underlying_stream.signal_not_transferring.called)
+
+ def test_no_call_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_not_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call '
+ 'signal_not_transferring to the underlying stream.'
+ )
+
+
+class TestStreamReaderProgress(BaseUtilsTest):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(
+ original_stream, [self.callback, self.callback]
+ )
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(self.amounts_seen, [9, 9])
+
+
+class TestTaskSemaphore(unittest.TestCase):
+ def setUp(self):
+ self.semaphore = TaskSemaphore(1)
+
+ def test_should_block_at_max_capacity(self):
+ self.semaphore.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ self.semaphore.acquire('a', blocking=False)
+
+ def test_release_capacity(self):
+ acquire_token = self.semaphore.acquire('a', blocking=False)
+ self.semaphore.release('a', acquire_token)
+ try:
+ self.semaphore.acquire('a', blocking=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'The release of the semaphore should have allowed for '
+ 'the second acquire to not be blocked'
+ )
+
+
+class TestSlidingWindowSemaphore(unittest.TestCase):
+ # These tests use block=False to tests will fail
+ # instead of hang the test runner in the case of x
+ # incorrect behavior.
+ def test_acquire_release_basic_case(self):
+ sem = SlidingWindowSemaphore(1)
+ # Count is 1
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', 0)
+ # Count now back to 1.
+
+ def test_can_acquire_release_multiple_times(self):
+ sem = SlidingWindowSemaphore(1)
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', num)
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 1)
+ sem.release('a', num)
+
+ def test_can_acquire_a_range(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ sem.release('a', 2)
+ # Now we're reset so we should be able to acquire the same
+ # sequence again.
+ self.assertEqual(sem.acquire('a', blocking=False), 3)
+ self.assertEqual(sem.acquire('a', blocking=False), 4)
+ self.assertEqual(sem.acquire('a', blocking=False), 5)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_counter_release_only_on_min_element(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # The count only increases when we free the min
+ # element. This means if we're currently failing to
+ # acquire now:
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ # Then freeing a non-min element:
+ sem.release('a', 1)
+
+ # doesn't change anything. We still fail to acquire.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_raises_error_when_count_is_zero(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # Count is now 0 so trying to acquire should fail.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ def test_release_counters_can_increment_counter_repeatedly(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # These two releases don't increment the counter
+ # because we're waiting on 0.
+ sem.release('a', 1)
+ sem.release('a', 2)
+ self.assertEqual(sem.current_count(), 0)
+ # But as soon as we release 0, we free up 0, 1, and 2.
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def test_error_to_release_unknown_tag(self):
+ sem = SlidingWindowSemaphore(3)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_track_multiple_tags(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+
+ # We're at our max of 3 even though 2 are for A and 1 is for B.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('b', blocking=False)
+
+ def test_can_handle_multiple_tags_released(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('b', blocking=False)
+ sem.acquire('b', blocking=False)
+
+ sem.release('b', 1)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 0)
+
+ sem.release('b', 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+
+ sem.release('a', 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 2)
+
+ def test_is_error_to_release_unknown_sequence_number(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(ValueError):
+ sem.release('a', 1)
+
+ def test_is_error_to_double_release(self):
+ # This is different than other error tests because
+ # we're verifying we can reset the state after an
+ # acquire/release cycle.
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 2)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_check_in_partial_range(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ sem.release('a', 1)
+ sem.release('a', 3)
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 2)
+
+
+class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase):
+ # These tests focus on mutithreaded properties of the range
+ # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore.
+ def setUp(self):
+ self.threads = []
+
+ def tearDown(self):
+ self.join_threads()
+
+ def join_threads(self):
+ for thread in self.threads:
+ thread.join()
+ self.threads = []
+
+ def start_threads(self):
+ for thread in self.threads:
+ thread.start()
+
+ def test_acquire_blocks_until_release_is_called(self):
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def acquire():
+ # This next call to acquire will block.
+ self.assertEqual(sem.acquire('a', blocking=True), 2)
+
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ # Starting the thread will block the sem.acquire()
+ # in the acquire function above.
+ t.start()
+ # This still will keep the thread blocked.
+ sem.release('a', 1)
+ # Releasing the min element will unblock the thread.
+ sem.release('a', 0)
+ t.join()
+ sem.release('a', 2)
+
+ def test_stress_invariants_random_order(self):
+ sem = SlidingWindowSemaphore(100)
+ for _ in range(10):
+ recorded = []
+ for _ in range(100):
+ recorded.append(sem.acquire('a', blocking=False))
+ # Release them in randomized order. As long as we
+ # eventually free all 100, we should have all the
+ # resources released.
+ random.shuffle(recorded)
+ for i in recorded:
+ sem.release('a', i)
+
+ # Everything's freed so should be back at count == 100
+ self.assertEqual(sem.current_count(), 100)
+
+ def test_blocking_stress(self):
+ sem = SlidingWindowSemaphore(5)
+ num_threads = 10
+ num_iterations = 50
+
+ def acquire():
+ for _ in range(num_iterations):
+ num = sem.acquire('a', blocking=True)
+ time.sleep(0.001)
+ sem.release('a', num)
+
+ for i in range(num_threads):
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ self.start_threads()
+ self.join_threads()
+ # Should have all the available resources freed.
+ self.assertEqual(sem.current_count(), 5)
+ # Should have acquired num_threads * num_iterations
+ self.assertEqual(
+ sem.acquire('a', blocking=False), num_threads * num_iterations
+ )
+
+
+class TestAdjustChunksize(unittest.TestCase):
+ def setUp(self):
+ self.adjuster = ChunksizeAdjuster()
+
+ def test_valid_chunksize(self):
+ chunksize = 7 * (1024**2)
+ file_size = 8 * (1024**2)
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, chunksize)
+
+ def test_chunksize_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ file_size = 3 * MIN_UPLOAD_CHUNKSIZE
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_chunksize_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ file_size = MAX_SINGLE_UPLOAD_SIZE * 2
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
+
+ def test_chunksize_too_small(self):
+ chunksize = 7 * (1024**2)
+ file_size = 5 * (1024**4)
+ # If we try to upload a 5TB file, we'll need to use 896MB part
+ # sizes.
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, 896 * (1024**2))
+ num_parts = file_size / new_size
+ self.assertLessEqual(num_parts, MAX_PARTS)
+
+ def test_unknown_file_size_with_valid_chunksize(self):
+ chunksize = 7 * (1024**2)
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, chunksize)
+
+ def test_unknown_file_size_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_unknown_file_size_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)