aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/s3transfer/py3
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/python/s3transfer/py3
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/python/s3transfer/py3')
-rw-r--r--contrib/python/s3transfer/py3/.dist-info/METADATA42
-rw-r--r--contrib/python/s3transfer/py3/.dist-info/top_level.txt1
-rw-r--r--contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml2
-rw-r--r--contrib/python/s3transfer/py3/LICENSE.txt202
-rw-r--r--contrib/python/s3transfer/py3/NOTICE.txt2
-rw-r--r--contrib/python/s3transfer/py3/README.rst13
-rw-r--r--contrib/python/s3transfer/py3/patches/01-fix-tests.patch242
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/__init__.py875
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/bandwidth.py439
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/compat.py94
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/constants.py29
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/copies.py368
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/crt.py644
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/delete.py71
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/download.py790
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/exceptions.py37
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/futures.py606
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/manager.py727
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/processpool.py1008
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/subscribers.py92
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/tasks.py387
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/upload.py795
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/utils.py802
-rw-r--r--contrib/python/s3transfer/py3/tests/__init__.py531
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/__init__.py12
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_copy.py554
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_crt.py267
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_delete.py76
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_download.py497
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_manager.py191
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_processpool.py281
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_upload.py538
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_utils.py41
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/__init__.py12
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py452
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_compat.py105
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_copies.py177
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_crt.py173
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_delete.py67
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_download.py999
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_futures.py696
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_manager.py143
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_processpool.py728
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py780
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_subscribers.py91
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_tasks.py833
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_upload.py694
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_utils.py1189
-rw-r--r--contrib/python/s3transfer/py3/tests/ya.make44
-rw-r--r--contrib/python/s3transfer/py3/ya.make51
50 files changed, 18490 insertions, 0 deletions
diff --git a/contrib/python/s3transfer/py3/.dist-info/METADATA b/contrib/python/s3transfer/py3/.dist-info/METADATA
new file mode 100644
index 0000000000..7d635068d7
--- /dev/null
+++ b/contrib/python/s3transfer/py3/.dist-info/METADATA
@@ -0,0 +1,42 @@
+Metadata-Version: 2.1
+Name: s3transfer
+Version: 0.5.1
+Summary: An Amazon S3 Transfer Manager
+Home-page: https://github.com/boto/s3transfer
+Author: Amazon Web Services
+Author-email: kyknapp1@gmail.com
+License: Apache License 2.0
+Platform: UNKNOWN
+Classifier: Development Status :: 3 - Alpha
+Classifier: Intended Audience :: Developers
+Classifier: Natural Language :: English
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Requires-Python: >= 3.6
+License-File: LICENSE.txt
+License-File: NOTICE.txt
+Requires-Dist: botocore (<2.0a.0,>=1.12.36)
+Provides-Extra: crt
+Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt'
+
+=====================================================
+s3transfer - An Amazon S3 Transfer Manager for Python
+=====================================================
+
+S3transfer is a Python library for managing Amazon S3 transfers.
+This project is maintained and published by Amazon Web Services.
+
+.. note::
+
+ This project is not currently GA. If you are planning to use this code in
+ production, make sure to lock to a minor version as interfaces may break
+ from minor version to minor version. For a basic, stable interface of
+ s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
+
+
diff --git a/contrib/python/s3transfer/py3/.dist-info/top_level.txt b/contrib/python/s3transfer/py3/.dist-info/top_level.txt
new file mode 100644
index 0000000000..572c6a92fb
--- /dev/null
+++ b/contrib/python/s3transfer/py3/.dist-info/top_level.txt
@@ -0,0 +1 @@
+s3transfer
diff --git a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml
new file mode 100644
index 0000000000..f2f140fb3c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml
@@ -0,0 +1,2 @@
+exclude:
+- tests/integration/.+
diff --git a/contrib/python/s3transfer/py3/LICENSE.txt b/contrib/python/s3transfer/py3/LICENSE.txt
new file mode 100644
index 0000000000..d645695673
--- /dev/null
+++ b/contrib/python/s3transfer/py3/LICENSE.txt
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/contrib/python/s3transfer/py3/NOTICE.txt b/contrib/python/s3transfer/py3/NOTICE.txt
new file mode 100644
index 0000000000..3e616fdf0c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/NOTICE.txt
@@ -0,0 +1,2 @@
+s3transfer
+Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
diff --git a/contrib/python/s3transfer/py3/README.rst b/contrib/python/s3transfer/py3/README.rst
new file mode 100644
index 0000000000..441029109e
--- /dev/null
+++ b/contrib/python/s3transfer/py3/README.rst
@@ -0,0 +1,13 @@
+=====================================================
+s3transfer - An Amazon S3 Transfer Manager for Python
+=====================================================
+
+S3transfer is a Python library for managing Amazon S3 transfers.
+This project is maintained and published by Amazon Web Services.
+
+.. note::
+
+ This project is not currently GA. If you are planning to use this code in
+ production, make sure to lock to a minor version as interfaces may break
+ from minor version to minor version. For a basic, stable interface of
+ s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
diff --git a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch
new file mode 100644
index 0000000000..aa8d3fab4e
--- /dev/null
+++ b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch
@@ -0,0 +1,242 @@
+--- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree)
+@@ -15,7 +15,7 @@ from botocore.stub import Stubber
+
+ from s3transfer.manager import TransferConfig, TransferManager
+ from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
+-from tests import BaseGeneralInterfaceTest, FileSizeProvider
++from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
+
+
+ class BaseCopyTest(BaseGeneralInterfaceTest):
+--- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree)
+@@ -18,7 +18,7 @@ from concurrent.futures import Future
+ from botocore.session import Session
+
+ from s3transfer.subscribers import BaseSubscriber
+-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+ if HAS_CRT:
+ import awscrt
+--- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.manager import TransferManager
+-from tests import BaseGeneralInterfaceTest
++from __tests__ import BaseGeneralInterfaceTest
+
+
+ class TestDeleteObject(BaseGeneralInterfaceTest):
+--- contrib/python/s3transfer/py3/tests/functional/test_download.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree)
+@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError
+ from s3transfer.compat import SOCKET_ERROR
+ from s3transfer.exceptions import RetriesExceededError
+ from s3transfer.manager import TransferConfig, TransferManager
+-from tests import (
++from __tests__ import (
+ BaseGeneralInterfaceTest,
+ FileSizeProvider,
+ NonSeekableWriter,
+--- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree)
+@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object
+ from s3transfer.exceptions import CancelledError, FatalError
+ from s3transfer.futures import BaseExecutor
+ from s3transfer.manager import TransferConfig, TransferManager
+-from tests import StubbedClientTest, mock, skip_if_using_serial_implementation
++from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
+
+
+ class ArbitraryException(Exception):
+--- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree)
+@@ -21,7 +21,7 @@ from botocore.stub import Stubber
+
+ from s3transfer.exceptions import CancelledError
+ from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
+-from tests import FileCreator, mock, unittest
++from __tests__ import FileCreator, mock, unittest
+
+
+ class StubbedClient:
+--- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree)
+@@ -23,7 +23,7 @@ from botocore.stub import ANY
+
+ from s3transfer.manager import TransferConfig, TransferManager
+ from s3transfer.utils import ChunksizeAdjuster
+-from tests import (
++from __tests__ import (
+ BaseGeneralInterfaceTest,
+ NonSeekableReader,
+ RecordingOSUtils,
+--- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree)
+@@ -16,7 +16,7 @@ import socket
+ import tempfile
+
+ from s3transfer.utils import OSUtils
+-from tests import skip_if_windows, unittest
++from __tests__ import skip_if_windows, unittest
+
+
+ @skip_if_windows('Windows does not support UNIX special files')
+--- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree)
+@@ -25,7 +25,7 @@ from s3transfer.bandwidth import (
+ TimeUtils,
+ )
+ from s3transfer.futures import TransferCoordinator
+-from tests import mock, unittest
++from __tests__ import mock, unittest
+
+
+ class FixedIncrementalTickTimeUtils(TimeUtils):
+--- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree)
+@@ -17,7 +17,7 @@ import tempfile
+ from io import BytesIO
+
+ from s3transfer.compat import BaseManager, readable, seekable
+-from tests import skip_if_windows, unittest
++from __tests__ import skip_if_windows, unittest
+
+
+ class ErrorRaisingSeekWrapper:
+--- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.copies import CopyObjectTask, CopyPartTask
+-from tests import BaseTaskTest, RecordingSubscriber
++from __tests__ import BaseTaskTest, RecordingSubscriber
+
+
+ class BaseCopyTaskTest(BaseTaskTest):
+--- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree)
+@@ -15,7 +15,7 @@ from botocore.session import Session
+
+ from s3transfer.exceptions import TransferNotDoneError
+ from s3transfer.utils import CallArgs
+-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+ if HAS_CRT:
+ import awscrt.s3
+--- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.delete import DeleteObjectTask
+-from tests import BaseTaskTest
++from __tests__ import BaseTaskTest
+
+
+ class TestDeleteObjectTask(BaseTaskTest):
+--- contrib/python/s3transfer/py3/tests/unit/test_download.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree)
+@@ -37,7 +37,7 @@ from s3transfer.download import (
+ from s3transfer.exceptions import RetriesExceededError
+ from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
+ from s3transfer.utils import CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileCreator,
+--- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree)
+@@ -37,7 +37,7 @@ from s3transfer.utils import (
+ NoResourcesAvailable,
+ TaskSemaphore,
+ )
+-from tests import (
++from __tests__ import (
+ RecordingExecutor,
+ TransferCoordinatorWithInterrupt,
+ mock,
+--- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree)
+@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor
+ from s3transfer.exceptions import CancelledError, FatalError
+ from s3transfer.futures import TransferCoordinator
+ from s3transfer.manager import TransferConfig, TransferCoordinatorController
+-from tests import TransferCoordinatorWithInterrupt, unittest
++from __tests__ import TransferCoordinatorWithInterrupt, unittest
+
+
+ class FutureResultException(Exception):
+--- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree)
+@@ -39,7 +39,7 @@ from s3transfer.processpool import (
+ ignore_ctrl_c,
+ )
+ from s3transfer.utils import CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ FileCreator,
+ StreamWithError,
+ StubbedClientTest,
+--- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree)
+@@ -33,7 +33,7 @@ from s3transfer import (
+ random_file_extension,
+ )
+ from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+-from tests import mock, unittest
++from __tests__ import mock, unittest
+
+
+ class InMemoryOSLayer(OSUtils):
+--- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree)
+@@ -12,7 +12,7 @@
+ # language governing permissions and limitations under the License.
+ from s3transfer.exceptions import InvalidSubscriberMethodError
+ from s3transfer.subscribers import BaseSubscriber
+-from tests import unittest
++from __tests__ import unittest
+
+
+ class ExtraMethodsSubscriber(BaseSubscriber):
+--- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree)
+@@ -23,7 +23,7 @@ from s3transfer.tasks import (
+ Task,
+ )
+ from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ RecordingSubscriber,
+--- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree)
+@@ -32,7 +32,7 @@ from s3transfer.upload import (
+ UploadSubmissionTask,
+ )
+ from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileSizeProvider,
+--- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree)
+@@ -43,7 +43,7 @@ from s3transfer.utils import (
+ invoke_progress_callbacks,
+ random_file_extension,
+ )
+-from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest
++from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
+
+
+ class TestGetCallbacks(unittest.TestCase):
diff --git a/contrib/python/s3transfer/py3/s3transfer/__init__.py b/contrib/python/s3transfer/py3/s3transfer/__init__.py
new file mode 100644
index 0000000000..1a749c712e
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/__init__.py
@@ -0,0 +1,875 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Abstractions over S3's upload/download operations.
+
+This module provides high level abstractions for efficient
+uploads/downloads. It handles several things for the user:
+
+* Automatically switching to multipart transfers when
+ a file is over a specific size threshold
+* Uploading/downloading a file in parallel
+* Throttling based on max bandwidth
+* Progress callbacks to monitor transfers
+* Retries. While botocore handles retries for streaming uploads,
+ it is not possible for it to handle retries for streaming
+ downloads. This module handles retries for both cases so
+ you don't need to implement any retry logic yourself.
+
+This module has a reasonable set of defaults. It also allows you
+to configure many aspects of the transfer process including:
+
+* Multipart threshold size
+* Max parallel downloads
+* Max bandwidth
+* Socket timeouts
+* Retry amounts
+
+There is no support for s3->s3 multipart copies at this
+time.
+
+
+.. _ref_s3transfer_usage:
+
+Usage
+=====
+
+The simplest way to use this module is:
+
+.. code-block:: python
+
+ client = boto3.client('s3', 'us-west-2')
+ transfer = S3Transfer(client)
+ # Upload /tmp/myfile to s3://bucket/key
+ transfer.upload_file('/tmp/myfile', 'bucket', 'key')
+
+ # Download s3://bucket/key to /tmp/myfile
+ transfer.download_file('bucket', 'key', '/tmp/myfile')
+
+The ``upload_file`` and ``download_file`` methods also accept
+``**kwargs``, which will be forwarded through to the corresponding
+client operation. Here are a few examples using ``upload_file``::
+
+ # Making the object public
+ transfer.upload_file('/tmp/myfile', 'bucket', 'key',
+ extra_args={'ACL': 'public-read'})
+
+ # Setting metadata
+ transfer.upload_file('/tmp/myfile', 'bucket', 'key',
+ extra_args={'Metadata': {'a': 'b', 'c': 'd'}})
+
+ # Setting content type
+ transfer.upload_file('/tmp/myfile.json', 'bucket', 'key',
+ extra_args={'ContentType': "application/json"})
+
+
+The ``S3Transfer`` class also supports progress callbacks so you can
+provide transfer progress to users. Both the ``upload_file`` and
+``download_file`` methods take an optional ``callback`` parameter.
+Here's an example of how to print a simple progress percentage
+to the user:
+
+.. code-block:: python
+
+ class ProgressPercentage(object):
+ def __init__(self, filename):
+ self._filename = filename
+ self._size = float(os.path.getsize(filename))
+ self._seen_so_far = 0
+ self._lock = threading.Lock()
+
+ def __call__(self, bytes_amount):
+ # To simplify we'll assume this is hooked up
+ # to a single filename.
+ with self._lock:
+ self._seen_so_far += bytes_amount
+ percentage = (self._seen_so_far / self._size) * 100
+ sys.stdout.write(
+ "\r%s %s / %s (%.2f%%)" % (self._filename, self._seen_so_far,
+ self._size, percentage))
+ sys.stdout.flush()
+
+
+ transfer = S3Transfer(boto3.client('s3', 'us-west-2'))
+ # Upload /tmp/myfile to s3://bucket/key and print upload progress.
+ transfer.upload_file('/tmp/myfile', 'bucket', 'key',
+ callback=ProgressPercentage('/tmp/myfile'))
+
+
+
+You can also provide a TransferConfig object to the S3Transfer
+object that gives you more fine grained control over the
+transfer. For example:
+
+.. code-block:: python
+
+ client = boto3.client('s3', 'us-west-2')
+ config = TransferConfig(
+ multipart_threshold=8 * 1024 * 1024,
+ max_concurrency=10,
+ num_download_attempts=10,
+ )
+ transfer = S3Transfer(client, config)
+ transfer.upload_file('/tmp/foo', 'bucket', 'key')
+
+
+"""
+import concurrent.futures
+import functools
+import logging
+import math
+import os
+import queue
+import random
+import socket
+import string
+import threading
+
+from botocore.compat import six # noqa: F401
+from botocore.exceptions import IncompleteReadError
+from botocore.vendored.requests.packages.urllib3.exceptions import (
+ ReadTimeoutError,
+)
+
+import s3transfer.compat
+from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+
+__author__ = 'Amazon Web Services'
+__version__ = '0.5.1'
+
+
+class NullHandler(logging.Handler):
+ def emit(self, record):
+ pass
+
+
+logger = logging.getLogger(__name__)
+logger.addHandler(NullHandler())
+
+MB = 1024 * 1024
+SHUTDOWN_SENTINEL = object()
+
+
+def random_file_extension(num_digits=8):
+ return ''.join(random.choice(string.hexdigits) for _ in range(num_digits))
+
+
+def disable_upload_callbacks(request, operation_name, **kwargs):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'disable_callback'
+ ):
+ request.body.disable_callback()
+
+
+def enable_upload_callbacks(request, operation_name, **kwargs):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'enable_callback'
+ ):
+ request.body.enable_callback()
+
+
+class QueueShutdownError(Exception):
+ pass
+
+
+class ReadFileChunk:
+ def __init__(
+ self,
+ fileobj,
+ start_byte,
+ chunk_size,
+ full_file_size,
+ callback=None,
+ enable_callback=True,
+ ):
+ """
+
+ Given a file object shown below:
+
+ |___________________________________________________|
+ 0 | | full_file_size
+ |----chunk_size---|
+ start_byte
+
+ :type fileobj: file
+ :param fileobj: File like object
+
+ :type start_byte: int
+ :param start_byte: The first byte from which to start reading.
+
+ :type chunk_size: int
+ :param chunk_size: The max chunk size to read. Trying to read
+ pass the end of the chunk size will behave like you've
+ reached the end of the file.
+
+ :type full_file_size: int
+ :param full_file_size: The entire content length associated
+ with ``fileobj``.
+
+ :type callback: function(amount_read)
+ :param callback: Called whenever data is read from this object.
+
+ """
+ self._fileobj = fileobj
+ self._start_byte = start_byte
+ self._size = self._calculate_file_size(
+ self._fileobj,
+ requested_size=chunk_size,
+ start_byte=start_byte,
+ actual_file_size=full_file_size,
+ )
+ self._fileobj.seek(self._start_byte)
+ self._amount_read = 0
+ self._callback = callback
+ self._callback_enabled = enable_callback
+
+ @classmethod
+ def from_filename(
+ cls,
+ filename,
+ start_byte,
+ chunk_size,
+ callback=None,
+ enable_callback=True,
+ ):
+ """Convenience factory function to create from a filename.
+
+ :type start_byte: int
+ :param start_byte: The first byte from which to start reading.
+
+ :type chunk_size: int
+ :param chunk_size: The max chunk size to read. Trying to read
+ pass the end of the chunk size will behave like you've
+ reached the end of the file.
+
+ :type full_file_size: int
+ :param full_file_size: The entire content length associated
+ with ``fileobj``.
+
+ :type callback: function(amount_read)
+ :param callback: Called whenever data is read from this object.
+
+ :type enable_callback: bool
+ :param enable_callback: Indicate whether to invoke callback
+ during read() calls.
+
+ :rtype: ``ReadFileChunk``
+ :return: A new instance of ``ReadFileChunk``
+
+ """
+ f = open(filename, 'rb')
+ file_size = os.fstat(f.fileno()).st_size
+ return cls(
+ f, start_byte, chunk_size, file_size, callback, enable_callback
+ )
+
+ def _calculate_file_size(
+ self, fileobj, requested_size, start_byte, actual_file_size
+ ):
+ max_chunk_size = actual_file_size - start_byte
+ return min(max_chunk_size, requested_size)
+
+ def read(self, amount=None):
+ if amount is None:
+ amount_to_read = self._size - self._amount_read
+ else:
+ amount_to_read = min(self._size - self._amount_read, amount)
+ data = self._fileobj.read(amount_to_read)
+ self._amount_read += len(data)
+ if self._callback is not None and self._callback_enabled:
+ self._callback(len(data))
+ return data
+
+ def enable_callback(self):
+ self._callback_enabled = True
+
+ def disable_callback(self):
+ self._callback_enabled = False
+
+ def seek(self, where):
+ self._fileobj.seek(self._start_byte + where)
+ if self._callback is not None and self._callback_enabled:
+ # To also rewind the callback() for an accurate progress report
+ self._callback(where - self._amount_read)
+ self._amount_read = where
+
+ def close(self):
+ self._fileobj.close()
+
+ def tell(self):
+ return self._amount_read
+
+ def __len__(self):
+ # __len__ is defined because requests will try to determine the length
+ # of the stream to set a content length. In the normal case
+ # of the file it will just stat the file, but we need to change that
+ # behavior. By providing a __len__, requests will use that instead
+ # of stat'ing the file.
+ return self._size
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+
+ def __iter__(self):
+ # This is a workaround for http://bugs.python.org/issue17575
+ # Basically httplib will try to iterate over the contents, even
+ # if its a file like object. This wasn't noticed because we've
+ # already exhausted the stream so iterating over the file immediately
+ # stops, which is what we're simulating here.
+ return iter([])
+
+
+class StreamReaderProgress:
+ """Wrapper for a read only stream that adds progress callbacks."""
+
+ def __init__(self, stream, callback=None):
+ self._stream = stream
+ self._callback = callback
+
+ def read(self, *args, **kwargs):
+ value = self._stream.read(*args, **kwargs)
+ if self._callback is not None:
+ self._callback(len(value))
+ return value
+
+
+class OSUtils:
+ def get_file_size(self, filename):
+ return os.path.getsize(filename)
+
+ def open_file_chunk_reader(self, filename, start_byte, size, callback):
+ return ReadFileChunk.from_filename(
+ filename, start_byte, size, callback, enable_callback=False
+ )
+
+ def open(self, filename, mode):
+ return open(filename, mode)
+
+ def remove_file(self, filename):
+ """Remove a file, noop if file does not exist."""
+ # Unlike os.remove, if the file does not exist,
+ # then this method does nothing.
+ try:
+ os.remove(filename)
+ except OSError:
+ pass
+
+ def rename_file(self, current_filename, new_filename):
+ s3transfer.compat.rename_file(current_filename, new_filename)
+
+
+class MultipartUploader:
+ # These are the extra_args that need to be forwarded onto
+ # subsequent upload_parts.
+ UPLOAD_PART_ARGS = [
+ 'SSECustomerKey',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ ]
+
+ def __init__(
+ self,
+ client,
+ config,
+ osutil,
+ executor_cls=concurrent.futures.ThreadPoolExecutor,
+ ):
+ self._client = client
+ self._config = config
+ self._os = osutil
+ self._executor_cls = executor_cls
+
+ def _extra_upload_part_args(self, extra_args):
+ # Only the args in UPLOAD_PART_ARGS actually need to be passed
+ # onto the upload_part calls.
+ upload_parts_args = {}
+ for key, value in extra_args.items():
+ if key in self.UPLOAD_PART_ARGS:
+ upload_parts_args[key] = value
+ return upload_parts_args
+
+ def upload_file(self, filename, bucket, key, callback, extra_args):
+ response = self._client.create_multipart_upload(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ upload_id = response['UploadId']
+ try:
+ parts = self._upload_parts(
+ upload_id, filename, bucket, key, callback, extra_args
+ )
+ except Exception as e:
+ logger.debug(
+ "Exception raised while uploading parts, "
+ "aborting multipart upload.",
+ exc_info=True,
+ )
+ self._client.abort_multipart_upload(
+ Bucket=bucket, Key=key, UploadId=upload_id
+ )
+ raise S3UploadFailedError(
+ "Failed to upload {} to {}: {}".format(
+ filename, '/'.join([bucket, key]), e
+ )
+ )
+ self._client.complete_multipart_upload(
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ MultipartUpload={'Parts': parts},
+ )
+
+ def _upload_parts(
+ self, upload_id, filename, bucket, key, callback, extra_args
+ ):
+ upload_parts_extra_args = self._extra_upload_part_args(extra_args)
+ parts = []
+ part_size = self._config.multipart_chunksize
+ num_parts = int(
+ math.ceil(self._os.get_file_size(filename) / float(part_size))
+ )
+ max_workers = self._config.max_concurrency
+ with self._executor_cls(max_workers=max_workers) as executor:
+ upload_partial = functools.partial(
+ self._upload_one_part,
+ filename,
+ bucket,
+ key,
+ upload_id,
+ part_size,
+ upload_parts_extra_args,
+ callback,
+ )
+ for part in executor.map(upload_partial, range(1, num_parts + 1)):
+ parts.append(part)
+ return parts
+
+ def _upload_one_part(
+ self,
+ filename,
+ bucket,
+ key,
+ upload_id,
+ part_size,
+ extra_args,
+ callback,
+ part_number,
+ ):
+ open_chunk_reader = self._os.open_file_chunk_reader
+ with open_chunk_reader(
+ filename, part_size * (part_number - 1), part_size, callback
+ ) as body:
+ response = self._client.upload_part(
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ Body=body,
+ **extra_args,
+ )
+ etag = response['ETag']
+ return {'ETag': etag, 'PartNumber': part_number}
+
+
+class ShutdownQueue(queue.Queue):
+ """A queue implementation that can be shutdown.
+
+ Shutting down a queue means that this class adds a
+ trigger_shutdown method that will trigger all subsequent
+ calls to put() to fail with a ``QueueShutdownError``.
+
+ It purposefully deviates from queue.Queue, and is *not* meant
+ to be a drop in replacement for ``queue.Queue``.
+
+ """
+
+ def _init(self, maxsize):
+ self._shutdown = False
+ self._shutdown_lock = threading.Lock()
+ # queue.Queue is an old style class so we don't use super().
+ return queue.Queue._init(self, maxsize)
+
+ def trigger_shutdown(self):
+ with self._shutdown_lock:
+ self._shutdown = True
+ logger.debug("The IO queue is now shutdown.")
+
+ def put(self, item):
+ # Note: this is not sufficient, it's still possible to deadlock!
+ # Need to hook into the condition vars used by this class.
+ with self._shutdown_lock:
+ if self._shutdown:
+ raise QueueShutdownError(
+ "Cannot put item to queue when " "queue has been shutdown."
+ )
+ return queue.Queue.put(self, item)
+
+
+class MultipartDownloader:
+ def __init__(
+ self,
+ client,
+ config,
+ osutil,
+ executor_cls=concurrent.futures.ThreadPoolExecutor,
+ ):
+ self._client = client
+ self._config = config
+ self._os = osutil
+ self._executor_cls = executor_cls
+ self._ioqueue = ShutdownQueue(self._config.max_io_queue)
+
+ def download_file(
+ self, bucket, key, filename, object_size, extra_args, callback=None
+ ):
+ with self._executor_cls(max_workers=2) as controller:
+ # 1 thread for the future that manages the uploading of files
+ # 1 thread for the future that manages IO writes.
+ download_parts_handler = functools.partial(
+ self._download_file_as_future,
+ bucket,
+ key,
+ filename,
+ object_size,
+ callback,
+ )
+ parts_future = controller.submit(download_parts_handler)
+
+ io_writes_handler = functools.partial(
+ self._perform_io_writes, filename
+ )
+ io_future = controller.submit(io_writes_handler)
+ results = concurrent.futures.wait(
+ [parts_future, io_future],
+ return_when=concurrent.futures.FIRST_EXCEPTION,
+ )
+ self._process_future_results(results)
+
+ def _process_future_results(self, futures):
+ finished, unfinished = futures
+ for future in finished:
+ future.result()
+
+ def _download_file_as_future(
+ self, bucket, key, filename, object_size, callback
+ ):
+ part_size = self._config.multipart_chunksize
+ num_parts = int(math.ceil(object_size / float(part_size)))
+ max_workers = self._config.max_concurrency
+ download_partial = functools.partial(
+ self._download_range,
+ bucket,
+ key,
+ filename,
+ part_size,
+ num_parts,
+ callback,
+ )
+ try:
+ with self._executor_cls(max_workers=max_workers) as executor:
+ list(executor.map(download_partial, range(num_parts)))
+ finally:
+ self._ioqueue.put(SHUTDOWN_SENTINEL)
+
+ def _calculate_range_param(self, part_size, part_index, num_parts):
+ start_range = part_index * part_size
+ if part_index == num_parts - 1:
+ end_range = ''
+ else:
+ end_range = start_range + part_size - 1
+ range_param = f'bytes={start_range}-{end_range}'
+ return range_param
+
+ def _download_range(
+ self, bucket, key, filename, part_size, num_parts, callback, part_index
+ ):
+ try:
+ range_param = self._calculate_range_param(
+ part_size, part_index, num_parts
+ )
+
+ max_attempts = self._config.num_download_attempts
+ last_exception = None
+ for i in range(max_attempts):
+ try:
+ logger.debug("Making get_object call.")
+ response = self._client.get_object(
+ Bucket=bucket, Key=key, Range=range_param
+ )
+ streaming_body = StreamReaderProgress(
+ response['Body'], callback
+ )
+ buffer_size = 1024 * 16
+ current_index = part_size * part_index
+ for chunk in iter(
+ lambda: streaming_body.read(buffer_size), b''
+ ):
+ self._ioqueue.put((current_index, chunk))
+ current_index += len(chunk)
+ return
+ except (
+ socket.timeout,
+ OSError,
+ ReadTimeoutError,
+ IncompleteReadError,
+ ) as e:
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
+ last_exception = e
+ continue
+ raise RetriesExceededError(last_exception)
+ finally:
+ logger.debug("EXITING _download_range for part: %s", part_index)
+
+ def _perform_io_writes(self, filename):
+ with self._os.open(filename, 'wb') as f:
+ while True:
+ task = self._ioqueue.get()
+ if task is SHUTDOWN_SENTINEL:
+ logger.debug(
+ "Shutdown sentinel received in IO handler, "
+ "shutting down IO handler."
+ )
+ return
+ else:
+ try:
+ offset, data = task
+ f.seek(offset)
+ f.write(data)
+ except Exception as e:
+ logger.debug(
+ "Caught exception in IO thread: %s",
+ e,
+ exc_info=True,
+ )
+ self._ioqueue.trigger_shutdown()
+ raise
+
+
+class TransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ max_concurrency=10,
+ multipart_chunksize=8 * MB,
+ num_download_attempts=5,
+ max_io_queue=100,
+ ):
+ self.multipart_threshold = multipart_threshold
+ self.max_concurrency = max_concurrency
+ self.multipart_chunksize = multipart_chunksize
+ self.num_download_attempts = num_download_attempts
+ self.max_io_queue = max_io_queue
+
+
+class S3Transfer:
+
+ ALLOWED_DOWNLOAD_ARGS = [
+ 'VersionId',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKey',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ ]
+
+ ALLOWED_UPLOAD_ARGS = [
+ 'ACL',
+ 'CacheControl',
+ 'ContentDisposition',
+ 'ContentEncoding',
+ 'ContentLanguage',
+ 'ContentType',
+ 'Expires',
+ 'GrantFullControl',
+ 'GrantRead',
+ 'GrantReadACP',
+ 'GrantWriteACL',
+ 'Metadata',
+ 'RequestPayer',
+ 'ServerSideEncryption',
+ 'StorageClass',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKey',
+ 'SSECustomerKeyMD5',
+ 'SSEKMSKeyId',
+ 'SSEKMSEncryptionContext',
+ 'Tagging',
+ ]
+
+ def __init__(self, client, config=None, osutil=None):
+ self._client = client
+ if config is None:
+ config = TransferConfig()
+ self._config = config
+ if osutil is None:
+ osutil = OSUtils()
+ self._osutil = osutil
+
+ def upload_file(
+ self, filename, bucket, key, callback=None, extra_args=None
+ ):
+ """Upload a file to an S3 object.
+
+ Variants have also been injected into S3 client, Bucket and Object.
+ You don't have to use S3Transfer.upload_file() directly.
+ """
+ if extra_args is None:
+ extra_args = {}
+ self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS)
+ events = self._client.meta.events
+ events.register_first(
+ 'request-created.s3',
+ disable_upload_callbacks,
+ unique_id='s3upload-callback-disable',
+ )
+ events.register_last(
+ 'request-created.s3',
+ enable_upload_callbacks,
+ unique_id='s3upload-callback-enable',
+ )
+ if (
+ self._osutil.get_file_size(filename)
+ >= self._config.multipart_threshold
+ ):
+ self._multipart_upload(filename, bucket, key, callback, extra_args)
+ else:
+ self._put_object(filename, bucket, key, callback, extra_args)
+
+ def _put_object(self, filename, bucket, key, callback, extra_args):
+ # We're using open_file_chunk_reader so we can take advantage of the
+ # progress callback functionality.
+ open_chunk_reader = self._osutil.open_file_chunk_reader
+ with open_chunk_reader(
+ filename,
+ 0,
+ self._osutil.get_file_size(filename),
+ callback=callback,
+ ) as body:
+ self._client.put_object(
+ Bucket=bucket, Key=key, Body=body, **extra_args
+ )
+
+ def download_file(
+ self, bucket, key, filename, extra_args=None, callback=None
+ ):
+ """Download an S3 object to a file.
+
+ Variants have also been injected into S3 client, Bucket and Object.
+ You don't have to use S3Transfer.download_file() directly.
+ """
+ # This method will issue a ``head_object`` request to determine
+ # the size of the S3 object. This is used to determine if the
+ # object is downloaded in parallel.
+ if extra_args is None:
+ extra_args = {}
+ self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS)
+ object_size = self._object_size(bucket, key, extra_args)
+ temp_filename = filename + os.extsep + random_file_extension()
+ try:
+ self._download_file(
+ bucket, key, temp_filename, object_size, extra_args, callback
+ )
+ except Exception:
+ logger.debug(
+ "Exception caught in download_file, removing partial "
+ "file: %s",
+ temp_filename,
+ exc_info=True,
+ )
+ self._osutil.remove_file(temp_filename)
+ raise
+ else:
+ self._osutil.rename_file(temp_filename, filename)
+
+ def _download_file(
+ self, bucket, key, filename, object_size, extra_args, callback
+ ):
+ if object_size >= self._config.multipart_threshold:
+ self._ranged_download(
+ bucket, key, filename, object_size, extra_args, callback
+ )
+ else:
+ self._get_object(bucket, key, filename, extra_args, callback)
+
+ def _validate_all_known_args(self, actual, allowed):
+ for kwarg in actual:
+ if kwarg not in allowed:
+ raise ValueError(
+ "Invalid extra_args key '%s', "
+ "must be one of: %s" % (kwarg, ', '.join(allowed))
+ )
+
+ def _ranged_download(
+ self, bucket, key, filename, object_size, extra_args, callback
+ ):
+ downloader = MultipartDownloader(
+ self._client, self._config, self._osutil
+ )
+ downloader.download_file(
+ bucket, key, filename, object_size, extra_args, callback
+ )
+
+ def _get_object(self, bucket, key, filename, extra_args, callback):
+ # precondition: num_download_attempts > 0
+ max_attempts = self._config.num_download_attempts
+ last_exception = None
+ for i in range(max_attempts):
+ try:
+ return self._do_get_object(
+ bucket, key, filename, extra_args, callback
+ )
+ except (
+ socket.timeout,
+ OSError,
+ ReadTimeoutError,
+ IncompleteReadError,
+ ) as e:
+ # TODO: we need a way to reset the callback if the
+ # download failed.
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
+ last_exception = e
+ continue
+ raise RetriesExceededError(last_exception)
+
+ def _do_get_object(self, bucket, key, filename, extra_args, callback):
+ response = self._client.get_object(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ streaming_body = StreamReaderProgress(response['Body'], callback)
+ with self._osutil.open(filename, 'wb') as f:
+ for chunk in iter(lambda: streaming_body.read(8192), b''):
+ f.write(chunk)
+
+ def _object_size(self, bucket, key, extra_args):
+ return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[
+ 'ContentLength'
+ ]
+
+ def _multipart_upload(self, filename, bucket, key, callback, extra_args):
+ uploader = MultipartUploader(self._client, self._config, self._osutil)
+ uploader.upload_file(filename, bucket, key, callback, extra_args)
diff --git a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py
new file mode 100644
index 0000000000..9bac5885e1
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py
@@ -0,0 +1,439 @@
+# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import threading
+import time
+
+
+class RequestExceededException(Exception):
+ def __init__(self, requested_amt, retry_time):
+ """Error when requested amount exceeds what is allowed
+
+ The request that raised this error should be retried after waiting
+ the time specified by ``retry_time``.
+
+ :type requested_amt: int
+ :param requested_amt: The originally requested byte amount
+
+ :type retry_time: float
+ :param retry_time: The length in time to wait to retry for the
+ requested amount
+ """
+ self.requested_amt = requested_amt
+ self.retry_time = retry_time
+ msg = 'Request amount {} exceeded the amount available. Retry in {}'.format(
+ requested_amt, retry_time
+ )
+ super().__init__(msg)
+
+
+class RequestToken:
+ """A token to pass as an identifier when consuming from the LeakyBucket"""
+
+ pass
+
+
+class TimeUtils:
+ def time(self):
+ """Get the current time back
+
+ :rtype: float
+ :returns: The current time in seconds
+ """
+ return time.time()
+
+ def sleep(self, value):
+ """Sleep for a designated time
+
+ :type value: float
+ :param value: The time to sleep for in seconds
+ """
+ return time.sleep(value)
+
+
+class BandwidthLimiter:
+ def __init__(self, leaky_bucket, time_utils=None):
+ """Limits bandwidth for shared S3 transfers
+
+ :type leaky_bucket: LeakyBucket
+ :param leaky_bucket: The leaky bucket to use limit bandwidth
+
+ :type time_utils: TimeUtils
+ :param time_utils: Time utility to use for interacting with time.
+ """
+ self._leaky_bucket = leaky_bucket
+ self._time_utils = time_utils
+ if time_utils is None:
+ self._time_utils = TimeUtils()
+
+ def get_bandwith_limited_stream(
+ self, fileobj, transfer_coordinator, enabled=True
+ ):
+ """Wraps a fileobj in a bandwidth limited stream wrapper
+
+ :type fileobj: file-like obj
+ :param fileobj: The file-like obj to wrap
+
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ param transfer_coordinator: The coordinator for the general transfer
+ that the wrapped stream is a part of
+
+ :type enabled: boolean
+ :param enabled: Whether bandwidth limiting should be enabled to start
+ """
+ stream = BandwidthLimitedStream(
+ fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils
+ )
+ if not enabled:
+ stream.disable_bandwidth_limiting()
+ return stream
+
+
+class BandwidthLimitedStream:
+ def __init__(
+ self,
+ fileobj,
+ leaky_bucket,
+ transfer_coordinator,
+ time_utils=None,
+ bytes_threshold=256 * 1024,
+ ):
+ """Limits bandwidth for reads on a wrapped stream
+
+ :type fileobj: file-like object
+ :param fileobj: The file like object to wrap
+
+ :type leaky_bucket: LeakyBucket
+ :param leaky_bucket: The leaky bucket to use to throttle reads on
+ the stream
+
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ param transfer_coordinator: The coordinator for the general transfer
+ that the wrapped stream is a part of
+
+ :type time_utils: TimeUtils
+ :param time_utils: The time utility to use for interacting with time
+ """
+ self._fileobj = fileobj
+ self._leaky_bucket = leaky_bucket
+ self._transfer_coordinator = transfer_coordinator
+ self._time_utils = time_utils
+ if time_utils is None:
+ self._time_utils = TimeUtils()
+ self._bandwidth_limiting_enabled = True
+ self._request_token = RequestToken()
+ self._bytes_seen = 0
+ self._bytes_threshold = bytes_threshold
+
+ def enable_bandwidth_limiting(self):
+ """Enable bandwidth limiting on reads to the stream"""
+ self._bandwidth_limiting_enabled = True
+
+ def disable_bandwidth_limiting(self):
+ """Disable bandwidth limiting on reads to the stream"""
+ self._bandwidth_limiting_enabled = False
+
+ def read(self, amount):
+ """Read a specified amount
+
+ Reads will only be throttled if bandwidth limiting is enabled.
+ """
+ if not self._bandwidth_limiting_enabled:
+ return self._fileobj.read(amount)
+
+ # We do not want to be calling consume on every read as the read
+ # amounts can be small causing the lock of the leaky bucket to
+ # introduce noticeable overhead. So instead we keep track of
+ # how many bytes we have seen and only call consume once we pass a
+ # certain threshold.
+ self._bytes_seen += amount
+ if self._bytes_seen < self._bytes_threshold:
+ return self._fileobj.read(amount)
+
+ self._consume_through_leaky_bucket()
+ return self._fileobj.read(amount)
+
+ def _consume_through_leaky_bucket(self):
+ # NOTE: If the read amount on the stream are high, it will result
+ # in large bursty behavior as there is not an interface for partial
+ # reads. However given the read's on this abstraction are at most 256KB
+ # (via downloads), it reduces the burstiness to be small KB bursts at
+ # worst.
+ while not self._transfer_coordinator.exception:
+ try:
+ self._leaky_bucket.consume(
+ self._bytes_seen, self._request_token
+ )
+ self._bytes_seen = 0
+ return
+ except RequestExceededException as e:
+ self._time_utils.sleep(e.retry_time)
+ else:
+ raise self._transfer_coordinator.exception
+
+ def signal_transferring(self):
+ """Signal that data being read is being transferred to S3"""
+ self.enable_bandwidth_limiting()
+
+ def signal_not_transferring(self):
+ """Signal that data being read is not being transferred to S3"""
+ self.disable_bandwidth_limiting()
+
+ def seek(self, where, whence=0):
+ self._fileobj.seek(where, whence)
+
+ def tell(self):
+ return self._fileobj.tell()
+
+ def close(self):
+ if self._bandwidth_limiting_enabled and self._bytes_seen:
+ # This handles the case where the file is small enough to never
+ # trigger the threshold and thus is never subjugated to the
+ # leaky bucket on read(). This specifically happens for small
+ # uploads. So instead to account for those bytes, have
+ # it go through the leaky bucket when the file gets closed.
+ self._consume_through_leaky_bucket()
+ self._fileobj.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+
+
+class LeakyBucket:
+ def __init__(
+ self,
+ max_rate,
+ time_utils=None,
+ rate_tracker=None,
+ consumption_scheduler=None,
+ ):
+ """A leaky bucket abstraction to limit bandwidth consumption
+
+ :type rate: int
+ :type rate: The maximum rate to allow. This rate is in terms of
+ bytes per second.
+
+ :type time_utils: TimeUtils
+ :param time_utils: The time utility to use for interacting with time
+
+ :type rate_tracker: BandwidthRateTracker
+ :param rate_tracker: Tracks bandwidth consumption
+
+ :type consumption_scheduler: ConsumptionScheduler
+ :param consumption_scheduler: Schedules consumption retries when
+ necessary
+ """
+ self._max_rate = float(max_rate)
+ self._time_utils = time_utils
+ if time_utils is None:
+ self._time_utils = TimeUtils()
+ self._lock = threading.Lock()
+ self._rate_tracker = rate_tracker
+ if rate_tracker is None:
+ self._rate_tracker = BandwidthRateTracker()
+ self._consumption_scheduler = consumption_scheduler
+ if consumption_scheduler is None:
+ self._consumption_scheduler = ConsumptionScheduler()
+
+ def consume(self, amt, request_token):
+ """Consume an a requested amount
+
+ :type amt: int
+ :param amt: The amount of bytes to request to consume
+
+ :type request_token: RequestToken
+ :param request_token: The token associated to the consumption
+ request that is used to identify the request. So if a
+ RequestExceededException is raised the token should be used
+ in subsequent retry consume() request.
+
+ :raises RequestExceededException: If the consumption amount would
+ exceed the maximum allocated bandwidth
+
+ :rtype: int
+ :returns: The amount consumed
+ """
+ with self._lock:
+ time_now = self._time_utils.time()
+ if self._consumption_scheduler.is_scheduled(request_token):
+ return self._release_requested_amt_for_scheduled_request(
+ amt, request_token, time_now
+ )
+ elif self._projected_to_exceed_max_rate(amt, time_now):
+ self._raise_request_exceeded_exception(
+ amt, request_token, time_now
+ )
+ else:
+ return self._release_requested_amt(amt, time_now)
+
+ def _projected_to_exceed_max_rate(self, amt, time_now):
+ projected_rate = self._rate_tracker.get_projected_rate(amt, time_now)
+ return projected_rate > self._max_rate
+
+ def _release_requested_amt_for_scheduled_request(
+ self, amt, request_token, time_now
+ ):
+ self._consumption_scheduler.process_scheduled_consumption(
+ request_token
+ )
+ return self._release_requested_amt(amt, time_now)
+
+ def _raise_request_exceeded_exception(self, amt, request_token, time_now):
+ allocated_time = amt / float(self._max_rate)
+ retry_time = self._consumption_scheduler.schedule_consumption(
+ amt, request_token, allocated_time
+ )
+ raise RequestExceededException(
+ requested_amt=amt, retry_time=retry_time
+ )
+
+ def _release_requested_amt(self, amt, time_now):
+ self._rate_tracker.record_consumption_rate(amt, time_now)
+ return amt
+
+
+class ConsumptionScheduler:
+ def __init__(self):
+ """Schedules when to consume a desired amount"""
+ self._tokens_to_scheduled_consumption = {}
+ self._total_wait = 0
+
+ def is_scheduled(self, token):
+ """Indicates if a consumption request has been scheduled
+
+ :type token: RequestToken
+ :param token: The token associated to the consumption
+ request that is used to identify the request.
+ """
+ return token in self._tokens_to_scheduled_consumption
+
+ def schedule_consumption(self, amt, token, time_to_consume):
+ """Schedules a wait time to be able to consume an amount
+
+ :type amt: int
+ :param amt: The amount of bytes scheduled to be consumed
+
+ :type token: RequestToken
+ :param token: The token associated to the consumption
+ request that is used to identify the request.
+
+ :type time_to_consume: float
+ :param time_to_consume: The desired time it should take for that
+ specific request amount to be consumed in regardless of previously
+ scheduled consumption requests
+
+ :rtype: float
+ :returns: The amount of time to wait for the specific request before
+ actually consuming the specified amount.
+ """
+ self._total_wait += time_to_consume
+ self._tokens_to_scheduled_consumption[token] = {
+ 'wait_duration': self._total_wait,
+ 'time_to_consume': time_to_consume,
+ }
+ return self._total_wait
+
+ def process_scheduled_consumption(self, token):
+ """Processes a scheduled consumption request that has completed
+
+ :type token: RequestToken
+ :param token: The token associated to the consumption
+ request that is used to identify the request.
+ """
+ scheduled_retry = self._tokens_to_scheduled_consumption.pop(token)
+ self._total_wait = max(
+ self._total_wait - scheduled_retry['time_to_consume'], 0
+ )
+
+
+class BandwidthRateTracker:
+ def __init__(self, alpha=0.8):
+ """Tracks the rate of bandwidth consumption
+
+ :type a: float
+ :param a: The constant to use in calculating the exponentional moving
+ average of the bandwidth rate. Specifically it is used in the
+ following calculation:
+
+ current_rate = alpha * new_rate + (1 - alpha) * current_rate
+
+ This value of this constant should be between 0 and 1.
+ """
+ self._alpha = alpha
+ self._last_time = None
+ self._current_rate = None
+
+ @property
+ def current_rate(self):
+ """The current transfer rate
+
+ :rtype: float
+ :returns: The current tracked transfer rate
+ """
+ if self._last_time is None:
+ return 0.0
+ return self._current_rate
+
+ def get_projected_rate(self, amt, time_at_consumption):
+ """Get the projected rate using a provided amount and time
+
+ :type amt: int
+ :param amt: The proposed amount to consume
+
+ :type time_at_consumption: float
+ :param time_at_consumption: The proposed time to consume at
+
+ :rtype: float
+ :returns: The consumption rate if that amt and time were consumed
+ """
+ if self._last_time is None:
+ return 0.0
+ return self._calculate_exponential_moving_average_rate(
+ amt, time_at_consumption
+ )
+
+ def record_consumption_rate(self, amt, time_at_consumption):
+ """Record the consumption rate based off amount and time point
+
+ :type amt: int
+ :param amt: The amount that got consumed
+
+ :type time_at_consumption: float
+ :param time_at_consumption: The time at which the amount was consumed
+ """
+ if self._last_time is None:
+ self._last_time = time_at_consumption
+ self._current_rate = 0.0
+ return
+ self._current_rate = self._calculate_exponential_moving_average_rate(
+ amt, time_at_consumption
+ )
+ self._last_time = time_at_consumption
+
+ def _calculate_rate(self, amt, time_at_consumption):
+ time_delta = time_at_consumption - self._last_time
+ if time_delta <= 0:
+ # While it is really unlikely to see this in an actual transfer,
+ # we do not want to be returning back a negative rate or try to
+ # divide the amount by zero. So instead return back an infinite
+ # rate as the time delta is infinitesimally small.
+ return float('inf')
+ return amt / (time_delta)
+
+ def _calculate_exponential_moving_average_rate(
+ self, amt, time_at_consumption
+ ):
+ new_rate = self._calculate_rate(amt, time_at_consumption)
+ return self._alpha * new_rate + (1 - self._alpha) * self._current_rate
diff --git a/contrib/python/s3transfer/py3/s3transfer/compat.py b/contrib/python/s3transfer/py3/s3transfer/compat.py
new file mode 100644
index 0000000000..68267ad0e2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/compat.py
@@ -0,0 +1,94 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import errno
+import inspect
+import os
+import socket
+import sys
+
+from botocore.compat import six
+
+if sys.platform.startswith('win'):
+ def rename_file(current_filename, new_filename):
+ try:
+ os.remove(new_filename)
+ except OSError as e:
+ if not e.errno == errno.ENOENT:
+ # We only want to a ignore trying to remove
+ # a file that does not exist. If it fails
+ # for any other reason we should be propagating
+ # that exception.
+ raise
+ os.rename(current_filename, new_filename)
+else:
+ rename_file = os.rename
+
+
+def accepts_kwargs(func):
+ return inspect.getfullargspec(func)[2]
+
+
+# In python 3, socket.error is OSError, which is too general
+# for what we want (i.e FileNotFoundError is a subclass of OSError).
+# In python 3, all the socket related errors are in a newly created
+# ConnectionError.
+SOCKET_ERROR = ConnectionError
+MAXINT = None
+
+
+def seekable(fileobj):
+ """Backwards compat function to determine if a fileobj is seekable
+
+ :param fileobj: The file-like object to determine if seekable
+
+ :returns: True, if seekable. False, otherwise.
+ """
+ # If the fileobj has a seekable attr, try calling the seekable()
+ # method on it.
+ if hasattr(fileobj, 'seekable'):
+ return fileobj.seekable()
+ # If there is no seekable attr, check if the object can be seeked
+ # or telled. If it can, try to seek to the current position.
+ elif hasattr(fileobj, 'seek') and hasattr(fileobj, 'tell'):
+ try:
+ fileobj.seek(0, 1)
+ return True
+ except OSError:
+ # If an io related error was thrown then it is not seekable.
+ return False
+ # Else, the fileobj is not seekable
+ return False
+
+
+def readable(fileobj):
+ """Determines whether or not a file-like object is readable.
+
+ :param fileobj: The file-like object to determine if readable
+
+ :returns: True, if readable. False otherwise.
+ """
+ if hasattr(fileobj, 'readable'):
+ return fileobj.readable()
+
+ return hasattr(fileobj, 'read')
+
+
+def fallocate(fileobj, size):
+ if hasattr(os, 'posix_fallocate'):
+ os.posix_fallocate(fileobj.fileno(), 0, size)
+ else:
+ fileobj.truncate(size)
+
+
+# Import at end of file to avoid circular dependencies
+from multiprocessing.managers import BaseManager # noqa: F401,E402
diff --git a/contrib/python/s3transfer/py3/s3transfer/constants.py b/contrib/python/s3transfer/py3/s3transfer/constants.py
new file mode 100644
index 0000000000..ba35bc72e9
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/constants.py
@@ -0,0 +1,29 @@
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import s3transfer
+
+KB = 1024
+MB = KB * KB
+GB = MB * KB
+
+ALLOWED_DOWNLOAD_ARGS = [
+ 'VersionId',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKey',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ 'ExpectedBucketOwner',
+]
+
+USER_AGENT = 's3transfer/%s' % s3transfer.__version__
+PROCESS_USER_AGENT = '%s processpool' % USER_AGENT
diff --git a/contrib/python/s3transfer/py3/s3transfer/copies.py b/contrib/python/s3transfer/py3/s3transfer/copies.py
new file mode 100644
index 0000000000..a1dfdc8ba3
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/copies.py
@@ -0,0 +1,368 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import math
+
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import (
+ ChunksizeAdjuster,
+ calculate_range_parameter,
+ get_callbacks,
+ get_filtered_dict,
+)
+
+
+class CopySubmissionTask(SubmissionTask):
+ """Task for submitting tasks to execute a copy"""
+
+ EXTRA_ARGS_TO_HEAD_ARGS_MAPPING = {
+ 'CopySourceIfMatch': 'IfMatch',
+ 'CopySourceIfModifiedSince': 'IfModifiedSince',
+ 'CopySourceIfNoneMatch': 'IfNoneMatch',
+ 'CopySourceIfUnmodifiedSince': 'IfUnmodifiedSince',
+ 'CopySourceSSECustomerKey': 'SSECustomerKey',
+ 'CopySourceSSECustomerAlgorithm': 'SSECustomerAlgorithm',
+ 'CopySourceSSECustomerKeyMD5': 'SSECustomerKeyMD5',
+ 'RequestPayer': 'RequestPayer',
+ 'ExpectedBucketOwner': 'ExpectedBucketOwner',
+ }
+
+ UPLOAD_PART_COPY_ARGS = [
+ 'CopySourceIfMatch',
+ 'CopySourceIfModifiedSince',
+ 'CopySourceIfNoneMatch',
+ 'CopySourceIfUnmodifiedSince',
+ 'CopySourceSSECustomerKey',
+ 'CopySourceSSECustomerAlgorithm',
+ 'CopySourceSSECustomerKeyMD5',
+ 'SSECustomerKey',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ 'ExpectedBucketOwner',
+ ]
+
+ CREATE_MULTIPART_ARGS_BLACKLIST = [
+ 'CopySourceIfMatch',
+ 'CopySourceIfModifiedSince',
+ 'CopySourceIfNoneMatch',
+ 'CopySourceIfUnmodifiedSince',
+ 'CopySourceSSECustomerKey',
+ 'CopySourceSSECustomerAlgorithm',
+ 'CopySourceSSECustomerKeyMD5',
+ 'MetadataDirective',
+ 'TaggingDirective',
+ ]
+
+ COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
+
+ def _submit(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
+ """
+ :param client: The client associated with the transfer manager
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The transfer config associated with the transfer
+ manager
+
+ :type osutil: s3transfer.utils.OSUtil
+ :param osutil: The os utility associated to the transfer manager
+
+ :type request_executor: s3transfer.futures.BoundedExecutor
+ :param request_executor: The request executor associated with the
+ transfer manager
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+ """
+ # Determine the size if it was not provided
+ if transfer_future.meta.size is None:
+ # If a size was not provided figure out the size for the
+ # user. Note that we will only use the client provided to
+ # the TransferManager. If the object is outside of the region
+ # of the client, they may have to provide the file size themselves
+ # with a completely new client.
+ call_args = transfer_future.meta.call_args
+ head_object_request = (
+ self._get_head_object_request_from_copy_source(
+ call_args.copy_source
+ )
+ )
+ extra_args = call_args.extra_args
+
+ # Map any values that may be used in the head object that is
+ # used in the copy object
+ for param, value in extra_args.items():
+ if param in self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING:
+ head_object_request[
+ self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param]
+ ] = value
+
+ response = call_args.source_client.head_object(
+ **head_object_request
+ )
+ transfer_future.meta.provide_transfer_size(
+ response['ContentLength']
+ )
+
+ # If it is greater than threshold do a multipart copy, otherwise
+ # do a regular copy object.
+ if transfer_future.meta.size < config.multipart_threshold:
+ self._submit_copy_request(
+ client, config, osutil, request_executor, transfer_future
+ )
+ else:
+ self._submit_multipart_request(
+ client, config, osutil, request_executor, transfer_future
+ )
+
+ def _submit_copy_request(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Get the needed progress callbacks for the task
+ progress_callbacks = get_callbacks(transfer_future, 'progress')
+
+ # Submit the request of a single copy.
+ self._transfer_coordinator.submit(
+ request_executor,
+ CopyObjectTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'copy_source': call_args.copy_source,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': call_args.extra_args,
+ 'callbacks': progress_callbacks,
+ 'size': transfer_future.meta.size,
+ },
+ is_final=True,
+ ),
+ )
+
+ def _submit_multipart_request(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Submit the request to create a multipart upload and make sure it
+ # does not include any of the arguments used for copy part.
+ create_multipart_extra_args = {}
+ for param, val in call_args.extra_args.items():
+ if param not in self.CREATE_MULTIPART_ARGS_BLACKLIST:
+ create_multipart_extra_args[param] = val
+
+ create_multipart_future = self._transfer_coordinator.submit(
+ request_executor,
+ CreateMultipartUploadTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': create_multipart_extra_args,
+ },
+ ),
+ )
+
+ # Determine how many parts are needed based on filesize and
+ # desired chunksize.
+ part_size = config.multipart_chunksize
+ adjuster = ChunksizeAdjuster()
+ part_size = adjuster.adjust_chunksize(
+ part_size, transfer_future.meta.size
+ )
+ num_parts = int(
+ math.ceil(transfer_future.meta.size / float(part_size))
+ )
+
+ # Submit requests to upload the parts of the file.
+ part_futures = []
+ progress_callbacks = get_callbacks(transfer_future, 'progress')
+
+ for part_number in range(1, num_parts + 1):
+ extra_part_args = self._extra_upload_part_args(
+ call_args.extra_args
+ )
+ # The part number for upload part starts at 1 while the
+ # range parameter starts at zero, so just subtract 1 off of
+ # the part number
+ extra_part_args['CopySourceRange'] = calculate_range_parameter(
+ part_size,
+ part_number - 1,
+ num_parts,
+ transfer_future.meta.size,
+ )
+ # Get the size of the part copy as well for the progress
+ # callbacks.
+ size = self._get_transfer_size(
+ part_size,
+ part_number - 1,
+ num_parts,
+ transfer_future.meta.size,
+ )
+ part_futures.append(
+ self._transfer_coordinator.submit(
+ request_executor,
+ CopyPartTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'copy_source': call_args.copy_source,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'part_number': part_number,
+ 'extra_args': extra_part_args,
+ 'callbacks': progress_callbacks,
+ 'size': size,
+ },
+ pending_main_kwargs={
+ 'upload_id': create_multipart_future
+ },
+ ),
+ )
+ )
+
+ complete_multipart_extra_args = self._extra_complete_multipart_args(
+ call_args.extra_args
+ )
+ # Submit the request to complete the multipart upload.
+ self._transfer_coordinator.submit(
+ request_executor,
+ CompleteMultipartUploadTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': complete_multipart_extra_args,
+ },
+ pending_main_kwargs={
+ 'upload_id': create_multipart_future,
+ 'parts': part_futures,
+ },
+ is_final=True,
+ ),
+ )
+
+ def _get_head_object_request_from_copy_source(self, copy_source):
+ if isinstance(copy_source, dict):
+ return copy.copy(copy_source)
+ else:
+ raise TypeError(
+ 'Expecting dictionary formatted: '
+ '{"Bucket": bucket_name, "Key": key} '
+ 'but got %s or type %s.' % (copy_source, type(copy_source))
+ )
+
+ def _extra_upload_part_args(self, extra_args):
+ # Only the args in COPY_PART_ARGS actually need to be passed
+ # onto the upload_part_copy calls.
+ return get_filtered_dict(extra_args, self.UPLOAD_PART_COPY_ARGS)
+
+ def _extra_complete_multipart_args(self, extra_args):
+ return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS)
+
+ def _get_transfer_size(
+ self, part_size, part_index, num_parts, total_transfer_size
+ ):
+ if part_index == num_parts - 1:
+ # The last part may be different in size then the rest of the
+ # parts.
+ return total_transfer_size - (part_index * part_size)
+ return part_size
+
+
+class CopyObjectTask(Task):
+ """Task to do a nonmultipart copy"""
+
+ def _main(
+ self, client, copy_source, bucket, key, extra_args, callbacks, size
+ ):
+ """
+ :param client: The client to use when calling PutObject
+ :param copy_source: The CopySource parameter to use
+ :param bucket: The name of the bucket to copy to
+ :param key: The name of the key to copy to
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in the upload.
+ :param callbacks: List of callbacks to call after copy
+ :param size: The size of the transfer. This value is passed into
+ the callbacks
+
+ """
+ client.copy_object(
+ CopySource=copy_source, Bucket=bucket, Key=key, **extra_args
+ )
+ for callback in callbacks:
+ callback(bytes_transferred=size)
+
+
+class CopyPartTask(Task):
+ """Task to upload a part in a multipart copy"""
+
+ def _main(
+ self,
+ client,
+ copy_source,
+ bucket,
+ key,
+ upload_id,
+ part_number,
+ extra_args,
+ callbacks,
+ size,
+ ):
+ """
+ :param client: The client to use when calling PutObject
+ :param copy_source: The CopySource parameter to use
+ :param bucket: The name of the bucket to upload to
+ :param key: The name of the key to upload to
+ :param upload_id: The id of the upload
+ :param part_number: The number representing the part of the multipart
+ upload
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in the upload.
+ :param callbacks: List of callbacks to call after copy part
+ :param size: The size of the transfer. This value is passed into
+ the callbacks
+
+ :rtype: dict
+ :returns: A dictionary representing a part::
+
+ {'Etag': etag_value, 'PartNumber': part_number}
+
+ This value can be appended to a list to be used to complete
+ the multipart upload.
+ """
+ response = client.upload_part_copy(
+ CopySource=copy_source,
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ **extra_args
+ )
+ for callback in callbacks:
+ callback(bytes_transferred=size)
+ etag = response['CopyPartResult']['ETag']
+ return {'ETag': etag, 'PartNumber': part_number}
diff --git a/contrib/python/s3transfer/py3/s3transfer/crt.py b/contrib/python/s3transfer/py3/s3transfer/crt.py
new file mode 100644
index 0000000000..7b5d130136
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/crt.py
@@ -0,0 +1,644 @@
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import logging
+import threading
+from io import BytesIO
+
+import awscrt.http
+import botocore.awsrequest
+import botocore.session
+from awscrt.auth import AwsCredentials, AwsCredentialsProvider
+from awscrt.io import (
+ ClientBootstrap,
+ ClientTlsContext,
+ DefaultHostResolver,
+ EventLoopGroup,
+ TlsContextOptions,
+)
+from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
+from botocore import UNSIGNED
+from botocore.compat import urlsplit
+from botocore.config import Config
+from botocore.exceptions import NoCredentialsError
+
+from s3transfer.constants import GB, MB
+from s3transfer.exceptions import TransferNotDoneError
+from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
+from s3transfer.utils import CallArgs, OSUtils, get_callbacks
+
+logger = logging.getLogger(__name__)
+
+
+class CRTCredentialProviderAdapter:
+ def __init__(self, botocore_credential_provider):
+ self._botocore_credential_provider = botocore_credential_provider
+ self._loaded_credentials = None
+ self._lock = threading.Lock()
+
+ def __call__(self):
+ credentials = self._get_credentials().get_frozen_credentials()
+ return AwsCredentials(
+ credentials.access_key, credentials.secret_key, credentials.token
+ )
+
+ def _get_credentials(self):
+ with self._lock:
+ if self._loaded_credentials is None:
+ loaded_creds = (
+ self._botocore_credential_provider.load_credentials()
+ )
+ if loaded_creds is None:
+ raise NoCredentialsError()
+ self._loaded_credentials = loaded_creds
+ return self._loaded_credentials
+
+
+def create_s3_crt_client(
+ region,
+ botocore_credential_provider=None,
+ num_threads=None,
+ target_throughput=5 * GB / 8,
+ part_size=8 * MB,
+ use_ssl=True,
+ verify=None,
+):
+ """
+ :type region: str
+ :param region: The region used for signing
+
+ :type botocore_credential_provider:
+ Optional[botocore.credentials.CredentialResolver]
+ :param botocore_credential_provider: Provide credentials for CRT
+ to sign the request if not set, the request will not be signed
+
+ :type num_threads: Optional[int]
+ :param num_threads: Number of worker threads generated. Default
+ is the number of processors in the machine.
+
+ :type target_throughput: Optional[int]
+ :param target_throughput: Throughput target in Bytes.
+ Default is 0.625 GB/s (which translates to 5 Gb/s).
+
+ :type part_size: Optional[int]
+ :param part_size: Size, in Bytes, of parts that files will be downloaded
+ or uploaded in.
+
+ :type use_ssl: boolean
+ :param use_ssl: Whether or not to use SSL. By default, SSL is used.
+ Note that not all services support non-ssl connections.
+
+ :type verify: Optional[boolean/string]
+ :param verify: Whether or not to verify SSL certificates.
+ By default SSL certificates are verified. You can provide the
+ following values:
+
+ * False - do not validate SSL certificates. SSL will still be
+ used (unless use_ssl is False), but SSL certificates
+ will not be verified.
+ * path/to/cert/bundle.pem - A filename of the CA cert bundle to
+ use. Specify this argument if you want to use a custom CA cert
+ bundle instead of the default one on your system.
+ """
+
+ event_loop_group = EventLoopGroup(num_threads)
+ host_resolver = DefaultHostResolver(event_loop_group)
+ bootstrap = ClientBootstrap(event_loop_group, host_resolver)
+ provider = None
+ tls_connection_options = None
+
+ tls_mode = (
+ S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED
+ )
+ if verify is not None:
+ tls_ctx_options = TlsContextOptions()
+ if verify:
+ tls_ctx_options.override_default_trust_store_from_path(
+ ca_filepath=verify
+ )
+ else:
+ tls_ctx_options.verify_peer = False
+ client_tls_option = ClientTlsContext(tls_ctx_options)
+ tls_connection_options = client_tls_option.new_connection_options()
+ if botocore_credential_provider:
+ credentails_provider_adapter = CRTCredentialProviderAdapter(
+ botocore_credential_provider
+ )
+ provider = AwsCredentialsProvider.new_delegate(
+ credentails_provider_adapter
+ )
+
+ target_gbps = target_throughput * 8 / GB
+ return S3Client(
+ bootstrap=bootstrap,
+ region=region,
+ credential_provider=provider,
+ part_size=part_size,
+ tls_mode=tls_mode,
+ tls_connection_options=tls_connection_options,
+ throughput_target_gbps=target_gbps,
+ )
+
+
+class CRTTransferManager:
+ def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
+ """A transfer manager interface for Amazon S3 on CRT s3 client.
+
+ :type crt_s3_client: awscrt.s3.S3Client
+ :param crt_s3_client: The CRT s3 client, handling all the
+ HTTP requests and functions under then hood
+
+ :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer
+ :param crt_request_serializer: Serializer, generates unsigned crt HTTP
+ request.
+
+ :type osutil: s3transfer.utils.OSUtils
+ :param osutil: OSUtils object to use for os-related behavior when
+ using with transfer manager.
+ """
+ if osutil is None:
+ self._osutil = OSUtils()
+ self._crt_s3_client = crt_s3_client
+ self._s3_args_creator = S3ClientArgsCreator(
+ crt_request_serializer, self._osutil
+ )
+ self._future_coordinators = []
+ self._semaphore = threading.Semaphore(128) # not configurable
+ # A counter to create unique id's for each transfer submitted.
+ self._id_counter = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, *args):
+ cancel = False
+ if exc_type:
+ cancel = True
+ self._shutdown(cancel)
+
+ def download(
+ self, bucket, key, fileobj, extra_args=None, subscribers=None
+ ):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("get_object", callargs)
+
+ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("put_object", callargs)
+
+ def delete(self, bucket, key, extra_args=None, subscribers=None):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("delete_object", callargs)
+
+ def shutdown(self, cancel=False):
+ self._shutdown(cancel)
+
+ def _cancel_transfers(self):
+ for coordinator in self._future_coordinators:
+ if not coordinator.done():
+ coordinator.cancel()
+
+ def _finish_transfers(self):
+ for coordinator in self._future_coordinators:
+ coordinator.result()
+
+ def _wait_transfers_done(self):
+ for coordinator in self._future_coordinators:
+ coordinator.wait_until_on_done_callbacks_complete()
+
+ def _shutdown(self, cancel=False):
+ if cancel:
+ self._cancel_transfers()
+ try:
+ self._finish_transfers()
+
+ except KeyboardInterrupt:
+ self._cancel_transfers()
+ except Exception:
+ pass
+ finally:
+ self._wait_transfers_done()
+
+ def _release_semaphore(self, **kwargs):
+ self._semaphore.release()
+
+ def _submit_transfer(self, request_type, call_args):
+ on_done_after_calls = [self._release_semaphore]
+ coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
+ components = {
+ 'meta': CRTTransferMeta(self._id_counter, call_args),
+ 'coordinator': coordinator,
+ }
+ future = CRTTransferFuture(**components)
+ afterdone = AfterDoneHandler(coordinator)
+ on_done_after_calls.append(afterdone)
+
+ try:
+ self._semaphore.acquire()
+ on_queued = self._s3_args_creator.get_crt_callback(
+ future, 'queued'
+ )
+ on_queued()
+ crt_callargs = self._s3_args_creator.get_make_request_args(
+ request_type,
+ call_args,
+ coordinator,
+ future,
+ on_done_after_calls,
+ )
+ crt_s3_request = self._crt_s3_client.make_request(**crt_callargs)
+ except Exception as e:
+ coordinator.set_exception(e, True)
+ on_done = self._s3_args_creator.get_crt_callback(
+ future, 'done', after_subscribers=on_done_after_calls
+ )
+ on_done(error=e)
+ else:
+ coordinator.set_s3_request(crt_s3_request)
+ self._future_coordinators.append(coordinator)
+
+ self._id_counter += 1
+ return future
+
+
+class CRTTransferMeta(BaseTransferMeta):
+ """Holds metadata about the CRTTransferFuture"""
+
+ def __init__(self, transfer_id=None, call_args=None):
+ self._transfer_id = transfer_id
+ self._call_args = call_args
+ self._user_context = {}
+
+ @property
+ def call_args(self):
+ return self._call_args
+
+ @property
+ def transfer_id(self):
+ return self._transfer_id
+
+ @property
+ def user_context(self):
+ return self._user_context
+
+
+class CRTTransferFuture(BaseTransferFuture):
+ def __init__(self, meta=None, coordinator=None):
+ """The future associated to a submitted transfer request via CRT S3 client
+
+ :type meta: s3transfer.crt.CRTTransferMeta
+ :param meta: The metadata associated to the transfer future.
+
+ :type coordinator: s3transfer.crt.CRTTransferCoordinator
+ :param coordinator: The coordinator associated to the transfer future.
+ """
+ self._meta = meta
+ if meta is None:
+ self._meta = CRTTransferMeta()
+ self._coordinator = coordinator
+
+ @property
+ def meta(self):
+ return self._meta
+
+ def done(self):
+ return self._coordinator.done()
+
+ def result(self, timeout=None):
+ self._coordinator.result(timeout)
+
+ def cancel(self):
+ self._coordinator.cancel()
+
+ def set_exception(self, exception):
+ """Sets the exception on the future."""
+ if not self.done():
+ raise TransferNotDoneError(
+ 'set_exception can only be called once the transfer is '
+ 'complete.'
+ )
+ self._coordinator.set_exception(exception, override=True)
+
+
+class BaseCRTRequestSerializer:
+ def serialize_http_request(self, transfer_type, future):
+ """Serialize CRT HTTP requests.
+
+ :type transfer_type: string
+ :param transfer_type: the type of transfer made,
+ e.g 'put_object', 'get_object', 'delete_object'
+
+ :type future: s3transfer.crt.CRTTransferFuture
+
+ :rtype: awscrt.http.HttpRequest
+ :returns: An unsigned HTTP request to be used for the CRT S3 client
+ """
+ raise NotImplementedError('serialize_http_request()')
+
+
+class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer):
+ def __init__(self, session, client_kwargs=None):
+ """Serialize CRT HTTP request using botocore logic
+ It also takes into account configuration from both the session
+ and any keyword arguments that could be passed to
+ `Session.create_client()` when serializing the request.
+
+ :type session: botocore.session.Session
+
+ :type client_kwargs: Optional[Dict[str, str]])
+ :param client_kwargs: The kwargs for the botocore
+ s3 client initialization.
+ """
+ self._session = session
+ if client_kwargs is None:
+ client_kwargs = {}
+ self._resolve_client_config(session, client_kwargs)
+ self._client = session.create_client(**client_kwargs)
+ self._client.meta.events.register(
+ 'request-created.s3.*', self._capture_http_request
+ )
+ self._client.meta.events.register(
+ 'after-call.s3.*', self._change_response_to_serialized_http_request
+ )
+ self._client.meta.events.register(
+ 'before-send.s3.*', self._make_fake_http_response
+ )
+
+ def _resolve_client_config(self, session, client_kwargs):
+ user_provided_config = None
+ if session.get_default_client_config():
+ user_provided_config = session.get_default_client_config()
+ if 'config' in client_kwargs:
+ user_provided_config = client_kwargs['config']
+
+ client_config = Config(signature_version=UNSIGNED)
+ if user_provided_config:
+ client_config = user_provided_config.merge(client_config)
+ client_kwargs['config'] = client_config
+ client_kwargs["service_name"] = "s3"
+
+ def _crt_request_from_aws_request(self, aws_request):
+ url_parts = urlsplit(aws_request.url)
+ crt_path = url_parts.path
+ if url_parts.query:
+ crt_path = f'{crt_path}?{url_parts.query}'
+ headers_list = []
+ for name, value in aws_request.headers.items():
+ if isinstance(value, str):
+ headers_list.append((name, value))
+ else:
+ headers_list.append((name, str(value, 'utf-8')))
+
+ crt_headers = awscrt.http.HttpHeaders(headers_list)
+ # CRT requires body (if it exists) to be an I/O stream.
+ crt_body_stream = None
+ if aws_request.body:
+ if hasattr(aws_request.body, 'seek'):
+ crt_body_stream = aws_request.body
+ else:
+ crt_body_stream = BytesIO(aws_request.body)
+
+ crt_request = awscrt.http.HttpRequest(
+ method=aws_request.method,
+ path=crt_path,
+ headers=crt_headers,
+ body_stream=crt_body_stream,
+ )
+ return crt_request
+
+ def _convert_to_crt_http_request(self, botocore_http_request):
+ # Logic that does CRTUtils.crt_request_from_aws_request
+ crt_request = self._crt_request_from_aws_request(botocore_http_request)
+ if crt_request.headers.get("host") is None:
+ # If host is not set, set it for the request before using CRT s3
+ url_parts = urlsplit(botocore_http_request.url)
+ crt_request.headers.set("host", url_parts.netloc)
+ if crt_request.headers.get('Content-MD5') is not None:
+ crt_request.headers.remove("Content-MD5")
+ return crt_request
+
+ def _capture_http_request(self, request, **kwargs):
+ request.context['http_request'] = request
+
+ def _change_response_to_serialized_http_request(
+ self, context, parsed, **kwargs
+ ):
+ request = context['http_request']
+ parsed['HTTPRequest'] = request.prepare()
+
+ def _make_fake_http_response(self, request, **kwargs):
+ return botocore.awsrequest.AWSResponse(
+ None,
+ 200,
+ {},
+ FakeRawResponse(b""),
+ )
+
+ def _get_botocore_http_request(self, client_method, call_args):
+ return getattr(self._client, client_method)(
+ Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args
+ )['HTTPRequest']
+
+ def serialize_http_request(self, transfer_type, future):
+ botocore_http_request = self._get_botocore_http_request(
+ transfer_type, future.meta.call_args
+ )
+ crt_request = self._convert_to_crt_http_request(botocore_http_request)
+ return crt_request
+
+
+class FakeRawResponse(BytesIO):
+ def stream(self, amt=1024, decode_content=None):
+ while True:
+ chunk = self.read(amt)
+ if not chunk:
+ break
+ yield chunk
+
+
+class CRTTransferCoordinator:
+ """A helper class for managing CRTTransferFuture"""
+
+ def __init__(self, transfer_id=None, s3_request=None):
+ self.transfer_id = transfer_id
+ self._s3_request = s3_request
+ self._lock = threading.Lock()
+ self._exception = None
+ self._crt_future = None
+ self._done_event = threading.Event()
+
+ @property
+ def s3_request(self):
+ return self._s3_request
+
+ def set_done_callbacks_complete(self):
+ self._done_event.set()
+
+ def wait_until_on_done_callbacks_complete(self, timeout=None):
+ self._done_event.wait(timeout)
+
+ def set_exception(self, exception, override=False):
+ with self._lock:
+ if not self.done() or override:
+ self._exception = exception
+
+ def cancel(self):
+ if self._s3_request:
+ self._s3_request.cancel()
+
+ def result(self, timeout=None):
+ if self._exception:
+ raise self._exception
+ try:
+ self._crt_future.result(timeout)
+ except KeyboardInterrupt:
+ self.cancel()
+ raise
+ finally:
+ if self._s3_request:
+ self._s3_request = None
+ self._crt_future.result(timeout)
+
+ def done(self):
+ if self._crt_future is None:
+ return False
+ return self._crt_future.done()
+
+ def set_s3_request(self, s3_request):
+ self._s3_request = s3_request
+ self._crt_future = self._s3_request.finished_future
+
+
+class S3ClientArgsCreator:
+ def __init__(self, crt_request_serializer, os_utils):
+ self._request_serializer = crt_request_serializer
+ self._os_utils = os_utils
+
+ def get_make_request_args(
+ self, request_type, call_args, coordinator, future, on_done_after_calls
+ ):
+ recv_filepath = None
+ send_filepath = None
+ s3_meta_request_type = getattr(
+ S3RequestType, request_type.upper(), S3RequestType.DEFAULT
+ )
+ on_done_before_calls = []
+ if s3_meta_request_type == S3RequestType.GET_OBJECT:
+ final_filepath = call_args.fileobj
+ recv_filepath = self._os_utils.get_temp_filename(final_filepath)
+ file_ondone_call = RenameTempFileHandler(
+ coordinator, final_filepath, recv_filepath, self._os_utils
+ )
+ on_done_before_calls.append(file_ondone_call)
+ elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
+ send_filepath = call_args.fileobj
+ data_len = self._os_utils.get_file_size(send_filepath)
+ call_args.extra_args["ContentLength"] = data_len
+
+ crt_request = self._request_serializer.serialize_http_request(
+ request_type, future
+ )
+
+ return {
+ 'request': crt_request,
+ 'type': s3_meta_request_type,
+ 'recv_filepath': recv_filepath,
+ 'send_filepath': send_filepath,
+ 'on_done': self.get_crt_callback(
+ future, 'done', on_done_before_calls, on_done_after_calls
+ ),
+ 'on_progress': self.get_crt_callback(future, 'progress'),
+ }
+
+ def get_crt_callback(
+ self,
+ future,
+ callback_type,
+ before_subscribers=None,
+ after_subscribers=None,
+ ):
+ def invoke_all_callbacks(*args, **kwargs):
+ callbacks_list = []
+ if before_subscribers is not None:
+ callbacks_list += before_subscribers
+ callbacks_list += get_callbacks(future, callback_type)
+ if after_subscribers is not None:
+ callbacks_list += after_subscribers
+ for callback in callbacks_list:
+ # The get_callbacks helper will set the first augment
+ # by keyword, the other augments need to be set by keyword
+ # as well
+ if callback_type == "progress":
+ callback(bytes_transferred=args[0])
+ else:
+ callback(*args, **kwargs)
+
+ return invoke_all_callbacks
+
+
+class RenameTempFileHandler:
+ def __init__(self, coordinator, final_filename, temp_filename, osutil):
+ self._coordinator = coordinator
+ self._final_filename = final_filename
+ self._temp_filename = temp_filename
+ self._osutil = osutil
+
+ def __call__(self, **kwargs):
+ error = kwargs['error']
+ if error:
+ self._osutil.remove_file(self._temp_filename)
+ else:
+ try:
+ self._osutil.rename_file(
+ self._temp_filename, self._final_filename
+ )
+ except Exception as e:
+ self._osutil.remove_file(self._temp_filename)
+ # the CRT future has done already at this point
+ self._coordinator.set_exception(e)
+
+
+class AfterDoneHandler:
+ def __init__(self, coordinator):
+ self._coordinator = coordinator
+
+ def __call__(self, **kwargs):
+ self._coordinator.set_done_callbacks_complete()
diff --git a/contrib/python/s3transfer/py3/s3transfer/delete.py b/contrib/python/s3transfer/py3/s3transfer/delete.py
new file mode 100644
index 0000000000..74084d312a
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/delete.py
@@ -0,0 +1,71 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.tasks import SubmissionTask, Task
+
+
+class DeleteSubmissionTask(SubmissionTask):
+ """Task for submitting tasks to execute an object deletion."""
+
+ def _submit(self, client, request_executor, transfer_future, **kwargs):
+ """
+ :param client: The client associated with the transfer manager
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The transfer config associated with the transfer
+ manager
+
+ :type osutil: s3transfer.utils.OSUtil
+ :param osutil: The os utility associated to the transfer manager
+
+ :type request_executor: s3transfer.futures.BoundedExecutor
+ :param request_executor: The request executor associated with the
+ transfer manager
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+ """
+ call_args = transfer_future.meta.call_args
+
+ self._transfer_coordinator.submit(
+ request_executor,
+ DeleteObjectTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': call_args.extra_args,
+ },
+ is_final=True,
+ ),
+ )
+
+
+class DeleteObjectTask(Task):
+ def _main(self, client, bucket, key, extra_args):
+ """
+
+ :param client: The S3 client to use when calling DeleteObject
+
+ :type bucket: str
+ :param bucket: The name of the bucket.
+
+ :type key: str
+ :param key: The name of the object to delete.
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments to pass to the DeleteObject call.
+
+ """
+ client.delete_object(Bucket=bucket, Key=key, **extra_args)
diff --git a/contrib/python/s3transfer/py3/s3transfer/download.py b/contrib/python/s3transfer/py3/s3transfer/download.py
new file mode 100644
index 0000000000..dc8980d4ed
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/download.py
@@ -0,0 +1,790 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import heapq
+import logging
+import threading
+
+from s3transfer.compat import seekable
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG
+from s3transfer.tasks import SubmissionTask, Task
+from s3transfer.utils import (
+ S3_RETRYABLE_DOWNLOAD_ERRORS,
+ CountCallbackInvoker,
+ DeferredOpenFile,
+ FunctionContainer,
+ StreamReaderProgress,
+ calculate_num_parts,
+ calculate_range_parameter,
+ get_callbacks,
+ invoke_progress_callbacks,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DownloadOutputManager:
+ """Base manager class for handling various types of files for downloads
+
+ This class is typically used for the DownloadSubmissionTask class to help
+ determine the following:
+
+ * Provides the fileobj to write to downloads to
+ * Get a task to complete once everything downloaded has been written
+
+ The answers/implementations differ for the various types of file outputs
+ that may be accepted. All implementations must subclass and override
+ public methods from this class.
+ """
+
+ def __init__(self, osutil, transfer_coordinator, io_executor):
+ self._osutil = osutil
+ self._transfer_coordinator = transfer_coordinator
+ self._io_executor = io_executor
+
+ @classmethod
+ def is_compatible(cls, download_target, osutil):
+ """Determines if the target for the download is compatible with manager
+
+ :param download_target: The target for which the upload will write
+ data to.
+
+ :param osutil: The os utility to be used for the transfer
+
+ :returns: True if the manager can handle the type of target specified
+ otherwise returns False.
+ """
+ raise NotImplementedError('must implement is_compatible()')
+
+ def get_download_task_tag(self):
+ """Get the tag (if any) to associate all GetObjectTasks
+
+ :rtype: s3transfer.futures.TaskTag
+ :returns: The tag to associate all GetObjectTasks with
+ """
+ return None
+
+ def get_fileobj_for_io_writes(self, transfer_future):
+ """Get file-like object to use for io writes in the io executor
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The future associated with upload request
+
+ returns: A file-like object to write to
+ """
+ raise NotImplementedError('must implement get_fileobj_for_io_writes()')
+
+ def queue_file_io_task(self, fileobj, data, offset):
+ """Queue IO write for submission to the IO executor.
+
+ This method accepts an IO executor and information about the
+ downloaded data, and handles submitting this to the IO executor.
+
+ This method may defer submission to the IO executor if necessary.
+
+ """
+ self._transfer_coordinator.submit(
+ self._io_executor, self.get_io_write_task(fileobj, data, offset)
+ )
+
+ def get_io_write_task(self, fileobj, data, offset):
+ """Get an IO write task for the requested set of data
+
+ This task can be ran immediately or be submitted to the IO executor
+ for it to run.
+
+ :type fileobj: file-like object
+ :param fileobj: The file-like object to write to
+
+ :type data: bytes
+ :param data: The data to write out
+
+ :type offset: integer
+ :param offset: The offset to write the data to in the file-like object
+
+ :returns: An IO task to be used to write data to a file-like object
+ """
+ return IOWriteTask(
+ self._transfer_coordinator,
+ main_kwargs={
+ 'fileobj': fileobj,
+ 'data': data,
+ 'offset': offset,
+ },
+ )
+
+ def get_final_io_task(self):
+ """Get the final io task to complete the download
+
+ This is needed because based on the architecture of the TransferManager
+ the final tasks will be sent to the IO executor, but the executor
+ needs a final task for it to signal that the transfer is done and
+ all done callbacks can be run.
+
+ :rtype: s3transfer.tasks.Task
+ :returns: A final task to completed in the io executor
+ """
+ raise NotImplementedError('must implement get_final_io_task()')
+
+ def _get_fileobj_from_filename(self, filename):
+ f = DeferredOpenFile(
+ filename, mode='wb', open_function=self._osutil.open
+ )
+ # Make sure the file gets closed and we remove the temporary file
+ # if anything goes wrong during the process.
+ self._transfer_coordinator.add_failure_cleanup(f.close)
+ return f
+
+
+class DownloadFilenameOutputManager(DownloadOutputManager):
+ def __init__(self, osutil, transfer_coordinator, io_executor):
+ super().__init__(osutil, transfer_coordinator, io_executor)
+ self._final_filename = None
+ self._temp_filename = None
+ self._temp_fileobj = None
+
+ @classmethod
+ def is_compatible(cls, download_target, osutil):
+ return isinstance(download_target, str)
+
+ def get_fileobj_for_io_writes(self, transfer_future):
+ fileobj = transfer_future.meta.call_args.fileobj
+ self._final_filename = fileobj
+ self._temp_filename = self._osutil.get_temp_filename(fileobj)
+ self._temp_fileobj = self._get_temp_fileobj()
+ return self._temp_fileobj
+
+ def get_final_io_task(self):
+ # A task to rename the file from the temporary file to its final
+ # location is needed. This should be the last task needed to complete
+ # the download.
+ return IORenameFileTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'fileobj': self._temp_fileobj,
+ 'final_filename': self._final_filename,
+ 'osutil': self._osutil,
+ },
+ is_final=True,
+ )
+
+ def _get_temp_fileobj(self):
+ f = self._get_fileobj_from_filename(self._temp_filename)
+ self._transfer_coordinator.add_failure_cleanup(
+ self._osutil.remove_file, self._temp_filename
+ )
+ return f
+
+
+class DownloadSeekableOutputManager(DownloadOutputManager):
+ @classmethod
+ def is_compatible(cls, download_target, osutil):
+ return seekable(download_target)
+
+ def get_fileobj_for_io_writes(self, transfer_future):
+ # Return the fileobj provided to the future.
+ return transfer_future.meta.call_args.fileobj
+
+ def get_final_io_task(self):
+ # This task will serve the purpose of signaling when all of the io
+ # writes have finished so done callbacks can be called.
+ return CompleteDownloadNOOPTask(
+ transfer_coordinator=self._transfer_coordinator
+ )
+
+
+class DownloadNonSeekableOutputManager(DownloadOutputManager):
+ def __init__(
+ self, osutil, transfer_coordinator, io_executor, defer_queue=None
+ ):
+ super().__init__(osutil, transfer_coordinator, io_executor)
+ if defer_queue is None:
+ defer_queue = DeferQueue()
+ self._defer_queue = defer_queue
+ self._io_submit_lock = threading.Lock()
+
+ @classmethod
+ def is_compatible(cls, download_target, osutil):
+ return hasattr(download_target, 'write')
+
+ def get_download_task_tag(self):
+ return IN_MEMORY_DOWNLOAD_TAG
+
+ def get_fileobj_for_io_writes(self, transfer_future):
+ return transfer_future.meta.call_args.fileobj
+
+ def get_final_io_task(self):
+ return CompleteDownloadNOOPTask(
+ transfer_coordinator=self._transfer_coordinator
+ )
+
+ def queue_file_io_task(self, fileobj, data, offset):
+ with self._io_submit_lock:
+ writes = self._defer_queue.request_writes(offset, data)
+ for write in writes:
+ data = write['data']
+ logger.debug(
+ "Queueing IO offset %s for fileobj: %s",
+ write['offset'],
+ fileobj,
+ )
+ super().queue_file_io_task(fileobj, data, offset)
+
+ def get_io_write_task(self, fileobj, data, offset):
+ return IOStreamingWriteTask(
+ self._transfer_coordinator,
+ main_kwargs={
+ 'fileobj': fileobj,
+ 'data': data,
+ },
+ )
+
+
+class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager):
+ def __init__(
+ self, osutil, transfer_coordinator, io_executor, defer_queue=None
+ ):
+ super().__init__(
+ osutil, transfer_coordinator, io_executor, defer_queue
+ )
+ self._fileobj = None
+
+ @classmethod
+ def is_compatible(cls, download_target, osutil):
+ return isinstance(download_target, str) and osutil.is_special_file(
+ download_target
+ )
+
+ def get_fileobj_for_io_writes(self, transfer_future):
+ filename = transfer_future.meta.call_args.fileobj
+ self._fileobj = self._get_fileobj_from_filename(filename)
+ return self._fileobj
+
+ def get_final_io_task(self):
+ # Make sure the file gets closed once the transfer is done.
+ return IOCloseTask(
+ transfer_coordinator=self._transfer_coordinator,
+ is_final=True,
+ main_kwargs={'fileobj': self._fileobj},
+ )
+
+
+class DownloadSubmissionTask(SubmissionTask):
+ """Task for submitting tasks to execute a download"""
+
+ def _get_download_output_manager_cls(self, transfer_future, osutil):
+ """Retrieves a class for managing output for a download
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future for the request
+
+ :type osutil: s3transfer.utils.OSUtils
+ :param osutil: The os utility associated to the transfer
+
+ :rtype: class of DownloadOutputManager
+ :returns: The appropriate class to use for managing a specific type of
+ input for downloads.
+ """
+ download_manager_resolver_chain = [
+ DownloadSpecialFilenameOutputManager,
+ DownloadFilenameOutputManager,
+ DownloadSeekableOutputManager,
+ DownloadNonSeekableOutputManager,
+ ]
+
+ fileobj = transfer_future.meta.call_args.fileobj
+ for download_manager_cls in download_manager_resolver_chain:
+ if download_manager_cls.is_compatible(fileobj, osutil):
+ return download_manager_cls
+ raise RuntimeError(
+ 'Output {} of type: {} is not supported.'.format(
+ fileobj, type(fileobj)
+ )
+ )
+
+ def _submit(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ transfer_future,
+ bandwidth_limiter=None,
+ ):
+ """
+ :param client: The client associated with the transfer manager
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The transfer config associated with the transfer
+ manager
+
+ :type osutil: s3transfer.utils.OSUtil
+ :param osutil: The os utility associated to the transfer manager
+
+ :type request_executor: s3transfer.futures.BoundedExecutor
+ :param request_executor: The request executor associated with the
+ transfer manager
+
+ :type io_executor: s3transfer.futures.BoundedExecutor
+ :param io_executor: The io executor associated with the
+ transfer manager
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+
+ :type bandwidth_limiter: s3transfer.bandwidth.BandwidthLimiter
+ :param bandwidth_limiter: The bandwidth limiter to use when
+ downloading streams
+ """
+ if transfer_future.meta.size is None:
+ # If a size was not provided figure out the size for the
+ # user.
+ response = client.head_object(
+ Bucket=transfer_future.meta.call_args.bucket,
+ Key=transfer_future.meta.call_args.key,
+ **transfer_future.meta.call_args.extra_args,
+ )
+ transfer_future.meta.provide_transfer_size(
+ response['ContentLength']
+ )
+
+ download_output_manager = self._get_download_output_manager_cls(
+ transfer_future, osutil
+ )(osutil, self._transfer_coordinator, io_executor)
+
+ # If it is greater than threshold do a ranged download, otherwise
+ # do a regular GetObject download.
+ if transfer_future.meta.size < config.multipart_threshold:
+ self._submit_download_request(
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ )
+ else:
+ self._submit_ranged_download_request(
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ )
+
+ def _submit_download_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Get a handle to the file that will be used for writing downloaded
+ # contents
+ fileobj = download_output_manager.get_fileobj_for_io_writes(
+ transfer_future
+ )
+
+ # Get the needed callbacks for the task
+ progress_callbacks = get_callbacks(transfer_future, 'progress')
+
+ # Get any associated tags for the get object task.
+ get_object_tag = download_output_manager.get_download_task_tag()
+
+ # Get the final io task to run once the download is complete.
+ final_task = download_output_manager.get_final_io_task()
+
+ # Submit the task to download the object.
+ self._transfer_coordinator.submit(
+ request_executor,
+ ImmediatelyWriteIOGetObjectTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'fileobj': fileobj,
+ 'extra_args': call_args.extra_args,
+ 'callbacks': progress_callbacks,
+ 'max_attempts': config.num_download_attempts,
+ 'download_output_manager': download_output_manager,
+ 'io_chunksize': config.io_chunksize,
+ 'bandwidth_limiter': bandwidth_limiter,
+ },
+ done_callbacks=[final_task],
+ ),
+ tag=get_object_tag,
+ )
+
+ def _submit_ranged_download_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Get the needed progress callbacks for the task
+ progress_callbacks = get_callbacks(transfer_future, 'progress')
+
+ # Get a handle to the file that will be used for writing downloaded
+ # contents
+ fileobj = download_output_manager.get_fileobj_for_io_writes(
+ transfer_future
+ )
+
+ # Determine the number of parts
+ part_size = config.multipart_chunksize
+ num_parts = calculate_num_parts(transfer_future.meta.size, part_size)
+
+ # Get any associated tags for the get object task.
+ get_object_tag = download_output_manager.get_download_task_tag()
+
+ # Callback invoker to submit the final io task once all downloads
+ # are complete.
+ finalize_download_invoker = CountCallbackInvoker(
+ self._get_final_io_task_submission_callback(
+ download_output_manager, io_executor
+ )
+ )
+ for i in range(num_parts):
+ # Calculate the range parameter
+ range_parameter = calculate_range_parameter(
+ part_size, i, num_parts
+ )
+
+ # Inject the Range parameter to the parameters to be passed in
+ # as extra args
+ extra_args = {'Range': range_parameter}
+ extra_args.update(call_args.extra_args)
+ finalize_download_invoker.increment()
+ # Submit the ranged downloads
+ self._transfer_coordinator.submit(
+ request_executor,
+ GetObjectTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'fileobj': fileobj,
+ 'extra_args': extra_args,
+ 'callbacks': progress_callbacks,
+ 'max_attempts': config.num_download_attempts,
+ 'start_index': i * part_size,
+ 'download_output_manager': download_output_manager,
+ 'io_chunksize': config.io_chunksize,
+ 'bandwidth_limiter': bandwidth_limiter,
+ },
+ done_callbacks=[finalize_download_invoker.decrement],
+ ),
+ tag=get_object_tag,
+ )
+ finalize_download_invoker.finalize()
+
+ def _get_final_io_task_submission_callback(
+ self, download_manager, io_executor
+ ):
+ final_task = download_manager.get_final_io_task()
+ return FunctionContainer(
+ self._transfer_coordinator.submit, io_executor, final_task
+ )
+
+ def _calculate_range_param(self, part_size, part_index, num_parts):
+ # Used to calculate the Range parameter
+ start_range = part_index * part_size
+ if part_index == num_parts - 1:
+ end_range = ''
+ else:
+ end_range = start_range + part_size - 1
+ range_param = f'bytes={start_range}-{end_range}'
+ return range_param
+
+
+class GetObjectTask(Task):
+ def _main(
+ self,
+ client,
+ bucket,
+ key,
+ fileobj,
+ extra_args,
+ callbacks,
+ max_attempts,
+ download_output_manager,
+ io_chunksize,
+ start_index=0,
+ bandwidth_limiter=None,
+ ):
+ """Downloads an object and places content into io queue
+
+ :param client: The client to use when calling GetObject
+ :param bucket: The bucket to download from
+ :param key: The key to download from
+ :param fileobj: The file handle to write content to
+ :param exta_args: Any extra arguments to include in GetObject request
+ :param callbacks: List of progress callbacks to invoke on download
+ :param max_attempts: The number of retries to do when downloading
+ :param download_output_manager: The download output manager associated
+ with the current download.
+ :param io_chunksize: The size of each io chunk to read from the
+ download stream and queue in the io queue.
+ :param start_index: The location in the file to start writing the
+ content of the key to.
+ :param bandwidth_limiter: The bandwidth limiter to use when throttling
+ the downloading of data in streams.
+ """
+ last_exception = None
+ for i in range(max_attempts):
+ try:
+ current_index = start_index
+ response = client.get_object(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ streaming_body = StreamReaderProgress(
+ response['Body'], callbacks
+ )
+ if bandwidth_limiter:
+ streaming_body = (
+ bandwidth_limiter.get_bandwith_limited_stream(
+ streaming_body, self._transfer_coordinator
+ )
+ )
+
+ chunks = DownloadChunkIterator(streaming_body, io_chunksize)
+ for chunk in chunks:
+ # If the transfer is done because of a cancellation
+ # or error somewhere else, stop trying to submit more
+ # data to be written and break out of the download.
+ if not self._transfer_coordinator.done():
+ self._handle_io(
+ download_output_manager,
+ fileobj,
+ chunk,
+ current_index,
+ )
+ current_index += len(chunk)
+ else:
+ return
+ return
+ except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
+ last_exception = e
+ # Also invoke the progress callbacks to indicate that we
+ # are trying to download the stream again and all progress
+ # for this GetObject has been lost.
+ invoke_progress_callbacks(
+ callbacks, start_index - current_index
+ )
+ continue
+ raise RetriesExceededError(last_exception)
+
+ def _handle_io(self, download_output_manager, fileobj, chunk, index):
+ download_output_manager.queue_file_io_task(fileobj, chunk, index)
+
+
+class ImmediatelyWriteIOGetObjectTask(GetObjectTask):
+ """GetObjectTask that immediately writes to the provided file object
+
+ This is useful for downloads where it is known only one thread is
+ downloading the object so there is no reason to go through the
+ overhead of using an IO queue and executor.
+ """
+
+ def _handle_io(self, download_output_manager, fileobj, chunk, index):
+ task = download_output_manager.get_io_write_task(fileobj, chunk, index)
+ task()
+
+
+class IOWriteTask(Task):
+ def _main(self, fileobj, data, offset):
+ """Pulls off an io queue to write contents to a file
+
+ :param fileobj: The file handle to write content to
+ :param data: The data to write
+ :param offset: The offset to write the data to.
+ """
+ fileobj.seek(offset)
+ fileobj.write(data)
+
+
+class IOStreamingWriteTask(Task):
+ """Task for writing data to a non-seekable stream."""
+
+ def _main(self, fileobj, data):
+ """Write data to a fileobj.
+
+ Data will be written directly to the fileobj without
+ any prior seeking.
+
+ :param fileobj: The fileobj to write content to
+ :param data: The data to write
+
+ """
+ fileobj.write(data)
+
+
+class IORenameFileTask(Task):
+ """A task to rename a temporary file to its final filename
+
+ :param fileobj: The file handle that content was written to.
+ :param final_filename: The final name of the file to rename to
+ upon completion of writing the contents.
+ :param osutil: OS utility
+ """
+
+ def _main(self, fileobj, final_filename, osutil):
+ fileobj.close()
+ osutil.rename_file(fileobj.name, final_filename)
+
+
+class IOCloseTask(Task):
+ """A task to close out a file once the download is complete.
+
+ :param fileobj: The fileobj to close.
+ """
+
+ def _main(self, fileobj):
+ fileobj.close()
+
+
+class CompleteDownloadNOOPTask(Task):
+ """A NOOP task to serve as an indicator that the download is complete
+
+ Note that the default for is_final is set to True because this should
+ always be the last task.
+ """
+
+ def __init__(
+ self,
+ transfer_coordinator,
+ main_kwargs=None,
+ pending_main_kwargs=None,
+ done_callbacks=None,
+ is_final=True,
+ ):
+ super().__init__(
+ transfer_coordinator=transfer_coordinator,
+ main_kwargs=main_kwargs,
+ pending_main_kwargs=pending_main_kwargs,
+ done_callbacks=done_callbacks,
+ is_final=is_final,
+ )
+
+ def _main(self):
+ pass
+
+
+class DownloadChunkIterator:
+ def __init__(self, body, chunksize):
+ """Iterator to chunk out a downloaded S3 stream
+
+ :param body: A readable file-like object
+ :param chunksize: The amount to read each time
+ """
+ self._body = body
+ self._chunksize = chunksize
+ self._num_reads = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ chunk = self._body.read(self._chunksize)
+ self._num_reads += 1
+ if chunk:
+ return chunk
+ elif self._num_reads == 1:
+ # Even though the response may have not had any
+ # content, we still want to account for an empty object's
+ # existence so return the empty chunk for that initial
+ # read.
+ return chunk
+ raise StopIteration()
+
+ next = __next__
+
+
+class DeferQueue:
+ """IO queue that defers write requests until they are queued sequentially.
+
+ This class is used to track IO data for a *single* fileobj.
+
+ You can send data to this queue, and it will defer any IO write requests
+ until it has the next contiguous block available (starting at 0).
+
+ """
+
+ def __init__(self):
+ self._writes = []
+ self._pending_offsets = set()
+ self._next_offset = 0
+
+ def request_writes(self, offset, data):
+ """Request any available writes given new incoming data.
+
+ You call this method by providing new data along with the
+ offset associated with the data. If that new data unlocks
+ any contiguous writes that can now be submitted, this
+ method will return all applicable writes.
+
+ This is done with 1 method call so you don't have to
+ make two method calls (put(), get()) which acquires a lock
+ each method call.
+
+ """
+ if offset < self._next_offset:
+ # This is a request for a write that we've already
+ # seen. This can happen in the event of a retry
+ # where if we retry at at offset N/2, we'll requeue
+ # offsets 0-N/2 again.
+ return []
+ writes = []
+ if offset in self._pending_offsets:
+ # We've already queued this offset so this request is
+ # a duplicate. In this case we should ignore
+ # this request and prefer what's already queued.
+ return []
+ heapq.heappush(self._writes, (offset, data))
+ self._pending_offsets.add(offset)
+ while self._writes and self._writes[0][0] == self._next_offset:
+ next_write = heapq.heappop(self._writes)
+ writes.append({'offset': next_write[0], 'data': next_write[1]})
+ self._pending_offsets.remove(next_write[0])
+ self._next_offset += len(next_write[1])
+ return writes
diff --git a/contrib/python/s3transfer/py3/s3transfer/exceptions.py b/contrib/python/s3transfer/py3/s3transfer/exceptions.py
new file mode 100644
index 0000000000..6150fe650d
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/exceptions.py
@@ -0,0 +1,37 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from concurrent.futures import CancelledError
+
+
+class RetriesExceededError(Exception):
+ def __init__(self, last_exception, msg='Max Retries Exceeded'):
+ super().__init__(msg)
+ self.last_exception = last_exception
+
+
+class S3UploadFailedError(Exception):
+ pass
+
+
+class InvalidSubscriberMethodError(Exception):
+ pass
+
+
+class TransferNotDoneError(Exception):
+ pass
+
+
+class FatalError(CancelledError):
+ """A CancelledError raised from an error in the TransferManager"""
+
+ pass
diff --git a/contrib/python/s3transfer/py3/s3transfer/futures.py b/contrib/python/s3transfer/py3/s3transfer/futures.py
new file mode 100644
index 0000000000..39e071fb60
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/futures.py
@@ -0,0 +1,606 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import logging
+import sys
+import threading
+from collections import namedtuple
+from concurrent import futures
+
+from s3transfer.compat import MAXINT
+from s3transfer.exceptions import CancelledError, TransferNotDoneError
+from s3transfer.utils import FunctionContainer, TaskSemaphore
+
+logger = logging.getLogger(__name__)
+
+
+class BaseTransferFuture:
+ @property
+ def meta(self):
+ """The metadata associated to the TransferFuture"""
+ raise NotImplementedError('meta')
+
+ def done(self):
+ """Determines if a TransferFuture has completed
+
+ :returns: True if completed. False, otherwise.
+ """
+ raise NotImplementedError('done()')
+
+ def result(self):
+ """Waits until TransferFuture is done and returns the result
+
+ If the TransferFuture succeeded, it will return the result. If the
+ TransferFuture failed, it will raise the exception associated to the
+ failure.
+ """
+ raise NotImplementedError('result()')
+
+ def cancel(self):
+ """Cancels the request associated with the TransferFuture"""
+ raise NotImplementedError('cancel()')
+
+
+class BaseTransferMeta:
+ @property
+ def call_args(self):
+ """The call args used in the transfer request"""
+ raise NotImplementedError('call_args')
+
+ @property
+ def transfer_id(self):
+ """The unique id of the transfer"""
+ raise NotImplementedError('transfer_id')
+
+ @property
+ def user_context(self):
+ """A dictionary that requesters can store data in"""
+ raise NotImplementedError('user_context')
+
+
+class TransferFuture(BaseTransferFuture):
+ def __init__(self, meta=None, coordinator=None):
+ """The future associated to a submitted transfer request
+
+ :type meta: TransferMeta
+ :param meta: The metadata associated to the request. This object
+ is visible to the requester.
+
+ :type coordinator: TransferCoordinator
+ :param coordinator: The coordinator associated to the request. This
+ object is not visible to the requester.
+ """
+ self._meta = meta
+ if meta is None:
+ self._meta = TransferMeta()
+
+ self._coordinator = coordinator
+ if coordinator is None:
+ self._coordinator = TransferCoordinator()
+
+ @property
+ def meta(self):
+ return self._meta
+
+ def done(self):
+ return self._coordinator.done()
+
+ def result(self):
+ try:
+ # Usually the result() method blocks until the transfer is done,
+ # however if a KeyboardInterrupt is raised we want want to exit
+ # out of this and propagate the exception.
+ return self._coordinator.result()
+ except KeyboardInterrupt as e:
+ self.cancel()
+ raise e
+
+ def cancel(self):
+ self._coordinator.cancel()
+
+ def set_exception(self, exception):
+ """Sets the exception on the future."""
+ if not self.done():
+ raise TransferNotDoneError(
+ 'set_exception can only be called once the transfer is '
+ 'complete.'
+ )
+ self._coordinator.set_exception(exception, override=True)
+
+
+class TransferMeta(BaseTransferMeta):
+ """Holds metadata about the TransferFuture"""
+
+ def __init__(self, call_args=None, transfer_id=None):
+ self._call_args = call_args
+ self._transfer_id = transfer_id
+ self._size = None
+ self._user_context = {}
+
+ @property
+ def call_args(self):
+ """The call args used in the transfer request"""
+ return self._call_args
+
+ @property
+ def transfer_id(self):
+ """The unique id of the transfer"""
+ return self._transfer_id
+
+ @property
+ def size(self):
+ """The size of the transfer request if known"""
+ return self._size
+
+ @property
+ def user_context(self):
+ """A dictionary that requesters can store data in"""
+ return self._user_context
+
+ def provide_transfer_size(self, size):
+ """A method to provide the size of a transfer request
+
+ By providing this value, the TransferManager will not try to
+ call HeadObject or use the use OS to determine the size of the
+ transfer.
+ """
+ self._size = size
+
+
+class TransferCoordinator:
+ """A helper class for managing TransferFuture"""
+
+ def __init__(self, transfer_id=None):
+ self.transfer_id = transfer_id
+ self._status = 'not-started'
+ self._result = None
+ self._exception = None
+ self._associated_futures = set()
+ self._failure_cleanups = []
+ self._done_callbacks = []
+ self._done_event = threading.Event()
+ self._lock = threading.Lock()
+ self._associated_futures_lock = threading.Lock()
+ self._done_callbacks_lock = threading.Lock()
+ self._failure_cleanups_lock = threading.Lock()
+
+ def __repr__(self):
+ return '{}(transfer_id={})'.format(
+ self.__class__.__name__, self.transfer_id
+ )
+
+ @property
+ def exception(self):
+ return self._exception
+
+ @property
+ def associated_futures(self):
+ """The list of futures associated to the inprogress TransferFuture
+
+ Once the transfer finishes this list becomes empty as the transfer
+ is considered done and there should be no running futures left.
+ """
+ with self._associated_futures_lock:
+ # We return a copy of the list because we do not want to
+ # processing the returned list while another thread is adding
+ # more futures to the actual list.
+ return copy.copy(self._associated_futures)
+
+ @property
+ def failure_cleanups(self):
+ """The list of callbacks to call when the TransferFuture fails"""
+ return self._failure_cleanups
+
+ @property
+ def status(self):
+ """The status of the TransferFuture
+
+ The currently supported states are:
+ * not-started - Has yet to start. If in this state, a transfer
+ can be canceled immediately and nothing will happen.
+ * queued - SubmissionTask is about to submit tasks
+ * running - Is inprogress. In-progress as of now means that
+ the SubmissionTask that runs the transfer is being executed. So
+ there is no guarantee any transfer requests had been made to
+ S3 if this state is reached.
+ * cancelled - Was cancelled
+ * failed - An exception other than CancelledError was thrown
+ * success - No exceptions were thrown and is done.
+ """
+ return self._status
+
+ def set_result(self, result):
+ """Set a result for the TransferFuture
+
+ Implies that the TransferFuture succeeded. This will always set a
+ result because it is invoked on the final task where there is only
+ ever one final task and it is ran at the very end of a transfer
+ process. So if a result is being set for this final task, the transfer
+ succeeded even if something came a long and canceled the transfer
+ on the final task.
+ """
+ with self._lock:
+ self._exception = None
+ self._result = result
+ self._status = 'success'
+
+ def set_exception(self, exception, override=False):
+ """Set an exception for the TransferFuture
+
+ Implies the TransferFuture failed.
+
+ :param exception: The exception that cause the transfer to fail.
+ :param override: If True, override any existing state.
+ """
+ with self._lock:
+ if not self.done() or override:
+ self._exception = exception
+ self._status = 'failed'
+
+ def result(self):
+ """Waits until TransferFuture is done and returns the result
+
+ If the TransferFuture succeeded, it will return the result. If the
+ TransferFuture failed, it will raise the exception associated to the
+ failure.
+ """
+ # Doing a wait() with no timeout cannot be interrupted in python2 but
+ # can be interrupted in python3 so we just wait with the largest
+ # possible value integer value, which is on the scale of billions of
+ # years...
+ self._done_event.wait(MAXINT)
+
+ # Once done waiting, raise an exception if present or return the
+ # final result.
+ if self._exception:
+ raise self._exception
+ return self._result
+
+ def cancel(self, msg='', exc_type=CancelledError):
+ """Cancels the TransferFuture
+
+ :param msg: The message to attach to the cancellation
+ :param exc_type: The type of exception to set for the cancellation
+ """
+ with self._lock:
+ if not self.done():
+ should_announce_done = False
+ logger.debug('%s cancel(%s) called', self, msg)
+ self._exception = exc_type(msg)
+ if self._status == 'not-started':
+ should_announce_done = True
+ self._status = 'cancelled'
+ if should_announce_done:
+ self.announce_done()
+
+ def set_status_to_queued(self):
+ """Sets the TransferFutrue's status to running"""
+ self._transition_to_non_done_state('queued')
+
+ def set_status_to_running(self):
+ """Sets the TransferFuture's status to running"""
+ self._transition_to_non_done_state('running')
+
+ def _transition_to_non_done_state(self, desired_state):
+ with self._lock:
+ if self.done():
+ raise RuntimeError(
+ 'Unable to transition from done state %s to non-done '
+ 'state %s.' % (self.status, desired_state)
+ )
+ self._status = desired_state
+
+ def submit(self, executor, task, tag=None):
+ """Submits a task to a provided executor
+
+ :type executor: s3transfer.futures.BoundedExecutor
+ :param executor: The executor to submit the callable to
+
+ :type task: s3transfer.tasks.Task
+ :param task: The task to submit to the executor
+
+ :type tag: s3transfer.futures.TaskTag
+ :param tag: A tag to associate to the submitted task
+
+ :rtype: concurrent.futures.Future
+ :returns: A future representing the submitted task
+ """
+ logger.debug(
+ "Submitting task {} to executor {} for transfer request: {}.".format(
+ task, executor, self.transfer_id
+ )
+ )
+ future = executor.submit(task, tag=tag)
+ # Add this created future to the list of associated future just
+ # in case it is needed during cleanups.
+ self.add_associated_future(future)
+ future.add_done_callback(
+ FunctionContainer(self.remove_associated_future, future)
+ )
+ return future
+
+ def done(self):
+ """Determines if a TransferFuture has completed
+
+ :returns: False if status is equal to 'failed', 'cancelled', or
+ 'success'. True, otherwise
+ """
+ return self.status in ['failed', 'cancelled', 'success']
+
+ def add_associated_future(self, future):
+ """Adds a future to be associated with the TransferFuture"""
+ with self._associated_futures_lock:
+ self._associated_futures.add(future)
+
+ def remove_associated_future(self, future):
+ """Removes a future's association to the TransferFuture"""
+ with self._associated_futures_lock:
+ self._associated_futures.remove(future)
+
+ def add_done_callback(self, function, *args, **kwargs):
+ """Add a done callback to be invoked when transfer is done"""
+ with self._done_callbacks_lock:
+ self._done_callbacks.append(
+ FunctionContainer(function, *args, **kwargs)
+ )
+
+ def add_failure_cleanup(self, function, *args, **kwargs):
+ """Adds a callback to call upon failure"""
+ with self._failure_cleanups_lock:
+ self._failure_cleanups.append(
+ FunctionContainer(function, *args, **kwargs)
+ )
+
+ def announce_done(self):
+ """Announce that future is done running and run associated callbacks
+
+ This will run any failure cleanups if the transfer failed if not
+ they have not been run, allows the result() to be unblocked, and will
+ run any done callbacks associated to the TransferFuture if they have
+ not already been ran.
+ """
+ if self.status != 'success':
+ self._run_failure_cleanups()
+ self._done_event.set()
+ self._run_done_callbacks()
+
+ def _run_done_callbacks(self):
+ # Run the callbacks and remove the callbacks from the internal
+ # list so they do not get ran again if done is announced more than
+ # once.
+ with self._done_callbacks_lock:
+ self._run_callbacks(self._done_callbacks)
+ self._done_callbacks = []
+
+ def _run_failure_cleanups(self):
+ # Run the cleanup callbacks and remove the callbacks from the internal
+ # list so they do not get ran again if done is announced more than
+ # once.
+ with self._failure_cleanups_lock:
+ self._run_callbacks(self.failure_cleanups)
+ self._failure_cleanups = []
+
+ def _run_callbacks(self, callbacks):
+ for callback in callbacks:
+ self._run_callback(callback)
+
+ def _run_callback(self, callback):
+ try:
+ callback()
+ # We do not want a callback interrupting the process, especially
+ # in the failure cleanups. So log and catch, the exception.
+ except Exception:
+ logger.debug("Exception raised in %s." % callback, exc_info=True)
+
+
+class BoundedExecutor:
+ EXECUTOR_CLS = futures.ThreadPoolExecutor
+
+ def __init__(
+ self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None
+ ):
+ """An executor implementation that has a maximum queued up tasks
+
+ The executor will block if the number of tasks that have been
+ submitted and is currently working on is past its maximum.
+
+ :params max_size: The maximum number of inflight futures. An inflight
+ future means that the task is either queued up or is currently
+ being executed. A size of None or 0 means that the executor will
+ have no bound in terms of the number of inflight futures.
+
+ :params max_num_threads: The maximum number of threads the executor
+ uses.
+
+ :type tag_semaphores: dict
+ :params tag_semaphores: A dictionary where the key is the name of the
+ tag and the value is the semaphore to use when limiting the
+ number of tasks the executor is processing at a time.
+
+ :type executor_cls: BaseExecutor
+ :param underlying_executor_cls: The executor class that
+ get bounded by this executor. If None is provided, the
+ concurrent.futures.ThreadPoolExecutor class is used.
+ """
+ self._max_num_threads = max_num_threads
+ if executor_cls is None:
+ executor_cls = self.EXECUTOR_CLS
+ self._executor = executor_cls(max_workers=self._max_num_threads)
+ self._semaphore = TaskSemaphore(max_size)
+ self._tag_semaphores = tag_semaphores
+
+ def submit(self, task, tag=None, block=True):
+ """Submit a task to complete
+
+ :type task: s3transfer.tasks.Task
+ :param task: The task to run __call__ on
+
+
+ :type tag: s3transfer.futures.TaskTag
+ :param tag: An optional tag to associate to the task. This
+ is used to override which semaphore to use.
+
+ :type block: boolean
+ :param block: True if to wait till it is possible to submit a task.
+ False, if not to wait and raise an error if not able to submit
+ a task.
+
+ :returns: The future associated to the submitted task
+ """
+ semaphore = self._semaphore
+ # If a tag was provided, use the semaphore associated to that
+ # tag.
+ if tag:
+ semaphore = self._tag_semaphores[tag]
+
+ # Call acquire on the semaphore.
+ acquire_token = semaphore.acquire(task.transfer_id, block)
+ # Create a callback to invoke when task is done in order to call
+ # release on the semaphore.
+ release_callback = FunctionContainer(
+ semaphore.release, task.transfer_id, acquire_token
+ )
+ # Submit the task to the underlying executor.
+ future = ExecutorFuture(self._executor.submit(task))
+ # Add the Semaphore.release() callback to the future such that
+ # it is invoked once the future completes.
+ future.add_done_callback(release_callback)
+ return future
+
+ def shutdown(self, wait=True):
+ self._executor.shutdown(wait)
+
+
+class ExecutorFuture:
+ def __init__(self, future):
+ """A future returned from the executor
+
+ Currently, it is just a wrapper around a concurrent.futures.Future.
+ However, this can eventually grow to implement the needed functionality
+ of concurrent.futures.Future if we move off of the library and not
+ affect the rest of the codebase.
+
+ :type future: concurrent.futures.Future
+ :param future: The underlying future
+ """
+ self._future = future
+
+ def result(self):
+ return self._future.result()
+
+ def add_done_callback(self, fn):
+ """Adds a callback to be completed once future is done
+
+ :param fn: A callable that takes no arguments. Note that is different
+ than concurrent.futures.Future.add_done_callback that requires
+ a single argument for the future.
+ """
+ # The done callback for concurrent.futures.Future will always pass a
+ # the future in as the only argument. So we need to create the
+ # proper signature wrapper that will invoke the callback provided.
+ def done_callback(future_passed_to_callback):
+ return fn()
+
+ self._future.add_done_callback(done_callback)
+
+ def done(self):
+ return self._future.done()
+
+
+class BaseExecutor:
+ """Base Executor class implementation needed to work with s3transfer"""
+
+ def __init__(self, max_workers=None):
+ pass
+
+ def submit(self, fn, *args, **kwargs):
+ raise NotImplementedError('submit()')
+
+ def shutdown(self, wait=True):
+ raise NotImplementedError('shutdown()')
+
+
+class NonThreadedExecutor(BaseExecutor):
+ """A drop-in replacement non-threaded version of ThreadPoolExecutor"""
+
+ def submit(self, fn, *args, **kwargs):
+ future = NonThreadedExecutorFuture()
+ try:
+ result = fn(*args, **kwargs)
+ future.set_result(result)
+ except Exception:
+ e, tb = sys.exc_info()[1:]
+ logger.debug(
+ 'Setting exception for %s to %s with traceback %s',
+ future,
+ e,
+ tb,
+ )
+ future.set_exception_info(e, tb)
+ return future
+
+ def shutdown(self, wait=True):
+ pass
+
+
+class NonThreadedExecutorFuture:
+ """The Future returned from NonThreadedExecutor
+
+ Note that this future is **not** thread-safe as it is being used
+ from the context of a non-threaded environment.
+ """
+
+ def __init__(self):
+ self._result = None
+ self._exception = None
+ self._traceback = None
+ self._done = False
+ self._done_callbacks = []
+
+ def set_result(self, result):
+ self._result = result
+ self._set_done()
+
+ def set_exception_info(self, exception, traceback):
+ self._exception = exception
+ self._traceback = traceback
+ self._set_done()
+
+ def result(self, timeout=None):
+ if self._exception:
+ raise self._exception.with_traceback(self._traceback)
+ return self._result
+
+ def _set_done(self):
+ self._done = True
+ for done_callback in self._done_callbacks:
+ self._invoke_done_callback(done_callback)
+ self._done_callbacks = []
+
+ def _invoke_done_callback(self, done_callback):
+ return done_callback(self)
+
+ def done(self):
+ return self._done
+
+ def add_done_callback(self, fn):
+ if self._done:
+ self._invoke_done_callback(fn)
+ else:
+ self._done_callbacks.append(fn)
+
+
+TaskTag = namedtuple('TaskTag', ['name'])
+
+IN_MEMORY_UPLOAD_TAG = TaskTag('in_memory_upload')
+IN_MEMORY_DOWNLOAD_TAG = TaskTag('in_memory_download')
diff --git a/contrib/python/s3transfer/py3/s3transfer/manager.py b/contrib/python/s3transfer/py3/s3transfer/manager.py
new file mode 100644
index 0000000000..ff6afa12c1
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/manager.py
@@ -0,0 +1,727 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import logging
+import re
+import threading
+
+from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket
+from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB
+from s3transfer.copies import CopySubmissionTask
+from s3transfer.delete import DeleteSubmissionTask
+from s3transfer.download import DownloadSubmissionTask
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import (
+ IN_MEMORY_DOWNLOAD_TAG,
+ IN_MEMORY_UPLOAD_TAG,
+ BoundedExecutor,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.upload import UploadSubmissionTask
+from s3transfer.utils import (
+ CallArgs,
+ OSUtils,
+ SlidingWindowSemaphore,
+ TaskSemaphore,
+ get_callbacks,
+ signal_not_transferring,
+ signal_transferring,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class TransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ multipart_chunksize=8 * MB,
+ max_request_concurrency=10,
+ max_submission_concurrency=5,
+ max_request_queue_size=1000,
+ max_submission_queue_size=1000,
+ max_io_queue_size=1000,
+ io_chunksize=256 * KB,
+ num_download_attempts=5,
+ max_in_memory_upload_chunks=10,
+ max_in_memory_download_chunks=10,
+ max_bandwidth=None,
+ ):
+ """Configurations for the transfer manager
+
+ :param multipart_threshold: The threshold for which multipart
+ transfers occur.
+
+ :param max_request_concurrency: The maximum number of S3 API
+ transfer-related requests that can happen at a time.
+
+ :param max_submission_concurrency: The maximum number of threads
+ processing a call to a TransferManager method. Processing a
+ call usually entails determining which S3 API requests that need
+ to be enqueued, but does **not** entail making any of the
+ S3 API data transferring requests needed to perform the transfer.
+ The threads controlled by ``max_request_concurrency`` is
+ responsible for that.
+
+ :param multipart_chunksize: The size of each transfer if a request
+ becomes a multipart transfer.
+
+ :param max_request_queue_size: The maximum amount of S3 API requests
+ that can be queued at a time.
+
+ :param max_submission_queue_size: The maximum amount of
+ TransferManager method calls that can be queued at a time.
+
+ :param max_io_queue_size: The maximum amount of read parts that
+ can be queued to be written to disk per download. The default
+ size for each elementin this queue is 8 KB.
+
+ :param io_chunksize: The max size of each chunk in the io queue.
+ Currently, this is size used when reading from the downloaded
+ stream as well.
+
+ :param num_download_attempts: The number of download attempts that
+ will be tried upon errors with downloading an object in S3. Note
+ that these retries account for errors that occur when streaming
+ down the data from s3 (i.e. socket errors and read timeouts that
+ occur after receiving an OK response from s3).
+ Other retryable exceptions such as throttling errors and 5xx errors
+ are already retried by botocore (this default is 5). The
+ ``num_download_attempts`` does not take into account the
+ number of exceptions retried by botocore.
+
+ :param max_in_memory_upload_chunks: The number of chunks that can
+ be stored in memory at a time for all ongoing upload requests.
+ This pertains to chunks of data that need to be stored in memory
+ during an upload if the data is sourced from a file-like object.
+ The total maximum memory footprint due to a in-memory upload
+ chunks is roughly equal to:
+
+ max_in_memory_upload_chunks * multipart_chunksize
+ + max_submission_concurrency * multipart_chunksize
+
+ ``max_submission_concurrency`` has an affect on this value because
+ for each thread pulling data off of a file-like object, they may
+ be waiting with a single read chunk to be submitted for upload
+ because the ``max_in_memory_upload_chunks`` value has been reached
+ by the threads making the upload request.
+
+ :param max_in_memory_download_chunks: The number of chunks that can
+ be buffered in memory and **not** in the io queue at a time for all
+ ongoing download requests. This pertains specifically to file-like
+ objects that cannot be seeked. The total maximum memory footprint
+ due to a in-memory download chunks is roughly equal to:
+
+ max_in_memory_download_chunks * multipart_chunksize
+
+ :param max_bandwidth: The maximum bandwidth that will be consumed
+ in uploading and downloading file content. The value is in terms of
+ bytes per second.
+ """
+ self.multipart_threshold = multipart_threshold
+ self.multipart_chunksize = multipart_chunksize
+ self.max_request_concurrency = max_request_concurrency
+ self.max_submission_concurrency = max_submission_concurrency
+ self.max_request_queue_size = max_request_queue_size
+ self.max_submission_queue_size = max_submission_queue_size
+ self.max_io_queue_size = max_io_queue_size
+ self.io_chunksize = io_chunksize
+ self.num_download_attempts = num_download_attempts
+ self.max_in_memory_upload_chunks = max_in_memory_upload_chunks
+ self.max_in_memory_download_chunks = max_in_memory_download_chunks
+ self.max_bandwidth = max_bandwidth
+ self._validate_attrs_are_nonzero()
+
+ def _validate_attrs_are_nonzero(self):
+ for attr, attr_val in self.__dict__.items():
+ if attr_val is not None and attr_val <= 0:
+ raise ValueError(
+ 'Provided parameter %s of value %s must be greater than '
+ '0.' % (attr, attr_val)
+ )
+
+
+class TransferManager:
+ ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS
+
+ ALLOWED_UPLOAD_ARGS = [
+ 'ACL',
+ 'CacheControl',
+ 'ContentDisposition',
+ 'ContentEncoding',
+ 'ContentLanguage',
+ 'ContentType',
+ 'ExpectedBucketOwner',
+ 'Expires',
+ 'GrantFullControl',
+ 'GrantRead',
+ 'GrantReadACP',
+ 'GrantWriteACP',
+ 'Metadata',
+ 'RequestPayer',
+ 'ServerSideEncryption',
+ 'StorageClass',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKey',
+ 'SSECustomerKeyMD5',
+ 'SSEKMSKeyId',
+ 'SSEKMSEncryptionContext',
+ 'Tagging',
+ 'WebsiteRedirectLocation',
+ ]
+
+ ALLOWED_COPY_ARGS = ALLOWED_UPLOAD_ARGS + [
+ 'CopySourceIfMatch',
+ 'CopySourceIfModifiedSince',
+ 'CopySourceIfNoneMatch',
+ 'CopySourceIfUnmodifiedSince',
+ 'CopySourceSSECustomerAlgorithm',
+ 'CopySourceSSECustomerKey',
+ 'CopySourceSSECustomerKeyMD5',
+ 'MetadataDirective',
+ 'TaggingDirective',
+ ]
+
+ ALLOWED_DELETE_ARGS = [
+ 'MFA',
+ 'VersionId',
+ 'RequestPayer',
+ 'ExpectedBucketOwner',
+ ]
+
+ VALIDATE_SUPPORTED_BUCKET_VALUES = True
+
+ _UNSUPPORTED_BUCKET_PATTERNS = {
+ 'S3 Object Lambda': re.compile(
+ r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:'
+ r'accesspoint[/:][a-zA-Z0-9\-]{1,63}'
+ ),
+ }
+
+ def __init__(self, client, config=None, osutil=None, executor_cls=None):
+ """A transfer manager interface for Amazon S3
+
+ :param client: Client to be used by the manager
+ :param config: TransferConfig to associate specific configurations
+ :param osutil: OSUtils object to use for os-related behavior when
+ using with transfer manager.
+
+ :type executor_cls: s3transfer.futures.BaseExecutor
+ :param executor_cls: The class of executor to use with the transfer
+ manager. By default, concurrent.futures.ThreadPoolExecutor is used.
+ """
+ self._client = client
+ self._config = config
+ if config is None:
+ self._config = TransferConfig()
+ self._osutil = osutil
+ if osutil is None:
+ self._osutil = OSUtils()
+ self._coordinator_controller = TransferCoordinatorController()
+ # A counter to create unique id's for each transfer submitted.
+ self._id_counter = 0
+
+ # The executor responsible for making S3 API transfer requests
+ self._request_executor = BoundedExecutor(
+ max_size=self._config.max_request_queue_size,
+ max_num_threads=self._config.max_request_concurrency,
+ tag_semaphores={
+ IN_MEMORY_UPLOAD_TAG: TaskSemaphore(
+ self._config.max_in_memory_upload_chunks
+ ),
+ IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(
+ self._config.max_in_memory_download_chunks
+ ),
+ },
+ executor_cls=executor_cls,
+ )
+
+ # The executor responsible for submitting the necessary tasks to
+ # perform the desired transfer
+ self._submission_executor = BoundedExecutor(
+ max_size=self._config.max_submission_queue_size,
+ max_num_threads=self._config.max_submission_concurrency,
+ executor_cls=executor_cls,
+ )
+
+ # There is one thread available for writing to disk. It will handle
+ # downloads for all files.
+ self._io_executor = BoundedExecutor(
+ max_size=self._config.max_io_queue_size,
+ max_num_threads=1,
+ executor_cls=executor_cls,
+ )
+
+ # The component responsible for limiting bandwidth usage if it
+ # is configured.
+ self._bandwidth_limiter = None
+ if self._config.max_bandwidth is not None:
+ logger.debug(
+ 'Setting max_bandwidth to %s', self._config.max_bandwidth
+ )
+ leaky_bucket = LeakyBucket(self._config.max_bandwidth)
+ self._bandwidth_limiter = BandwidthLimiter(leaky_bucket)
+
+ self._register_handlers()
+
+ @property
+ def client(self):
+ return self._client
+
+ @property
+ def config(self):
+ return self._config
+
+ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
+ """Uploads a file to S3
+
+ :type fileobj: str or seekable file-like object
+ :param fileobj: The name of a file to upload or a seekable file-like
+ object to upload. It is recommended to use a filename because
+ file-like objects may result in higher memory usage.
+
+ :type bucket: str
+ :param bucket: The name of the bucket to upload to
+
+ :type key: str
+ :param key: The name of the key to upload to
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ client operation
+
+ :type subscribers: list(s3transfer.subscribers.BaseSubscriber)
+ :param subscribers: The list of subscribers to be invoked in the
+ order provided based on the event emit during the process of
+ the transfer request.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :returns: Transfer future representing the upload
+ """
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = []
+ self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS)
+ self._validate_if_bucket_supported(bucket)
+ call_args = CallArgs(
+ fileobj=fileobj,
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ extra_main_kwargs = {}
+ if self._bandwidth_limiter:
+ extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter
+ return self._submit_transfer(
+ call_args, UploadSubmissionTask, extra_main_kwargs
+ )
+
+ def download(
+ self, bucket, key, fileobj, extra_args=None, subscribers=None
+ ):
+ """Downloads a file from S3
+
+ :type bucket: str
+ :param bucket: The name of the bucket to download from
+
+ :type key: str
+ :param key: The name of the key to download from
+
+ :type fileobj: str or seekable file-like object
+ :param fileobj: The name of a file to download or a seekable file-like
+ object to download. It is recommended to use a filename because
+ file-like objects may result in higher memory usage.
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ client operation
+
+ :type subscribers: list(s3transfer.subscribers.BaseSubscriber)
+ :param subscribers: The list of subscribers to be invoked in the
+ order provided based on the event emit during the process of
+ the transfer request.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :returns: Transfer future representing the download
+ """
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = []
+ self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS)
+ self._validate_if_bucket_supported(bucket)
+ call_args = CallArgs(
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ extra_main_kwargs = {'io_executor': self._io_executor}
+ if self._bandwidth_limiter:
+ extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter
+ return self._submit_transfer(
+ call_args, DownloadSubmissionTask, extra_main_kwargs
+ )
+
+ def copy(
+ self,
+ copy_source,
+ bucket,
+ key,
+ extra_args=None,
+ subscribers=None,
+ source_client=None,
+ ):
+ """Copies a file in S3
+
+ :type copy_source: dict
+ :param copy_source: The name of the source bucket, key name of the
+ source object, and optional version ID of the source object. The
+ dictionary format is:
+ ``{'Bucket': 'bucket', 'Key': 'key', 'VersionId': 'id'}``. Note
+ that the ``VersionId`` key is optional and may be omitted.
+
+ :type bucket: str
+ :param bucket: The name of the bucket to copy to
+
+ :type key: str
+ :param key: The name of the key to copy to
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ client operation
+
+ :type subscribers: a list of subscribers
+ :param subscribers: The list of subscribers to be invoked in the
+ order provided based on the event emit during the process of
+ the transfer request.
+
+ :type source_client: botocore or boto3 Client
+ :param source_client: The client to be used for operation that
+ may happen at the source object. For example, this client is
+ used for the head_object that determines the size of the copy.
+ If no client is provided, the transfer manager's client is used
+ as the client for the source object.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :returns: Transfer future representing the copy
+ """
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = []
+ if source_client is None:
+ source_client = self._client
+ self._validate_all_known_args(extra_args, self.ALLOWED_COPY_ARGS)
+ if isinstance(copy_source, dict):
+ self._validate_if_bucket_supported(copy_source.get('Bucket'))
+ self._validate_if_bucket_supported(bucket)
+ call_args = CallArgs(
+ copy_source=copy_source,
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ source_client=source_client,
+ )
+ return self._submit_transfer(call_args, CopySubmissionTask)
+
+ def delete(self, bucket, key, extra_args=None, subscribers=None):
+ """Delete an S3 object.
+
+ :type bucket: str
+ :param bucket: The name of the bucket.
+
+ :type key: str
+ :param key: The name of the S3 object to delete.
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ DeleteObject call.
+
+ :type subscribers: list
+ :param subscribers: A list of subscribers to be invoked during the
+ process of the transfer request. Note that the ``on_progress``
+ callback is not invoked during object deletion.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :return: Transfer future representing the deletion.
+
+ """
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = []
+ self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS)
+ self._validate_if_bucket_supported(bucket)
+ call_args = CallArgs(
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer(call_args, DeleteSubmissionTask)
+
+ def _validate_if_bucket_supported(self, bucket):
+ # s3 high level operations don't support some resources
+ # (eg. S3 Object Lambda) only direct API calls are available
+ # for such resources
+ if self.VALIDATE_SUPPORTED_BUCKET_VALUES:
+ for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items():
+ match = pattern.match(bucket)
+ if match:
+ raise ValueError(
+ 'TransferManager methods do not support %s '
+ 'resource. Use direct client calls instead.' % resource
+ )
+
+ def _validate_all_known_args(self, actual, allowed):
+ for kwarg in actual:
+ if kwarg not in allowed:
+ raise ValueError(
+ "Invalid extra_args key '%s', "
+ "must be one of: %s" % (kwarg, ', '.join(allowed))
+ )
+
+ def _submit_transfer(
+ self, call_args, submission_task_cls, extra_main_kwargs=None
+ ):
+ if not extra_main_kwargs:
+ extra_main_kwargs = {}
+
+ # Create a TransferFuture to return back to the user
+ transfer_future, components = self._get_future_with_components(
+ call_args
+ )
+
+ # Add any provided done callbacks to the created transfer future
+ # to be invoked on the transfer future being complete.
+ for callback in get_callbacks(transfer_future, 'done'):
+ components['coordinator'].add_done_callback(callback)
+
+ # Get the main kwargs needed to instantiate the submission task
+ main_kwargs = self._get_submission_task_main_kwargs(
+ transfer_future, extra_main_kwargs
+ )
+
+ # Submit a SubmissionTask that will submit all of the necessary
+ # tasks needed to complete the S3 transfer.
+ self._submission_executor.submit(
+ submission_task_cls(
+ transfer_coordinator=components['coordinator'],
+ main_kwargs=main_kwargs,
+ )
+ )
+
+ # Increment the unique id counter for future transfer requests
+ self._id_counter += 1
+
+ return transfer_future
+
+ def _get_future_with_components(self, call_args):
+ transfer_id = self._id_counter
+ # Creates a new transfer future along with its components
+ transfer_coordinator = TransferCoordinator(transfer_id=transfer_id)
+ # Track the transfer coordinator for transfers to manage.
+ self._coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Also make sure that the transfer coordinator is removed once
+ # the transfer completes so it does not stick around in memory.
+ transfer_coordinator.add_done_callback(
+ self._coordinator_controller.remove_transfer_coordinator,
+ transfer_coordinator,
+ )
+ components = {
+ 'meta': TransferMeta(call_args, transfer_id=transfer_id),
+ 'coordinator': transfer_coordinator,
+ }
+ transfer_future = TransferFuture(**components)
+ return transfer_future, components
+
+ def _get_submission_task_main_kwargs(
+ self, transfer_future, extra_main_kwargs
+ ):
+ main_kwargs = {
+ 'client': self._client,
+ 'config': self._config,
+ 'osutil': self._osutil,
+ 'request_executor': self._request_executor,
+ 'transfer_future': transfer_future,
+ }
+ main_kwargs.update(extra_main_kwargs)
+ return main_kwargs
+
+ def _register_handlers(self):
+ # Register handlers to enable/disable callbacks on uploads.
+ event_name = 'request-created.s3'
+ self._client.meta.events.register_first(
+ event_name,
+ signal_not_transferring,
+ unique_id='s3upload-not-transferring',
+ )
+ self._client.meta.events.register_last(
+ event_name, signal_transferring, unique_id='s3upload-transferring'
+ )
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, *args):
+ cancel = False
+ cancel_msg = ''
+ cancel_exc_type = FatalError
+ # If a exception was raised in the context handler, signal to cancel
+ # all of the inprogress futures in the shutdown.
+ if exc_type:
+ cancel = True
+ cancel_msg = str(exc_value)
+ if not cancel_msg:
+ cancel_msg = repr(exc_value)
+ # If it was a KeyboardInterrupt, the cancellation was initiated
+ # by the user.
+ if isinstance(exc_value, KeyboardInterrupt):
+ cancel_exc_type = CancelledError
+ self._shutdown(cancel, cancel_msg, cancel_exc_type)
+
+ def shutdown(self, cancel=False, cancel_msg=''):
+ """Shutdown the TransferManager
+
+ It will wait till all transfers complete before it completely shuts
+ down.
+
+ :type cancel: boolean
+ :param cancel: If True, calls TransferFuture.cancel() for
+ all in-progress in transfers. This is useful if you want the
+ shutdown to happen quicker.
+
+ :type cancel_msg: str
+ :param cancel_msg: The message to specify if canceling all in-progress
+ transfers.
+ """
+ self._shutdown(cancel, cancel, cancel_msg)
+
+ def _shutdown(self, cancel, cancel_msg, exc_type=CancelledError):
+ if cancel:
+ # Cancel all in-flight transfers if requested, before waiting
+ # for them to complete.
+ self._coordinator_controller.cancel(cancel_msg, exc_type)
+ try:
+ # Wait until there are no more in-progress transfers. This is
+ # wrapped in a try statement because this can be interrupted
+ # with a KeyboardInterrupt that needs to be caught.
+ self._coordinator_controller.wait()
+ except KeyboardInterrupt:
+ # If not errors were raised in the try block, the cancel should
+ # have no coordinators it needs to run cancel on. If there was
+ # an error raised in the try statement we want to cancel all of
+ # the inflight transfers before shutting down to speed that
+ # process up.
+ self._coordinator_controller.cancel('KeyboardInterrupt()')
+ raise
+ finally:
+ # Shutdown all of the executors.
+ self._submission_executor.shutdown()
+ self._request_executor.shutdown()
+ self._io_executor.shutdown()
+
+
+class TransferCoordinatorController:
+ def __init__(self):
+ """Abstraction to control all transfer coordinators
+
+ This abstraction allows the manager to wait for inprogress transfers
+ to complete and cancel all inprogress transfers.
+ """
+ self._lock = threading.Lock()
+ self._tracked_transfer_coordinators = set()
+
+ @property
+ def tracked_transfer_coordinators(self):
+ """The set of transfer coordinators being tracked"""
+ with self._lock:
+ # We return a copy because the set is mutable and if you were to
+ # iterate over the set, it may be changing in length due to
+ # additions and removals of transfer coordinators.
+ return copy.copy(self._tracked_transfer_coordinators)
+
+ def add_transfer_coordinator(self, transfer_coordinator):
+ """Adds a transfer coordinator of a transfer to be canceled if needed
+
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ :param transfer_coordinator: The transfer coordinator for the
+ particular transfer
+ """
+ with self._lock:
+ self._tracked_transfer_coordinators.add(transfer_coordinator)
+
+ def remove_transfer_coordinator(self, transfer_coordinator):
+ """Remove a transfer coordinator from cancellation consideration
+
+ Typically, this method is invoked by the transfer coordinator itself
+ to remove its self when it completes its transfer.
+
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ :param transfer_coordinator: The transfer coordinator for the
+ particular transfer
+ """
+ with self._lock:
+ self._tracked_transfer_coordinators.remove(transfer_coordinator)
+
+ def cancel(self, msg='', exc_type=CancelledError):
+ """Cancels all inprogress transfers
+
+ This cancels the inprogress transfers by calling cancel() on all
+ tracked transfer coordinators.
+
+ :param msg: The message to pass on to each transfer coordinator that
+ gets cancelled.
+
+ :param exc_type: The type of exception to set for the cancellation
+ """
+ for transfer_coordinator in self.tracked_transfer_coordinators:
+ transfer_coordinator.cancel(msg, exc_type)
+
+ def wait(self):
+ """Wait until there are no more inprogress transfers
+
+ This will not stop when failures are encountered and not propagate any
+ of these errors from failed transfers, but it can be interrupted with
+ a KeyboardInterrupt.
+ """
+ try:
+ transfer_coordinator = None
+ for transfer_coordinator in self.tracked_transfer_coordinators:
+ transfer_coordinator.result()
+ except KeyboardInterrupt:
+ logger.debug('Received KeyboardInterrupt in wait()')
+ # If Keyboard interrupt is raised while waiting for
+ # the result, then exit out of the wait and raise the
+ # exception
+ if transfer_coordinator:
+ logger.debug(
+ 'On KeyboardInterrupt was waiting for %s',
+ transfer_coordinator,
+ )
+ raise
+ except Exception:
+ # A general exception could have been thrown because
+ # of result(). We just want to ignore this and continue
+ # because we at least know that the transfer coordinator
+ # has completed.
+ pass
diff --git a/contrib/python/s3transfer/py3/s3transfer/processpool.py b/contrib/python/s3transfer/py3/s3transfer/processpool.py
new file mode 100644
index 0000000000..017eeb4499
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/processpool.py
@@ -0,0 +1,1008 @@
+# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Speeds up S3 throughput by using processes
+
+Getting Started
+===============
+
+The :class:`ProcessPoolDownloader` can be used to download a single file by
+calling :meth:`ProcessPoolDownloader.download_file`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+
+
+This snippet downloads the S3 object located in the bucket ``mybucket`` at the
+key ``mykey`` to the local file ``myfile``. Any errors encountered during the
+transfer are not propagated. To determine if a transfer succeeded or
+failed, use the `Futures`_ interface.
+
+
+The :class:`ProcessPoolDownloader` can be used to download multiple files as
+well:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+ downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
+
+
+When running this snippet, the downloading of ``mykey`` and ``myotherkey``
+happen in parallel. The first ``download_file`` call does not block the
+second ``download_file`` call. The snippet blocks when exiting
+the context manager and blocks until both downloads are complete.
+
+Alternatively, the ``ProcessPoolDownloader`` can be instantiated
+and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ downloader = ProcessPoolDownloader()
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+ downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
+ downloader.shutdown()
+
+
+For this code snippet, the call to ``shutdown`` blocks until both
+downloads are complete.
+
+
+Additional Parameters
+=====================
+
+Additional parameters can be provided to the ``download_file`` method:
+
+* ``extra_args``: A dictionary containing any additional client arguments
+ to include in the
+ `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_
+ API request. For example:
+
+ .. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file(
+ 'mybucket', 'mykey', 'myfile',
+ extra_args={'VersionId': 'myversion'})
+
+
+* ``expected_size``: By default, the downloader will make a HeadObject
+ call to determine the size of the object. To opt-out of this additional
+ API call, you can provide the size of the object in bytes:
+
+ .. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ MB = 1024 * 1024
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file(
+ 'mybucket', 'mykey', 'myfile', expected_size=2 * MB)
+
+
+Futures
+=======
+
+When ``download_file`` is called, it immediately returns a
+:class:`ProcessPoolTransferFuture`. The future can be used to poll the state
+of a particular transfer. To get the result of the download,
+call :meth:`ProcessPoolTransferFuture.result`. The method blocks
+until the transfer completes, whether it succeeds or fails. For example:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ future = downloader.download_file('mybucket', 'mykey', 'myfile')
+ print(future.result())
+
+
+If the download succeeds, the future returns ``None``:
+
+.. code:: python
+
+ None
+
+
+If the download fails, the exception causing the failure is raised. For
+example, if ``mykey`` did not exist, the following error would be raised
+
+
+.. code:: python
+
+ botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found
+
+
+.. note::
+
+ :meth:`ProcessPoolTransferFuture.result` can only be called while the
+ ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or
+ inside the context manager).
+
+
+Process Pool Configuration
+==========================
+
+By default, the downloader has the following configuration options:
+
+* ``multipart_threshold``: The threshold size for performing ranged downloads
+ in bytes. By default, ranged downloads happen for S3 objects that are
+ greater than or equal to 8 MB in size.
+
+* ``multipart_chunksize``: The size of each ranged download in bytes. By
+ default, the size of each ranged download is 8 MB.
+
+* ``max_request_processes``: The maximum number of processes used to download
+ S3 objects. By default, the maximum is 10 processes.
+
+
+To change the default configuration, use the :class:`ProcessTransferConfig`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+ from s3transfer.processpool import ProcessTransferConfig
+
+ config = ProcessTransferConfig(
+ multipart_threshold=64 * 1024 * 1024, # 64 MB
+ max_request_processes=50
+ )
+ downloader = ProcessPoolDownloader(config=config)
+
+
+Client Configuration
+====================
+
+The process pool downloader creates ``botocore`` clients on your behalf. In
+order to affect how the client is created, pass the keyword arguments
+that would have been used in the :meth:`botocore.Session.create_client` call:
+
+.. code:: python
+
+
+ from s3transfer.processpool import ProcessPoolDownloader
+ from s3transfer.processpool import ProcessTransferConfig
+
+ downloader = ProcessPoolDownloader(
+ client_kwargs={'region_name': 'us-west-2'})
+
+
+This snippet ensures that all clients created by the ``ProcessPoolDownloader``
+are using ``us-west-2`` as their region.
+
+"""
+import collections
+import contextlib
+import logging
+import multiprocessing
+import signal
+import threading
+from copy import deepcopy
+
+import botocore.session
+from botocore.config import Config
+
+from s3transfer.compat import MAXINT, BaseManager
+from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT
+from s3transfer.exceptions import CancelledError, RetriesExceededError
+from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
+from s3transfer.utils import (
+ S3_RETRYABLE_DOWNLOAD_ERRORS,
+ CallArgs,
+ OSUtils,
+ calculate_num_parts,
+ calculate_range_parameter,
+)
+
+logger = logging.getLogger(__name__)
+
+SHUTDOWN_SIGNAL = 'SHUTDOWN'
+
+# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader
+# to the GetObjectSubmitter in order for the submitter to begin submitting
+# GetObjectJobs to the GetObjectWorkers.
+DownloadFileRequest = collections.namedtuple(
+ 'DownloadFileRequest',
+ [
+ 'transfer_id', # The unique id for the transfer
+ 'bucket', # The bucket to download the object from
+ 'key', # The key to download the object from
+ 'filename', # The user-requested download location
+ 'extra_args', # Extra arguments to provide to client calls
+ 'expected_size', # The user-provided expected size of the download
+ ],
+)
+
+# The GetObjectJob tuple is submitted from the GetObjectSubmitter
+# to the GetObjectWorkers to download the file or parts of the file.
+GetObjectJob = collections.namedtuple(
+ 'GetObjectJob',
+ [
+ 'transfer_id', # The unique id for the transfer
+ 'bucket', # The bucket to download the object from
+ 'key', # The key to download the object from
+ 'temp_filename', # The temporary file to write the content to via
+ # completed GetObject calls.
+ 'extra_args', # Extra arguments to provide to the GetObject call
+ 'offset', # The offset to write the content for the temp file.
+ 'filename', # The user-requested download location. The worker
+ # of final GetObjectJob will move the file located at
+ # temp_filename to the location of filename.
+ ],
+)
+
+
+@contextlib.contextmanager
+def ignore_ctrl_c():
+ original_handler = _add_ignore_handler_for_interrupts()
+ yield
+ signal.signal(signal.SIGINT, original_handler)
+
+
+def _add_ignore_handler_for_interrupts():
+ # Windows is unable to pickle signal.signal directly so it needs to
+ # be wrapped in a function defined at the module level
+ return signal.signal(signal.SIGINT, signal.SIG_IGN)
+
+
+class ProcessTransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ multipart_chunksize=8 * MB,
+ max_request_processes=10,
+ ):
+ """Configuration for the ProcessPoolDownloader
+
+ :param multipart_threshold: The threshold for which ranged downloads
+ occur.
+
+ :param multipart_chunksize: The chunk size of each ranged download.
+
+ :param max_request_processes: The maximum number of processes that
+ will be making S3 API transfer-related requests at a time.
+ """
+ self.multipart_threshold = multipart_threshold
+ self.multipart_chunksize = multipart_chunksize
+ self.max_request_processes = max_request_processes
+
+
+class ProcessPoolDownloader:
+ def __init__(self, client_kwargs=None, config=None):
+ """Downloads S3 objects using process pools
+
+ :type client_kwargs: dict
+ :param client_kwargs: The keyword arguments to provide when
+ instantiating S3 clients. The arguments must match the keyword
+ arguments provided to the
+ `botocore.session.Session.create_client()` method.
+
+ :type config: ProcessTransferConfig
+ :param config: Configuration for the downloader
+ """
+ if client_kwargs is None:
+ client_kwargs = {}
+ self._client_factory = ClientFactory(client_kwargs)
+
+ self._transfer_config = config
+ if config is None:
+ self._transfer_config = ProcessTransferConfig()
+
+ self._download_request_queue = multiprocessing.Queue(1000)
+ self._worker_queue = multiprocessing.Queue(1000)
+ self._osutil = OSUtils()
+
+ self._started = False
+ self._start_lock = threading.Lock()
+
+ # These below are initialized in the start() method
+ self._manager = None
+ self._transfer_monitor = None
+ self._submitter = None
+ self._workers = []
+
+ def download_file(
+ self, bucket, key, filename, extra_args=None, expected_size=None
+ ):
+ """Downloads the object's contents to a file
+
+ :type bucket: str
+ :param bucket: The name of the bucket to download from
+
+ :type key: str
+ :param key: The name of the key to download from
+
+ :type filename: str
+ :param filename: The name of a file to download to.
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ client operation
+
+ :type expected_size: int
+ :param expected_size: The expected size in bytes of the download. If
+ provided, the downloader will not call HeadObject to determine the
+ object's size and use the provided value instead. The size is
+ needed to determine whether to do a multipart download.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :returns: Transfer future representing the download
+ """
+ self._start_if_needed()
+ if extra_args is None:
+ extra_args = {}
+ self._validate_all_known_args(extra_args)
+ transfer_id = self._transfer_monitor.notify_new_transfer()
+ download_file_request = DownloadFileRequest(
+ transfer_id=transfer_id,
+ bucket=bucket,
+ key=key,
+ filename=filename,
+ extra_args=extra_args,
+ expected_size=expected_size,
+ )
+ logger.debug(
+ 'Submitting download file request: %s.', download_file_request
+ )
+ self._download_request_queue.put(download_file_request)
+ call_args = CallArgs(
+ bucket=bucket,
+ key=key,
+ filename=filename,
+ extra_args=extra_args,
+ expected_size=expected_size,
+ )
+ future = self._get_transfer_future(transfer_id, call_args)
+ return future
+
+ def shutdown(self):
+ """Shutdown the downloader
+
+ It will wait till all downloads are complete before returning.
+ """
+ self._shutdown_if_needed()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, *args):
+ if isinstance(exc_value, KeyboardInterrupt):
+ if self._transfer_monitor is not None:
+ self._transfer_monitor.notify_cancel_all_in_progress()
+ self.shutdown()
+
+ def _start_if_needed(self):
+ with self._start_lock:
+ if not self._started:
+ self._start()
+
+ def _start(self):
+ self._start_transfer_monitor_manager()
+ self._start_submitter()
+ self._start_get_object_workers()
+ self._started = True
+
+ def _validate_all_known_args(self, provided):
+ for kwarg in provided:
+ if kwarg not in ALLOWED_DOWNLOAD_ARGS:
+ download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS)
+ raise ValueError(
+ f"Invalid extra_args key '{kwarg}', "
+ f"must be one of: {download_args}"
+ )
+
+ def _get_transfer_future(self, transfer_id, call_args):
+ meta = ProcessPoolTransferMeta(
+ call_args=call_args, transfer_id=transfer_id
+ )
+ future = ProcessPoolTransferFuture(
+ monitor=self._transfer_monitor, meta=meta
+ )
+ return future
+
+ def _start_transfer_monitor_manager(self):
+ logger.debug('Starting the TransferMonitorManager.')
+ self._manager = TransferMonitorManager()
+ # We do not want Ctrl-C's to cause the manager to shutdown immediately
+ # as worker processes will still need to communicate with it when they
+ # are shutting down. So instead we ignore Ctrl-C and let the manager
+ # be explicitly shutdown when shutting down the downloader.
+ self._manager.start(_add_ignore_handler_for_interrupts)
+ self._transfer_monitor = self._manager.TransferMonitor()
+
+ def _start_submitter(self):
+ logger.debug('Starting the GetObjectSubmitter.')
+ self._submitter = GetObjectSubmitter(
+ transfer_config=self._transfer_config,
+ client_factory=self._client_factory,
+ transfer_monitor=self._transfer_monitor,
+ osutil=self._osutil,
+ download_request_queue=self._download_request_queue,
+ worker_queue=self._worker_queue,
+ )
+ self._submitter.start()
+
+ def _start_get_object_workers(self):
+ logger.debug(
+ 'Starting %s GetObjectWorkers.',
+ self._transfer_config.max_request_processes,
+ )
+ for _ in range(self._transfer_config.max_request_processes):
+ worker = GetObjectWorker(
+ queue=self._worker_queue,
+ client_factory=self._client_factory,
+ transfer_monitor=self._transfer_monitor,
+ osutil=self._osutil,
+ )
+ worker.start()
+ self._workers.append(worker)
+
+ def _shutdown_if_needed(self):
+ with self._start_lock:
+ if self._started:
+ self._shutdown()
+
+ def _shutdown(self):
+ self._shutdown_submitter()
+ self._shutdown_get_object_workers()
+ self._shutdown_transfer_monitor_manager()
+ self._started = False
+
+ def _shutdown_transfer_monitor_manager(self):
+ logger.debug('Shutting down the TransferMonitorManager.')
+ self._manager.shutdown()
+
+ def _shutdown_submitter(self):
+ logger.debug('Shutting down the GetObjectSubmitter.')
+ self._download_request_queue.put(SHUTDOWN_SIGNAL)
+ self._submitter.join()
+
+ def _shutdown_get_object_workers(self):
+ logger.debug('Shutting down the GetObjectWorkers.')
+ for _ in self._workers:
+ self._worker_queue.put(SHUTDOWN_SIGNAL)
+ for worker in self._workers:
+ worker.join()
+
+
+class ProcessPoolTransferFuture(BaseTransferFuture):
+ def __init__(self, monitor, meta):
+ """The future associated to a submitted process pool transfer request
+
+ :type monitor: TransferMonitor
+ :param monitor: The monitor associated to the process pool downloader
+
+ :type meta: ProcessPoolTransferMeta
+ :param meta: The metadata associated to the request. This object
+ is visible to the requester.
+ """
+ self._monitor = monitor
+ self._meta = meta
+
+ @property
+ def meta(self):
+ return self._meta
+
+ def done(self):
+ return self._monitor.is_done(self._meta.transfer_id)
+
+ def result(self):
+ try:
+ return self._monitor.poll_for_result(self._meta.transfer_id)
+ except KeyboardInterrupt:
+ # For the multiprocessing Manager, a thread is given a single
+ # connection to reuse in communicating between the thread in the
+ # main process and the Manager's process. If a Ctrl-C happens when
+ # polling for the result, it will make the main thread stop trying
+ # to receive from the connection, but the Manager process will not
+ # know that the main process has stopped trying to receive and
+ # will not close the connection. As a result if another message is
+ # sent to the Manager process, the listener in the Manager
+ # processes will not process the new message as it is still trying
+ # trying to process the previous message (that was Ctrl-C'd) and
+ # thus cause the thread in the main process to hang on its send.
+ # The only way around this is to create a new connection and send
+ # messages from that new connection instead.
+ self._monitor._connect()
+ self.cancel()
+ raise
+
+ def cancel(self):
+ self._monitor.notify_exception(
+ self._meta.transfer_id, CancelledError()
+ )
+
+
+class ProcessPoolTransferMeta(BaseTransferMeta):
+ """Holds metadata about the ProcessPoolTransferFuture"""
+
+ def __init__(self, transfer_id, call_args):
+ self._transfer_id = transfer_id
+ self._call_args = call_args
+ self._user_context = {}
+
+ @property
+ def call_args(self):
+ return self._call_args
+
+ @property
+ def transfer_id(self):
+ return self._transfer_id
+
+ @property
+ def user_context(self):
+ return self._user_context
+
+
+class ClientFactory:
+ def __init__(self, client_kwargs=None):
+ """Creates S3 clients for processes
+
+ Botocore sessions and clients are not pickleable so they cannot be
+ inherited across Process boundaries. Instead, they must be instantiated
+ once a process is running.
+ """
+ self._client_kwargs = client_kwargs
+ if self._client_kwargs is None:
+ self._client_kwargs = {}
+
+ client_config = deepcopy(self._client_kwargs.get('config', Config()))
+ if not client_config.user_agent_extra:
+ client_config.user_agent_extra = PROCESS_USER_AGENT
+ else:
+ client_config.user_agent_extra += " " + PROCESS_USER_AGENT
+ self._client_kwargs['config'] = client_config
+
+ def create_client(self):
+ """Create a botocore S3 client"""
+ return botocore.session.Session().create_client(
+ 's3', **self._client_kwargs
+ )
+
+
+class TransferMonitor:
+ def __init__(self):
+ """Monitors transfers for cross-process communication
+
+ Notifications can be sent to the monitor and information can be
+ retrieved from the monitor for a particular transfer. This abstraction
+ is ran in a ``multiprocessing.managers.BaseManager`` in order to be
+ shared across processes.
+ """
+ # TODO: Add logic that removes the TransferState if the transfer is
+ # marked as done and the reference to the future is no longer being
+ # held onto. Without this logic, this dictionary will continue to
+ # grow in size with no limit.
+ self._transfer_states = {}
+ self._id_count = 0
+ self._init_lock = threading.Lock()
+
+ def notify_new_transfer(self):
+ with self._init_lock:
+ transfer_id = self._id_count
+ self._transfer_states[transfer_id] = TransferState()
+ self._id_count += 1
+ return transfer_id
+
+ def is_done(self, transfer_id):
+ """Determine a particular transfer is complete
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: True, if done. False, otherwise.
+ """
+ return self._transfer_states[transfer_id].done
+
+ def notify_done(self, transfer_id):
+ """Notify a particular transfer is complete
+
+ :param transfer_id: Unique identifier for the transfer
+ """
+ self._transfer_states[transfer_id].set_done()
+
+ def poll_for_result(self, transfer_id):
+ """Poll for the result of a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: If the transfer succeeded, it will return the result. If the
+ transfer failed, it will raise the exception associated to the
+ failure.
+ """
+ self._transfer_states[transfer_id].wait_till_done()
+ exception = self._transfer_states[transfer_id].exception
+ if exception:
+ raise exception
+ return None
+
+ def notify_exception(self, transfer_id, exception):
+ """Notify an exception was encountered for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :param exception: The exception encountered for that transfer
+ """
+ # TODO: Not all exceptions are pickleable so if we are running
+ # this in a multiprocessing.BaseManager we will want to
+ # make sure to update this signature to ensure pickleability of the
+ # arguments or have the ProxyObject do the serialization.
+ self._transfer_states[transfer_id].exception = exception
+
+ def notify_cancel_all_in_progress(self):
+ for transfer_state in self._transfer_states.values():
+ if not transfer_state.done:
+ transfer_state.exception = CancelledError()
+
+ def get_exception(self, transfer_id):
+ """Retrieve the exception encountered for the transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: The exception encountered for that transfer. Otherwise
+ if there were no exceptions, returns None.
+ """
+ return self._transfer_states[transfer_id].exception
+
+ def notify_expected_jobs_to_complete(self, transfer_id, num_jobs):
+ """Notify the amount of jobs expected for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :param num_jobs: The number of jobs to complete the transfer
+ """
+ self._transfer_states[transfer_id].jobs_to_complete = num_jobs
+
+ def notify_job_complete(self, transfer_id):
+ """Notify that a single job is completed for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: The number of jobs remaining to complete the transfer
+ """
+ return self._transfer_states[transfer_id].decrement_jobs_to_complete()
+
+
+class TransferState:
+ """Represents the current state of an individual transfer"""
+
+ # NOTE: Ideally the TransferState object would be used directly by the
+ # various different abstractions in the ProcessPoolDownloader and remove
+ # the need for the TransferMonitor. However, it would then impose the
+ # constraint that two hops are required to make or get any changes in the
+ # state of a transfer across processes: one hop to get a proxy object for
+ # the TransferState and then a second hop to communicate calling the
+ # specific TransferState method.
+ def __init__(self):
+ self._exception = None
+ self._done_event = threading.Event()
+ self._job_lock = threading.Lock()
+ self._jobs_to_complete = 0
+
+ @property
+ def done(self):
+ return self._done_event.is_set()
+
+ def set_done(self):
+ self._done_event.set()
+
+ def wait_till_done(self):
+ self._done_event.wait(MAXINT)
+
+ @property
+ def exception(self):
+ return self._exception
+
+ @exception.setter
+ def exception(self, val):
+ self._exception = val
+
+ @property
+ def jobs_to_complete(self):
+ return self._jobs_to_complete
+
+ @jobs_to_complete.setter
+ def jobs_to_complete(self, val):
+ self._jobs_to_complete = val
+
+ def decrement_jobs_to_complete(self):
+ with self._job_lock:
+ self._jobs_to_complete -= 1
+ return self._jobs_to_complete
+
+
+class TransferMonitorManager(BaseManager):
+ pass
+
+
+TransferMonitorManager.register('TransferMonitor', TransferMonitor)
+
+
+class BaseS3TransferProcess(multiprocessing.Process):
+ def __init__(self, client_factory):
+ super().__init__()
+ self._client_factory = client_factory
+ self._client = None
+
+ def run(self):
+ # Clients are not pickleable so their instantiation cannot happen
+ # in the __init__ for processes that are created under the
+ # spawn method.
+ self._client = self._client_factory.create_client()
+ with ignore_ctrl_c():
+ # By default these processes are ran as child processes to the
+ # main process. Any Ctrl-c encountered in the main process is
+ # propagated to the child process and interrupt it at any time.
+ # To avoid any potentially bad states caused from an interrupt
+ # (i.e. a transfer failing to notify its done or making the
+ # communication protocol become out of sync with the
+ # TransferMonitor), we ignore all Ctrl-C's and allow the main
+ # process to notify these child processes when to stop processing
+ # jobs.
+ self._do_run()
+
+ def _do_run(self):
+ raise NotImplementedError('_do_run()')
+
+
+class GetObjectSubmitter(BaseS3TransferProcess):
+ def __init__(
+ self,
+ transfer_config,
+ client_factory,
+ transfer_monitor,
+ osutil,
+ download_request_queue,
+ worker_queue,
+ ):
+ """Submit GetObjectJobs to fulfill a download file request
+
+ :param transfer_config: Configuration for transfers.
+ :param client_factory: ClientFactory for creating S3 clients.
+ :param transfer_monitor: Monitor for notifying and retrieving state
+ of transfer.
+ :param osutil: OSUtils object to use for os-related behavior when
+ performing the transfer.
+ :param download_request_queue: Queue to retrieve download file
+ requests.
+ :param worker_queue: Queue to submit GetObjectJobs for workers
+ to perform.
+ """
+ super().__init__(client_factory)
+ self._transfer_config = transfer_config
+ self._transfer_monitor = transfer_monitor
+ self._osutil = osutil
+ self._download_request_queue = download_request_queue
+ self._worker_queue = worker_queue
+
+ def _do_run(self):
+ while True:
+ download_file_request = self._download_request_queue.get()
+ if download_file_request == SHUTDOWN_SIGNAL:
+ logger.debug('Submitter shutdown signal received.')
+ return
+ try:
+ self._submit_get_object_jobs(download_file_request)
+ except Exception as e:
+ logger.debug(
+ 'Exception caught when submitting jobs for '
+ 'download file request %s: %s',
+ download_file_request,
+ e,
+ exc_info=True,
+ )
+ self._transfer_monitor.notify_exception(
+ download_file_request.transfer_id, e
+ )
+ self._transfer_monitor.notify_done(
+ download_file_request.transfer_id
+ )
+
+ def _submit_get_object_jobs(self, download_file_request):
+ size = self._get_size(download_file_request)
+ temp_filename = self._allocate_temp_file(download_file_request, size)
+ if size < self._transfer_config.multipart_threshold:
+ self._submit_single_get_object_job(
+ download_file_request, temp_filename
+ )
+ else:
+ self._submit_ranged_get_object_jobs(
+ download_file_request, temp_filename, size
+ )
+
+ def _get_size(self, download_file_request):
+ expected_size = download_file_request.expected_size
+ if expected_size is None:
+ expected_size = self._client.head_object(
+ Bucket=download_file_request.bucket,
+ Key=download_file_request.key,
+ **download_file_request.extra_args,
+ )['ContentLength']
+ return expected_size
+
+ def _allocate_temp_file(self, download_file_request, size):
+ temp_filename = self._osutil.get_temp_filename(
+ download_file_request.filename
+ )
+ self._osutil.allocate(temp_filename, size)
+ return temp_filename
+
+ def _submit_single_get_object_job(
+ self, download_file_request, temp_filename
+ ):
+ self._notify_jobs_to_complete(download_file_request.transfer_id, 1)
+ self._submit_get_object_job(
+ transfer_id=download_file_request.transfer_id,
+ bucket=download_file_request.bucket,
+ key=download_file_request.key,
+ temp_filename=temp_filename,
+ offset=0,
+ extra_args=download_file_request.extra_args,
+ filename=download_file_request.filename,
+ )
+
+ def _submit_ranged_get_object_jobs(
+ self, download_file_request, temp_filename, size
+ ):
+ part_size = self._transfer_config.multipart_chunksize
+ num_parts = calculate_num_parts(size, part_size)
+ self._notify_jobs_to_complete(
+ download_file_request.transfer_id, num_parts
+ )
+ for i in range(num_parts):
+ offset = i * part_size
+ range_parameter = calculate_range_parameter(
+ part_size, i, num_parts
+ )
+ get_object_kwargs = {'Range': range_parameter}
+ get_object_kwargs.update(download_file_request.extra_args)
+ self._submit_get_object_job(
+ transfer_id=download_file_request.transfer_id,
+ bucket=download_file_request.bucket,
+ key=download_file_request.key,
+ temp_filename=temp_filename,
+ offset=offset,
+ extra_args=get_object_kwargs,
+ filename=download_file_request.filename,
+ )
+
+ def _submit_get_object_job(self, **get_object_job_kwargs):
+ self._worker_queue.put(GetObjectJob(**get_object_job_kwargs))
+
+ def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete):
+ logger.debug(
+ 'Notifying %s job(s) to complete for transfer_id %s.',
+ jobs_to_complete,
+ transfer_id,
+ )
+ self._transfer_monitor.notify_expected_jobs_to_complete(
+ transfer_id, jobs_to_complete
+ )
+
+
+class GetObjectWorker(BaseS3TransferProcess):
+ # TODO: It may make sense to expose these class variables as configuration
+ # options if users want to tweak them.
+ _MAX_ATTEMPTS = 5
+ _IO_CHUNKSIZE = 2 * MB
+
+ def __init__(self, queue, client_factory, transfer_monitor, osutil):
+ """Fulfills GetObjectJobs
+
+ Downloads the S3 object, writes it to the specified file, and
+ renames the file to its final location if it completes the final
+ job for a particular transfer.
+
+ :param queue: Queue for retrieving GetObjectJob's
+ :param client_factory: ClientFactory for creating S3 clients
+ :param transfer_monitor: Monitor for notifying
+ :param osutil: OSUtils object to use for os-related behavior when
+ performing the transfer.
+ """
+ super().__init__(client_factory)
+ self._queue = queue
+ self._client_factory = client_factory
+ self._transfer_monitor = transfer_monitor
+ self._osutil = osutil
+
+ def _do_run(self):
+ while True:
+ job = self._queue.get()
+ if job == SHUTDOWN_SIGNAL:
+ logger.debug('Worker shutdown signal received.')
+ return
+ if not self._transfer_monitor.get_exception(job.transfer_id):
+ self._run_get_object_job(job)
+ else:
+ logger.debug(
+ 'Skipping get object job %s because there was a previous '
+ 'exception.',
+ job,
+ )
+ remaining = self._transfer_monitor.notify_job_complete(
+ job.transfer_id
+ )
+ logger.debug(
+ '%s jobs remaining for transfer_id %s.',
+ remaining,
+ job.transfer_id,
+ )
+ if not remaining:
+ self._finalize_download(
+ job.transfer_id, job.temp_filename, job.filename
+ )
+
+ def _run_get_object_job(self, job):
+ try:
+ self._do_get_object(
+ bucket=job.bucket,
+ key=job.key,
+ temp_filename=job.temp_filename,
+ extra_args=job.extra_args,
+ offset=job.offset,
+ )
+ except Exception as e:
+ logger.debug(
+ 'Exception caught when downloading object for '
+ 'get object job %s: %s',
+ job,
+ e,
+ exc_info=True,
+ )
+ self._transfer_monitor.notify_exception(job.transfer_id, e)
+
+ def _do_get_object(self, bucket, key, extra_args, temp_filename, offset):
+ last_exception = None
+ for i in range(self._MAX_ATTEMPTS):
+ try:
+ response = self._client.get_object(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ self._write_to_file(temp_filename, offset, response['Body'])
+ return
+ except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
+ logger.debug(
+ 'Retrying exception caught (%s), '
+ 'retrying request, (attempt %s / %s)',
+ e,
+ i + 1,
+ self._MAX_ATTEMPTS,
+ exc_info=True,
+ )
+ last_exception = e
+ raise RetriesExceededError(last_exception)
+
+ def _write_to_file(self, filename, offset, body):
+ with open(filename, 'rb+') as f:
+ f.seek(offset)
+ chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'')
+ for chunk in chunks:
+ f.write(chunk)
+
+ def _finalize_download(self, transfer_id, temp_filename, filename):
+ if self._transfer_monitor.get_exception(transfer_id):
+ self._osutil.remove_file(temp_filename)
+ else:
+ self._do_file_rename(transfer_id, temp_filename, filename)
+ self._transfer_monitor.notify_done(transfer_id)
+
+ def _do_file_rename(self, transfer_id, temp_filename, filename):
+ try:
+ self._osutil.rename_file(temp_filename, filename)
+ except Exception as e:
+ self._transfer_monitor.notify_exception(transfer_id, e)
+ self._osutil.remove_file(temp_filename)
diff --git a/contrib/python/s3transfer/py3/s3transfer/subscribers.py b/contrib/python/s3transfer/py3/s3transfer/subscribers.py
new file mode 100644
index 0000000000..cf0dbaa0d7
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/subscribers.py
@@ -0,0 +1,92 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.compat import accepts_kwargs
+from s3transfer.exceptions import InvalidSubscriberMethodError
+
+
+class BaseSubscriber:
+ """The base subscriber class
+
+ It is recommended that all subscriber implementations subclass and then
+ override the subscription methods (i.e. on_{subsribe_type}() methods).
+ """
+
+ VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done']
+
+ def __new__(cls, *args, **kwargs):
+ cls._validate_subscriber_methods()
+ return super().__new__(cls)
+
+ @classmethod
+ def _validate_subscriber_methods(cls):
+ for subscriber_type in cls.VALID_SUBSCRIBER_TYPES:
+ subscriber_method = getattr(cls, 'on_' + subscriber_type)
+ if not callable(subscriber_method):
+ raise InvalidSubscriberMethodError(
+ 'Subscriber method %s must be callable.'
+ % subscriber_method
+ )
+
+ if not accepts_kwargs(subscriber_method):
+ raise InvalidSubscriberMethodError(
+ 'Subscriber method %s must accept keyword '
+ 'arguments (**kwargs)' % subscriber_method
+ )
+
+ def on_queued(self, future, **kwargs):
+ """Callback to be invoked when transfer request gets queued
+
+ This callback can be useful for:
+
+ * Keeping track of how many transfers have been requested
+ * Providing the expected transfer size through
+ future.meta.provide_transfer_size() so a HeadObject would not
+ need to be made for copies and downloads.
+
+ :type future: s3transfer.futures.TransferFuture
+ :param future: The TransferFuture representing the requested transfer.
+ """
+ pass
+
+ def on_progress(self, future, bytes_transferred, **kwargs):
+ """Callback to be invoked when progress is made on transfer
+
+ This callback can be useful for:
+
+ * Recording and displaying progress
+
+ :type future: s3transfer.futures.TransferFuture
+ :param future: The TransferFuture representing the requested transfer.
+
+ :type bytes_transferred: int
+ :param bytes_transferred: The number of bytes transferred for that
+ invocation of the callback. Note that a negative amount can be
+ provided, which usually indicates that an in-progress request
+ needed to be retried and thus progress was rewound.
+ """
+ pass
+
+ def on_done(self, future, **kwargs):
+ """Callback to be invoked once a transfer is done
+
+ This callback can be useful for:
+
+ * Recording and displaying whether the transfer succeeded or
+ failed using future.result()
+ * Running some task after the transfer completed like changing
+ the last modified time of a downloaded file.
+
+ :type future: s3transfer.futures.TransferFuture
+ :param future: The TransferFuture representing the requested transfer.
+ """
+ pass
diff --git a/contrib/python/s3transfer/py3/s3transfer/tasks.py b/contrib/python/s3transfer/py3/s3transfer/tasks.py
new file mode 100644
index 0000000000..1bad981264
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/tasks.py
@@ -0,0 +1,387 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import logging
+
+from s3transfer.utils import get_callbacks
+
+logger = logging.getLogger(__name__)
+
+
+class Task:
+ """A task associated to a TransferFuture request
+
+ This is a base class for other classes to subclass from. All subclassed
+ classes must implement the main() method.
+ """
+
+ def __init__(
+ self,
+ transfer_coordinator,
+ main_kwargs=None,
+ pending_main_kwargs=None,
+ done_callbacks=None,
+ is_final=False,
+ ):
+ """
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ :param transfer_coordinator: The context associated to the
+ TransferFuture for which this Task is associated with.
+
+ :type main_kwargs: dict
+ :param main_kwargs: The keyword args that can be immediately supplied
+ to the _main() method of the task
+
+ :type pending_main_kwargs: dict
+ :param pending_main_kwargs: The keyword args that are depended upon
+ by the result from a dependent future(s). The result returned by
+ the future(s) will be used as the value for the keyword argument
+ when _main() is called. The values for each key can be:
+ * a single future - Once completed, its value will be the
+ result of that single future
+ * a list of futures - Once all of the futures complete, the
+ value used will be a list of each completed future result
+ value in order of when they were originally supplied.
+
+ :type done_callbacks: list of callbacks
+ :param done_callbacks: A list of callbacks to call once the task is
+ done completing. Each callback will be called with no arguments
+ and will be called no matter if the task succeeds or an exception
+ is raised.
+
+ :type is_final: boolean
+ :param is_final: True, to indicate that this task is the final task
+ for the TransferFuture request. By setting this value to True, it
+ will set the result of the entire TransferFuture to the result
+ returned by this task's main() method.
+ """
+ self._transfer_coordinator = transfer_coordinator
+
+ self._main_kwargs = main_kwargs
+ if self._main_kwargs is None:
+ self._main_kwargs = {}
+
+ self._pending_main_kwargs = pending_main_kwargs
+ if pending_main_kwargs is None:
+ self._pending_main_kwargs = {}
+
+ self._done_callbacks = done_callbacks
+ if self._done_callbacks is None:
+ self._done_callbacks = []
+
+ self._is_final = is_final
+
+ def __repr__(self):
+ # These are the general main_kwarg parameters that we want to
+ # display in the repr.
+ params_to_display = [
+ 'bucket',
+ 'key',
+ 'part_number',
+ 'final_filename',
+ 'transfer_future',
+ 'offset',
+ 'extra_args',
+ ]
+ main_kwargs_to_display = self._get_kwargs_with_params_to_include(
+ self._main_kwargs, params_to_display
+ )
+ return '{}(transfer_id={}, {})'.format(
+ self.__class__.__name__,
+ self._transfer_coordinator.transfer_id,
+ main_kwargs_to_display,
+ )
+
+ @property
+ def transfer_id(self):
+ """The id for the transfer request that the task belongs to"""
+ return self._transfer_coordinator.transfer_id
+
+ def _get_kwargs_with_params_to_include(self, kwargs, include):
+ filtered_kwargs = {}
+ for param in include:
+ if param in kwargs:
+ filtered_kwargs[param] = kwargs[param]
+ return filtered_kwargs
+
+ def _get_kwargs_with_params_to_exclude(self, kwargs, exclude):
+ filtered_kwargs = {}
+ for param, value in kwargs.items():
+ if param in exclude:
+ continue
+ filtered_kwargs[param] = value
+ return filtered_kwargs
+
+ def __call__(self):
+ """The callable to use when submitting a Task to an executor"""
+ try:
+ # Wait for all of futures this task depends on.
+ self._wait_on_dependent_futures()
+ # Gather up all of the main keyword arguments for main().
+ # This includes the immediately provided main_kwargs and
+ # the values for pending_main_kwargs that source from the return
+ # values from the task's dependent futures.
+ kwargs = self._get_all_main_kwargs()
+ # If the task is not done (really only if some other related
+ # task to the TransferFuture had failed) then execute the task's
+ # main() method.
+ if not self._transfer_coordinator.done():
+ return self._execute_main(kwargs)
+ except Exception as e:
+ self._log_and_set_exception(e)
+ finally:
+ # Run any done callbacks associated to the task no matter what.
+ for done_callback in self._done_callbacks:
+ done_callback()
+
+ if self._is_final:
+ # If this is the final task announce that it is done if results
+ # are waiting on its completion.
+ self._transfer_coordinator.announce_done()
+
+ def _execute_main(self, kwargs):
+ # Do not display keyword args that should not be printed, especially
+ # if they are going to make the logs hard to follow.
+ params_to_exclude = ['data']
+ kwargs_to_display = self._get_kwargs_with_params_to_exclude(
+ kwargs, params_to_exclude
+ )
+ # Log what is about to be executed.
+ logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}")
+
+ return_value = self._main(**kwargs)
+ # If the task is the final task, then set the TransferFuture's
+ # value to the return value from main().
+ if self._is_final:
+ self._transfer_coordinator.set_result(return_value)
+ return return_value
+
+ def _log_and_set_exception(self, exception):
+ # If an exception is ever thrown than set the exception for the
+ # entire TransferFuture.
+ logger.debug("Exception raised.", exc_info=True)
+ self._transfer_coordinator.set_exception(exception)
+
+ def _main(self, **kwargs):
+ """The method that will be ran in the executor
+
+ This method must be implemented by subclasses from Task. main() can
+ be implemented with any arguments decided upon by the subclass.
+ """
+ raise NotImplementedError('_main() must be implemented')
+
+ def _wait_on_dependent_futures(self):
+ # Gather all of the futures into that main() depends on.
+ futures_to_wait_on = []
+ for _, future in self._pending_main_kwargs.items():
+ # If the pending main keyword arg is a list then extend the list.
+ if isinstance(future, list):
+ futures_to_wait_on.extend(future)
+ # If the pending main keyword arg is a future append it to the list.
+ else:
+ futures_to_wait_on.append(future)
+ # Now wait for all of the futures to complete.
+ self._wait_until_all_complete(futures_to_wait_on)
+
+ def _wait_until_all_complete(self, futures):
+ # This is a basic implementation of the concurrent.futures.wait()
+ #
+ # concurrent.futures.wait() is not used instead because of this
+ # reported issue: https://bugs.python.org/issue20319.
+ # The issue would occasionally cause multipart uploads to hang
+ # when wait() was called. With this approach, it avoids the
+ # concurrency bug by removing any association with concurrent.futures
+ # implementation of waiters.
+ logger.debug(
+ '%s about to wait for the following futures %s', self, futures
+ )
+ for future in futures:
+ try:
+ logger.debug('%s about to wait for %s', self, future)
+ future.result()
+ except Exception:
+ # result() can also produce exceptions. We want to ignore
+ # these to be deferred to error handling down the road.
+ pass
+ logger.debug('%s done waiting for dependent futures', self)
+
+ def _get_all_main_kwargs(self):
+ # Copy over all of the kwargs that we know is available.
+ kwargs = copy.copy(self._main_kwargs)
+
+ # Iterate through the kwargs whose values are pending on the result
+ # of a future.
+ for key, pending_value in self._pending_main_kwargs.items():
+ # If the value is a list of futures, iterate though the list
+ # appending on the result from each future.
+ if isinstance(pending_value, list):
+ result = []
+ for future in pending_value:
+ result.append(future.result())
+ # Otherwise if the pending_value is a future, just wait for it.
+ else:
+ result = pending_value.result()
+ # Add the retrieved value to the kwargs to be sent to the
+ # main() call.
+ kwargs[key] = result
+ return kwargs
+
+
+class SubmissionTask(Task):
+ """A base class for any submission task
+
+ Submission tasks are the top-level task used to submit a series of tasks
+ to execute a particular transfer.
+ """
+
+ def _main(self, transfer_future, **kwargs):
+ """
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+
+ :param kwargs: Any additional kwargs that you may want to pass
+ to the _submit() method
+ """
+ try:
+ self._transfer_coordinator.set_status_to_queued()
+
+ # Before submitting any tasks, run all of the on_queued callbacks
+ on_queued_callbacks = get_callbacks(transfer_future, 'queued')
+ for on_queued_callback in on_queued_callbacks:
+ on_queued_callback()
+
+ # Once callbacks have been ran set the status to running.
+ self._transfer_coordinator.set_status_to_running()
+
+ # Call the submit method to start submitting tasks to execute the
+ # transfer.
+ self._submit(transfer_future=transfer_future, **kwargs)
+ except BaseException as e:
+ # If there was an exception raised during the submission of task
+ # there is a chance that the final task that signals if a transfer
+ # is done and too run the cleanup may never have been submitted in
+ # the first place so we need to account accordingly.
+ #
+ # Note that BaseException is caught, instead of Exception, because
+ # for some implementations of executors, specifically the serial
+ # implementation, the SubmissionTask is directly exposed to
+ # KeyboardInterupts and so needs to cleanup and signal done
+ # for those as well.
+
+ # Set the exception, that caused the process to fail.
+ self._log_and_set_exception(e)
+
+ # Wait for all possibly associated futures that may have spawned
+ # from this submission task have finished before we announce the
+ # transfer done.
+ self._wait_for_all_submitted_futures_to_complete()
+
+ # Announce the transfer as done, which will run any cleanups
+ # and done callbacks as well.
+ self._transfer_coordinator.announce_done()
+
+ def _submit(self, transfer_future, **kwargs):
+ """The submission method to be implemented
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+
+ :param kwargs: Any additional keyword arguments you want to be passed
+ in
+ """
+ raise NotImplementedError('_submit() must be implemented')
+
+ def _wait_for_all_submitted_futures_to_complete(self):
+ # We want to wait for all futures that were submitted to
+ # complete as we do not want the cleanup callbacks or done callbacks
+ # to be called to early. The main problem is any task that was
+ # submitted may have submitted even more during its process and so
+ # we need to account accordingly.
+
+ # First get all of the futures that were submitted up to this point.
+ submitted_futures = self._transfer_coordinator.associated_futures
+ while submitted_futures:
+ # Wait for those futures to complete.
+ self._wait_until_all_complete(submitted_futures)
+ # However, more futures may have been submitted as we waited so
+ # we need to check again for any more associated futures.
+ possibly_more_submitted_futures = (
+ self._transfer_coordinator.associated_futures
+ )
+ # If the current list of submitted futures is equal to the
+ # the list of associated futures for when after the wait completes,
+ # we can ensure no more futures were submitted in waiting on
+ # the current list of futures to complete ultimately meaning all
+ # futures that may have spawned from the original submission task
+ # have completed.
+ if submitted_futures == possibly_more_submitted_futures:
+ break
+ submitted_futures = possibly_more_submitted_futures
+
+
+class CreateMultipartUploadTask(Task):
+ """Task to initiate a multipart upload"""
+
+ def _main(self, client, bucket, key, extra_args):
+ """
+ :param client: The client to use when calling CreateMultipartUpload
+ :param bucket: The name of the bucket to upload to
+ :param key: The name of the key to upload to
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in the initialization.
+
+ :returns: The upload id of the multipart upload
+ """
+ # Create the multipart upload.
+ response = client.create_multipart_upload(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ upload_id = response['UploadId']
+
+ # Add a cleanup if the multipart upload fails at any point.
+ self._transfer_coordinator.add_failure_cleanup(
+ client.abort_multipart_upload,
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ )
+ return upload_id
+
+
+class CompleteMultipartUploadTask(Task):
+ """Task to complete a multipart upload"""
+
+ def _main(self, client, bucket, key, upload_id, parts, extra_args):
+ """
+ :param client: The client to use when calling CompleteMultipartUpload
+ :param bucket: The name of the bucket to upload to
+ :param key: The name of the key to upload to
+ :param upload_id: The id of the upload
+ :param parts: A list of parts to use to complete the multipart upload::
+
+ [{'Etag': etag_value, 'PartNumber': part_number}, ...]
+
+ Each element in the list consists of a return value from
+ ``UploadPartTask.main()``.
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in completing the multipart transfer.
+ """
+ client.complete_multipart_upload(
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ MultipartUpload={'Parts': parts},
+ **extra_args,
+ )
diff --git a/contrib/python/s3transfer/py3/s3transfer/upload.py b/contrib/python/s3transfer/py3/s3transfer/upload.py
new file mode 100644
index 0000000000..31ade051d7
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/upload.py
@@ -0,0 +1,795 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import math
+from io import BytesIO
+
+from s3transfer.compat import readable, seekable
+from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import (
+ ChunksizeAdjuster,
+ DeferredOpenFile,
+ get_callbacks,
+ get_filtered_dict,
+)
+
+
+class AggregatedProgressCallback:
+ def __init__(self, callbacks, threshold=1024 * 256):
+ """Aggregates progress updates for every provided progress callback
+
+ :type callbacks: A list of functions that accepts bytes_transferred
+ as a single argument
+ :param callbacks: The callbacks to invoke when threshold is reached
+
+ :type threshold: int
+ :param threshold: The progress threshold in which to take the
+ aggregated progress and invoke the progress callback with that
+ aggregated progress total
+ """
+ self._callbacks = callbacks
+ self._threshold = threshold
+ self._bytes_seen = 0
+
+ def __call__(self, bytes_transferred):
+ self._bytes_seen += bytes_transferred
+ if self._bytes_seen >= self._threshold:
+ self._trigger_callbacks()
+
+ def flush(self):
+ """Flushes out any progress that has not been sent to its callbacks"""
+ if self._bytes_seen > 0:
+ self._trigger_callbacks()
+
+ def _trigger_callbacks(self):
+ for callback in self._callbacks:
+ callback(bytes_transferred=self._bytes_seen)
+ self._bytes_seen = 0
+
+
+class InterruptReader:
+ """Wrapper that can interrupt reading using an error
+
+ It uses a transfer coordinator to propagate an error if it notices
+ that a read is being made while the file is being read from.
+
+ :type fileobj: file-like obj
+ :param fileobj: The file-like object to read from
+
+ :type transfer_coordinator: s3transfer.futures.TransferCoordinator
+ :param transfer_coordinator: The transfer coordinator to use if the
+ reader needs to be interrupted.
+ """
+
+ def __init__(self, fileobj, transfer_coordinator):
+ self._fileobj = fileobj
+ self._transfer_coordinator = transfer_coordinator
+
+ def read(self, amount=None):
+ # If there is an exception, then raise the exception.
+ # We raise an error instead of returning no bytes because for
+ # requests where the content length and md5 was sent, it will
+ # cause md5 mismatches and retries as there was no indication that
+ # the stream being read from encountered any issues.
+ if self._transfer_coordinator.exception:
+ raise self._transfer_coordinator.exception
+ return self._fileobj.read(amount)
+
+ def seek(self, where, whence=0):
+ self._fileobj.seek(where, whence)
+
+ def tell(self):
+ return self._fileobj.tell()
+
+ def close(self):
+ self._fileobj.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+
+
+class UploadInputManager:
+ """Base manager class for handling various types of files for uploads
+
+ This class is typically used for the UploadSubmissionTask class to help
+ determine the following:
+
+ * How to determine the size of the file
+ * How to determine if a multipart upload is required
+ * How to retrieve the body for a PutObject
+ * How to retrieve the bodies for a set of UploadParts
+
+ The answers/implementations differ for the various types of file inputs
+ that may be accepted. All implementations must subclass and override
+ public methods from this class.
+ """
+
+ def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None):
+ self._osutil = osutil
+ self._transfer_coordinator = transfer_coordinator
+ self._bandwidth_limiter = bandwidth_limiter
+
+ @classmethod
+ def is_compatible(cls, upload_source):
+ """Determines if the source for the upload is compatible with manager
+
+ :param upload_source: The source for which the upload will pull data
+ from.
+
+ :returns: True if the manager can handle the type of source specified
+ otherwise returns False.
+ """
+ raise NotImplementedError('must implement _is_compatible()')
+
+ def stores_body_in_memory(self, operation_name):
+ """Whether the body it provides are stored in-memory
+
+ :type operation_name: str
+ :param operation_name: The name of the client operation that the body
+ is being used for. Valid operation_names are ``put_object`` and
+ ``upload_part``.
+
+ :rtype: boolean
+ :returns: True if the body returned by the manager will be stored in
+ memory. False if the manager will not directly store the body in
+ memory.
+ """
+ raise NotImplementedError('must implement store_body_in_memory()')
+
+ def provide_transfer_size(self, transfer_future):
+ """Provides the transfer size of an upload
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The future associated with upload request
+ """
+ raise NotImplementedError('must implement provide_transfer_size()')
+
+ def requires_multipart_upload(self, transfer_future, config):
+ """Determines where a multipart upload is required
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The future associated with upload request
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The config associated to the transfer manager
+
+ :rtype: boolean
+ :returns: True, if the upload should be multipart based on
+ configuration and size. False, otherwise.
+ """
+ raise NotImplementedError('must implement requires_multipart_upload()')
+
+ def get_put_object_body(self, transfer_future):
+ """Returns the body to use for PutObject
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The future associated with upload request
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The config associated to the transfer manager
+
+ :rtype: s3transfer.utils.ReadFileChunk
+ :returns: A ReadFileChunk including all progress callbacks
+ associated with the transfer future.
+ """
+ raise NotImplementedError('must implement get_put_object_body()')
+
+ def yield_upload_part_bodies(self, transfer_future, chunksize):
+ """Yields the part number and body to use for each UploadPart
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The future associated with upload request
+
+ :type chunksize: int
+ :param chunksize: The chunksize to use for this upload.
+
+ :rtype: int, s3transfer.utils.ReadFileChunk
+ :returns: Yields the part number and the ReadFileChunk including all
+ progress callbacks associated with the transfer future for that
+ specific yielded part.
+ """
+ raise NotImplementedError('must implement yield_upload_part_bodies()')
+
+ def _wrap_fileobj(self, fileobj):
+ fileobj = InterruptReader(fileobj, self._transfer_coordinator)
+ if self._bandwidth_limiter:
+ fileobj = self._bandwidth_limiter.get_bandwith_limited_stream(
+ fileobj, self._transfer_coordinator, enabled=False
+ )
+ return fileobj
+
+ def _get_progress_callbacks(self, transfer_future):
+ callbacks = get_callbacks(transfer_future, 'progress')
+ # We only want to be wrapping the callbacks if there are callbacks to
+ # invoke because we do not want to be doing any unnecessary work if
+ # there are no callbacks to invoke.
+ if callbacks:
+ return [AggregatedProgressCallback(callbacks)]
+ return []
+
+ def _get_close_callbacks(self, aggregated_progress_callbacks):
+ return [callback.flush for callback in aggregated_progress_callbacks]
+
+
+class UploadFilenameInputManager(UploadInputManager):
+ """Upload utility for filenames"""
+
+ @classmethod
+ def is_compatible(cls, upload_source):
+ return isinstance(upload_source, str)
+
+ def stores_body_in_memory(self, operation_name):
+ return False
+
+ def provide_transfer_size(self, transfer_future):
+ transfer_future.meta.provide_transfer_size(
+ self._osutil.get_file_size(transfer_future.meta.call_args.fileobj)
+ )
+
+ def requires_multipart_upload(self, transfer_future, config):
+ return transfer_future.meta.size >= config.multipart_threshold
+
+ def get_put_object_body(self, transfer_future):
+ # Get a file-like object for the given input
+ fileobj, full_size = self._get_put_object_fileobj_with_full_size(
+ transfer_future
+ )
+
+ # Wrap fileobj with interrupt reader that will quickly cancel
+ # uploads if needed instead of having to wait for the socket
+ # to completely read all of the data.
+ fileobj = self._wrap_fileobj(fileobj)
+
+ callbacks = self._get_progress_callbacks(transfer_future)
+ close_callbacks = self._get_close_callbacks(callbacks)
+ size = transfer_future.meta.size
+ # Return the file-like object wrapped into a ReadFileChunk to get
+ # progress.
+ return self._osutil.open_file_chunk_reader_from_fileobj(
+ fileobj=fileobj,
+ chunk_size=size,
+ full_file_size=full_size,
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
+
+ def yield_upload_part_bodies(self, transfer_future, chunksize):
+ full_file_size = transfer_future.meta.size
+ num_parts = self._get_num_parts(transfer_future, chunksize)
+ for part_number in range(1, num_parts + 1):
+ callbacks = self._get_progress_callbacks(transfer_future)
+ close_callbacks = self._get_close_callbacks(callbacks)
+ start_byte = chunksize * (part_number - 1)
+ # Get a file-like object for that part and the size of the full
+ # file size for the associated file-like object for that part.
+ fileobj, full_size = self._get_upload_part_fileobj_with_full_size(
+ transfer_future.meta.call_args.fileobj,
+ start_byte=start_byte,
+ part_size=chunksize,
+ full_file_size=full_file_size,
+ )
+
+ # Wrap fileobj with interrupt reader that will quickly cancel
+ # uploads if needed instead of having to wait for the socket
+ # to completely read all of the data.
+ fileobj = self._wrap_fileobj(fileobj)
+
+ # Wrap the file-like object into a ReadFileChunk to get progress.
+ read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj(
+ fileobj=fileobj,
+ chunk_size=chunksize,
+ full_file_size=full_size,
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
+ yield part_number, read_file_chunk
+
+ def _get_deferred_open_file(self, fileobj, start_byte):
+ fileobj = DeferredOpenFile(
+ fileobj, start_byte, open_function=self._osutil.open
+ )
+ return fileobj
+
+ def _get_put_object_fileobj_with_full_size(self, transfer_future):
+ fileobj = transfer_future.meta.call_args.fileobj
+ size = transfer_future.meta.size
+ return self._get_deferred_open_file(fileobj, 0), size
+
+ def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs):
+ start_byte = kwargs['start_byte']
+ full_size = kwargs['full_file_size']
+ return self._get_deferred_open_file(fileobj, start_byte), full_size
+
+ def _get_num_parts(self, transfer_future, part_size):
+ return int(math.ceil(transfer_future.meta.size / float(part_size)))
+
+
+class UploadSeekableInputManager(UploadFilenameInputManager):
+ """Upload utility for an open file object"""
+
+ @classmethod
+ def is_compatible(cls, upload_source):
+ return readable(upload_source) and seekable(upload_source)
+
+ def stores_body_in_memory(self, operation_name):
+ if operation_name == 'put_object':
+ return False
+ else:
+ return True
+
+ def provide_transfer_size(self, transfer_future):
+ fileobj = transfer_future.meta.call_args.fileobj
+ # To determine size, first determine the starting position
+ # Seek to the end and then find the difference in the length
+ # between the end and start positions.
+ start_position = fileobj.tell()
+ fileobj.seek(0, 2)
+ end_position = fileobj.tell()
+ fileobj.seek(start_position)
+ transfer_future.meta.provide_transfer_size(
+ end_position - start_position
+ )
+
+ def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs):
+ # Note: It is unfortunate that in order to do a multithreaded
+ # multipart upload we cannot simply copy the filelike object
+ # since there is not really a mechanism in python (i.e. os.dup
+ # points to the same OS filehandle which causes concurrency
+ # issues). So instead we need to read from the fileobj and
+ # chunk the data out to separate file-like objects in memory.
+ data = fileobj.read(kwargs['part_size'])
+ # We return the length of the data instead of the full_file_size
+ # because we partitioned the data into separate BytesIO objects
+ # meaning the BytesIO object has no knowledge of its start position
+ # relative the input source nor access to the rest of the input
+ # source. So we must treat it as its own standalone file.
+ return BytesIO(data), len(data)
+
+ def _get_put_object_fileobj_with_full_size(self, transfer_future):
+ fileobj = transfer_future.meta.call_args.fileobj
+ # The current position needs to be taken into account when retrieving
+ # the full size of the file.
+ size = fileobj.tell() + transfer_future.meta.size
+ return fileobj, size
+
+
+class UploadNonSeekableInputManager(UploadInputManager):
+ """Upload utility for a file-like object that cannot seek."""
+
+ def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None):
+ super().__init__(osutil, transfer_coordinator, bandwidth_limiter)
+ self._initial_data = b''
+
+ @classmethod
+ def is_compatible(cls, upload_source):
+ return readable(upload_source)
+
+ def stores_body_in_memory(self, operation_name):
+ return True
+
+ def provide_transfer_size(self, transfer_future):
+ # No-op because there is no way to do this short of reading the entire
+ # body into memory.
+ return
+
+ def requires_multipart_upload(self, transfer_future, config):
+ # If the user has set the size, we can use that.
+ if transfer_future.meta.size is not None:
+ return transfer_future.meta.size >= config.multipart_threshold
+
+ # This is tricky to determine in this case because we can't know how
+ # large the input is. So to figure it out, we read data into memory
+ # up until the threshold and compare how much data was actually read
+ # against the threshold.
+ fileobj = transfer_future.meta.call_args.fileobj
+ threshold = config.multipart_threshold
+ self._initial_data = self._read(fileobj, threshold, False)
+ if len(self._initial_data) < threshold:
+ return False
+ else:
+ return True
+
+ def get_put_object_body(self, transfer_future):
+ callbacks = self._get_progress_callbacks(transfer_future)
+ close_callbacks = self._get_close_callbacks(callbacks)
+ fileobj = transfer_future.meta.call_args.fileobj
+
+ body = self._wrap_data(
+ self._initial_data + fileobj.read(), callbacks, close_callbacks
+ )
+
+ # Zero out the stored data so we don't have additional copies
+ # hanging around in memory.
+ self._initial_data = None
+ return body
+
+ def yield_upload_part_bodies(self, transfer_future, chunksize):
+ file_object = transfer_future.meta.call_args.fileobj
+ part_number = 0
+
+ # Continue reading parts from the file-like object until it is empty.
+ while True:
+ callbacks = self._get_progress_callbacks(transfer_future)
+ close_callbacks = self._get_close_callbacks(callbacks)
+ part_number += 1
+ part_content = self._read(file_object, chunksize)
+ if not part_content:
+ break
+ part_object = self._wrap_data(
+ part_content, callbacks, close_callbacks
+ )
+
+ # Zero out part_content to avoid hanging on to additional data.
+ part_content = None
+ yield part_number, part_object
+
+ def _read(self, fileobj, amount, truncate=True):
+ """
+ Reads a specific amount of data from a stream and returns it. If there
+ is any data in initial_data, that will be popped out first.
+
+ :type fileobj: A file-like object that implements read
+ :param fileobj: The stream to read from.
+
+ :type amount: int
+ :param amount: The number of bytes to read from the stream.
+
+ :type truncate: bool
+ :param truncate: Whether or not to truncate initial_data after
+ reading from it.
+
+ :return: Generator which generates part bodies from the initial data.
+ """
+ # If the the initial data is empty, we simply read from the fileobj
+ if len(self._initial_data) == 0:
+ return fileobj.read(amount)
+
+ # If the requested number of bytes is less than the amount of
+ # initial data, pull entirely from initial data.
+ if amount <= len(self._initial_data):
+ data = self._initial_data[:amount]
+ # Truncate initial data so we don't hang onto the data longer
+ # than we need.
+ if truncate:
+ self._initial_data = self._initial_data[amount:]
+ return data
+
+ # At this point there is some initial data left, but not enough to
+ # satisfy the number of bytes requested. Pull out the remaining
+ # initial data and read the rest from the fileobj.
+ amount_to_read = amount - len(self._initial_data)
+ data = self._initial_data + fileobj.read(amount_to_read)
+
+ # Zero out initial data so we don't hang onto the data any more.
+ if truncate:
+ self._initial_data = b''
+ return data
+
+ def _wrap_data(self, data, callbacks, close_callbacks):
+ """
+ Wraps data with the interrupt reader and the file chunk reader.
+
+ :type data: bytes
+ :param data: The data to wrap.
+
+ :type callbacks: list
+ :param callbacks: The callbacks associated with the transfer future.
+
+ :type close_callbacks: list
+ :param close_callbacks: The callbacks to be called when closing the
+ wrapper for the data.
+
+ :return: Fully wrapped data.
+ """
+ fileobj = self._wrap_fileobj(BytesIO(data))
+ return self._osutil.open_file_chunk_reader_from_fileobj(
+ fileobj=fileobj,
+ chunk_size=len(data),
+ full_file_size=len(data),
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
+
+
+class UploadSubmissionTask(SubmissionTask):
+ """Task for submitting tasks to execute an upload"""
+
+ UPLOAD_PART_ARGS = [
+ 'SSECustomerKey',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ 'ExpectedBucketOwner',
+ ]
+
+ COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
+
+ def _get_upload_input_manager_cls(self, transfer_future):
+ """Retrieves a class for managing input for an upload based on file type
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future for the request
+
+ :rtype: class of UploadInputManager
+ :returns: The appropriate class to use for managing a specific type of
+ input for uploads.
+ """
+ upload_manager_resolver_chain = [
+ UploadFilenameInputManager,
+ UploadSeekableInputManager,
+ UploadNonSeekableInputManager,
+ ]
+
+ fileobj = transfer_future.meta.call_args.fileobj
+ for upload_manager_cls in upload_manager_resolver_chain:
+ if upload_manager_cls.is_compatible(fileobj):
+ return upload_manager_cls
+ raise RuntimeError(
+ 'Input {} of type: {} is not supported.'.format(
+ fileobj, type(fileobj)
+ )
+ )
+
+ def _submit(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ bandwidth_limiter=None,
+ ):
+ """
+ :param client: The client associated with the transfer manager
+
+ :type config: s3transfer.manager.TransferConfig
+ :param config: The transfer config associated with the transfer
+ manager
+
+ :type osutil: s3transfer.utils.OSUtil
+ :param osutil: The os utility associated to the transfer manager
+
+ :type request_executor: s3transfer.futures.BoundedExecutor
+ :param request_executor: The request executor associated with the
+ transfer manager
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future associated with the
+ transfer request that tasks are being submitted for
+ """
+ upload_input_manager = self._get_upload_input_manager_cls(
+ transfer_future
+ )(osutil, self._transfer_coordinator, bandwidth_limiter)
+
+ # Determine the size if it was not provided
+ if transfer_future.meta.size is None:
+ upload_input_manager.provide_transfer_size(transfer_future)
+
+ # Do a multipart upload if needed, otherwise do a regular put object.
+ if not upload_input_manager.requires_multipart_upload(
+ transfer_future, config
+ ):
+ self._submit_upload_request(
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ )
+ else:
+ self._submit_multipart_request(
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ )
+
+ def _submit_upload_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Get any tags that need to be associated to the put object task
+ put_object_tag = self._get_upload_task_tag(
+ upload_input_manager, 'put_object'
+ )
+
+ # Submit the request of a single upload.
+ self._transfer_coordinator.submit(
+ request_executor,
+ PutObjectTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'fileobj': upload_input_manager.get_put_object_body(
+ transfer_future
+ ),
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': call_args.extra_args,
+ },
+ is_final=True,
+ ),
+ tag=put_object_tag,
+ )
+
+ def _submit_multipart_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ ):
+ call_args = transfer_future.meta.call_args
+
+ # Submit the request to create a multipart upload.
+ create_multipart_future = self._transfer_coordinator.submit(
+ request_executor,
+ CreateMultipartUploadTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': call_args.extra_args,
+ },
+ ),
+ )
+
+ # Submit requests to upload the parts of the file.
+ part_futures = []
+ extra_part_args = self._extra_upload_part_args(call_args.extra_args)
+
+ # Get any tags that need to be associated to the submitted task
+ # for upload the data
+ upload_part_tag = self._get_upload_task_tag(
+ upload_input_manager, 'upload_part'
+ )
+
+ size = transfer_future.meta.size
+ adjuster = ChunksizeAdjuster()
+ chunksize = adjuster.adjust_chunksize(config.multipart_chunksize, size)
+ part_iterator = upload_input_manager.yield_upload_part_bodies(
+ transfer_future, chunksize
+ )
+
+ for part_number, fileobj in part_iterator:
+ part_futures.append(
+ self._transfer_coordinator.submit(
+ request_executor,
+ UploadPartTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'fileobj': fileobj,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'part_number': part_number,
+ 'extra_args': extra_part_args,
+ },
+ pending_main_kwargs={
+ 'upload_id': create_multipart_future
+ },
+ ),
+ tag=upload_part_tag,
+ )
+ )
+
+ complete_multipart_extra_args = self._extra_complete_multipart_args(
+ call_args.extra_args
+ )
+ # Submit the request to complete the multipart upload.
+ self._transfer_coordinator.submit(
+ request_executor,
+ CompleteMultipartUploadTask(
+ transfer_coordinator=self._transfer_coordinator,
+ main_kwargs={
+ 'client': client,
+ 'bucket': call_args.bucket,
+ 'key': call_args.key,
+ 'extra_args': complete_multipart_extra_args,
+ },
+ pending_main_kwargs={
+ 'upload_id': create_multipart_future,
+ 'parts': part_futures,
+ },
+ is_final=True,
+ ),
+ )
+
+ def _extra_upload_part_args(self, extra_args):
+ # Only the args in UPLOAD_PART_ARGS actually need to be passed
+ # onto the upload_part calls.
+ return get_filtered_dict(extra_args, self.UPLOAD_PART_ARGS)
+
+ def _extra_complete_multipart_args(self, extra_args):
+ return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS)
+
+ def _get_upload_task_tag(self, upload_input_manager, operation_name):
+ tag = None
+ if upload_input_manager.stores_body_in_memory(operation_name):
+ tag = IN_MEMORY_UPLOAD_TAG
+ return tag
+
+
+class PutObjectTask(Task):
+ """Task to do a nonmultipart upload"""
+
+ def _main(self, client, fileobj, bucket, key, extra_args):
+ """
+ :param client: The client to use when calling PutObject
+ :param fileobj: The file to upload.
+ :param bucket: The name of the bucket to upload to
+ :param key: The name of the key to upload to
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in the upload.
+ """
+ with fileobj as body:
+ client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args)
+
+
+class UploadPartTask(Task):
+ """Task to upload a part in a multipart upload"""
+
+ def _main(
+ self, client, fileobj, bucket, key, upload_id, part_number, extra_args
+ ):
+ """
+ :param client: The client to use when calling PutObject
+ :param fileobj: The file to upload.
+ :param bucket: The name of the bucket to upload to
+ :param key: The name of the key to upload to
+ :param upload_id: The id of the upload
+ :param part_number: The number representing the part of the multipart
+ upload
+ :param extra_args: A dictionary of any extra arguments that may be
+ used in the upload.
+
+ :rtype: dict
+ :returns: A dictionary representing a part::
+
+ {'Etag': etag_value, 'PartNumber': part_number}
+
+ This value can be appended to a list to be used to complete
+ the multipart upload.
+ """
+ with fileobj as body:
+ response = client.upload_part(
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ Body=body,
+ **extra_args
+ )
+ etag = response['ETag']
+ return {'ETag': etag, 'PartNumber': part_number}
diff --git a/contrib/python/s3transfer/py3/s3transfer/utils.py b/contrib/python/s3transfer/py3/s3transfer/utils.py
new file mode 100644
index 0000000000..ba881c67dd
--- /dev/null
+++ b/contrib/python/s3transfer/py3/s3transfer/utils.py
@@ -0,0 +1,802 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import functools
+import logging
+import math
+import os
+import random
+import socket
+import stat
+import string
+import threading
+from collections import defaultdict
+
+from botocore.exceptions import IncompleteReadError, ReadTimeoutError
+
+from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file
+
+MAX_PARTS = 10000
+# The maximum file size you can upload via S3 per request.
+# See: http://docs.aws.amazon.com/AmazonS3/latest/dev/UploadingObjects.html
+# and: http://docs.aws.amazon.com/AmazonS3/latest/dev/qfacts.html
+MAX_SINGLE_UPLOAD_SIZE = 5 * (1024 ** 3)
+MIN_UPLOAD_CHUNKSIZE = 5 * (1024 ** 2)
+logger = logging.getLogger(__name__)
+
+
+S3_RETRYABLE_DOWNLOAD_ERRORS = (
+ socket.timeout,
+ SOCKET_ERROR,
+ ReadTimeoutError,
+ IncompleteReadError,
+)
+
+
+def random_file_extension(num_digits=8):
+ return ''.join(random.choice(string.hexdigits) for _ in range(num_digits))
+
+
+def signal_not_transferring(request, operation_name, **kwargs):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'signal_not_transferring'
+ ):
+ request.body.signal_not_transferring()
+
+
+def signal_transferring(request, operation_name, **kwargs):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'signal_transferring'
+ ):
+ request.body.signal_transferring()
+
+
+def calculate_num_parts(size, part_size):
+ return int(math.ceil(size / float(part_size)))
+
+
+def calculate_range_parameter(
+ part_size, part_index, num_parts, total_size=None
+):
+ """Calculate the range parameter for multipart downloads/copies
+
+ :type part_size: int
+ :param part_size: The size of the part
+
+ :type part_index: int
+ :param part_index: The index for which this parts starts. This index starts
+ at zero
+
+ :type num_parts: int
+ :param num_parts: The total number of parts in the transfer
+
+ :returns: The value to use for Range parameter on downloads or
+ the CopySourceRange parameter for copies
+ """
+ # Used to calculate the Range parameter
+ start_range = part_index * part_size
+ if part_index == num_parts - 1:
+ end_range = ''
+ if total_size is not None:
+ end_range = str(total_size - 1)
+ else:
+ end_range = start_range + part_size - 1
+ range_param = f'bytes={start_range}-{end_range}'
+ return range_param
+
+
+def get_callbacks(transfer_future, callback_type):
+ """Retrieves callbacks from a subscriber
+
+ :type transfer_future: s3transfer.futures.TransferFuture
+ :param transfer_future: The transfer future the subscriber is associated
+ to.
+
+ :type callback_type: str
+ :param callback_type: The type of callback to retrieve from the subscriber.
+ Valid types include:
+ * 'queued'
+ * 'progress'
+ * 'done'
+
+ :returns: A list of callbacks for the type specified. All callbacks are
+ preinjected with the transfer future.
+ """
+ callbacks = []
+ for subscriber in transfer_future.meta.call_args.subscribers:
+ callback_name = 'on_' + callback_type
+ if hasattr(subscriber, callback_name):
+ callbacks.append(
+ functools.partial(
+ getattr(subscriber, callback_name), future=transfer_future
+ )
+ )
+ return callbacks
+
+
+def invoke_progress_callbacks(callbacks, bytes_transferred):
+ """Calls all progress callbacks
+
+ :param callbacks: A list of progress callbacks to invoke
+ :param bytes_transferred: The number of bytes transferred. This is passed
+ to the callbacks. If no bytes were transferred the callbacks will not
+ be invoked because no progress was achieved. It is also possible
+ to receive a negative amount which comes from retrying a transfer
+ request.
+ """
+ # Only invoke the callbacks if bytes were actually transferred.
+ if bytes_transferred:
+ for callback in callbacks:
+ callback(bytes_transferred=bytes_transferred)
+
+
+def get_filtered_dict(original_dict, whitelisted_keys):
+ """Gets a dictionary filtered by whitelisted keys
+
+ :param original_dict: The original dictionary of arguments to source keys
+ and values.
+ :param whitelisted_key: A list of keys to include in the filtered
+ dictionary.
+
+ :returns: A dictionary containing key/values from the original dictionary
+ whose key was included in the whitelist
+ """
+ filtered_dict = {}
+ for key, value in original_dict.items():
+ if key in whitelisted_keys:
+ filtered_dict[key] = value
+ return filtered_dict
+
+
+class CallArgs:
+ def __init__(self, **kwargs):
+ """A class that records call arguments
+
+ The call arguments must be passed as keyword arguments. It will set
+ each keyword argument as an attribute of the object along with its
+ associated value.
+ """
+ for arg, value in kwargs.items():
+ setattr(self, arg, value)
+
+
+class FunctionContainer:
+ """An object that contains a function and any args or kwargs to call it
+
+ When called the provided function will be called with provided args
+ and kwargs.
+ """
+
+ def __init__(self, func, *args, **kwargs):
+ self._func = func
+ self._args = args
+ self._kwargs = kwargs
+
+ def __repr__(self):
+ return 'Function: {} with args {} and kwargs {}'.format(
+ self._func, self._args, self._kwargs
+ )
+
+ def __call__(self):
+ return self._func(*self._args, **self._kwargs)
+
+
+class CountCallbackInvoker:
+ """An abstraction to invoke a callback when a shared count reaches zero
+
+ :param callback: Callback invoke when finalized count reaches zero
+ """
+
+ def __init__(self, callback):
+ self._lock = threading.Lock()
+ self._callback = callback
+ self._count = 0
+ self._is_finalized = False
+
+ @property
+ def current_count(self):
+ with self._lock:
+ return self._count
+
+ def increment(self):
+ """Increment the count by one"""
+ with self._lock:
+ if self._is_finalized:
+ raise RuntimeError(
+ 'Counter has been finalized it can no longer be '
+ 'incremented.'
+ )
+ self._count += 1
+
+ def decrement(self):
+ """Decrement the count by one"""
+ with self._lock:
+ if self._count == 0:
+ raise RuntimeError(
+ 'Counter is at zero. It cannot dip below zero'
+ )
+ self._count -= 1
+ if self._is_finalized and self._count == 0:
+ self._callback()
+
+ def finalize(self):
+ """Finalize the counter
+
+ Once finalized, the counter never be incremented and the callback
+ can be invoked once the count reaches zero
+ """
+ with self._lock:
+ self._is_finalized = True
+ if self._count == 0:
+ self._callback()
+
+
+class OSUtils:
+ _MAX_FILENAME_LEN = 255
+
+ def get_file_size(self, filename):
+ return os.path.getsize(filename)
+
+ def open_file_chunk_reader(self, filename, start_byte, size, callbacks):
+ return ReadFileChunk.from_filename(
+ filename, start_byte, size, callbacks, enable_callbacks=False
+ )
+
+ def open_file_chunk_reader_from_fileobj(
+ self,
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks,
+ close_callbacks=None,
+ ):
+ return ReadFileChunk(
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks=callbacks,
+ enable_callbacks=False,
+ close_callbacks=close_callbacks,
+ )
+
+ def open(self, filename, mode):
+ return open(filename, mode)
+
+ def remove_file(self, filename):
+ """Remove a file, noop if file does not exist."""
+ # Unlike os.remove, if the file does not exist,
+ # then this method does nothing.
+ try:
+ os.remove(filename)
+ except OSError:
+ pass
+
+ def rename_file(self, current_filename, new_filename):
+ rename_file(current_filename, new_filename)
+
+ def is_special_file(cls, filename):
+ """Checks to see if a file is a special UNIX file.
+
+ It checks if the file is a character special device, block special
+ device, FIFO, or socket.
+
+ :param filename: Name of the file
+
+ :returns: True if the file is a special file. False, if is not.
+ """
+ # If it does not exist, it must be a new file so it cannot be
+ # a special file.
+ if not os.path.exists(filename):
+ return False
+ mode = os.stat(filename).st_mode
+ # Character special device.
+ if stat.S_ISCHR(mode):
+ return True
+ # Block special device
+ if stat.S_ISBLK(mode):
+ return True
+ # Named pipe / FIFO
+ if stat.S_ISFIFO(mode):
+ return True
+ # Socket.
+ if stat.S_ISSOCK(mode):
+ return True
+ return False
+
+ def get_temp_filename(self, filename):
+ suffix = os.extsep + random_file_extension()
+ path = os.path.dirname(filename)
+ name = os.path.basename(filename)
+ temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix
+ return os.path.join(path, temp_filename)
+
+ def allocate(self, filename, size):
+ try:
+ with self.open(filename, 'wb') as f:
+ fallocate(f, size)
+ except OSError:
+ self.remove_file(filename)
+ raise
+
+
+class DeferredOpenFile:
+ def __init__(self, filename, start_byte=0, mode='rb', open_function=open):
+ """A class that defers the opening of a file till needed
+
+ This is useful for deferring opening of a file till it is needed
+ in a separate thread, as there is a limit of how many open files
+ there can be in a single thread for most operating systems. The
+ file gets opened in the following methods: ``read()``, ``seek()``,
+ and ``__enter__()``
+
+ :type filename: str
+ :param filename: The name of the file to open
+
+ :type start_byte: int
+ :param start_byte: The byte to seek to when the file is opened.
+
+ :type mode: str
+ :param mode: The mode to use to open the file
+
+ :type open_function: function
+ :param open_function: The function to use to open the file
+ """
+ self._filename = filename
+ self._fileobj = None
+ self._start_byte = start_byte
+ self._mode = mode
+ self._open_function = open_function
+
+ def _open_if_needed(self):
+ if self._fileobj is None:
+ self._fileobj = self._open_function(self._filename, self._mode)
+ if self._start_byte != 0:
+ self._fileobj.seek(self._start_byte)
+
+ @property
+ def name(self):
+ return self._filename
+
+ def read(self, amount=None):
+ self._open_if_needed()
+ return self._fileobj.read(amount)
+
+ def write(self, data):
+ self._open_if_needed()
+ self._fileobj.write(data)
+
+ def seek(self, where, whence=0):
+ self._open_if_needed()
+ self._fileobj.seek(where, whence)
+
+ def tell(self):
+ if self._fileobj is None:
+ return self._start_byte
+ return self._fileobj.tell()
+
+ def close(self):
+ if self._fileobj:
+ self._fileobj.close()
+
+ def __enter__(self):
+ self._open_if_needed()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+
+
+class ReadFileChunk:
+ def __init__(
+ self,
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks=None,
+ enable_callbacks=True,
+ close_callbacks=None,
+ ):
+ """
+
+ Given a file object shown below::
+
+ |___________________________________________________|
+ 0 | | full_file_size
+ |----chunk_size---|
+ f.tell()
+
+ :type fileobj: file
+ :param fileobj: File like object
+
+ :type chunk_size: int
+ :param chunk_size: The max chunk size to read. Trying to read
+ pass the end of the chunk size will behave like you've
+ reached the end of the file.
+
+ :type full_file_size: int
+ :param full_file_size: The entire content length associated
+ with ``fileobj``.
+
+ :type callbacks: A list of function(amount_read)
+ :param callbacks: Called whenever data is read from this object in the
+ order provided.
+
+ :type enable_callbacks: boolean
+ :param enable_callbacks: True if to run callbacks. Otherwise, do not
+ run callbacks
+
+ :type close_callbacks: A list of function()
+ :param close_callbacks: Called when close is called. The function
+ should take no arguments.
+ """
+ self._fileobj = fileobj
+ self._start_byte = self._fileobj.tell()
+ self._size = self._calculate_file_size(
+ self._fileobj,
+ requested_size=chunk_size,
+ start_byte=self._start_byte,
+ actual_file_size=full_file_size,
+ )
+ # _amount_read represents the position in the chunk and may exceed
+ # the chunk size, but won't allow reads out of bounds.
+ self._amount_read = 0
+ self._callbacks = callbacks
+ if callbacks is None:
+ self._callbacks = []
+ self._callbacks_enabled = enable_callbacks
+ self._close_callbacks = close_callbacks
+ if close_callbacks is None:
+ self._close_callbacks = close_callbacks
+
+ @classmethod
+ def from_filename(
+ cls,
+ filename,
+ start_byte,
+ chunk_size,
+ callbacks=None,
+ enable_callbacks=True,
+ ):
+ """Convenience factory function to create from a filename.
+
+ :type start_byte: int
+ :param start_byte: The first byte from which to start reading.
+
+ :type chunk_size: int
+ :param chunk_size: The max chunk size to read. Trying to read
+ pass the end of the chunk size will behave like you've
+ reached the end of the file.
+
+ :type full_file_size: int
+ :param full_file_size: The entire content length associated
+ with ``fileobj``.
+
+ :type callbacks: function(amount_read)
+ :param callbacks: Called whenever data is read from this object.
+
+ :type enable_callbacks: bool
+ :param enable_callbacks: Indicate whether to invoke callback
+ during read() calls.
+
+ :rtype: ``ReadFileChunk``
+ :return: A new instance of ``ReadFileChunk``
+
+ """
+ f = open(filename, 'rb')
+ f.seek(start_byte)
+ file_size = os.fstat(f.fileno()).st_size
+ return cls(f, chunk_size, file_size, callbacks, enable_callbacks)
+
+ def _calculate_file_size(
+ self, fileobj, requested_size, start_byte, actual_file_size
+ ):
+ max_chunk_size = actual_file_size - start_byte
+ return min(max_chunk_size, requested_size)
+
+ def read(self, amount=None):
+ amount_left = max(self._size - self._amount_read, 0)
+ if amount is None:
+ amount_to_read = amount_left
+ else:
+ amount_to_read = min(amount_left, amount)
+ data = self._fileobj.read(amount_to_read)
+ self._amount_read += len(data)
+ if self._callbacks is not None and self._callbacks_enabled:
+ invoke_progress_callbacks(self._callbacks, len(data))
+ return data
+
+ def signal_transferring(self):
+ self.enable_callback()
+ if hasattr(self._fileobj, 'signal_transferring'):
+ self._fileobj.signal_transferring()
+
+ def signal_not_transferring(self):
+ self.disable_callback()
+ if hasattr(self._fileobj, 'signal_not_transferring'):
+ self._fileobj.signal_not_transferring()
+
+ def enable_callback(self):
+ self._callbacks_enabled = True
+
+ def disable_callback(self):
+ self._callbacks_enabled = False
+
+ def seek(self, where, whence=0):
+ if whence not in (0, 1, 2):
+ # Mimic io's error for invalid whence values
+ raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)")
+
+ # Recalculate where based on chunk attributes so seek from file
+ # start (whence=0) is always used
+ where += self._start_byte
+ if whence == 1:
+ where += self._amount_read
+ elif whence == 2:
+ where += self._size
+
+ self._fileobj.seek(max(where, self._start_byte))
+ if self._callbacks is not None and self._callbacks_enabled:
+ # To also rewind the callback() for an accurate progress report
+ bounded_where = max(min(where - self._start_byte, self._size), 0)
+ bounded_amount_read = min(self._amount_read, self._size)
+ amount = bounded_where - bounded_amount_read
+ invoke_progress_callbacks(
+ self._callbacks, bytes_transferred=amount
+ )
+ self._amount_read = max(where - self._start_byte, 0)
+
+ def close(self):
+ if self._close_callbacks is not None and self._callbacks_enabled:
+ for callback in self._close_callbacks:
+ callback()
+ self._fileobj.close()
+
+ def tell(self):
+ return self._amount_read
+
+ def __len__(self):
+ # __len__ is defined because requests will try to determine the length
+ # of the stream to set a content length. In the normal case
+ # of the file it will just stat the file, but we need to change that
+ # behavior. By providing a __len__, requests will use that instead
+ # of stat'ing the file.
+ return self._size
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+
+ def __iter__(self):
+ # This is a workaround for http://bugs.python.org/issue17575
+ # Basically httplib will try to iterate over the contents, even
+ # if its a file like object. This wasn't noticed because we've
+ # already exhausted the stream so iterating over the file immediately
+ # stops, which is what we're simulating here.
+ return iter([])
+
+
+class StreamReaderProgress:
+ """Wrapper for a read only stream that adds progress callbacks."""
+
+ def __init__(self, stream, callbacks=None):
+ self._stream = stream
+ self._callbacks = callbacks
+ if callbacks is None:
+ self._callbacks = []
+
+ def read(self, *args, **kwargs):
+ value = self._stream.read(*args, **kwargs)
+ invoke_progress_callbacks(self._callbacks, len(value))
+ return value
+
+
+class NoResourcesAvailable(Exception):
+ pass
+
+
+class TaskSemaphore:
+ def __init__(self, count):
+ """A semaphore for the purpose of limiting the number of tasks
+
+ :param count: The size of semaphore
+ """
+ self._semaphore = threading.Semaphore(count)
+
+ def acquire(self, tag, blocking=True):
+ """Acquire the semaphore
+
+ :param tag: A tag identifying what is acquiring the semaphore. Note
+ that this is not really needed to directly use this class but is
+ needed for API compatibility with the SlidingWindowSemaphore
+ implementation.
+ :param block: If True, block until it can be acquired. If False,
+ do not block and raise an exception if cannot be acquired.
+
+ :returns: A token (can be None) to use when releasing the semaphore
+ """
+ logger.debug("Acquiring %s", tag)
+ if not self._semaphore.acquire(blocking):
+ raise NoResourcesAvailable("Cannot acquire tag '%s'" % tag)
+
+ def release(self, tag, acquire_token):
+ """Release the semaphore
+
+ :param tag: A tag identifying what is releasing the semaphore
+ :param acquire_token: The token returned from when the semaphore was
+ acquired. Note that this is not really needed to directly use this
+ class but is needed for API compatibility with the
+ SlidingWindowSemaphore implementation.
+ """
+ logger.debug(f"Releasing acquire {tag}/{acquire_token}")
+ self._semaphore.release()
+
+
+class SlidingWindowSemaphore(TaskSemaphore):
+ """A semaphore used to coordinate sequential resource access.
+
+ This class is similar to the stdlib BoundedSemaphore:
+
+ * It's initialized with a count.
+ * Each call to ``acquire()`` decrements the counter.
+ * If the count is at zero, then ``acquire()`` will either block until the
+ count increases, or if ``blocking=False``, then it will raise
+ a NoResourcesAvailable exception indicating that it failed to acquire the
+ semaphore.
+
+ The main difference is that this semaphore is used to limit
+ access to a resource that requires sequential access. For example,
+ if I want to access resource R that has 20 subresources R_0 - R_19,
+ this semaphore can also enforce that you only have a max range of
+ 10 at any given point in time. You must also specify a tag name
+ when you acquire the semaphore. The sliding window semantics apply
+ on a per tag basis. The internal count will only be incremented
+ when the minimum sequence number for a tag is released.
+
+ """
+
+ def __init__(self, count):
+ self._count = count
+ # Dict[tag, next_sequence_number].
+ self._tag_sequences = defaultdict(int)
+ self._lowest_sequence = {}
+ self._lock = threading.Lock()
+ self._condition = threading.Condition(self._lock)
+ # Dict[tag, List[sequence_number]]
+ self._pending_release = {}
+
+ def current_count(self):
+ with self._lock:
+ return self._count
+
+ def acquire(self, tag, blocking=True):
+ logger.debug("Acquiring %s", tag)
+ self._condition.acquire()
+ try:
+ if self._count == 0:
+ if not blocking:
+ raise NoResourcesAvailable("Cannot acquire tag '%s'" % tag)
+ else:
+ while self._count == 0:
+ self._condition.wait()
+ # self._count is no longer zero.
+ # First, check if this is the first time we're seeing this tag.
+ sequence_number = self._tag_sequences[tag]
+ if sequence_number == 0:
+ # First time seeing the tag, so record we're at 0.
+ self._lowest_sequence[tag] = sequence_number
+ self._tag_sequences[tag] += 1
+ self._count -= 1
+ return sequence_number
+ finally:
+ self._condition.release()
+
+ def release(self, tag, acquire_token):
+ sequence_number = acquire_token
+ logger.debug("Releasing acquire %s/%s", tag, sequence_number)
+ self._condition.acquire()
+ try:
+ if tag not in self._tag_sequences:
+ raise ValueError("Attempted to release unknown tag: %s" % tag)
+ max_sequence = self._tag_sequences[tag]
+ if self._lowest_sequence[tag] == sequence_number:
+ # We can immediately process this request and free up
+ # resources.
+ self._lowest_sequence[tag] += 1
+ self._count += 1
+ self._condition.notify()
+ queued = self._pending_release.get(tag, [])
+ while queued:
+ if self._lowest_sequence[tag] == queued[-1]:
+ queued.pop()
+ self._lowest_sequence[tag] += 1
+ self._count += 1
+ else:
+ break
+ elif self._lowest_sequence[tag] < sequence_number < max_sequence:
+ # We can't do anything right now because we're still waiting
+ # for the min sequence for the tag to be released. We have
+ # to queue this for pending release.
+ self._pending_release.setdefault(tag, []).append(
+ sequence_number
+ )
+ self._pending_release[tag].sort(reverse=True)
+ else:
+ raise ValueError(
+ "Attempted to release unknown sequence number "
+ "%s for tag: %s" % (sequence_number, tag)
+ )
+ finally:
+ self._condition.release()
+
+
+class ChunksizeAdjuster:
+ def __init__(
+ self,
+ max_size=MAX_SINGLE_UPLOAD_SIZE,
+ min_size=MIN_UPLOAD_CHUNKSIZE,
+ max_parts=MAX_PARTS,
+ ):
+ self.max_size = max_size
+ self.min_size = min_size
+ self.max_parts = max_parts
+
+ def adjust_chunksize(self, current_chunksize, file_size=None):
+ """Get a chunksize close to current that fits within all S3 limits.
+
+ :type current_chunksize: int
+ :param current_chunksize: The currently configured chunksize.
+
+ :type file_size: int or None
+ :param file_size: The size of the file to upload. This might be None
+ if the object being transferred has an unknown size.
+
+ :returns: A valid chunksize that fits within configured limits.
+ """
+ chunksize = current_chunksize
+ if file_size is not None:
+ chunksize = self._adjust_for_max_parts(chunksize, file_size)
+ return self._adjust_for_chunksize_limits(chunksize)
+
+ def _adjust_for_chunksize_limits(self, current_chunksize):
+ if current_chunksize > self.max_size:
+ logger.debug(
+ "Chunksize greater than maximum chunksize. "
+ "Setting to %s from %s." % (self.max_size, current_chunksize)
+ )
+ return self.max_size
+ elif current_chunksize < self.min_size:
+ logger.debug(
+ "Chunksize less than minimum chunksize. "
+ "Setting to %s from %s." % (self.min_size, current_chunksize)
+ )
+ return self.min_size
+ else:
+ return current_chunksize
+
+ def _adjust_for_max_parts(self, current_chunksize, file_size):
+ chunksize = current_chunksize
+ num_parts = int(math.ceil(file_size / float(chunksize)))
+
+ while num_parts > self.max_parts:
+ chunksize *= 2
+ num_parts = int(math.ceil(file_size / float(chunksize)))
+
+ if chunksize != current_chunksize:
+ logger.debug(
+ "Chunksize would result in the number of parts exceeding the "
+ "maximum. Setting to %s from %s."
+ % (chunksize, current_chunksize)
+ )
+
+ return chunksize
diff --git a/contrib/python/s3transfer/py3/tests/__init__.py b/contrib/python/s3transfer/py3/tests/__init__.py
new file mode 100644
index 0000000000..e36c4936bf
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/__init__.py
@@ -0,0 +1,531 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import hashlib
+import io
+import math
+import os
+import platform
+import shutil
+import string
+import tempfile
+import unittest
+from unittest import mock # noqa: F401
+
+import botocore.session
+from botocore.stub import Stubber
+
+from s3transfer.futures import (
+ IN_MEMORY_DOWNLOAD_TAG,
+ IN_MEMORY_UPLOAD_TAG,
+ BoundedExecutor,
+ NonThreadedExecutor,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.manager import TransferConfig
+from s3transfer.subscribers import BaseSubscriber
+from s3transfer.utils import (
+ CallArgs,
+ OSUtils,
+ SlidingWindowSemaphore,
+ TaskSemaphore,
+)
+
+ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS
+# Detect if CRT is available for use
+try:
+ import awscrt.s3 # noqa: F401
+
+ HAS_CRT = True
+except ImportError:
+ HAS_CRT = False
+
+
+def setup_package():
+ if is_serial_implementation():
+ BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor
+
+
+def teardown_package():
+ BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS
+
+
+def is_serial_implementation():
+ return os.environ.get('USE_SERIAL_EXECUTOR', False)
+
+
+def assert_files_equal(first, second):
+ if os.path.getsize(first) != os.path.getsize(second):
+ raise AssertionError(f"Files are not equal: {first}, {second}")
+ first_md5 = md5_checksum(first)
+ second_md5 = md5_checksum(second)
+ if first_md5 != second_md5:
+ raise AssertionError(
+ "Files are not equal: {}(md5={}) != {}(md5={})".format(
+ first, first_md5, second, second_md5
+ )
+ )
+
+
+def md5_checksum(filename):
+ checksum = hashlib.md5()
+ with open(filename, 'rb') as f:
+ for chunk in iter(lambda: f.read(8192), b''):
+ checksum.update(chunk)
+ return checksum.hexdigest()
+
+
+def random_bucket_name(prefix='s3transfer', num_chars=10):
+ base = string.ascii_lowercase + string.digits
+ random_bytes = bytearray(os.urandom(num_chars))
+ return prefix + ''.join([base[b % len(base)] for b in random_bytes])
+
+
+def skip_if_windows(reason):
+ """Decorator to skip tests that should not be run on windows.
+
+ Example usage:
+
+ @skip_if_windows("Not valid")
+ def test_some_non_windows_stuff(self):
+ self.assertEqual(...)
+
+ """
+
+ def decorator(func):
+ return unittest.skipIf(
+ platform.system() not in ['Darwin', 'Linux'], reason
+ )(func)
+
+ return decorator
+
+
+def skip_if_using_serial_implementation(reason):
+ """Decorator to skip tests when running as the serial implementation"""
+
+ def decorator(func):
+ return unittest.skipIf(is_serial_implementation(), reason)(func)
+
+ return decorator
+
+
+def requires_crt(cls, reason=None):
+ if reason is None:
+ reason = "Test requires awscrt to be installed."
+ return unittest.skipIf(not HAS_CRT, reason)(cls)
+
+
+class StreamWithError:
+ """A wrapper to simulate errors while reading from a stream
+
+ :param stream: The underlying stream to read from
+ :param exception_type: The exception type to throw
+ :param num_reads: The number of times to allow a read before raising
+ the exception. A value of zero indicates to raise the error on the
+ first read.
+ """
+
+ def __init__(self, stream, exception_type, num_reads=0):
+ self._stream = stream
+ self._exception_type = exception_type
+ self._num_reads = num_reads
+ self._count = 0
+
+ def read(self, n=-1):
+ if self._count == self._num_reads:
+ raise self._exception_type
+ self._count += 1
+ return self._stream.read(n)
+
+
+class FileSizeProvider:
+ def __init__(self, file_size):
+ self.file_size = file_size
+
+ def on_queued(self, future, **kwargs):
+ future.meta.provide_transfer_size(self.file_size)
+
+
+class FileCreator:
+ def __init__(self):
+ self.rootdir = tempfile.mkdtemp()
+
+ def remove_all(self):
+ shutil.rmtree(self.rootdir)
+
+ def create_file(self, filename, contents, mode='w'):
+ """Creates a file in a tmpdir
+ ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
+ It will be translated into a full path in a tmp dir.
+ ``mode`` is the mode the file should be opened either as ``w`` or
+ `wb``.
+ Returns the full path to the file.
+ """
+ full_path = os.path.join(self.rootdir, filename)
+ if not os.path.isdir(os.path.dirname(full_path)):
+ os.makedirs(os.path.dirname(full_path))
+ with open(full_path, mode) as f:
+ f.write(contents)
+ return full_path
+
+ def create_file_with_size(self, filename, filesize):
+ filename = self.create_file(filename, contents='')
+ chunksize = 8192
+ with open(filename, 'wb') as f:
+ for i in range(int(math.ceil(filesize / float(chunksize)))):
+ f.write(b'a' * chunksize)
+ return filename
+
+ def append_file(self, filename, contents):
+ """Append contents to a file
+ ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
+ It will be translated into a full path in a tmp dir.
+ Returns the full path to the file.
+ """
+ full_path = os.path.join(self.rootdir, filename)
+ if not os.path.isdir(os.path.dirname(full_path)):
+ os.makedirs(os.path.dirname(full_path))
+ with open(full_path, 'a') as f:
+ f.write(contents)
+ return full_path
+
+ def full_path(self, filename):
+ """Translate relative path to full path in temp dir.
+ f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
+ """
+ return os.path.join(self.rootdir, filename)
+
+
+class RecordingOSUtils(OSUtils):
+ """An OSUtil abstraction that records openings and renamings"""
+
+ def __init__(self):
+ super().__init__()
+ self.open_records = []
+ self.rename_records = []
+
+ def open(self, filename, mode):
+ self.open_records.append((filename, mode))
+ return super().open(filename, mode)
+
+ def rename_file(self, current_filename, new_filename):
+ self.rename_records.append((current_filename, new_filename))
+ super().rename_file(current_filename, new_filename)
+
+
+class RecordingSubscriber(BaseSubscriber):
+ def __init__(self):
+ self.on_queued_calls = []
+ self.on_progress_calls = []
+ self.on_done_calls = []
+
+ def on_queued(self, **kwargs):
+ self.on_queued_calls.append(kwargs)
+
+ def on_progress(self, **kwargs):
+ self.on_progress_calls.append(kwargs)
+
+ def on_done(self, **kwargs):
+ self.on_done_calls.append(kwargs)
+
+ def calculate_bytes_seen(self, **kwargs):
+ amount_seen = 0
+ for call in self.on_progress_calls:
+ amount_seen += call['bytes_transferred']
+ return amount_seen
+
+
+class TransferCoordinatorWithInterrupt(TransferCoordinator):
+ """Used to inject keyboard interrupts"""
+
+ def result(self):
+ raise KeyboardInterrupt()
+
+
+class RecordingExecutor:
+ """A wrapper on an executor to record calls made to submit()
+
+ You can access the submissions property to receive a list of dictionaries
+ that represents all submissions where the dictionary is formatted::
+
+ {
+ 'fn': function
+ 'args': positional args (as tuple)
+ 'kwargs': keyword args (as dict)
+ }
+ """
+
+ def __init__(self, executor):
+ self._executor = executor
+ self.submissions = []
+
+ def submit(self, task, tag=None, block=True):
+ future = self._executor.submit(task, tag, block)
+ self.submissions.append({'task': task, 'tag': tag, 'block': block})
+ return future
+
+ def shutdown(self):
+ self._executor.shutdown()
+
+
+class StubbedClientTest(unittest.TestCase):
+ def setUp(self):
+ self.session = botocore.session.get_session()
+ self.region = 'us-west-2'
+ self.client = self.session.create_client(
+ 's3',
+ self.region,
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ self.stubber = Stubber(self.client)
+ self.stubber.activate()
+
+ def tearDown(self):
+ self.stubber.deactivate()
+
+ def reset_stubber_with_new_client(self, override_client_kwargs):
+ client_kwargs = {
+ 'service_name': 's3',
+ 'region_name': self.region,
+ 'aws_access_key_id': 'foo',
+ 'aws_secret_access_key': 'bar',
+ }
+ client_kwargs.update(override_client_kwargs)
+ self.client = self.session.create_client(**client_kwargs)
+ self.stubber = Stubber(self.client)
+ self.stubber.activate()
+
+
+class BaseTaskTest(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.transfer_coordinator = TransferCoordinator()
+
+ def get_task(self, task_cls, **kwargs):
+ if 'transfer_coordinator' not in kwargs:
+ kwargs['transfer_coordinator'] = self.transfer_coordinator
+ return task_cls(**kwargs)
+
+ def get_transfer_future(self, call_args=None):
+ return TransferFuture(
+ meta=TransferMeta(call_args), coordinator=self.transfer_coordinator
+ )
+
+
+class BaseSubmissionTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig()
+ self.osutil = OSUtils()
+ self.executor = BoundedExecutor(
+ 1000,
+ 1,
+ {
+ IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10),
+ IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10),
+ },
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ self.executor.shutdown()
+
+
+class BaseGeneralInterfaceTest(StubbedClientTest):
+ """A general test class to ensure consistency across TransferManger methods
+
+ This test should never be called and should be subclassed from to pick up
+ the various tests that all TransferManager method must pass from a
+ functionality standpoint.
+ """
+
+ __test__ = False
+
+ def manager(self):
+ """The transfer manager to use"""
+ raise NotImplementedError('method is not implemented')
+
+ @property
+ def method(self):
+ """The transfer manager method to invoke i.e. upload()"""
+ raise NotImplementedError('method is not implemented')
+
+ def create_call_kwargs(self):
+ """The kwargs to be passed to the transfer manager method"""
+ raise NotImplementedError('create_call_kwargs is not implemented')
+
+ def create_invalid_extra_args(self):
+ """A value for extra_args that will cause validation errors"""
+ raise NotImplementedError(
+ 'create_invalid_extra_args is not implemented'
+ )
+
+ def create_stubbed_responses(self):
+ """A list of stubbed responses that will cause the request to succeed
+
+ The elements of this list is a dictionary that will be used as key
+ word arguments to botocore.Stubber.add_response(). For example::
+
+ [{'method': 'put_object', 'service_response': {}}]
+ """
+ raise NotImplementedError(
+ 'create_stubbed_responses is not implemented'
+ )
+
+ def create_expected_progress_callback_info(self):
+ """A list of kwargs expected to be passed to each progress callback
+
+ Note that the future kwargs does not need to be added to each
+ dictionary provided in the list. This is injected for you. An example
+ is::
+
+ [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2}
+ ]
+
+ This indicates that the progress callback will be called three
+ times and pass along the specified keyword arguments and corresponding
+ values.
+ """
+ raise NotImplementedError(
+ 'create_expected_progress_callback_info is not implemented'
+ )
+
+ def _setup_default_stubbed_responses(self):
+ for stubbed_response in self.create_stubbed_responses():
+ self.stubber.add_response(**stubbed_response)
+
+ def test_returns_future_with_meta(self):
+ self._setup_default_stubbed_responses()
+ future = self.method(**self.create_call_kwargs())
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert the return value is a future with metadata associated to it.
+ self.assertIsInstance(future, TransferFuture)
+ self.assertIsInstance(future.meta, TransferMeta)
+
+ def test_returns_correct_call_args(self):
+ self._setup_default_stubbed_responses()
+ call_kwargs = self.create_call_kwargs()
+ future = self.method(**call_kwargs)
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert that there are call args associated to the metadata
+ self.assertIsInstance(future.meta.call_args, CallArgs)
+ # Assert that all of the arguments passed to the method exist and
+ # are of the correct value in call_args.
+ for param, value in call_kwargs.items():
+ self.assertEqual(value, getattr(future.meta.call_args, param))
+
+ def test_has_transfer_id_associated_to_future(self):
+ self._setup_default_stubbed_responses()
+ call_kwargs = self.create_call_kwargs()
+ future = self.method(**call_kwargs)
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert that an transfer id was associated to the future.
+ # Since there is only one transfer request is made for that transfer
+ # manager the id will be zero since it will be the first transfer
+ # request made for that transfer manager.
+ self.assertEqual(future.meta.transfer_id, 0)
+
+ # If we make a second request, the transfer id should have incremented
+ # by one for that new TransferFuture.
+ self._setup_default_stubbed_responses()
+ future = self.method(**call_kwargs)
+ future.result()
+ self.assertEqual(future.meta.transfer_id, 1)
+
+ def test_invalid_extra_args(self):
+ with self.assertRaisesRegex(ValueError, 'Invalid extra_args'):
+ self.method(
+ extra_args=self.create_invalid_extra_args(),
+ **self.create_call_kwargs(),
+ )
+
+ def test_for_callback_kwargs_correctness(self):
+ # Add the stubbed responses before invoking the method
+ self._setup_default_stubbed_responses()
+
+ subscriber = RecordingSubscriber()
+ future = self.method(
+ subscribers=[subscriber], **self.create_call_kwargs()
+ )
+ # We call shutdown instead of result on future because the future
+ # could be finished but the done callback could still be going.
+ # The manager's shutdown method ensures everything completes.
+ self.manager.shutdown()
+
+ # Assert the various subscribers were called with the
+ # expected kwargs
+ expected_progress_calls = self.create_expected_progress_callback_info()
+ for expected_progress_call in expected_progress_calls:
+ expected_progress_call['future'] = future
+
+ self.assertEqual(subscriber.on_queued_calls, [{'future': future}])
+ self.assertEqual(subscriber.on_progress_calls, expected_progress_calls)
+ self.assertEqual(subscriber.on_done_calls, [{'future': future}])
+
+
+class NonSeekableReader(io.RawIOBase):
+ def __init__(self, b=b''):
+ super().__init__()
+ self._data = io.BytesIO(b)
+
+ def seekable(self):
+ return False
+
+ def writable(self):
+ return False
+
+ def readable(self):
+ return True
+
+ def write(self, b):
+ # This is needed because python will not always return the correct
+ # kind of error even though writeable returns False.
+ raise io.UnsupportedOperation("write")
+
+ def read(self, n=-1):
+ return self._data.read(n)
+
+
+class NonSeekableWriter(io.RawIOBase):
+ def __init__(self, fileobj):
+ super().__init__()
+ self._fileobj = fileobj
+
+ def seekable(self):
+ return False
+
+ def writable(self):
+ return True
+
+ def readable(self):
+ return False
+
+ def write(self, b):
+ self._fileobj.write(b)
+
+ def read(self, n=-1):
+ raise io.UnsupportedOperation("read")
diff --git a/contrib/python/s3transfer/py3/tests/functional/__init__.py b/contrib/python/s3transfer/py3/tests/functional/__init__.py
new file mode 100644
index 0000000000..fa58dbdb55
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_copy.py b/contrib/python/s3transfer/py3/tests/functional/test_copy.py
new file mode 100644
index 0000000000..801c9003bb
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_copy.py
@@ -0,0 +1,554 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from botocore.exceptions import ClientError
+from botocore.stub import Stubber
+
+from s3transfer.manager import TransferConfig, TransferManager
+from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
+from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
+
+
+class BaseCopyTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_chunksize=MIN_UPLOAD_CHUNKSIZE,
+ multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
+ self.extra_args = {}
+ self.subscribers = []
+
+ self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2)
+ self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize)
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.copy
+
+ def create_call_kwargs(self):
+ return {
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {'method': 'copy_object', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': len(self.content)},
+ ]
+
+ def add_head_object_response(self, expected_params=None, stubber=None):
+ if not stubber:
+ stubber = self.stubber
+ head_response = self.create_stubbed_responses()[0]
+ if expected_params:
+ head_response['expected_params'] = expected_params
+ stubber.add_response(**head_response)
+
+ def add_successful_copy_responses(
+ self,
+ expected_copy_params=None,
+ expected_create_mpu_params=None,
+ expected_complete_mpu_params=None,
+ ):
+
+ # Add all responses needed to do the copy of the object.
+ # Should account for both ranged and nonranged downloads.
+ stubbed_responses = self.create_stubbed_responses()[1:]
+
+ # If the length of copy responses is greater than one then it is
+ # a multipart copy.
+ copy_responses = stubbed_responses[0:1]
+ if len(stubbed_responses) > 1:
+ copy_responses = stubbed_responses[1:-1]
+
+ # Add the expected create multipart upload params.
+ if expected_create_mpu_params:
+ stubbed_responses[0][
+ 'expected_params'
+ ] = expected_create_mpu_params
+
+ # Add any expected copy parameters.
+ if expected_copy_params:
+ for i, copy_response in enumerate(copy_responses):
+ if isinstance(expected_copy_params, list):
+ copy_response['expected_params'] = expected_copy_params[i]
+ else:
+ copy_response['expected_params'] = expected_copy_params
+
+ # Add the expected complete multipart upload params.
+ if expected_complete_mpu_params:
+ stubbed_responses[-1][
+ 'expected_params'
+ ] = expected_complete_mpu_params
+
+ # Add the responses to the stubber.
+ for stubbed_response in stubbed_responses:
+ self.stubber.add_response(**stubbed_response)
+
+ def test_can_provide_file_size(self):
+ self.add_successful_copy_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
+
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+
+ # The HeadObject should have not happened and should have been able
+ # to successfully copy the file.
+ self.stubber.assert_no_pending_responses()
+
+ def test_provide_copy_source_as_dict(self):
+ self.copy_source['VersionId'] = 'mysourceversionid'
+ expected_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ 'VersionId': 'mysourceversionid',
+ }
+
+ self.add_head_object_response(expected_params=expected_params)
+ self.add_successful_copy_responses()
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_invalid_copy_source(self):
+ self.copy_source = ['bucket', 'key']
+ future = self.manager.copy(**self.create_call_kwargs())
+ with self.assertRaises(TypeError):
+ future.result()
+
+ def test_provide_copy_source_client(self):
+ source_client = self.session.create_client(
+ 's3',
+ 'eu-central-1',
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ source_stubber = Stubber(source_client)
+ source_stubber.activate()
+ self.addCleanup(source_stubber.deactivate)
+
+ self.add_head_object_response(stubber=source_stubber)
+ self.add_successful_copy_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['source_client'] = source_client
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+
+ # Make sure that all of the responses were properly
+ # used for both clients.
+ source_stubber.assert_no_pending_responses()
+ self.stubber.assert_no_pending_responses()
+
+
+class TestNonMultipartCopy(BaseCopyTest):
+ __test__ = True
+
+ def test_copy(self):
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ }
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_with_extra_args(self):
+ self.extra_args['MetadataDirective'] = 'REPLACE'
+
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'MetadataDirective': 'REPLACE',
+ }
+
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_maps_extra_args_to_head_object(self):
+ self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
+
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ 'SSECustomerAlgorithm': 'AES256',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'CopySourceSSECustomerAlgorithm': 'AES256',
+ }
+
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_allowed_copy_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('CopyObject')
+ for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_copy_with_tagging(self):
+ extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
+ self.add_head_object_response()
+ self.add_successful_copy_responses(
+ expected_copy_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'Tagging': 'tag1=val1',
+ 'TaggingDirective': 'REPLACE',
+ }
+ )
+ future = self.manager.copy(
+ self.copy_source, self.bucket, self.key, extra_args
+ )
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key)
+
+ def test_raise_exception_on_s3_object_lambda_resource_as_source(self):
+ source = {
+ 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ }
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.copy(source, self.bucket, self.key)
+
+
+class TestMultipartCopy(BaseCopyTest):
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'create_multipart_upload',
+ 'service_response': {'UploadId': 'my-upload-id'},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}},
+ },
+ {'method': 'complete_multipart_upload', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ # Note that last read is from the empty sentinel indicating
+ # that the stream is done.
+ return [
+ {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
+ {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
+ {'bytes_transferred': self.half_chunksize},
+ ]
+
+ def add_create_multipart_upload_response(self):
+ self.stubber.add_response(**self.create_stubbed_responses()[1])
+
+ def _get_expected_params(self):
+ upload_id = 'my-upload-id'
+
+ # Add expected parameters to the head object
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+
+ # Add expected parameters for the create multipart
+ expected_create_mpu_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ }
+
+ expected_copy_params = []
+ # Add expected parameters to the copy part
+ ranges = [
+ 'bytes=0-5242879',
+ 'bytes=5242880-10485759',
+ 'bytes=10485760-13107199',
+ ]
+ for i, range_val in enumerate(ranges):
+ expected_copy_params.append(
+ {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': upload_id,
+ 'PartNumber': i + 1,
+ 'CopySourceRange': range_val,
+ }
+ )
+
+ # Add expected parameters for the complete multipart
+ expected_complete_mpu_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {
+ 'Parts': [
+ {'ETag': 'etag-1', 'PartNumber': 1},
+ {'ETag': 'etag-2', 'PartNumber': 2},
+ {'ETag': 'etag-3', 'PartNumber': 3},
+ ]
+ },
+ }
+
+ return expected_head_params, {
+ 'expected_create_mpu_params': expected_create_mpu_params,
+ 'expected_copy_params': expected_copy_params,
+ 'expected_complete_mpu_params': expected_complete_mpu_params,
+ }
+
+ def _add_params_to_expected_params(
+ self, add_copy_kwargs, operation_types, new_params
+ ):
+
+ expected_params_to_update = []
+ for operation_type in operation_types:
+ add_copy_kwargs_key = 'expected_' + operation_type + '_params'
+ expected_params = add_copy_kwargs[add_copy_kwargs_key]
+ if isinstance(expected_params, list):
+ expected_params_to_update.extend(expected_params)
+ else:
+ expected_params_to_update.append(expected_params)
+
+ for expected_params in expected_params_to_update:
+ expected_params.update(new_params)
+
+ def test_copy(self):
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_with_extra_args(self):
+ # This extra argument should be added to the head object,
+ # the create multipart upload, and upload part copy.
+ self.extra_args['RequestPayer'] = 'requester'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ head_params.update(self.extra_args)
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs,
+ ['create_mpu', 'copy', 'complete_mpu'],
+ self.extra_args,
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_blacklists_args_to_create_multipart(self):
+ # This argument can never be used for multipart uploads
+ self.extra_args['MetadataDirective'] = 'COPY'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_args_to_only_create_multipart(self):
+ self.extra_args['ACL'] = 'private'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['create_mpu'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_passes_args_to_create_multipart_and_upload_part(self):
+ # This will only be used for the complete multipart upload
+ # and upload part.
+ self.extra_args['SSECustomerAlgorithm'] = 'AES256'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_maps_extra_args_to_head_object(self):
+ self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+
+ # The CopySourceSSECustomerAlgorithm needs to get mapped to
+ # SSECustomerAlgorithm for HeadObject
+ head_params['SSECustomerAlgorithm'] = 'AES256'
+ self.add_head_object_response(expected_params=head_params)
+
+ # However, it needs to remain the same for UploadPartCopy.
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['copy'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_abort_on_failure(self):
+ # First add the head object and create multipart upload
+ self.add_head_object_response()
+ self.add_create_multipart_upload_response()
+
+ # Cause an error on upload_part_copy
+ self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure')
+
+ # Add the abort multipart to ensure it gets cleaned up on failure
+ self.stubber.add_response(
+ 'abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': 'my-upload-id',
+ },
+ )
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'):
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_mp_copy_with_tagging_directive(self):
+ extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
+ self.add_head_object_response()
+ self.add_successful_copy_responses(
+ expected_create_mpu_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Tagging': 'tag1=val1',
+ }
+ )
+ future = self.manager.copy(
+ self.copy_source, self.bucket, self.key, extra_args
+ )
+ future.result()
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_crt.py b/contrib/python/s3transfer/py3/tests/functional/test_crt.py
new file mode 100644
index 0000000000..fad0f4b23b
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_crt.py
@@ -0,0 +1,267 @@
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import fnmatch
+import threading
+import time
+from concurrent.futures import Future
+
+from botocore.session import Session
+
+from s3transfer.subscribers import BaseSubscriber
+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+if HAS_CRT:
+ import awscrt
+
+ import s3transfer.crt
+
+
+class submitThread(threading.Thread):
+ def __init__(self, transfer_manager, futures, callargs):
+ threading.Thread.__init__(self)
+ self._transfer_manager = transfer_manager
+ self._futures = futures
+ self._callargs = callargs
+
+ def run(self):
+ self._futures.append(self._transfer_manager.download(*self._callargs))
+
+
+class RecordingSubscriber(BaseSubscriber):
+ def __init__(self):
+ self.on_queued_called = False
+ self.on_done_called = False
+ self.bytes_transferred = 0
+ self.on_queued_future = None
+ self.on_done_future = None
+
+ def on_queued(self, future, **kwargs):
+ self.on_queued_called = True
+ self.on_queued_future = future
+
+ def on_done(self, future, **kwargs):
+ self.on_done_called = True
+ self.on_done_future = future
+
+
+@requires_crt
+class TestCRTTransferManager(unittest.TestCase):
+ def setUp(self):
+ self.region = 'us-west-2'
+ self.bucket = "test_bucket"
+ self.key = "test_key"
+ self.files = FileCreator()
+ self.filename = self.files.create_file('myfile', 'my content')
+ self.expected_path = "/" + self.bucket + "/" + self.key
+ self.expected_host = "s3.%s.amazonaws.com" % (self.region)
+ self.s3_request = mock.Mock(awscrt.s3.S3Request)
+ self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
+ self.s3_crt_client.make_request.return_value = self.s3_request
+ self.session = Session()
+ self.session.set_config_variable('region', self.region)
+ self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
+ self.session
+ )
+ self.transfer_manager = s3transfer.crt.CRTTransferManager(
+ crt_s3_client=self.s3_crt_client,
+ crt_request_serializer=self.request_serializer,
+ )
+ self.record_subscriber = RecordingSubscriber()
+
+ def tearDown(self):
+ self.files.remove_all()
+
+ def _assert_subscribers_called(self, expected_future=None):
+ self.assertTrue(self.record_subscriber.on_queued_called)
+ self.assertTrue(self.record_subscriber.on_done_called)
+ if expected_future:
+ self.assertIs(
+ self.record_subscriber.on_queued_future, expected_future
+ )
+ self.assertIs(
+ self.record_subscriber.on_done_future, expected_future
+ )
+
+ def _invoke_done_callbacks(self, **kwargs):
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ on_done = callargs_kwargs["on_done"]
+ on_done(error=None)
+
+ def _simulate_file_download(self, recv_filepath):
+ self.files.create_file(recv_filepath, "fake response")
+
+ def _simulate_make_request_side_effect(self, **kwargs):
+ if kwargs.get('recv_filepath'):
+ self._simulate_file_download(kwargs['recv_filepath'])
+ self._invoke_done_callbacks()
+ return mock.DEFAULT
+
+ def test_upload(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ self.assertEqual(callargs_kwargs["send_filepath"], self.filename)
+ self.assertIsNone(callargs_kwargs["recv_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("PUT", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+
+ def test_download(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.download(
+ self.bucket, self.key, self.filename, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ # the recv_filepath will be set to a temporary file path with some
+ # random suffix
+ self.assertTrue(
+ fnmatch.fnmatch(
+ callargs_kwargs["recv_filepath"],
+ f'{self.filename}.*',
+ )
+ )
+ self.assertIsNone(callargs_kwargs["send_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("GET", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+ with open(self.filename, 'rb') as f:
+ # Check the fake response overwrites the file because of download
+ self.assertEqual(f.read(), b'fake response')
+
+ def test_delete(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.delete(
+ self.bucket, self.key, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ self.assertIsNone(callargs_kwargs["send_filepath"])
+ self.assertIsNone(callargs_kwargs["recv_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("DELETE", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+
+ def test_blocks_when_max_requests_processes_reached(self):
+ futures = []
+ callargs = (self.bucket, self.key, self.filename, {}, [])
+ max_request_processes = 128 # the hard coded max processes
+ all_concurrent = max_request_processes + 1
+ threads = []
+ for i in range(0, all_concurrent):
+ thread = submitThread(self.transfer_manager, futures, callargs)
+ thread.start()
+ threads.append(thread)
+ # Sleep until the expected max requests has been reached
+ while len(futures) < max_request_processes:
+ time.sleep(0.05)
+ self.assertLessEqual(
+ self.s3_crt_client.make_request.call_count, max_request_processes
+ )
+ # Release lock
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ on_done = callargs_kwargs["on_done"]
+ on_done(error=None)
+ for thread in threads:
+ thread.join()
+ self.assertEqual(
+ self.s3_crt_client.make_request.call_count, all_concurrent
+ )
+
+ def _cancel_function(self):
+ self.cancel_called = True
+ self.s3_request.finished_future.set_exception(
+ awscrt.exceptions.from_code(0)
+ )
+ self._invoke_done_callbacks()
+
+ def test_cancel(self):
+ self.s3_request.finished_future = Future()
+ self.cancel_called = False
+ self.s3_request.cancel = self._cancel_function
+ try:
+ with self.transfer_manager:
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+ raise KeyboardInterrupt()
+ except KeyboardInterrupt:
+ pass
+
+ with self.assertRaises(awscrt.exceptions.AwsCrtError):
+ future.result()
+ self.assertTrue(self.cancel_called)
+
+ def test_serializer_error_handling(self):
+ class SerializationException(Exception):
+ pass
+
+ class ExceptionRaisingSerializer(
+ s3transfer.crt.BaseCRTRequestSerializer
+ ):
+ def serialize_http_request(self, transfer_type, future):
+ raise SerializationException()
+
+ not_impl_serializer = ExceptionRaisingSerializer()
+ transfer_manager = s3transfer.crt.CRTTransferManager(
+ crt_s3_client=self.s3_crt_client,
+ crt_request_serializer=not_impl_serializer,
+ )
+ future = transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+
+ with self.assertRaises(SerializationException):
+ future.result()
+
+ def test_crt_s3_client_error_handling(self):
+ self.s3_crt_client.make_request.side_effect = (
+ awscrt.exceptions.from_code(0)
+ )
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+ with self.assertRaises(awscrt.exceptions.AwsCrtError):
+ future.result()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_delete.py b/contrib/python/s3transfer/py3/tests/functional/test_delete.py
new file mode 100644
index 0000000000..28587a47a4
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_delete.py
@@ -0,0 +1,76 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.manager import TransferManager
+from __tests__ import BaseGeneralInterfaceTest
+
+
+class TestDeleteObject(BaseGeneralInterfaceTest):
+
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.manager = TransferManager(self.client)
+
+ @property
+ def method(self):
+ """The transfer manager method to invoke i.e. upload()"""
+ return self.manager.delete
+
+ def create_call_kwargs(self):
+ """The kwargs to be passed to the transfer manager method"""
+ return {
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {
+ 'BadKwargs': True,
+ }
+
+ def create_stubbed_responses(self):
+ """A list of stubbed responses that will cause the request to succeed
+
+ The elements of this list is a dictionary that will be used as key
+ word arguments to botocore.Stubber.add_response(). For example::
+
+ [{'method': 'put_object', 'service_response': {}}]
+ """
+ return [
+ {
+ 'method': 'delete_object',
+ 'service_response': {},
+ 'expected_params': {'Bucket': self.bucket, 'Key': self.key},
+ }
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return []
+
+ def test_known_allowed_args_in_input_shape(self):
+ op_model = self.client.meta.service_model.operation_model(
+ 'DeleteObject'
+ )
+ for allowed_arg in self.manager.ALLOWED_DELETE_ARGS:
+ self.assertIn(allowed_arg, op_model.input_shape.members)
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.delete(s3_object_lambda_arn, self.key)
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_download.py b/contrib/python/s3transfer/py3/tests/functional/test_download.py
new file mode 100644
index 0000000000..64a8a1309d
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_download.py
@@ -0,0 +1,497 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import glob
+import os
+import shutil
+import tempfile
+import time
+from io import BytesIO
+
+from botocore.exceptions import ClientError
+
+from s3transfer.compat import SOCKET_ERROR
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.manager import TransferConfig, TransferManager
+from __tests__ import (
+ BaseGeneralInterfaceTest,
+ FileSizeProvider,
+ NonSeekableWriter,
+ RecordingOSUtils,
+ RecordingSubscriber,
+ StreamWithError,
+ skip_if_using_serial_implementation,
+ skip_if_windows,
+)
+
+
+class BaseDownloadTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(max_request_concurrency=1)
+ self._manager = TransferManager(self.client, self.config)
+
+ # Create a temporary directory to write to
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # Create a stream to read from
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.download
+
+ def create_call_kwargs(self):
+ return {
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'fileobj': self.filename,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ # We want to make sure the beginning of the stream is always used
+ # in case this gets called twice.
+ self.stream.seek(0)
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': self.stream},
+ },
+ ]
+
+ def create_expected_progress_callback_info(self):
+ # Note that last read is from the empty sentinel indicating
+ # that the stream is done.
+ return [{'bytes_transferred': 10}]
+
+ def add_head_object_response(self, expected_params=None):
+ head_response = self.create_stubbed_responses()[0]
+ if expected_params:
+ head_response['expected_params'] = expected_params
+ self.stubber.add_response(**head_response)
+
+ def add_successful_get_object_responses(
+ self, expected_params=None, expected_ranges=None
+ ):
+ # Add all get_object responses needed to complete the download.
+ # Should account for both ranged and nonranged downloads.
+ for i, stubbed_response in enumerate(
+ self.create_stubbed_responses()[1:]
+ ):
+ if expected_params:
+ stubbed_response['expected_params'] = copy.deepcopy(
+ expected_params
+ )
+ if expected_ranges:
+ stubbed_response['expected_params'][
+ 'Range'
+ ] = expected_ranges[i]
+ self.stubber.add_response(**stubbed_response)
+
+ def add_n_retryable_get_object_responses(self, n, num_reads=0):
+ for _ in range(n):
+ self.stubber.add_response(
+ method='get_object',
+ service_response={
+ 'Body': StreamWithError(
+ copy.deepcopy(self.stream), SOCKET_ERROR, num_reads
+ )
+ },
+ )
+
+ def test_download_temporary_file_does_not_exist(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+ # Make sure the file exists
+ self.assertTrue(os.path.exists(self.filename))
+ # Make sure the random temporary file does not exist
+ possible_matches = glob.glob('%s*' % self.filename + os.extsep)
+ self.assertEqual(possible_matches, [])
+
+ def test_download_for_fileobj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ with open(self.filename, 'wb') as f:
+ future = self.manager.download(
+ self.bucket, self.key, f, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_download_for_seekable_filelike_obj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ # Create a file-like object to test. In this case, it is a BytesIO
+ # object.
+ bytes_io = BytesIO()
+
+ future = self.manager.download(
+ self.bucket, self.key, bytes_io, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ bytes_io.seek(0)
+ self.assertEqual(self.content, bytes_io.read())
+
+ def test_download_for_nonseekable_filelike_obj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ with open(self.filename, 'wb') as f:
+ future = self.manager.download(
+ self.bucket, self.key, NonSeekableWriter(f), self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_download_cleanup_on_failure(self):
+ self.add_head_object_response()
+
+ # Throw an error on the download
+ self.stubber.add_client_error('get_object')
+
+ future = self.manager.download(**self.create_call_kwargs())
+
+ with self.assertRaises(ClientError):
+ future.result()
+ # Make sure the actual file and the temporary do not exist
+ # by globbing for the file and any of its extensions
+ possible_matches = glob.glob('%s*' % self.filename)
+ self.assertEqual(possible_matches, [])
+
+ def test_download_with_nonexistent_directory(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['fileobj'] = os.path.join(
+ self.tempdir, 'missing-directory', 'myfile'
+ )
+ future = self.manager.download(**call_kwargs)
+ with self.assertRaises(IOError):
+ future.result()
+
+ def test_retries_and_succeeds(self):
+ self.add_head_object_response()
+ # Insert a response that will trigger a retry.
+ self.add_n_retryable_get_object_responses(1)
+ # Add the normal responses to simulate the download proceeding
+ # as normal after the retry.
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+
+ # The retry should have been consumed and the process should have
+ # continued using the successful responses.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_retry_failure(self):
+ self.add_head_object_response()
+
+ max_retries = 3
+ self.config.num_download_attempts = max_retries
+ self._manager = TransferManager(self.client, self.config)
+ # Add responses that fill up the maximum number of retries.
+ self.add_n_retryable_get_object_responses(max_retries)
+
+ future = self.manager.download(**self.create_call_kwargs())
+
+ # A retry exceeded error should have happened.
+ with self.assertRaises(RetriesExceededError):
+ future.result()
+
+ # All of the retries should have been used up.
+ self.stubber.assert_no_pending_responses()
+
+ def test_retry_rewinds_callbacks(self):
+ self.add_head_object_response()
+ # Insert a response that will trigger a retry after one read of the
+ # stream has been made.
+ self.add_n_retryable_get_object_responses(1, num_reads=1)
+ # Add the normal responses to simulate the download proceeding
+ # as normal after the retry.
+ self.add_successful_get_object_responses()
+
+ recorder_subscriber = RecordingSubscriber()
+ # Set the streaming to a size that is smaller than the data we
+ # currently provide to it to simulate rewinds of callbacks.
+ self.config.io_chunksize = 3
+ future = self.manager.download(
+ subscribers=[recorder_subscriber], **self.create_call_kwargs()
+ )
+ future.result()
+
+ # Ensure that there is no more remaining responses and that contents
+ # are correct.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ # Assert that the number of bytes seen is equal to the length of
+ # downloaded content.
+ self.assertEqual(
+ recorder_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ # Also ensure that the second progress invocation was negative three
+ # because a retry happened on the second read of the stream and we
+ # know that the chunk size for each read is 3.
+ progress_byte_amts = [
+ call['bytes_transferred']
+ for call in recorder_subscriber.on_progress_calls
+ ]
+ self.assertEqual(-3, progress_byte_amts[1])
+
+ def test_can_provide_file_size(self):
+ self.add_successful_get_object_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
+
+ future = self.manager.download(**call_kwargs)
+ future.result()
+
+ # The HeadObject should have not happened and should have been able
+ # to successfully download the file.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_uses_provided_osutil(self):
+ osutil = RecordingOSUtils()
+ # Use the recording os utility for the transfer manager
+ self._manager = TransferManager(self.client, self.config, osutil)
+
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+ # The osutil should have had its open() method invoked when opening
+ # a temporary file and its rename_file() method invoked when the
+ # the temporary file was moved to its final location.
+ self.assertEqual(len(osutil.open_records), 1)
+ self.assertEqual(len(osutil.rename_records), 1)
+
+ @skip_if_windows('Windows does not support UNIX special files')
+ @skip_if_using_serial_implementation(
+ 'A separate thread is needed to read from the fifo'
+ )
+ def test_download_for_fifo_file(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ # Create the fifo file
+ os.mkfifo(self.filename)
+
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+
+ # The call to open a fifo will block until there is both a reader
+ # and a writer, so we need to open it for reading after we've
+ # started the transfer.
+ with open(self.filename, 'rb') as fifo:
+ future.result()
+ self.assertEqual(fifo.read(), self.content)
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.download(
+ s3_object_lambda_arn, self.key, self.filename, self.extra_args
+ )
+
+
+class TestNonRangedDownload(BaseDownloadTest):
+ # TODO: If you want to add tests outside of this test class and still
+ # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
+ # you do not, your tests will not get picked up by the test runner! This
+ # needs to be done until we find a better way to ignore running test cases
+ # from the general test base class, which we do not want ran.
+ __test__ = True
+
+ def test_download(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'RequestPayer': 'requester',
+ }
+ self.add_head_object_response(expected_params)
+ self.add_successful_get_object_responses(expected_params)
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_allowed_copy_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('GetObject')
+ for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_download_empty_object(self):
+ self.content = b''
+ self.stream = BytesIO(self.content)
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the empty file exists
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(b'', f.read())
+
+ def test_uses_bandwidth_limiter(self):
+ self.content = b'a' * 1024 * 1024
+ self.stream = BytesIO(self.content)
+ self.config = TransferConfig(
+ max_request_concurrency=1, max_bandwidth=len(self.content) / 2
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ start = time.time()
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+ # This is just a smoke test to make sure that the limiter is
+ # being used and not necessary its exactness. So we set the maximum
+ # bandwidth to len(content)/2 per sec and make sure that it is
+ # noticeably slower. Ideally it will take more than two seconds, but
+ # given tracking at the beginning of transfers are not entirely
+ # accurate setting at the initial start of a transfer, we give us
+ # some flexibility by setting the expected time to half of the
+ # theoretical time to take.
+ self.assertGreaterEqual(time.time() - start, 1)
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+
+class TestRangedDownload(BaseDownloadTest):
+ # TODO: If you want to add tests outside of this test class and still
+ # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
+ # you do not, your tests will not get picked up by the test runner! This
+ # needs to be done until we find a better way to ignore running test cases
+ # from the general test base class, which we do not want ran.
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[0:4])},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[4:8])},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[8:])},
+ },
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2},
+ ]
+
+ def test_download(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'RequestPayer': 'requester',
+ }
+ expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-']
+ self.add_head_object_response(expected_params)
+ self.add_successful_get_object_responses(
+ expected_params, expected_ranges
+ )
+
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_manager.py b/contrib/python/s3transfer/py3/tests/functional/test_manager.py
new file mode 100644
index 0000000000..1c980e7bc6
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_manager.py
@@ -0,0 +1,191 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from io import BytesIO
+
+from botocore.awsrequest import create_request_object
+
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import BaseExecutor
+from s3transfer.manager import TransferConfig, TransferManager
+from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
+
+
+class ArbitraryException(Exception):
+ pass
+
+
+class SignalTransferringBody(BytesIO):
+ """A mocked body with the ability to signal when transfers occur"""
+
+ def __init__(self):
+ super().__init__()
+ self.signal_transferring_call_count = 0
+ self.signal_not_transferring_call_count = 0
+
+ def signal_transferring(self):
+ self.signal_transferring_call_count += 1
+
+ def signal_not_transferring(self):
+ self.signal_not_transferring_call_count += 1
+
+ def seek(self, where, whence=0):
+ pass
+
+ def tell(self):
+ return 0
+
+ def read(self, amount=0):
+ return b''
+
+
+class TestTransferManager(StubbedClientTest):
+ @skip_if_using_serial_implementation(
+ 'Exception is thrown once all transfers are submitted. '
+ 'However for the serial implementation, transfers are performed '
+ 'in main thread meaning all transfers will complete before the '
+ 'exception being thrown.'
+ )
+ def test_error_in_context_manager_cancels_incomplete_transfers(self):
+ # The purpose of this test is to make sure if an error is raised
+ # in the body of the context manager, incomplete transfers will
+ # be cancelled with value of the exception wrapped by a CancelledError
+
+ # NOTE: The fact that delete() was chosen to test this is arbitrary
+ # other than it is the easiet to set up for the stubber.
+ # The specific operation is not important to the purpose of this test.
+ num_transfers = 100
+ futures = []
+ ref_exception_msg = 'arbitrary exception'
+
+ for _ in range(num_transfers):
+ self.stubber.add_response('delete_object', {})
+
+ manager = TransferManager(
+ self.client,
+ TransferConfig(
+ max_request_concurrency=1, max_submission_concurrency=1
+ ),
+ )
+ try:
+ with manager:
+ for i in range(num_transfers):
+ futures.append(manager.delete('mybucket', 'mykey'))
+ raise ArbitraryException(ref_exception_msg)
+ except ArbitraryException:
+ # At least one of the submitted futures should have been
+ # cancelled.
+ with self.assertRaisesRegex(FatalError, ref_exception_msg):
+ for future in futures:
+ future.result()
+
+ @skip_if_using_serial_implementation(
+ 'Exception is thrown once all transfers are submitted. '
+ 'However for the serial implementation, transfers are performed '
+ 'in main thread meaning all transfers will complete before the '
+ 'exception being thrown.'
+ )
+ def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self):
+ # The purpose of this test is to make sure if an error is raised
+ # in the body of the context manager, incomplete transfers will
+ # be cancelled with value of the exception wrapped by a CancelledError
+
+ # NOTE: The fact that delete() was chosen to test this is arbitrary
+ # other than it is the easiet to set up for the stubber.
+ # The specific operation is not important to the purpose of this test.
+ num_transfers = 100
+ futures = []
+
+ for _ in range(num_transfers):
+ self.stubber.add_response('delete_object', {})
+
+ manager = TransferManager(
+ self.client,
+ TransferConfig(
+ max_request_concurrency=1, max_submission_concurrency=1
+ ),
+ )
+ try:
+ with manager:
+ for i in range(num_transfers):
+ futures.append(manager.delete('mybucket', 'mykey'))
+ raise KeyboardInterrupt()
+ except KeyboardInterrupt:
+ # At least one of the submitted futures should have been
+ # cancelled.
+ with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'):
+ for future in futures:
+ future.result()
+
+ def test_enable_disable_callbacks_only_ever_registered_once(self):
+ body = SignalTransferringBody()
+ request = create_request_object(
+ {
+ 'method': 'PUT',
+ 'url': 'https://s3.amazonaws.com',
+ 'body': body,
+ 'headers': {},
+ 'context': {},
+ }
+ )
+ # Create two TransferManager's using the same client
+ TransferManager(self.client)
+ TransferManager(self.client)
+ self.client.meta.events.emit(
+ 'request-created.s3', request=request, operation_name='PutObject'
+ )
+ # The client should have only have the enable/disable callback
+ # handlers registered once depite being used for two different
+ # TransferManagers.
+ self.assertEqual(
+ body.signal_transferring_call_count,
+ 1,
+ 'The enable_callback() should have only ever been registered once',
+ )
+ self.assertEqual(
+ body.signal_not_transferring_call_count,
+ 1,
+ 'The disable_callback() should have only ever been registered '
+ 'once',
+ )
+
+ def test_use_custom_executor_implementation(self):
+ mocked_executor_cls = mock.Mock(BaseExecutor)
+ transfer_manager = TransferManager(
+ self.client, executor_cls=mocked_executor_cls
+ )
+ transfer_manager.delete('bucket', 'key')
+ self.assertTrue(mocked_executor_cls.return_value.submit.called)
+
+ def test_unicode_exception_in_context_manager(self):
+ with self.assertRaises(ArbitraryException):
+ with TransferManager(self.client):
+ raise ArbitraryException('\u2713')
+
+ def test_client_property(self):
+ manager = TransferManager(self.client)
+ self.assertIs(manager.client, self.client)
+
+ def test_config_property(self):
+ config = TransferConfig()
+ manager = TransferManager(self.client, config)
+ self.assertIs(manager.config, config)
+
+ def test_can_disable_bucket_validation(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ config = TransferConfig()
+ manager = TransferManager(self.client, config)
+ manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False
+ manager.delete(s3_object_lambda_arn, 'my-key')
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py
new file mode 100644
index 0000000000..1396c919f2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py
@@ -0,0 +1,281 @@
+# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import glob
+import os
+from io import BytesIO
+from multiprocessing.managers import BaseManager
+
+import botocore.exceptions
+import botocore.session
+from botocore.stub import Stubber
+
+from s3transfer.exceptions import CancelledError
+from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
+from __tests__ import FileCreator, mock, unittest
+
+
+class StubbedClient:
+ def __init__(self):
+ self._client = botocore.session.get_session().create_client(
+ 's3',
+ 'us-west-2',
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ self._stubber = Stubber(self._client)
+ self._stubber.activate()
+ self._caught_stubber_errors = []
+
+ def get_object(self, **kwargs):
+ return self._client.get_object(**kwargs)
+
+ def head_object(self, **kwargs):
+ return self._client.head_object(**kwargs)
+
+ def add_response(self, *args, **kwargs):
+ self._stubber.add_response(*args, **kwargs)
+
+ def add_client_error(self, *args, **kwargs):
+ self._stubber.add_client_error(*args, **kwargs)
+
+
+class StubbedClientManager(BaseManager):
+ pass
+
+
+StubbedClientManager.register('StubbedClient', StubbedClient)
+
+
+# Ideally a Mock would be used here. However, they cannot be pickled
+# for Windows. So instead we define a factory class at the module level that
+# can return a stubbed client we initialized in the setUp.
+class StubbedClientFactory:
+ def __init__(self, stubbed_client):
+ self._stubbed_client = stubbed_client
+
+ def __call__(self, *args, **kwargs):
+ # The __call__ is defined so we can provide an instance of the
+ # StubbedClientFactory to mock.patch() and have the instance be
+ # returned when the patched class is instantiated.
+ return self
+
+ def create_client(self):
+ return self._stubbed_client
+
+
+class TestProcessPoolDownloader(unittest.TestCase):
+ def setUp(self):
+ # The stubbed client needs to run in a manager to be shared across
+ # processes and have it properly consume the stubbed response across
+ # processes.
+ self.manager = StubbedClientManager()
+ self.manager.start()
+ self.stubbed_client = self.manager.StubbedClient()
+ self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)
+
+ self.client_factory_patch = mock.patch(
+ 's3transfer.processpool.ClientFactory', self.stubbed_client_factory
+ )
+ self.client_factory_patch.start()
+ self.files = FileCreator()
+
+ self.config = ProcessTransferConfig(max_request_processes=1)
+ self.downloader = ProcessPoolDownloader(config=self.config)
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.filename = self.files.full_path('filename')
+ self.remote_contents = b'my content'
+ self.stream = BytesIO(self.remote_contents)
+
+ def tearDown(self):
+ self.manager.shutdown()
+ self.client_factory_patch.stop()
+ self.files.remove_all()
+
+ def assert_contents(self, filename, expected_contents):
+ self.assertTrue(os.path.exists(filename))
+ with open(filename, 'rb') as f:
+ self.assertEqual(f.read(), expected_contents)
+
+ def test_download_file(self):
+ self.stubbed_client.add_response(
+ 'head_object', {'ContentLength': len(self.remote_contents)}
+ )
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ self.downloader.download_file(self.bucket, self.key, self.filename)
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_multiple_files(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ self.stubbed_client.add_response(
+ 'get_object', {'Body': BytesIO(self.remote_contents)}
+ )
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ other_file = self.files.full_path('filename2')
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ other_file,
+ expected_size=len(self.remote_contents),
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+ self.assert_contents(other_file, self.remote_contents)
+
+ def test_download_file_ranged_download(self):
+ half_of_content_length = int(len(self.remote_contents) / 2)
+ self.stubbed_client.add_response(
+ 'head_object', {'ContentLength': len(self.remote_contents)}
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': BytesIO(self.remote_contents[:half_of_content_length])},
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': BytesIO(self.remote_contents[half_of_content_length:])},
+ )
+ downloader = ProcessPoolDownloader(
+ config=ProcessTransferConfig(
+ multipart_chunksize=half_of_content_length,
+ multipart_threshold=half_of_content_length,
+ max_request_processes=1,
+ )
+ )
+ with downloader:
+ downloader.download_file(self.bucket, self.key, self.filename)
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_file_extra_args(self):
+ self.stubbed_client.add_response(
+ 'head_object',
+ {'ContentLength': len(self.remote_contents)},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': self.stream},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ extra_args={'VersionId': 'versionid'},
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_file_expected_size(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_cleans_up_tempfile_on_failure(self):
+ self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assertFalse(os.path.exists(self.filename))
+ # Any tempfile should have been erased as well
+ possible_matches = glob.glob('%s*' % self.filename + os.extsep)
+ self.assertEqual(possible_matches, [])
+
+ def test_validates_extra_args(self):
+ with self.downloader:
+ with self.assertRaises(ValueError):
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ extra_args={'NotSupported': 'NotSupported'},
+ )
+
+ def test_result_with_success(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assertIsNone(future.result())
+
+ def test_result_with_exception(self):
+ self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ with self.assertRaises(botocore.exceptions.ClientError):
+ future.result()
+
+ def test_result_with_cancel(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ future.cancel()
+ with self.assertRaises(CancelledError):
+ future.result()
+
+ def test_shutdown_with_no_downloads(self):
+ downloader = ProcessPoolDownloader()
+ try:
+ downloader.shutdown()
+ except AttributeError:
+ self.fail(
+ 'The downloader should be able to be shutdown even though '
+ 'the downloader was never started.'
+ )
+
+ def test_shutdown_with_no_downloads_and_ctrl_c(self):
+ # Special shutdown logic happens if a KeyboardInterrupt is raised in
+ # the context manager. However, this logic can not happen if the
+ # downloader was never started. So a KeyboardInterrupt should be
+ # the only exception propagated.
+ with self.assertRaises(KeyboardInterrupt):
+ with self.downloader:
+ raise KeyboardInterrupt()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_upload.py b/contrib/python/s3transfer/py3/tests/functional/test_upload.py
new file mode 100644
index 0000000000..4f294e85ad
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_upload.py
@@ -0,0 +1,538 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import tempfile
+import time
+from io import BytesIO
+
+from botocore.awsrequest import AWSRequest
+from botocore.client import Config
+from botocore.exceptions import ClientError
+from botocore.stub import ANY
+
+from s3transfer.manager import TransferConfig, TransferManager
+from s3transfer.utils import ChunksizeAdjuster
+from __tests__ import (
+ BaseGeneralInterfaceTest,
+ NonSeekableReader,
+ RecordingOSUtils,
+ RecordingSubscriber,
+ mock,
+)
+
+
+class BaseUploadTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE
+ # when we're adjusting parts.
+ # This is really wasteful and fails CI builds because self.contents
+ # would normally use 10MB+ of memory.
+ # Until there's an API to configure this, we're patching this with
+ # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly
+ # because it's already bound to a default value in the
+ # chunksize adjuster. Instead we need to patch out the
+ # chunksize adjuster class.
+ self.adjuster_patch = mock.patch(
+ 's3transfer.upload.ChunksizeAdjuster',
+ lambda: ChunksizeAdjuster(min_size=1),
+ )
+ self.adjuster_patch.start()
+ self.config = TransferConfig(max_request_concurrency=1)
+ self._manager = TransferManager(self.client, self.config)
+
+ # Create a temporary directory with files to read from
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'my content'
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+ self.adjuster_patch.stop()
+
+ def collect_body(self, params, model, **kwargs):
+ # A handler to simulate the reading of the body including the
+ # request-created event that signals to simulate the progress
+ # callbacks
+ if 'Body' in params:
+ # TODO: This is not ideal. Need to figure out a better idea of
+ # simulating reading of the request across the wire to trigger
+ # progress callbacks
+ request = AWSRequest(
+ method='PUT',
+ url='https://s3.amazonaws.com',
+ data=params['Body'],
+ )
+ self.client.meta.events.emit(
+ 'request-created.s3.%s' % model.name,
+ request=request,
+ operation_name=model.name,
+ )
+ self.sent_bodies.append(self._stream_body(params['Body']))
+
+ def _stream_body(self, body):
+ read_amt = 8 * 1024
+ data = body.read(read_amt)
+ collected_body = data
+ while data:
+ data = body.read(read_amt)
+ collected_body += data
+ return collected_body
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.upload
+
+ def create_call_kwargs(self):
+ return {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ return [{'method': 'put_object', 'service_response': {}}]
+
+ def create_expected_progress_callback_info(self):
+ return [{'bytes_transferred': 10}]
+
+ def assert_expected_client_calls_were_correct(self):
+ # We assert that expected client calls were made by ensuring that
+ # there are no more pending responses. If there are no more pending
+ # responses, then all stubbed responses were consumed.
+ self.stubber.assert_no_pending_responses()
+
+
+class TestNonMultipartUpload(BaseUploadTest):
+ __test__ = True
+
+ def add_put_object_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key}
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ upload_response = self.create_stubbed_responses()[0]
+ upload_response['expected_params'] = expected_params
+ self.stubber.add_response(**upload_response)
+
+ def assert_put_object_body_was_correct(self):
+ self.assertEqual(self.sent_bodies, [self.content])
+
+ def test_upload(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ self.add_put_object_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_fileobj(self):
+ self.add_put_object_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_seekable_filelike_obj(self):
+ self.add_put_object_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
+ self.add_put_object_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ seek_pos = 5
+ bytes_io.seek(seek_pos)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
+
+ def test_upload_for_non_seekable_filelike_obj(self):
+ self.add_put_object_response_with_default_expected_params()
+ body = NonSeekableReader(self.content)
+ future = self.manager.upload(
+ body, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_sigv4_progress_callbacks_invoked_once(self):
+ # Reset the client and manager to use sigv4
+ self.reset_stubber_with_new_client(
+ {'config': Config(signature_version='s3v4')}
+ )
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ # Add the stubbed response.
+ self.add_put_object_response_with_default_expected_params()
+
+ subscriber = RecordingSubscriber()
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, subscribers=[subscriber]
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ # The amount of bytes seen should be the same as the file size
+ self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content))
+
+ def test_uses_provided_osutil(self):
+ osutil = RecordingOSUtils()
+ # Use the recording os utility for the transfer manager
+ self._manager = TransferManager(self.client, self.config, osutil)
+
+ self.add_put_object_response_with_default_expected_params()
+
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ future.result()
+
+ # The upload should have used the os utility. We check this by making
+ # sure that the recorded opens are as expected.
+ expected_opens = [(self.filename, 'rb')]
+ self.assertEqual(osutil.open_records, expected_opens)
+
+ def test_allowed_upload_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('PutObject')
+ for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_upload_with_bandwidth_limiter(self):
+ self.content = b'a' * 1024 * 1024
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.config = TransferConfig(
+ max_request_concurrency=1, max_bandwidth=len(self.content) / 2
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ self.add_put_object_response_with_default_expected_params()
+ start = time.time()
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ future.result()
+ # This is just a smoke test to make sure that the limiter is
+ # being used and not necessary its exactness. So we set the maximum
+ # bandwidth to len(content)/2 per sec and make sure that it is
+ # noticeably slower. Ideally it will take more than two seconds, but
+ # given tracking at the beginning of transfers are not entirely
+ # accurate setting at the initial start of a transfer, we give us
+ # some flexibility by setting the expected time to half of the
+ # theoretical time to take.
+ self.assertGreaterEqual(time.time() - start, 1)
+
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.upload(self.filename, s3_object_lambda_arn, self.key)
+
+
+class TestMultipartUpload(BaseUploadTest):
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.chunksize = 4
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=self.chunksize,
+ )
+ self._manager = TransferManager(self.client, self.config)
+ self.multipart_id = 'my-upload-id'
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'create_multipart_upload',
+ 'service_response': {'UploadId': self.multipart_id},
+ },
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}},
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}},
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}},
+ {'method': 'complete_multipart_upload', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2},
+ ]
+
+ def assert_upload_part_bodies_were_correct(self):
+ expected_contents = []
+ for i in range(0, len(self.content), self.chunksize):
+ end_i = i + self.chunksize
+ if end_i > len(self.content):
+ expected_contents.append(self.content[i:])
+ else:
+ expected_contents.append(self.content[i:end_i])
+ self.assertEqual(self.sent_bodies, expected_contents)
+
+ def add_create_multipart_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {'Bucket': self.bucket, 'Key': self.key}
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ response = self.create_stubbed_responses()[0]
+ response['expected_params'] = expected_params
+ self.stubber.add_response(**response)
+
+ def add_upload_part_responses_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ num_parts = 3
+ upload_part_responses = self.create_stubbed_responses()[1:-1]
+ for i in range(num_parts):
+ upload_part_response = upload_part_responses[i]
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'Body': ANY,
+ 'PartNumber': i + 1,
+ }
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ upload_part_response['expected_params'] = expected_params
+ self.stubber.add_response(**upload_part_response)
+
+ def add_complete_multipart_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'MultipartUpload': {
+ 'Parts': [
+ {'ETag': 'etag-1', 'PartNumber': 1},
+ {'ETag': 'etag-2', 'PartNumber': 2},
+ {'ETag': 'etag-3', 'PartNumber': 3},
+ ]
+ },
+ }
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ response = self.create_stubbed_responses()[-1]
+ response['expected_params'] = expected_params
+ self.stubber.add_response(**response)
+
+ def test_upload(self):
+ self.extra_args['RequestPayer'] = 'requester'
+
+ # Add requester pays to the create multipart upload and upload parts.
+ self.add_create_multipart_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ self.add_upload_part_responses_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ self.add_complete_multipart_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ def test_upload_for_fileobj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_for_seekable_filelike_obj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ seek_pos = 1
+ bytes_io.seek(seek_pos)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
+
+ def test_upload_for_non_seekable_filelike_obj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ stream = NonSeekableReader(self.content)
+ future = self.manager.upload(
+ stream, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_limits_in_memory_chunks_for_fileobj(self):
+ # Limit the maximum in memory chunks to one but make number of
+ # threads more than one. This means that the upload will have to
+ # happen sequentially despite having many threads available because
+ # data is sequentially partitioned into chunks in memory and since
+ # there can only every be one in memory chunk, each upload part will
+ # have to happen one at a time.
+ self.config.max_request_concurrency = 10
+ self.config.max_in_memory_upload_chunks = 1
+ self._manager = TransferManager(self.client, self.config)
+
+ # Add some default stubbed responses.
+ # These responses are added in order of part number so if the
+ # multipart upload is not done sequentially, which it should because
+ # we limit the in memory upload chunks to one, the stubber will
+ # raise exceptions for mismatching parameters for partNumber when
+ # once the upload() method is called on the transfer manager.
+ # If there is a mismatch, the stubber error will propagate on
+ # the future.result()
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+
+ # Make sure that the stubber had all of its stubbed responses consumed.
+ self.assert_expected_client_calls_were_correct()
+ # Ensure the contents were uploaded in sequentially order by checking
+ # the sent contents were in order.
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_failure_invokes_abort(self):
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response={'UploadId': self.multipart_id},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ method='upload_part',
+ service_response={'ETag': 'etag-1'},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Body': ANY,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'PartNumber': 1,
+ },
+ )
+ # With the upload part failing this should immediately initiate
+ # an abort multipart with no more upload parts called.
+ self.stubber.add_client_error(method='upload_part')
+
+ self.stubber.add_response(
+ method='abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ },
+ )
+
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ # The exception should get propagated to the future and not be
+ # a cancelled error or something.
+ with self.assertRaises(ClientError):
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ def test_upload_passes_select_extra_args(self):
+ self.extra_args['Metadata'] = {'foo': 'bar'}
+
+ # Add metadata to expected create multipart upload call
+ self.add_create_multipart_response_with_default_expected_params(
+ extra_expected_params={'Metadata': {'foo': 'bar'}}
+ )
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_utils.py b/contrib/python/s3transfer/py3/tests/functional/test_utils.py
new file mode 100644
index 0000000000..fd4a232ecc
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/functional/test_utils.py
@@ -0,0 +1,41 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import socket
+import tempfile
+
+from s3transfer.utils import OSUtils
+from __tests__ import skip_if_windows, unittest
+
+
+@skip_if_windows('Windows does not support UNIX special files')
+class TestOSUtilsSpecialFiles(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_character_device(self):
+ self.assertTrue(OSUtils().is_special_file('/dev/null'))
+
+ def test_fifo(self):
+ os.mkfifo(self.filename)
+ self.assertTrue(OSUtils().is_special_file(self.filename))
+
+ def test_socket(self):
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.bind(self.filename)
+ self.assertTrue(OSUtils().is_special_file(self.filename))
diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py
new file mode 100644
index 0000000000..79ef91c6a2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
new file mode 100644
index 0000000000..b796f8f24c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
@@ -0,0 +1,452 @@
+# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import tempfile
+
+from s3transfer.bandwidth import (
+ BandwidthLimitedStream,
+ BandwidthLimiter,
+ BandwidthRateTracker,
+ ConsumptionScheduler,
+ LeakyBucket,
+ RequestExceededException,
+ RequestToken,
+ TimeUtils,
+)
+from s3transfer.futures import TransferCoordinator
+from __tests__ import mock, unittest
+
+
+class FixedIncrementalTickTimeUtils(TimeUtils):
+ def __init__(self, seconds_per_tick=1.0):
+ self._count = 0
+ self._seconds_per_tick = seconds_per_tick
+
+ def time(self):
+ current_count = self._count
+ self._count += self._seconds_per_tick
+ return current_count
+
+
+class TestTimeUtils(unittest.TestCase):
+ @mock.patch('time.time')
+ def test_time(self, mock_time):
+ mock_return_val = 1
+ mock_time.return_value = mock_return_val
+ time_utils = TimeUtils()
+ self.assertEqual(time_utils.time(), mock_return_val)
+
+ @mock.patch('time.sleep')
+ def test_sleep(self, mock_sleep):
+ time_utils = TimeUtils()
+ time_utils.sleep(1)
+ self.assertEqual(mock_sleep.call_args_list, [mock.call(1)])
+
+
+class BaseBandwidthLimitTest(unittest.TestCase):
+ def setUp(self):
+ self.leaky_bucket = mock.Mock(LeakyBucket)
+ self.time_utils = mock.Mock(TimeUtils)
+ self.tempdir = tempfile.mkdtemp()
+ self.content = b'a' * 1024 * 1024
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.coordinator = TransferCoordinator()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def assert_consume_calls(self, amts):
+ expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts]
+ self.assertEqual(
+ self.leaky_bucket.consume.call_args_list, expected_consume_args
+ )
+
+
+class TestBandwidthLimiter(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
+
+ def test_get_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.assert_consume_calls(amts=[len(self.content)])
+
+ def test_get_disabled_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator, enabled=False
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.leaky_bucket.consume.assert_not_called()
+
+
+class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bytes_threshold = 1
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def get_bandwidth_limited_stream(self, f):
+ return BandwidthLimitedStream(
+ f,
+ self.leaky_bucket,
+ self.coordinator,
+ self.time_utils,
+ self.bytes_threshold,
+ )
+
+ def assert_sleep_calls(self, amts):
+ expected_sleep_args_list = [mock.call(amt) for amt in amts]
+ self.assertEqual(
+ self.time_utils.sleep.call_args_list, expected_sleep_args_list
+ )
+
+ def get_unique_consume_request_tokens(self):
+ return {
+ call_args[0][1]
+ for call_args in self.leaky_bucket.consume.call_args_list
+ }
+
+ def test_read(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[len(self.content)])
+ self.assert_sleep_calls(amts=[])
+
+ def test_retries_on_request_exceeded(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ retry_time = 1
+ amt_requested = len(self.content)
+ self.leaky_bucket.consume.side_effect = [
+ RequestExceededException(amt_requested, retry_time),
+ len(self.content),
+ ]
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[amt_requested, amt_requested])
+ self.assert_sleep_calls(amts=[retry_time])
+
+ def test_with_transfer_coordinator_exception(self):
+ self.coordinator.set_exception(ValueError())
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ with self.assertRaises(ValueError):
+ stream.read(len(self.content))
+
+ def test_read_when_bandwidth_limiting_disabled(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assertFalse(self.leaky_bucket.consume.called)
+
+ def test_read_toggle_disable_enable_bandwidth_limiting(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.enable_bandwidth_limiting()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+ def test_seek(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.seek(1)
+ self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)])
+
+ def test_tell(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.tell()
+ self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()])
+
+ def test_close(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.close()
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_context_manager(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ with stream as stream_handle:
+ self.assertIs(stream_handle, stream)
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_reuses_request_token(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
+
+ def test_request_tokens_unique_per_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
+
+ def test_call_consume_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ self.assertEqual(stream.read(1), self.content[1:2])
+ self.assert_consume_calls(amts=[2])
+
+ def test_resets_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ self.assertEqual(stream.read(1), self.content[2:3])
+ self.assert_consume_calls(amts=[2])
+
+ def test_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.close()
+ self.assert_consume_calls(amts=[1])
+
+ def test_no_bytes_remaining_on(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ stream.close()
+ # There should have been no more consume() calls made
+ # as all bytes have been accounted for in the previous
+ # consume() call.
+ self.assert_consume_calls(amts=[2])
+
+ def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.disable_bandwidth_limiting()
+ stream.close()
+ self.assert_consume_calls(amts=[])
+
+ def test_signal_transferring(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.signal_not_transferring()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.signal_transferring()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+
+class TestLeakyBucket(unittest.TestCase):
+ def setUp(self):
+ self.max_rate = 1
+ self.time_now = 1.0
+ self.time_utils = mock.Mock(TimeUtils)
+ self.time_utils.time.return_value = self.time_now
+ self.scheduler = mock.Mock(ConsumptionScheduler)
+ self.scheduler.is_scheduled.return_value = False
+ self.rate_tracker = mock.Mock(BandwidthRateTracker)
+ self.leaky_bucket = LeakyBucket(
+ self.max_rate, self.time_utils, self.rate_tracker, self.scheduler
+ )
+
+ def set_projected_rate(self, rate):
+ self.rate_tracker.get_projected_rate.return_value = rate
+
+ def set_retry_time(self, retry_time):
+ self.scheduler.schedule_consumption.return_value = retry_time
+
+ def assert_recorded_consumed_amt(self, expected_amt):
+ self.assertEqual(
+ self.rate_tracker.record_consumption_rate.call_args,
+ mock.call(expected_amt, self.time_utils.time.return_value),
+ )
+
+ def assert_was_scheduled(self, amt, token):
+ self.assertEqual(
+ self.scheduler.schedule_consumption.call_args,
+ mock.call(amt, token, amt / (self.max_rate)),
+ )
+
+ def assert_nothing_scheduled(self):
+ self.assertFalse(self.scheduler.schedule_consumption.called)
+
+ def assert_processed_request_token(self, request_token):
+ self.assertEqual(
+ self.scheduler.process_scheduled_consumption.call_args,
+ mock.call(request_token),
+ )
+
+ def test_consume_under_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate / 2)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_at_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_over_max_rate(self):
+ amt = 1
+ retry_time = 2.0
+ self.set_projected_rate(self.max_rate + 1)
+ self.set_retry_time(retry_time)
+ request_token = RequestToken()
+ try:
+ self.leaky_bucket.consume(amt, request_token)
+ self.fail('A RequestExceededException should have been thrown')
+ except RequestExceededException as e:
+ self.assertEqual(e.requested_amt, amt)
+ self.assertEqual(e.retry_time, retry_time)
+ self.assert_was_scheduled(amt, request_token)
+
+ def test_consume_with_scheduled_retry(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate + 1)
+ self.scheduler.is_scheduled.return_value = True
+ request_token = RequestToken()
+ self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
+ # Nothing new should have been scheduled but the request token
+ # should have been processed.
+ self.assert_nothing_scheduled()
+ self.assert_processed_request_token(request_token)
+
+
+class TestConsumptionScheduler(unittest.TestCase):
+ def setUp(self):
+ self.scheduler = ConsumptionScheduler()
+
+ def test_schedule_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ def test_schedule_consumption_for_multiple_requests(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ other_consume_time = 3
+ other_token = RequestToken()
+ next_wait_time = self.scheduler.schedule_consumption(
+ 1, other_token, other_consume_time
+ )
+
+ # This wait time should be the previous time plus its desired
+ # wait time
+ self.assertEqual(next_wait_time, consume_time + other_consume_time)
+
+ def test_is_scheduled(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.assertTrue(self.scheduler.is_scheduled(token))
+
+ def test_is_not_scheduled(self):
+ self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
+
+ def test_process_scheduled_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.scheduler.process_scheduled_consumption(token)
+ self.assertFalse(self.scheduler.is_scheduled(token))
+ different_time = 7
+ # The previous consume time should have no affect on the next wait tim
+ # as it has been completed.
+ self.assertEqual(
+ self.scheduler.schedule_consumption(1, token, different_time),
+ different_time,
+ )
+
+
+class TestBandwidthRateTracker(unittest.TestCase):
+ def setUp(self):
+ self.alpha = 0.8
+ self.rate_tracker = BandwidthRateTracker(self.alpha)
+
+ def test_current_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate_after_one_recorded_point(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ # There is no last time point to do a diff against so return a
+ # current rate of 0.0
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, 0.96)
+
+ def test_get_projected_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
+
+ def test_get_projected_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ projected_rate = self.rate_tracker.get_projected_rate(1, 3)
+ self.assertEqual(projected_rate, 0.96)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, projected_rate)
+
+ def test_get_projected_rate_for_same_timestamp(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.assertEqual(
+ self.rate_tracker.get_projected_rate(1, 1), float('inf')
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
new file mode 100644
index 0000000000..78fdc25845
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
@@ -0,0 +1,105 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import signal
+import tempfile
+from io import BytesIO
+
+from s3transfer.compat import BaseManager, readable, seekable
+from __tests__ import skip_if_windows, unittest
+
+
+class ErrorRaisingSeekWrapper:
+ """An object wrapper that throws an error when seeked on
+
+ :param fileobj: The fileobj that it wraps
+ :param exception: The exception to raise when seeked on.
+ """
+
+ def __init__(self, fileobj, exception):
+ self._fileobj = fileobj
+ self._exception = exception
+
+ def seek(self, offset, whence=0):
+ raise self._exception
+
+ def tell(self):
+ return self._fileobj.tell()
+
+
+class TestSeekable(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_seekable_fileobj(self):
+ with open(self.filename, 'w') as f:
+ self.assertTrue(seekable(f))
+
+ def test_non_file_like_obj(self):
+ # Fails because there is no seekable(), seek(), nor tell()
+ self.assertFalse(seekable(object()))
+
+ def test_non_seekable_ioerror(self):
+ # Should return False if IOError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError())))
+
+ def test_non_seekable_oserror(self):
+ # Should return False if OSError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError())))
+
+
+class TestReadable(unittest.TestCase):
+ def test_readable_fileobj(self):
+ with tempfile.TemporaryFile() as f:
+ self.assertTrue(readable(f))
+
+ def test_readable_file_like_obj(self):
+ self.assertTrue(readable(BytesIO()))
+
+ def test_non_file_like_obj(self):
+ self.assertFalse(readable(object()))
+
+
+class TestBaseManager(unittest.TestCase):
+ def create_pid_manager(self):
+ class PIDManager(BaseManager):
+ pass
+
+ PIDManager.register('getpid', os.getpid)
+ return PIDManager()
+
+ def get_pid(self, pid_manager):
+ pid = pid_manager.getpid()
+ # A proxy object is returned back. The needed value can be acquired
+ # from the repr and converting that to an integer
+ return int(str(pid))
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_can_provide_signal_handler_initializers_to_start(self):
+ manager = self.create_pid_manager()
+ manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN))
+ pid = self.get_pid(manager)
+ try:
+ os.kill(pid, signal.SIGINT)
+ except KeyboardInterrupt:
+ pass
+ # Try using the manager after the os.kill on the parent process. The
+ # manager should not have died and should still be usable.
+ self.assertEqual(pid, self.get_pid(manager))
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
new file mode 100644
index 0000000000..3681f69b94
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
@@ -0,0 +1,177 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.copies import CopyObjectTask, CopyPartTask
+from __tests__ import BaseTaskTest, RecordingSubscriber
+
+
+class BaseCopyTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
+ self.extra_args = {}
+ self.callbacks = []
+ self.size = 5
+
+
+class TestCopyObjectTask(BaseCopyTaskTest):
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['ACL'] = 'private'
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'ACL': 'private',
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
+
+
+class TestCopyPartTask(BaseCopyTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.copy_source_range = 'bytes=5-9'
+ self.extra_args['CopySourceRange'] = self.copy_source_range
+ self.upload_id = 'myuploadid'
+ self.part_number = 1
+ self.result_etag = 'my-etag'
+
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': self.upload_id,
+ 'part_number': self.part_number,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyPartTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ 'RequestPayer': 'requester',
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
new file mode 100644
index 0000000000..8c32668eab
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
@@ -0,0 +1,173 @@
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from botocore.credentials import CredentialResolver, ReadOnlyCredentials
+from botocore.session import Session
+
+from s3transfer.exceptions import TransferNotDoneError
+from s3transfer.utils import CallArgs
+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+if HAS_CRT:
+ import awscrt.s3
+
+ import s3transfer.crt
+
+
+class CustomFutureException(Exception):
+ pass
+
+
+@requires_crt
+class TestBotocoreCRTRequestSerializer(unittest.TestCase):
+ def setUp(self):
+ self.region = 'us-west-2'
+ self.session = Session()
+ self.session.set_config_variable('region', self.region)
+ self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
+ self.session
+ )
+ self.bucket = "test_bucket"
+ self.key = "test_key"
+ self.files = FileCreator()
+ self.filename = self.files.create_file('myfile', 'my content')
+ self.expected_path = "/" + self.bucket + "/" + self.key
+ self.expected_host = "s3.%s.amazonaws.com" % (self.region)
+
+ def tearDown(self):
+ self.files.remove_all()
+
+ def test_upload_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "put_object", future
+ )
+ self.assertEqual("PUT", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_download_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "get_object", future
+ )
+ self.assertEqual("GET", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_delete_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket, key=self.key, extra_args={}, subscribers=[]
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "delete_object", future
+ )
+ self.assertEqual("DELETE", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+
+@requires_crt
+class TestCRTCredentialProviderAdapter(unittest.TestCase):
+ def setUp(self):
+ self.botocore_credential_provider = mock.Mock(CredentialResolver)
+ self.access_key = "access_key"
+ self.secret_key = "secret_key"
+ self.token = "token"
+ self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials(
+ self.access_key, self.secret_key, self.token
+ )
+
+ def _call_adapter_and_check(self, credentails_provider_adapter):
+ credentials = credentails_provider_adapter()
+ self.assertEqual(credentials.access_key_id, self.access_key)
+ self.assertEqual(credentials.secret_access_key, self.secret_key)
+ self.assertEqual(credentials.session_token, self.token)
+
+ def test_fetch_crt_credentials_successfully(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ self._call_adapter_and_check(credentails_provider_adapter)
+
+ def test_load_credentials_once(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ called_times = 5
+ for i in range(called_times):
+ self._call_adapter_and_check(credentails_provider_adapter)
+ # Assert that the load_credentails of botocore credential provider
+ # will only be called once
+ self.assertEqual(
+ self.botocore_credential_provider.load_credentials.call_count, 1
+ )
+
+
+@requires_crt
+class TestCRTTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType)
+ self.mock_crt_future = mock.Mock(awscrt.s3.Future)
+ self.mock_s3_request.finished_future = self.mock_crt_future
+ self.coordinator = s3transfer.crt.CRTTransferCoordinator()
+ self.coordinator.set_s3_request(self.mock_s3_request)
+ self.future = s3transfer.crt.CRTTransferFuture(
+ coordinator=self.coordinator
+ )
+
+ def test_set_exception(self):
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
+
+ def test_set_exception_raises_error_when_not_done(self):
+ self.mock_crt_future.done.return_value = False
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(CustomFutureException())
+
+ def test_set_exception_can_override_previous_exception(self):
+ self.future.set_exception(Exception())
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
new file mode 100644
index 0000000000..23b77112f2
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
@@ -0,0 +1,67 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.delete import DeleteObjectTask
+from __tests__ import BaseTaskTest
+
+
+class TestDeleteObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+
+ def get_delete_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['MFA'] = 'mfa-code'
+ self.extra_args['VersionId'] = '12345'
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ # These extra_args should be injected into the
+ # expected params for the delete_object call.
+ 'MFA': 'mfa-code',
+ 'VersionId': '12345',
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py
new file mode 100644
index 0000000000..2bd095f867
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py
@@ -0,0 +1,999 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from s3transfer.bandwidth import BandwidthLimiter
+from s3transfer.compat import SOCKET_ERROR
+from s3transfer.download import (
+ CompleteDownloadNOOPTask,
+ DeferQueue,
+ DownloadChunkIterator,
+ DownloadFilenameOutputManager,
+ DownloadNonSeekableOutputManager,
+ DownloadSeekableOutputManager,
+ DownloadSpecialFilenameOutputManager,
+ DownloadSubmissionTask,
+ GetObjectTask,
+ ImmediatelyWriteIOGetObjectTask,
+ IOCloseTask,
+ IORenameFileTask,
+ IOStreamingWriteTask,
+ IOWriteTask,
+)
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileCreator,
+ NonSeekableWriter,
+ RecordingExecutor,
+ StreamWithError,
+ mock,
+ unittest,
+)
+
+
+class DownloadException(Exception):
+ pass
+
+
+class WriteCollector:
+ """A utility to collect information about writes and seeks"""
+
+ def __init__(self):
+ self._pos = 0
+ self.writes = []
+
+ def seek(self, pos, whence=0):
+ self._pos = pos
+
+ def write(self, data):
+ self.writes.append((self._pos, data))
+ self._pos += len(data)
+
+
+class AlwaysIndicatesSpecialFileOSUtils(OSUtils):
+ """OSUtil that always returns True for is_special_file"""
+
+ def is_special_file(self, filename):
+ return True
+
+
+class CancelledStreamWrapper:
+ """A wrapper to trigger a cancellation while stream reading
+
+ Forces the transfer coordinator to cancel after a certain amount of reads
+ :param stream: The underlying stream to read from
+ :param transfer_coordinator: The coordinator for the transfer
+ :param num_reads: On which read to signal a cancellation. 0 is the first
+ read.
+ """
+
+ def __init__(self, stream, transfer_coordinator, num_reads=0):
+ self._stream = stream
+ self._transfer_coordinator = transfer_coordinator
+ self._num_reads = num_reads
+ self._count = 0
+
+ def read(self, *args, **kwargs):
+ if self._num_reads == self._count:
+ self._transfer_coordinator.cancel()
+ self._stream.read(*args, **kwargs)
+ self._count += 1
+
+
+class BaseDownloadOutputManagerTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+
+ # Create a file to write to
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.call_args = CallArgs(fileobj=self.filename)
+ self.future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+
+class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ self.assertTrue(hasattr(f, 'seek'))
+ # Make sure the name of the file returned is not the same as the
+ # final filename as we should be writing to a temporary file.
+ self.assertNotEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ ref_contents = b'my_contents'
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ temp_filename = f.name
+ # Write some data to test that the data gets moved over to the
+ # final location.
+ f.write(ref_contents)
+ final_task = self.download_output_manager.get_final_io_task()
+ # Make sure it is the appropriate task.
+ self.assertIsInstance(final_task, IORenameFileTask)
+ final_task()
+ # Make sure the temp_file gets removed
+ self.assertFalse(os.path.exists(temp_filename))
+ # Make sure what ever was written to the temp file got moved to
+ # the final filename
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), ref_contents)
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = AlwaysIndicatesSpecialFileOSUtils()
+ self.download_output_manager = DownloadSpecialFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible_for_special_file(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, AlwaysIndicatesSpecialFileOSUtils()
+ )
+ )
+
+ def test_is_not_compatible_for_non_special_file(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, OSUtils()
+ )
+ )
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ # Make sure the name of the file returned is the same as the
+ # final filename as we should not be writing to a temporary file.
+ self.assertEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(), IOCloseTask
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+
+class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ # Create a fileobj to write to
+ self.fileobj = open(self.filename, 'wb')
+
+ self.call_args = CallArgs(fileobj=self.fileobj)
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.fileobj, self.osutil
+ )
+ )
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(object(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ self.assertIs(
+ self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ),
+ self.fileobj,
+ )
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(),
+ CompleteDownloadNOOPTask,
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadNonSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, io_executor=None
+ )
+
+ def test_is_compatible_with_seekable_stream(self):
+ with open(self.filename, 'wb') as f:
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_not_compatible_with_filename(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_compatible_with_non_seekable_stream(self):
+ class NonSeekable:
+ def write(self, data):
+ pass
+
+ f = NonSeekable()
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_is_compatible_with_bytesio(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIs(
+ self.download_output_manager.get_download_task_tag(),
+ IN_MEMORY_DOWNLOAD_TAG,
+ )
+
+ def test_submit_writes_from_internal_queue(self):
+ class FakeQueue:
+ def request_writes(self, offset, data):
+ return [
+ {'offset': 0, 'data': 'foo'},
+ {'offset': 3, 'data': 'bar'},
+ ]
+
+ q = FakeQueue()
+ io_executor = BoundedExecutor(1000, 1)
+ manager = DownloadNonSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=io_executor,
+ defer_queue=q,
+ )
+ fileobj = WriteCollector()
+ manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1)
+ io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=1
+ )
+ self.assertIsInstance(io_write_task, IOStreamingWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(0, 'foo')])
+
+
+class TestDownloadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # Create a stream to read from
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'io_executor': self.io_executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_download_submission_task()
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_for_get_object(self, tag_value):
+ submissions_to_compare = self.executor.submissions
+ if len(submissions_to_compare) > 1:
+ # If it was ranged get, make sure we do not include the join task.
+ submissions_to_compare = submissions_to_compare[:-1]
+ for submission in submissions_to_compare:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def add_head_object_response(self):
+ self.stubber.add_response(
+ 'head_object', {'ContentLength': len(self.content)}
+ )
+
+ def add_get_responses(self):
+ chunksize = self.config.multipart_chunksize
+ for i in range(0, len(self.content), chunksize):
+ if i + chunksize > len(self.content):
+ stream = BytesIO(self.content[i:])
+ self.stubber.add_response('get_object', {'Body': stream})
+ else:
+ stream = BytesIO(self.content[i : i + chunksize])
+ self.stubber.add_response('get_object', {'Body': stream})
+
+ def configure_for_ranged_get(self):
+ self.config.multipart_threshold = 1
+ self.config.multipart_chunksize = 4
+
+ def get_download_submission_task(self):
+ return self.get_task(
+ DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def wait_and_assert_completed_successfully(self, submission_task):
+ submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_submits_no_tag_for_get_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_filename(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def tests_submits_tag_for_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+
+class TestGetObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+ self.max_attempts = 5
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+ self.fileobj = WriteCollector()
+ self.osutil = OSUtils()
+ self.io_chunksize = 64 * (1024 ** 2)
+ self.task_cls = GetObjectTask
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def get_download_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'fileobj': self.fileobj,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'max_attempts': self.max_attempts,
+ 'download_output_manager': self.download_output_manager,
+ 'io_chunksize': self.io_chunksize,
+ }
+ default_kwargs.update(kwargs)
+ self.transfer_coordinator.set_status_to_queued()
+ return self.get_task(self.task_cls, main_kwargs=default_kwargs)
+
+ def assert_io_writes(self, expected_writes):
+ # Let the io executor process all of the writes before checking
+ # what writes were sent to it.
+ self.io_executor.shutdown()
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_extra_args(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Range': 'bytes=0-',
+ },
+ )
+ self.extra_args['Range'] = 'bytes=0-'
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_control_chunk_size(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+
+ self.assert_io_writes(expected_contents)
+
+ def test_start_index(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(start_index=5)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(5, self.content)])
+
+ def test_uses_bandwidth_limiter(self):
+ bandwidth_limiter = mock.Mock(BandwidthLimiter)
+
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(bandwidth_limiter=bandwidth_limiter)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(
+ bandwidth_limiter.get_bandwith_limited_stream.call_args_list,
+ [mock.call(mock.ANY, self.transfer_coordinator)],
+ )
+
+ def test_retries_succeeds(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ # Retryable error should have not affected the bytes placed into
+ # the io queue.
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_retries_failure(self):
+ for _ in range(self.max_attempts):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+
+ # Should have failed out on a RetriesExceededError
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_retries_in_middle_of_streaming(self):
+ # After the first read a retryable error will be thrown
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(
+ copy.deepcopy(self.stream), SOCKET_ERROR, 1
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ # This is the content initially read in before the retry hit on the
+ # second read()
+ expected_contents.append((0, bytes(self.content[0:1])))
+
+ # The rest of the content should be the entire set of data partitioned
+ # out based on the one byte stream chunk size. Note the second
+ # element in the list should be a copy of the first element since
+ # a retryable exception happened in between.
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+ self.assert_io_writes(expected_contents)
+
+ def test_cancels_out_of_queueing(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': CancelledStreamWrapper(
+ self.stream, self.transfer_coordinator
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ # Make sure that no contents were added to the queue because the task
+ # should have been canceled before trying to add the contents to the
+ # io queue.
+ self.assert_io_writes([])
+
+ def test_handles_callback_on_initial_error(self):
+ # We can't use the stubber for this because we need to raise
+ # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows
+ # you to raise a ClientError.
+ self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR())
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+ # Should have failed out on a RetriesExceededError because
+ # get_object keeps raising a socket error.
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+
+
+class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask):
+ def setUp(self):
+ super().setUp()
+ self.task_cls = ImmediatelyWriteIOGetObjectTask
+ # When data is written out, it should not use the io executor at all
+ # if it does use the io executor that is a deviation from expected
+ # behavior as the data should be written immediately to the file
+ # object once downloaded.
+ self.io_executor = None
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def assert_io_writes(self, expected_writes):
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+
+class BaseIOTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.osutil = OSUtils()
+ self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile')
+ self.final_filename = os.path.join(self.files.rootdir, 'myfile')
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+
+class TestIOStreamingWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foobar'},
+ )
+ task()
+ task2 = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'baz'},
+ )
+ task2()
+ with open(self.temp_filename, 'rb') as f:
+ # We should just have written to the file in the order
+ # the tasks were executed.
+ self.assertEqual(f.read(), b'foobarbaz')
+
+
+class TestIOWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ # Write once to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0},
+ )
+ task()
+
+ # Write again to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3},
+ )
+ task()
+
+ with open(self.temp_filename, 'rb') as f:
+ self.assertEqual(f.read(), b'foobar')
+
+
+class TestIORenameFileTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IORenameFileTask,
+ main_kwargs={
+ 'fileobj': f,
+ 'final_filename': self.final_filename,
+ 'osutil': self.osutil,
+ },
+ )
+ task()
+ self.assertTrue(os.path.exists(self.final_filename))
+ self.assertFalse(os.path.exists(self.temp_filename))
+
+
+class TestIOCloseTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'w') as f:
+ task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f})
+ task()
+ self.assertTrue(f.closed)
+
+
+class TestDownloadChunkIterator(unittest.TestCase):
+ def test_iter(self):
+ content = b'my content'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, len(content)):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'my content'])
+
+ def test_iter_chunksize(self):
+ content = b'1234'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'123', b'4'])
+
+ def test_empty_content(self):
+ body = BytesIO(b'')
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b''])
+
+
+class TestDeferQueue(unittest.TestCase):
+ def setUp(self):
+ self.q = DeferQueue()
+
+ def test_no_writes_when_not_lowest_block(self):
+ writes = self.q.request_writes(offset=1, data='bar')
+ self.assertEqual(writes, [])
+
+ def test_writes_returned_in_order(self):
+ self.assertEqual(self.q.request_writes(offset=3, data='d'), [])
+ self.assertEqual(self.q.request_writes(offset=2, data='c'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # Everything at this point has been deferred, but as soon as we
+ # send offset=0, that will unlock offsets 0-3.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ {'offset': 3, 'data': 'd'},
+ ],
+ )
+
+ def test_unlocks_partial_range(self):
+ self.assertEqual(self.q.request_writes(offset=5, data='f'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ ],
+ )
+
+ def test_data_can_be_any_size(self):
+ self.q.request_writes(offset=5, data='hello world')
+ writes = self.q.request_writes(offset=0, data='abcde')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'abcde'},
+ {'offset': 5, 'data': 'hello world'},
+ ],
+ )
+
+ def test_data_queued_in_order(self):
+ # This immediately gets returned because offset=0 is the
+ # next range we're waiting on.
+ writes = self.q.request_writes(offset=0, data='hello world')
+ self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}])
+ # Same thing here but with offset
+ writes = self.q.request_writes(offset=11, data='hello again')
+ self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}])
+
+ def test_writes_below_min_offset_are_ignored(self):
+ self.q.request_writes(offset=0, data='a')
+ self.q.request_writes(offset=1, data='b')
+ self.q.request_writes(offset=2, data='c')
+
+ # At this point we're expecting offset=3, so if a write
+ # comes in below 3, we ignore it.
+ self.assertEqual(self.q.request_writes(offset=0, data='a'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=3, data='d'),
+ [{'offset': 3, 'data': 'd'}],
+ )
+
+ def test_duplicate_writes_are_ignored(self):
+ self.q.request_writes(offset=2, data='c')
+ self.q.request_writes(offset=1, data='b')
+
+ # We're still waiting for offset=0, but if
+ # a duplicate write comes in for offset=2/offset=1
+ # it's ignored. This gives "first one wins" behavior.
+ self.assertEqual(self.q.request_writes(offset=2, data='X'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='Y'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=0, data='a'),
+ [
+ {'offset': 0, 'data': 'a'},
+ # Note we're seeing 'b' 'c', and not 'X', 'Y'.
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ ],
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
new file mode 100644
index 0000000000..ca2888a654
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
@@ -0,0 +1,696 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import sys
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import (
+ CancelledError,
+ FatalError,
+ TransferNotDoneError,
+)
+from s3transfer.futures import (
+ BaseExecutor,
+ BoundedExecutor,
+ ExecutorFuture,
+ NonThreadedExecutor,
+ NonThreadedExecutorFuture,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.tasks import Task
+from s3transfer.utils import (
+ FunctionContainer,
+ NoResourcesAvailable,
+ TaskSemaphore,
+)
+from __tests__ import (
+ RecordingExecutor,
+ TransferCoordinatorWithInterrupt,
+ mock,
+ unittest,
+)
+
+
+def return_call_args(*args, **kwargs):
+ return args, kwargs
+
+
+def raise_exception(exception):
+ raise exception
+
+
+def get_exc_info(exception):
+ try:
+ raise_exception(exception)
+ except Exception:
+ return sys.exc_info()
+
+
+class RecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self):
+ self.all_transfer_futures_ever_associated = set()
+ super().__init__()
+
+ def add_associated_future(self, future):
+ self.all_transfer_futures_ever_associated.add(future)
+ super().add_associated_future(future)
+
+
+class ReturnFooTask(Task):
+ def _main(self, **kwargs):
+ return 'foo'
+
+
+class SleepTask(Task):
+ def _main(self, sleep_time, **kwargs):
+ time.sleep(sleep_time)
+
+
+class TestTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.meta = TransferMeta()
+ self.coordinator = TransferCoordinator()
+ self.future = self._get_transfer_future()
+
+ def _get_transfer_future(self, **kwargs):
+ components = {
+ 'meta': self.meta,
+ 'coordinator': self.coordinator,
+ }
+ for component_name, component in kwargs.items():
+ components[component_name] = component
+ return TransferFuture(**components)
+
+ def test_meta(self):
+ self.assertIs(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.coordinator.set_result(None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ result = 'foo'
+ self.coordinator.set_result(result)
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), result)
+
+ def test_keyboard_interrupt_on_result_does_not_block(self):
+ # This should raise a KeyboardInterrupt when result is called on it.
+ self.coordinator = TransferCoordinatorWithInterrupt()
+ self.future = self._get_transfer_future()
+
+ # result() should not block and immediately raise the keyboard
+ # interrupt exception.
+ with self.assertRaises(KeyboardInterrupt):
+ self.future.result()
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.assertTrue(self.future.done())
+ self.assertEqual(self.coordinator.status, 'cancelled')
+
+ def test_set_exception(self):
+ # Set the result such that there is no exception
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), 'result')
+
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+ def test_set_exception_only_after_done(self):
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(ValueError())
+
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+
+class TestTransferMeta(unittest.TestCase):
+ def setUp(self):
+ self.transfer_meta = TransferMeta()
+
+ def test_size(self):
+ self.assertEqual(self.transfer_meta.size, None)
+ self.transfer_meta.provide_transfer_size(5)
+ self.assertEqual(self.transfer_meta.size, 5)
+
+ def test_call_args(self):
+ call_args = object()
+ transfer_meta = TransferMeta(call_args)
+ # Assert the that call args provided is the same as is returned
+ self.assertIs(transfer_meta.call_args, call_args)
+
+ def test_transfer_id(self):
+ transfer_meta = TransferMeta(transfer_id=1)
+ self.assertEqual(transfer_meta.transfer_id, 1)
+
+ def test_user_context(self):
+ self.transfer_meta.user_context['foo'] = 'bar'
+ self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'})
+
+
+class TestTransferCoordinator(unittest.TestCase):
+ def setUp(self):
+ self.transfer_coordinator = TransferCoordinator()
+
+ def test_transfer_id(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(transfer_coordinator.transfer_id, 1)
+
+ def test_repr(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(
+ repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)'
+ )
+
+ def test_initial_status(self):
+ # A TransferCoordinator with no progress should have the status
+ # of not-started
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ def test_set_status_to_queued(self):
+ self.transfer_coordinator.set_status_to_queued()
+ self.assertEqual(self.transfer_coordinator.status, 'queued')
+
+ def test_cannot_set_status_to_queued_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_queued()
+
+ def test_status_running(self):
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ def test_cannot_set_status_to_running_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_running()
+
+ def test_set_result(self):
+ success_result = 'foo'
+ self.transfer_coordinator.set_result(success_result)
+ self.transfer_coordinator.announce_done()
+ # Setting result should result in a success state and the return value
+ # that was set.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+ self.assertEqual(self.transfer_coordinator.result(), success_result)
+
+ def test_set_exception(self):
+ exception_result = RuntimeError
+ self.transfer_coordinator.set_exception(exception_result)
+ self.transfer_coordinator.announce_done()
+ # Setting an exception should result in a failed state and the return
+ # value should be the raised exception
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(self.transfer_coordinator.exception, exception_result)
+ with self.assertRaises(exception_result):
+ self.transfer_coordinator.result()
+
+ def test_exception_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError)
+ # It status should be success even after the exception is set because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_exception_can_override_done_state_with_override_flag(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError, override=True)
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_cancel(self):
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.cancel()
+ # This should set the state to cancelled and raise the CancelledError
+ # exception and should have also set the done event so that result()
+ # is no longer set.
+ self.assertEqual(self.transfer_coordinator.status, 'cancelled')
+ with self.assertRaises(CancelledError):
+ self.transfer_coordinator.result()
+
+ def test_cancel_can_run_done_callbacks_that_uses_result(self):
+ exceptions = []
+
+ def capture_exception(transfer_coordinator, captured_exceptions):
+ try:
+ transfer_coordinator.result()
+ except Exception as e:
+ captured_exceptions.append(e)
+
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.add_done_callback(
+ capture_exception, self.transfer_coordinator, exceptions
+ )
+ self.transfer_coordinator.cancel()
+
+ self.assertEqual(len(exceptions), 1)
+ self.assertIsInstance(exceptions[0], CancelledError)
+
+ def test_cancel_with_message(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message, exc_type=FatalError)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.cancel()
+ # It status should be success even after cancel is called because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_set_result_can_override_cancel(self):
+ self.transfer_coordinator.cancel()
+ # Result setting should override any cancel or set exception as this
+ # is always invoked by the final task.
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_submit(self):
+ # Submit a callable to the transfer coordinator. It should submit it
+ # to the executor.
+ executor = RecordingExecutor(
+ BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)})
+ )
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task, tag='my-tag')
+ executor.shutdown()
+ # Make sure the future got submit and executed as well by checking its
+ # result value which should include the provided future tag.
+ self.assertEqual(
+ executor.submissions,
+ [{'block': True, 'tag': 'my-tag', 'task': task}],
+ )
+ self.assertEqual(future.result(), 'foo')
+
+ def test_association_and_disassociation_on_submit(self):
+ self.transfer_coordinator = RecordingTransferCoordinator()
+
+ # Submit a callable to the transfer coordinator.
+ executor = BoundedExecutor(1, 1)
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task)
+ executor.shutdown()
+
+ # Make sure the future that got submitted was associated to the
+ # transfer future at some point.
+ self.assertEqual(
+ self.transfer_coordinator.all_transfer_futures_ever_associated,
+ {future},
+ )
+
+ # Make sure the future got disassociated once the future is now done
+ # by looking at the currently associated futures.
+ self.assertEqual(self.transfer_coordinator.associated_futures, set())
+
+ def test_done(self):
+ # These should result in not done state:
+ # queued
+ self.assertFalse(self.transfer_coordinator.done())
+ # running
+ self.transfer_coordinator.set_status_to_running()
+ self.assertFalse(self.transfer_coordinator.done())
+
+ # These should result in done state:
+ # failed
+ self.transfer_coordinator.set_exception(Exception)
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # success
+ self.transfer_coordinator.set_result('foo')
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # cancelled
+ self.transfer_coordinator.cancel()
+ self.assertTrue(self.transfer_coordinator.done())
+
+ def test_result_waits_until_done(self):
+ execution_order = []
+
+ def sleep_then_set_result(transfer_coordinator, execution_order):
+ time.sleep(0.05)
+ execution_order.append('setting_result')
+ transfer_coordinator.set_result(None)
+ self.transfer_coordinator.announce_done()
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ executor.submit(
+ sleep_then_set_result,
+ self.transfer_coordinator,
+ execution_order,
+ )
+ self.transfer_coordinator.result()
+ execution_order.append('after_result')
+
+ # The result() call should have waited until the other thread set
+ # the result after sleeping for 0.05 seconds.
+ self.assertTrue(execution_order, ['setting_result', 'after_result'])
+
+ def test_failure_cleanups(self):
+ args = (1, 2)
+ kwargs = {'foo': 'bar'}
+
+ second_args = (2, 4)
+ second_kwargs = {'biz': 'baz'}
+
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *args, **kwargs
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *second_args, **second_kwargs
+ )
+
+ # Ensure the callbacks got added.
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2)
+
+ result_list = []
+ # Ensure they will get called in the correct order.
+ for cleanup in self.transfer_coordinator.failure_cleanups:
+ result_list.append(cleanup())
+ self.assertEqual(
+ result_list, [(args, kwargs), (second_args, second_kwargs)]
+ )
+
+ def test_associated_futures(self):
+ first_future = object()
+ # Associate one future to the transfer
+ self.transfer_coordinator.add_associated_future(first_future)
+ associated_futures = self.transfer_coordinator.associated_futures
+ # The first future should be in the returned list of futures.
+ self.assertEqual(associated_futures, {first_future})
+
+ second_future = object()
+ # Associate another future to the transfer.
+ self.transfer_coordinator.add_associated_future(second_future)
+ # The association should not have mutated the returned list from
+ # before.
+ self.assertEqual(associated_futures, {first_future})
+
+ # Both futures should be in the returned list.
+ self.assertEqual(
+ self.transfer_coordinator.associated_futures,
+ {first_future, second_future},
+ )
+
+ def test_done_callbacks_on_done(self):
+ done_callback_invocations = []
+ callback = FunctionContainer(
+ done_callback_invocations.append, 'done callback called'
+ )
+
+ # Add the done callback to the transfer.
+ self.transfer_coordinator.add_done_callback(callback)
+
+ # Announce that the transfer is done. This should invoke the done
+ # callback.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ # If done is announced again, we should not invoke the callback again
+ # because done has already been announced and thus the callback has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ def test_failure_cleanups_on_done(self):
+ cleanup_invocations = []
+ callback = FunctionContainer(
+ cleanup_invocations.append, 'cleanup called'
+ )
+
+ # Add the failure cleanup to the transfer.
+ self.transfer_coordinator.add_failure_cleanup(callback)
+
+ # Announce that the transfer is done. This should invoke the failure
+ # cleanup.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+ # If done is announced again, we should not invoke the cleanup again
+ # because done has already been announced and thus the cleanup has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+
+class TestBoundedExecutor(unittest.TestCase):
+ def setUp(self):
+ self.coordinator = TransferCoordinator()
+ self.tag_semaphores = {}
+ self.executor = self.get_executor()
+
+ def get_executor(self, max_size=1, max_num_threads=1):
+ return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores)
+
+ def get_task(self, task_cls, main_kwargs=None):
+ return task_cls(self.coordinator, main_kwargs=main_kwargs)
+
+ def get_sleep_task(self, sleep_time=0.01):
+ return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time})
+
+ def add_semaphore(self, task_tag, count):
+ self.tag_semaphores[task_tag] = TaskSemaphore(count)
+
+ def assert_submit_would_block(self, task, tag=None):
+ with self.assertRaises(NoResourcesAvailable):
+ self.executor.submit(task, tag=tag, block=False)
+
+ def assert_submit_would_not_block(self, task, tag=None, **kwargs):
+ try:
+ self.executor.submit(task, tag=tag, block=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'Task {} should not have been blocked. Caused by:\n{}'.format(
+ task, traceback.format_exc()
+ )
+ )
+
+ def add_done_callback_to_future(self, future, fn, *args, **kwargs):
+ callback_for_future = FunctionContainer(fn, *args, **kwargs)
+ future.add_done_callback(callback_for_future)
+
+ def test_submit_single_task(self):
+ # Ensure we can submit a task to the executor
+ task = self.get_task(ReturnFooTask)
+ future = self.executor.submit(task)
+
+ # Ensure what we get back is a Future
+ self.assertIsInstance(future, ExecutorFuture)
+ # Ensure the callable got executed.
+ self.assertEqual(future.result(), 'foo')
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_executor_blocks_on_full_capacity(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_sleep_task()
+ self.executor.submit(first_task)
+ # The first task should be sleeping for a substantial period of
+ # time such that on the submission of the second task, it will
+ # raise an error saying that it cannot be submitted as the max
+ # capacity of the semaphore is one.
+ self.assert_submit_would_block(second_task)
+
+ def test_executor_clears_capacity_on_done_tasks(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_task(ReturnFooTask)
+
+ # Submit a task.
+ future = self.executor.submit(first_task)
+
+ # Submit a new task when the first task finishes. This should not get
+ # blocked because the first task should have finished clearing up
+ # capacity.
+ self.add_done_callback_to_future(
+ future, self.assert_submit_would_not_block, second_task
+ )
+
+ # Wait for it to complete.
+ self.executor.shutdown()
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_would_not_block_when_full_capacity_in_other_semaphore(self):
+ first_task = self.get_sleep_task()
+
+ # Now let's create a new task with a tag and so it uses different
+ # semaphore.
+ task_tag = 'other'
+ other_task = self.get_sleep_task()
+ self.add_semaphore(task_tag, 1)
+
+ # Submit the normal first task
+ self.executor.submit(first_task)
+
+ # Even though The first task should be sleeping for a substantial
+ # period of time, the submission of the second task should not
+ # raise an error because it should use a different semaphore
+ self.assert_submit_would_not_block(other_task, task_tag)
+
+ # Another submission of the other task though should raise
+ # an exception as the capacity is equal to one for that tag.
+ self.assert_submit_would_block(other_task, task_tag)
+
+ def test_shutdown(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown()
+ # Ensure that the shutdown waits until the task is done
+ self.assertTrue(future.done())
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_shutdown_no_wait(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown(False)
+ # Ensure that the shutdown returns immediately even if the task is
+ # not done, which it should not be because it it slow.
+ self.assertFalse(future.done())
+
+ def test_replace_underlying_executor(self):
+ mocked_executor_cls = mock.Mock(BaseExecutor)
+ executor = BoundedExecutor(10, 1, {}, mocked_executor_cls)
+ executor.submit(self.get_task(ReturnFooTask))
+ self.assertTrue(mocked_executor_cls.return_value.submit.called)
+
+
+class TestExecutorFuture(unittest.TestCase):
+ def test_result(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'}))
+
+ def test_done(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertTrue(wrapped_future.done())
+
+ def test_add_done_callback(self):
+ done_callbacks = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ wrapped_future.add_done_callback(
+ FunctionContainer(done_callbacks.append, 'called')
+ )
+ self.assertEqual(done_callbacks, ['called'])
+
+
+class TestNonThreadedExecutor(unittest.TestCase):
+ def test_submit(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(return_call_args, 1, 2, foo='bar')
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'}))
+
+ def test_submit_with_exception(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(raise_exception, RuntimeError())
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ with self.assertRaises(RuntimeError):
+ future.result()
+
+ def test_submit_with_exception_and_captures_info(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ future = NonThreadedExecutor().submit(raise_exception, exception)
+ try:
+ future.result()
+ # An exception should have been raised
+ self.fail('Future should have raised a ValueError')
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+
+class TestNonThreadedExecutorFuture(unittest.TestCase):
+ def setUp(self):
+ self.future = NonThreadedExecutorFuture()
+
+ def test_done_starts_false(self):
+ self.assertFalse(self.future.done())
+
+ def test_done_after_setting_result(self):
+ self.future.set_result('result')
+ self.assertTrue(self.future.done())
+
+ def test_done_after_setting_exception(self):
+ self.future.set_exception_info(Exception(), None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.future.set_result('result')
+ self.assertEqual(self.future.result(), 'result')
+
+ def test_exception_result(self):
+ exception = ValueError('message')
+ self.future.set_exception_info(exception, None)
+ with self.assertRaisesRegex(ValueError, 'message'):
+ self.future.result()
+
+ def test_exception_result_doesnt_modify_last_frame(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ self.future.set_exception_info(exception, tb)
+ try:
+ self.future.result()
+ # An exception should have been raised
+ self.fail()
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+ def test_done_callback(self):
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [])
+ self.future.set_result('result')
+ self.assertEqual(done_futures, [self.future])
+
+ def test_done_callback_after_done(self):
+ self.future.set_result('result')
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [self.future])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
new file mode 100644
index 0000000000..fc3caa843f
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
@@ -0,0 +1,143 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import time
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import TransferCoordinator
+from s3transfer.manager import TransferConfig, TransferCoordinatorController
+from __tests__ import TransferCoordinatorWithInterrupt, unittest
+
+
+class FutureResultException(Exception):
+ pass
+
+
+class TestTransferConfig(unittest.TestCase):
+ def test_exception_on_zero_attr_value(self):
+ with self.assertRaises(ValueError):
+ TransferConfig(max_request_queue_size=0)
+
+
+class TestTransferCoordinatorController(unittest.TestCase):
+ def setUp(self):
+ self.coordinator_controller = TransferCoordinatorController()
+
+ def sleep_then_announce_done(self, transfer_coordinator, sleep_time):
+ time.sleep(sleep_time)
+ transfer_coordinator.set_result('done')
+ transfer_coordinator.announce_done()
+
+ def assert_coordinator_is_cancelled(self, transfer_coordinator):
+ self.assertEqual(transfer_coordinator.status, 'cancelled')
+
+ def test_add_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Ensure that is tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators,
+ {transfer_coordinator},
+ )
+
+ def test_remove_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Now remove the coordinator
+ self.coordinator_controller.remove_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Make sure that it is no longer getting tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators, set()
+ )
+
+ def test_cancel(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Cancel with the canceler
+ self.coordinator_controller.cancel()
+ # Check that coordinator got canceled
+ self.assert_coordinator_is_cancelled(transfer_coordinator)
+
+ def test_cancel_with_message(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message, exc_type=FatalError)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ transfer_coordinator.result()
+
+ def test_wait_for_done_transfer_coordinators(self):
+ # Create a coordinator and add it to the canceler
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+
+ sleep_time = 0.02
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ # In a separate thread sleep and then set the transfer coordinator
+ # to done after sleeping.
+ start_time = time.time()
+ executor.submit(
+ self.sleep_then_announce_done, transfer_coordinator, sleep_time
+ )
+ # Now call wait to wait for the transfer coordinator to be done.
+ self.coordinator_controller.wait()
+ end_time = time.time()
+ wait_time = end_time - start_time
+ # The time waited should not be less than the time it took to sleep in
+ # the separate thread because the wait ending should be dependent on
+ # the sleeping thread announcing that the transfer coordinator is done.
+ self.assertTrue(sleep_time <= wait_time)
+
+ def test_wait_does_not_propogate_exceptions_from_result(self):
+ transfer_coordinator = TransferCoordinator()
+ transfer_coordinator.set_exception(FutureResultException())
+ transfer_coordinator.announce_done()
+ try:
+ self.coordinator_controller.wait()
+ except FutureResultException as e:
+ self.fail('%s should not have been raised.' % e)
+
+ def test_wait_can_be_interrupted(self):
+ inject_interrupt_coordinator = TransferCoordinatorWithInterrupt()
+ self.coordinator_controller.add_transfer_coordinator(
+ inject_interrupt_coordinator
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ self.coordinator_controller.wait()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
new file mode 100644
index 0000000000..d77b5e0240
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
@@ -0,0 +1,728 @@
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import queue
+import signal
+import threading
+import time
+from io import BytesIO
+
+from botocore.client import BaseClient
+from botocore.config import Config
+from botocore.exceptions import ClientError, ReadTimeoutError
+
+from s3transfer.constants import PROCESS_USER_AGENT
+from s3transfer.exceptions import CancelledError, RetriesExceededError
+from s3transfer.processpool import (
+ SHUTDOWN_SIGNAL,
+ ClientFactory,
+ DownloadFileRequest,
+ GetObjectJob,
+ GetObjectSubmitter,
+ GetObjectWorker,
+ ProcessPoolDownloader,
+ ProcessPoolTransferFuture,
+ ProcessPoolTransferMeta,
+ ProcessTransferConfig,
+ TransferMonitor,
+ TransferState,
+ ignore_ctrl_c,
+)
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ FileCreator,
+ StreamWithError,
+ StubbedClientTest,
+ mock,
+ skip_if_windows,
+ unittest,
+)
+
+
+class RenameFailingOSUtils(OSUtils):
+ def __init__(self, exception):
+ self.exception = exception
+
+ def rename_file(self, current_filename, new_filename):
+ raise self.exception
+
+
+class TestIgnoreCtrlC(unittest.TestCase):
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_ignore_ctrl_c(self):
+ with ignore_ctrl_c():
+ try:
+ os.kill(os.getpid(), signal.SIGINT)
+ except KeyboardInterrupt:
+ self.fail(
+ 'The ignore_ctrl_c context manager should have '
+ 'ignored the KeyboardInterrupt exception'
+ )
+
+
+class TestProcessPoolDownloader(unittest.TestCase):
+ def test_uses_client_kwargs(self):
+ with mock.patch('s3transfer.processpool.ClientFactory') as factory:
+ ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'})
+ self.assertEqual(
+ factory.call_args[0][0], {'region_name': 'myregion'}
+ )
+
+
+class TestProcessPoolTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+ self.meta = ProcessPoolTransferMeta(
+ transfer_id=self.transfer_id, call_args=CallArgs()
+ )
+ self.future = ProcessPoolTransferFuture(
+ monitor=self.monitor, meta=self.meta
+ )
+
+ def test_meta(self):
+ self.assertEqual(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.future.result())
+
+ def test_result_with_exception(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.future.result()
+
+ def test_result_with_keyboard_interrupt(self):
+ mock_monitor = mock.Mock(TransferMonitor)
+ mock_monitor._connect = mock.Mock()
+ mock_monitor.poll_for_result.side_effect = KeyboardInterrupt()
+ future = ProcessPoolTransferFuture(
+ monitor=mock_monitor, meta=self.meta
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ future.result()
+ self.assertTrue(mock_monitor._connect.called)
+ self.assertTrue(mock_monitor.notify_exception.called)
+ call_args = mock_monitor.notify_exception.call_args[0]
+ self.assertEqual(call_args[0], self.transfer_id)
+ self.assertIsInstance(call_args[1], CancelledError)
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(CancelledError):
+ self.future.result()
+
+
+class TestProcessPoolTransferMeta(unittest.TestCase):
+ def test_transfer_id(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.transfer_id, 1)
+
+ def test_call_args(self):
+ call_args = CallArgs()
+ meta = ProcessPoolTransferMeta(1, call_args)
+ self.assertEqual(meta.call_args, call_args)
+
+ def test_user_context(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.user_context, {})
+ meta.user_context['mykey'] = 'myvalue'
+ self.assertEqual(meta.user_context, {'mykey': 'myvalue'})
+
+
+class TestClientFactory(unittest.TestCase):
+ def test_create_client(self):
+ client = ClientFactory().create_client()
+ self.assertIsInstance(client, BaseClient)
+ self.assertEqual(client.meta.service_model.service_name, 's3')
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_create_client_with_client_kwargs(self):
+ client = ClientFactory({'region_name': 'myregion'}).create_client()
+ self.assertEqual(client.meta.region_name, 'myregion')
+
+ def test_user_agent_with_config(self):
+ client = ClientFactory({'config': Config()}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent_extra(self):
+ config = Config(user_agent_extra='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent(self):
+ config = Config(user_agent='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+
+class TestTransferMonitor(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+
+ def test_notify_new_transfer_creates_new_state(self):
+ monitor = TransferMonitor()
+ transfer_id = monitor.notify_new_transfer()
+ self.assertFalse(monitor.is_done(transfer_id))
+ self.assertIsNone(monitor.get_exception(transfer_id))
+
+ def test_notify_new_transfer_increments_transfer_id(self):
+ monitor = TransferMonitor()
+ self.assertEqual(monitor.notify_new_transfer(), 0)
+ self.assertEqual(monitor.notify_new_transfer(), 1)
+
+ def test_notify_get_exception(self):
+ exception = Exception()
+ self.monitor.notify_exception(self.transfer_id, exception)
+ self.assertEqual(
+ self.monitor.get_exception(self.transfer_id), exception
+ )
+
+ def test_get_no_exception(self):
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+ def test_notify_jobs(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0)
+
+ def test_notify_jobs_for_multiple_transfers(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ other_transfer_id = self.monitor.notify_new_transfer()
+ self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(
+ self.monitor.notify_job_complete(other_transfer_id), 1
+ )
+
+ def test_done(self):
+ self.assertFalse(self.monitor.is_done(self.transfer_id))
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+
+ def test_poll_for_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.monitor.poll_for_result(self.transfer_id))
+
+ def test_poll_for_result_raises_error(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.monitor.poll_for_result(self.transfer_id)
+
+ def test_poll_for_result_waits_till_done(self):
+ event_order = []
+
+ def sleep_then_notify_done():
+ time.sleep(0.05)
+ event_order.append('notify_done')
+ self.monitor.notify_done(self.transfer_id)
+
+ t = threading.Thread(target=sleep_then_notify_done)
+ t.start()
+
+ self.monitor.poll_for_result(self.transfer_id)
+ event_order.append('done_polling')
+ self.assertEqual(event_order, ['notify_done', 'done_polling'])
+
+ def test_notify_cancel_all_in_progress(self):
+ monitor = TransferMonitor()
+ transfer_ids = []
+ for _ in range(10):
+ transfer_ids.append(monitor.notify_new_transfer())
+ monitor.notify_cancel_all_in_progress()
+ for transfer_id in transfer_ids:
+ self.assertIsInstance(
+ monitor.get_exception(transfer_id), CancelledError
+ )
+ # Cancelling a transfer does not mean it is done as there may
+ # be cleanup work left to do.
+ self.assertFalse(monitor.is_done(transfer_id))
+
+ def test_notify_cancel_does_not_affect_done_transfers(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.monitor.notify_cancel_all_in_progress()
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+
+class TestTransferState(unittest.TestCase):
+ def setUp(self):
+ self.state = TransferState()
+
+ def test_done(self):
+ self.assertFalse(self.state.done)
+ self.state.set_done()
+ self.assertTrue(self.state.done)
+
+ def test_waits_till_done_is_set(self):
+ event_order = []
+
+ def sleep_then_set_done():
+ time.sleep(0.05)
+ event_order.append('set_done')
+ self.state.set_done()
+
+ t = threading.Thread(target=sleep_then_set_done)
+ t.start()
+
+ self.state.wait_till_done()
+ event_order.append('done_waiting')
+ self.assertEqual(event_order, ['set_done', 'done_waiting'])
+
+ def test_exception(self):
+ exception = RuntimeError()
+ self.state.exception = exception
+ self.assertEqual(self.state.exception, exception)
+
+ def test_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.jobs_to_complete, 5)
+
+ def test_decrement_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.decrement_jobs_to_complete(), 4)
+
+
+class TestGetObjectSubmitter(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.transfer_config = ProcessTransferConfig()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = mock.Mock(OSUtils)
+ self.download_request_queue = queue.Queue()
+ self.worker_queue = queue.Queue()
+ self.submitter = GetObjectSubmitter(
+ transfer_config=self.transfer_config,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ download_request_queue=self.download_request_queue,
+ worker_queue=self.worker_queue,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.filename = 'myfile'
+ self.temp_filename = 'myfile.temp'
+ self.osutil.get_temp_filename.return_value = self.temp_filename
+ self.extra_args = {}
+ self.expected_size = None
+
+ def add_download_file_request(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'filename': self.filename,
+ 'extra_args': self.extra_args,
+ 'expected_size': self.expected_size,
+ }
+ kwargs.update(override_kwargs)
+ self.download_request_queue.put(DownloadFileRequest(**kwargs))
+
+ def add_shutdown(self):
+ self.download_request_queue.put(SHUTDOWN_SIGNAL)
+
+ def assert_submitted_get_object_jobs(self, expected_jobs):
+ actual_jobs = []
+ while not self.worker_queue.empty():
+ actual_jobs.append(self.worker_queue.get())
+ self.assertEqual(actual_jobs, expected_jobs)
+
+ def test_run_for_non_ranged_download(self):
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_for_ranged_download(self):
+ self.transfer_config.multipart_chunksize = 2
+ self.transfer_config.multipart_threshold = 4
+ self.add_download_file_request(expected_size=4)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 4)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'Range': 'bytes=0-1'},
+ filename=self.filename,
+ ),
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=2,
+ extra_args={'Range': 'bytes=2-'},
+ filename=self.filename,
+ ),
+ ]
+ )
+
+ def test_run_when_expected_size_not_provided(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_extra_args(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ self.add_download_file_request(
+ extra_args={'VersionId': 'versionid'}, expected_size=None
+ )
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'VersionId': 'versionid'},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_exception(self):
+ self.stubber.add_client_error('head_object', 'NoSuchKey', 404)
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_with_error_in_allocating_temp_file(self):
+ self.osutil.allocate.side_effect = OSError()
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), OSError
+ )
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_submitter_cannot_be_killed(self):
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.head_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.submitter.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The submitter should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
+
+
+class TestGetObjectWorker(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.queue = queue.Queue()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = OSUtils()
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.remote_contents = b'my content'
+ self.temp_filename = self.files.create_file('tempfile', '')
+ self.extra_args = {}
+ self.offset = 0
+ self.final_filename = self.files.full_path('final_filename')
+ self.stream = BytesIO(self.remote_contents)
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+ def add_get_object_job(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'temp_filename': self.temp_filename,
+ 'extra_args': self.extra_args,
+ 'offset': self.offset,
+ 'filename': self.final_filename,
+ }
+ kwargs.update(override_kwargs)
+ self.queue.put(GetObjectJob(**kwargs))
+
+ def add_shutdown(self):
+ self.queue.put(SHUTDOWN_SIGNAL)
+
+ def add_stubbed_get_object_response(self, body=None, expected_params=None):
+ if body is None:
+ body = self.stream
+ get_object_response = {'Body': body}
+
+ if expected_params is None:
+ expected_params = {'Bucket': self.bucket, 'Key': self.key}
+
+ self.stubber.add_response(
+ 'get_object', get_object_response, expected_params
+ )
+
+ def assert_contents(self, filename, contents):
+ self.assertTrue(os.path.exists(filename))
+ with open(filename, 'rb') as f:
+ self.assertEqual(f.read(), contents)
+
+ def assert_does_not_exist(self, filename):
+ self.assertFalse(os.path.exists(filename))
+
+ def test_run_is_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_contents(self.final_filename, self.remote_contents)
+
+ def test_run_jobs_is_not_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_with_extra_args(self):
+ self.add_get_object_job(extra_args={'VersionId': 'versionid'})
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ }
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_with_offset(self):
+ offset = 1
+ self.add_get_object_job(offset=offset)
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ with open(self.temp_filename, 'rb') as f:
+ f.seek(offset)
+ self.assertEqual(f.read(), self.remote_contents)
+
+ def test_run_error_in_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.stubber.add_client_error('get_object', 'NoSuchKey', 404)
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_does_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+
+ def test_run_can_exhaust_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ # 5 is the current setting for max number of GetObject attempts
+ for _ in range(5):
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id),
+ RetriesExceededError,
+ )
+
+ def test_run_skips_get_object_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+
+ self.worker.run()
+ # Note we did not add a stubbed response for get_object
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_final_job_removes_file_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_fails_to_rename_file(self):
+ exception = OSError()
+ osutil = RenameFailingOSUtils(exception)
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=osutil,
+ )
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.assertEqual(
+ self.transfer_monitor.get_exception(self.transfer_id), exception
+ )
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_worker_cannot_be_killed(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.get_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.worker.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The worker should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
new file mode 100644
index 0000000000..35cf4a22dd
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
@@ -0,0 +1,780 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import socket
+import tempfile
+from concurrent import futures
+from contextlib import closing
+from io import BytesIO, StringIO
+
+from s3transfer import (
+ MultipartDownloader,
+ MultipartUploader,
+ OSUtils,
+ QueueShutdownError,
+ ReadFileChunk,
+ S3Transfer,
+ ShutdownQueue,
+ StreamReaderProgress,
+ TransferConfig,
+ disable_upload_callbacks,
+ enable_upload_callbacks,
+ random_file_extension,
+)
+from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+from __tests__ import mock, unittest
+
+
+class InMemoryOSLayer(OSUtils):
+ def __init__(self, filemap):
+ self.filemap = filemap
+
+ def get_file_size(self, filename):
+ return len(self.filemap[filename])
+
+ def open_file_chunk_reader(self, filename, start_byte, size, callback):
+ return closing(BytesIO(self.filemap[filename]))
+
+ def open(self, filename, mode):
+ if 'wb' in mode:
+ fileobj = BytesIO()
+ self.filemap[filename] = fileobj
+ return closing(fileobj)
+ else:
+ return closing(self.filemap[filename])
+
+ def remove_file(self, filename):
+ if filename in self.filemap:
+ del self.filemap[filename]
+
+ def rename_file(self, current_filename, new_filename):
+ if current_filename in self.filemap:
+ self.filemap[new_filename] = self.filemap.pop(current_filename)
+
+
+class SequentialExecutor:
+ def __init__(self, max_workers):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ pass
+
+ # The real map() interface actually takes *args, but we specifically do
+ # _not_ use this interface.
+ def map(self, function, args):
+ results = []
+ for arg in args:
+ results.append(function(arg))
+ return results
+
+ def submit(self, function):
+ future = futures.Future()
+ future.set_result(function())
+ return future
+
+
+class TestOSUtils(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_get_file_size(self):
+ with mock.patch('os.path.getsize') as m:
+ OSUtils().get_file_size('myfile')
+ m.assert_called_with('myfile')
+
+ def test_open_file_chunk_reader(self):
+ with mock.patch('s3transfer.ReadFileChunk') as m:
+ OSUtils().open_file_chunk_reader('myfile', 0, 100, None)
+ m.from_filename.assert_called_with(
+ 'myfile', 0, 100, None, enable_callback=False
+ )
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ with mock.patch('os.remove') as remove:
+ remove.side_effect = OSError('fake error')
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_remove_file_proxies_remove_file(self):
+ with mock.patch('os.remove') as remove:
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_rename_file(self):
+ with mock.patch('s3transfer.compat.rename_file') as rename_file:
+ OSUtils().rename_file('foo', 'newfoo')
+ rename_file.assert_called_with('foo', 'newfoo')
+
+
+class TestReadFileChunk(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+
+class TestReadFileChunkWithCallback(TestReadFileChunk):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ with open(self.filename, 'wb') as f:
+ f.write(b'abc')
+ self.amounts_seen = []
+
+ def callback(self, amount):
+ self.amounts_seen.append(amount)
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+
+class TestStreamReaderProgress(unittest.TestCase):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ amounts_seen = []
+
+ def callback(amount):
+ amounts_seen.append(amount)
+
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream, callback)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(amounts_seen, [9])
+
+
+class TestMultipartUploader(unittest.TestCase):
+ def test_multipart_upload_uses_correct_client_calls(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ # We need to check both the sequence of calls (create/upload/complete)
+ # as well as the params passed between the calls, including
+ # 1. The upload_id was plumbed through
+ # 2. The collected etags were added to the complete call.
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key'
+ )
+ # Should be two parts.
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_injects_proper_kwargs(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ extra_args = {
+ 'SSECustomerKey': 'fakekey',
+ 'SSECustomerAlgorithm': 'AES256',
+ 'StorageClass': 'REDUCED_REDUNDANCY',
+ }
+ uploader.upload_file('filename', 'bucket', 'key', None, extra_args)
+
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ # The initial call should inject all the storage class params.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ StorageClass='REDUCED_REDUNDANCY',
+ )
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ # We only have to forward certain **extra_args in subsequent
+ # UploadPart calls.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_is_aborted_on_error(self):
+ # If the create_multipart_upload succeeds and any upload_part
+ # fails, then abort_multipart_upload will be called.
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.side_effect = Exception(
+ "Some kind of error occurred."
+ )
+
+ with self.assertRaises(S3UploadFailedError):
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ client.abort_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key', UploadId='upload_id'
+ )
+
+
+class TestMultipartDownloader(unittest.TestCase):
+
+ maxDiff = None
+
+ def test_multipart_download_uses_correct_client_calls(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ client.get_object.assert_called_with(
+ Range='bytes=0-', Bucket='bucket', Key='key'
+ )
+
+ def test_multipart_download_with_multiple_parts(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ # For testing purposes, we're testing with a multipart threshold
+ # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz',
+ # this should result in 3 calls. In python slices this would be:
+ # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly
+ # different because they use inclusive ranges.
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # Note these are inclusive ranges.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_retry_on_failures_from_stream_reads(self):
+ # If we get an exception during a call to the response body's .read()
+ # method, we should retry the request.
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = [
+ socket.error("fake error"),
+ response_body,
+ ]
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # The first call to range=0-3 fails because of the
+ # side_effect above where we make the .read() raise a
+ # socket.error.
+ # The second call to range=0-3 then succeeds.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_exception_raised_on_exceeded_retries(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = socket.error("fake error")
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ with self.assertRaises(RetriesExceededError):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_io_thread_failure_triggers_shutdown(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ os_layer = mock.Mock()
+ mock_fileobj = mock.MagicMock()
+ mock_fileobj.__enter__.return_value = mock_fileobj
+ mock_fileobj.write.side_effect = Exception("fake IO error")
+ os_layer.open.return_value = mock_fileobj
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), os_layer, SequentialExecutor
+ )
+ # We're verifying that the exception raised from the IO future
+ # propagates back up via download_file().
+ with self.assertRaisesRegex(Exception, "fake IO error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_download_futures_fail_triggers_shutdown(self):
+ class FailedDownloadParts(SequentialExecutor):
+ def __init__(self, max_workers):
+ self.is_first = True
+
+ def submit(self, function):
+ future = futures.Future()
+ if self.is_first:
+ # This is the download_parts_thread.
+ future.set_exception(
+ Exception("fake download parts error")
+ )
+ self.is_first = False
+ return future
+
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts
+ )
+ with self.assertRaisesRegex(Exception, "fake download parts error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+
+class TestS3Transfer(unittest.TestCase):
+ def setUp(self):
+ self.client = mock.Mock()
+ self.random_file_patch = mock.patch('s3transfer.random_file_extension')
+ self.random_file = self.random_file_patch.start()
+ self.random_file.return_value = 'RANDOM'
+
+ def tearDown(self):
+ self.random_file_patch.stop()
+
+ def test_callback_handlers_register_on_put_item(self):
+ osutil = InMemoryOSLayer({'smallfile': b'foobar'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ events = self.client.meta.events
+ events.register_first.assert_called_with(
+ 'request-created.s3',
+ disable_upload_callbacks,
+ unique_id='s3upload-callback-disable',
+ )
+ events.register_last.assert_called_with(
+ 'request-created.s3',
+ enable_upload_callbacks,
+ unique_id='s3upload-callback-enable',
+ )
+
+ def test_upload_below_multipart_threshold_uses_put_object(self):
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY
+ )
+
+ def test_extra_args_on_uploaded_passed_to_api_call(self):
+ extra_args = {'ACL': 'public-read'}
+ fake_files = {'smallfile': b'hello world'}
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file(
+ 'smallfile', 'bucket', 'key', extra_args=extra_args
+ )
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read'
+ )
+
+ def test_uses_multipart_upload_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartUploader') as uploader:
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ config = TransferConfig(
+ multipart_threshold=2, multipart_chunksize=2
+ )
+ transfer = S3Transfer(self.client, osutil=osutil, config=config)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+
+ uploader.return_value.upload_file.assert_called_with(
+ 'smallfile', 'bucket', 'key', None, {}
+ )
+
+ def test_uses_multipart_download_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartDownloader') as downloader:
+ osutil = InMemoryOSLayer({})
+ over_multipart_threshold = 100 * 1024 * 1024
+ transfer = S3Transfer(self.client, osutil=osutil)
+ callback = mock.sentinel.CALLBACK
+ self.client.head_object.return_value = {
+ 'ContentLength': over_multipart_threshold,
+ }
+ transfer.download_file(
+ 'bucket', 'key', 'filename', callback=callback
+ )
+
+ downloader.return_value.download_file.assert_called_with(
+ # Note how we're downloading to a temporary random file.
+ 'bucket',
+ 'key',
+ 'filename.RANDOM',
+ over_multipart_threshold,
+ {},
+ callback,
+ )
+
+ def test_download_file_with_invalid_extra_args(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ with self.assertRaises(ValueError):
+ transfer.download_file(
+ 'bucket',
+ 'key',
+ '/tmp/smallfile',
+ extra_args={'BadValue': 'foo'},
+ )
+
+ def test_upload_file_with_invalid_extra_args(self):
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ bad_args = {"WebsiteRedirectLocation": "/foo"}
+ with self.assertRaises(ValueError):
+ transfer.upload_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args
+ )
+
+ def test_download_file_fowards_extra_args(self):
+ extra_args = {
+ 'SSECustomerKey': 'foo',
+ 'SSECustomerAlgorithm': 'AES256',
+ }
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
+ )
+
+ # Note that we need to invoke the HeadObject call
+ # and the PutObject call with the extra_args.
+ # This is necessary. Trying to HeadObject an SSE object
+ # will return a 400 if you don't provide the required
+ # params.
+ self.client.get_object.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ SSECustomerAlgorithm='AES256',
+ SSECustomerKey='foo',
+ )
+
+ def test_get_object_stream_is_retried_and_succeeds(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.side_effect = [
+ # First request fails.
+ socket.error("fake error"),
+ # Second succeeds.
+ {'Body': BytesIO(b'foobar')},
+ ]
+ transfer.download_file('bucket', 'key', '/tmp/smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 2)
+
+ def test_get_object_stream_uses_all_retries_and_errors_out(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ # Here we're raising an exception every single time, which
+ # will exhaust our retry count and propagate a
+ # RetriesExceededError.
+ self.client.get_object.side_effect = socket.error("fake error")
+ with self.assertRaises(RetriesExceededError):
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 5)
+ # We should have also cleaned up the in progress file
+ # we were downloading to.
+ self.assertEqual(osutil.filemap, {})
+
+ def test_download_below_multipart_threshold(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.client.get_object.assert_called_with(Bucket='bucket', Key='key')
+
+ def test_can_create_with_just_client(self):
+ transfer = S3Transfer(client=mock.Mock())
+ self.assertIsInstance(transfer, S3Transfer)
+
+
+class TestShutdownQueue(unittest.TestCase):
+ def test_handles_normal_put_get_requests(self):
+ q = ShutdownQueue()
+ q.put('foo')
+ self.assertEqual(q.get(), 'foo')
+
+ def test_put_raises_error_on_shutdown(self):
+ q = ShutdownQueue()
+ q.trigger_shutdown()
+ with self.assertRaises(QueueShutdownError):
+ q.put('foo')
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestCallbackHandlers(unittest.TestCase):
+ def setUp(self):
+ self.request = mock.Mock()
+
+ def test_disable_request_on_put_object(self):
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_disable_request_on_upload_part(self):
+ disable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_enable_object_on_put_object(self):
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_enable_object_on_upload_part(self):
+ enable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_dont_disable_if_missing_interface(self):
+ del self.request.body.disable_callback
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_enable_if_missing_interface(self):
+ del self.request.body.enable_callback
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_disable_if_wrong_operation(self):
+ disable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.disable_callback.called)
+
+ def test_dont_enable_if_wrong_operation(self):
+ enable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.enable_callback.called)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
new file mode 100644
index 0000000000..a26d3a548c
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
@@ -0,0 +1,91 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.exceptions import InvalidSubscriberMethodError
+from s3transfer.subscribers import BaseSubscriber
+from __tests__ import unittest
+
+
+class ExtraMethodsSubscriber(BaseSubscriber):
+ def extra_method(self):
+ return 'called extra method'
+
+
+class NotCallableSubscriber(BaseSubscriber):
+ on_done = 'foo'
+
+
+class NoKwargsSubscriber(BaseSubscriber):
+ def on_done(self):
+ pass
+
+
+class OverrideMethodSubscriber(BaseSubscriber):
+ def on_queued(self, **kwargs):
+ return kwargs
+
+
+class OverrideConstructorSubscriber(BaseSubscriber):
+ def __init__(self, arg1, arg2):
+ self.arg1 = arg1
+ self.arg2 = arg2
+
+
+class TestSubscribers(unittest.TestCase):
+ def test_can_instantiate_base_subscriber(self):
+ try:
+ BaseSubscriber()
+ except InvalidSubscriberMethodError:
+ self.fail('BaseSubscriber should be instantiable')
+
+ def test_can_call_base_subscriber_method(self):
+ subscriber = BaseSubscriber()
+ try:
+ subscriber.on_done(future=None)
+ except Exception as e:
+ self.fail(
+ 'Should be able to call base class subscriber method. '
+ 'instead got: %s' % e
+ )
+
+ def test_subclass_can_have_and_call_additional_methods(self):
+ subscriber = ExtraMethodsSubscriber()
+ self.assertEqual(subscriber.extra_method(), 'called extra method')
+
+ def test_can_subclass_and_override_method_from_base_subscriber(self):
+ subscriber = OverrideMethodSubscriber()
+ # Make sure that the overridden method is called
+ self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'})
+
+ def test_can_subclass_and_override_constructor_from_base_class(self):
+ subscriber = OverrideConstructorSubscriber('foo', arg2='bar')
+ # Make sure you can create a custom constructor.
+ self.assertEqual(subscriber.arg1, 'foo')
+ self.assertEqual(subscriber.arg2, 'bar')
+
+ def test_invalid_arguments_in_constructor_of_subclass_subscriber(self):
+ # The override constructor should still have validation of
+ # constructor args.
+ with self.assertRaises(TypeError):
+ OverrideConstructorSubscriber()
+
+ def test_not_callable_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must be callable'
+ ):
+ NotCallableSubscriber()
+
+ def test_no_kwargs_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must accept keyword'
+ ):
+ NoKwargsSubscriber()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
new file mode 100644
index 0000000000..4f0bc4d1cc
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
@@ -0,0 +1,833 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from concurrent import futures
+from functools import partial
+from threading import Event
+
+from s3transfer.futures import BoundedExecutor, TransferCoordinator
+from s3transfer.subscribers import BaseSubscriber
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class TaskFailureException(Exception):
+ pass
+
+
+class SuccessTask(Task):
+ def _main(
+ self, return_value='success', callbacks=None, failure_cleanups=None
+ ):
+ if callbacks:
+ for callback in callbacks:
+ callback()
+ if failure_cleanups:
+ for failure_cleanup in failure_cleanups:
+ self._transfer_coordinator.add_failure_cleanup(failure_cleanup)
+ return return_value
+
+
+class FailureTask(Task):
+ def _main(self, exception=TaskFailureException):
+ raise exception()
+
+
+class ReturnKwargsTask(Task):
+ def _main(self, **kwargs):
+ return kwargs
+
+
+class SubmitMoreTasksTask(Task):
+ def _main(self, executor, tasks_to_submit):
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+
+
+class NOOPSubmissionTask(SubmissionTask):
+ def _submit(self, transfer_future, **kwargs):
+ pass
+
+
+class ExceptionSubmissionTask(SubmissionTask):
+ def _submit(
+ self,
+ transfer_future,
+ executor=None,
+ tasks_to_submit=None,
+ additional_callbacks=None,
+ exception=TaskFailureException,
+ ):
+ if executor and tasks_to_submit:
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+ if additional_callbacks:
+ for callback in additional_callbacks:
+ callback()
+ raise exception()
+
+
+class StatusRecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self, transfer_id=None):
+ super().__init__(transfer_id)
+ self.status_changes = [self._status]
+
+ def set_status_to_queued(self):
+ super().set_status_to_queued()
+ self._record_status_change()
+
+ def set_status_to_running(self):
+ super().set_status_to_running()
+ self._record_status_change()
+
+ def _record_status_change(self):
+ self.status_changes.append(self._status)
+
+
+class RecordingStateSubscriber(BaseSubscriber):
+ def __init__(self, transfer_coordinator):
+ self._transfer_coordinator = transfer_coordinator
+ self.status_during_on_queued = None
+
+ def on_queued(self, **kwargs):
+ self.status_during_on_queued = self._transfer_coordinator.status
+
+
+class TestSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.executor = BoundedExecutor(1000, 5)
+ self.call_args = CallArgs(subscribers=[])
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.main_kwargs = {'transfer_future': self.transfer_future}
+
+ def test_transitions_from_not_started_to_queued_to_running(self):
+ self.transfer_coordinator = StatusRecordingTransferCoordinator()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ # Status should be queued until submission task has been ran.
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ submission_task()
+ # Once submission task has been ran, the status should now be running.
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Ensure the transitions were as expected as well.
+ self.assertEqual(
+ self.transfer_coordinator.status_changes,
+ ['not-started', 'queued', 'running'],
+ )
+
+ def test_on_queued_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the on_queued callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_on_queued_status_in_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingStateSubscriber(self.transfer_coordinator)
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the status was queued during on_queued callback.
+ self.assertEqual(subscriber.status_during_on_queued, 'queued')
+
+ def test_sets_exception_from_submit(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Make sure the status of the future is failed
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the future propagates the exception encountered in the
+ # submission task.
+ with self.assertRaises(TaskFailureException):
+ self.transfer_future.result()
+
+ def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self):
+ self.main_kwargs['exception'] = KeyboardInterrupt
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ with self.assertRaises(KeyboardInterrupt):
+ self.transfer_future.result()
+
+ def test_calls_done_callbacks_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_calls_failure_cleanups_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_cleanups_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the done callbacks to be invoke more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called only once.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_done_callbacks_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the failure cleanups to be invoked more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called only once.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_handles_cleanups_submitted_in_other_tasks(self):
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ # We want the cleanup to be added in the execution of the task and
+ # still be executed by the submission task when it fails.
+ task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [task]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called even though the callback got
+ # added in a completely different task.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self):
+ # In this test, we want to make sure that any tasks that may be
+ # submitted in another task complete before we start performing
+ # cleanups.
+ #
+ # This is tested by doing the following:
+ #
+ # ExecutionSubmissionTask
+ # |
+ # +--submits-->SubmitMoreTasksTask
+ # |
+ # +--submits-->SuccessTask
+ # |
+ # +-->sleeps-->adds failure cleanup
+ #
+ # In the end, the failure cleanup of the SuccessTask should be ran
+ # when the ExecutionSubmissionTask fails. If the
+ # ExeceptionSubmissionTask did not run the failure cleanup it is most
+ # likely that it did not wait for the SuccessTask to complete, which
+ # it needs to because the ExeceptionSubmissionTask does not know
+ # what failure cleanups it needs to run until all spawned tasks have
+ # completed.
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+
+ cleanup_task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+ task_for_submitting_cleanup_task = self.get_task(
+ SubmitMoreTasksTask,
+ main_kwargs={
+ 'executor': self.executor,
+ 'tasks_to_submit': [cleanup_task],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [
+ task_for_submitting_cleanup_task
+ ]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_submission_task_announces_done_if_cancelled_before_main(self):
+ invocations_of_done = []
+ done_callback = FunctionContainer(
+ invocations_of_done.append, 'done announced'
+ )
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ self.transfer_coordinator.cancel()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Because the submission task was cancelled before being run
+ # it did not submit any extra tasks so a result it is responsible
+ # for making sure it announces done as nothing else will.
+ self.assertEqual(invocations_of_done, ['done announced'])
+
+
+class TestTask(unittest.TestCase):
+ def setUp(self):
+ self.transfer_id = 1
+ self.transfer_coordinator = TransferCoordinator(
+ transfer_id=self.transfer_id
+ )
+
+ def test_repr(self):
+ main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'}
+ task = ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs
+ )
+ # The repr should not include the other parameter because it is not
+ # a desired parameter to include.
+ self.assertEqual(
+ repr(task),
+ 'ReturnKwargsTask(transfer_id={}, {})'.format(
+ self.transfer_id, {'bucket': 'mybucket'}
+ ),
+ )
+
+ def test_transfer_id(self):
+ task = SuccessTask(self.transfer_coordinator)
+ # Make sure that the id is the one provided to the id associated
+ # to the transfer coordinator.
+ self.assertEqual(task.transfer_id, self.transfer_id)
+
+ def test_context_status_transitioning_success(self):
+ # The status should be set to running.
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # If a task is called, the status still should be running.
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Once the final task is called, the status should be set to success.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_context_status_transitioning_failed(self):
+ self.transfer_coordinator.set_status_to_running()
+
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # A failure task should result in the failed status
+ FailureTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Even if the final task comes in and succeeds, it should stay failed.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_result_setting_for_success(self):
+ override_return = 'foo'
+ SuccessTask(self.transfer_coordinator)()
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': override_return},
+ is_final=True,
+ )()
+
+ # The return value for the transfer future should be of the final
+ # task.
+ self.assertEqual(self.transfer_coordinator.result(), override_return)
+
+ def test_result_setting_for_error(self):
+ FailureTask(self.transfer_coordinator)()
+
+ # If another failure comes in, the result should still throw the
+ # original exception when result() is eventually called.
+ FailureTask(
+ self.transfer_coordinator, main_kwargs={'exception': Exception}
+ )()
+
+ # Even if a success task comes along, the result of the future
+ # should be the original exception
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_done_callbacks_success(self):
+ callback_results = []
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For successful tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_done_callbacks_failure(self):
+ callback_results = []
+ FailureTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For even failed tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ # Callbacks should continue to be called even after a related failure
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'third'),
+ partial(callback_results.append, 'fourth'),
+ ],
+ )()
+ self.assertEqual(
+ callback_results, ['first', 'second', 'third', 'fourth']
+ )
+
+ def test_failure_cleanups_on_failure(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ FailureTask(self.transfer_coordinator)()
+ # The failure callbacks should have not been called yet because it
+ # is not the last task
+ self.assertEqual(callback_results, [])
+
+ # Now the failure callbacks should get called.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_no_failure_cleanups_on_success(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ # The failure cleanups should not have been called because no task
+ # failed for the transfer context.
+ self.assertEqual(callback_results, [])
+
+ def test_passing_main_kwargs(self):
+ main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+ ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True
+ )()
+ # The kwargs should have been passed to the main()
+ self.assertEqual(self.transfer_coordinator.result(), main_kwargs)
+
+ def test_passing_pending_kwargs_single_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo']},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['baz']},
+ )
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_kwargs_list_of_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': ['first', 'second']}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][0]},
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][1]},
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['foo'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out in the expected order.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_and_non_pending_kwargs(self):
+ main_kwargs = {'nonpending_value': 'foo'}
+ pending_kwargs = {}
+ ref_main_kwargs = {
+ 'nonpending_value': 'foo',
+ 'pending_value': 'bar',
+ 'pending_list': ['first', 'second'],
+ }
+
+ # Create the pending tasks
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['pending_value'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_value']
+ },
+ )
+ )
+
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][0]
+ },
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][1]
+ },
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ # and just regular nonpending kwargs.
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ main_kwargs=main_kwargs,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have all of the kwargs (both pending and
+ # nonpending)
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_single_failed_pending_future(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_single_failed_pending_future_in_list(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ second_future = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value in the list
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+
+class BaseMultipartTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+
+
+class TestCreateMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'foo'
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ response = {'UploadId': upload_id}
+ task = self.get_task(
+ CreateMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response=response,
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ result_id = task()
+ self.stubber.assert_no_pending_responses()
+ # Ensure the upload id returned is correct
+ self.assertEqual(upload_id, result_id)
+
+ # Make sure that the abort was added as a cleanup failure
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1)
+
+ # Make sure if it is called, it will abort correctly
+ self.stubber.add_response(
+ method='abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ },
+ )
+ self.transfer_coordinator.failure_cleanups[0]()
+ self.stubber.assert_no_pending_responses()
+
+
+class TestCompleteMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+
+ def test_includes_extra_args(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {'RequestPayer': 'requester'},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ 'RequestPayer': 'requester',
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
new file mode 100644
index 0000000000..1ac38b3616
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
@@ -0,0 +1,694 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+
+import math
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from botocore.stub import ANY
+
+from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
+from s3transfer.manager import TransferConfig
+from s3transfer.upload import (
+ AggregatedProgressCallback,
+ InterruptReader,
+ PutObjectTask,
+ UploadFilenameInputManager,
+ UploadNonSeekableInputManager,
+ UploadPartTask,
+ UploadSeekableInputManager,
+ UploadSubmissionTask,
+)
+from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileSizeProvider,
+ NonSeekableReader,
+ RecordingExecutor,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class InterruptionError(Exception):
+ pass
+
+
+class OSUtilsExceptionOnFileSize(OSUtils):
+ def get_file_size(self, filename):
+ raise AssertionError(
+ "The file %s should not have been stated" % filename
+ )
+
+
+class BaseUploadTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+ self.osutil = OSUtils()
+
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'my content'
+ self.subscribers = []
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+
+class TestAggregatedProgressCallback(unittest.TestCase):
+ def setUp(self):
+ self.aggregated_amounts = []
+ self.threshold = 3
+ self.aggregated_progress_callback = AggregatedProgressCallback(
+ [self.callback], self.threshold
+ )
+
+ def callback(self, bytes_transferred):
+ self.aggregated_amounts.append(bytes_transferred)
+
+ def test_under_threshold(self):
+ one_under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(one_under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback(1)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_at_threshold(self):
+ self.aggregated_progress_callback(self.threshold)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_over_threshold(self):
+ over_threshold_amount = self.threshold + 1
+ self.aggregated_progress_callback(over_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [over_threshold_amount])
+
+ def test_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+ def test_flush_with_nothing_to_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+ # Flushing again should do nothing as it was just flushed
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+
+class TestInterruptReader(BaseUploadTest):
+ def test_read_raises_exception(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Read some bytes to show it can be read.
+ self.assertEqual(reader.read(1), self.content[0:1])
+ # Then set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError())
+ # The next read should have the exception propograte
+ with self.assertRaises(InterruptionError):
+ reader.read()
+
+ def test_seek(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can seek correctly
+ reader.seek(1)
+ self.assertEqual(reader.read(1), self.content[1:2])
+
+ def test_tell(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can tell correctly
+ reader.seek(1)
+ self.assertEqual(reader.tell(), 1)
+
+
+class BaseUploadInputManagerTest(BaseUploadTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+ self.config = TransferConfig()
+ self.recording_subscriber = RecordingSubscriber()
+ self.subscribers.append(self.recording_subscriber)
+
+ def _get_expected_body_for_part(self, part_number):
+ # A helper method for retrieving the expected body for a specific
+ # part number of the data
+ total_size = len(self.content)
+ chunk_size = self.config.multipart_chunksize
+ start_index = (part_number - 1) * chunk_size
+ end_index = part_number * chunk_size
+ if end_index >= total_size:
+ return self.content[start_index:]
+ return self.content[start_index:end_index]
+
+
+class TestUploadFilenameInputManager(BaseUploadInputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadFilenameInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.call_args = CallArgs(
+ fileobj=self.filename, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.upload_input_manager.is_compatible(
+ self.future.meta.call_args.fileobj
+ )
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # The provided file size should be equal to size of the contents of
+ # the file.
+ self.assertEqual(self.future.meta.size, len(self.content))
+
+ def test_requires_multipart_upload(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ # With the default multipart threshold, the length of the content
+ # should be smaller than the threshold thus not requiring a multipart
+ # transfer.
+ self.assertFalse(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+ # Decreasing the threshold to that of the length of the content of
+ # the file should trigger the need for a multipart upload.
+ self.config.multipart_threshold = len(self.content)
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ def test_get_put_object_body(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+ read_file_chunk.enable_callback()
+ # The file-like object provided back should be the same as the content
+ # of the file.
+ with read_file_chunk:
+ self.assertEqual(read_file_chunk.read(), self.content)
+ # The file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_get_put_object_body_is_interruptable(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ # Ensure the returned read file chunk can be interrupted with that
+ # error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+ def test_yield_upload_part_bodies(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ expected_part_number = 1
+ for part_number, read_file_chunk in part_iterator:
+ # Ensure that the part number is as expected
+ self.assertEqual(part_number, expected_part_number)
+ read_file_chunk.enable_callback()
+ # Ensure that the body is correct for that part.
+ with read_file_chunk:
+ self.assertEqual(
+ read_file_chunk.read(),
+ self._get_expected_body_for_part(part_number),
+ )
+ expected_part_number += 1
+
+ # All of the file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_yield_upload_part_bodies_are_interruptable(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ for _, read_file_chunk in part_iterator:
+ # Ensure that each read file chunk yielded can be interrupted
+ # with that error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+
+class TestUploadSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = open(self.filename, 'rb')
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(self.upload_input_manager.is_compatible(BytesIO()))
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(self.upload_input_manager.is_compatible(object()))
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_get_put_object_body(self):
+ start_pos = 3
+ self.fileobj.seek(start_pos)
+ adjusted_size = len(self.content) - start_pos
+ self.future.meta.provide_transfer_size(adjusted_size)
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ read_file_chunk.enable_callback()
+ # The fact that the file was seeked to start should be taken into
+ # account in length and content for the read file chunk.
+ with read_file_chunk:
+ self.assertEqual(len(read_file_chunk), adjusted_size)
+ self.assertEqual(read_file_chunk.read(), self.content[start_pos:])
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), adjusted_size
+ )
+
+
+class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadNonSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = NonSeekableReader(self.content)
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def assert_multipart_parts(self):
+ """
+ Asserts that the input manager will generate a multipart upload
+ and that each part is in order and the correct size.
+ """
+ # Assert that a multipart upload is required.
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ # Get a list of all the parts that would be sent.
+ parts = list(
+ self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ )
+
+ # Assert that the actual number of parts is what we would expect
+ # based on the configuration.
+ size = self.config.multipart_chunksize
+ num_parts = math.ceil(len(self.content) / size)
+ self.assertEqual(len(parts), num_parts)
+
+ # Run for every part but the last part.
+ for i, part in enumerate(parts[:-1]):
+ # Assert the part number is correct.
+ self.assertEqual(part[0], i + 1)
+ # Assert the part contains the right amount of data.
+ data = part[1].read()
+ self.assertEqual(len(data), size)
+
+ # Assert that the last part is the correct size.
+ expected_final_size = len(self.content) - ((num_parts - 1) * size)
+ final_part = parts[-1]
+ self.assertEqual(len(final_part[1].read()), expected_final_size)
+
+ # Assert that the last part has the correct part number.
+ self.assertEqual(final_part[0], len(parts))
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # There is no way to get the size without reading the entire body.
+ self.assertEqual(self.future.meta.size, None)
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_initial_data_parts_threshold_lesser(self):
+ # threshold < size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 2
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_equal(self):
+ # threshold == size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 4
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_greater(self):
+ # threshold > size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 8
+ self.assert_multipart_parts()
+
+
+class TestUploadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3)
+ self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE
+ self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def add_multipart_upload_stubbed_responses(self):
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response={'UploadId': 'my-id'},
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-1'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-2'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-3'}
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload', service_response={}
+ )
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_value_for_put_object(self, tag_value):
+ self.assertEqual(self.executor.submissions[0]['tag'], tag_value)
+
+ def assert_tag_value_for_upload_parts(self, tag_value):
+ for submission in self.executor.submissions[1:-1]:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def test_provide_file_size_on_put(self):
+ self.call_args.subscribers.append(FileSizeProvider(len(self.content)))
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+
+ # With this submitter, it will fail to stat the file if a transfer
+ # size is not provided.
+ self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize()
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+ def test_submits_no_tag_for_put_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_no_tag_for_multipart_filename(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit any of the upload part tasks were
+ # were associated when submitted to the executor
+ self.assert_tag_value_for_upload_parts(None)
+
+ def test_submits_no_tag_for_put_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_tag_for_multipart_fileobj(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure tags to limit all of the upload part tasks were
+ # were associated when submitted to the executor as these tasks will
+ # have chunks of data stored with them in memory.
+ self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG)
+
+
+class TestPutObjectTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ PutObjectTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+
+class TestUploadPartTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'RequestPayer': 'requester'}
+ upload_id = 'my-id'
+ part_number = 1
+ etag = 'foo'
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ UploadPartTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'part_number': part_number,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='upload_part',
+ service_response={'ETag': etag},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'PartNumber': part_number,
+ 'RequestPayer': 'requester',
+ },
+ )
+ rval = task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number})
+ self.assertEqual(self.sent_bodies, [self.content])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
new file mode 100644
index 0000000000..a1ff904e7a
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
@@ -0,0 +1,1189 @@
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import io
+import os.path
+import random
+import re
+import shutil
+import tempfile
+import threading
+import time
+from io import BytesIO, StringIO
+
+from s3transfer.futures import TransferFuture, TransferMeta
+from s3transfer.utils import (
+ MAX_PARTS,
+ MAX_SINGLE_UPLOAD_SIZE,
+ MIN_UPLOAD_CHUNKSIZE,
+ CallArgs,
+ ChunksizeAdjuster,
+ CountCallbackInvoker,
+ DeferredOpenFile,
+ FunctionContainer,
+ NoResourcesAvailable,
+ OSUtils,
+ ReadFileChunk,
+ SlidingWindowSemaphore,
+ StreamReaderProgress,
+ TaskSemaphore,
+ calculate_num_parts,
+ calculate_range_parameter,
+ get_callbacks,
+ get_filtered_dict,
+ invoke_progress_callbacks,
+ random_file_extension,
+)
+from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
+
+
+class TestGetCallbacks(unittest.TestCase):
+ def setUp(self):
+ self.subscriber = RecordingSubscriber()
+ self.second_subscriber = RecordingSubscriber()
+ self.call_args = CallArgs(
+ subscribers=[self.subscriber, self.second_subscriber]
+ )
+ self.transfer_meta = TransferMeta(self.call_args)
+ self.transfer_future = TransferFuture(self.transfer_meta)
+
+ def test_get_callbacks(self):
+ callbacks = get_callbacks(self.transfer_future, 'queued')
+ # Make sure two callbacks were added as both subscribers had
+ # an on_queued method.
+ self.assertEqual(len(callbacks), 2)
+
+ # Ensure that the callback was injected with the future by calling
+ # one of them and checking that the future was used in the call.
+ callbacks[0]()
+ self.assertEqual(
+ self.subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_get_callbacks_for_missing_type(self):
+ callbacks = get_callbacks(self.transfer_future, 'fake_state')
+ # There should be no callbacks as the subscribers will not have the
+ # on_fake_state method
+ self.assertEqual(len(callbacks), 0)
+
+
+class TestGetFilteredDict(unittest.TestCase):
+ def test_get_filtered_dict(self):
+ original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'}
+ whitelist = ['Include']
+ self.assertEqual(
+ get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'}
+ )
+
+
+class TestCallArgs(unittest.TestCase):
+ def test_call_args(self):
+ call_args = CallArgs(foo='bar', biz='baz')
+ self.assertEqual(call_args.foo, 'bar')
+ self.assertEqual(call_args.biz, 'baz')
+
+
+class TestFunctionContainer(unittest.TestCase):
+ def get_args_kwargs(self, *args, **kwargs):
+ return args, kwargs
+
+ def test_call(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(func_container(), (('foo',), {'bar': 'baz'}))
+
+ def test_repr(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(
+ str(func_container),
+ 'Function: {} with args {} and kwargs {}'.format(
+ self.get_args_kwargs, ('foo',), {'bar': 'baz'}
+ ),
+ )
+
+
+class TestCountCallbackInvoker(unittest.TestCase):
+ def invoke_callback(self):
+ self.ref_results.append('callback invoked')
+
+ def assert_callback_invoked(self):
+ self.assertEqual(self.ref_results, ['callback invoked'])
+
+ def assert_callback_not_invoked(self):
+ self.assertEqual(self.ref_results, [])
+
+ def setUp(self):
+ self.ref_results = []
+ self.invoker = CountCallbackInvoker(self.invoke_callback)
+
+ def test_increment(self):
+ self.invoker.increment()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_decrement(self):
+ self.invoker.increment()
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_count_cannot_go_below_zero(self):
+ with self.assertRaises(RuntimeError):
+ self.invoker.decrement()
+
+ def test_callback_invoked_only_once_finalized(self):
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assert_callback_not_invoked()
+ self.invoker.finalize()
+ # Callback should only be invoked once finalized
+ self.assert_callback_invoked()
+
+ def test_callback_invoked_after_finalizing_and_count_reaching_zero(self):
+ self.invoker.increment()
+ self.invoker.finalize()
+ # Make sure that it does not get invoked immediately after
+ # finalizing as the count is currently one
+ self.assert_callback_not_invoked()
+ self.invoker.decrement()
+ self.assert_callback_invoked()
+
+ def test_cannot_increment_after_finalization(self):
+ self.invoker.finalize()
+ with self.assertRaises(RuntimeError):
+ self.invoker.increment()
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestInvokeProgressCallbacks(unittest.TestCase):
+ def test_invoke_progress_callbacks(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 2)
+ self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2)
+
+ def test_invoke_progress_callbacks_with_no_progress(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 0)
+ self.assertEqual(len(recording_subscriber.on_progress_calls), 0)
+
+
+class TestCalculateNumParts(unittest.TestCase):
+ def test_calculate_num_parts_divisible(self):
+ self.assertEqual(calculate_num_parts(size=4, part_size=2), 2)
+
+ def test_calculate_num_parts_not_divisible(self):
+ self.assertEqual(calculate_num_parts(size=3, part_size=2), 2)
+
+
+class TestCalculateRangeParameter(unittest.TestCase):
+ def setUp(self):
+ self.part_size = 5
+ self.part_index = 1
+ self.num_parts = 3
+
+ def test_calculate_range_paramter(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, self.num_parts
+ )
+ self.assertEqual(range_val, 'bytes=5-9')
+
+ def test_last_part_with_no_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2
+ )
+ self.assertEqual(range_val, 'bytes=5-')
+
+ def test_last_part_with_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2, total_size=8
+ )
+ self.assertEqual(range_val, 'bytes=5-7')
+
+
+class BaseUtilsTest(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.content = b'abc'
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.amounts_seen = []
+ self.num_close_callback_calls = 0
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def callback(self, bytes_transferred):
+ self.amounts_seen.append(bytes_transferred)
+
+ def close_callback(self):
+ self.num_close_callback_calls += 1
+
+
+class TestOSUtils(BaseUtilsTest):
+ def test_get_file_size(self):
+ self.assertEqual(
+ OSUtils().get_file_size(self.filename), len(self.content)
+ )
+
+ def test_open_file_chunk_reader(self):
+ reader = OSUtils().open_file_chunk_reader(
+ self.filename, 0, 3, [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_open_file_chunk_reader_from_fileobj(self):
+ with open(self.filename, 'rb') as f:
+ reader = OSUtils().open_file_chunk_reader_from_fileobj(
+ f, len(self.content), len(self.content), [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ reader.close()
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ non_existent_file = os.path.join(self.tempdir, 'no-exist')
+ # This should not exist to start.
+ self.assertFalse(os.path.exists(non_existent_file))
+ try:
+ OSUtils().remove_file(non_existent_file)
+ except OSError as e:
+ self.fail('OSError should have been caught: %s' % e)
+
+ def test_remove_file_proxies_remove_file(self):
+ OSUtils().remove_file(self.filename)
+ self.assertFalse(os.path.exists(self.filename))
+
+ def test_rename_file(self):
+ new_filename = os.path.join(self.tempdir, 'newfoo')
+ OSUtils().rename_file(self.filename, new_filename)
+ self.assertFalse(os.path.exists(self.filename))
+ self.assertTrue(os.path.exists(new_filename))
+
+ def test_is_special_file_for_normal_file(self):
+ self.assertFalse(OSUtils().is_special_file(self.filename))
+
+ def test_is_special_file_for_non_existant_file(self):
+ non_existant_filename = os.path.join(self.tempdir, 'no-exist')
+ self.assertFalse(os.path.exists(non_existant_filename))
+ self.assertFalse(OSUtils().is_special_file(non_existant_filename))
+
+ def test_get_temp_filename(self):
+ filename = 'myfile'
+ self.assertIsNotNone(
+ re.match(
+ r'%s\.[0-9A-Fa-f]{8}$' % filename,
+ OSUtils().get_temp_filename(filename),
+ )
+ )
+
+ def test_get_temp_filename_len_255(self):
+ filename = 'a' * 255
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_get_temp_filename_len_gt_255(self):
+ filename = 'a' * 280
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_allocate(self):
+ truncate_size = 1
+ OSUtils().allocate(self.filename, truncate_size)
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(len(f.read()), truncate_size)
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_io_error(self, mock_fallocate):
+ mock_fallocate.side_effect = IOError()
+ with self.assertRaises(IOError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_os_error(self, mock_fallocate):
+ mock_fallocate.side_effect = OSError()
+ with self.assertRaises(OSError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+
+class TestDeferredOpenFile(BaseUtilsTest):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.contents = b'my contents'
+ with open(self.filename, 'wb') as f:
+ f.write(self.contents)
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.open_call_args = []
+
+ def tearDown(self):
+ self.deferred_open_file.close()
+ super().tearDown()
+
+ def recording_open_function(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return open(filename, mode)
+
+ def open_nonseekable(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return NonSeekableWriter(BytesIO(self.content))
+
+ def test_instantiation_does_not_open_file(self):
+ DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.assertEqual(len(self.open_call_args), 0)
+
+ def test_name(self):
+ self.assertEqual(self.deferred_open_file.name, self.filename)
+
+ def test_read(self):
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[0:2])
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_write(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ open_function=self.recording_open_function,
+ )
+
+ write_content = b'foo'
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.close()
+ # Both of the writes should now be in the file.
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), write_content * 2)
+ # Open should have only been called once.
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_seek(self):
+ self.deferred_open_file.seek(2)
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_does_not_seek_with_zero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=0,
+ open_function=self.open_nonseekable,
+ )
+
+ try:
+ # If this seeks, an UnsupportedOperation error will be raised.
+ self.deferred_open_file.write(b'data')
+ except io.UnsupportedOperation:
+ self.fail('DeferredOpenFile seeked upon opening')
+
+ def test_open_seeks_with_nonzero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=5,
+ open_function=self.open_nonseekable,
+ )
+
+ # Since a non-seekable file is being opened, calling Seek will raise
+ # an UnsupportedOperation error.
+ with self.assertRaises(io.UnsupportedOperation):
+ self.deferred_open_file.write(b'data')
+
+ def test_tell(self):
+ self.deferred_open_file.tell()
+ # tell() should not have opened the file if it has not been seeked
+ # or read because we know the start bytes upfront.
+ self.assertEqual(len(self.open_call_args), 0)
+
+ self.deferred_open_file.seek(2)
+ self.assertEqual(self.deferred_open_file.tell(), 2)
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_args(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='ab+',
+ open_function=self.recording_open_function,
+ )
+ # Force an open
+ self.deferred_open_file.write(b'data')
+ self.assertEqual(len(self.open_call_args), 1)
+ self.assertEqual(self.open_call_args[0], (self.filename, 'ab+'))
+
+ def test_context_handler(self):
+ with self.deferred_open_file:
+ self.assertEqual(len(self.open_call_args), 1)
+
+
+class TestReadFileChunk(BaseUtilsTest):
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(1, whence=1)
+ self.assertEqual(chunk.tell(), 1)
+ chunk.seek(-1, whence=1)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(-1, whence=2)
+ self.assertEqual(chunk.tell(), 2)
+
+ def test_tell_and_seek_boundaries(self):
+ # Test to ensure ReadFileChunk behaves the same as the
+ # Python standard library around seeking and reading out
+ # of bounds in a file object.
+ data = b'abcdefghij12345678klmnopqrst'
+ start_pos = 10
+ chunk_size = 8
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+
+ # ReadFileChunk should be a substring of only numbers
+ file_objects = [
+ ReadFileChunk.from_filename(
+ filename, start_byte=start_pos, chunk_size=chunk_size
+ )
+ ]
+
+ # Uncomment next line to validate we match Python's io.BytesIO
+ # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size]))
+
+ for obj in file_objects:
+ self._assert_whence_start_behavior(obj)
+ self._assert_whence_end_behavior(obj)
+ self._assert_whence_relative_behavior(obj)
+ self._assert_boundary_behavior(obj)
+
+ def _assert_whence_start_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(1, 0)
+ self.assertEqual(file_obj.tell(), 1)
+
+ file_obj.seek(1)
+ self.assertEqual(file_obj.tell(), 1)
+ self.assertEqual(file_obj.read(), b'2345678')
+
+ file_obj.seek(3, 0)
+ self.assertEqual(file_obj.tell(), 3)
+
+ file_obj.seek(0, 0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_relative_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(2, 1)
+ self.assertEqual(file_obj.tell(), 2)
+
+ file_obj.seek(1, 1)
+ self.assertEqual(file_obj.tell(), 3)
+ self.assertEqual(file_obj.read(), b'45678')
+
+ file_obj.seek(20, 1)
+ self.assertEqual(file_obj.tell(), 28)
+
+ file_obj.seek(-30, 1)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_end_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(-1, 2)
+ self.assertEqual(file_obj.tell(), 7)
+
+ file_obj.seek(1, 2)
+ self.assertEqual(file_obj.tell(), 9)
+
+ file_obj.seek(3, 2)
+ self.assertEqual(file_obj.tell(), 11)
+ self.assertEqual(file_obj.read(), b'')
+
+ file_obj.seek(-15, 2)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 2)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_boundary_behavior(self, file_obj):
+ # Verify we're at the start
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we can't move backwards beyond start of file
+ file_obj.seek(-10, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we *can* move after end of file, but return nothing
+ file_obj.seek(10, 2)
+ self.assertEqual(file_obj.tell(), 18)
+ self.assertEqual(file_obj.read(), b'')
+ self.assertEqual(file_obj.read(10), b'')
+
+ # Verify we can partially rewind
+ file_obj.seek(-12, 1)
+ self.assertEqual(file_obj.tell(), 6)
+ self.assertEqual(file_obj.read(), b'78')
+ self.assertEqual(file_obj.tell(), 8)
+
+ # Verify we can rewind to start
+ file_obj.seek(0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_all_callbacks_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback, self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ # The list should be twice as long because there are two callbacks
+ # recording the amount read.
+ self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+ def test_callback_triggered_by_out_of_bound_seeks(self):
+ data = b'abcdefghij1234567890klmnopqr'
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=10, chunk_size=10, callbacks=[self.callback]
+ )
+
+ # Seek calls that generate "0" progress are skipped by
+ # invoke_progress_callbacks and won't appear in the list.
+ expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10]
+
+ self._assert_out_of_bound_start_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_end_seek(chunk, expected_callback_prog)
+
+ def _assert_out_of_bound_start_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20) # (20, 10)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(11) # (11, 0)
+ chunk.seek(20) # (20, 0)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(0) # (0, -10)
+ chunk.seek(0) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_relative_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20, 1) # (20, 10)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-9, 1) # (11, 0)
+ chunk.seek(9, 1) # (20, 0)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-20, 1) # (0, -10)
+ chunk.seek(-1000, 1) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_end_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(10, 2) # (20, 10)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(1, 2) # (11, 0)
+ chunk.seek(10, 2) # (20, 0)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-10, 2) # (0, -10)
+ chunk.seek(-1000, 2) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def test_close_callbacks(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_close_callbacks_when_not_enabled(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ enable_callbacks=False,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_close_callbacks_when_context_handler_is_used(self):
+ with open(self.filename) as f:
+ with ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ ) as chunk:
+ chunk.read(1)
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_signal_transferring(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.signal_not_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [])
+ chunk.signal_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1])
+
+ def test_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_transferring()
+ self.assertTrue(underlying_stream.signal_transferring.called)
+
+ def test_no_call_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call signal_transferring '
+ 'to the underlying stream.'
+ )
+
+ def test_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_not_transferring()
+ self.assertTrue(underlying_stream.signal_not_transferring.called)
+
+ def test_no_call_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_not_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call '
+ 'signal_not_transferring to the underlying stream.'
+ )
+
+
+class TestStreamReaderProgress(BaseUtilsTest):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(
+ original_stream, [self.callback, self.callback]
+ )
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(self.amounts_seen, [9, 9])
+
+
+class TestTaskSemaphore(unittest.TestCase):
+ def setUp(self):
+ self.semaphore = TaskSemaphore(1)
+
+ def test_should_block_at_max_capacity(self):
+ self.semaphore.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ self.semaphore.acquire('a', blocking=False)
+
+ def test_release_capacity(self):
+ acquire_token = self.semaphore.acquire('a', blocking=False)
+ self.semaphore.release('a', acquire_token)
+ try:
+ self.semaphore.acquire('a', blocking=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'The release of the semaphore should have allowed for '
+ 'the second acquire to not be blocked'
+ )
+
+
+class TestSlidingWindowSemaphore(unittest.TestCase):
+ # These tests use block=False to tests will fail
+ # instead of hang the test runner in the case of x
+ # incorrect behavior.
+ def test_acquire_release_basic_case(self):
+ sem = SlidingWindowSemaphore(1)
+ # Count is 1
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', 0)
+ # Count now back to 1.
+
+ def test_can_acquire_release_multiple_times(self):
+ sem = SlidingWindowSemaphore(1)
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', num)
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 1)
+ sem.release('a', num)
+
+ def test_can_acquire_a_range(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ sem.release('a', 2)
+ # Now we're reset so we should be able to acquire the same
+ # sequence again.
+ self.assertEqual(sem.acquire('a', blocking=False), 3)
+ self.assertEqual(sem.acquire('a', blocking=False), 4)
+ self.assertEqual(sem.acquire('a', blocking=False), 5)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_counter_release_only_on_min_element(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # The count only increases when we free the min
+ # element. This means if we're currently failing to
+ # acquire now:
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ # Then freeing a non-min element:
+ sem.release('a', 1)
+
+ # doesn't change anything. We still fail to acquire.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_raises_error_when_count_is_zero(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # Count is now 0 so trying to acquire should fail.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ def test_release_counters_can_increment_counter_repeatedly(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # These two releases don't increment the counter
+ # because we're waiting on 0.
+ sem.release('a', 1)
+ sem.release('a', 2)
+ self.assertEqual(sem.current_count(), 0)
+ # But as soon as we release 0, we free up 0, 1, and 2.
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def test_error_to_release_unknown_tag(self):
+ sem = SlidingWindowSemaphore(3)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_track_multiple_tags(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+
+ # We're at our max of 3 even though 2 are for A and 1 is for B.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('b', blocking=False)
+
+ def test_can_handle_multiple_tags_released(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('b', blocking=False)
+ sem.acquire('b', blocking=False)
+
+ sem.release('b', 1)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 0)
+
+ sem.release('b', 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+
+ sem.release('a', 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 2)
+
+ def test_is_error_to_release_unknown_sequence_number(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(ValueError):
+ sem.release('a', 1)
+
+ def test_is_error_to_double_release(self):
+ # This is different than other error tests because
+ # we're verifying we can reset the state after an
+ # acquire/release cycle.
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 2)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_check_in_partial_range(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ sem.release('a', 1)
+ sem.release('a', 3)
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 2)
+
+
+class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase):
+ # These tests focus on mutithreaded properties of the range
+ # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore.
+ def setUp(self):
+ self.threads = []
+
+ def tearDown(self):
+ self.join_threads()
+
+ def join_threads(self):
+ for thread in self.threads:
+ thread.join()
+ self.threads = []
+
+ def start_threads(self):
+ for thread in self.threads:
+ thread.start()
+
+ def test_acquire_blocks_until_release_is_called(self):
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def acquire():
+ # This next call to acquire will block.
+ self.assertEqual(sem.acquire('a', blocking=True), 2)
+
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ # Starting the thread will block the sem.acquire()
+ # in the acquire function above.
+ t.start()
+ # This still will keep the thread blocked.
+ sem.release('a', 1)
+ # Releasing the min element will unblock the thread.
+ sem.release('a', 0)
+ t.join()
+ sem.release('a', 2)
+
+ def test_stress_invariants_random_order(self):
+ sem = SlidingWindowSemaphore(100)
+ for _ in range(10):
+ recorded = []
+ for _ in range(100):
+ recorded.append(sem.acquire('a', blocking=False))
+ # Release them in randomized order. As long as we
+ # eventually free all 100, we should have all the
+ # resources released.
+ random.shuffle(recorded)
+ for i in recorded:
+ sem.release('a', i)
+
+ # Everything's freed so should be back at count == 100
+ self.assertEqual(sem.current_count(), 100)
+
+ def test_blocking_stress(self):
+ sem = SlidingWindowSemaphore(5)
+ num_threads = 10
+ num_iterations = 50
+
+ def acquire():
+ for _ in range(num_iterations):
+ num = sem.acquire('a', blocking=True)
+ time.sleep(0.001)
+ sem.release('a', num)
+
+ for i in range(num_threads):
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ self.start_threads()
+ self.join_threads()
+ # Should have all the available resources freed.
+ self.assertEqual(sem.current_count(), 5)
+ # Should have acquired num_threads * num_iterations
+ self.assertEqual(
+ sem.acquire('a', blocking=False), num_threads * num_iterations
+ )
+
+
+class TestAdjustChunksize(unittest.TestCase):
+ def setUp(self):
+ self.adjuster = ChunksizeAdjuster()
+
+ def test_valid_chunksize(self):
+ chunksize = 7 * (1024 ** 2)
+ file_size = 8 * (1024 ** 2)
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, chunksize)
+
+ def test_chunksize_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ file_size = 3 * MIN_UPLOAD_CHUNKSIZE
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_chunksize_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ file_size = MAX_SINGLE_UPLOAD_SIZE * 2
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
+
+ def test_chunksize_too_small(self):
+ chunksize = 7 * (1024 ** 2)
+ file_size = 5 * (1024 ** 4)
+ # If we try to upload a 5TB file, we'll need to use 896MB part
+ # sizes.
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, 896 * (1024 ** 2))
+ num_parts = file_size / new_size
+ self.assertLessEqual(num_parts, MAX_PARTS)
+
+ def test_unknown_file_size_with_valid_chunksize(self):
+ chunksize = 7 * (1024 ** 2)
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, chunksize)
+
+ def test_unknown_file_size_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_unknown_file_size_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
diff --git a/contrib/python/s3transfer/py3/tests/ya.make b/contrib/python/s3transfer/py3/tests/ya.make
new file mode 100644
index 0000000000..fdbf22b0c5
--- /dev/null
+++ b/contrib/python/s3transfer/py3/tests/ya.make
@@ -0,0 +1,44 @@
+PY3TEST()
+
+OWNER(g:python-contrib)
+
+SIZE(MEDIUM)
+
+FORK_TESTS()
+
+PEERDIR(
+ contrib/python/mock
+ contrib/python/s3transfer
+)
+
+TEST_SRCS(
+ functional/__init__.py
+ functional/test_copy.py
+ functional/test_crt.py
+ functional/test_delete.py
+ functional/test_download.py
+ functional/test_manager.py
+ functional/test_processpool.py
+ functional/test_upload.py
+ functional/test_utils.py
+ __init__.py
+ unit/__init__.py
+ unit/test_bandwidth.py
+ unit/test_compat.py
+ unit/test_copies.py
+ unit/test_crt.py
+ unit/test_delete.py
+ unit/test_download.py
+ unit/test_futures.py
+ unit/test_manager.py
+ unit/test_processpool.py
+ unit/test_s3transfer.py
+ unit/test_subscribers.py
+ unit/test_tasks.py
+ unit/test_upload.py
+ unit/test_utils.py
+)
+
+NO_LINT()
+
+END()
diff --git a/contrib/python/s3transfer/py3/ya.make b/contrib/python/s3transfer/py3/ya.make
new file mode 100644
index 0000000000..964a630639
--- /dev/null
+++ b/contrib/python/s3transfer/py3/ya.make
@@ -0,0 +1,51 @@
+# Generated by devtools/yamaker (pypi).
+
+PY3_LIBRARY()
+
+OWNER(gebetix g:python-contrib)
+
+VERSION(0.5.1)
+
+LICENSE(Apache-2.0)
+
+PEERDIR(
+ contrib/python/botocore
+)
+
+NO_LINT()
+
+NO_CHECK_IMPORTS(
+ s3transfer.crt
+)
+
+PY_SRCS(
+ TOP_LEVEL
+ s3transfer/__init__.py
+ s3transfer/bandwidth.py
+ s3transfer/compat.py
+ s3transfer/constants.py
+ s3transfer/copies.py
+ s3transfer/crt.py
+ s3transfer/delete.py
+ s3transfer/download.py
+ s3transfer/exceptions.py
+ s3transfer/futures.py
+ s3transfer/manager.py
+ s3transfer/processpool.py
+ s3transfer/subscribers.py
+ s3transfer/tasks.py
+ s3transfer/upload.py
+ s3transfer/utils.py
+)
+
+RESOURCE_FILES(
+ PREFIX contrib/python/s3transfer/py3/
+ .dist-info/METADATA
+ .dist-info/top_level.txt
+)
+
+END()
+
+RECURSE_FOR_TESTS(
+ tests
+)