diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/python/s3transfer/py3 | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/python/s3transfer/py3')
50 files changed, 18490 insertions, 0 deletions
diff --git a/contrib/python/s3transfer/py3/.dist-info/METADATA b/contrib/python/s3transfer/py3/.dist-info/METADATA new file mode 100644 index 0000000000..7d635068d7 --- /dev/null +++ b/contrib/python/s3transfer/py3/.dist-info/METADATA @@ -0,0 +1,42 @@ +Metadata-Version: 2.1 +Name: s3transfer +Version: 0.5.1 +Summary: An Amazon S3 Transfer Manager +Home-page: https://github.com/boto/s3transfer +Author: Amazon Web Services +Author-email: kyknapp1@gmail.com +License: Apache License 2.0 +Platform: UNKNOWN +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Developers +Classifier: Natural Language :: English +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Requires-Python: >= 3.6 +License-File: LICENSE.txt +License-File: NOTICE.txt +Requires-Dist: botocore (<2.0a.0,>=1.12.36) +Provides-Extra: crt +Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt' + +===================================================== +s3transfer - An Amazon S3 Transfer Manager for Python +===================================================== + +S3transfer is a Python library for managing Amazon S3 transfers. +This project is maintained and published by Amazon Web Services. + +.. note:: + + This project is not currently GA. If you are planning to use this code in + production, make sure to lock to a minor version as interfaces may break + from minor version to minor version. For a basic, stable interface of + s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ + + diff --git a/contrib/python/s3transfer/py3/.dist-info/top_level.txt b/contrib/python/s3transfer/py3/.dist-info/top_level.txt new file mode 100644 index 0000000000..572c6a92fb --- /dev/null +++ b/contrib/python/s3transfer/py3/.dist-info/top_level.txt @@ -0,0 +1 @@ +s3transfer diff --git a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml new file mode 100644 index 0000000000..f2f140fb3c --- /dev/null +++ b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml @@ -0,0 +1,2 @@ +exclude: +- tests/integration/.+ diff --git a/contrib/python/s3transfer/py3/LICENSE.txt b/contrib/python/s3transfer/py3/LICENSE.txt new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/contrib/python/s3transfer/py3/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/s3transfer/py3/NOTICE.txt b/contrib/python/s3transfer/py3/NOTICE.txt new file mode 100644 index 0000000000..3e616fdf0c --- /dev/null +++ b/contrib/python/s3transfer/py3/NOTICE.txt @@ -0,0 +1,2 @@ +s3transfer +Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/contrib/python/s3transfer/py3/README.rst b/contrib/python/s3transfer/py3/README.rst new file mode 100644 index 0000000000..441029109e --- /dev/null +++ b/contrib/python/s3transfer/py3/README.rst @@ -0,0 +1,13 @@ +===================================================== +s3transfer - An Amazon S3 Transfer Manager for Python +===================================================== + +S3transfer is a Python library for managing Amazon S3 transfers. +This project is maintained and published by Amazon Web Services. + +.. note:: + + This project is not currently GA. If you are planning to use this code in + production, make sure to lock to a minor version as interfaces may break + from minor version to minor version. For a basic, stable interface of + s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ diff --git a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch new file mode 100644 index 0000000000..aa8d3fab4e --- /dev/null +++ b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch @@ -0,0 +1,242 @@ +--- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree) +@@ -15,7 +15,7 @@ from botocore.stub import Stubber + + from s3transfer.manager import TransferConfig, TransferManager + from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE +-from tests import BaseGeneralInterfaceTest, FileSizeProvider ++from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider + + + class BaseCopyTest(BaseGeneralInterfaceTest): +--- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree) +@@ -18,7 +18,7 @@ from concurrent.futures import Future + from botocore.session import Session + + from s3transfer.subscribers import BaseSubscriber +-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest ++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + + if HAS_CRT: + import awscrt +--- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.manager import TransferManager +-from tests import BaseGeneralInterfaceTest ++from __tests__ import BaseGeneralInterfaceTest + + + class TestDeleteObject(BaseGeneralInterfaceTest): +--- contrib/python/s3transfer/py3/tests/functional/test_download.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree) +@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError + from s3transfer.compat import SOCKET_ERROR + from s3transfer.exceptions import RetriesExceededError + from s3transfer.manager import TransferConfig, TransferManager +-from tests import ( ++from __tests__ import ( + BaseGeneralInterfaceTest, + FileSizeProvider, + NonSeekableWriter, +--- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree) +@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object + from s3transfer.exceptions import CancelledError, FatalError + from s3transfer.futures import BaseExecutor + from s3transfer.manager import TransferConfig, TransferManager +-from tests import StubbedClientTest, mock, skip_if_using_serial_implementation ++from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation + + + class ArbitraryException(Exception): +--- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree) +@@ -21,7 +21,7 @@ from botocore.stub import Stubber + + from s3transfer.exceptions import CancelledError + from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig +-from tests import FileCreator, mock, unittest ++from __tests__ import FileCreator, mock, unittest + + + class StubbedClient: +--- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree) +@@ -23,7 +23,7 @@ from botocore.stub import ANY + + from s3transfer.manager import TransferConfig, TransferManager + from s3transfer.utils import ChunksizeAdjuster +-from tests import ( ++from __tests__ import ( + BaseGeneralInterfaceTest, + NonSeekableReader, + RecordingOSUtils, +--- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree) +@@ -16,7 +16,7 @@ import socket + import tempfile + + from s3transfer.utils import OSUtils +-from tests import skip_if_windows, unittest ++from __tests__ import skip_if_windows, unittest + + + @skip_if_windows('Windows does not support UNIX special files') +--- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree) +@@ -25,7 +25,7 @@ from s3transfer.bandwidth import ( + TimeUtils, + ) + from s3transfer.futures import TransferCoordinator +-from tests import mock, unittest ++from __tests__ import mock, unittest + + + class FixedIncrementalTickTimeUtils(TimeUtils): +--- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree) +@@ -17,7 +17,7 @@ import tempfile + from io import BytesIO + + from s3transfer.compat import BaseManager, readable, seekable +-from tests import skip_if_windows, unittest ++from __tests__ import skip_if_windows, unittest + + + class ErrorRaisingSeekWrapper: +--- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.copies import CopyObjectTask, CopyPartTask +-from tests import BaseTaskTest, RecordingSubscriber ++from __tests__ import BaseTaskTest, RecordingSubscriber + + + class BaseCopyTaskTest(BaseTaskTest): +--- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree) +@@ -15,7 +15,7 @@ from botocore.session import Session + + from s3transfer.exceptions import TransferNotDoneError + from s3transfer.utils import CallArgs +-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest ++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + + if HAS_CRT: + import awscrt.s3 +--- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.delete import DeleteObjectTask +-from tests import BaseTaskTest ++from __tests__ import BaseTaskTest + + + class TestDeleteObjectTask(BaseTaskTest): +--- contrib/python/s3transfer/py3/tests/unit/test_download.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree) +@@ -37,7 +37,7 @@ from s3transfer.download import ( + from s3transfer.exceptions import RetriesExceededError + from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor + from s3transfer.utils import CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileCreator, +--- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree) +@@ -37,7 +37,7 @@ from s3transfer.utils import ( + NoResourcesAvailable, + TaskSemaphore, + ) +-from tests import ( ++from __tests__ import ( + RecordingExecutor, + TransferCoordinatorWithInterrupt, + mock, +--- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree) +@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor + from s3transfer.exceptions import CancelledError, FatalError + from s3transfer.futures import TransferCoordinator + from s3transfer.manager import TransferConfig, TransferCoordinatorController +-from tests import TransferCoordinatorWithInterrupt, unittest ++from __tests__ import TransferCoordinatorWithInterrupt, unittest + + + class FutureResultException(Exception): +--- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree) +@@ -39,7 +39,7 @@ from s3transfer.processpool import ( + ignore_ctrl_c, + ) + from s3transfer.utils import CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + FileCreator, + StreamWithError, + StubbedClientTest, +--- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree) +@@ -33,7 +33,7 @@ from s3transfer import ( + random_file_extension, + ) + from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError +-from tests import mock, unittest ++from __tests__ import mock, unittest + + + class InMemoryOSLayer(OSUtils): +--- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree) +@@ -12,7 +12,7 @@ + # language governing permissions and limitations under the License. + from s3transfer.exceptions import InvalidSubscriberMethodError + from s3transfer.subscribers import BaseSubscriber +-from tests import unittest ++from __tests__ import unittest + + + class ExtraMethodsSubscriber(BaseSubscriber): +--- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree) +@@ -23,7 +23,7 @@ from s3transfer.tasks import ( + Task, + ) + from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + RecordingSubscriber, +--- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree) +@@ -32,7 +32,7 @@ from s3transfer.upload import ( + UploadSubmissionTask, + ) + from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileSizeProvider, +--- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree) +@@ -43,7 +43,7 @@ from s3transfer.utils import ( + invoke_progress_callbacks, + random_file_extension, + ) +-from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest ++from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest + + + class TestGetCallbacks(unittest.TestCase): diff --git a/contrib/python/s3transfer/py3/s3transfer/__init__.py b/contrib/python/s3transfer/py3/s3transfer/__init__.py new file mode 100644 index 0000000000..1a749c712e --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/__init__.py @@ -0,0 +1,875 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Abstractions over S3's upload/download operations. + +This module provides high level abstractions for efficient +uploads/downloads. It handles several things for the user: + +* Automatically switching to multipart transfers when + a file is over a specific size threshold +* Uploading/downloading a file in parallel +* Throttling based on max bandwidth +* Progress callbacks to monitor transfers +* Retries. While botocore handles retries for streaming uploads, + it is not possible for it to handle retries for streaming + downloads. This module handles retries for both cases so + you don't need to implement any retry logic yourself. + +This module has a reasonable set of defaults. It also allows you +to configure many aspects of the transfer process including: + +* Multipart threshold size +* Max parallel downloads +* Max bandwidth +* Socket timeouts +* Retry amounts + +There is no support for s3->s3 multipart copies at this +time. + + +.. _ref_s3transfer_usage: + +Usage +===== + +The simplest way to use this module is: + +.. code-block:: python + + client = boto3.client('s3', 'us-west-2') + transfer = S3Transfer(client) + # Upload /tmp/myfile to s3://bucket/key + transfer.upload_file('/tmp/myfile', 'bucket', 'key') + + # Download s3://bucket/key to /tmp/myfile + transfer.download_file('bucket', 'key', '/tmp/myfile') + +The ``upload_file`` and ``download_file`` methods also accept +``**kwargs``, which will be forwarded through to the corresponding +client operation. Here are a few examples using ``upload_file``:: + + # Making the object public + transfer.upload_file('/tmp/myfile', 'bucket', 'key', + extra_args={'ACL': 'public-read'}) + + # Setting metadata + transfer.upload_file('/tmp/myfile', 'bucket', 'key', + extra_args={'Metadata': {'a': 'b', 'c': 'd'}}) + + # Setting content type + transfer.upload_file('/tmp/myfile.json', 'bucket', 'key', + extra_args={'ContentType': "application/json"}) + + +The ``S3Transfer`` class also supports progress callbacks so you can +provide transfer progress to users. Both the ``upload_file`` and +``download_file`` methods take an optional ``callback`` parameter. +Here's an example of how to print a simple progress percentage +to the user: + +.. code-block:: python + + class ProgressPercentage(object): + def __init__(self, filename): + self._filename = filename + self._size = float(os.path.getsize(filename)) + self._seen_so_far = 0 + self._lock = threading.Lock() + + def __call__(self, bytes_amount): + # To simplify we'll assume this is hooked up + # to a single filename. + with self._lock: + self._seen_so_far += bytes_amount + percentage = (self._seen_so_far / self._size) * 100 + sys.stdout.write( + "\r%s %s / %s (%.2f%%)" % (self._filename, self._seen_so_far, + self._size, percentage)) + sys.stdout.flush() + + + transfer = S3Transfer(boto3.client('s3', 'us-west-2')) + # Upload /tmp/myfile to s3://bucket/key and print upload progress. + transfer.upload_file('/tmp/myfile', 'bucket', 'key', + callback=ProgressPercentage('/tmp/myfile')) + + + +You can also provide a TransferConfig object to the S3Transfer +object that gives you more fine grained control over the +transfer. For example: + +.. code-block:: python + + client = boto3.client('s3', 'us-west-2') + config = TransferConfig( + multipart_threshold=8 * 1024 * 1024, + max_concurrency=10, + num_download_attempts=10, + ) + transfer = S3Transfer(client, config) + transfer.upload_file('/tmp/foo', 'bucket', 'key') + + +""" +import concurrent.futures +import functools +import logging +import math +import os +import queue +import random +import socket +import string +import threading + +from botocore.compat import six # noqa: F401 +from botocore.exceptions import IncompleteReadError +from botocore.vendored.requests.packages.urllib3.exceptions import ( + ReadTimeoutError, +) + +import s3transfer.compat +from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError + +__author__ = 'Amazon Web Services' +__version__ = '0.5.1' + + +class NullHandler(logging.Handler): + def emit(self, record): + pass + + +logger = logging.getLogger(__name__) +logger.addHandler(NullHandler()) + +MB = 1024 * 1024 +SHUTDOWN_SENTINEL = object() + + +def random_file_extension(num_digits=8): + return ''.join(random.choice(string.hexdigits) for _ in range(num_digits)) + + +def disable_upload_callbacks(request, operation_name, **kwargs): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'disable_callback' + ): + request.body.disable_callback() + + +def enable_upload_callbacks(request, operation_name, **kwargs): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'enable_callback' + ): + request.body.enable_callback() + + +class QueueShutdownError(Exception): + pass + + +class ReadFileChunk: + def __init__( + self, + fileobj, + start_byte, + chunk_size, + full_file_size, + callback=None, + enable_callback=True, + ): + """ + + Given a file object shown below: + + |___________________________________________________| + 0 | | full_file_size + |----chunk_size---| + start_byte + + :type fileobj: file + :param fileobj: File like object + + :type start_byte: int + :param start_byte: The first byte from which to start reading. + + :type chunk_size: int + :param chunk_size: The max chunk size to read. Trying to read + pass the end of the chunk size will behave like you've + reached the end of the file. + + :type full_file_size: int + :param full_file_size: The entire content length associated + with ``fileobj``. + + :type callback: function(amount_read) + :param callback: Called whenever data is read from this object. + + """ + self._fileobj = fileobj + self._start_byte = start_byte + self._size = self._calculate_file_size( + self._fileobj, + requested_size=chunk_size, + start_byte=start_byte, + actual_file_size=full_file_size, + ) + self._fileobj.seek(self._start_byte) + self._amount_read = 0 + self._callback = callback + self._callback_enabled = enable_callback + + @classmethod + def from_filename( + cls, + filename, + start_byte, + chunk_size, + callback=None, + enable_callback=True, + ): + """Convenience factory function to create from a filename. + + :type start_byte: int + :param start_byte: The first byte from which to start reading. + + :type chunk_size: int + :param chunk_size: The max chunk size to read. Trying to read + pass the end of the chunk size will behave like you've + reached the end of the file. + + :type full_file_size: int + :param full_file_size: The entire content length associated + with ``fileobj``. + + :type callback: function(amount_read) + :param callback: Called whenever data is read from this object. + + :type enable_callback: bool + :param enable_callback: Indicate whether to invoke callback + during read() calls. + + :rtype: ``ReadFileChunk`` + :return: A new instance of ``ReadFileChunk`` + + """ + f = open(filename, 'rb') + file_size = os.fstat(f.fileno()).st_size + return cls( + f, start_byte, chunk_size, file_size, callback, enable_callback + ) + + def _calculate_file_size( + self, fileobj, requested_size, start_byte, actual_file_size + ): + max_chunk_size = actual_file_size - start_byte + return min(max_chunk_size, requested_size) + + def read(self, amount=None): + if amount is None: + amount_to_read = self._size - self._amount_read + else: + amount_to_read = min(self._size - self._amount_read, amount) + data = self._fileobj.read(amount_to_read) + self._amount_read += len(data) + if self._callback is not None and self._callback_enabled: + self._callback(len(data)) + return data + + def enable_callback(self): + self._callback_enabled = True + + def disable_callback(self): + self._callback_enabled = False + + def seek(self, where): + self._fileobj.seek(self._start_byte + where) + if self._callback is not None and self._callback_enabled: + # To also rewind the callback() for an accurate progress report + self._callback(where - self._amount_read) + self._amount_read = where + + def close(self): + self._fileobj.close() + + def tell(self): + return self._amount_read + + def __len__(self): + # __len__ is defined because requests will try to determine the length + # of the stream to set a content length. In the normal case + # of the file it will just stat the file, but we need to change that + # behavior. By providing a __len__, requests will use that instead + # of stat'ing the file. + return self._size + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + def __iter__(self): + # This is a workaround for http://bugs.python.org/issue17575 + # Basically httplib will try to iterate over the contents, even + # if its a file like object. This wasn't noticed because we've + # already exhausted the stream so iterating over the file immediately + # stops, which is what we're simulating here. + return iter([]) + + +class StreamReaderProgress: + """Wrapper for a read only stream that adds progress callbacks.""" + + def __init__(self, stream, callback=None): + self._stream = stream + self._callback = callback + + def read(self, *args, **kwargs): + value = self._stream.read(*args, **kwargs) + if self._callback is not None: + self._callback(len(value)) + return value + + +class OSUtils: + def get_file_size(self, filename): + return os.path.getsize(filename) + + def open_file_chunk_reader(self, filename, start_byte, size, callback): + return ReadFileChunk.from_filename( + filename, start_byte, size, callback, enable_callback=False + ) + + def open(self, filename, mode): + return open(filename, mode) + + def remove_file(self, filename): + """Remove a file, noop if file does not exist.""" + # Unlike os.remove, if the file does not exist, + # then this method does nothing. + try: + os.remove(filename) + except OSError: + pass + + def rename_file(self, current_filename, new_filename): + s3transfer.compat.rename_file(current_filename, new_filename) + + +class MultipartUploader: + # These are the extra_args that need to be forwarded onto + # subsequent upload_parts. + UPLOAD_PART_ARGS = [ + 'SSECustomerKey', + 'SSECustomerAlgorithm', + 'SSECustomerKeyMD5', + 'RequestPayer', + ] + + def __init__( + self, + client, + config, + osutil, + executor_cls=concurrent.futures.ThreadPoolExecutor, + ): + self._client = client + self._config = config + self._os = osutil + self._executor_cls = executor_cls + + def _extra_upload_part_args(self, extra_args): + # Only the args in UPLOAD_PART_ARGS actually need to be passed + # onto the upload_part calls. + upload_parts_args = {} + for key, value in extra_args.items(): + if key in self.UPLOAD_PART_ARGS: + upload_parts_args[key] = value + return upload_parts_args + + def upload_file(self, filename, bucket, key, callback, extra_args): + response = self._client.create_multipart_upload( + Bucket=bucket, Key=key, **extra_args + ) + upload_id = response['UploadId'] + try: + parts = self._upload_parts( + upload_id, filename, bucket, key, callback, extra_args + ) + except Exception as e: + logger.debug( + "Exception raised while uploading parts, " + "aborting multipart upload.", + exc_info=True, + ) + self._client.abort_multipart_upload( + Bucket=bucket, Key=key, UploadId=upload_id + ) + raise S3UploadFailedError( + "Failed to upload {} to {}: {}".format( + filename, '/'.join([bucket, key]), e + ) + ) + self._client.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={'Parts': parts}, + ) + + def _upload_parts( + self, upload_id, filename, bucket, key, callback, extra_args + ): + upload_parts_extra_args = self._extra_upload_part_args(extra_args) + parts = [] + part_size = self._config.multipart_chunksize + num_parts = int( + math.ceil(self._os.get_file_size(filename) / float(part_size)) + ) + max_workers = self._config.max_concurrency + with self._executor_cls(max_workers=max_workers) as executor: + upload_partial = functools.partial( + self._upload_one_part, + filename, + bucket, + key, + upload_id, + part_size, + upload_parts_extra_args, + callback, + ) + for part in executor.map(upload_partial, range(1, num_parts + 1)): + parts.append(part) + return parts + + def _upload_one_part( + self, + filename, + bucket, + key, + upload_id, + part_size, + extra_args, + callback, + part_number, + ): + open_chunk_reader = self._os.open_file_chunk_reader + with open_chunk_reader( + filename, part_size * (part_number - 1), part_size, callback + ) as body: + response = self._client.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=body, + **extra_args, + ) + etag = response['ETag'] + return {'ETag': etag, 'PartNumber': part_number} + + +class ShutdownQueue(queue.Queue): + """A queue implementation that can be shutdown. + + Shutting down a queue means that this class adds a + trigger_shutdown method that will trigger all subsequent + calls to put() to fail with a ``QueueShutdownError``. + + It purposefully deviates from queue.Queue, and is *not* meant + to be a drop in replacement for ``queue.Queue``. + + """ + + def _init(self, maxsize): + self._shutdown = False + self._shutdown_lock = threading.Lock() + # queue.Queue is an old style class so we don't use super(). + return queue.Queue._init(self, maxsize) + + def trigger_shutdown(self): + with self._shutdown_lock: + self._shutdown = True + logger.debug("The IO queue is now shutdown.") + + def put(self, item): + # Note: this is not sufficient, it's still possible to deadlock! + # Need to hook into the condition vars used by this class. + with self._shutdown_lock: + if self._shutdown: + raise QueueShutdownError( + "Cannot put item to queue when " "queue has been shutdown." + ) + return queue.Queue.put(self, item) + + +class MultipartDownloader: + def __init__( + self, + client, + config, + osutil, + executor_cls=concurrent.futures.ThreadPoolExecutor, + ): + self._client = client + self._config = config + self._os = osutil + self._executor_cls = executor_cls + self._ioqueue = ShutdownQueue(self._config.max_io_queue) + + def download_file( + self, bucket, key, filename, object_size, extra_args, callback=None + ): + with self._executor_cls(max_workers=2) as controller: + # 1 thread for the future that manages the uploading of files + # 1 thread for the future that manages IO writes. + download_parts_handler = functools.partial( + self._download_file_as_future, + bucket, + key, + filename, + object_size, + callback, + ) + parts_future = controller.submit(download_parts_handler) + + io_writes_handler = functools.partial( + self._perform_io_writes, filename + ) + io_future = controller.submit(io_writes_handler) + results = concurrent.futures.wait( + [parts_future, io_future], + return_when=concurrent.futures.FIRST_EXCEPTION, + ) + self._process_future_results(results) + + def _process_future_results(self, futures): + finished, unfinished = futures + for future in finished: + future.result() + + def _download_file_as_future( + self, bucket, key, filename, object_size, callback + ): + part_size = self._config.multipart_chunksize + num_parts = int(math.ceil(object_size / float(part_size))) + max_workers = self._config.max_concurrency + download_partial = functools.partial( + self._download_range, + bucket, + key, + filename, + part_size, + num_parts, + callback, + ) + try: + with self._executor_cls(max_workers=max_workers) as executor: + list(executor.map(download_partial, range(num_parts))) + finally: + self._ioqueue.put(SHUTDOWN_SENTINEL) + + def _calculate_range_param(self, part_size, part_index, num_parts): + start_range = part_index * part_size + if part_index == num_parts - 1: + end_range = '' + else: + end_range = start_range + part_size - 1 + range_param = f'bytes={start_range}-{end_range}' + return range_param + + def _download_range( + self, bucket, key, filename, part_size, num_parts, callback, part_index + ): + try: + range_param = self._calculate_range_param( + part_size, part_index, num_parts + ) + + max_attempts = self._config.num_download_attempts + last_exception = None + for i in range(max_attempts): + try: + logger.debug("Making get_object call.") + response = self._client.get_object( + Bucket=bucket, Key=key, Range=range_param + ) + streaming_body = StreamReaderProgress( + response['Body'], callback + ) + buffer_size = 1024 * 16 + current_index = part_size * part_index + for chunk in iter( + lambda: streaming_body.read(buffer_size), b'' + ): + self._ioqueue.put((current_index, chunk)) + current_index += len(chunk) + return + except ( + socket.timeout, + OSError, + ReadTimeoutError, + IncompleteReadError, + ) as e: + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) + last_exception = e + continue + raise RetriesExceededError(last_exception) + finally: + logger.debug("EXITING _download_range for part: %s", part_index) + + def _perform_io_writes(self, filename): + with self._os.open(filename, 'wb') as f: + while True: + task = self._ioqueue.get() + if task is SHUTDOWN_SENTINEL: + logger.debug( + "Shutdown sentinel received in IO handler, " + "shutting down IO handler." + ) + return + else: + try: + offset, data = task + f.seek(offset) + f.write(data) + except Exception as e: + logger.debug( + "Caught exception in IO thread: %s", + e, + exc_info=True, + ) + self._ioqueue.trigger_shutdown() + raise + + +class TransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + max_concurrency=10, + multipart_chunksize=8 * MB, + num_download_attempts=5, + max_io_queue=100, + ): + self.multipart_threshold = multipart_threshold + self.max_concurrency = max_concurrency + self.multipart_chunksize = multipart_chunksize + self.num_download_attempts = num_download_attempts + self.max_io_queue = max_io_queue + + +class S3Transfer: + + ALLOWED_DOWNLOAD_ARGS = [ + 'VersionId', + 'SSECustomerAlgorithm', + 'SSECustomerKey', + 'SSECustomerKeyMD5', + 'RequestPayer', + ] + + ALLOWED_UPLOAD_ARGS = [ + 'ACL', + 'CacheControl', + 'ContentDisposition', + 'ContentEncoding', + 'ContentLanguage', + 'ContentType', + 'Expires', + 'GrantFullControl', + 'GrantRead', + 'GrantReadACP', + 'GrantWriteACL', + 'Metadata', + 'RequestPayer', + 'ServerSideEncryption', + 'StorageClass', + 'SSECustomerAlgorithm', + 'SSECustomerKey', + 'SSECustomerKeyMD5', + 'SSEKMSKeyId', + 'SSEKMSEncryptionContext', + 'Tagging', + ] + + def __init__(self, client, config=None, osutil=None): + self._client = client + if config is None: + config = TransferConfig() + self._config = config + if osutil is None: + osutil = OSUtils() + self._osutil = osutil + + def upload_file( + self, filename, bucket, key, callback=None, extra_args=None + ): + """Upload a file to an S3 object. + + Variants have also been injected into S3 client, Bucket and Object. + You don't have to use S3Transfer.upload_file() directly. + """ + if extra_args is None: + extra_args = {} + self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS) + events = self._client.meta.events + events.register_first( + 'request-created.s3', + disable_upload_callbacks, + unique_id='s3upload-callback-disable', + ) + events.register_last( + 'request-created.s3', + enable_upload_callbacks, + unique_id='s3upload-callback-enable', + ) + if ( + self._osutil.get_file_size(filename) + >= self._config.multipart_threshold + ): + self._multipart_upload(filename, bucket, key, callback, extra_args) + else: + self._put_object(filename, bucket, key, callback, extra_args) + + def _put_object(self, filename, bucket, key, callback, extra_args): + # We're using open_file_chunk_reader so we can take advantage of the + # progress callback functionality. + open_chunk_reader = self._osutil.open_file_chunk_reader + with open_chunk_reader( + filename, + 0, + self._osutil.get_file_size(filename), + callback=callback, + ) as body: + self._client.put_object( + Bucket=bucket, Key=key, Body=body, **extra_args + ) + + def download_file( + self, bucket, key, filename, extra_args=None, callback=None + ): + """Download an S3 object to a file. + + Variants have also been injected into S3 client, Bucket and Object. + You don't have to use S3Transfer.download_file() directly. + """ + # This method will issue a ``head_object`` request to determine + # the size of the S3 object. This is used to determine if the + # object is downloaded in parallel. + if extra_args is None: + extra_args = {} + self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS) + object_size = self._object_size(bucket, key, extra_args) + temp_filename = filename + os.extsep + random_file_extension() + try: + self._download_file( + bucket, key, temp_filename, object_size, extra_args, callback + ) + except Exception: + logger.debug( + "Exception caught in download_file, removing partial " + "file: %s", + temp_filename, + exc_info=True, + ) + self._osutil.remove_file(temp_filename) + raise + else: + self._osutil.rename_file(temp_filename, filename) + + def _download_file( + self, bucket, key, filename, object_size, extra_args, callback + ): + if object_size >= self._config.multipart_threshold: + self._ranged_download( + bucket, key, filename, object_size, extra_args, callback + ) + else: + self._get_object(bucket, key, filename, extra_args, callback) + + def _validate_all_known_args(self, actual, allowed): + for kwarg in actual: + if kwarg not in allowed: + raise ValueError( + "Invalid extra_args key '%s', " + "must be one of: %s" % (kwarg, ', '.join(allowed)) + ) + + def _ranged_download( + self, bucket, key, filename, object_size, extra_args, callback + ): + downloader = MultipartDownloader( + self._client, self._config, self._osutil + ) + downloader.download_file( + bucket, key, filename, object_size, extra_args, callback + ) + + def _get_object(self, bucket, key, filename, extra_args, callback): + # precondition: num_download_attempts > 0 + max_attempts = self._config.num_download_attempts + last_exception = None + for i in range(max_attempts): + try: + return self._do_get_object( + bucket, key, filename, extra_args, callback + ) + except ( + socket.timeout, + OSError, + ReadTimeoutError, + IncompleteReadError, + ) as e: + # TODO: we need a way to reset the callback if the + # download failed. + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) + last_exception = e + continue + raise RetriesExceededError(last_exception) + + def _do_get_object(self, bucket, key, filename, extra_args, callback): + response = self._client.get_object( + Bucket=bucket, Key=key, **extra_args + ) + streaming_body = StreamReaderProgress(response['Body'], callback) + with self._osutil.open(filename, 'wb') as f: + for chunk in iter(lambda: streaming_body.read(8192), b''): + f.write(chunk) + + def _object_size(self, bucket, key, extra_args): + return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[ + 'ContentLength' + ] + + def _multipart_upload(self, filename, bucket, key, callback, extra_args): + uploader = MultipartUploader(self._client, self._config, self._osutil) + uploader.upload_file(filename, bucket, key, callback, extra_args) diff --git a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py new file mode 100644 index 0000000000..9bac5885e1 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py @@ -0,0 +1,439 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import threading +import time + + +class RequestExceededException(Exception): + def __init__(self, requested_amt, retry_time): + """Error when requested amount exceeds what is allowed + + The request that raised this error should be retried after waiting + the time specified by ``retry_time``. + + :type requested_amt: int + :param requested_amt: The originally requested byte amount + + :type retry_time: float + :param retry_time: The length in time to wait to retry for the + requested amount + """ + self.requested_amt = requested_amt + self.retry_time = retry_time + msg = 'Request amount {} exceeded the amount available. Retry in {}'.format( + requested_amt, retry_time + ) + super().__init__(msg) + + +class RequestToken: + """A token to pass as an identifier when consuming from the LeakyBucket""" + + pass + + +class TimeUtils: + def time(self): + """Get the current time back + + :rtype: float + :returns: The current time in seconds + """ + return time.time() + + def sleep(self, value): + """Sleep for a designated time + + :type value: float + :param value: The time to sleep for in seconds + """ + return time.sleep(value) + + +class BandwidthLimiter: + def __init__(self, leaky_bucket, time_utils=None): + """Limits bandwidth for shared S3 transfers + + :type leaky_bucket: LeakyBucket + :param leaky_bucket: The leaky bucket to use limit bandwidth + + :type time_utils: TimeUtils + :param time_utils: Time utility to use for interacting with time. + """ + self._leaky_bucket = leaky_bucket + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + + def get_bandwith_limited_stream( + self, fileobj, transfer_coordinator, enabled=True + ): + """Wraps a fileobj in a bandwidth limited stream wrapper + + :type fileobj: file-like obj + :param fileobj: The file-like obj to wrap + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type enabled: boolean + :param enabled: Whether bandwidth limiting should be enabled to start + """ + stream = BandwidthLimitedStream( + fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils + ) + if not enabled: + stream.disable_bandwidth_limiting() + return stream + + +class BandwidthLimitedStream: + def __init__( + self, + fileobj, + leaky_bucket, + transfer_coordinator, + time_utils=None, + bytes_threshold=256 * 1024, + ): + """Limits bandwidth for reads on a wrapped stream + + :type fileobj: file-like object + :param fileobj: The file like object to wrap + + :type leaky_bucket: LeakyBucket + :param leaky_bucket: The leaky bucket to use to throttle reads on + the stream + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + """ + self._fileobj = fileobj + self._leaky_bucket = leaky_bucket + self._transfer_coordinator = transfer_coordinator + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._bandwidth_limiting_enabled = True + self._request_token = RequestToken() + self._bytes_seen = 0 + self._bytes_threshold = bytes_threshold + + def enable_bandwidth_limiting(self): + """Enable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = True + + def disable_bandwidth_limiting(self): + """Disable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = False + + def read(self, amount): + """Read a specified amount + + Reads will only be throttled if bandwidth limiting is enabled. + """ + if not self._bandwidth_limiting_enabled: + return self._fileobj.read(amount) + + # We do not want to be calling consume on every read as the read + # amounts can be small causing the lock of the leaky bucket to + # introduce noticeable overhead. So instead we keep track of + # how many bytes we have seen and only call consume once we pass a + # certain threshold. + self._bytes_seen += amount + if self._bytes_seen < self._bytes_threshold: + return self._fileobj.read(amount) + + self._consume_through_leaky_bucket() + return self._fileobj.read(amount) + + def _consume_through_leaky_bucket(self): + # NOTE: If the read amount on the stream are high, it will result + # in large bursty behavior as there is not an interface for partial + # reads. However given the read's on this abstraction are at most 256KB + # (via downloads), it reduces the burstiness to be small KB bursts at + # worst. + while not self._transfer_coordinator.exception: + try: + self._leaky_bucket.consume( + self._bytes_seen, self._request_token + ) + self._bytes_seen = 0 + return + except RequestExceededException as e: + self._time_utils.sleep(e.retry_time) + else: + raise self._transfer_coordinator.exception + + def signal_transferring(self): + """Signal that data being read is being transferred to S3""" + self.enable_bandwidth_limiting() + + def signal_not_transferring(self): + """Signal that data being read is not being transferred to S3""" + self.disable_bandwidth_limiting() + + def seek(self, where, whence=0): + self._fileobj.seek(where, whence) + + def tell(self): + return self._fileobj.tell() + + def close(self): + if self._bandwidth_limiting_enabled and self._bytes_seen: + # This handles the case where the file is small enough to never + # trigger the threshold and thus is never subjugated to the + # leaky bucket on read(). This specifically happens for small + # uploads. So instead to account for those bytes, have + # it go through the leaky bucket when the file gets closed. + self._consume_through_leaky_bucket() + self._fileobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class LeakyBucket: + def __init__( + self, + max_rate, + time_utils=None, + rate_tracker=None, + consumption_scheduler=None, + ): + """A leaky bucket abstraction to limit bandwidth consumption + + :type rate: int + :type rate: The maximum rate to allow. This rate is in terms of + bytes per second. + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + + :type rate_tracker: BandwidthRateTracker + :param rate_tracker: Tracks bandwidth consumption + + :type consumption_scheduler: ConsumptionScheduler + :param consumption_scheduler: Schedules consumption retries when + necessary + """ + self._max_rate = float(max_rate) + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._lock = threading.Lock() + self._rate_tracker = rate_tracker + if rate_tracker is None: + self._rate_tracker = BandwidthRateTracker() + self._consumption_scheduler = consumption_scheduler + if consumption_scheduler is None: + self._consumption_scheduler = ConsumptionScheduler() + + def consume(self, amt, request_token): + """Consume an a requested amount + + :type amt: int + :param amt: The amount of bytes to request to consume + + :type request_token: RequestToken + :param request_token: The token associated to the consumption + request that is used to identify the request. So if a + RequestExceededException is raised the token should be used + in subsequent retry consume() request. + + :raises RequestExceededException: If the consumption amount would + exceed the maximum allocated bandwidth + + :rtype: int + :returns: The amount consumed + """ + with self._lock: + time_now = self._time_utils.time() + if self._consumption_scheduler.is_scheduled(request_token): + return self._release_requested_amt_for_scheduled_request( + amt, request_token, time_now + ) + elif self._projected_to_exceed_max_rate(amt, time_now): + self._raise_request_exceeded_exception( + amt, request_token, time_now + ) + else: + return self._release_requested_amt(amt, time_now) + + def _projected_to_exceed_max_rate(self, amt, time_now): + projected_rate = self._rate_tracker.get_projected_rate(amt, time_now) + return projected_rate > self._max_rate + + def _release_requested_amt_for_scheduled_request( + self, amt, request_token, time_now + ): + self._consumption_scheduler.process_scheduled_consumption( + request_token + ) + return self._release_requested_amt(amt, time_now) + + def _raise_request_exceeded_exception(self, amt, request_token, time_now): + allocated_time = amt / float(self._max_rate) + retry_time = self._consumption_scheduler.schedule_consumption( + amt, request_token, allocated_time + ) + raise RequestExceededException( + requested_amt=amt, retry_time=retry_time + ) + + def _release_requested_amt(self, amt, time_now): + self._rate_tracker.record_consumption_rate(amt, time_now) + return amt + + +class ConsumptionScheduler: + def __init__(self): + """Schedules when to consume a desired amount""" + self._tokens_to_scheduled_consumption = {} + self._total_wait = 0 + + def is_scheduled(self, token): + """Indicates if a consumption request has been scheduled + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + return token in self._tokens_to_scheduled_consumption + + def schedule_consumption(self, amt, token, time_to_consume): + """Schedules a wait time to be able to consume an amount + + :type amt: int + :param amt: The amount of bytes scheduled to be consumed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + + :type time_to_consume: float + :param time_to_consume: The desired time it should take for that + specific request amount to be consumed in regardless of previously + scheduled consumption requests + + :rtype: float + :returns: The amount of time to wait for the specific request before + actually consuming the specified amount. + """ + self._total_wait += time_to_consume + self._tokens_to_scheduled_consumption[token] = { + 'wait_duration': self._total_wait, + 'time_to_consume': time_to_consume, + } + return self._total_wait + + def process_scheduled_consumption(self, token): + """Processes a scheduled consumption request that has completed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + scheduled_retry = self._tokens_to_scheduled_consumption.pop(token) + self._total_wait = max( + self._total_wait - scheduled_retry['time_to_consume'], 0 + ) + + +class BandwidthRateTracker: + def __init__(self, alpha=0.8): + """Tracks the rate of bandwidth consumption + + :type a: float + :param a: The constant to use in calculating the exponentional moving + average of the bandwidth rate. Specifically it is used in the + following calculation: + + current_rate = alpha * new_rate + (1 - alpha) * current_rate + + This value of this constant should be between 0 and 1. + """ + self._alpha = alpha + self._last_time = None + self._current_rate = None + + @property + def current_rate(self): + """The current transfer rate + + :rtype: float + :returns: The current tracked transfer rate + """ + if self._last_time is None: + return 0.0 + return self._current_rate + + def get_projected_rate(self, amt, time_at_consumption): + """Get the projected rate using a provided amount and time + + :type amt: int + :param amt: The proposed amount to consume + + :type time_at_consumption: float + :param time_at_consumption: The proposed time to consume at + + :rtype: float + :returns: The consumption rate if that amt and time were consumed + """ + if self._last_time is None: + return 0.0 + return self._calculate_exponential_moving_average_rate( + amt, time_at_consumption + ) + + def record_consumption_rate(self, amt, time_at_consumption): + """Record the consumption rate based off amount and time point + + :type amt: int + :param amt: The amount that got consumed + + :type time_at_consumption: float + :param time_at_consumption: The time at which the amount was consumed + """ + if self._last_time is None: + self._last_time = time_at_consumption + self._current_rate = 0.0 + return + self._current_rate = self._calculate_exponential_moving_average_rate( + amt, time_at_consumption + ) + self._last_time = time_at_consumption + + def _calculate_rate(self, amt, time_at_consumption): + time_delta = time_at_consumption - self._last_time + if time_delta <= 0: + # While it is really unlikely to see this in an actual transfer, + # we do not want to be returning back a negative rate or try to + # divide the amount by zero. So instead return back an infinite + # rate as the time delta is infinitesimally small. + return float('inf') + return amt / (time_delta) + + def _calculate_exponential_moving_average_rate( + self, amt, time_at_consumption + ): + new_rate = self._calculate_rate(amt, time_at_consumption) + return self._alpha * new_rate + (1 - self._alpha) * self._current_rate diff --git a/contrib/python/s3transfer/py3/s3transfer/compat.py b/contrib/python/s3transfer/py3/s3transfer/compat.py new file mode 100644 index 0000000000..68267ad0e2 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/compat.py @@ -0,0 +1,94 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import errno +import inspect +import os +import socket +import sys + +from botocore.compat import six + +if sys.platform.startswith('win'): + def rename_file(current_filename, new_filename): + try: + os.remove(new_filename) + except OSError as e: + if not e.errno == errno.ENOENT: + # We only want to a ignore trying to remove + # a file that does not exist. If it fails + # for any other reason we should be propagating + # that exception. + raise + os.rename(current_filename, new_filename) +else: + rename_file = os.rename + + +def accepts_kwargs(func): + return inspect.getfullargspec(func)[2] + + +# In python 3, socket.error is OSError, which is too general +# for what we want (i.e FileNotFoundError is a subclass of OSError). +# In python 3, all the socket related errors are in a newly created +# ConnectionError. +SOCKET_ERROR = ConnectionError +MAXINT = None + + +def seekable(fileobj): + """Backwards compat function to determine if a fileobj is seekable + + :param fileobj: The file-like object to determine if seekable + + :returns: True, if seekable. False, otherwise. + """ + # If the fileobj has a seekable attr, try calling the seekable() + # method on it. + if hasattr(fileobj, 'seekable'): + return fileobj.seekable() + # If there is no seekable attr, check if the object can be seeked + # or telled. If it can, try to seek to the current position. + elif hasattr(fileobj, 'seek') and hasattr(fileobj, 'tell'): + try: + fileobj.seek(0, 1) + return True + except OSError: + # If an io related error was thrown then it is not seekable. + return False + # Else, the fileobj is not seekable + return False + + +def readable(fileobj): + """Determines whether or not a file-like object is readable. + + :param fileobj: The file-like object to determine if readable + + :returns: True, if readable. False otherwise. + """ + if hasattr(fileobj, 'readable'): + return fileobj.readable() + + return hasattr(fileobj, 'read') + + +def fallocate(fileobj, size): + if hasattr(os, 'posix_fallocate'): + os.posix_fallocate(fileobj.fileno(), 0, size) + else: + fileobj.truncate(size) + + +# Import at end of file to avoid circular dependencies +from multiprocessing.managers import BaseManager # noqa: F401,E402 diff --git a/contrib/python/s3transfer/py3/s3transfer/constants.py b/contrib/python/s3transfer/py3/s3transfer/constants.py new file mode 100644 index 0000000000..ba35bc72e9 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/constants.py @@ -0,0 +1,29 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import s3transfer + +KB = 1024 +MB = KB * KB +GB = MB * KB + +ALLOWED_DOWNLOAD_ARGS = [ + 'VersionId', + 'SSECustomerAlgorithm', + 'SSECustomerKey', + 'SSECustomerKeyMD5', + 'RequestPayer', + 'ExpectedBucketOwner', +] + +USER_AGENT = 's3transfer/%s' % s3transfer.__version__ +PROCESS_USER_AGENT = '%s processpool' % USER_AGENT diff --git a/contrib/python/s3transfer/py3/s3transfer/copies.py b/contrib/python/s3transfer/py3/s3transfer/copies.py new file mode 100644 index 0000000000..a1dfdc8ba3 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/copies.py @@ -0,0 +1,368 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import math + +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import ( + ChunksizeAdjuster, + calculate_range_parameter, + get_callbacks, + get_filtered_dict, +) + + +class CopySubmissionTask(SubmissionTask): + """Task for submitting tasks to execute a copy""" + + EXTRA_ARGS_TO_HEAD_ARGS_MAPPING = { + 'CopySourceIfMatch': 'IfMatch', + 'CopySourceIfModifiedSince': 'IfModifiedSince', + 'CopySourceIfNoneMatch': 'IfNoneMatch', + 'CopySourceIfUnmodifiedSince': 'IfUnmodifiedSince', + 'CopySourceSSECustomerKey': 'SSECustomerKey', + 'CopySourceSSECustomerAlgorithm': 'SSECustomerAlgorithm', + 'CopySourceSSECustomerKeyMD5': 'SSECustomerKeyMD5', + 'RequestPayer': 'RequestPayer', + 'ExpectedBucketOwner': 'ExpectedBucketOwner', + } + + UPLOAD_PART_COPY_ARGS = [ + 'CopySourceIfMatch', + 'CopySourceIfModifiedSince', + 'CopySourceIfNoneMatch', + 'CopySourceIfUnmodifiedSince', + 'CopySourceSSECustomerKey', + 'CopySourceSSECustomerAlgorithm', + 'CopySourceSSECustomerKeyMD5', + 'SSECustomerKey', + 'SSECustomerAlgorithm', + 'SSECustomerKeyMD5', + 'RequestPayer', + 'ExpectedBucketOwner', + ] + + CREATE_MULTIPART_ARGS_BLACKLIST = [ + 'CopySourceIfMatch', + 'CopySourceIfModifiedSince', + 'CopySourceIfNoneMatch', + 'CopySourceIfUnmodifiedSince', + 'CopySourceSSECustomerKey', + 'CopySourceSSECustomerAlgorithm', + 'CopySourceSSECustomerKeyMD5', + 'MetadataDirective', + 'TaggingDirective', + ] + + COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] + + def _submit( + self, client, config, osutil, request_executor, transfer_future + ): + """ + :param client: The client associated with the transfer manager + + :type config: s3transfer.manager.TransferConfig + :param config: The transfer config associated with the transfer + manager + + :type osutil: s3transfer.utils.OSUtil + :param osutil: The os utility associated to the transfer manager + + :type request_executor: s3transfer.futures.BoundedExecutor + :param request_executor: The request executor associated with the + transfer manager + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + """ + # Determine the size if it was not provided + if transfer_future.meta.size is None: + # If a size was not provided figure out the size for the + # user. Note that we will only use the client provided to + # the TransferManager. If the object is outside of the region + # of the client, they may have to provide the file size themselves + # with a completely new client. + call_args = transfer_future.meta.call_args + head_object_request = ( + self._get_head_object_request_from_copy_source( + call_args.copy_source + ) + ) + extra_args = call_args.extra_args + + # Map any values that may be used in the head object that is + # used in the copy object + for param, value in extra_args.items(): + if param in self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING: + head_object_request[ + self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param] + ] = value + + response = call_args.source_client.head_object( + **head_object_request + ) + transfer_future.meta.provide_transfer_size( + response['ContentLength'] + ) + + # If it is greater than threshold do a multipart copy, otherwise + # do a regular copy object. + if transfer_future.meta.size < config.multipart_threshold: + self._submit_copy_request( + client, config, osutil, request_executor, transfer_future + ) + else: + self._submit_multipart_request( + client, config, osutil, request_executor, transfer_future + ) + + def _submit_copy_request( + self, client, config, osutil, request_executor, transfer_future + ): + call_args = transfer_future.meta.call_args + + # Get the needed progress callbacks for the task + progress_callbacks = get_callbacks(transfer_future, 'progress') + + # Submit the request of a single copy. + self._transfer_coordinator.submit( + request_executor, + CopyObjectTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'copy_source': call_args.copy_source, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': call_args.extra_args, + 'callbacks': progress_callbacks, + 'size': transfer_future.meta.size, + }, + is_final=True, + ), + ) + + def _submit_multipart_request( + self, client, config, osutil, request_executor, transfer_future + ): + call_args = transfer_future.meta.call_args + + # Submit the request to create a multipart upload and make sure it + # does not include any of the arguments used for copy part. + create_multipart_extra_args = {} + for param, val in call_args.extra_args.items(): + if param not in self.CREATE_MULTIPART_ARGS_BLACKLIST: + create_multipart_extra_args[param] = val + + create_multipart_future = self._transfer_coordinator.submit( + request_executor, + CreateMultipartUploadTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': create_multipart_extra_args, + }, + ), + ) + + # Determine how many parts are needed based on filesize and + # desired chunksize. + part_size = config.multipart_chunksize + adjuster = ChunksizeAdjuster() + part_size = adjuster.adjust_chunksize( + part_size, transfer_future.meta.size + ) + num_parts = int( + math.ceil(transfer_future.meta.size / float(part_size)) + ) + + # Submit requests to upload the parts of the file. + part_futures = [] + progress_callbacks = get_callbacks(transfer_future, 'progress') + + for part_number in range(1, num_parts + 1): + extra_part_args = self._extra_upload_part_args( + call_args.extra_args + ) + # The part number for upload part starts at 1 while the + # range parameter starts at zero, so just subtract 1 off of + # the part number + extra_part_args['CopySourceRange'] = calculate_range_parameter( + part_size, + part_number - 1, + num_parts, + transfer_future.meta.size, + ) + # Get the size of the part copy as well for the progress + # callbacks. + size = self._get_transfer_size( + part_size, + part_number - 1, + num_parts, + transfer_future.meta.size, + ) + part_futures.append( + self._transfer_coordinator.submit( + request_executor, + CopyPartTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'copy_source': call_args.copy_source, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'part_number': part_number, + 'extra_args': extra_part_args, + 'callbacks': progress_callbacks, + 'size': size, + }, + pending_main_kwargs={ + 'upload_id': create_multipart_future + }, + ), + ) + ) + + complete_multipart_extra_args = self._extra_complete_multipart_args( + call_args.extra_args + ) + # Submit the request to complete the multipart upload. + self._transfer_coordinator.submit( + request_executor, + CompleteMultipartUploadTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': complete_multipart_extra_args, + }, + pending_main_kwargs={ + 'upload_id': create_multipart_future, + 'parts': part_futures, + }, + is_final=True, + ), + ) + + def _get_head_object_request_from_copy_source(self, copy_source): + if isinstance(copy_source, dict): + return copy.copy(copy_source) + else: + raise TypeError( + 'Expecting dictionary formatted: ' + '{"Bucket": bucket_name, "Key": key} ' + 'but got %s or type %s.' % (copy_source, type(copy_source)) + ) + + def _extra_upload_part_args(self, extra_args): + # Only the args in COPY_PART_ARGS actually need to be passed + # onto the upload_part_copy calls. + return get_filtered_dict(extra_args, self.UPLOAD_PART_COPY_ARGS) + + def _extra_complete_multipart_args(self, extra_args): + return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS) + + def _get_transfer_size( + self, part_size, part_index, num_parts, total_transfer_size + ): + if part_index == num_parts - 1: + # The last part may be different in size then the rest of the + # parts. + return total_transfer_size - (part_index * part_size) + return part_size + + +class CopyObjectTask(Task): + """Task to do a nonmultipart copy""" + + def _main( + self, client, copy_source, bucket, key, extra_args, callbacks, size + ): + """ + :param client: The client to use when calling PutObject + :param copy_source: The CopySource parameter to use + :param bucket: The name of the bucket to copy to + :param key: The name of the key to copy to + :param extra_args: A dictionary of any extra arguments that may be + used in the upload. + :param callbacks: List of callbacks to call after copy + :param size: The size of the transfer. This value is passed into + the callbacks + + """ + client.copy_object( + CopySource=copy_source, Bucket=bucket, Key=key, **extra_args + ) + for callback in callbacks: + callback(bytes_transferred=size) + + +class CopyPartTask(Task): + """Task to upload a part in a multipart copy""" + + def _main( + self, + client, + copy_source, + bucket, + key, + upload_id, + part_number, + extra_args, + callbacks, + size, + ): + """ + :param client: The client to use when calling PutObject + :param copy_source: The CopySource parameter to use + :param bucket: The name of the bucket to upload to + :param key: The name of the key to upload to + :param upload_id: The id of the upload + :param part_number: The number representing the part of the multipart + upload + :param extra_args: A dictionary of any extra arguments that may be + used in the upload. + :param callbacks: List of callbacks to call after copy part + :param size: The size of the transfer. This value is passed into + the callbacks + + :rtype: dict + :returns: A dictionary representing a part:: + + {'Etag': etag_value, 'PartNumber': part_number} + + This value can be appended to a list to be used to complete + the multipart upload. + """ + response = client.upload_part_copy( + CopySource=copy_source, + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + **extra_args + ) + for callback in callbacks: + callback(bytes_transferred=size) + etag = response['CopyPartResult']['ETag'] + return {'ETag': etag, 'PartNumber': part_number} diff --git a/contrib/python/s3transfer/py3/s3transfer/crt.py b/contrib/python/s3transfer/py3/s3transfer/crt.py new file mode 100644 index 0000000000..7b5d130136 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/crt.py @@ -0,0 +1,644 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import logging +import threading +from io import BytesIO + +import awscrt.http +import botocore.awsrequest +import botocore.session +from awscrt.auth import AwsCredentials, AwsCredentialsProvider +from awscrt.io import ( + ClientBootstrap, + ClientTlsContext, + DefaultHostResolver, + EventLoopGroup, + TlsContextOptions, +) +from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from botocore import UNSIGNED +from botocore.compat import urlsplit +from botocore.config import Config +from botocore.exceptions import NoCredentialsError + +from s3transfer.constants import GB, MB +from s3transfer.exceptions import TransferNotDoneError +from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.utils import CallArgs, OSUtils, get_callbacks + +logger = logging.getLogger(__name__) + + +class CRTCredentialProviderAdapter: + def __init__(self, botocore_credential_provider): + self._botocore_credential_provider = botocore_credential_provider + self._loaded_credentials = None + self._lock = threading.Lock() + + def __call__(self): + credentials = self._get_credentials().get_frozen_credentials() + return AwsCredentials( + credentials.access_key, credentials.secret_key, credentials.token + ) + + def _get_credentials(self): + with self._lock: + if self._loaded_credentials is None: + loaded_creds = ( + self._botocore_credential_provider.load_credentials() + ) + if loaded_creds is None: + raise NoCredentialsError() + self._loaded_credentials = loaded_creds + return self._loaded_credentials + + +def create_s3_crt_client( + region, + botocore_credential_provider=None, + num_threads=None, + target_throughput=5 * GB / 8, + part_size=8 * MB, + use_ssl=True, + verify=None, +): + """ + :type region: str + :param region: The region used for signing + + :type botocore_credential_provider: + Optional[botocore.credentials.CredentialResolver] + :param botocore_credential_provider: Provide credentials for CRT + to sign the request if not set, the request will not be signed + + :type num_threads: Optional[int] + :param num_threads: Number of worker threads generated. Default + is the number of processors in the machine. + + :type target_throughput: Optional[int] + :param target_throughput: Throughput target in Bytes. + Default is 0.625 GB/s (which translates to 5 Gb/s). + + :type part_size: Optional[int] + :param part_size: Size, in Bytes, of parts that files will be downloaded + or uploaded in. + + :type use_ssl: boolean + :param use_ssl: Whether or not to use SSL. By default, SSL is used. + Note that not all services support non-ssl connections. + + :type verify: Optional[boolean/string] + :param verify: Whether or not to verify SSL certificates. + By default SSL certificates are verified. You can provide the + following values: + + * False - do not validate SSL certificates. SSL will still be + used (unless use_ssl is False), but SSL certificates + will not be verified. + * path/to/cert/bundle.pem - A filename of the CA cert bundle to + use. Specify this argument if you want to use a custom CA cert + bundle instead of the default one on your system. + """ + + event_loop_group = EventLoopGroup(num_threads) + host_resolver = DefaultHostResolver(event_loop_group) + bootstrap = ClientBootstrap(event_loop_group, host_resolver) + provider = None + tls_connection_options = None + + tls_mode = ( + S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED + ) + if verify is not None: + tls_ctx_options = TlsContextOptions() + if verify: + tls_ctx_options.override_default_trust_store_from_path( + ca_filepath=verify + ) + else: + tls_ctx_options.verify_peer = False + client_tls_option = ClientTlsContext(tls_ctx_options) + tls_connection_options = client_tls_option.new_connection_options() + if botocore_credential_provider: + credentails_provider_adapter = CRTCredentialProviderAdapter( + botocore_credential_provider + ) + provider = AwsCredentialsProvider.new_delegate( + credentails_provider_adapter + ) + + target_gbps = target_throughput * 8 / GB + return S3Client( + bootstrap=bootstrap, + region=region, + credential_provider=provider, + part_size=part_size, + tls_mode=tls_mode, + tls_connection_options=tls_connection_options, + throughput_target_gbps=target_gbps, + ) + + +class CRTTransferManager: + def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): + """A transfer manager interface for Amazon S3 on CRT s3 client. + + :type crt_s3_client: awscrt.s3.S3Client + :param crt_s3_client: The CRT s3 client, handling all the + HTTP requests and functions under then hood + + :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer + :param crt_request_serializer: Serializer, generates unsigned crt HTTP + request. + + :type osutil: s3transfer.utils.OSUtils + :param osutil: OSUtils object to use for os-related behavior when + using with transfer manager. + """ + if osutil is None: + self._osutil = OSUtils() + self._crt_s3_client = crt_s3_client + self._s3_args_creator = S3ClientArgsCreator( + crt_request_serializer, self._osutil + ) + self._future_coordinators = [] + self._semaphore = threading.Semaphore(128) # not configurable + # A counter to create unique id's for each transfer submitted. + self._id_counter = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, *args): + cancel = False + if exc_type: + cancel = True + self._shutdown(cancel) + + def download( + self, bucket, key, fileobj, extra_args=None, subscribers=None + ): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("get_object", callargs) + + def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("put_object", callargs) + + def delete(self, bucket, key, extra_args=None, subscribers=None): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("delete_object", callargs) + + def shutdown(self, cancel=False): + self._shutdown(cancel) + + def _cancel_transfers(self): + for coordinator in self._future_coordinators: + if not coordinator.done(): + coordinator.cancel() + + def _finish_transfers(self): + for coordinator in self._future_coordinators: + coordinator.result() + + def _wait_transfers_done(self): + for coordinator in self._future_coordinators: + coordinator.wait_until_on_done_callbacks_complete() + + def _shutdown(self, cancel=False): + if cancel: + self._cancel_transfers() + try: + self._finish_transfers() + + except KeyboardInterrupt: + self._cancel_transfers() + except Exception: + pass + finally: + self._wait_transfers_done() + + def _release_semaphore(self, **kwargs): + self._semaphore.release() + + def _submit_transfer(self, request_type, call_args): + on_done_after_calls = [self._release_semaphore] + coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) + components = { + 'meta': CRTTransferMeta(self._id_counter, call_args), + 'coordinator': coordinator, + } + future = CRTTransferFuture(**components) + afterdone = AfterDoneHandler(coordinator) + on_done_after_calls.append(afterdone) + + try: + self._semaphore.acquire() + on_queued = self._s3_args_creator.get_crt_callback( + future, 'queued' + ) + on_queued() + crt_callargs = self._s3_args_creator.get_make_request_args( + request_type, + call_args, + coordinator, + future, + on_done_after_calls, + ) + crt_s3_request = self._crt_s3_client.make_request(**crt_callargs) + except Exception as e: + coordinator.set_exception(e, True) + on_done = self._s3_args_creator.get_crt_callback( + future, 'done', after_subscribers=on_done_after_calls + ) + on_done(error=e) + else: + coordinator.set_s3_request(crt_s3_request) + self._future_coordinators.append(coordinator) + + self._id_counter += 1 + return future + + +class CRTTransferMeta(BaseTransferMeta): + """Holds metadata about the CRTTransferFuture""" + + def __init__(self, transfer_id=None, call_args=None): + self._transfer_id = transfer_id + self._call_args = call_args + self._user_context = {} + + @property + def call_args(self): + return self._call_args + + @property + def transfer_id(self): + return self._transfer_id + + @property + def user_context(self): + return self._user_context + + +class CRTTransferFuture(BaseTransferFuture): + def __init__(self, meta=None, coordinator=None): + """The future associated to a submitted transfer request via CRT S3 client + + :type meta: s3transfer.crt.CRTTransferMeta + :param meta: The metadata associated to the transfer future. + + :type coordinator: s3transfer.crt.CRTTransferCoordinator + :param coordinator: The coordinator associated to the transfer future. + """ + self._meta = meta + if meta is None: + self._meta = CRTTransferMeta() + self._coordinator = coordinator + + @property + def meta(self): + return self._meta + + def done(self): + return self._coordinator.done() + + def result(self, timeout=None): + self._coordinator.result(timeout) + + def cancel(self): + self._coordinator.cancel() + + def set_exception(self, exception): + """Sets the exception on the future.""" + if not self.done(): + raise TransferNotDoneError( + 'set_exception can only be called once the transfer is ' + 'complete.' + ) + self._coordinator.set_exception(exception, override=True) + + +class BaseCRTRequestSerializer: + def serialize_http_request(self, transfer_type, future): + """Serialize CRT HTTP requests. + + :type transfer_type: string + :param transfer_type: the type of transfer made, + e.g 'put_object', 'get_object', 'delete_object' + + :type future: s3transfer.crt.CRTTransferFuture + + :rtype: awscrt.http.HttpRequest + :returns: An unsigned HTTP request to be used for the CRT S3 client + """ + raise NotImplementedError('serialize_http_request()') + + +class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer): + def __init__(self, session, client_kwargs=None): + """Serialize CRT HTTP request using botocore logic + It also takes into account configuration from both the session + and any keyword arguments that could be passed to + `Session.create_client()` when serializing the request. + + :type session: botocore.session.Session + + :type client_kwargs: Optional[Dict[str, str]]) + :param client_kwargs: The kwargs for the botocore + s3 client initialization. + """ + self._session = session + if client_kwargs is None: + client_kwargs = {} + self._resolve_client_config(session, client_kwargs) + self._client = session.create_client(**client_kwargs) + self._client.meta.events.register( + 'request-created.s3.*', self._capture_http_request + ) + self._client.meta.events.register( + 'after-call.s3.*', self._change_response_to_serialized_http_request + ) + self._client.meta.events.register( + 'before-send.s3.*', self._make_fake_http_response + ) + + def _resolve_client_config(self, session, client_kwargs): + user_provided_config = None + if session.get_default_client_config(): + user_provided_config = session.get_default_client_config() + if 'config' in client_kwargs: + user_provided_config = client_kwargs['config'] + + client_config = Config(signature_version=UNSIGNED) + if user_provided_config: + client_config = user_provided_config.merge(client_config) + client_kwargs['config'] = client_config + client_kwargs["service_name"] = "s3" + + def _crt_request_from_aws_request(self, aws_request): + url_parts = urlsplit(aws_request.url) + crt_path = url_parts.path + if url_parts.query: + crt_path = f'{crt_path}?{url_parts.query}' + headers_list = [] + for name, value in aws_request.headers.items(): + if isinstance(value, str): + headers_list.append((name, value)) + else: + headers_list.append((name, str(value, 'utf-8'))) + + crt_headers = awscrt.http.HttpHeaders(headers_list) + # CRT requires body (if it exists) to be an I/O stream. + crt_body_stream = None + if aws_request.body: + if hasattr(aws_request.body, 'seek'): + crt_body_stream = aws_request.body + else: + crt_body_stream = BytesIO(aws_request.body) + + crt_request = awscrt.http.HttpRequest( + method=aws_request.method, + path=crt_path, + headers=crt_headers, + body_stream=crt_body_stream, + ) + return crt_request + + def _convert_to_crt_http_request(self, botocore_http_request): + # Logic that does CRTUtils.crt_request_from_aws_request + crt_request = self._crt_request_from_aws_request(botocore_http_request) + if crt_request.headers.get("host") is None: + # If host is not set, set it for the request before using CRT s3 + url_parts = urlsplit(botocore_http_request.url) + crt_request.headers.set("host", url_parts.netloc) + if crt_request.headers.get('Content-MD5') is not None: + crt_request.headers.remove("Content-MD5") + return crt_request + + def _capture_http_request(self, request, **kwargs): + request.context['http_request'] = request + + def _change_response_to_serialized_http_request( + self, context, parsed, **kwargs + ): + request = context['http_request'] + parsed['HTTPRequest'] = request.prepare() + + def _make_fake_http_response(self, request, **kwargs): + return botocore.awsrequest.AWSResponse( + None, + 200, + {}, + FakeRawResponse(b""), + ) + + def _get_botocore_http_request(self, client_method, call_args): + return getattr(self._client, client_method)( + Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args + )['HTTPRequest'] + + def serialize_http_request(self, transfer_type, future): + botocore_http_request = self._get_botocore_http_request( + transfer_type, future.meta.call_args + ) + crt_request = self._convert_to_crt_http_request(botocore_http_request) + return crt_request + + +class FakeRawResponse(BytesIO): + def stream(self, amt=1024, decode_content=None): + while True: + chunk = self.read(amt) + if not chunk: + break + yield chunk + + +class CRTTransferCoordinator: + """A helper class for managing CRTTransferFuture""" + + def __init__(self, transfer_id=None, s3_request=None): + self.transfer_id = transfer_id + self._s3_request = s3_request + self._lock = threading.Lock() + self._exception = None + self._crt_future = None + self._done_event = threading.Event() + + @property + def s3_request(self): + return self._s3_request + + def set_done_callbacks_complete(self): + self._done_event.set() + + def wait_until_on_done_callbacks_complete(self, timeout=None): + self._done_event.wait(timeout) + + def set_exception(self, exception, override=False): + with self._lock: + if not self.done() or override: + self._exception = exception + + def cancel(self): + if self._s3_request: + self._s3_request.cancel() + + def result(self, timeout=None): + if self._exception: + raise self._exception + try: + self._crt_future.result(timeout) + except KeyboardInterrupt: + self.cancel() + raise + finally: + if self._s3_request: + self._s3_request = None + self._crt_future.result(timeout) + + def done(self): + if self._crt_future is None: + return False + return self._crt_future.done() + + def set_s3_request(self, s3_request): + self._s3_request = s3_request + self._crt_future = self._s3_request.finished_future + + +class S3ClientArgsCreator: + def __init__(self, crt_request_serializer, os_utils): + self._request_serializer = crt_request_serializer + self._os_utils = os_utils + + def get_make_request_args( + self, request_type, call_args, coordinator, future, on_done_after_calls + ): + recv_filepath = None + send_filepath = None + s3_meta_request_type = getattr( + S3RequestType, request_type.upper(), S3RequestType.DEFAULT + ) + on_done_before_calls = [] + if s3_meta_request_type == S3RequestType.GET_OBJECT: + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + file_ondone_call = RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + on_done_before_calls.append(file_ondone_call) + elif s3_meta_request_type == S3RequestType.PUT_OBJECT: + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + + crt_request = self._request_serializer.serialize_http_request( + request_type, future + ) + + return { + 'request': crt_request, + 'type': s3_meta_request_type, + 'recv_filepath': recv_filepath, + 'send_filepath': send_filepath, + 'on_done': self.get_crt_callback( + future, 'done', on_done_before_calls, on_done_after_calls + ), + 'on_progress': self.get_crt_callback(future, 'progress'), + } + + def get_crt_callback( + self, + future, + callback_type, + before_subscribers=None, + after_subscribers=None, + ): + def invoke_all_callbacks(*args, **kwargs): + callbacks_list = [] + if before_subscribers is not None: + callbacks_list += before_subscribers + callbacks_list += get_callbacks(future, callback_type) + if after_subscribers is not None: + callbacks_list += after_subscribers + for callback in callbacks_list: + # The get_callbacks helper will set the first augment + # by keyword, the other augments need to be set by keyword + # as well + if callback_type == "progress": + callback(bytes_transferred=args[0]) + else: + callback(*args, **kwargs) + + return invoke_all_callbacks + + +class RenameTempFileHandler: + def __init__(self, coordinator, final_filename, temp_filename, osutil): + self._coordinator = coordinator + self._final_filename = final_filename + self._temp_filename = temp_filename + self._osutil = osutil + + def __call__(self, **kwargs): + error = kwargs['error'] + if error: + self._osutil.remove_file(self._temp_filename) + else: + try: + self._osutil.rename_file( + self._temp_filename, self._final_filename + ) + except Exception as e: + self._osutil.remove_file(self._temp_filename) + # the CRT future has done already at this point + self._coordinator.set_exception(e) + + +class AfterDoneHandler: + def __init__(self, coordinator): + self._coordinator = coordinator + + def __call__(self, **kwargs): + self._coordinator.set_done_callbacks_complete() diff --git a/contrib/python/s3transfer/py3/s3transfer/delete.py b/contrib/python/s3transfer/py3/s3transfer/delete.py new file mode 100644 index 0000000000..74084d312a --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/delete.py @@ -0,0 +1,71 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.tasks import SubmissionTask, Task + + +class DeleteSubmissionTask(SubmissionTask): + """Task for submitting tasks to execute an object deletion.""" + + def _submit(self, client, request_executor, transfer_future, **kwargs): + """ + :param client: The client associated with the transfer manager + + :type config: s3transfer.manager.TransferConfig + :param config: The transfer config associated with the transfer + manager + + :type osutil: s3transfer.utils.OSUtil + :param osutil: The os utility associated to the transfer manager + + :type request_executor: s3transfer.futures.BoundedExecutor + :param request_executor: The request executor associated with the + transfer manager + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + """ + call_args = transfer_future.meta.call_args + + self._transfer_coordinator.submit( + request_executor, + DeleteObjectTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': call_args.extra_args, + }, + is_final=True, + ), + ) + + +class DeleteObjectTask(Task): + def _main(self, client, bucket, key, extra_args): + """ + + :param client: The S3 client to use when calling DeleteObject + + :type bucket: str + :param bucket: The name of the bucket. + + :type key: str + :param key: The name of the object to delete. + + :type extra_args: dict + :param extra_args: Extra arguments to pass to the DeleteObject call. + + """ + client.delete_object(Bucket=bucket, Key=key, **extra_args) diff --git a/contrib/python/s3transfer/py3/s3transfer/download.py b/contrib/python/s3transfer/py3/s3transfer/download.py new file mode 100644 index 0000000000..dc8980d4ed --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/download.py @@ -0,0 +1,790 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import heapq +import logging +import threading + +from s3transfer.compat import seekable +from s3transfer.exceptions import RetriesExceededError +from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG +from s3transfer.tasks import SubmissionTask, Task +from s3transfer.utils import ( + S3_RETRYABLE_DOWNLOAD_ERRORS, + CountCallbackInvoker, + DeferredOpenFile, + FunctionContainer, + StreamReaderProgress, + calculate_num_parts, + calculate_range_parameter, + get_callbacks, + invoke_progress_callbacks, +) + +logger = logging.getLogger(__name__) + + +class DownloadOutputManager: + """Base manager class for handling various types of files for downloads + + This class is typically used for the DownloadSubmissionTask class to help + determine the following: + + * Provides the fileobj to write to downloads to + * Get a task to complete once everything downloaded has been written + + The answers/implementations differ for the various types of file outputs + that may be accepted. All implementations must subclass and override + public methods from this class. + """ + + def __init__(self, osutil, transfer_coordinator, io_executor): + self._osutil = osutil + self._transfer_coordinator = transfer_coordinator + self._io_executor = io_executor + + @classmethod + def is_compatible(cls, download_target, osutil): + """Determines if the target for the download is compatible with manager + + :param download_target: The target for which the upload will write + data to. + + :param osutil: The os utility to be used for the transfer + + :returns: True if the manager can handle the type of target specified + otherwise returns False. + """ + raise NotImplementedError('must implement is_compatible()') + + def get_download_task_tag(self): + """Get the tag (if any) to associate all GetObjectTasks + + :rtype: s3transfer.futures.TaskTag + :returns: The tag to associate all GetObjectTasks with + """ + return None + + def get_fileobj_for_io_writes(self, transfer_future): + """Get file-like object to use for io writes in the io executor + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The future associated with upload request + + returns: A file-like object to write to + """ + raise NotImplementedError('must implement get_fileobj_for_io_writes()') + + def queue_file_io_task(self, fileobj, data, offset): + """Queue IO write for submission to the IO executor. + + This method accepts an IO executor and information about the + downloaded data, and handles submitting this to the IO executor. + + This method may defer submission to the IO executor if necessary. + + """ + self._transfer_coordinator.submit( + self._io_executor, self.get_io_write_task(fileobj, data, offset) + ) + + def get_io_write_task(self, fileobj, data, offset): + """Get an IO write task for the requested set of data + + This task can be ran immediately or be submitted to the IO executor + for it to run. + + :type fileobj: file-like object + :param fileobj: The file-like object to write to + + :type data: bytes + :param data: The data to write out + + :type offset: integer + :param offset: The offset to write the data to in the file-like object + + :returns: An IO task to be used to write data to a file-like object + """ + return IOWriteTask( + self._transfer_coordinator, + main_kwargs={ + 'fileobj': fileobj, + 'data': data, + 'offset': offset, + }, + ) + + def get_final_io_task(self): + """Get the final io task to complete the download + + This is needed because based on the architecture of the TransferManager + the final tasks will be sent to the IO executor, but the executor + needs a final task for it to signal that the transfer is done and + all done callbacks can be run. + + :rtype: s3transfer.tasks.Task + :returns: A final task to completed in the io executor + """ + raise NotImplementedError('must implement get_final_io_task()') + + def _get_fileobj_from_filename(self, filename): + f = DeferredOpenFile( + filename, mode='wb', open_function=self._osutil.open + ) + # Make sure the file gets closed and we remove the temporary file + # if anything goes wrong during the process. + self._transfer_coordinator.add_failure_cleanup(f.close) + return f + + +class DownloadFilenameOutputManager(DownloadOutputManager): + def __init__(self, osutil, transfer_coordinator, io_executor): + super().__init__(osutil, transfer_coordinator, io_executor) + self._final_filename = None + self._temp_filename = None + self._temp_fileobj = None + + @classmethod + def is_compatible(cls, download_target, osutil): + return isinstance(download_target, str) + + def get_fileobj_for_io_writes(self, transfer_future): + fileobj = transfer_future.meta.call_args.fileobj + self._final_filename = fileobj + self._temp_filename = self._osutil.get_temp_filename(fileobj) + self._temp_fileobj = self._get_temp_fileobj() + return self._temp_fileobj + + def get_final_io_task(self): + # A task to rename the file from the temporary file to its final + # location is needed. This should be the last task needed to complete + # the download. + return IORenameFileTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'fileobj': self._temp_fileobj, + 'final_filename': self._final_filename, + 'osutil': self._osutil, + }, + is_final=True, + ) + + def _get_temp_fileobj(self): + f = self._get_fileobj_from_filename(self._temp_filename) + self._transfer_coordinator.add_failure_cleanup( + self._osutil.remove_file, self._temp_filename + ) + return f + + +class DownloadSeekableOutputManager(DownloadOutputManager): + @classmethod + def is_compatible(cls, download_target, osutil): + return seekable(download_target) + + def get_fileobj_for_io_writes(self, transfer_future): + # Return the fileobj provided to the future. + return transfer_future.meta.call_args.fileobj + + def get_final_io_task(self): + # This task will serve the purpose of signaling when all of the io + # writes have finished so done callbacks can be called. + return CompleteDownloadNOOPTask( + transfer_coordinator=self._transfer_coordinator + ) + + +class DownloadNonSeekableOutputManager(DownloadOutputManager): + def __init__( + self, osutil, transfer_coordinator, io_executor, defer_queue=None + ): + super().__init__(osutil, transfer_coordinator, io_executor) + if defer_queue is None: + defer_queue = DeferQueue() + self._defer_queue = defer_queue + self._io_submit_lock = threading.Lock() + + @classmethod + def is_compatible(cls, download_target, osutil): + return hasattr(download_target, 'write') + + def get_download_task_tag(self): + return IN_MEMORY_DOWNLOAD_TAG + + def get_fileobj_for_io_writes(self, transfer_future): + return transfer_future.meta.call_args.fileobj + + def get_final_io_task(self): + return CompleteDownloadNOOPTask( + transfer_coordinator=self._transfer_coordinator + ) + + def queue_file_io_task(self, fileobj, data, offset): + with self._io_submit_lock: + writes = self._defer_queue.request_writes(offset, data) + for write in writes: + data = write['data'] + logger.debug( + "Queueing IO offset %s for fileobj: %s", + write['offset'], + fileobj, + ) + super().queue_file_io_task(fileobj, data, offset) + + def get_io_write_task(self, fileobj, data, offset): + return IOStreamingWriteTask( + self._transfer_coordinator, + main_kwargs={ + 'fileobj': fileobj, + 'data': data, + }, + ) + + +class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager): + def __init__( + self, osutil, transfer_coordinator, io_executor, defer_queue=None + ): + super().__init__( + osutil, transfer_coordinator, io_executor, defer_queue + ) + self._fileobj = None + + @classmethod + def is_compatible(cls, download_target, osutil): + return isinstance(download_target, str) and osutil.is_special_file( + download_target + ) + + def get_fileobj_for_io_writes(self, transfer_future): + filename = transfer_future.meta.call_args.fileobj + self._fileobj = self._get_fileobj_from_filename(filename) + return self._fileobj + + def get_final_io_task(self): + # Make sure the file gets closed once the transfer is done. + return IOCloseTask( + transfer_coordinator=self._transfer_coordinator, + is_final=True, + main_kwargs={'fileobj': self._fileobj}, + ) + + +class DownloadSubmissionTask(SubmissionTask): + """Task for submitting tasks to execute a download""" + + def _get_download_output_manager_cls(self, transfer_future, osutil): + """Retrieves a class for managing output for a download + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future for the request + + :type osutil: s3transfer.utils.OSUtils + :param osutil: The os utility associated to the transfer + + :rtype: class of DownloadOutputManager + :returns: The appropriate class to use for managing a specific type of + input for downloads. + """ + download_manager_resolver_chain = [ + DownloadSpecialFilenameOutputManager, + DownloadFilenameOutputManager, + DownloadSeekableOutputManager, + DownloadNonSeekableOutputManager, + ] + + fileobj = transfer_future.meta.call_args.fileobj + for download_manager_cls in download_manager_resolver_chain: + if download_manager_cls.is_compatible(fileobj, osutil): + return download_manager_cls + raise RuntimeError( + 'Output {} of type: {} is not supported.'.format( + fileobj, type(fileobj) + ) + ) + + def _submit( + self, + client, + config, + osutil, + request_executor, + io_executor, + transfer_future, + bandwidth_limiter=None, + ): + """ + :param client: The client associated with the transfer manager + + :type config: s3transfer.manager.TransferConfig + :param config: The transfer config associated with the transfer + manager + + :type osutil: s3transfer.utils.OSUtil + :param osutil: The os utility associated to the transfer manager + + :type request_executor: s3transfer.futures.BoundedExecutor + :param request_executor: The request executor associated with the + transfer manager + + :type io_executor: s3transfer.futures.BoundedExecutor + :param io_executor: The io executor associated with the + transfer manager + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + + :type bandwidth_limiter: s3transfer.bandwidth.BandwidthLimiter + :param bandwidth_limiter: The bandwidth limiter to use when + downloading streams + """ + if transfer_future.meta.size is None: + # If a size was not provided figure out the size for the + # user. + response = client.head_object( + Bucket=transfer_future.meta.call_args.bucket, + Key=transfer_future.meta.call_args.key, + **transfer_future.meta.call_args.extra_args, + ) + transfer_future.meta.provide_transfer_size( + response['ContentLength'] + ) + + download_output_manager = self._get_download_output_manager_cls( + transfer_future, osutil + )(osutil, self._transfer_coordinator, io_executor) + + # If it is greater than threshold do a ranged download, otherwise + # do a regular GetObject download. + if transfer_future.meta.size < config.multipart_threshold: + self._submit_download_request( + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ) + else: + self._submit_ranged_download_request( + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ) + + def _submit_download_request( + self, + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ): + call_args = transfer_future.meta.call_args + + # Get a handle to the file that will be used for writing downloaded + # contents + fileobj = download_output_manager.get_fileobj_for_io_writes( + transfer_future + ) + + # Get the needed callbacks for the task + progress_callbacks = get_callbacks(transfer_future, 'progress') + + # Get any associated tags for the get object task. + get_object_tag = download_output_manager.get_download_task_tag() + + # Get the final io task to run once the download is complete. + final_task = download_output_manager.get_final_io_task() + + # Submit the task to download the object. + self._transfer_coordinator.submit( + request_executor, + ImmediatelyWriteIOGetObjectTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'fileobj': fileobj, + 'extra_args': call_args.extra_args, + 'callbacks': progress_callbacks, + 'max_attempts': config.num_download_attempts, + 'download_output_manager': download_output_manager, + 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter, + }, + done_callbacks=[final_task], + ), + tag=get_object_tag, + ) + + def _submit_ranged_download_request( + self, + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ): + call_args = transfer_future.meta.call_args + + # Get the needed progress callbacks for the task + progress_callbacks = get_callbacks(transfer_future, 'progress') + + # Get a handle to the file that will be used for writing downloaded + # contents + fileobj = download_output_manager.get_fileobj_for_io_writes( + transfer_future + ) + + # Determine the number of parts + part_size = config.multipart_chunksize + num_parts = calculate_num_parts(transfer_future.meta.size, part_size) + + # Get any associated tags for the get object task. + get_object_tag = download_output_manager.get_download_task_tag() + + # Callback invoker to submit the final io task once all downloads + # are complete. + finalize_download_invoker = CountCallbackInvoker( + self._get_final_io_task_submission_callback( + download_output_manager, io_executor + ) + ) + for i in range(num_parts): + # Calculate the range parameter + range_parameter = calculate_range_parameter( + part_size, i, num_parts + ) + + # Inject the Range parameter to the parameters to be passed in + # as extra args + extra_args = {'Range': range_parameter} + extra_args.update(call_args.extra_args) + finalize_download_invoker.increment() + # Submit the ranged downloads + self._transfer_coordinator.submit( + request_executor, + GetObjectTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'fileobj': fileobj, + 'extra_args': extra_args, + 'callbacks': progress_callbacks, + 'max_attempts': config.num_download_attempts, + 'start_index': i * part_size, + 'download_output_manager': download_output_manager, + 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter, + }, + done_callbacks=[finalize_download_invoker.decrement], + ), + tag=get_object_tag, + ) + finalize_download_invoker.finalize() + + def _get_final_io_task_submission_callback( + self, download_manager, io_executor + ): + final_task = download_manager.get_final_io_task() + return FunctionContainer( + self._transfer_coordinator.submit, io_executor, final_task + ) + + def _calculate_range_param(self, part_size, part_index, num_parts): + # Used to calculate the Range parameter + start_range = part_index * part_size + if part_index == num_parts - 1: + end_range = '' + else: + end_range = start_range + part_size - 1 + range_param = f'bytes={start_range}-{end_range}' + return range_param + + +class GetObjectTask(Task): + def _main( + self, + client, + bucket, + key, + fileobj, + extra_args, + callbacks, + max_attempts, + download_output_manager, + io_chunksize, + start_index=0, + bandwidth_limiter=None, + ): + """Downloads an object and places content into io queue + + :param client: The client to use when calling GetObject + :param bucket: The bucket to download from + :param key: The key to download from + :param fileobj: The file handle to write content to + :param exta_args: Any extra arguments to include in GetObject request + :param callbacks: List of progress callbacks to invoke on download + :param max_attempts: The number of retries to do when downloading + :param download_output_manager: The download output manager associated + with the current download. + :param io_chunksize: The size of each io chunk to read from the + download stream and queue in the io queue. + :param start_index: The location in the file to start writing the + content of the key to. + :param bandwidth_limiter: The bandwidth limiter to use when throttling + the downloading of data in streams. + """ + last_exception = None + for i in range(max_attempts): + try: + current_index = start_index + response = client.get_object( + Bucket=bucket, Key=key, **extra_args + ) + streaming_body = StreamReaderProgress( + response['Body'], callbacks + ) + if bandwidth_limiter: + streaming_body = ( + bandwidth_limiter.get_bandwith_limited_stream( + streaming_body, self._transfer_coordinator + ) + ) + + chunks = DownloadChunkIterator(streaming_body, io_chunksize) + for chunk in chunks: + # If the transfer is done because of a cancellation + # or error somewhere else, stop trying to submit more + # data to be written and break out of the download. + if not self._transfer_coordinator.done(): + self._handle_io( + download_output_manager, + fileobj, + chunk, + current_index, + ) + current_index += len(chunk) + else: + return + return + except S3_RETRYABLE_DOWNLOAD_ERRORS as e: + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) + last_exception = e + # Also invoke the progress callbacks to indicate that we + # are trying to download the stream again and all progress + # for this GetObject has been lost. + invoke_progress_callbacks( + callbacks, start_index - current_index + ) + continue + raise RetriesExceededError(last_exception) + + def _handle_io(self, download_output_manager, fileobj, chunk, index): + download_output_manager.queue_file_io_task(fileobj, chunk, index) + + +class ImmediatelyWriteIOGetObjectTask(GetObjectTask): + """GetObjectTask that immediately writes to the provided file object + + This is useful for downloads where it is known only one thread is + downloading the object so there is no reason to go through the + overhead of using an IO queue and executor. + """ + + def _handle_io(self, download_output_manager, fileobj, chunk, index): + task = download_output_manager.get_io_write_task(fileobj, chunk, index) + task() + + +class IOWriteTask(Task): + def _main(self, fileobj, data, offset): + """Pulls off an io queue to write contents to a file + + :param fileobj: The file handle to write content to + :param data: The data to write + :param offset: The offset to write the data to. + """ + fileobj.seek(offset) + fileobj.write(data) + + +class IOStreamingWriteTask(Task): + """Task for writing data to a non-seekable stream.""" + + def _main(self, fileobj, data): + """Write data to a fileobj. + + Data will be written directly to the fileobj without + any prior seeking. + + :param fileobj: The fileobj to write content to + :param data: The data to write + + """ + fileobj.write(data) + + +class IORenameFileTask(Task): + """A task to rename a temporary file to its final filename + + :param fileobj: The file handle that content was written to. + :param final_filename: The final name of the file to rename to + upon completion of writing the contents. + :param osutil: OS utility + """ + + def _main(self, fileobj, final_filename, osutil): + fileobj.close() + osutil.rename_file(fileobj.name, final_filename) + + +class IOCloseTask(Task): + """A task to close out a file once the download is complete. + + :param fileobj: The fileobj to close. + """ + + def _main(self, fileobj): + fileobj.close() + + +class CompleteDownloadNOOPTask(Task): + """A NOOP task to serve as an indicator that the download is complete + + Note that the default for is_final is set to True because this should + always be the last task. + """ + + def __init__( + self, + transfer_coordinator, + main_kwargs=None, + pending_main_kwargs=None, + done_callbacks=None, + is_final=True, + ): + super().__init__( + transfer_coordinator=transfer_coordinator, + main_kwargs=main_kwargs, + pending_main_kwargs=pending_main_kwargs, + done_callbacks=done_callbacks, + is_final=is_final, + ) + + def _main(self): + pass + + +class DownloadChunkIterator: + def __init__(self, body, chunksize): + """Iterator to chunk out a downloaded S3 stream + + :param body: A readable file-like object + :param chunksize: The amount to read each time + """ + self._body = body + self._chunksize = chunksize + self._num_reads = 0 + + def __iter__(self): + return self + + def __next__(self): + chunk = self._body.read(self._chunksize) + self._num_reads += 1 + if chunk: + return chunk + elif self._num_reads == 1: + # Even though the response may have not had any + # content, we still want to account for an empty object's + # existence so return the empty chunk for that initial + # read. + return chunk + raise StopIteration() + + next = __next__ + + +class DeferQueue: + """IO queue that defers write requests until they are queued sequentially. + + This class is used to track IO data for a *single* fileobj. + + You can send data to this queue, and it will defer any IO write requests + until it has the next contiguous block available (starting at 0). + + """ + + def __init__(self): + self._writes = [] + self._pending_offsets = set() + self._next_offset = 0 + + def request_writes(self, offset, data): + """Request any available writes given new incoming data. + + You call this method by providing new data along with the + offset associated with the data. If that new data unlocks + any contiguous writes that can now be submitted, this + method will return all applicable writes. + + This is done with 1 method call so you don't have to + make two method calls (put(), get()) which acquires a lock + each method call. + + """ + if offset < self._next_offset: + # This is a request for a write that we've already + # seen. This can happen in the event of a retry + # where if we retry at at offset N/2, we'll requeue + # offsets 0-N/2 again. + return [] + writes = [] + if offset in self._pending_offsets: + # We've already queued this offset so this request is + # a duplicate. In this case we should ignore + # this request and prefer what's already queued. + return [] + heapq.heappush(self._writes, (offset, data)) + self._pending_offsets.add(offset) + while self._writes and self._writes[0][0] == self._next_offset: + next_write = heapq.heappop(self._writes) + writes.append({'offset': next_write[0], 'data': next_write[1]}) + self._pending_offsets.remove(next_write[0]) + self._next_offset += len(next_write[1]) + return writes diff --git a/contrib/python/s3transfer/py3/s3transfer/exceptions.py b/contrib/python/s3transfer/py3/s3transfer/exceptions.py new file mode 100644 index 0000000000..6150fe650d --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/exceptions.py @@ -0,0 +1,37 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from concurrent.futures import CancelledError + + +class RetriesExceededError(Exception): + def __init__(self, last_exception, msg='Max Retries Exceeded'): + super().__init__(msg) + self.last_exception = last_exception + + +class S3UploadFailedError(Exception): + pass + + +class InvalidSubscriberMethodError(Exception): + pass + + +class TransferNotDoneError(Exception): + pass + + +class FatalError(CancelledError): + """A CancelledError raised from an error in the TransferManager""" + + pass diff --git a/contrib/python/s3transfer/py3/s3transfer/futures.py b/contrib/python/s3transfer/py3/s3transfer/futures.py new file mode 100644 index 0000000000..39e071fb60 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/futures.py @@ -0,0 +1,606 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import logging +import sys +import threading +from collections import namedtuple +from concurrent import futures + +from s3transfer.compat import MAXINT +from s3transfer.exceptions import CancelledError, TransferNotDoneError +from s3transfer.utils import FunctionContainer, TaskSemaphore + +logger = logging.getLogger(__name__) + + +class BaseTransferFuture: + @property + def meta(self): + """The metadata associated to the TransferFuture""" + raise NotImplementedError('meta') + + def done(self): + """Determines if a TransferFuture has completed + + :returns: True if completed. False, otherwise. + """ + raise NotImplementedError('done()') + + def result(self): + """Waits until TransferFuture is done and returns the result + + If the TransferFuture succeeded, it will return the result. If the + TransferFuture failed, it will raise the exception associated to the + failure. + """ + raise NotImplementedError('result()') + + def cancel(self): + """Cancels the request associated with the TransferFuture""" + raise NotImplementedError('cancel()') + + +class BaseTransferMeta: + @property + def call_args(self): + """The call args used in the transfer request""" + raise NotImplementedError('call_args') + + @property + def transfer_id(self): + """The unique id of the transfer""" + raise NotImplementedError('transfer_id') + + @property + def user_context(self): + """A dictionary that requesters can store data in""" + raise NotImplementedError('user_context') + + +class TransferFuture(BaseTransferFuture): + def __init__(self, meta=None, coordinator=None): + """The future associated to a submitted transfer request + + :type meta: TransferMeta + :param meta: The metadata associated to the request. This object + is visible to the requester. + + :type coordinator: TransferCoordinator + :param coordinator: The coordinator associated to the request. This + object is not visible to the requester. + """ + self._meta = meta + if meta is None: + self._meta = TransferMeta() + + self._coordinator = coordinator + if coordinator is None: + self._coordinator = TransferCoordinator() + + @property + def meta(self): + return self._meta + + def done(self): + return self._coordinator.done() + + def result(self): + try: + # Usually the result() method blocks until the transfer is done, + # however if a KeyboardInterrupt is raised we want want to exit + # out of this and propagate the exception. + return self._coordinator.result() + except KeyboardInterrupt as e: + self.cancel() + raise e + + def cancel(self): + self._coordinator.cancel() + + def set_exception(self, exception): + """Sets the exception on the future.""" + if not self.done(): + raise TransferNotDoneError( + 'set_exception can only be called once the transfer is ' + 'complete.' + ) + self._coordinator.set_exception(exception, override=True) + + +class TransferMeta(BaseTransferMeta): + """Holds metadata about the TransferFuture""" + + def __init__(self, call_args=None, transfer_id=None): + self._call_args = call_args + self._transfer_id = transfer_id + self._size = None + self._user_context = {} + + @property + def call_args(self): + """The call args used in the transfer request""" + return self._call_args + + @property + def transfer_id(self): + """The unique id of the transfer""" + return self._transfer_id + + @property + def size(self): + """The size of the transfer request if known""" + return self._size + + @property + def user_context(self): + """A dictionary that requesters can store data in""" + return self._user_context + + def provide_transfer_size(self, size): + """A method to provide the size of a transfer request + + By providing this value, the TransferManager will not try to + call HeadObject or use the use OS to determine the size of the + transfer. + """ + self._size = size + + +class TransferCoordinator: + """A helper class for managing TransferFuture""" + + def __init__(self, transfer_id=None): + self.transfer_id = transfer_id + self._status = 'not-started' + self._result = None + self._exception = None + self._associated_futures = set() + self._failure_cleanups = [] + self._done_callbacks = [] + self._done_event = threading.Event() + self._lock = threading.Lock() + self._associated_futures_lock = threading.Lock() + self._done_callbacks_lock = threading.Lock() + self._failure_cleanups_lock = threading.Lock() + + def __repr__(self): + return '{}(transfer_id={})'.format( + self.__class__.__name__, self.transfer_id + ) + + @property + def exception(self): + return self._exception + + @property + def associated_futures(self): + """The list of futures associated to the inprogress TransferFuture + + Once the transfer finishes this list becomes empty as the transfer + is considered done and there should be no running futures left. + """ + with self._associated_futures_lock: + # We return a copy of the list because we do not want to + # processing the returned list while another thread is adding + # more futures to the actual list. + return copy.copy(self._associated_futures) + + @property + def failure_cleanups(self): + """The list of callbacks to call when the TransferFuture fails""" + return self._failure_cleanups + + @property + def status(self): + """The status of the TransferFuture + + The currently supported states are: + * not-started - Has yet to start. If in this state, a transfer + can be canceled immediately and nothing will happen. + * queued - SubmissionTask is about to submit tasks + * running - Is inprogress. In-progress as of now means that + the SubmissionTask that runs the transfer is being executed. So + there is no guarantee any transfer requests had been made to + S3 if this state is reached. + * cancelled - Was cancelled + * failed - An exception other than CancelledError was thrown + * success - No exceptions were thrown and is done. + """ + return self._status + + def set_result(self, result): + """Set a result for the TransferFuture + + Implies that the TransferFuture succeeded. This will always set a + result because it is invoked on the final task where there is only + ever one final task and it is ran at the very end of a transfer + process. So if a result is being set for this final task, the transfer + succeeded even if something came a long and canceled the transfer + on the final task. + """ + with self._lock: + self._exception = None + self._result = result + self._status = 'success' + + def set_exception(self, exception, override=False): + """Set an exception for the TransferFuture + + Implies the TransferFuture failed. + + :param exception: The exception that cause the transfer to fail. + :param override: If True, override any existing state. + """ + with self._lock: + if not self.done() or override: + self._exception = exception + self._status = 'failed' + + def result(self): + """Waits until TransferFuture is done and returns the result + + If the TransferFuture succeeded, it will return the result. If the + TransferFuture failed, it will raise the exception associated to the + failure. + """ + # Doing a wait() with no timeout cannot be interrupted in python2 but + # can be interrupted in python3 so we just wait with the largest + # possible value integer value, which is on the scale of billions of + # years... + self._done_event.wait(MAXINT) + + # Once done waiting, raise an exception if present or return the + # final result. + if self._exception: + raise self._exception + return self._result + + def cancel(self, msg='', exc_type=CancelledError): + """Cancels the TransferFuture + + :param msg: The message to attach to the cancellation + :param exc_type: The type of exception to set for the cancellation + """ + with self._lock: + if not self.done(): + should_announce_done = False + logger.debug('%s cancel(%s) called', self, msg) + self._exception = exc_type(msg) + if self._status == 'not-started': + should_announce_done = True + self._status = 'cancelled' + if should_announce_done: + self.announce_done() + + def set_status_to_queued(self): + """Sets the TransferFutrue's status to running""" + self._transition_to_non_done_state('queued') + + def set_status_to_running(self): + """Sets the TransferFuture's status to running""" + self._transition_to_non_done_state('running') + + def _transition_to_non_done_state(self, desired_state): + with self._lock: + if self.done(): + raise RuntimeError( + 'Unable to transition from done state %s to non-done ' + 'state %s.' % (self.status, desired_state) + ) + self._status = desired_state + + def submit(self, executor, task, tag=None): + """Submits a task to a provided executor + + :type executor: s3transfer.futures.BoundedExecutor + :param executor: The executor to submit the callable to + + :type task: s3transfer.tasks.Task + :param task: The task to submit to the executor + + :type tag: s3transfer.futures.TaskTag + :param tag: A tag to associate to the submitted task + + :rtype: concurrent.futures.Future + :returns: A future representing the submitted task + """ + logger.debug( + "Submitting task {} to executor {} for transfer request: {}.".format( + task, executor, self.transfer_id + ) + ) + future = executor.submit(task, tag=tag) + # Add this created future to the list of associated future just + # in case it is needed during cleanups. + self.add_associated_future(future) + future.add_done_callback( + FunctionContainer(self.remove_associated_future, future) + ) + return future + + def done(self): + """Determines if a TransferFuture has completed + + :returns: False if status is equal to 'failed', 'cancelled', or + 'success'. True, otherwise + """ + return self.status in ['failed', 'cancelled', 'success'] + + def add_associated_future(self, future): + """Adds a future to be associated with the TransferFuture""" + with self._associated_futures_lock: + self._associated_futures.add(future) + + def remove_associated_future(self, future): + """Removes a future's association to the TransferFuture""" + with self._associated_futures_lock: + self._associated_futures.remove(future) + + def add_done_callback(self, function, *args, **kwargs): + """Add a done callback to be invoked when transfer is done""" + with self._done_callbacks_lock: + self._done_callbacks.append( + FunctionContainer(function, *args, **kwargs) + ) + + def add_failure_cleanup(self, function, *args, **kwargs): + """Adds a callback to call upon failure""" + with self._failure_cleanups_lock: + self._failure_cleanups.append( + FunctionContainer(function, *args, **kwargs) + ) + + def announce_done(self): + """Announce that future is done running and run associated callbacks + + This will run any failure cleanups if the transfer failed if not + they have not been run, allows the result() to be unblocked, and will + run any done callbacks associated to the TransferFuture if they have + not already been ran. + """ + if self.status != 'success': + self._run_failure_cleanups() + self._done_event.set() + self._run_done_callbacks() + + def _run_done_callbacks(self): + # Run the callbacks and remove the callbacks from the internal + # list so they do not get ran again if done is announced more than + # once. + with self._done_callbacks_lock: + self._run_callbacks(self._done_callbacks) + self._done_callbacks = [] + + def _run_failure_cleanups(self): + # Run the cleanup callbacks and remove the callbacks from the internal + # list so they do not get ran again if done is announced more than + # once. + with self._failure_cleanups_lock: + self._run_callbacks(self.failure_cleanups) + self._failure_cleanups = [] + + def _run_callbacks(self, callbacks): + for callback in callbacks: + self._run_callback(callback) + + def _run_callback(self, callback): + try: + callback() + # We do not want a callback interrupting the process, especially + # in the failure cleanups. So log and catch, the exception. + except Exception: + logger.debug("Exception raised in %s." % callback, exc_info=True) + + +class BoundedExecutor: + EXECUTOR_CLS = futures.ThreadPoolExecutor + + def __init__( + self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None + ): + """An executor implementation that has a maximum queued up tasks + + The executor will block if the number of tasks that have been + submitted and is currently working on is past its maximum. + + :params max_size: The maximum number of inflight futures. An inflight + future means that the task is either queued up or is currently + being executed. A size of None or 0 means that the executor will + have no bound in terms of the number of inflight futures. + + :params max_num_threads: The maximum number of threads the executor + uses. + + :type tag_semaphores: dict + :params tag_semaphores: A dictionary where the key is the name of the + tag and the value is the semaphore to use when limiting the + number of tasks the executor is processing at a time. + + :type executor_cls: BaseExecutor + :param underlying_executor_cls: The executor class that + get bounded by this executor. If None is provided, the + concurrent.futures.ThreadPoolExecutor class is used. + """ + self._max_num_threads = max_num_threads + if executor_cls is None: + executor_cls = self.EXECUTOR_CLS + self._executor = executor_cls(max_workers=self._max_num_threads) + self._semaphore = TaskSemaphore(max_size) + self._tag_semaphores = tag_semaphores + + def submit(self, task, tag=None, block=True): + """Submit a task to complete + + :type task: s3transfer.tasks.Task + :param task: The task to run __call__ on + + + :type tag: s3transfer.futures.TaskTag + :param tag: An optional tag to associate to the task. This + is used to override which semaphore to use. + + :type block: boolean + :param block: True if to wait till it is possible to submit a task. + False, if not to wait and raise an error if not able to submit + a task. + + :returns: The future associated to the submitted task + """ + semaphore = self._semaphore + # If a tag was provided, use the semaphore associated to that + # tag. + if tag: + semaphore = self._tag_semaphores[tag] + + # Call acquire on the semaphore. + acquire_token = semaphore.acquire(task.transfer_id, block) + # Create a callback to invoke when task is done in order to call + # release on the semaphore. + release_callback = FunctionContainer( + semaphore.release, task.transfer_id, acquire_token + ) + # Submit the task to the underlying executor. + future = ExecutorFuture(self._executor.submit(task)) + # Add the Semaphore.release() callback to the future such that + # it is invoked once the future completes. + future.add_done_callback(release_callback) + return future + + def shutdown(self, wait=True): + self._executor.shutdown(wait) + + +class ExecutorFuture: + def __init__(self, future): + """A future returned from the executor + + Currently, it is just a wrapper around a concurrent.futures.Future. + However, this can eventually grow to implement the needed functionality + of concurrent.futures.Future if we move off of the library and not + affect the rest of the codebase. + + :type future: concurrent.futures.Future + :param future: The underlying future + """ + self._future = future + + def result(self): + return self._future.result() + + def add_done_callback(self, fn): + """Adds a callback to be completed once future is done + + :param fn: A callable that takes no arguments. Note that is different + than concurrent.futures.Future.add_done_callback that requires + a single argument for the future. + """ + # The done callback for concurrent.futures.Future will always pass a + # the future in as the only argument. So we need to create the + # proper signature wrapper that will invoke the callback provided. + def done_callback(future_passed_to_callback): + return fn() + + self._future.add_done_callback(done_callback) + + def done(self): + return self._future.done() + + +class BaseExecutor: + """Base Executor class implementation needed to work with s3transfer""" + + def __init__(self, max_workers=None): + pass + + def submit(self, fn, *args, **kwargs): + raise NotImplementedError('submit()') + + def shutdown(self, wait=True): + raise NotImplementedError('shutdown()') + + +class NonThreadedExecutor(BaseExecutor): + """A drop-in replacement non-threaded version of ThreadPoolExecutor""" + + def submit(self, fn, *args, **kwargs): + future = NonThreadedExecutorFuture() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception: + e, tb = sys.exc_info()[1:] + logger.debug( + 'Setting exception for %s to %s with traceback %s', + future, + e, + tb, + ) + future.set_exception_info(e, tb) + return future + + def shutdown(self, wait=True): + pass + + +class NonThreadedExecutorFuture: + """The Future returned from NonThreadedExecutor + + Note that this future is **not** thread-safe as it is being used + from the context of a non-threaded environment. + """ + + def __init__(self): + self._result = None + self._exception = None + self._traceback = None + self._done = False + self._done_callbacks = [] + + def set_result(self, result): + self._result = result + self._set_done() + + def set_exception_info(self, exception, traceback): + self._exception = exception + self._traceback = traceback + self._set_done() + + def result(self, timeout=None): + if self._exception: + raise self._exception.with_traceback(self._traceback) + return self._result + + def _set_done(self): + self._done = True + for done_callback in self._done_callbacks: + self._invoke_done_callback(done_callback) + self._done_callbacks = [] + + def _invoke_done_callback(self, done_callback): + return done_callback(self) + + def done(self): + return self._done + + def add_done_callback(self, fn): + if self._done: + self._invoke_done_callback(fn) + else: + self._done_callbacks.append(fn) + + +TaskTag = namedtuple('TaskTag', ['name']) + +IN_MEMORY_UPLOAD_TAG = TaskTag('in_memory_upload') +IN_MEMORY_DOWNLOAD_TAG = TaskTag('in_memory_download') diff --git a/contrib/python/s3transfer/py3/s3transfer/manager.py b/contrib/python/s3transfer/py3/s3transfer/manager.py new file mode 100644 index 0000000000..ff6afa12c1 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/manager.py @@ -0,0 +1,727 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import logging +import re +import threading + +from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket +from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB +from s3transfer.copies import CopySubmissionTask +from s3transfer.delete import DeleteSubmissionTask +from s3transfer.download import DownloadSubmissionTask +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import ( + IN_MEMORY_DOWNLOAD_TAG, + IN_MEMORY_UPLOAD_TAG, + BoundedExecutor, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.upload import UploadSubmissionTask +from s3transfer.utils import ( + CallArgs, + OSUtils, + SlidingWindowSemaphore, + TaskSemaphore, + get_callbacks, + signal_not_transferring, + signal_transferring, +) + +logger = logging.getLogger(__name__) + + +class TransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + multipart_chunksize=8 * MB, + max_request_concurrency=10, + max_submission_concurrency=5, + max_request_queue_size=1000, + max_submission_queue_size=1000, + max_io_queue_size=1000, + io_chunksize=256 * KB, + num_download_attempts=5, + max_in_memory_upload_chunks=10, + max_in_memory_download_chunks=10, + max_bandwidth=None, + ): + """Configurations for the transfer manager + + :param multipart_threshold: The threshold for which multipart + transfers occur. + + :param max_request_concurrency: The maximum number of S3 API + transfer-related requests that can happen at a time. + + :param max_submission_concurrency: The maximum number of threads + processing a call to a TransferManager method. Processing a + call usually entails determining which S3 API requests that need + to be enqueued, but does **not** entail making any of the + S3 API data transferring requests needed to perform the transfer. + The threads controlled by ``max_request_concurrency`` is + responsible for that. + + :param multipart_chunksize: The size of each transfer if a request + becomes a multipart transfer. + + :param max_request_queue_size: The maximum amount of S3 API requests + that can be queued at a time. + + :param max_submission_queue_size: The maximum amount of + TransferManager method calls that can be queued at a time. + + :param max_io_queue_size: The maximum amount of read parts that + can be queued to be written to disk per download. The default + size for each elementin this queue is 8 KB. + + :param io_chunksize: The max size of each chunk in the io queue. + Currently, this is size used when reading from the downloaded + stream as well. + + :param num_download_attempts: The number of download attempts that + will be tried upon errors with downloading an object in S3. Note + that these retries account for errors that occur when streaming + down the data from s3 (i.e. socket errors and read timeouts that + occur after receiving an OK response from s3). + Other retryable exceptions such as throttling errors and 5xx errors + are already retried by botocore (this default is 5). The + ``num_download_attempts`` does not take into account the + number of exceptions retried by botocore. + + :param max_in_memory_upload_chunks: The number of chunks that can + be stored in memory at a time for all ongoing upload requests. + This pertains to chunks of data that need to be stored in memory + during an upload if the data is sourced from a file-like object. + The total maximum memory footprint due to a in-memory upload + chunks is roughly equal to: + + max_in_memory_upload_chunks * multipart_chunksize + + max_submission_concurrency * multipart_chunksize + + ``max_submission_concurrency`` has an affect on this value because + for each thread pulling data off of a file-like object, they may + be waiting with a single read chunk to be submitted for upload + because the ``max_in_memory_upload_chunks`` value has been reached + by the threads making the upload request. + + :param max_in_memory_download_chunks: The number of chunks that can + be buffered in memory and **not** in the io queue at a time for all + ongoing download requests. This pertains specifically to file-like + objects that cannot be seeked. The total maximum memory footprint + due to a in-memory download chunks is roughly equal to: + + max_in_memory_download_chunks * multipart_chunksize + + :param max_bandwidth: The maximum bandwidth that will be consumed + in uploading and downloading file content. The value is in terms of + bytes per second. + """ + self.multipart_threshold = multipart_threshold + self.multipart_chunksize = multipart_chunksize + self.max_request_concurrency = max_request_concurrency + self.max_submission_concurrency = max_submission_concurrency + self.max_request_queue_size = max_request_queue_size + self.max_submission_queue_size = max_submission_queue_size + self.max_io_queue_size = max_io_queue_size + self.io_chunksize = io_chunksize + self.num_download_attempts = num_download_attempts + self.max_in_memory_upload_chunks = max_in_memory_upload_chunks + self.max_in_memory_download_chunks = max_in_memory_download_chunks + self.max_bandwidth = max_bandwidth + self._validate_attrs_are_nonzero() + + def _validate_attrs_are_nonzero(self): + for attr, attr_val in self.__dict__.items(): + if attr_val is not None and attr_val <= 0: + raise ValueError( + 'Provided parameter %s of value %s must be greater than ' + '0.' % (attr, attr_val) + ) + + +class TransferManager: + ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS + + ALLOWED_UPLOAD_ARGS = [ + 'ACL', + 'CacheControl', + 'ContentDisposition', + 'ContentEncoding', + 'ContentLanguage', + 'ContentType', + 'ExpectedBucketOwner', + 'Expires', + 'GrantFullControl', + 'GrantRead', + 'GrantReadACP', + 'GrantWriteACP', + 'Metadata', + 'RequestPayer', + 'ServerSideEncryption', + 'StorageClass', + 'SSECustomerAlgorithm', + 'SSECustomerKey', + 'SSECustomerKeyMD5', + 'SSEKMSKeyId', + 'SSEKMSEncryptionContext', + 'Tagging', + 'WebsiteRedirectLocation', + ] + + ALLOWED_COPY_ARGS = ALLOWED_UPLOAD_ARGS + [ + 'CopySourceIfMatch', + 'CopySourceIfModifiedSince', + 'CopySourceIfNoneMatch', + 'CopySourceIfUnmodifiedSince', + 'CopySourceSSECustomerAlgorithm', + 'CopySourceSSECustomerKey', + 'CopySourceSSECustomerKeyMD5', + 'MetadataDirective', + 'TaggingDirective', + ] + + ALLOWED_DELETE_ARGS = [ + 'MFA', + 'VersionId', + 'RequestPayer', + 'ExpectedBucketOwner', + ] + + VALIDATE_SUPPORTED_BUCKET_VALUES = True + + _UNSUPPORTED_BUCKET_PATTERNS = { + 'S3 Object Lambda': re.compile( + r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:' + r'accesspoint[/:][a-zA-Z0-9\-]{1,63}' + ), + } + + def __init__(self, client, config=None, osutil=None, executor_cls=None): + """A transfer manager interface for Amazon S3 + + :param client: Client to be used by the manager + :param config: TransferConfig to associate specific configurations + :param osutil: OSUtils object to use for os-related behavior when + using with transfer manager. + + :type executor_cls: s3transfer.futures.BaseExecutor + :param executor_cls: The class of executor to use with the transfer + manager. By default, concurrent.futures.ThreadPoolExecutor is used. + """ + self._client = client + self._config = config + if config is None: + self._config = TransferConfig() + self._osutil = osutil + if osutil is None: + self._osutil = OSUtils() + self._coordinator_controller = TransferCoordinatorController() + # A counter to create unique id's for each transfer submitted. + self._id_counter = 0 + + # The executor responsible for making S3 API transfer requests + self._request_executor = BoundedExecutor( + max_size=self._config.max_request_queue_size, + max_num_threads=self._config.max_request_concurrency, + tag_semaphores={ + IN_MEMORY_UPLOAD_TAG: TaskSemaphore( + self._config.max_in_memory_upload_chunks + ), + IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore( + self._config.max_in_memory_download_chunks + ), + }, + executor_cls=executor_cls, + ) + + # The executor responsible for submitting the necessary tasks to + # perform the desired transfer + self._submission_executor = BoundedExecutor( + max_size=self._config.max_submission_queue_size, + max_num_threads=self._config.max_submission_concurrency, + executor_cls=executor_cls, + ) + + # There is one thread available for writing to disk. It will handle + # downloads for all files. + self._io_executor = BoundedExecutor( + max_size=self._config.max_io_queue_size, + max_num_threads=1, + executor_cls=executor_cls, + ) + + # The component responsible for limiting bandwidth usage if it + # is configured. + self._bandwidth_limiter = None + if self._config.max_bandwidth is not None: + logger.debug( + 'Setting max_bandwidth to %s', self._config.max_bandwidth + ) + leaky_bucket = LeakyBucket(self._config.max_bandwidth) + self._bandwidth_limiter = BandwidthLimiter(leaky_bucket) + + self._register_handlers() + + @property + def client(self): + return self._client + + @property + def config(self): + return self._config + + def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): + """Uploads a file to S3 + + :type fileobj: str or seekable file-like object + :param fileobj: The name of a file to upload or a seekable file-like + object to upload. It is recommended to use a filename because + file-like objects may result in higher memory usage. + + :type bucket: str + :param bucket: The name of the bucket to upload to + + :type key: str + :param key: The name of the key to upload to + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + client operation + + :type subscribers: list(s3transfer.subscribers.BaseSubscriber) + :param subscribers: The list of subscribers to be invoked in the + order provided based on the event emit during the process of + the transfer request. + + :rtype: s3transfer.futures.TransferFuture + :returns: Transfer future representing the upload + """ + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = [] + self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS) + self._validate_if_bucket_supported(bucket) + call_args = CallArgs( + fileobj=fileobj, + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + ) + extra_main_kwargs = {} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, UploadSubmissionTask, extra_main_kwargs + ) + + def download( + self, bucket, key, fileobj, extra_args=None, subscribers=None + ): + """Downloads a file from S3 + + :type bucket: str + :param bucket: The name of the bucket to download from + + :type key: str + :param key: The name of the key to download from + + :type fileobj: str or seekable file-like object + :param fileobj: The name of a file to download or a seekable file-like + object to download. It is recommended to use a filename because + file-like objects may result in higher memory usage. + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + client operation + + :type subscribers: list(s3transfer.subscribers.BaseSubscriber) + :param subscribers: The list of subscribers to be invoked in the + order provided based on the event emit during the process of + the transfer request. + + :rtype: s3transfer.futures.TransferFuture + :returns: Transfer future representing the download + """ + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = [] + self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS) + self._validate_if_bucket_supported(bucket) + call_args = CallArgs( + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, + ) + extra_main_kwargs = {'io_executor': self._io_executor} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, DownloadSubmissionTask, extra_main_kwargs + ) + + def copy( + self, + copy_source, + bucket, + key, + extra_args=None, + subscribers=None, + source_client=None, + ): + """Copies a file in S3 + + :type copy_source: dict + :param copy_source: The name of the source bucket, key name of the + source object, and optional version ID of the source object. The + dictionary format is: + ``{'Bucket': 'bucket', 'Key': 'key', 'VersionId': 'id'}``. Note + that the ``VersionId`` key is optional and may be omitted. + + :type bucket: str + :param bucket: The name of the bucket to copy to + + :type key: str + :param key: The name of the key to copy to + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + client operation + + :type subscribers: a list of subscribers + :param subscribers: The list of subscribers to be invoked in the + order provided based on the event emit during the process of + the transfer request. + + :type source_client: botocore or boto3 Client + :param source_client: The client to be used for operation that + may happen at the source object. For example, this client is + used for the head_object that determines the size of the copy. + If no client is provided, the transfer manager's client is used + as the client for the source object. + + :rtype: s3transfer.futures.TransferFuture + :returns: Transfer future representing the copy + """ + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = [] + if source_client is None: + source_client = self._client + self._validate_all_known_args(extra_args, self.ALLOWED_COPY_ARGS) + if isinstance(copy_source, dict): + self._validate_if_bucket_supported(copy_source.get('Bucket')) + self._validate_if_bucket_supported(bucket) + call_args = CallArgs( + copy_source=copy_source, + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + source_client=source_client, + ) + return self._submit_transfer(call_args, CopySubmissionTask) + + def delete(self, bucket, key, extra_args=None, subscribers=None): + """Delete an S3 object. + + :type bucket: str + :param bucket: The name of the bucket. + + :type key: str + :param key: The name of the S3 object to delete. + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + DeleteObject call. + + :type subscribers: list + :param subscribers: A list of subscribers to be invoked during the + process of the transfer request. Note that the ``on_progress`` + callback is not invoked during object deletion. + + :rtype: s3transfer.futures.TransferFuture + :return: Transfer future representing the deletion. + + """ + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = [] + self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS) + self._validate_if_bucket_supported(bucket) + call_args = CallArgs( + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer(call_args, DeleteSubmissionTask) + + def _validate_if_bucket_supported(self, bucket): + # s3 high level operations don't support some resources + # (eg. S3 Object Lambda) only direct API calls are available + # for such resources + if self.VALIDATE_SUPPORTED_BUCKET_VALUES: + for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items(): + match = pattern.match(bucket) + if match: + raise ValueError( + 'TransferManager methods do not support %s ' + 'resource. Use direct client calls instead.' % resource + ) + + def _validate_all_known_args(self, actual, allowed): + for kwarg in actual: + if kwarg not in allowed: + raise ValueError( + "Invalid extra_args key '%s', " + "must be one of: %s" % (kwarg, ', '.join(allowed)) + ) + + def _submit_transfer( + self, call_args, submission_task_cls, extra_main_kwargs=None + ): + if not extra_main_kwargs: + extra_main_kwargs = {} + + # Create a TransferFuture to return back to the user + transfer_future, components = self._get_future_with_components( + call_args + ) + + # Add any provided done callbacks to the created transfer future + # to be invoked on the transfer future being complete. + for callback in get_callbacks(transfer_future, 'done'): + components['coordinator'].add_done_callback(callback) + + # Get the main kwargs needed to instantiate the submission task + main_kwargs = self._get_submission_task_main_kwargs( + transfer_future, extra_main_kwargs + ) + + # Submit a SubmissionTask that will submit all of the necessary + # tasks needed to complete the S3 transfer. + self._submission_executor.submit( + submission_task_cls( + transfer_coordinator=components['coordinator'], + main_kwargs=main_kwargs, + ) + ) + + # Increment the unique id counter for future transfer requests + self._id_counter += 1 + + return transfer_future + + def _get_future_with_components(self, call_args): + transfer_id = self._id_counter + # Creates a new transfer future along with its components + transfer_coordinator = TransferCoordinator(transfer_id=transfer_id) + # Track the transfer coordinator for transfers to manage. + self._coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Also make sure that the transfer coordinator is removed once + # the transfer completes so it does not stick around in memory. + transfer_coordinator.add_done_callback( + self._coordinator_controller.remove_transfer_coordinator, + transfer_coordinator, + ) + components = { + 'meta': TransferMeta(call_args, transfer_id=transfer_id), + 'coordinator': transfer_coordinator, + } + transfer_future = TransferFuture(**components) + return transfer_future, components + + def _get_submission_task_main_kwargs( + self, transfer_future, extra_main_kwargs + ): + main_kwargs = { + 'client': self._client, + 'config': self._config, + 'osutil': self._osutil, + 'request_executor': self._request_executor, + 'transfer_future': transfer_future, + } + main_kwargs.update(extra_main_kwargs) + return main_kwargs + + def _register_handlers(self): + # Register handlers to enable/disable callbacks on uploads. + event_name = 'request-created.s3' + self._client.meta.events.register_first( + event_name, + signal_not_transferring, + unique_id='s3upload-not-transferring', + ) + self._client.meta.events.register_last( + event_name, signal_transferring, unique_id='s3upload-transferring' + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, *args): + cancel = False + cancel_msg = '' + cancel_exc_type = FatalError + # If a exception was raised in the context handler, signal to cancel + # all of the inprogress futures in the shutdown. + if exc_type: + cancel = True + cancel_msg = str(exc_value) + if not cancel_msg: + cancel_msg = repr(exc_value) + # If it was a KeyboardInterrupt, the cancellation was initiated + # by the user. + if isinstance(exc_value, KeyboardInterrupt): + cancel_exc_type = CancelledError + self._shutdown(cancel, cancel_msg, cancel_exc_type) + + def shutdown(self, cancel=False, cancel_msg=''): + """Shutdown the TransferManager + + It will wait till all transfers complete before it completely shuts + down. + + :type cancel: boolean + :param cancel: If True, calls TransferFuture.cancel() for + all in-progress in transfers. This is useful if you want the + shutdown to happen quicker. + + :type cancel_msg: str + :param cancel_msg: The message to specify if canceling all in-progress + transfers. + """ + self._shutdown(cancel, cancel, cancel_msg) + + def _shutdown(self, cancel, cancel_msg, exc_type=CancelledError): + if cancel: + # Cancel all in-flight transfers if requested, before waiting + # for them to complete. + self._coordinator_controller.cancel(cancel_msg, exc_type) + try: + # Wait until there are no more in-progress transfers. This is + # wrapped in a try statement because this can be interrupted + # with a KeyboardInterrupt that needs to be caught. + self._coordinator_controller.wait() + except KeyboardInterrupt: + # If not errors were raised in the try block, the cancel should + # have no coordinators it needs to run cancel on. If there was + # an error raised in the try statement we want to cancel all of + # the inflight transfers before shutting down to speed that + # process up. + self._coordinator_controller.cancel('KeyboardInterrupt()') + raise + finally: + # Shutdown all of the executors. + self._submission_executor.shutdown() + self._request_executor.shutdown() + self._io_executor.shutdown() + + +class TransferCoordinatorController: + def __init__(self): + """Abstraction to control all transfer coordinators + + This abstraction allows the manager to wait for inprogress transfers + to complete and cancel all inprogress transfers. + """ + self._lock = threading.Lock() + self._tracked_transfer_coordinators = set() + + @property + def tracked_transfer_coordinators(self): + """The set of transfer coordinators being tracked""" + with self._lock: + # We return a copy because the set is mutable and if you were to + # iterate over the set, it may be changing in length due to + # additions and removals of transfer coordinators. + return copy.copy(self._tracked_transfer_coordinators) + + def add_transfer_coordinator(self, transfer_coordinator): + """Adds a transfer coordinator of a transfer to be canceled if needed + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + :param transfer_coordinator: The transfer coordinator for the + particular transfer + """ + with self._lock: + self._tracked_transfer_coordinators.add(transfer_coordinator) + + def remove_transfer_coordinator(self, transfer_coordinator): + """Remove a transfer coordinator from cancellation consideration + + Typically, this method is invoked by the transfer coordinator itself + to remove its self when it completes its transfer. + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + :param transfer_coordinator: The transfer coordinator for the + particular transfer + """ + with self._lock: + self._tracked_transfer_coordinators.remove(transfer_coordinator) + + def cancel(self, msg='', exc_type=CancelledError): + """Cancels all inprogress transfers + + This cancels the inprogress transfers by calling cancel() on all + tracked transfer coordinators. + + :param msg: The message to pass on to each transfer coordinator that + gets cancelled. + + :param exc_type: The type of exception to set for the cancellation + """ + for transfer_coordinator in self.tracked_transfer_coordinators: + transfer_coordinator.cancel(msg, exc_type) + + def wait(self): + """Wait until there are no more inprogress transfers + + This will not stop when failures are encountered and not propagate any + of these errors from failed transfers, but it can be interrupted with + a KeyboardInterrupt. + """ + try: + transfer_coordinator = None + for transfer_coordinator in self.tracked_transfer_coordinators: + transfer_coordinator.result() + except KeyboardInterrupt: + logger.debug('Received KeyboardInterrupt in wait()') + # If Keyboard interrupt is raised while waiting for + # the result, then exit out of the wait and raise the + # exception + if transfer_coordinator: + logger.debug( + 'On KeyboardInterrupt was waiting for %s', + transfer_coordinator, + ) + raise + except Exception: + # A general exception could have been thrown because + # of result(). We just want to ignore this and continue + # because we at least know that the transfer coordinator + # has completed. + pass diff --git a/contrib/python/s3transfer/py3/s3transfer/processpool.py b/contrib/python/s3transfer/py3/s3transfer/processpool.py new file mode 100644 index 0000000000..017eeb4499 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/processpool.py @@ -0,0 +1,1008 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Speeds up S3 throughput by using processes + +Getting Started +=============== + +The :class:`ProcessPoolDownloader` can be used to download a single file by +calling :meth:`ProcessPoolDownloader.download_file`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file('mybucket', 'mykey', 'myfile') + + +This snippet downloads the S3 object located in the bucket ``mybucket`` at the +key ``mykey`` to the local file ``myfile``. Any errors encountered during the +transfer are not propagated. To determine if a transfer succeeded or +failed, use the `Futures`_ interface. + + +The :class:`ProcessPoolDownloader` can be used to download multiple files as +well: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file('mybucket', 'mykey', 'myfile') + downloader.download_file('mybucket', 'myotherkey', 'myotherfile') + + +When running this snippet, the downloading of ``mykey`` and ``myotherkey`` +happen in parallel. The first ``download_file`` call does not block the +second ``download_file`` call. The snippet blocks when exiting +the context manager and blocks until both downloads are complete. + +Alternatively, the ``ProcessPoolDownloader`` can be instantiated +and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + downloader = ProcessPoolDownloader() + downloader.download_file('mybucket', 'mykey', 'myfile') + downloader.download_file('mybucket', 'myotherkey', 'myotherfile') + downloader.shutdown() + + +For this code snippet, the call to ``shutdown`` blocks until both +downloads are complete. + + +Additional Parameters +===================== + +Additional parameters can be provided to the ``download_file`` method: + +* ``extra_args``: A dictionary containing any additional client arguments + to include in the + `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_ + API request. For example: + + .. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file( + 'mybucket', 'mykey', 'myfile', + extra_args={'VersionId': 'myversion'}) + + +* ``expected_size``: By default, the downloader will make a HeadObject + call to determine the size of the object. To opt-out of this additional + API call, you can provide the size of the object in bytes: + + .. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + MB = 1024 * 1024 + with ProcessPoolDownloader() as downloader: + downloader.download_file( + 'mybucket', 'mykey', 'myfile', expected_size=2 * MB) + + +Futures +======= + +When ``download_file`` is called, it immediately returns a +:class:`ProcessPoolTransferFuture`. The future can be used to poll the state +of a particular transfer. To get the result of the download, +call :meth:`ProcessPoolTransferFuture.result`. The method blocks +until the transfer completes, whether it succeeds or fails. For example: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + future = downloader.download_file('mybucket', 'mykey', 'myfile') + print(future.result()) + + +If the download succeeds, the future returns ``None``: + +.. code:: python + + None + + +If the download fails, the exception causing the failure is raised. For +example, if ``mykey`` did not exist, the following error would be raised + + +.. code:: python + + botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found + + +.. note:: + + :meth:`ProcessPoolTransferFuture.result` can only be called while the + ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or + inside the context manager). + + +Process Pool Configuration +========================== + +By default, the downloader has the following configuration options: + +* ``multipart_threshold``: The threshold size for performing ranged downloads + in bytes. By default, ranged downloads happen for S3 objects that are + greater than or equal to 8 MB in size. + +* ``multipart_chunksize``: The size of each ranged download in bytes. By + default, the size of each ranged download is 8 MB. + +* ``max_request_processes``: The maximum number of processes used to download + S3 objects. By default, the maximum is 10 processes. + + +To change the default configuration, use the :class:`ProcessTransferConfig`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + from s3transfer.processpool import ProcessTransferConfig + + config = ProcessTransferConfig( + multipart_threshold=64 * 1024 * 1024, # 64 MB + max_request_processes=50 + ) + downloader = ProcessPoolDownloader(config=config) + + +Client Configuration +==================== + +The process pool downloader creates ``botocore`` clients on your behalf. In +order to affect how the client is created, pass the keyword arguments +that would have been used in the :meth:`botocore.Session.create_client` call: + +.. code:: python + + + from s3transfer.processpool import ProcessPoolDownloader + from s3transfer.processpool import ProcessTransferConfig + + downloader = ProcessPoolDownloader( + client_kwargs={'region_name': 'us-west-2'}) + + +This snippet ensures that all clients created by the ``ProcessPoolDownloader`` +are using ``us-west-2`` as their region. + +""" +import collections +import contextlib +import logging +import multiprocessing +import signal +import threading +from copy import deepcopy + +import botocore.session +from botocore.config import Config + +from s3transfer.compat import MAXINT, BaseManager +from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT +from s3transfer.exceptions import CancelledError, RetriesExceededError +from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.utils import ( + S3_RETRYABLE_DOWNLOAD_ERRORS, + CallArgs, + OSUtils, + calculate_num_parts, + calculate_range_parameter, +) + +logger = logging.getLogger(__name__) + +SHUTDOWN_SIGNAL = 'SHUTDOWN' + +# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader +# to the GetObjectSubmitter in order for the submitter to begin submitting +# GetObjectJobs to the GetObjectWorkers. +DownloadFileRequest = collections.namedtuple( + 'DownloadFileRequest', + [ + 'transfer_id', # The unique id for the transfer + 'bucket', # The bucket to download the object from + 'key', # The key to download the object from + 'filename', # The user-requested download location + 'extra_args', # Extra arguments to provide to client calls + 'expected_size', # The user-provided expected size of the download + ], +) + +# The GetObjectJob tuple is submitted from the GetObjectSubmitter +# to the GetObjectWorkers to download the file or parts of the file. +GetObjectJob = collections.namedtuple( + 'GetObjectJob', + [ + 'transfer_id', # The unique id for the transfer + 'bucket', # The bucket to download the object from + 'key', # The key to download the object from + 'temp_filename', # The temporary file to write the content to via + # completed GetObject calls. + 'extra_args', # Extra arguments to provide to the GetObject call + 'offset', # The offset to write the content for the temp file. + 'filename', # The user-requested download location. The worker + # of final GetObjectJob will move the file located at + # temp_filename to the location of filename. + ], +) + + +@contextlib.contextmanager +def ignore_ctrl_c(): + original_handler = _add_ignore_handler_for_interrupts() + yield + signal.signal(signal.SIGINT, original_handler) + + +def _add_ignore_handler_for_interrupts(): + # Windows is unable to pickle signal.signal directly so it needs to + # be wrapped in a function defined at the module level + return signal.signal(signal.SIGINT, signal.SIG_IGN) + + +class ProcessTransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + multipart_chunksize=8 * MB, + max_request_processes=10, + ): + """Configuration for the ProcessPoolDownloader + + :param multipart_threshold: The threshold for which ranged downloads + occur. + + :param multipart_chunksize: The chunk size of each ranged download. + + :param max_request_processes: The maximum number of processes that + will be making S3 API transfer-related requests at a time. + """ + self.multipart_threshold = multipart_threshold + self.multipart_chunksize = multipart_chunksize + self.max_request_processes = max_request_processes + + +class ProcessPoolDownloader: + def __init__(self, client_kwargs=None, config=None): + """Downloads S3 objects using process pools + + :type client_kwargs: dict + :param client_kwargs: The keyword arguments to provide when + instantiating S3 clients. The arguments must match the keyword + arguments provided to the + `botocore.session.Session.create_client()` method. + + :type config: ProcessTransferConfig + :param config: Configuration for the downloader + """ + if client_kwargs is None: + client_kwargs = {} + self._client_factory = ClientFactory(client_kwargs) + + self._transfer_config = config + if config is None: + self._transfer_config = ProcessTransferConfig() + + self._download_request_queue = multiprocessing.Queue(1000) + self._worker_queue = multiprocessing.Queue(1000) + self._osutil = OSUtils() + + self._started = False + self._start_lock = threading.Lock() + + # These below are initialized in the start() method + self._manager = None + self._transfer_monitor = None + self._submitter = None + self._workers = [] + + def download_file( + self, bucket, key, filename, extra_args=None, expected_size=None + ): + """Downloads the object's contents to a file + + :type bucket: str + :param bucket: The name of the bucket to download from + + :type key: str + :param key: The name of the key to download from + + :type filename: str + :param filename: The name of a file to download to. + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + client operation + + :type expected_size: int + :param expected_size: The expected size in bytes of the download. If + provided, the downloader will not call HeadObject to determine the + object's size and use the provided value instead. The size is + needed to determine whether to do a multipart download. + + :rtype: s3transfer.futures.TransferFuture + :returns: Transfer future representing the download + """ + self._start_if_needed() + if extra_args is None: + extra_args = {} + self._validate_all_known_args(extra_args) + transfer_id = self._transfer_monitor.notify_new_transfer() + download_file_request = DownloadFileRequest( + transfer_id=transfer_id, + bucket=bucket, + key=key, + filename=filename, + extra_args=extra_args, + expected_size=expected_size, + ) + logger.debug( + 'Submitting download file request: %s.', download_file_request + ) + self._download_request_queue.put(download_file_request) + call_args = CallArgs( + bucket=bucket, + key=key, + filename=filename, + extra_args=extra_args, + expected_size=expected_size, + ) + future = self._get_transfer_future(transfer_id, call_args) + return future + + def shutdown(self): + """Shutdown the downloader + + It will wait till all downloads are complete before returning. + """ + self._shutdown_if_needed() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, *args): + if isinstance(exc_value, KeyboardInterrupt): + if self._transfer_monitor is not None: + self._transfer_monitor.notify_cancel_all_in_progress() + self.shutdown() + + def _start_if_needed(self): + with self._start_lock: + if not self._started: + self._start() + + def _start(self): + self._start_transfer_monitor_manager() + self._start_submitter() + self._start_get_object_workers() + self._started = True + + def _validate_all_known_args(self, provided): + for kwarg in provided: + if kwarg not in ALLOWED_DOWNLOAD_ARGS: + download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS) + raise ValueError( + f"Invalid extra_args key '{kwarg}', " + f"must be one of: {download_args}" + ) + + def _get_transfer_future(self, transfer_id, call_args): + meta = ProcessPoolTransferMeta( + call_args=call_args, transfer_id=transfer_id + ) + future = ProcessPoolTransferFuture( + monitor=self._transfer_monitor, meta=meta + ) + return future + + def _start_transfer_monitor_manager(self): + logger.debug('Starting the TransferMonitorManager.') + self._manager = TransferMonitorManager() + # We do not want Ctrl-C's to cause the manager to shutdown immediately + # as worker processes will still need to communicate with it when they + # are shutting down. So instead we ignore Ctrl-C and let the manager + # be explicitly shutdown when shutting down the downloader. + self._manager.start(_add_ignore_handler_for_interrupts) + self._transfer_monitor = self._manager.TransferMonitor() + + def _start_submitter(self): + logger.debug('Starting the GetObjectSubmitter.') + self._submitter = GetObjectSubmitter( + transfer_config=self._transfer_config, + client_factory=self._client_factory, + transfer_monitor=self._transfer_monitor, + osutil=self._osutil, + download_request_queue=self._download_request_queue, + worker_queue=self._worker_queue, + ) + self._submitter.start() + + def _start_get_object_workers(self): + logger.debug( + 'Starting %s GetObjectWorkers.', + self._transfer_config.max_request_processes, + ) + for _ in range(self._transfer_config.max_request_processes): + worker = GetObjectWorker( + queue=self._worker_queue, + client_factory=self._client_factory, + transfer_monitor=self._transfer_monitor, + osutil=self._osutil, + ) + worker.start() + self._workers.append(worker) + + def _shutdown_if_needed(self): + with self._start_lock: + if self._started: + self._shutdown() + + def _shutdown(self): + self._shutdown_submitter() + self._shutdown_get_object_workers() + self._shutdown_transfer_monitor_manager() + self._started = False + + def _shutdown_transfer_monitor_manager(self): + logger.debug('Shutting down the TransferMonitorManager.') + self._manager.shutdown() + + def _shutdown_submitter(self): + logger.debug('Shutting down the GetObjectSubmitter.') + self._download_request_queue.put(SHUTDOWN_SIGNAL) + self._submitter.join() + + def _shutdown_get_object_workers(self): + logger.debug('Shutting down the GetObjectWorkers.') + for _ in self._workers: + self._worker_queue.put(SHUTDOWN_SIGNAL) + for worker in self._workers: + worker.join() + + +class ProcessPoolTransferFuture(BaseTransferFuture): + def __init__(self, monitor, meta): + """The future associated to a submitted process pool transfer request + + :type monitor: TransferMonitor + :param monitor: The monitor associated to the process pool downloader + + :type meta: ProcessPoolTransferMeta + :param meta: The metadata associated to the request. This object + is visible to the requester. + """ + self._monitor = monitor + self._meta = meta + + @property + def meta(self): + return self._meta + + def done(self): + return self._monitor.is_done(self._meta.transfer_id) + + def result(self): + try: + return self._monitor.poll_for_result(self._meta.transfer_id) + except KeyboardInterrupt: + # For the multiprocessing Manager, a thread is given a single + # connection to reuse in communicating between the thread in the + # main process and the Manager's process. If a Ctrl-C happens when + # polling for the result, it will make the main thread stop trying + # to receive from the connection, but the Manager process will not + # know that the main process has stopped trying to receive and + # will not close the connection. As a result if another message is + # sent to the Manager process, the listener in the Manager + # processes will not process the new message as it is still trying + # trying to process the previous message (that was Ctrl-C'd) and + # thus cause the thread in the main process to hang on its send. + # The only way around this is to create a new connection and send + # messages from that new connection instead. + self._monitor._connect() + self.cancel() + raise + + def cancel(self): + self._monitor.notify_exception( + self._meta.transfer_id, CancelledError() + ) + + +class ProcessPoolTransferMeta(BaseTransferMeta): + """Holds metadata about the ProcessPoolTransferFuture""" + + def __init__(self, transfer_id, call_args): + self._transfer_id = transfer_id + self._call_args = call_args + self._user_context = {} + + @property + def call_args(self): + return self._call_args + + @property + def transfer_id(self): + return self._transfer_id + + @property + def user_context(self): + return self._user_context + + +class ClientFactory: + def __init__(self, client_kwargs=None): + """Creates S3 clients for processes + + Botocore sessions and clients are not pickleable so they cannot be + inherited across Process boundaries. Instead, they must be instantiated + once a process is running. + """ + self._client_kwargs = client_kwargs + if self._client_kwargs is None: + self._client_kwargs = {} + + client_config = deepcopy(self._client_kwargs.get('config', Config())) + if not client_config.user_agent_extra: + client_config.user_agent_extra = PROCESS_USER_AGENT + else: + client_config.user_agent_extra += " " + PROCESS_USER_AGENT + self._client_kwargs['config'] = client_config + + def create_client(self): + """Create a botocore S3 client""" + return botocore.session.Session().create_client( + 's3', **self._client_kwargs + ) + + +class TransferMonitor: + def __init__(self): + """Monitors transfers for cross-process communication + + Notifications can be sent to the monitor and information can be + retrieved from the monitor for a particular transfer. This abstraction + is ran in a ``multiprocessing.managers.BaseManager`` in order to be + shared across processes. + """ + # TODO: Add logic that removes the TransferState if the transfer is + # marked as done and the reference to the future is no longer being + # held onto. Without this logic, this dictionary will continue to + # grow in size with no limit. + self._transfer_states = {} + self._id_count = 0 + self._init_lock = threading.Lock() + + def notify_new_transfer(self): + with self._init_lock: + transfer_id = self._id_count + self._transfer_states[transfer_id] = TransferState() + self._id_count += 1 + return transfer_id + + def is_done(self, transfer_id): + """Determine a particular transfer is complete + + :param transfer_id: Unique identifier for the transfer + :return: True, if done. False, otherwise. + """ + return self._transfer_states[transfer_id].done + + def notify_done(self, transfer_id): + """Notify a particular transfer is complete + + :param transfer_id: Unique identifier for the transfer + """ + self._transfer_states[transfer_id].set_done() + + def poll_for_result(self, transfer_id): + """Poll for the result of a transfer + + :param transfer_id: Unique identifier for the transfer + :return: If the transfer succeeded, it will return the result. If the + transfer failed, it will raise the exception associated to the + failure. + """ + self._transfer_states[transfer_id].wait_till_done() + exception = self._transfer_states[transfer_id].exception + if exception: + raise exception + return None + + def notify_exception(self, transfer_id, exception): + """Notify an exception was encountered for a transfer + + :param transfer_id: Unique identifier for the transfer + :param exception: The exception encountered for that transfer + """ + # TODO: Not all exceptions are pickleable so if we are running + # this in a multiprocessing.BaseManager we will want to + # make sure to update this signature to ensure pickleability of the + # arguments or have the ProxyObject do the serialization. + self._transfer_states[transfer_id].exception = exception + + def notify_cancel_all_in_progress(self): + for transfer_state in self._transfer_states.values(): + if not transfer_state.done: + transfer_state.exception = CancelledError() + + def get_exception(self, transfer_id): + """Retrieve the exception encountered for the transfer + + :param transfer_id: Unique identifier for the transfer + :return: The exception encountered for that transfer. Otherwise + if there were no exceptions, returns None. + """ + return self._transfer_states[transfer_id].exception + + def notify_expected_jobs_to_complete(self, transfer_id, num_jobs): + """Notify the amount of jobs expected for a transfer + + :param transfer_id: Unique identifier for the transfer + :param num_jobs: The number of jobs to complete the transfer + """ + self._transfer_states[transfer_id].jobs_to_complete = num_jobs + + def notify_job_complete(self, transfer_id): + """Notify that a single job is completed for a transfer + + :param transfer_id: Unique identifier for the transfer + :return: The number of jobs remaining to complete the transfer + """ + return self._transfer_states[transfer_id].decrement_jobs_to_complete() + + +class TransferState: + """Represents the current state of an individual transfer""" + + # NOTE: Ideally the TransferState object would be used directly by the + # various different abstractions in the ProcessPoolDownloader and remove + # the need for the TransferMonitor. However, it would then impose the + # constraint that two hops are required to make or get any changes in the + # state of a transfer across processes: one hop to get a proxy object for + # the TransferState and then a second hop to communicate calling the + # specific TransferState method. + def __init__(self): + self._exception = None + self._done_event = threading.Event() + self._job_lock = threading.Lock() + self._jobs_to_complete = 0 + + @property + def done(self): + return self._done_event.is_set() + + def set_done(self): + self._done_event.set() + + def wait_till_done(self): + self._done_event.wait(MAXINT) + + @property + def exception(self): + return self._exception + + @exception.setter + def exception(self, val): + self._exception = val + + @property + def jobs_to_complete(self): + return self._jobs_to_complete + + @jobs_to_complete.setter + def jobs_to_complete(self, val): + self._jobs_to_complete = val + + def decrement_jobs_to_complete(self): + with self._job_lock: + self._jobs_to_complete -= 1 + return self._jobs_to_complete + + +class TransferMonitorManager(BaseManager): + pass + + +TransferMonitorManager.register('TransferMonitor', TransferMonitor) + + +class BaseS3TransferProcess(multiprocessing.Process): + def __init__(self, client_factory): + super().__init__() + self._client_factory = client_factory + self._client = None + + def run(self): + # Clients are not pickleable so their instantiation cannot happen + # in the __init__ for processes that are created under the + # spawn method. + self._client = self._client_factory.create_client() + with ignore_ctrl_c(): + # By default these processes are ran as child processes to the + # main process. Any Ctrl-c encountered in the main process is + # propagated to the child process and interrupt it at any time. + # To avoid any potentially bad states caused from an interrupt + # (i.e. a transfer failing to notify its done or making the + # communication protocol become out of sync with the + # TransferMonitor), we ignore all Ctrl-C's and allow the main + # process to notify these child processes when to stop processing + # jobs. + self._do_run() + + def _do_run(self): + raise NotImplementedError('_do_run()') + + +class GetObjectSubmitter(BaseS3TransferProcess): + def __init__( + self, + transfer_config, + client_factory, + transfer_monitor, + osutil, + download_request_queue, + worker_queue, + ): + """Submit GetObjectJobs to fulfill a download file request + + :param transfer_config: Configuration for transfers. + :param client_factory: ClientFactory for creating S3 clients. + :param transfer_monitor: Monitor for notifying and retrieving state + of transfer. + :param osutil: OSUtils object to use for os-related behavior when + performing the transfer. + :param download_request_queue: Queue to retrieve download file + requests. + :param worker_queue: Queue to submit GetObjectJobs for workers + to perform. + """ + super().__init__(client_factory) + self._transfer_config = transfer_config + self._transfer_monitor = transfer_monitor + self._osutil = osutil + self._download_request_queue = download_request_queue + self._worker_queue = worker_queue + + def _do_run(self): + while True: + download_file_request = self._download_request_queue.get() + if download_file_request == SHUTDOWN_SIGNAL: + logger.debug('Submitter shutdown signal received.') + return + try: + self._submit_get_object_jobs(download_file_request) + except Exception as e: + logger.debug( + 'Exception caught when submitting jobs for ' + 'download file request %s: %s', + download_file_request, + e, + exc_info=True, + ) + self._transfer_monitor.notify_exception( + download_file_request.transfer_id, e + ) + self._transfer_monitor.notify_done( + download_file_request.transfer_id + ) + + def _submit_get_object_jobs(self, download_file_request): + size = self._get_size(download_file_request) + temp_filename = self._allocate_temp_file(download_file_request, size) + if size < self._transfer_config.multipart_threshold: + self._submit_single_get_object_job( + download_file_request, temp_filename + ) + else: + self._submit_ranged_get_object_jobs( + download_file_request, temp_filename, size + ) + + def _get_size(self, download_file_request): + expected_size = download_file_request.expected_size + if expected_size is None: + expected_size = self._client.head_object( + Bucket=download_file_request.bucket, + Key=download_file_request.key, + **download_file_request.extra_args, + )['ContentLength'] + return expected_size + + def _allocate_temp_file(self, download_file_request, size): + temp_filename = self._osutil.get_temp_filename( + download_file_request.filename + ) + self._osutil.allocate(temp_filename, size) + return temp_filename + + def _submit_single_get_object_job( + self, download_file_request, temp_filename + ): + self._notify_jobs_to_complete(download_file_request.transfer_id, 1) + self._submit_get_object_job( + transfer_id=download_file_request.transfer_id, + bucket=download_file_request.bucket, + key=download_file_request.key, + temp_filename=temp_filename, + offset=0, + extra_args=download_file_request.extra_args, + filename=download_file_request.filename, + ) + + def _submit_ranged_get_object_jobs( + self, download_file_request, temp_filename, size + ): + part_size = self._transfer_config.multipart_chunksize + num_parts = calculate_num_parts(size, part_size) + self._notify_jobs_to_complete( + download_file_request.transfer_id, num_parts + ) + for i in range(num_parts): + offset = i * part_size + range_parameter = calculate_range_parameter( + part_size, i, num_parts + ) + get_object_kwargs = {'Range': range_parameter} + get_object_kwargs.update(download_file_request.extra_args) + self._submit_get_object_job( + transfer_id=download_file_request.transfer_id, + bucket=download_file_request.bucket, + key=download_file_request.key, + temp_filename=temp_filename, + offset=offset, + extra_args=get_object_kwargs, + filename=download_file_request.filename, + ) + + def _submit_get_object_job(self, **get_object_job_kwargs): + self._worker_queue.put(GetObjectJob(**get_object_job_kwargs)) + + def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete): + logger.debug( + 'Notifying %s job(s) to complete for transfer_id %s.', + jobs_to_complete, + transfer_id, + ) + self._transfer_monitor.notify_expected_jobs_to_complete( + transfer_id, jobs_to_complete + ) + + +class GetObjectWorker(BaseS3TransferProcess): + # TODO: It may make sense to expose these class variables as configuration + # options if users want to tweak them. + _MAX_ATTEMPTS = 5 + _IO_CHUNKSIZE = 2 * MB + + def __init__(self, queue, client_factory, transfer_monitor, osutil): + """Fulfills GetObjectJobs + + Downloads the S3 object, writes it to the specified file, and + renames the file to its final location if it completes the final + job for a particular transfer. + + :param queue: Queue for retrieving GetObjectJob's + :param client_factory: ClientFactory for creating S3 clients + :param transfer_monitor: Monitor for notifying + :param osutil: OSUtils object to use for os-related behavior when + performing the transfer. + """ + super().__init__(client_factory) + self._queue = queue + self._client_factory = client_factory + self._transfer_monitor = transfer_monitor + self._osutil = osutil + + def _do_run(self): + while True: + job = self._queue.get() + if job == SHUTDOWN_SIGNAL: + logger.debug('Worker shutdown signal received.') + return + if not self._transfer_monitor.get_exception(job.transfer_id): + self._run_get_object_job(job) + else: + logger.debug( + 'Skipping get object job %s because there was a previous ' + 'exception.', + job, + ) + remaining = self._transfer_monitor.notify_job_complete( + job.transfer_id + ) + logger.debug( + '%s jobs remaining for transfer_id %s.', + remaining, + job.transfer_id, + ) + if not remaining: + self._finalize_download( + job.transfer_id, job.temp_filename, job.filename + ) + + def _run_get_object_job(self, job): + try: + self._do_get_object( + bucket=job.bucket, + key=job.key, + temp_filename=job.temp_filename, + extra_args=job.extra_args, + offset=job.offset, + ) + except Exception as e: + logger.debug( + 'Exception caught when downloading object for ' + 'get object job %s: %s', + job, + e, + exc_info=True, + ) + self._transfer_monitor.notify_exception(job.transfer_id, e) + + def _do_get_object(self, bucket, key, extra_args, temp_filename, offset): + last_exception = None + for i in range(self._MAX_ATTEMPTS): + try: + response = self._client.get_object( + Bucket=bucket, Key=key, **extra_args + ) + self._write_to_file(temp_filename, offset, response['Body']) + return + except S3_RETRYABLE_DOWNLOAD_ERRORS as e: + logger.debug( + 'Retrying exception caught (%s), ' + 'retrying request, (attempt %s / %s)', + e, + i + 1, + self._MAX_ATTEMPTS, + exc_info=True, + ) + last_exception = e + raise RetriesExceededError(last_exception) + + def _write_to_file(self, filename, offset, body): + with open(filename, 'rb+') as f: + f.seek(offset) + chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'') + for chunk in chunks: + f.write(chunk) + + def _finalize_download(self, transfer_id, temp_filename, filename): + if self._transfer_monitor.get_exception(transfer_id): + self._osutil.remove_file(temp_filename) + else: + self._do_file_rename(transfer_id, temp_filename, filename) + self._transfer_monitor.notify_done(transfer_id) + + def _do_file_rename(self, transfer_id, temp_filename, filename): + try: + self._osutil.rename_file(temp_filename, filename) + except Exception as e: + self._transfer_monitor.notify_exception(transfer_id, e) + self._osutil.remove_file(temp_filename) diff --git a/contrib/python/s3transfer/py3/s3transfer/subscribers.py b/contrib/python/s3transfer/py3/s3transfer/subscribers.py new file mode 100644 index 0000000000..cf0dbaa0d7 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/subscribers.py @@ -0,0 +1,92 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.compat import accepts_kwargs +from s3transfer.exceptions import InvalidSubscriberMethodError + + +class BaseSubscriber: + """The base subscriber class + + It is recommended that all subscriber implementations subclass and then + override the subscription methods (i.e. on_{subsribe_type}() methods). + """ + + VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done'] + + def __new__(cls, *args, **kwargs): + cls._validate_subscriber_methods() + return super().__new__(cls) + + @classmethod + def _validate_subscriber_methods(cls): + for subscriber_type in cls.VALID_SUBSCRIBER_TYPES: + subscriber_method = getattr(cls, 'on_' + subscriber_type) + if not callable(subscriber_method): + raise InvalidSubscriberMethodError( + 'Subscriber method %s must be callable.' + % subscriber_method + ) + + if not accepts_kwargs(subscriber_method): + raise InvalidSubscriberMethodError( + 'Subscriber method %s must accept keyword ' + 'arguments (**kwargs)' % subscriber_method + ) + + def on_queued(self, future, **kwargs): + """Callback to be invoked when transfer request gets queued + + This callback can be useful for: + + * Keeping track of how many transfers have been requested + * Providing the expected transfer size through + future.meta.provide_transfer_size() so a HeadObject would not + need to be made for copies and downloads. + + :type future: s3transfer.futures.TransferFuture + :param future: The TransferFuture representing the requested transfer. + """ + pass + + def on_progress(self, future, bytes_transferred, **kwargs): + """Callback to be invoked when progress is made on transfer + + This callback can be useful for: + + * Recording and displaying progress + + :type future: s3transfer.futures.TransferFuture + :param future: The TransferFuture representing the requested transfer. + + :type bytes_transferred: int + :param bytes_transferred: The number of bytes transferred for that + invocation of the callback. Note that a negative amount can be + provided, which usually indicates that an in-progress request + needed to be retried and thus progress was rewound. + """ + pass + + def on_done(self, future, **kwargs): + """Callback to be invoked once a transfer is done + + This callback can be useful for: + + * Recording and displaying whether the transfer succeeded or + failed using future.result() + * Running some task after the transfer completed like changing + the last modified time of a downloaded file. + + :type future: s3transfer.futures.TransferFuture + :param future: The TransferFuture representing the requested transfer. + """ + pass diff --git a/contrib/python/s3transfer/py3/s3transfer/tasks.py b/contrib/python/s3transfer/py3/s3transfer/tasks.py new file mode 100644 index 0000000000..1bad981264 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/tasks.py @@ -0,0 +1,387 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import logging + +from s3transfer.utils import get_callbacks + +logger = logging.getLogger(__name__) + + +class Task: + """A task associated to a TransferFuture request + + This is a base class for other classes to subclass from. All subclassed + classes must implement the main() method. + """ + + def __init__( + self, + transfer_coordinator, + main_kwargs=None, + pending_main_kwargs=None, + done_callbacks=None, + is_final=False, + ): + """ + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + :param transfer_coordinator: The context associated to the + TransferFuture for which this Task is associated with. + + :type main_kwargs: dict + :param main_kwargs: The keyword args that can be immediately supplied + to the _main() method of the task + + :type pending_main_kwargs: dict + :param pending_main_kwargs: The keyword args that are depended upon + by the result from a dependent future(s). The result returned by + the future(s) will be used as the value for the keyword argument + when _main() is called. The values for each key can be: + * a single future - Once completed, its value will be the + result of that single future + * a list of futures - Once all of the futures complete, the + value used will be a list of each completed future result + value in order of when they were originally supplied. + + :type done_callbacks: list of callbacks + :param done_callbacks: A list of callbacks to call once the task is + done completing. Each callback will be called with no arguments + and will be called no matter if the task succeeds or an exception + is raised. + + :type is_final: boolean + :param is_final: True, to indicate that this task is the final task + for the TransferFuture request. By setting this value to True, it + will set the result of the entire TransferFuture to the result + returned by this task's main() method. + """ + self._transfer_coordinator = transfer_coordinator + + self._main_kwargs = main_kwargs + if self._main_kwargs is None: + self._main_kwargs = {} + + self._pending_main_kwargs = pending_main_kwargs + if pending_main_kwargs is None: + self._pending_main_kwargs = {} + + self._done_callbacks = done_callbacks + if self._done_callbacks is None: + self._done_callbacks = [] + + self._is_final = is_final + + def __repr__(self): + # These are the general main_kwarg parameters that we want to + # display in the repr. + params_to_display = [ + 'bucket', + 'key', + 'part_number', + 'final_filename', + 'transfer_future', + 'offset', + 'extra_args', + ] + main_kwargs_to_display = self._get_kwargs_with_params_to_include( + self._main_kwargs, params_to_display + ) + return '{}(transfer_id={}, {})'.format( + self.__class__.__name__, + self._transfer_coordinator.transfer_id, + main_kwargs_to_display, + ) + + @property + def transfer_id(self): + """The id for the transfer request that the task belongs to""" + return self._transfer_coordinator.transfer_id + + def _get_kwargs_with_params_to_include(self, kwargs, include): + filtered_kwargs = {} + for param in include: + if param in kwargs: + filtered_kwargs[param] = kwargs[param] + return filtered_kwargs + + def _get_kwargs_with_params_to_exclude(self, kwargs, exclude): + filtered_kwargs = {} + for param, value in kwargs.items(): + if param in exclude: + continue + filtered_kwargs[param] = value + return filtered_kwargs + + def __call__(self): + """The callable to use when submitting a Task to an executor""" + try: + # Wait for all of futures this task depends on. + self._wait_on_dependent_futures() + # Gather up all of the main keyword arguments for main(). + # This includes the immediately provided main_kwargs and + # the values for pending_main_kwargs that source from the return + # values from the task's dependent futures. + kwargs = self._get_all_main_kwargs() + # If the task is not done (really only if some other related + # task to the TransferFuture had failed) then execute the task's + # main() method. + if not self._transfer_coordinator.done(): + return self._execute_main(kwargs) + except Exception as e: + self._log_and_set_exception(e) + finally: + # Run any done callbacks associated to the task no matter what. + for done_callback in self._done_callbacks: + done_callback() + + if self._is_final: + # If this is the final task announce that it is done if results + # are waiting on its completion. + self._transfer_coordinator.announce_done() + + def _execute_main(self, kwargs): + # Do not display keyword args that should not be printed, especially + # if they are going to make the logs hard to follow. + params_to_exclude = ['data'] + kwargs_to_display = self._get_kwargs_with_params_to_exclude( + kwargs, params_to_exclude + ) + # Log what is about to be executed. + logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}") + + return_value = self._main(**kwargs) + # If the task is the final task, then set the TransferFuture's + # value to the return value from main(). + if self._is_final: + self._transfer_coordinator.set_result(return_value) + return return_value + + def _log_and_set_exception(self, exception): + # If an exception is ever thrown than set the exception for the + # entire TransferFuture. + logger.debug("Exception raised.", exc_info=True) + self._transfer_coordinator.set_exception(exception) + + def _main(self, **kwargs): + """The method that will be ran in the executor + + This method must be implemented by subclasses from Task. main() can + be implemented with any arguments decided upon by the subclass. + """ + raise NotImplementedError('_main() must be implemented') + + def _wait_on_dependent_futures(self): + # Gather all of the futures into that main() depends on. + futures_to_wait_on = [] + for _, future in self._pending_main_kwargs.items(): + # If the pending main keyword arg is a list then extend the list. + if isinstance(future, list): + futures_to_wait_on.extend(future) + # If the pending main keyword arg is a future append it to the list. + else: + futures_to_wait_on.append(future) + # Now wait for all of the futures to complete. + self._wait_until_all_complete(futures_to_wait_on) + + def _wait_until_all_complete(self, futures): + # This is a basic implementation of the concurrent.futures.wait() + # + # concurrent.futures.wait() is not used instead because of this + # reported issue: https://bugs.python.org/issue20319. + # The issue would occasionally cause multipart uploads to hang + # when wait() was called. With this approach, it avoids the + # concurrency bug by removing any association with concurrent.futures + # implementation of waiters. + logger.debug( + '%s about to wait for the following futures %s', self, futures + ) + for future in futures: + try: + logger.debug('%s about to wait for %s', self, future) + future.result() + except Exception: + # result() can also produce exceptions. We want to ignore + # these to be deferred to error handling down the road. + pass + logger.debug('%s done waiting for dependent futures', self) + + def _get_all_main_kwargs(self): + # Copy over all of the kwargs that we know is available. + kwargs = copy.copy(self._main_kwargs) + + # Iterate through the kwargs whose values are pending on the result + # of a future. + for key, pending_value in self._pending_main_kwargs.items(): + # If the value is a list of futures, iterate though the list + # appending on the result from each future. + if isinstance(pending_value, list): + result = [] + for future in pending_value: + result.append(future.result()) + # Otherwise if the pending_value is a future, just wait for it. + else: + result = pending_value.result() + # Add the retrieved value to the kwargs to be sent to the + # main() call. + kwargs[key] = result + return kwargs + + +class SubmissionTask(Task): + """A base class for any submission task + + Submission tasks are the top-level task used to submit a series of tasks + to execute a particular transfer. + """ + + def _main(self, transfer_future, **kwargs): + """ + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + + :param kwargs: Any additional kwargs that you may want to pass + to the _submit() method + """ + try: + self._transfer_coordinator.set_status_to_queued() + + # Before submitting any tasks, run all of the on_queued callbacks + on_queued_callbacks = get_callbacks(transfer_future, 'queued') + for on_queued_callback in on_queued_callbacks: + on_queued_callback() + + # Once callbacks have been ran set the status to running. + self._transfer_coordinator.set_status_to_running() + + # Call the submit method to start submitting tasks to execute the + # transfer. + self._submit(transfer_future=transfer_future, **kwargs) + except BaseException as e: + # If there was an exception raised during the submission of task + # there is a chance that the final task that signals if a transfer + # is done and too run the cleanup may never have been submitted in + # the first place so we need to account accordingly. + # + # Note that BaseException is caught, instead of Exception, because + # for some implementations of executors, specifically the serial + # implementation, the SubmissionTask is directly exposed to + # KeyboardInterupts and so needs to cleanup and signal done + # for those as well. + + # Set the exception, that caused the process to fail. + self._log_and_set_exception(e) + + # Wait for all possibly associated futures that may have spawned + # from this submission task have finished before we announce the + # transfer done. + self._wait_for_all_submitted_futures_to_complete() + + # Announce the transfer as done, which will run any cleanups + # and done callbacks as well. + self._transfer_coordinator.announce_done() + + def _submit(self, transfer_future, **kwargs): + """The submission method to be implemented + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + + :param kwargs: Any additional keyword arguments you want to be passed + in + """ + raise NotImplementedError('_submit() must be implemented') + + def _wait_for_all_submitted_futures_to_complete(self): + # We want to wait for all futures that were submitted to + # complete as we do not want the cleanup callbacks or done callbacks + # to be called to early. The main problem is any task that was + # submitted may have submitted even more during its process and so + # we need to account accordingly. + + # First get all of the futures that were submitted up to this point. + submitted_futures = self._transfer_coordinator.associated_futures + while submitted_futures: + # Wait for those futures to complete. + self._wait_until_all_complete(submitted_futures) + # However, more futures may have been submitted as we waited so + # we need to check again for any more associated futures. + possibly_more_submitted_futures = ( + self._transfer_coordinator.associated_futures + ) + # If the current list of submitted futures is equal to the + # the list of associated futures for when after the wait completes, + # we can ensure no more futures were submitted in waiting on + # the current list of futures to complete ultimately meaning all + # futures that may have spawned from the original submission task + # have completed. + if submitted_futures == possibly_more_submitted_futures: + break + submitted_futures = possibly_more_submitted_futures + + +class CreateMultipartUploadTask(Task): + """Task to initiate a multipart upload""" + + def _main(self, client, bucket, key, extra_args): + """ + :param client: The client to use when calling CreateMultipartUpload + :param bucket: The name of the bucket to upload to + :param key: The name of the key to upload to + :param extra_args: A dictionary of any extra arguments that may be + used in the initialization. + + :returns: The upload id of the multipart upload + """ + # Create the multipart upload. + response = client.create_multipart_upload( + Bucket=bucket, Key=key, **extra_args + ) + upload_id = response['UploadId'] + + # Add a cleanup if the multipart upload fails at any point. + self._transfer_coordinator.add_failure_cleanup( + client.abort_multipart_upload, + Bucket=bucket, + Key=key, + UploadId=upload_id, + ) + return upload_id + + +class CompleteMultipartUploadTask(Task): + """Task to complete a multipart upload""" + + def _main(self, client, bucket, key, upload_id, parts, extra_args): + """ + :param client: The client to use when calling CompleteMultipartUpload + :param bucket: The name of the bucket to upload to + :param key: The name of the key to upload to + :param upload_id: The id of the upload + :param parts: A list of parts to use to complete the multipart upload:: + + [{'Etag': etag_value, 'PartNumber': part_number}, ...] + + Each element in the list consists of a return value from + ``UploadPartTask.main()``. + :param extra_args: A dictionary of any extra arguments that may be + used in completing the multipart transfer. + """ + client.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={'Parts': parts}, + **extra_args, + ) diff --git a/contrib/python/s3transfer/py3/s3transfer/upload.py b/contrib/python/s3transfer/py3/s3transfer/upload.py new file mode 100644 index 0000000000..31ade051d7 --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/upload.py @@ -0,0 +1,795 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import math +from io import BytesIO + +from s3transfer.compat import readable, seekable +from s3transfer.futures import IN_MEMORY_UPLOAD_TAG +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import ( + ChunksizeAdjuster, + DeferredOpenFile, + get_callbacks, + get_filtered_dict, +) + + +class AggregatedProgressCallback: + def __init__(self, callbacks, threshold=1024 * 256): + """Aggregates progress updates for every provided progress callback + + :type callbacks: A list of functions that accepts bytes_transferred + as a single argument + :param callbacks: The callbacks to invoke when threshold is reached + + :type threshold: int + :param threshold: The progress threshold in which to take the + aggregated progress and invoke the progress callback with that + aggregated progress total + """ + self._callbacks = callbacks + self._threshold = threshold + self._bytes_seen = 0 + + def __call__(self, bytes_transferred): + self._bytes_seen += bytes_transferred + if self._bytes_seen >= self._threshold: + self._trigger_callbacks() + + def flush(self): + """Flushes out any progress that has not been sent to its callbacks""" + if self._bytes_seen > 0: + self._trigger_callbacks() + + def _trigger_callbacks(self): + for callback in self._callbacks: + callback(bytes_transferred=self._bytes_seen) + self._bytes_seen = 0 + + +class InterruptReader: + """Wrapper that can interrupt reading using an error + + It uses a transfer coordinator to propagate an error if it notices + that a read is being made while the file is being read from. + + :type fileobj: file-like obj + :param fileobj: The file-like object to read from + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + :param transfer_coordinator: The transfer coordinator to use if the + reader needs to be interrupted. + """ + + def __init__(self, fileobj, transfer_coordinator): + self._fileobj = fileobj + self._transfer_coordinator = transfer_coordinator + + def read(self, amount=None): + # If there is an exception, then raise the exception. + # We raise an error instead of returning no bytes because for + # requests where the content length and md5 was sent, it will + # cause md5 mismatches and retries as there was no indication that + # the stream being read from encountered any issues. + if self._transfer_coordinator.exception: + raise self._transfer_coordinator.exception + return self._fileobj.read(amount) + + def seek(self, where, whence=0): + self._fileobj.seek(where, whence) + + def tell(self): + return self._fileobj.tell() + + def close(self): + self._fileobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class UploadInputManager: + """Base manager class for handling various types of files for uploads + + This class is typically used for the UploadSubmissionTask class to help + determine the following: + + * How to determine the size of the file + * How to determine if a multipart upload is required + * How to retrieve the body for a PutObject + * How to retrieve the bodies for a set of UploadParts + + The answers/implementations differ for the various types of file inputs + that may be accepted. All implementations must subclass and override + public methods from this class. + """ + + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): + self._osutil = osutil + self._transfer_coordinator = transfer_coordinator + self._bandwidth_limiter = bandwidth_limiter + + @classmethod + def is_compatible(cls, upload_source): + """Determines if the source for the upload is compatible with manager + + :param upload_source: The source for which the upload will pull data + from. + + :returns: True if the manager can handle the type of source specified + otherwise returns False. + """ + raise NotImplementedError('must implement _is_compatible()') + + def stores_body_in_memory(self, operation_name): + """Whether the body it provides are stored in-memory + + :type operation_name: str + :param operation_name: The name of the client operation that the body + is being used for. Valid operation_names are ``put_object`` and + ``upload_part``. + + :rtype: boolean + :returns: True if the body returned by the manager will be stored in + memory. False if the manager will not directly store the body in + memory. + """ + raise NotImplementedError('must implement store_body_in_memory()') + + def provide_transfer_size(self, transfer_future): + """Provides the transfer size of an upload + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The future associated with upload request + """ + raise NotImplementedError('must implement provide_transfer_size()') + + def requires_multipart_upload(self, transfer_future, config): + """Determines where a multipart upload is required + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The future associated with upload request + + :type config: s3transfer.manager.TransferConfig + :param config: The config associated to the transfer manager + + :rtype: boolean + :returns: True, if the upload should be multipart based on + configuration and size. False, otherwise. + """ + raise NotImplementedError('must implement requires_multipart_upload()') + + def get_put_object_body(self, transfer_future): + """Returns the body to use for PutObject + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The future associated with upload request + + :type config: s3transfer.manager.TransferConfig + :param config: The config associated to the transfer manager + + :rtype: s3transfer.utils.ReadFileChunk + :returns: A ReadFileChunk including all progress callbacks + associated with the transfer future. + """ + raise NotImplementedError('must implement get_put_object_body()') + + def yield_upload_part_bodies(self, transfer_future, chunksize): + """Yields the part number and body to use for each UploadPart + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The future associated with upload request + + :type chunksize: int + :param chunksize: The chunksize to use for this upload. + + :rtype: int, s3transfer.utils.ReadFileChunk + :returns: Yields the part number and the ReadFileChunk including all + progress callbacks associated with the transfer future for that + specific yielded part. + """ + raise NotImplementedError('must implement yield_upload_part_bodies()') + + def _wrap_fileobj(self, fileobj): + fileobj = InterruptReader(fileobj, self._transfer_coordinator) + if self._bandwidth_limiter: + fileobj = self._bandwidth_limiter.get_bandwith_limited_stream( + fileobj, self._transfer_coordinator, enabled=False + ) + return fileobj + + def _get_progress_callbacks(self, transfer_future): + callbacks = get_callbacks(transfer_future, 'progress') + # We only want to be wrapping the callbacks if there are callbacks to + # invoke because we do not want to be doing any unnecessary work if + # there are no callbacks to invoke. + if callbacks: + return [AggregatedProgressCallback(callbacks)] + return [] + + def _get_close_callbacks(self, aggregated_progress_callbacks): + return [callback.flush for callback in aggregated_progress_callbacks] + + +class UploadFilenameInputManager(UploadInputManager): + """Upload utility for filenames""" + + @classmethod + def is_compatible(cls, upload_source): + return isinstance(upload_source, str) + + def stores_body_in_memory(self, operation_name): + return False + + def provide_transfer_size(self, transfer_future): + transfer_future.meta.provide_transfer_size( + self._osutil.get_file_size(transfer_future.meta.call_args.fileobj) + ) + + def requires_multipart_upload(self, transfer_future, config): + return transfer_future.meta.size >= config.multipart_threshold + + def get_put_object_body(self, transfer_future): + # Get a file-like object for the given input + fileobj, full_size = self._get_put_object_fileobj_with_full_size( + transfer_future + ) + + # Wrap fileobj with interrupt reader that will quickly cancel + # uploads if needed instead of having to wait for the socket + # to completely read all of the data. + fileobj = self._wrap_fileobj(fileobj) + + callbacks = self._get_progress_callbacks(transfer_future) + close_callbacks = self._get_close_callbacks(callbacks) + size = transfer_future.meta.size + # Return the file-like object wrapped into a ReadFileChunk to get + # progress. + return self._osutil.open_file_chunk_reader_from_fileobj( + fileobj=fileobj, + chunk_size=size, + full_file_size=full_size, + callbacks=callbacks, + close_callbacks=close_callbacks, + ) + + def yield_upload_part_bodies(self, transfer_future, chunksize): + full_file_size = transfer_future.meta.size + num_parts = self._get_num_parts(transfer_future, chunksize) + for part_number in range(1, num_parts + 1): + callbacks = self._get_progress_callbacks(transfer_future) + close_callbacks = self._get_close_callbacks(callbacks) + start_byte = chunksize * (part_number - 1) + # Get a file-like object for that part and the size of the full + # file size for the associated file-like object for that part. + fileobj, full_size = self._get_upload_part_fileobj_with_full_size( + transfer_future.meta.call_args.fileobj, + start_byte=start_byte, + part_size=chunksize, + full_file_size=full_file_size, + ) + + # Wrap fileobj with interrupt reader that will quickly cancel + # uploads if needed instead of having to wait for the socket + # to completely read all of the data. + fileobj = self._wrap_fileobj(fileobj) + + # Wrap the file-like object into a ReadFileChunk to get progress. + read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj( + fileobj=fileobj, + chunk_size=chunksize, + full_file_size=full_size, + callbacks=callbacks, + close_callbacks=close_callbacks, + ) + yield part_number, read_file_chunk + + def _get_deferred_open_file(self, fileobj, start_byte): + fileobj = DeferredOpenFile( + fileobj, start_byte, open_function=self._osutil.open + ) + return fileobj + + def _get_put_object_fileobj_with_full_size(self, transfer_future): + fileobj = transfer_future.meta.call_args.fileobj + size = transfer_future.meta.size + return self._get_deferred_open_file(fileobj, 0), size + + def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs): + start_byte = kwargs['start_byte'] + full_size = kwargs['full_file_size'] + return self._get_deferred_open_file(fileobj, start_byte), full_size + + def _get_num_parts(self, transfer_future, part_size): + return int(math.ceil(transfer_future.meta.size / float(part_size))) + + +class UploadSeekableInputManager(UploadFilenameInputManager): + """Upload utility for an open file object""" + + @classmethod + def is_compatible(cls, upload_source): + return readable(upload_source) and seekable(upload_source) + + def stores_body_in_memory(self, operation_name): + if operation_name == 'put_object': + return False + else: + return True + + def provide_transfer_size(self, transfer_future): + fileobj = transfer_future.meta.call_args.fileobj + # To determine size, first determine the starting position + # Seek to the end and then find the difference in the length + # between the end and start positions. + start_position = fileobj.tell() + fileobj.seek(0, 2) + end_position = fileobj.tell() + fileobj.seek(start_position) + transfer_future.meta.provide_transfer_size( + end_position - start_position + ) + + def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs): + # Note: It is unfortunate that in order to do a multithreaded + # multipart upload we cannot simply copy the filelike object + # since there is not really a mechanism in python (i.e. os.dup + # points to the same OS filehandle which causes concurrency + # issues). So instead we need to read from the fileobj and + # chunk the data out to separate file-like objects in memory. + data = fileobj.read(kwargs['part_size']) + # We return the length of the data instead of the full_file_size + # because we partitioned the data into separate BytesIO objects + # meaning the BytesIO object has no knowledge of its start position + # relative the input source nor access to the rest of the input + # source. So we must treat it as its own standalone file. + return BytesIO(data), len(data) + + def _get_put_object_fileobj_with_full_size(self, transfer_future): + fileobj = transfer_future.meta.call_args.fileobj + # The current position needs to be taken into account when retrieving + # the full size of the file. + size = fileobj.tell() + transfer_future.meta.size + return fileobj, size + + +class UploadNonSeekableInputManager(UploadInputManager): + """Upload utility for a file-like object that cannot seek.""" + + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): + super().__init__(osutil, transfer_coordinator, bandwidth_limiter) + self._initial_data = b'' + + @classmethod + def is_compatible(cls, upload_source): + return readable(upload_source) + + def stores_body_in_memory(self, operation_name): + return True + + def provide_transfer_size(self, transfer_future): + # No-op because there is no way to do this short of reading the entire + # body into memory. + return + + def requires_multipart_upload(self, transfer_future, config): + # If the user has set the size, we can use that. + if transfer_future.meta.size is not None: + return transfer_future.meta.size >= config.multipart_threshold + + # This is tricky to determine in this case because we can't know how + # large the input is. So to figure it out, we read data into memory + # up until the threshold and compare how much data was actually read + # against the threshold. + fileobj = transfer_future.meta.call_args.fileobj + threshold = config.multipart_threshold + self._initial_data = self._read(fileobj, threshold, False) + if len(self._initial_data) < threshold: + return False + else: + return True + + def get_put_object_body(self, transfer_future): + callbacks = self._get_progress_callbacks(transfer_future) + close_callbacks = self._get_close_callbacks(callbacks) + fileobj = transfer_future.meta.call_args.fileobj + + body = self._wrap_data( + self._initial_data + fileobj.read(), callbacks, close_callbacks + ) + + # Zero out the stored data so we don't have additional copies + # hanging around in memory. + self._initial_data = None + return body + + def yield_upload_part_bodies(self, transfer_future, chunksize): + file_object = transfer_future.meta.call_args.fileobj + part_number = 0 + + # Continue reading parts from the file-like object until it is empty. + while True: + callbacks = self._get_progress_callbacks(transfer_future) + close_callbacks = self._get_close_callbacks(callbacks) + part_number += 1 + part_content = self._read(file_object, chunksize) + if not part_content: + break + part_object = self._wrap_data( + part_content, callbacks, close_callbacks + ) + + # Zero out part_content to avoid hanging on to additional data. + part_content = None + yield part_number, part_object + + def _read(self, fileobj, amount, truncate=True): + """ + Reads a specific amount of data from a stream and returns it. If there + is any data in initial_data, that will be popped out first. + + :type fileobj: A file-like object that implements read + :param fileobj: The stream to read from. + + :type amount: int + :param amount: The number of bytes to read from the stream. + + :type truncate: bool + :param truncate: Whether or not to truncate initial_data after + reading from it. + + :return: Generator which generates part bodies from the initial data. + """ + # If the the initial data is empty, we simply read from the fileobj + if len(self._initial_data) == 0: + return fileobj.read(amount) + + # If the requested number of bytes is less than the amount of + # initial data, pull entirely from initial data. + if amount <= len(self._initial_data): + data = self._initial_data[:amount] + # Truncate initial data so we don't hang onto the data longer + # than we need. + if truncate: + self._initial_data = self._initial_data[amount:] + return data + + # At this point there is some initial data left, but not enough to + # satisfy the number of bytes requested. Pull out the remaining + # initial data and read the rest from the fileobj. + amount_to_read = amount - len(self._initial_data) + data = self._initial_data + fileobj.read(amount_to_read) + + # Zero out initial data so we don't hang onto the data any more. + if truncate: + self._initial_data = b'' + return data + + def _wrap_data(self, data, callbacks, close_callbacks): + """ + Wraps data with the interrupt reader and the file chunk reader. + + :type data: bytes + :param data: The data to wrap. + + :type callbacks: list + :param callbacks: The callbacks associated with the transfer future. + + :type close_callbacks: list + :param close_callbacks: The callbacks to be called when closing the + wrapper for the data. + + :return: Fully wrapped data. + """ + fileobj = self._wrap_fileobj(BytesIO(data)) + return self._osutil.open_file_chunk_reader_from_fileobj( + fileobj=fileobj, + chunk_size=len(data), + full_file_size=len(data), + callbacks=callbacks, + close_callbacks=close_callbacks, + ) + + +class UploadSubmissionTask(SubmissionTask): + """Task for submitting tasks to execute an upload""" + + UPLOAD_PART_ARGS = [ + 'SSECustomerKey', + 'SSECustomerAlgorithm', + 'SSECustomerKeyMD5', + 'RequestPayer', + 'ExpectedBucketOwner', + ] + + COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] + + def _get_upload_input_manager_cls(self, transfer_future): + """Retrieves a class for managing input for an upload based on file type + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future for the request + + :rtype: class of UploadInputManager + :returns: The appropriate class to use for managing a specific type of + input for uploads. + """ + upload_manager_resolver_chain = [ + UploadFilenameInputManager, + UploadSeekableInputManager, + UploadNonSeekableInputManager, + ] + + fileobj = transfer_future.meta.call_args.fileobj + for upload_manager_cls in upload_manager_resolver_chain: + if upload_manager_cls.is_compatible(fileobj): + return upload_manager_cls + raise RuntimeError( + 'Input {} of type: {} is not supported.'.format( + fileobj, type(fileobj) + ) + ) + + def _submit( + self, + client, + config, + osutil, + request_executor, + transfer_future, + bandwidth_limiter=None, + ): + """ + :param client: The client associated with the transfer manager + + :type config: s3transfer.manager.TransferConfig + :param config: The transfer config associated with the transfer + manager + + :type osutil: s3transfer.utils.OSUtil + :param osutil: The os utility associated to the transfer manager + + :type request_executor: s3transfer.futures.BoundedExecutor + :param request_executor: The request executor associated with the + transfer manager + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future associated with the + transfer request that tasks are being submitted for + """ + upload_input_manager = self._get_upload_input_manager_cls( + transfer_future + )(osutil, self._transfer_coordinator, bandwidth_limiter) + + # Determine the size if it was not provided + if transfer_future.meta.size is None: + upload_input_manager.provide_transfer_size(transfer_future) + + # Do a multipart upload if needed, otherwise do a regular put object. + if not upload_input_manager.requires_multipart_upload( + transfer_future, config + ): + self._submit_upload_request( + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ) + else: + self._submit_multipart_request( + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ) + + def _submit_upload_request( + self, + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ): + call_args = transfer_future.meta.call_args + + # Get any tags that need to be associated to the put object task + put_object_tag = self._get_upload_task_tag( + upload_input_manager, 'put_object' + ) + + # Submit the request of a single upload. + self._transfer_coordinator.submit( + request_executor, + PutObjectTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'fileobj': upload_input_manager.get_put_object_body( + transfer_future + ), + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': call_args.extra_args, + }, + is_final=True, + ), + tag=put_object_tag, + ) + + def _submit_multipart_request( + self, + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ): + call_args = transfer_future.meta.call_args + + # Submit the request to create a multipart upload. + create_multipart_future = self._transfer_coordinator.submit( + request_executor, + CreateMultipartUploadTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': call_args.extra_args, + }, + ), + ) + + # Submit requests to upload the parts of the file. + part_futures = [] + extra_part_args = self._extra_upload_part_args(call_args.extra_args) + + # Get any tags that need to be associated to the submitted task + # for upload the data + upload_part_tag = self._get_upload_task_tag( + upload_input_manager, 'upload_part' + ) + + size = transfer_future.meta.size + adjuster = ChunksizeAdjuster() + chunksize = adjuster.adjust_chunksize(config.multipart_chunksize, size) + part_iterator = upload_input_manager.yield_upload_part_bodies( + transfer_future, chunksize + ) + + for part_number, fileobj in part_iterator: + part_futures.append( + self._transfer_coordinator.submit( + request_executor, + UploadPartTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'fileobj': fileobj, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'part_number': part_number, + 'extra_args': extra_part_args, + }, + pending_main_kwargs={ + 'upload_id': create_multipart_future + }, + ), + tag=upload_part_tag, + ) + ) + + complete_multipart_extra_args = self._extra_complete_multipart_args( + call_args.extra_args + ) + # Submit the request to complete the multipart upload. + self._transfer_coordinator.submit( + request_executor, + CompleteMultipartUploadTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'client': client, + 'bucket': call_args.bucket, + 'key': call_args.key, + 'extra_args': complete_multipart_extra_args, + }, + pending_main_kwargs={ + 'upload_id': create_multipart_future, + 'parts': part_futures, + }, + is_final=True, + ), + ) + + def _extra_upload_part_args(self, extra_args): + # Only the args in UPLOAD_PART_ARGS actually need to be passed + # onto the upload_part calls. + return get_filtered_dict(extra_args, self.UPLOAD_PART_ARGS) + + def _extra_complete_multipart_args(self, extra_args): + return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS) + + def _get_upload_task_tag(self, upload_input_manager, operation_name): + tag = None + if upload_input_manager.stores_body_in_memory(operation_name): + tag = IN_MEMORY_UPLOAD_TAG + return tag + + +class PutObjectTask(Task): + """Task to do a nonmultipart upload""" + + def _main(self, client, fileobj, bucket, key, extra_args): + """ + :param client: The client to use when calling PutObject + :param fileobj: The file to upload. + :param bucket: The name of the bucket to upload to + :param key: The name of the key to upload to + :param extra_args: A dictionary of any extra arguments that may be + used in the upload. + """ + with fileobj as body: + client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args) + + +class UploadPartTask(Task): + """Task to upload a part in a multipart upload""" + + def _main( + self, client, fileobj, bucket, key, upload_id, part_number, extra_args + ): + """ + :param client: The client to use when calling PutObject + :param fileobj: The file to upload. + :param bucket: The name of the bucket to upload to + :param key: The name of the key to upload to + :param upload_id: The id of the upload + :param part_number: The number representing the part of the multipart + upload + :param extra_args: A dictionary of any extra arguments that may be + used in the upload. + + :rtype: dict + :returns: A dictionary representing a part:: + + {'Etag': etag_value, 'PartNumber': part_number} + + This value can be appended to a list to be used to complete + the multipart upload. + """ + with fileobj as body: + response = client.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=body, + **extra_args + ) + etag = response['ETag'] + return {'ETag': etag, 'PartNumber': part_number} diff --git a/contrib/python/s3transfer/py3/s3transfer/utils.py b/contrib/python/s3transfer/py3/s3transfer/utils.py new file mode 100644 index 0000000000..ba881c67dd --- /dev/null +++ b/contrib/python/s3transfer/py3/s3transfer/utils.py @@ -0,0 +1,802 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import functools +import logging +import math +import os +import random +import socket +import stat +import string +import threading +from collections import defaultdict + +from botocore.exceptions import IncompleteReadError, ReadTimeoutError + +from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file + +MAX_PARTS = 10000 +# The maximum file size you can upload via S3 per request. +# See: http://docs.aws.amazon.com/AmazonS3/latest/dev/UploadingObjects.html +# and: http://docs.aws.amazon.com/AmazonS3/latest/dev/qfacts.html +MAX_SINGLE_UPLOAD_SIZE = 5 * (1024 ** 3) +MIN_UPLOAD_CHUNKSIZE = 5 * (1024 ** 2) +logger = logging.getLogger(__name__) + + +S3_RETRYABLE_DOWNLOAD_ERRORS = ( + socket.timeout, + SOCKET_ERROR, + ReadTimeoutError, + IncompleteReadError, +) + + +def random_file_extension(num_digits=8): + return ''.join(random.choice(string.hexdigits) for _ in range(num_digits)) + + +def signal_not_transferring(request, operation_name, **kwargs): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'signal_not_transferring' + ): + request.body.signal_not_transferring() + + +def signal_transferring(request, operation_name, **kwargs): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'signal_transferring' + ): + request.body.signal_transferring() + + +def calculate_num_parts(size, part_size): + return int(math.ceil(size / float(part_size))) + + +def calculate_range_parameter( + part_size, part_index, num_parts, total_size=None +): + """Calculate the range parameter for multipart downloads/copies + + :type part_size: int + :param part_size: The size of the part + + :type part_index: int + :param part_index: The index for which this parts starts. This index starts + at zero + + :type num_parts: int + :param num_parts: The total number of parts in the transfer + + :returns: The value to use for Range parameter on downloads or + the CopySourceRange parameter for copies + """ + # Used to calculate the Range parameter + start_range = part_index * part_size + if part_index == num_parts - 1: + end_range = '' + if total_size is not None: + end_range = str(total_size - 1) + else: + end_range = start_range + part_size - 1 + range_param = f'bytes={start_range}-{end_range}' + return range_param + + +def get_callbacks(transfer_future, callback_type): + """Retrieves callbacks from a subscriber + + :type transfer_future: s3transfer.futures.TransferFuture + :param transfer_future: The transfer future the subscriber is associated + to. + + :type callback_type: str + :param callback_type: The type of callback to retrieve from the subscriber. + Valid types include: + * 'queued' + * 'progress' + * 'done' + + :returns: A list of callbacks for the type specified. All callbacks are + preinjected with the transfer future. + """ + callbacks = [] + for subscriber in transfer_future.meta.call_args.subscribers: + callback_name = 'on_' + callback_type + if hasattr(subscriber, callback_name): + callbacks.append( + functools.partial( + getattr(subscriber, callback_name), future=transfer_future + ) + ) + return callbacks + + +def invoke_progress_callbacks(callbacks, bytes_transferred): + """Calls all progress callbacks + + :param callbacks: A list of progress callbacks to invoke + :param bytes_transferred: The number of bytes transferred. This is passed + to the callbacks. If no bytes were transferred the callbacks will not + be invoked because no progress was achieved. It is also possible + to receive a negative amount which comes from retrying a transfer + request. + """ + # Only invoke the callbacks if bytes were actually transferred. + if bytes_transferred: + for callback in callbacks: + callback(bytes_transferred=bytes_transferred) + + +def get_filtered_dict(original_dict, whitelisted_keys): + """Gets a dictionary filtered by whitelisted keys + + :param original_dict: The original dictionary of arguments to source keys + and values. + :param whitelisted_key: A list of keys to include in the filtered + dictionary. + + :returns: A dictionary containing key/values from the original dictionary + whose key was included in the whitelist + """ + filtered_dict = {} + for key, value in original_dict.items(): + if key in whitelisted_keys: + filtered_dict[key] = value + return filtered_dict + + +class CallArgs: + def __init__(self, **kwargs): + """A class that records call arguments + + The call arguments must be passed as keyword arguments. It will set + each keyword argument as an attribute of the object along with its + associated value. + """ + for arg, value in kwargs.items(): + setattr(self, arg, value) + + +class FunctionContainer: + """An object that contains a function and any args or kwargs to call it + + When called the provided function will be called with provided args + and kwargs. + """ + + def __init__(self, func, *args, **kwargs): + self._func = func + self._args = args + self._kwargs = kwargs + + def __repr__(self): + return 'Function: {} with args {} and kwargs {}'.format( + self._func, self._args, self._kwargs + ) + + def __call__(self): + return self._func(*self._args, **self._kwargs) + + +class CountCallbackInvoker: + """An abstraction to invoke a callback when a shared count reaches zero + + :param callback: Callback invoke when finalized count reaches zero + """ + + def __init__(self, callback): + self._lock = threading.Lock() + self._callback = callback + self._count = 0 + self._is_finalized = False + + @property + def current_count(self): + with self._lock: + return self._count + + def increment(self): + """Increment the count by one""" + with self._lock: + if self._is_finalized: + raise RuntimeError( + 'Counter has been finalized it can no longer be ' + 'incremented.' + ) + self._count += 1 + + def decrement(self): + """Decrement the count by one""" + with self._lock: + if self._count == 0: + raise RuntimeError( + 'Counter is at zero. It cannot dip below zero' + ) + self._count -= 1 + if self._is_finalized and self._count == 0: + self._callback() + + def finalize(self): + """Finalize the counter + + Once finalized, the counter never be incremented and the callback + can be invoked once the count reaches zero + """ + with self._lock: + self._is_finalized = True + if self._count == 0: + self._callback() + + +class OSUtils: + _MAX_FILENAME_LEN = 255 + + def get_file_size(self, filename): + return os.path.getsize(filename) + + def open_file_chunk_reader(self, filename, start_byte, size, callbacks): + return ReadFileChunk.from_filename( + filename, start_byte, size, callbacks, enable_callbacks=False + ) + + def open_file_chunk_reader_from_fileobj( + self, + fileobj, + chunk_size, + full_file_size, + callbacks, + close_callbacks=None, + ): + return ReadFileChunk( + fileobj, + chunk_size, + full_file_size, + callbacks=callbacks, + enable_callbacks=False, + close_callbacks=close_callbacks, + ) + + def open(self, filename, mode): + return open(filename, mode) + + def remove_file(self, filename): + """Remove a file, noop if file does not exist.""" + # Unlike os.remove, if the file does not exist, + # then this method does nothing. + try: + os.remove(filename) + except OSError: + pass + + def rename_file(self, current_filename, new_filename): + rename_file(current_filename, new_filename) + + def is_special_file(cls, filename): + """Checks to see if a file is a special UNIX file. + + It checks if the file is a character special device, block special + device, FIFO, or socket. + + :param filename: Name of the file + + :returns: True if the file is a special file. False, if is not. + """ + # If it does not exist, it must be a new file so it cannot be + # a special file. + if not os.path.exists(filename): + return False + mode = os.stat(filename).st_mode + # Character special device. + if stat.S_ISCHR(mode): + return True + # Block special device + if stat.S_ISBLK(mode): + return True + # Named pipe / FIFO + if stat.S_ISFIFO(mode): + return True + # Socket. + if stat.S_ISSOCK(mode): + return True + return False + + def get_temp_filename(self, filename): + suffix = os.extsep + random_file_extension() + path = os.path.dirname(filename) + name = os.path.basename(filename) + temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix + return os.path.join(path, temp_filename) + + def allocate(self, filename, size): + try: + with self.open(filename, 'wb') as f: + fallocate(f, size) + except OSError: + self.remove_file(filename) + raise + + +class DeferredOpenFile: + def __init__(self, filename, start_byte=0, mode='rb', open_function=open): + """A class that defers the opening of a file till needed + + This is useful for deferring opening of a file till it is needed + in a separate thread, as there is a limit of how many open files + there can be in a single thread for most operating systems. The + file gets opened in the following methods: ``read()``, ``seek()``, + and ``__enter__()`` + + :type filename: str + :param filename: The name of the file to open + + :type start_byte: int + :param start_byte: The byte to seek to when the file is opened. + + :type mode: str + :param mode: The mode to use to open the file + + :type open_function: function + :param open_function: The function to use to open the file + """ + self._filename = filename + self._fileobj = None + self._start_byte = start_byte + self._mode = mode + self._open_function = open_function + + def _open_if_needed(self): + if self._fileobj is None: + self._fileobj = self._open_function(self._filename, self._mode) + if self._start_byte != 0: + self._fileobj.seek(self._start_byte) + + @property + def name(self): + return self._filename + + def read(self, amount=None): + self._open_if_needed() + return self._fileobj.read(amount) + + def write(self, data): + self._open_if_needed() + self._fileobj.write(data) + + def seek(self, where, whence=0): + self._open_if_needed() + self._fileobj.seek(where, whence) + + def tell(self): + if self._fileobj is None: + return self._start_byte + return self._fileobj.tell() + + def close(self): + if self._fileobj: + self._fileobj.close() + + def __enter__(self): + self._open_if_needed() + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class ReadFileChunk: + def __init__( + self, + fileobj, + chunk_size, + full_file_size, + callbacks=None, + enable_callbacks=True, + close_callbacks=None, + ): + """ + + Given a file object shown below:: + + |___________________________________________________| + 0 | | full_file_size + |----chunk_size---| + f.tell() + + :type fileobj: file + :param fileobj: File like object + + :type chunk_size: int + :param chunk_size: The max chunk size to read. Trying to read + pass the end of the chunk size will behave like you've + reached the end of the file. + + :type full_file_size: int + :param full_file_size: The entire content length associated + with ``fileobj``. + + :type callbacks: A list of function(amount_read) + :param callbacks: Called whenever data is read from this object in the + order provided. + + :type enable_callbacks: boolean + :param enable_callbacks: True if to run callbacks. Otherwise, do not + run callbacks + + :type close_callbacks: A list of function() + :param close_callbacks: Called when close is called. The function + should take no arguments. + """ + self._fileobj = fileobj + self._start_byte = self._fileobj.tell() + self._size = self._calculate_file_size( + self._fileobj, + requested_size=chunk_size, + start_byte=self._start_byte, + actual_file_size=full_file_size, + ) + # _amount_read represents the position in the chunk and may exceed + # the chunk size, but won't allow reads out of bounds. + self._amount_read = 0 + self._callbacks = callbacks + if callbacks is None: + self._callbacks = [] + self._callbacks_enabled = enable_callbacks + self._close_callbacks = close_callbacks + if close_callbacks is None: + self._close_callbacks = close_callbacks + + @classmethod + def from_filename( + cls, + filename, + start_byte, + chunk_size, + callbacks=None, + enable_callbacks=True, + ): + """Convenience factory function to create from a filename. + + :type start_byte: int + :param start_byte: The first byte from which to start reading. + + :type chunk_size: int + :param chunk_size: The max chunk size to read. Trying to read + pass the end of the chunk size will behave like you've + reached the end of the file. + + :type full_file_size: int + :param full_file_size: The entire content length associated + with ``fileobj``. + + :type callbacks: function(amount_read) + :param callbacks: Called whenever data is read from this object. + + :type enable_callbacks: bool + :param enable_callbacks: Indicate whether to invoke callback + during read() calls. + + :rtype: ``ReadFileChunk`` + :return: A new instance of ``ReadFileChunk`` + + """ + f = open(filename, 'rb') + f.seek(start_byte) + file_size = os.fstat(f.fileno()).st_size + return cls(f, chunk_size, file_size, callbacks, enable_callbacks) + + def _calculate_file_size( + self, fileobj, requested_size, start_byte, actual_file_size + ): + max_chunk_size = actual_file_size - start_byte + return min(max_chunk_size, requested_size) + + def read(self, amount=None): + amount_left = max(self._size - self._amount_read, 0) + if amount is None: + amount_to_read = amount_left + else: + amount_to_read = min(amount_left, amount) + data = self._fileobj.read(amount_to_read) + self._amount_read += len(data) + if self._callbacks is not None and self._callbacks_enabled: + invoke_progress_callbacks(self._callbacks, len(data)) + return data + + def signal_transferring(self): + self.enable_callback() + if hasattr(self._fileobj, 'signal_transferring'): + self._fileobj.signal_transferring() + + def signal_not_transferring(self): + self.disable_callback() + if hasattr(self._fileobj, 'signal_not_transferring'): + self._fileobj.signal_not_transferring() + + def enable_callback(self): + self._callbacks_enabled = True + + def disable_callback(self): + self._callbacks_enabled = False + + def seek(self, where, whence=0): + if whence not in (0, 1, 2): + # Mimic io's error for invalid whence values + raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)") + + # Recalculate where based on chunk attributes so seek from file + # start (whence=0) is always used + where += self._start_byte + if whence == 1: + where += self._amount_read + elif whence == 2: + where += self._size + + self._fileobj.seek(max(where, self._start_byte)) + if self._callbacks is not None and self._callbacks_enabled: + # To also rewind the callback() for an accurate progress report + bounded_where = max(min(where - self._start_byte, self._size), 0) + bounded_amount_read = min(self._amount_read, self._size) + amount = bounded_where - bounded_amount_read + invoke_progress_callbacks( + self._callbacks, bytes_transferred=amount + ) + self._amount_read = max(where - self._start_byte, 0) + + def close(self): + if self._close_callbacks is not None and self._callbacks_enabled: + for callback in self._close_callbacks: + callback() + self._fileobj.close() + + def tell(self): + return self._amount_read + + def __len__(self): + # __len__ is defined because requests will try to determine the length + # of the stream to set a content length. In the normal case + # of the file it will just stat the file, but we need to change that + # behavior. By providing a __len__, requests will use that instead + # of stat'ing the file. + return self._size + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + def __iter__(self): + # This is a workaround for http://bugs.python.org/issue17575 + # Basically httplib will try to iterate over the contents, even + # if its a file like object. This wasn't noticed because we've + # already exhausted the stream so iterating over the file immediately + # stops, which is what we're simulating here. + return iter([]) + + +class StreamReaderProgress: + """Wrapper for a read only stream that adds progress callbacks.""" + + def __init__(self, stream, callbacks=None): + self._stream = stream + self._callbacks = callbacks + if callbacks is None: + self._callbacks = [] + + def read(self, *args, **kwargs): + value = self._stream.read(*args, **kwargs) + invoke_progress_callbacks(self._callbacks, len(value)) + return value + + +class NoResourcesAvailable(Exception): + pass + + +class TaskSemaphore: + def __init__(self, count): + """A semaphore for the purpose of limiting the number of tasks + + :param count: The size of semaphore + """ + self._semaphore = threading.Semaphore(count) + + def acquire(self, tag, blocking=True): + """Acquire the semaphore + + :param tag: A tag identifying what is acquiring the semaphore. Note + that this is not really needed to directly use this class but is + needed for API compatibility with the SlidingWindowSemaphore + implementation. + :param block: If True, block until it can be acquired. If False, + do not block and raise an exception if cannot be acquired. + + :returns: A token (can be None) to use when releasing the semaphore + """ + logger.debug("Acquiring %s", tag) + if not self._semaphore.acquire(blocking): + raise NoResourcesAvailable("Cannot acquire tag '%s'" % tag) + + def release(self, tag, acquire_token): + """Release the semaphore + + :param tag: A tag identifying what is releasing the semaphore + :param acquire_token: The token returned from when the semaphore was + acquired. Note that this is not really needed to directly use this + class but is needed for API compatibility with the + SlidingWindowSemaphore implementation. + """ + logger.debug(f"Releasing acquire {tag}/{acquire_token}") + self._semaphore.release() + + +class SlidingWindowSemaphore(TaskSemaphore): + """A semaphore used to coordinate sequential resource access. + + This class is similar to the stdlib BoundedSemaphore: + + * It's initialized with a count. + * Each call to ``acquire()`` decrements the counter. + * If the count is at zero, then ``acquire()`` will either block until the + count increases, or if ``blocking=False``, then it will raise + a NoResourcesAvailable exception indicating that it failed to acquire the + semaphore. + + The main difference is that this semaphore is used to limit + access to a resource that requires sequential access. For example, + if I want to access resource R that has 20 subresources R_0 - R_19, + this semaphore can also enforce that you only have a max range of + 10 at any given point in time. You must also specify a tag name + when you acquire the semaphore. The sliding window semantics apply + on a per tag basis. The internal count will only be incremented + when the minimum sequence number for a tag is released. + + """ + + def __init__(self, count): + self._count = count + # Dict[tag, next_sequence_number]. + self._tag_sequences = defaultdict(int) + self._lowest_sequence = {} + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + # Dict[tag, List[sequence_number]] + self._pending_release = {} + + def current_count(self): + with self._lock: + return self._count + + def acquire(self, tag, blocking=True): + logger.debug("Acquiring %s", tag) + self._condition.acquire() + try: + if self._count == 0: + if not blocking: + raise NoResourcesAvailable("Cannot acquire tag '%s'" % tag) + else: + while self._count == 0: + self._condition.wait() + # self._count is no longer zero. + # First, check if this is the first time we're seeing this tag. + sequence_number = self._tag_sequences[tag] + if sequence_number == 0: + # First time seeing the tag, so record we're at 0. + self._lowest_sequence[tag] = sequence_number + self._tag_sequences[tag] += 1 + self._count -= 1 + return sequence_number + finally: + self._condition.release() + + def release(self, tag, acquire_token): + sequence_number = acquire_token + logger.debug("Releasing acquire %s/%s", tag, sequence_number) + self._condition.acquire() + try: + if tag not in self._tag_sequences: + raise ValueError("Attempted to release unknown tag: %s" % tag) + max_sequence = self._tag_sequences[tag] + if self._lowest_sequence[tag] == sequence_number: + # We can immediately process this request and free up + # resources. + self._lowest_sequence[tag] += 1 + self._count += 1 + self._condition.notify() + queued = self._pending_release.get(tag, []) + while queued: + if self._lowest_sequence[tag] == queued[-1]: + queued.pop() + self._lowest_sequence[tag] += 1 + self._count += 1 + else: + break + elif self._lowest_sequence[tag] < sequence_number < max_sequence: + # We can't do anything right now because we're still waiting + # for the min sequence for the tag to be released. We have + # to queue this for pending release. + self._pending_release.setdefault(tag, []).append( + sequence_number + ) + self._pending_release[tag].sort(reverse=True) + else: + raise ValueError( + "Attempted to release unknown sequence number " + "%s for tag: %s" % (sequence_number, tag) + ) + finally: + self._condition.release() + + +class ChunksizeAdjuster: + def __init__( + self, + max_size=MAX_SINGLE_UPLOAD_SIZE, + min_size=MIN_UPLOAD_CHUNKSIZE, + max_parts=MAX_PARTS, + ): + self.max_size = max_size + self.min_size = min_size + self.max_parts = max_parts + + def adjust_chunksize(self, current_chunksize, file_size=None): + """Get a chunksize close to current that fits within all S3 limits. + + :type current_chunksize: int + :param current_chunksize: The currently configured chunksize. + + :type file_size: int or None + :param file_size: The size of the file to upload. This might be None + if the object being transferred has an unknown size. + + :returns: A valid chunksize that fits within configured limits. + """ + chunksize = current_chunksize + if file_size is not None: + chunksize = self._adjust_for_max_parts(chunksize, file_size) + return self._adjust_for_chunksize_limits(chunksize) + + def _adjust_for_chunksize_limits(self, current_chunksize): + if current_chunksize > self.max_size: + logger.debug( + "Chunksize greater than maximum chunksize. " + "Setting to %s from %s." % (self.max_size, current_chunksize) + ) + return self.max_size + elif current_chunksize < self.min_size: + logger.debug( + "Chunksize less than minimum chunksize. " + "Setting to %s from %s." % (self.min_size, current_chunksize) + ) + return self.min_size + else: + return current_chunksize + + def _adjust_for_max_parts(self, current_chunksize, file_size): + chunksize = current_chunksize + num_parts = int(math.ceil(file_size / float(chunksize))) + + while num_parts > self.max_parts: + chunksize *= 2 + num_parts = int(math.ceil(file_size / float(chunksize))) + + if chunksize != current_chunksize: + logger.debug( + "Chunksize would result in the number of parts exceeding the " + "maximum. Setting to %s from %s." + % (chunksize, current_chunksize) + ) + + return chunksize diff --git a/contrib/python/s3transfer/py3/tests/__init__.py b/contrib/python/s3transfer/py3/tests/__init__.py new file mode 100644 index 0000000000..e36c4936bf --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/__init__.py @@ -0,0 +1,531 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import hashlib +import io +import math +import os +import platform +import shutil +import string +import tempfile +import unittest +from unittest import mock # noqa: F401 + +import botocore.session +from botocore.stub import Stubber + +from s3transfer.futures import ( + IN_MEMORY_DOWNLOAD_TAG, + IN_MEMORY_UPLOAD_TAG, + BoundedExecutor, + NonThreadedExecutor, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.manager import TransferConfig +from s3transfer.subscribers import BaseSubscriber +from s3transfer.utils import ( + CallArgs, + OSUtils, + SlidingWindowSemaphore, + TaskSemaphore, +) + +ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS +# Detect if CRT is available for use +try: + import awscrt.s3 # noqa: F401 + + HAS_CRT = True +except ImportError: + HAS_CRT = False + + +def setup_package(): + if is_serial_implementation(): + BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor + + +def teardown_package(): + BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS + + +def is_serial_implementation(): + return os.environ.get('USE_SERIAL_EXECUTOR', False) + + +def assert_files_equal(first, second): + if os.path.getsize(first) != os.path.getsize(second): + raise AssertionError(f"Files are not equal: {first}, {second}") + first_md5 = md5_checksum(first) + second_md5 = md5_checksum(second) + if first_md5 != second_md5: + raise AssertionError( + "Files are not equal: {}(md5={}) != {}(md5={})".format( + first, first_md5, second, second_md5 + ) + ) + + +def md5_checksum(filename): + checksum = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(8192), b''): + checksum.update(chunk) + return checksum.hexdigest() + + +def random_bucket_name(prefix='s3transfer', num_chars=10): + base = string.ascii_lowercase + string.digits + random_bytes = bytearray(os.urandom(num_chars)) + return prefix + ''.join([base[b % len(base)] for b in random_bytes]) + + +def skip_if_windows(reason): + """Decorator to skip tests that should not be run on windows. + + Example usage: + + @skip_if_windows("Not valid") + def test_some_non_windows_stuff(self): + self.assertEqual(...) + + """ + + def decorator(func): + return unittest.skipIf( + platform.system() not in ['Darwin', 'Linux'], reason + )(func) + + return decorator + + +def skip_if_using_serial_implementation(reason): + """Decorator to skip tests when running as the serial implementation""" + + def decorator(func): + return unittest.skipIf(is_serial_implementation(), reason)(func) + + return decorator + + +def requires_crt(cls, reason=None): + if reason is None: + reason = "Test requires awscrt to be installed." + return unittest.skipIf(not HAS_CRT, reason)(cls) + + +class StreamWithError: + """A wrapper to simulate errors while reading from a stream + + :param stream: The underlying stream to read from + :param exception_type: The exception type to throw + :param num_reads: The number of times to allow a read before raising + the exception. A value of zero indicates to raise the error on the + first read. + """ + + def __init__(self, stream, exception_type, num_reads=0): + self._stream = stream + self._exception_type = exception_type + self._num_reads = num_reads + self._count = 0 + + def read(self, n=-1): + if self._count == self._num_reads: + raise self._exception_type + self._count += 1 + return self._stream.read(n) + + +class FileSizeProvider: + def __init__(self, file_size): + self.file_size = file_size + + def on_queued(self, future, **kwargs): + future.meta.provide_transfer_size(self.file_size) + + +class FileCreator: + def __init__(self): + self.rootdir = tempfile.mkdtemp() + + def remove_all(self): + shutil.rmtree(self.rootdir) + + def create_file(self, filename, contents, mode='w'): + """Creates a file in a tmpdir + ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" + It will be translated into a full path in a tmp dir. + ``mode`` is the mode the file should be opened either as ``w`` or + `wb``. + Returns the full path to the file. + """ + full_path = os.path.join(self.rootdir, filename) + if not os.path.isdir(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + with open(full_path, mode) as f: + f.write(contents) + return full_path + + def create_file_with_size(self, filename, filesize): + filename = self.create_file(filename, contents='') + chunksize = 8192 + with open(filename, 'wb') as f: + for i in range(int(math.ceil(filesize / float(chunksize)))): + f.write(b'a' * chunksize) + return filename + + def append_file(self, filename, contents): + """Append contents to a file + ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" + It will be translated into a full path in a tmp dir. + Returns the full path to the file. + """ + full_path = os.path.join(self.rootdir, filename) + if not os.path.isdir(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + with open(full_path, 'a') as f: + f.write(contents) + return full_path + + def full_path(self, filename): + """Translate relative path to full path in temp dir. + f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt + """ + return os.path.join(self.rootdir, filename) + + +class RecordingOSUtils(OSUtils): + """An OSUtil abstraction that records openings and renamings""" + + def __init__(self): + super().__init__() + self.open_records = [] + self.rename_records = [] + + def open(self, filename, mode): + self.open_records.append((filename, mode)) + return super().open(filename, mode) + + def rename_file(self, current_filename, new_filename): + self.rename_records.append((current_filename, new_filename)) + super().rename_file(current_filename, new_filename) + + +class RecordingSubscriber(BaseSubscriber): + def __init__(self): + self.on_queued_calls = [] + self.on_progress_calls = [] + self.on_done_calls = [] + + def on_queued(self, **kwargs): + self.on_queued_calls.append(kwargs) + + def on_progress(self, **kwargs): + self.on_progress_calls.append(kwargs) + + def on_done(self, **kwargs): + self.on_done_calls.append(kwargs) + + def calculate_bytes_seen(self, **kwargs): + amount_seen = 0 + for call in self.on_progress_calls: + amount_seen += call['bytes_transferred'] + return amount_seen + + +class TransferCoordinatorWithInterrupt(TransferCoordinator): + """Used to inject keyboard interrupts""" + + def result(self): + raise KeyboardInterrupt() + + +class RecordingExecutor: + """A wrapper on an executor to record calls made to submit() + + You can access the submissions property to receive a list of dictionaries + that represents all submissions where the dictionary is formatted:: + + { + 'fn': function + 'args': positional args (as tuple) + 'kwargs': keyword args (as dict) + } + """ + + def __init__(self, executor): + self._executor = executor + self.submissions = [] + + def submit(self, task, tag=None, block=True): + future = self._executor.submit(task, tag, block) + self.submissions.append({'task': task, 'tag': tag, 'block': block}) + return future + + def shutdown(self): + self._executor.shutdown() + + +class StubbedClientTest(unittest.TestCase): + def setUp(self): + self.session = botocore.session.get_session() + self.region = 'us-west-2' + self.client = self.session.create_client( + 's3', + self.region, + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + self.stubber = Stubber(self.client) + self.stubber.activate() + + def tearDown(self): + self.stubber.deactivate() + + def reset_stubber_with_new_client(self, override_client_kwargs): + client_kwargs = { + 'service_name': 's3', + 'region_name': self.region, + 'aws_access_key_id': 'foo', + 'aws_secret_access_key': 'bar', + } + client_kwargs.update(override_client_kwargs) + self.client = self.session.create_client(**client_kwargs) + self.stubber = Stubber(self.client) + self.stubber.activate() + + +class BaseTaskTest(StubbedClientTest): + def setUp(self): + super().setUp() + self.transfer_coordinator = TransferCoordinator() + + def get_task(self, task_cls, **kwargs): + if 'transfer_coordinator' not in kwargs: + kwargs['transfer_coordinator'] = self.transfer_coordinator + return task_cls(**kwargs) + + def get_transfer_future(self, call_args=None): + return TransferFuture( + meta=TransferMeta(call_args), coordinator=self.transfer_coordinator + ) + + +class BaseSubmissionTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.config = TransferConfig() + self.osutil = OSUtils() + self.executor = BoundedExecutor( + 1000, + 1, + { + IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10), + IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10), + }, + ) + + def tearDown(self): + super().tearDown() + self.executor.shutdown() + + +class BaseGeneralInterfaceTest(StubbedClientTest): + """A general test class to ensure consistency across TransferManger methods + + This test should never be called and should be subclassed from to pick up + the various tests that all TransferManager method must pass from a + functionality standpoint. + """ + + __test__ = False + + def manager(self): + """The transfer manager to use""" + raise NotImplementedError('method is not implemented') + + @property + def method(self): + """The transfer manager method to invoke i.e. upload()""" + raise NotImplementedError('method is not implemented') + + def create_call_kwargs(self): + """The kwargs to be passed to the transfer manager method""" + raise NotImplementedError('create_call_kwargs is not implemented') + + def create_invalid_extra_args(self): + """A value for extra_args that will cause validation errors""" + raise NotImplementedError( + 'create_invalid_extra_args is not implemented' + ) + + def create_stubbed_responses(self): + """A list of stubbed responses that will cause the request to succeed + + The elements of this list is a dictionary that will be used as key + word arguments to botocore.Stubber.add_response(). For example:: + + [{'method': 'put_object', 'service_response': {}}] + """ + raise NotImplementedError( + 'create_stubbed_responses is not implemented' + ) + + def create_expected_progress_callback_info(self): + """A list of kwargs expected to be passed to each progress callback + + Note that the future kwargs does not need to be added to each + dictionary provided in the list. This is injected for you. An example + is:: + + [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2} + ] + + This indicates that the progress callback will be called three + times and pass along the specified keyword arguments and corresponding + values. + """ + raise NotImplementedError( + 'create_expected_progress_callback_info is not implemented' + ) + + def _setup_default_stubbed_responses(self): + for stubbed_response in self.create_stubbed_responses(): + self.stubber.add_response(**stubbed_response) + + def test_returns_future_with_meta(self): + self._setup_default_stubbed_responses() + future = self.method(**self.create_call_kwargs()) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert the return value is a future with metadata associated to it. + self.assertIsInstance(future, TransferFuture) + self.assertIsInstance(future.meta, TransferMeta) + + def test_returns_correct_call_args(self): + self._setup_default_stubbed_responses() + call_kwargs = self.create_call_kwargs() + future = self.method(**call_kwargs) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert that there are call args associated to the metadata + self.assertIsInstance(future.meta.call_args, CallArgs) + # Assert that all of the arguments passed to the method exist and + # are of the correct value in call_args. + for param, value in call_kwargs.items(): + self.assertEqual(value, getattr(future.meta.call_args, param)) + + def test_has_transfer_id_associated_to_future(self): + self._setup_default_stubbed_responses() + call_kwargs = self.create_call_kwargs() + future = self.method(**call_kwargs) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert that an transfer id was associated to the future. + # Since there is only one transfer request is made for that transfer + # manager the id will be zero since it will be the first transfer + # request made for that transfer manager. + self.assertEqual(future.meta.transfer_id, 0) + + # If we make a second request, the transfer id should have incremented + # by one for that new TransferFuture. + self._setup_default_stubbed_responses() + future = self.method(**call_kwargs) + future.result() + self.assertEqual(future.meta.transfer_id, 1) + + def test_invalid_extra_args(self): + with self.assertRaisesRegex(ValueError, 'Invalid extra_args'): + self.method( + extra_args=self.create_invalid_extra_args(), + **self.create_call_kwargs(), + ) + + def test_for_callback_kwargs_correctness(self): + # Add the stubbed responses before invoking the method + self._setup_default_stubbed_responses() + + subscriber = RecordingSubscriber() + future = self.method( + subscribers=[subscriber], **self.create_call_kwargs() + ) + # We call shutdown instead of result on future because the future + # could be finished but the done callback could still be going. + # The manager's shutdown method ensures everything completes. + self.manager.shutdown() + + # Assert the various subscribers were called with the + # expected kwargs + expected_progress_calls = self.create_expected_progress_callback_info() + for expected_progress_call in expected_progress_calls: + expected_progress_call['future'] = future + + self.assertEqual(subscriber.on_queued_calls, [{'future': future}]) + self.assertEqual(subscriber.on_progress_calls, expected_progress_calls) + self.assertEqual(subscriber.on_done_calls, [{'future': future}]) + + +class NonSeekableReader(io.RawIOBase): + def __init__(self, b=b''): + super().__init__() + self._data = io.BytesIO(b) + + def seekable(self): + return False + + def writable(self): + return False + + def readable(self): + return True + + def write(self, b): + # This is needed because python will not always return the correct + # kind of error even though writeable returns False. + raise io.UnsupportedOperation("write") + + def read(self, n=-1): + return self._data.read(n) + + +class NonSeekableWriter(io.RawIOBase): + def __init__(self, fileobj): + super().__init__() + self._fileobj = fileobj + + def seekable(self): + return False + + def writable(self): + return True + + def readable(self): + return False + + def write(self, b): + self._fileobj.write(b) + + def read(self, n=-1): + raise io.UnsupportedOperation("read") diff --git a/contrib/python/s3transfer/py3/tests/functional/__init__.py b/contrib/python/s3transfer/py3/tests/functional/__init__.py new file mode 100644 index 0000000000..fa58dbdb55 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/contrib/python/s3transfer/py3/tests/functional/test_copy.py b/contrib/python/s3transfer/py3/tests/functional/test_copy.py new file mode 100644 index 0000000000..801c9003bb --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_copy.py @@ -0,0 +1,554 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.exceptions import ClientError +from botocore.stub import Stubber + +from s3transfer.manager import TransferConfig, TransferManager +from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE +from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider + + +class BaseCopyTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_chunksize=MIN_UPLOAD_CHUNKSIZE, + multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4, + ) + self._manager = TransferManager(self.client, self.config) + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} + self.extra_args = {} + self.subscribers = [] + + self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2) + self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize) + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.copy + + def create_call_kwargs(self): + return { + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + {'method': 'copy_object', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': len(self.content)}, + ] + + def add_head_object_response(self, expected_params=None, stubber=None): + if not stubber: + stubber = self.stubber + head_response = self.create_stubbed_responses()[0] + if expected_params: + head_response['expected_params'] = expected_params + stubber.add_response(**head_response) + + def add_successful_copy_responses( + self, + expected_copy_params=None, + expected_create_mpu_params=None, + expected_complete_mpu_params=None, + ): + + # Add all responses needed to do the copy of the object. + # Should account for both ranged and nonranged downloads. + stubbed_responses = self.create_stubbed_responses()[1:] + + # If the length of copy responses is greater than one then it is + # a multipart copy. + copy_responses = stubbed_responses[0:1] + if len(stubbed_responses) > 1: + copy_responses = stubbed_responses[1:-1] + + # Add the expected create multipart upload params. + if expected_create_mpu_params: + stubbed_responses[0][ + 'expected_params' + ] = expected_create_mpu_params + + # Add any expected copy parameters. + if expected_copy_params: + for i, copy_response in enumerate(copy_responses): + if isinstance(expected_copy_params, list): + copy_response['expected_params'] = expected_copy_params[i] + else: + copy_response['expected_params'] = expected_copy_params + + # Add the expected complete multipart upload params. + if expected_complete_mpu_params: + stubbed_responses[-1][ + 'expected_params' + ] = expected_complete_mpu_params + + # Add the responses to the stubber. + for stubbed_response in stubbed_responses: + self.stubber.add_response(**stubbed_response) + + def test_can_provide_file_size(self): + self.add_successful_copy_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] + + future = self.manager.copy(**call_kwargs) + future.result() + + # The HeadObject should have not happened and should have been able + # to successfully copy the file. + self.stubber.assert_no_pending_responses() + + def test_provide_copy_source_as_dict(self): + self.copy_source['VersionId'] = 'mysourceversionid' + expected_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + 'VersionId': 'mysourceversionid', + } + + self.add_head_object_response(expected_params=expected_params) + self.add_successful_copy_responses() + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_invalid_copy_source(self): + self.copy_source = ['bucket', 'key'] + future = self.manager.copy(**self.create_call_kwargs()) + with self.assertRaises(TypeError): + future.result() + + def test_provide_copy_source_client(self): + source_client = self.session.create_client( + 's3', + 'eu-central-1', + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + source_stubber = Stubber(source_client) + source_stubber.activate() + self.addCleanup(source_stubber.deactivate) + + self.add_head_object_response(stubber=source_stubber) + self.add_successful_copy_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['source_client'] = source_client + future = self.manager.copy(**call_kwargs) + future.result() + + # Make sure that all of the responses were properly + # used for both clients. + source_stubber.assert_no_pending_responses() + self.stubber.assert_no_pending_responses() + + +class TestNonMultipartCopy(BaseCopyTest): + __test__ = True + + def test_copy(self): + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + } + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_with_extra_args(self): + self.extra_args['MetadataDirective'] = 'REPLACE' + + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'MetadataDirective': 'REPLACE', + } + + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_maps_extra_args_to_head_object(self): + self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' + + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + 'SSECustomerAlgorithm': 'AES256', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'CopySourceSSECustomerAlgorithm': 'AES256', + } + + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_allowed_copy_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('CopyObject') + for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_copy_with_tagging(self): + extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} + self.add_head_object_response() + self.add_successful_copy_responses( + expected_copy_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'Tagging': 'tag1=val1', + 'TaggingDirective': 'REPLACE', + } + ) + future = self.manager.copy( + self.copy_source, self.bucket, self.key, extra_args + ) + future.result() + self.stubber.assert_no_pending_responses() + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key) + + def test_raise_exception_on_s3_object_lambda_resource_as_source(self): + source = { + 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + } + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.copy(source, self.bucket, self.key) + + +class TestMultipartCopy(BaseCopyTest): + __test__ = True + + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=4, + ) + self._manager = TransferManager(self.client, self.config) + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'create_multipart_upload', + 'service_response': {'UploadId': 'my-upload-id'}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}}, + }, + {'method': 'complete_multipart_upload', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + # Note that last read is from the empty sentinel indicating + # that the stream is done. + return [ + {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, + {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, + {'bytes_transferred': self.half_chunksize}, + ] + + def add_create_multipart_upload_response(self): + self.stubber.add_response(**self.create_stubbed_responses()[1]) + + def _get_expected_params(self): + upload_id = 'my-upload-id' + + # Add expected parameters to the head object + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + + # Add expected parameters for the create multipart + expected_create_mpu_params = { + 'Bucket': self.bucket, + 'Key': self.key, + } + + expected_copy_params = [] + # Add expected parameters to the copy part + ranges = [ + 'bytes=0-5242879', + 'bytes=5242880-10485759', + 'bytes=10485760-13107199', + ] + for i, range_val in enumerate(ranges): + expected_copy_params.append( + { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': upload_id, + 'PartNumber': i + 1, + 'CopySourceRange': range_val, + } + ) + + # Add expected parameters for the complete multipart + expected_complete_mpu_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': { + 'Parts': [ + {'ETag': 'etag-1', 'PartNumber': 1}, + {'ETag': 'etag-2', 'PartNumber': 2}, + {'ETag': 'etag-3', 'PartNumber': 3}, + ] + }, + } + + return expected_head_params, { + 'expected_create_mpu_params': expected_create_mpu_params, + 'expected_copy_params': expected_copy_params, + 'expected_complete_mpu_params': expected_complete_mpu_params, + } + + def _add_params_to_expected_params( + self, add_copy_kwargs, operation_types, new_params + ): + + expected_params_to_update = [] + for operation_type in operation_types: + add_copy_kwargs_key = 'expected_' + operation_type + '_params' + expected_params = add_copy_kwargs[add_copy_kwargs_key] + if isinstance(expected_params, list): + expected_params_to_update.extend(expected_params) + else: + expected_params_to_update.append(expected_params) + + for expected_params in expected_params_to_update: + expected_params.update(new_params) + + def test_copy(self): + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + self.add_successful_copy_responses(**add_copy_kwargs) + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_with_extra_args(self): + # This extra argument should be added to the head object, + # the create multipart upload, and upload part copy. + self.extra_args['RequestPayer'] = 'requester' + + head_params, add_copy_kwargs = self._get_expected_params() + head_params.update(self.extra_args) + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, + ['create_mpu', 'copy', 'complete_mpu'], + self.extra_args, + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_blacklists_args_to_create_multipart(self): + # This argument can never be used for multipart uploads + self.extra_args['MetadataDirective'] = 'COPY' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_args_to_only_create_multipart(self): + self.extra_args['ACL'] = 'private' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, ['create_mpu'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_passes_args_to_create_multipart_and_upload_part(self): + # This will only be used for the complete multipart upload + # and upload part. + self.extra_args['SSECustomerAlgorithm'] = 'AES256' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_maps_extra_args_to_head_object(self): + self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' + + head_params, add_copy_kwargs = self._get_expected_params() + + # The CopySourceSSECustomerAlgorithm needs to get mapped to + # SSECustomerAlgorithm for HeadObject + head_params['SSECustomerAlgorithm'] = 'AES256' + self.add_head_object_response(expected_params=head_params) + + # However, it needs to remain the same for UploadPartCopy. + self._add_params_to_expected_params( + add_copy_kwargs, ['copy'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_abort_on_failure(self): + # First add the head object and create multipart upload + self.add_head_object_response() + self.add_create_multipart_upload_response() + + # Cause an error on upload_part_copy + self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure') + + # Add the abort multipart to ensure it gets cleaned up on failure + self.stubber.add_response( + 'abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': 'my-upload-id', + }, + ) + + future = self.manager.copy(**self.create_call_kwargs()) + with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'): + future.result() + self.stubber.assert_no_pending_responses() + + def test_mp_copy_with_tagging_directive(self): + extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} + self.add_head_object_response() + self.add_successful_copy_responses( + expected_create_mpu_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Tagging': 'tag1=val1', + } + ) + future = self.manager.copy( + self.copy_source, self.bucket, self.key, extra_args + ) + future.result() + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_crt.py b/contrib/python/s3transfer/py3/tests/functional/test_crt.py new file mode 100644 index 0000000000..fad0f4b23b --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_crt.py @@ -0,0 +1,267 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import fnmatch +import threading +import time +from concurrent.futures import Future + +from botocore.session import Session + +from s3transfer.subscribers import BaseSubscriber +from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + +if HAS_CRT: + import awscrt + + import s3transfer.crt + + +class submitThread(threading.Thread): + def __init__(self, transfer_manager, futures, callargs): + threading.Thread.__init__(self) + self._transfer_manager = transfer_manager + self._futures = futures + self._callargs = callargs + + def run(self): + self._futures.append(self._transfer_manager.download(*self._callargs)) + + +class RecordingSubscriber(BaseSubscriber): + def __init__(self): + self.on_queued_called = False + self.on_done_called = False + self.bytes_transferred = 0 + self.on_queued_future = None + self.on_done_future = None + + def on_queued(self, future, **kwargs): + self.on_queued_called = True + self.on_queued_future = future + + def on_done(self, future, **kwargs): + self.on_done_called = True + self.on_done_future = future + + +@requires_crt +class TestCRTTransferManager(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.bucket = "test_bucket" + self.key = "test_key" + self.files = FileCreator() + self.filename = self.files.create_file('myfile', 'my content') + self.expected_path = "/" + self.bucket + "/" + self.key + self.expected_host = "s3.%s.amazonaws.com" % (self.region) + self.s3_request = mock.Mock(awscrt.s3.S3Request) + self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) + self.s3_crt_client.make_request.return_value = self.s3_request + self.session = Session() + self.session.set_config_variable('region', self.region) + self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( + self.session + ) + self.transfer_manager = s3transfer.crt.CRTTransferManager( + crt_s3_client=self.s3_crt_client, + crt_request_serializer=self.request_serializer, + ) + self.record_subscriber = RecordingSubscriber() + + def tearDown(self): + self.files.remove_all() + + def _assert_subscribers_called(self, expected_future=None): + self.assertTrue(self.record_subscriber.on_queued_called) + self.assertTrue(self.record_subscriber.on_done_called) + if expected_future: + self.assertIs( + self.record_subscriber.on_queued_future, expected_future + ) + self.assertIs( + self.record_subscriber.on_done_future, expected_future + ) + + def _invoke_done_callbacks(self, **kwargs): + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + on_done = callargs_kwargs["on_done"] + on_done(error=None) + + def _simulate_file_download(self, recv_filepath): + self.files.create_file(recv_filepath, "fake response") + + def _simulate_make_request_side_effect(self, **kwargs): + if kwargs.get('recv_filepath'): + self._simulate_file_download(kwargs['recv_filepath']) + self._invoke_done_callbacks() + return mock.DEFAULT + + def test_upload(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + self.assertEqual(callargs_kwargs["send_filepath"], self.filename) + self.assertIsNone(callargs_kwargs["recv_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + + def test_download(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.download( + self.bucket, self.key, self.filename, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + # the recv_filepath will be set to a temporary file path with some + # random suffix + self.assertTrue( + fnmatch.fnmatch( + callargs_kwargs["recv_filepath"], + f'{self.filename}.*', + ) + ) + self.assertIsNone(callargs_kwargs["send_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + with open(self.filename, 'rb') as f: + # Check the fake response overwrites the file because of download + self.assertEqual(f.read(), b'fake response') + + def test_delete(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.delete( + self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + self.assertIsNone(callargs_kwargs["send_filepath"]) + self.assertIsNone(callargs_kwargs["recv_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("DELETE", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + + def test_blocks_when_max_requests_processes_reached(self): + futures = [] + callargs = (self.bucket, self.key, self.filename, {}, []) + max_request_processes = 128 # the hard coded max processes + all_concurrent = max_request_processes + 1 + threads = [] + for i in range(0, all_concurrent): + thread = submitThread(self.transfer_manager, futures, callargs) + thread.start() + threads.append(thread) + # Sleep until the expected max requests has been reached + while len(futures) < max_request_processes: + time.sleep(0.05) + self.assertLessEqual( + self.s3_crt_client.make_request.call_count, max_request_processes + ) + # Release lock + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + on_done = callargs_kwargs["on_done"] + on_done(error=None) + for thread in threads: + thread.join() + self.assertEqual( + self.s3_crt_client.make_request.call_count, all_concurrent + ) + + def _cancel_function(self): + self.cancel_called = True + self.s3_request.finished_future.set_exception( + awscrt.exceptions.from_code(0) + ) + self._invoke_done_callbacks() + + def test_cancel(self): + self.s3_request.finished_future = Future() + self.cancel_called = False + self.s3_request.cancel = self._cancel_function + try: + with self.transfer_manager: + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + raise KeyboardInterrupt() + except KeyboardInterrupt: + pass + + with self.assertRaises(awscrt.exceptions.AwsCrtError): + future.result() + self.assertTrue(self.cancel_called) + + def test_serializer_error_handling(self): + class SerializationException(Exception): + pass + + class ExceptionRaisingSerializer( + s3transfer.crt.BaseCRTRequestSerializer + ): + def serialize_http_request(self, transfer_type, future): + raise SerializationException() + + not_impl_serializer = ExceptionRaisingSerializer() + transfer_manager = s3transfer.crt.CRTTransferManager( + crt_s3_client=self.s3_crt_client, + crt_request_serializer=not_impl_serializer, + ) + future = transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + + with self.assertRaises(SerializationException): + future.result() + + def test_crt_s3_client_error_handling(self): + self.s3_crt_client.make_request.side_effect = ( + awscrt.exceptions.from_code(0) + ) + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + with self.assertRaises(awscrt.exceptions.AwsCrtError): + future.result() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_delete.py b/contrib/python/s3transfer/py3/tests/functional/test_delete.py new file mode 100644 index 0000000000..28587a47a4 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_delete.py @@ -0,0 +1,76 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.manager import TransferManager +from __tests__ import BaseGeneralInterfaceTest + + +class TestDeleteObject(BaseGeneralInterfaceTest): + + __test__ = True + + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.manager = TransferManager(self.client) + + @property + def method(self): + """The transfer manager method to invoke i.e. upload()""" + return self.manager.delete + + def create_call_kwargs(self): + """The kwargs to be passed to the transfer manager method""" + return { + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return { + 'BadKwargs': True, + } + + def create_stubbed_responses(self): + """A list of stubbed responses that will cause the request to succeed + + The elements of this list is a dictionary that will be used as key + word arguments to botocore.Stubber.add_response(). For example:: + + [{'method': 'put_object', 'service_response': {}}] + """ + return [ + { + 'method': 'delete_object', + 'service_response': {}, + 'expected_params': {'Bucket': self.bucket, 'Key': self.key}, + } + ] + + def create_expected_progress_callback_info(self): + return [] + + def test_known_allowed_args_in_input_shape(self): + op_model = self.client.meta.service_model.operation_model( + 'DeleteObject' + ) + for allowed_arg in self.manager.ALLOWED_DELETE_ARGS: + self.assertIn(allowed_arg, op_model.input_shape.members) + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.delete(s3_object_lambda_arn, self.key) diff --git a/contrib/python/s3transfer/py3/tests/functional/test_download.py b/contrib/python/s3transfer/py3/tests/functional/test_download.py new file mode 100644 index 0000000000..64a8a1309d --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_download.py @@ -0,0 +1,497 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import glob +import os +import shutil +import tempfile +import time +from io import BytesIO + +from botocore.exceptions import ClientError + +from s3transfer.compat import SOCKET_ERROR +from s3transfer.exceptions import RetriesExceededError +from s3transfer.manager import TransferConfig, TransferManager +from __tests__ import ( + BaseGeneralInterfaceTest, + FileSizeProvider, + NonSeekableWriter, + RecordingOSUtils, + RecordingSubscriber, + StreamWithError, + skip_if_using_serial_implementation, + skip_if_windows, +) + + +class BaseDownloadTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + self.config = TransferConfig(max_request_concurrency=1) + self._manager = TransferManager(self.client, self.config) + + # Create a temporary directory to write to + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # Create a stream to read from + self.content = b'my content' + self.stream = BytesIO(self.content) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.download + + def create_call_kwargs(self): + return { + 'bucket': self.bucket, + 'key': self.key, + 'fileobj': self.filename, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + # We want to make sure the beginning of the stream is always used + # in case this gets called twice. + self.stream.seek(0) + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': self.stream}, + }, + ] + + def create_expected_progress_callback_info(self): + # Note that last read is from the empty sentinel indicating + # that the stream is done. + return [{'bytes_transferred': 10}] + + def add_head_object_response(self, expected_params=None): + head_response = self.create_stubbed_responses()[0] + if expected_params: + head_response['expected_params'] = expected_params + self.stubber.add_response(**head_response) + + def add_successful_get_object_responses( + self, expected_params=None, expected_ranges=None + ): + # Add all get_object responses needed to complete the download. + # Should account for both ranged and nonranged downloads. + for i, stubbed_response in enumerate( + self.create_stubbed_responses()[1:] + ): + if expected_params: + stubbed_response['expected_params'] = copy.deepcopy( + expected_params + ) + if expected_ranges: + stubbed_response['expected_params'][ + 'Range' + ] = expected_ranges[i] + self.stubber.add_response(**stubbed_response) + + def add_n_retryable_get_object_responses(self, n, num_reads=0): + for _ in range(n): + self.stubber.add_response( + method='get_object', + service_response={ + 'Body': StreamWithError( + copy.deepcopy(self.stream), SOCKET_ERROR, num_reads + ) + }, + ) + + def test_download_temporary_file_does_not_exist(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + # Make sure the file exists + self.assertTrue(os.path.exists(self.filename)) + # Make sure the random temporary file does not exist + possible_matches = glob.glob('%s*' % self.filename + os.extsep) + self.assertEqual(possible_matches, []) + + def test_download_for_fileobj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + with open(self.filename, 'wb') as f: + future = self.manager.download( + self.bucket, self.key, f, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_download_for_seekable_filelike_obj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + # Create a file-like object to test. In this case, it is a BytesIO + # object. + bytes_io = BytesIO() + + future = self.manager.download( + self.bucket, self.key, bytes_io, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + bytes_io.seek(0) + self.assertEqual(self.content, bytes_io.read()) + + def test_download_for_nonseekable_filelike_obj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + with open(self.filename, 'wb') as f: + future = self.manager.download( + self.bucket, self.key, NonSeekableWriter(f), self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_download_cleanup_on_failure(self): + self.add_head_object_response() + + # Throw an error on the download + self.stubber.add_client_error('get_object') + + future = self.manager.download(**self.create_call_kwargs()) + + with self.assertRaises(ClientError): + future.result() + # Make sure the actual file and the temporary do not exist + # by globbing for the file and any of its extensions + possible_matches = glob.glob('%s*' % self.filename) + self.assertEqual(possible_matches, []) + + def test_download_with_nonexistent_directory(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['fileobj'] = os.path.join( + self.tempdir, 'missing-directory', 'myfile' + ) + future = self.manager.download(**call_kwargs) + with self.assertRaises(IOError): + future.result() + + def test_retries_and_succeeds(self): + self.add_head_object_response() + # Insert a response that will trigger a retry. + self.add_n_retryable_get_object_responses(1) + # Add the normal responses to simulate the download proceeding + # as normal after the retry. + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + + # The retry should have been consumed and the process should have + # continued using the successful responses. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_retry_failure(self): + self.add_head_object_response() + + max_retries = 3 + self.config.num_download_attempts = max_retries + self._manager = TransferManager(self.client, self.config) + # Add responses that fill up the maximum number of retries. + self.add_n_retryable_get_object_responses(max_retries) + + future = self.manager.download(**self.create_call_kwargs()) + + # A retry exceeded error should have happened. + with self.assertRaises(RetriesExceededError): + future.result() + + # All of the retries should have been used up. + self.stubber.assert_no_pending_responses() + + def test_retry_rewinds_callbacks(self): + self.add_head_object_response() + # Insert a response that will trigger a retry after one read of the + # stream has been made. + self.add_n_retryable_get_object_responses(1, num_reads=1) + # Add the normal responses to simulate the download proceeding + # as normal after the retry. + self.add_successful_get_object_responses() + + recorder_subscriber = RecordingSubscriber() + # Set the streaming to a size that is smaller than the data we + # currently provide to it to simulate rewinds of callbacks. + self.config.io_chunksize = 3 + future = self.manager.download( + subscribers=[recorder_subscriber], **self.create_call_kwargs() + ) + future.result() + + # Ensure that there is no more remaining responses and that contents + # are correct. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + # Assert that the number of bytes seen is equal to the length of + # downloaded content. + self.assertEqual( + recorder_subscriber.calculate_bytes_seen(), len(self.content) + ) + + # Also ensure that the second progress invocation was negative three + # because a retry happened on the second read of the stream and we + # know that the chunk size for each read is 3. + progress_byte_amts = [ + call['bytes_transferred'] + for call in recorder_subscriber.on_progress_calls + ] + self.assertEqual(-3, progress_byte_amts[1]) + + def test_can_provide_file_size(self): + self.add_successful_get_object_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] + + future = self.manager.download(**call_kwargs) + future.result() + + # The HeadObject should have not happened and should have been able + # to successfully download the file. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_uses_provided_osutil(self): + osutil = RecordingOSUtils() + # Use the recording os utility for the transfer manager + self._manager = TransferManager(self.client, self.config, osutil) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + # The osutil should have had its open() method invoked when opening + # a temporary file and its rename_file() method invoked when the + # the temporary file was moved to its final location. + self.assertEqual(len(osutil.open_records), 1) + self.assertEqual(len(osutil.rename_records), 1) + + @skip_if_windows('Windows does not support UNIX special files') + @skip_if_using_serial_implementation( + 'A separate thread is needed to read from the fifo' + ) + def test_download_for_fifo_file(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + # Create the fifo file + os.mkfifo(self.filename) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + + # The call to open a fifo will block until there is both a reader + # and a writer, so we need to open it for reading after we've + # started the transfer. + with open(self.filename, 'rb') as fifo: + future.result() + self.assertEqual(fifo.read(), self.content) + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.download( + s3_object_lambda_arn, self.key, self.filename, self.extra_args + ) + + +class TestNonRangedDownload(BaseDownloadTest): + # TODO: If you want to add tests outside of this test class and still + # subclass from BaseDownloadTest you need to set ``__test__ = True``. If + # you do not, your tests will not get picked up by the test runner! This + # needs to be done until we find a better way to ignore running test cases + # from the general test base class, which we do not want ran. + __test__ = True + + def test_download(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + self.add_head_object_response(expected_params) + self.add_successful_get_object_responses(expected_params) + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_allowed_copy_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('GetObject') + for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_download_empty_object(self): + self.content = b'' + self.stream = BytesIO(self.content) + self.add_head_object_response() + self.add_successful_get_object_responses() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the empty file exists + with open(self.filename, 'rb') as f: + self.assertEqual(b'', f.read()) + + def test_uses_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + self.stream = BytesIO(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content) / 2 + ) + self._manager = TransferManager(self.client, self.config) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + start = time.time() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + +class TestRangedDownload(BaseDownloadTest): + # TODO: If you want to add tests outside of this test class and still + # subclass from BaseDownloadTest you need to set ``__test__ = True``. If + # you do not, your tests will not get picked up by the test runner! This + # needs to be done until we find a better way to ignore running test cases + # from the general test base class, which we do not want ran. + __test__ = True + + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=4, + ) + self._manager = TransferManager(self.client, self.config) + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[0:4])}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[4:8])}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[8:])}, + }, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2}, + ] + + def test_download(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + self.add_head_object_response(expected_params) + self.add_successful_get_object_responses( + expected_params, expected_ranges + ) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) diff --git a/contrib/python/s3transfer/py3/tests/functional/test_manager.py b/contrib/python/s3transfer/py3/tests/functional/test_manager.py new file mode 100644 index 0000000000..1c980e7bc6 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_manager.py @@ -0,0 +1,191 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from io import BytesIO + +from botocore.awsrequest import create_request_object + +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import BaseExecutor +from s3transfer.manager import TransferConfig, TransferManager +from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation + + +class ArbitraryException(Exception): + pass + + +class SignalTransferringBody(BytesIO): + """A mocked body with the ability to signal when transfers occur""" + + def __init__(self): + super().__init__() + self.signal_transferring_call_count = 0 + self.signal_not_transferring_call_count = 0 + + def signal_transferring(self): + self.signal_transferring_call_count += 1 + + def signal_not_transferring(self): + self.signal_not_transferring_call_count += 1 + + def seek(self, where, whence=0): + pass + + def tell(self): + return 0 + + def read(self, amount=0): + return b'' + + +class TestTransferManager(StubbedClientTest): + @skip_if_using_serial_implementation( + 'Exception is thrown once all transfers are submitted. ' + 'However for the serial implementation, transfers are performed ' + 'in main thread meaning all transfers will complete before the ' + 'exception being thrown.' + ) + def test_error_in_context_manager_cancels_incomplete_transfers(self): + # The purpose of this test is to make sure if an error is raised + # in the body of the context manager, incomplete transfers will + # be cancelled with value of the exception wrapped by a CancelledError + + # NOTE: The fact that delete() was chosen to test this is arbitrary + # other than it is the easiet to set up for the stubber. + # The specific operation is not important to the purpose of this test. + num_transfers = 100 + futures = [] + ref_exception_msg = 'arbitrary exception' + + for _ in range(num_transfers): + self.stubber.add_response('delete_object', {}) + + manager = TransferManager( + self.client, + TransferConfig( + max_request_concurrency=1, max_submission_concurrency=1 + ), + ) + try: + with manager: + for i in range(num_transfers): + futures.append(manager.delete('mybucket', 'mykey')) + raise ArbitraryException(ref_exception_msg) + except ArbitraryException: + # At least one of the submitted futures should have been + # cancelled. + with self.assertRaisesRegex(FatalError, ref_exception_msg): + for future in futures: + future.result() + + @skip_if_using_serial_implementation( + 'Exception is thrown once all transfers are submitted. ' + 'However for the serial implementation, transfers are performed ' + 'in main thread meaning all transfers will complete before the ' + 'exception being thrown.' + ) + def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self): + # The purpose of this test is to make sure if an error is raised + # in the body of the context manager, incomplete transfers will + # be cancelled with value of the exception wrapped by a CancelledError + + # NOTE: The fact that delete() was chosen to test this is arbitrary + # other than it is the easiet to set up for the stubber. + # The specific operation is not important to the purpose of this test. + num_transfers = 100 + futures = [] + + for _ in range(num_transfers): + self.stubber.add_response('delete_object', {}) + + manager = TransferManager( + self.client, + TransferConfig( + max_request_concurrency=1, max_submission_concurrency=1 + ), + ) + try: + with manager: + for i in range(num_transfers): + futures.append(manager.delete('mybucket', 'mykey')) + raise KeyboardInterrupt() + except KeyboardInterrupt: + # At least one of the submitted futures should have been + # cancelled. + with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'): + for future in futures: + future.result() + + def test_enable_disable_callbacks_only_ever_registered_once(self): + body = SignalTransferringBody() + request = create_request_object( + { + 'method': 'PUT', + 'url': 'https://s3.amazonaws.com', + 'body': body, + 'headers': {}, + 'context': {}, + } + ) + # Create two TransferManager's using the same client + TransferManager(self.client) + TransferManager(self.client) + self.client.meta.events.emit( + 'request-created.s3', request=request, operation_name='PutObject' + ) + # The client should have only have the enable/disable callback + # handlers registered once depite being used for two different + # TransferManagers. + self.assertEqual( + body.signal_transferring_call_count, + 1, + 'The enable_callback() should have only ever been registered once', + ) + self.assertEqual( + body.signal_not_transferring_call_count, + 1, + 'The disable_callback() should have only ever been registered ' + 'once', + ) + + def test_use_custom_executor_implementation(self): + mocked_executor_cls = mock.Mock(BaseExecutor) + transfer_manager = TransferManager( + self.client, executor_cls=mocked_executor_cls + ) + transfer_manager.delete('bucket', 'key') + self.assertTrue(mocked_executor_cls.return_value.submit.called) + + def test_unicode_exception_in_context_manager(self): + with self.assertRaises(ArbitraryException): + with TransferManager(self.client): + raise ArbitraryException('\u2713') + + def test_client_property(self): + manager = TransferManager(self.client) + self.assertIs(manager.client, self.client) + + def test_config_property(self): + config = TransferConfig() + manager = TransferManager(self.client, config) + self.assertIs(manager.config, config) + + def test_can_disable_bucket_validation(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + config = TransferConfig() + manager = TransferManager(self.client, config) + manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False + manager.delete(s3_object_lambda_arn, 'my-key') diff --git a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py new file mode 100644 index 0000000000..1396c919f2 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py @@ -0,0 +1,281 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import glob +import os +from io import BytesIO +from multiprocessing.managers import BaseManager + +import botocore.exceptions +import botocore.session +from botocore.stub import Stubber + +from s3transfer.exceptions import CancelledError +from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig +from __tests__ import FileCreator, mock, unittest + + +class StubbedClient: + def __init__(self): + self._client = botocore.session.get_session().create_client( + 's3', + 'us-west-2', + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + self._stubber = Stubber(self._client) + self._stubber.activate() + self._caught_stubber_errors = [] + + def get_object(self, **kwargs): + return self._client.get_object(**kwargs) + + def head_object(self, **kwargs): + return self._client.head_object(**kwargs) + + def add_response(self, *args, **kwargs): + self._stubber.add_response(*args, **kwargs) + + def add_client_error(self, *args, **kwargs): + self._stubber.add_client_error(*args, **kwargs) + + +class StubbedClientManager(BaseManager): + pass + + +StubbedClientManager.register('StubbedClient', StubbedClient) + + +# Ideally a Mock would be used here. However, they cannot be pickled +# for Windows. So instead we define a factory class at the module level that +# can return a stubbed client we initialized in the setUp. +class StubbedClientFactory: + def __init__(self, stubbed_client): + self._stubbed_client = stubbed_client + + def __call__(self, *args, **kwargs): + # The __call__ is defined so we can provide an instance of the + # StubbedClientFactory to mock.patch() and have the instance be + # returned when the patched class is instantiated. + return self + + def create_client(self): + return self._stubbed_client + + +class TestProcessPoolDownloader(unittest.TestCase): + def setUp(self): + # The stubbed client needs to run in a manager to be shared across + # processes and have it properly consume the stubbed response across + # processes. + self.manager = StubbedClientManager() + self.manager.start() + self.stubbed_client = self.manager.StubbedClient() + self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client) + + self.client_factory_patch = mock.patch( + 's3transfer.processpool.ClientFactory', self.stubbed_client_factory + ) + self.client_factory_patch.start() + self.files = FileCreator() + + self.config = ProcessTransferConfig(max_request_processes=1) + self.downloader = ProcessPoolDownloader(config=self.config) + self.bucket = 'mybucket' + self.key = 'mykey' + self.filename = self.files.full_path('filename') + self.remote_contents = b'my content' + self.stream = BytesIO(self.remote_contents) + + def tearDown(self): + self.manager.shutdown() + self.client_factory_patch.stop() + self.files.remove_all() + + def assert_contents(self, filename, expected_contents): + self.assertTrue(os.path.exists(filename)) + with open(filename, 'rb') as f: + self.assertEqual(f.read(), expected_contents) + + def test_download_file(self): + self.stubbed_client.add_response( + 'head_object', {'ContentLength': len(self.remote_contents)} + ) + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + self.downloader.download_file(self.bucket, self.key, self.filename) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_multiple_files(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + self.stubbed_client.add_response( + 'get_object', {'Body': BytesIO(self.remote_contents)} + ) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + other_file = self.files.full_path('filename2') + self.downloader.download_file( + self.bucket, + self.key, + other_file, + expected_size=len(self.remote_contents), + ) + self.assert_contents(self.filename, self.remote_contents) + self.assert_contents(other_file, self.remote_contents) + + def test_download_file_ranged_download(self): + half_of_content_length = int(len(self.remote_contents) / 2) + self.stubbed_client.add_response( + 'head_object', {'ContentLength': len(self.remote_contents)} + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': BytesIO(self.remote_contents[:half_of_content_length])}, + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': BytesIO(self.remote_contents[half_of_content_length:])}, + ) + downloader = ProcessPoolDownloader( + config=ProcessTransferConfig( + multipart_chunksize=half_of_content_length, + multipart_threshold=half_of_content_length, + max_request_processes=1, + ) + ) + with downloader: + downloader.download_file(self.bucket, self.key, self.filename) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_file_extra_args(self): + self.stubbed_client.add_response( + 'head_object', + {'ContentLength': len(self.remote_contents)}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': self.stream}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + extra_args={'VersionId': 'versionid'}, + ) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_file_expected_size(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assert_contents(self.filename, self.remote_contents) + + def test_cleans_up_tempfile_on_failure(self): + self.stubbed_client.add_client_error('get_object', 'NoSuchKey') + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assertFalse(os.path.exists(self.filename)) + # Any tempfile should have been erased as well + possible_matches = glob.glob('%s*' % self.filename + os.extsep) + self.assertEqual(possible_matches, []) + + def test_validates_extra_args(self): + with self.downloader: + with self.assertRaises(ValueError): + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + extra_args={'NotSupported': 'NotSupported'}, + ) + + def test_result_with_success(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assertIsNone(future.result()) + + def test_result_with_exception(self): + self.stubbed_client.add_client_error('get_object', 'NoSuchKey') + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + with self.assertRaises(botocore.exceptions.ClientError): + future.result() + + def test_result_with_cancel(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + future.cancel() + with self.assertRaises(CancelledError): + future.result() + + def test_shutdown_with_no_downloads(self): + downloader = ProcessPoolDownloader() + try: + downloader.shutdown() + except AttributeError: + self.fail( + 'The downloader should be able to be shutdown even though ' + 'the downloader was never started.' + ) + + def test_shutdown_with_no_downloads_and_ctrl_c(self): + # Special shutdown logic happens if a KeyboardInterrupt is raised in + # the context manager. However, this logic can not happen if the + # downloader was never started. So a KeyboardInterrupt should be + # the only exception propagated. + with self.assertRaises(KeyboardInterrupt): + with self.downloader: + raise KeyboardInterrupt() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_upload.py b/contrib/python/s3transfer/py3/tests/functional/test_upload.py new file mode 100644 index 0000000000..4f294e85ad --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_upload.py @@ -0,0 +1,538 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile +import time +from io import BytesIO + +from botocore.awsrequest import AWSRequest +from botocore.client import Config +from botocore.exceptions import ClientError +from botocore.stub import ANY + +from s3transfer.manager import TransferConfig, TransferManager +from s3transfer.utils import ChunksizeAdjuster +from __tests__ import ( + BaseGeneralInterfaceTest, + NonSeekableReader, + RecordingOSUtils, + RecordingSubscriber, + mock, +) + + +class BaseUploadTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE + # when we're adjusting parts. + # This is really wasteful and fails CI builds because self.contents + # would normally use 10MB+ of memory. + # Until there's an API to configure this, we're patching this with + # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly + # because it's already bound to a default value in the + # chunksize adjuster. Instead we need to patch out the + # chunksize adjuster class. + self.adjuster_patch = mock.patch( + 's3transfer.upload.ChunksizeAdjuster', + lambda: ChunksizeAdjuster(min_size=1), + ) + self.adjuster_patch.start() + self.config = TransferConfig(max_request_concurrency=1) + self._manager = TransferManager(self.client, self.config) + + # Create a temporary directory with files to read from + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'my content' + + with open(self.filename, 'wb') as f: + f.write(self.content) + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + self.adjuster_patch.stop() + + def collect_body(self, params, model, **kwargs): + # A handler to simulate the reading of the body including the + # request-created event that signals to simulate the progress + # callbacks + if 'Body' in params: + # TODO: This is not ideal. Need to figure out a better idea of + # simulating reading of the request across the wire to trigger + # progress callbacks + request = AWSRequest( + method='PUT', + url='https://s3.amazonaws.com', + data=params['Body'], + ) + self.client.meta.events.emit( + 'request-created.s3.%s' % model.name, + request=request, + operation_name=model.name, + ) + self.sent_bodies.append(self._stream_body(params['Body'])) + + def _stream_body(self, body): + read_amt = 8 * 1024 + data = body.read(read_amt) + collected_body = data + while data: + data = body.read(read_amt) + collected_body += data + return collected_body + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.upload + + def create_call_kwargs(self): + return { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + return [{'method': 'put_object', 'service_response': {}}] + + def create_expected_progress_callback_info(self): + return [{'bytes_transferred': 10}] + + def assert_expected_client_calls_were_correct(self): + # We assert that expected client calls were made by ensuring that + # there are no more pending responses. If there are no more pending + # responses, then all stubbed responses were consumed. + self.stubber.assert_no_pending_responses() + + +class TestNonMultipartUpload(BaseUploadTest): + __test__ = True + + def add_put_object_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key} + if extra_expected_params: + expected_params.update(extra_expected_params) + upload_response = self.create_stubbed_responses()[0] + upload_response['expected_params'] = expected_params + self.stubber.add_response(**upload_response) + + def assert_put_object_body_was_correct(self): + self.assertEqual(self.sent_bodies, [self.content]) + + def test_upload(self): + self.extra_args['RequestPayer'] = 'requester' + self.add_put_object_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_fileobj(self): + self.add_put_object_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_seekable_filelike_obj(self): + self.add_put_object_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): + self.add_put_object_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + seek_pos = 5 + bytes_io.seek(seek_pos) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) + + def test_upload_for_non_seekable_filelike_obj(self): + self.add_put_object_response_with_default_expected_params() + body = NonSeekableReader(self.content) + future = self.manager.upload( + body, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_sigv4_progress_callbacks_invoked_once(self): + # Reset the client and manager to use sigv4 + self.reset_stubber_with_new_client( + {'config': Config(signature_version='s3v4')} + ) + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + self._manager = TransferManager(self.client, self.config) + + # Add the stubbed response. + self.add_put_object_response_with_default_expected_params() + + subscriber = RecordingSubscriber() + future = self.manager.upload( + self.filename, self.bucket, self.key, subscribers=[subscriber] + ) + future.result() + self.assert_expected_client_calls_were_correct() + + # The amount of bytes seen should be the same as the file size + self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content)) + + def test_uses_provided_osutil(self): + osutil = RecordingOSUtils() + # Use the recording os utility for the transfer manager + self._manager = TransferManager(self.client, self.config, osutil) + + self.add_put_object_response_with_default_expected_params() + + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + + # The upload should have used the os utility. We check this by making + # sure that the recorded opens are as expected. + expected_opens = [(self.filename, 'rb')] + self.assertEqual(osutil.open_records, expected_opens) + + def test_allowed_upload_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('PutObject') + for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_upload_with_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + with open(self.filename, 'wb') as f: + f.write(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content) / 2 + ) + self._manager = TransferManager(self.client, self.config) + + self.add_put_object_response_with_default_expected_params() + start = time.time() + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.upload(self.filename, s3_object_lambda_arn, self.key) + + +class TestMultipartUpload(BaseUploadTest): + __test__ = True + + def setUp(self): + super().setUp() + self.chunksize = 4 + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=self.chunksize, + ) + self._manager = TransferManager(self.client, self.config) + self.multipart_id = 'my-upload-id' + + def create_stubbed_responses(self): + return [ + { + 'method': 'create_multipart_upload', + 'service_response': {'UploadId': self.multipart_id}, + }, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}}, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}}, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}}, + {'method': 'complete_multipart_upload', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2}, + ] + + def assert_upload_part_bodies_were_correct(self): + expected_contents = [] + for i in range(0, len(self.content), self.chunksize): + end_i = i + self.chunksize + if end_i > len(self.content): + expected_contents.append(self.content[i:]) + else: + expected_contents.append(self.content[i:end_i]) + self.assertEqual(self.sent_bodies, expected_contents) + + def add_create_multipart_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = {'Bucket': self.bucket, 'Key': self.key} + if extra_expected_params: + expected_params.update(extra_expected_params) + response = self.create_stubbed_responses()[0] + response['expected_params'] = expected_params + self.stubber.add_response(**response) + + def add_upload_part_responses_with_default_expected_params( + self, extra_expected_params=None + ): + num_parts = 3 + upload_part_responses = self.create_stubbed_responses()[1:-1] + for i in range(num_parts): + upload_part_response = upload_part_responses[i] + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'Body': ANY, + 'PartNumber': i + 1, + } + if extra_expected_params: + expected_params.update(extra_expected_params) + upload_part_response['expected_params'] = expected_params + self.stubber.add_response(**upload_part_response) + + def add_complete_multipart_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'MultipartUpload': { + 'Parts': [ + {'ETag': 'etag-1', 'PartNumber': 1}, + {'ETag': 'etag-2', 'PartNumber': 2}, + {'ETag': 'etag-3', 'PartNumber': 3}, + ] + }, + } + if extra_expected_params: + expected_params.update(extra_expected_params) + response = self.create_stubbed_responses()[-1] + response['expected_params'] = expected_params + self.stubber.add_response(**response) + + def test_upload(self): + self.extra_args['RequestPayer'] = 'requester' + + # Add requester pays to the create multipart upload and upload parts. + self.add_create_multipart_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + self.add_upload_part_responses_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + self.add_complete_multipart_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + + def test_upload_for_fileobj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_upload_for_seekable_filelike_obj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + seek_pos = 1 + bytes_io.seek(seek_pos) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) + + def test_upload_for_non_seekable_filelike_obj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + stream = NonSeekableReader(self.content) + future = self.manager.upload( + stream, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_limits_in_memory_chunks_for_fileobj(self): + # Limit the maximum in memory chunks to one but make number of + # threads more than one. This means that the upload will have to + # happen sequentially despite having many threads available because + # data is sequentially partitioned into chunks in memory and since + # there can only every be one in memory chunk, each upload part will + # have to happen one at a time. + self.config.max_request_concurrency = 10 + self.config.max_in_memory_upload_chunks = 1 + self._manager = TransferManager(self.client, self.config) + + # Add some default stubbed responses. + # These responses are added in order of part number so if the + # multipart upload is not done sequentially, which it should because + # we limit the in memory upload chunks to one, the stubber will + # raise exceptions for mismatching parameters for partNumber when + # once the upload() method is called on the transfer manager. + # If there is a mismatch, the stubber error will propagate on + # the future.result() + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + + # Make sure that the stubber had all of its stubbed responses consumed. + self.assert_expected_client_calls_were_correct() + # Ensure the contents were uploaded in sequentially order by checking + # the sent contents were in order. + self.assert_upload_part_bodies_were_correct() + + def test_upload_failure_invokes_abort(self): + self.stubber.add_response( + method='create_multipart_upload', + service_response={'UploadId': self.multipart_id}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + method='upload_part', + service_response={'ETag': 'etag-1'}, + expected_params={ + 'Bucket': self.bucket, + 'Body': ANY, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'PartNumber': 1, + }, + ) + # With the upload part failing this should immediately initiate + # an abort multipart with no more upload parts called. + self.stubber.add_client_error(method='upload_part') + + self.stubber.add_response( + method='abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + }, + ) + + future = self.manager.upload(self.filename, self.bucket, self.key) + # The exception should get propagated to the future and not be + # a cancelled error or something. + with self.assertRaises(ClientError): + future.result() + self.assert_expected_client_calls_were_correct() + + def test_upload_passes_select_extra_args(self): + self.extra_args['Metadata'] = {'foo': 'bar'} + + # Add metadata to expected create multipart upload call + self.add_create_multipart_response_with_default_expected_params( + extra_expected_params={'Metadata': {'foo': 'bar'}} + ) + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_utils.py b/contrib/python/s3transfer/py3/tests/functional/test_utils.py new file mode 100644 index 0000000000..fd4a232ecc --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/functional/test_utils.py @@ -0,0 +1,41 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import socket +import tempfile + +from s3transfer.utils import OSUtils +from __tests__ import skip_if_windows, unittest + + +@skip_if_windows('Windows does not support UNIX special files') +class TestOSUtilsSpecialFiles(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_character_device(self): + self.assertTrue(OSUtils().is_special_file('/dev/null')) + + def test_fifo(self): + os.mkfifo(self.filename) + self.assertTrue(OSUtils().is_special_file(self.filename)) + + def test_socket(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind(self.filename) + self.assertTrue(OSUtils().is_special_file(self.filename)) diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py new file mode 100644 index 0000000000..79ef91c6a2 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py new file mode 100644 index 0000000000..b796f8f24c --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py @@ -0,0 +1,452 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile + +from s3transfer.bandwidth import ( + BandwidthLimitedStream, + BandwidthLimiter, + BandwidthRateTracker, + ConsumptionScheduler, + LeakyBucket, + RequestExceededException, + RequestToken, + TimeUtils, +) +from s3transfer.futures import TransferCoordinator +from __tests__ import mock, unittest + + +class FixedIncrementalTickTimeUtils(TimeUtils): + def __init__(self, seconds_per_tick=1.0): + self._count = 0 + self._seconds_per_tick = seconds_per_tick + + def time(self): + current_count = self._count + self._count += self._seconds_per_tick + return current_count + + +class TestTimeUtils(unittest.TestCase): + @mock.patch('time.time') + def test_time(self, mock_time): + mock_return_val = 1 + mock_time.return_value = mock_return_val + time_utils = TimeUtils() + self.assertEqual(time_utils.time(), mock_return_val) + + @mock.patch('time.sleep') + def test_sleep(self, mock_sleep): + time_utils = TimeUtils() + time_utils.sleep(1) + self.assertEqual(mock_sleep.call_args_list, [mock.call(1)]) + + +class BaseBandwidthLimitTest(unittest.TestCase): + def setUp(self): + self.leaky_bucket = mock.Mock(LeakyBucket) + self.time_utils = mock.Mock(TimeUtils) + self.tempdir = tempfile.mkdtemp() + self.content = b'a' * 1024 * 1024 + self.filename = os.path.join(self.tempdir, 'myfile') + with open(self.filename, 'wb') as f: + f.write(self.content) + self.coordinator = TransferCoordinator() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def assert_consume_calls(self, amts): + expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts] + self.assertEqual( + self.leaky_bucket.consume.call_args_list, expected_consume_args + ) + + +class TestBandwidthLimiter(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) + + def test_get_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.assert_consume_calls(amts=[len(self.content)]) + + def test_get_disabled_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator, enabled=False + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.leaky_bucket.consume.assert_not_called() + + +class TestBandwidthLimitedStream(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bytes_threshold = 1 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def get_bandwidth_limited_stream(self, f): + return BandwidthLimitedStream( + f, + self.leaky_bucket, + self.coordinator, + self.time_utils, + self.bytes_threshold, + ) + + def assert_sleep_calls(self, amts): + expected_sleep_args_list = [mock.call(amt) for amt in amts] + self.assertEqual( + self.time_utils.sleep.call_args_list, expected_sleep_args_list + ) + + def get_unique_consume_request_tokens(self): + return { + call_args[0][1] + for call_args in self.leaky_bucket.consume.call_args_list + } + + def test_read(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[len(self.content)]) + self.assert_sleep_calls(amts=[]) + + def test_retries_on_request_exceeded(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + retry_time = 1 + amt_requested = len(self.content) + self.leaky_bucket.consume.side_effect = [ + RequestExceededException(amt_requested, retry_time), + len(self.content), + ] + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[amt_requested, amt_requested]) + self.assert_sleep_calls(amts=[retry_time]) + + def test_with_transfer_coordinator_exception(self): + self.coordinator.set_exception(ValueError()) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + with self.assertRaises(ValueError): + stream.read(len(self.content)) + + def test_read_when_bandwidth_limiting_disabled(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assertFalse(self.leaky_bucket.consume.called) + + def test_read_toggle_disable_enable_bandwidth_limiting(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.enable_bandwidth_limiting() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + def test_seek(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.seek(1) + self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)]) + + def test_tell(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.tell() + self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()]) + + def test_close(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.close() + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_context_manager(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + with stream as stream_handle: + self.assertIs(stream_handle, stream) + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_reuses_request_token(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) + + def test_request_tokens_unique_per_stream(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) + + def test_call_consume_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + self.assertEqual(stream.read(1), self.content[1:2]) + self.assert_consume_calls(amts=[2]) + + def test_resets_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + self.assertEqual(stream.read(1), self.content[2:3]) + self.assert_consume_calls(amts=[2]) + + def test_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.close() + self.assert_consume_calls(amts=[1]) + + def test_no_bytes_remaining_on(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + stream.close() + # There should have been no more consume() calls made + # as all bytes have been accounted for in the previous + # consume() call. + self.assert_consume_calls(amts=[2]) + + def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.disable_bandwidth_limiting() + stream.close() + self.assert_consume_calls(amts=[]) + + def test_signal_transferring(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.signal_not_transferring() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.signal_transferring() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + +class TestLeakyBucket(unittest.TestCase): + def setUp(self): + self.max_rate = 1 + self.time_now = 1.0 + self.time_utils = mock.Mock(TimeUtils) + self.time_utils.time.return_value = self.time_now + self.scheduler = mock.Mock(ConsumptionScheduler) + self.scheduler.is_scheduled.return_value = False + self.rate_tracker = mock.Mock(BandwidthRateTracker) + self.leaky_bucket = LeakyBucket( + self.max_rate, self.time_utils, self.rate_tracker, self.scheduler + ) + + def set_projected_rate(self, rate): + self.rate_tracker.get_projected_rate.return_value = rate + + def set_retry_time(self, retry_time): + self.scheduler.schedule_consumption.return_value = retry_time + + def assert_recorded_consumed_amt(self, expected_amt): + self.assertEqual( + self.rate_tracker.record_consumption_rate.call_args, + mock.call(expected_amt, self.time_utils.time.return_value), + ) + + def assert_was_scheduled(self, amt, token): + self.assertEqual( + self.scheduler.schedule_consumption.call_args, + mock.call(amt, token, amt / (self.max_rate)), + ) + + def assert_nothing_scheduled(self): + self.assertFalse(self.scheduler.schedule_consumption.called) + + def assert_processed_request_token(self, request_token): + self.assertEqual( + self.scheduler.process_scheduled_consumption.call_args, + mock.call(request_token), + ) + + def test_consume_under_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate / 2) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_at_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_over_max_rate(self): + amt = 1 + retry_time = 2.0 + self.set_projected_rate(self.max_rate + 1) + self.set_retry_time(retry_time) + request_token = RequestToken() + try: + self.leaky_bucket.consume(amt, request_token) + self.fail('A RequestExceededException should have been thrown') + except RequestExceededException as e: + self.assertEqual(e.requested_amt, amt) + self.assertEqual(e.retry_time, retry_time) + self.assert_was_scheduled(amt, request_token) + + def test_consume_with_scheduled_retry(self): + amt = 1 + self.set_projected_rate(self.max_rate + 1) + self.scheduler.is_scheduled.return_value = True + request_token = RequestToken() + self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) + # Nothing new should have been scheduled but the request token + # should have been processed. + self.assert_nothing_scheduled() + self.assert_processed_request_token(request_token) + + +class TestConsumptionScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = ConsumptionScheduler() + + def test_schedule_consumption(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + def test_schedule_consumption_for_multiple_requests(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + other_consume_time = 3 + other_token = RequestToken() + next_wait_time = self.scheduler.schedule_consumption( + 1, other_token, other_consume_time + ) + + # This wait time should be the previous time plus its desired + # wait time + self.assertEqual(next_wait_time, consume_time + other_consume_time) + + def test_is_scheduled(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.assertTrue(self.scheduler.is_scheduled(token)) + + def test_is_not_scheduled(self): + self.assertFalse(self.scheduler.is_scheduled(RequestToken())) + + def test_process_scheduled_consumption(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.scheduler.process_scheduled_consumption(token) + self.assertFalse(self.scheduler.is_scheduled(token)) + different_time = 7 + # The previous consume time should have no affect on the next wait tim + # as it has been completed. + self.assertEqual( + self.scheduler.schedule_consumption(1, token, different_time), + different_time, + ) + + +class TestBandwidthRateTracker(unittest.TestCase): + def setUp(self): + self.alpha = 0.8 + self.rate_tracker = BandwidthRateTracker(self.alpha) + + def test_current_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate_after_one_recorded_point(self): + self.rate_tracker.record_consumption_rate(1, 1) + # There is no last time point to do a diff against so return a + # current rate of 0.0 + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, 0.96) + + def test_get_projected_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) + + def test_get_projected_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + projected_rate = self.rate_tracker.get_projected_rate(1, 3) + self.assertEqual(projected_rate, 0.96) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, projected_rate) + + def test_get_projected_rate_for_same_timestamp(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.assertEqual( + self.rate_tracker.get_projected_rate(1, 1), float('inf') + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py new file mode 100644 index 0000000000..78fdc25845 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py @@ -0,0 +1,105 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import signal +import tempfile +from io import BytesIO + +from s3transfer.compat import BaseManager, readable, seekable +from __tests__ import skip_if_windows, unittest + + +class ErrorRaisingSeekWrapper: + """An object wrapper that throws an error when seeked on + + :param fileobj: The fileobj that it wraps + :param exception: The exception to raise when seeked on. + """ + + def __init__(self, fileobj, exception): + self._fileobj = fileobj + self._exception = exception + + def seek(self, offset, whence=0): + raise self._exception + + def tell(self): + return self._fileobj.tell() + + +class TestSeekable(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_seekable_fileobj(self): + with open(self.filename, 'w') as f: + self.assertTrue(seekable(f)) + + def test_non_file_like_obj(self): + # Fails because there is no seekable(), seek(), nor tell() + self.assertFalse(seekable(object())) + + def test_non_seekable_ioerror(self): + # Should return False if IOError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError()))) + + def test_non_seekable_oserror(self): + # Should return False if OSError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError()))) + + +class TestReadable(unittest.TestCase): + def test_readable_fileobj(self): + with tempfile.TemporaryFile() as f: + self.assertTrue(readable(f)) + + def test_readable_file_like_obj(self): + self.assertTrue(readable(BytesIO())) + + def test_non_file_like_obj(self): + self.assertFalse(readable(object())) + + +class TestBaseManager(unittest.TestCase): + def create_pid_manager(self): + class PIDManager(BaseManager): + pass + + PIDManager.register('getpid', os.getpid) + return PIDManager() + + def get_pid(self, pid_manager): + pid = pid_manager.getpid() + # A proxy object is returned back. The needed value can be acquired + # from the repr and converting that to an integer + return int(str(pid)) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_can_provide_signal_handler_initializers_to_start(self): + manager = self.create_pid_manager() + manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN)) + pid = self.get_pid(manager) + try: + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + # Try using the manager after the os.kill on the parent process. The + # manager should not have died and should still be usable. + self.assertEqual(pid, self.get_pid(manager)) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py new file mode 100644 index 0000000000..3681f69b94 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py @@ -0,0 +1,177 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.copies import CopyObjectTask, CopyPartTask +from __tests__ import BaseTaskTest, RecordingSubscriber + + +class BaseCopyTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} + self.extra_args = {} + self.callbacks = [] + self.size = 5 + + +class TestCopyObjectTask(BaseCopyTaskTest): + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['ACL'] = 'private' + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'ACL': 'private', + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) + + +class TestCopyPartTask(BaseCopyTaskTest): + def setUp(self): + super().setUp() + self.copy_source_range = 'bytes=5-9' + self.extra_args['CopySourceRange'] = self.copy_source_range + self.upload_id = 'myuploadid' + self.part_number = 1 + self.result_etag = 'my-etag' + + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': self.upload_id, + 'part_number': self.part_number, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyPartTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['RequestPayer'] = 'requester' + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + 'RequestPayer': 'requester', + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py new file mode 100644 index 0000000000..8c32668eab --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py @@ -0,0 +1,173 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.credentials import CredentialResolver, ReadOnlyCredentials +from botocore.session import Session + +from s3transfer.exceptions import TransferNotDoneError +from s3transfer.utils import CallArgs +from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + +if HAS_CRT: + import awscrt.s3 + + import s3transfer.crt + + +class CustomFutureException(Exception): + pass + + +@requires_crt +class TestBotocoreCRTRequestSerializer(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.session = Session() + self.session.set_config_variable('region', self.region) + self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( + self.session + ) + self.bucket = "test_bucket" + self.key = "test_key" + self.files = FileCreator() + self.filename = self.files.create_file('myfile', 'my content') + self.expected_path = "/" + self.bucket + "/" + self.key + self.expected_host = "s3.%s.amazonaws.com" % (self.region) + + def tearDown(self): + self.files.remove_all() + + def test_upload_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "put_object", future + ) + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_download_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "get_object", future + ) + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_delete_request(self): + callargs = CallArgs( + bucket=self.bucket, key=self.key, extra_args={}, subscribers=[] + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "delete_object", future + ) + self.assertEqual("DELETE", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + +@requires_crt +class TestCRTCredentialProviderAdapter(unittest.TestCase): + def setUp(self): + self.botocore_credential_provider = mock.Mock(CredentialResolver) + self.access_key = "access_key" + self.secret_key = "secret_key" + self.token = "token" + self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( + self.access_key, self.secret_key, self.token + ) + + def _call_adapter_and_check(self, credentails_provider_adapter): + credentials = credentails_provider_adapter() + self.assertEqual(credentials.access_key_id, self.access_key) + self.assertEqual(credentials.secret_access_key, self.secret_key) + self.assertEqual(credentials.session_token, self.token) + + def test_fetch_crt_credentials_successfully(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + self._call_adapter_and_check(credentails_provider_adapter) + + def test_load_credentials_once(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + called_times = 5 + for i in range(called_times): + self._call_adapter_and_check(credentails_provider_adapter) + # Assert that the load_credentails of botocore credential provider + # will only be called once + self.assertEqual( + self.botocore_credential_provider.load_credentials.call_count, 1 + ) + + +@requires_crt +class TestCRTTransferFuture(unittest.TestCase): + def setUp(self): + self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType) + self.mock_crt_future = mock.Mock(awscrt.s3.Future) + self.mock_s3_request.finished_future = self.mock_crt_future + self.coordinator = s3transfer.crt.CRTTransferCoordinator() + self.coordinator.set_s3_request(self.mock_s3_request) + self.future = s3transfer.crt.CRTTransferFuture( + coordinator=self.coordinator + ) + + def test_set_exception(self): + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() + + def test_set_exception_raises_error_when_not_done(self): + self.mock_crt_future.done.return_value = False + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(CustomFutureException()) + + def test_set_exception_can_override_previous_exception(self): + self.future.set_exception(Exception()) + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py new file mode 100644 index 0000000000..23b77112f2 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py @@ -0,0 +1,67 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.delete import DeleteObjectTask +from __tests__ import BaseTaskTest + + +class TestDeleteObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + + def get_delete_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + } + default_kwargs.update(kwargs) + return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['MFA'] = 'mfa-code' + self.extra_args['VersionId'] = '12345' + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + # These extra_args should be injected into the + # expected params for the delete_object call. + 'MFA': 'mfa-code', + 'VersionId': '12345', + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py new file mode 100644 index 0000000000..2bd095f867 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py @@ -0,0 +1,999 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import os +import shutil +import tempfile +from io import BytesIO + +from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.compat import SOCKET_ERROR +from s3transfer.download import ( + CompleteDownloadNOOPTask, + DeferQueue, + DownloadChunkIterator, + DownloadFilenameOutputManager, + DownloadNonSeekableOutputManager, + DownloadSeekableOutputManager, + DownloadSpecialFilenameOutputManager, + DownloadSubmissionTask, + GetObjectTask, + ImmediatelyWriteIOGetObjectTask, + IOCloseTask, + IORenameFileTask, + IOStreamingWriteTask, + IOWriteTask, +) +from s3transfer.exceptions import RetriesExceededError +from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileCreator, + NonSeekableWriter, + RecordingExecutor, + StreamWithError, + mock, + unittest, +) + + +class DownloadException(Exception): + pass + + +class WriteCollector: + """A utility to collect information about writes and seeks""" + + def __init__(self): + self._pos = 0 + self.writes = [] + + def seek(self, pos, whence=0): + self._pos = pos + + def write(self, data): + self.writes.append((self._pos, data)) + self._pos += len(data) + + +class AlwaysIndicatesSpecialFileOSUtils(OSUtils): + """OSUtil that always returns True for is_special_file""" + + def is_special_file(self, filename): + return True + + +class CancelledStreamWrapper: + """A wrapper to trigger a cancellation while stream reading + + Forces the transfer coordinator to cancel after a certain amount of reads + :param stream: The underlying stream to read from + :param transfer_coordinator: The coordinator for the transfer + :param num_reads: On which read to signal a cancellation. 0 is the first + read. + """ + + def __init__(self, stream, transfer_coordinator, num_reads=0): + self._stream = stream + self._transfer_coordinator = transfer_coordinator + self._num_reads = num_reads + self._count = 0 + + def read(self, *args, **kwargs): + if self._num_reads == self._count: + self._transfer_coordinator.cancel() + self._stream.read(*args, **kwargs) + self._count += 1 + + +class BaseDownloadOutputManagerTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + + # Create a file to write to + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.call_args = CallArgs(fileobj=self.filename) + self.future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + +class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + self.assertTrue(hasattr(f, 'seek')) + # Make sure the name of the file returned is not the same as the + # final filename as we should be writing to a temporary file. + self.assertNotEqual(f.name, self.filename) + + def test_get_final_io_task(self): + ref_contents = b'my_contents' + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + temp_filename = f.name + # Write some data to test that the data gets moved over to the + # final location. + f.write(ref_contents) + final_task = self.download_output_manager.get_final_io_task() + # Make sure it is the appropriate task. + self.assertIsInstance(final_task, IORenameFileTask) + final_task() + # Make sure the temp_file gets removed + self.assertFalse(os.path.exists(temp_filename)) + # Make sure what ever was written to the temp file got moved to + # the final filename + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), ref_contents) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.osutil = AlwaysIndicatesSpecialFileOSUtils() + self.download_output_manager = DownloadSpecialFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible_for_special_file(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, AlwaysIndicatesSpecialFileOSUtils() + ) + ) + + def test_is_not_compatible_for_non_special_file(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, OSUtils() + ) + ) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + # Make sure the name of the file returned is the same as the + # final filename as we should not be writing to a temporary file. + self.assertEqual(f.name, self.filename) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), IOCloseTask + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + +class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + # Create a fileobj to write to + self.fileobj = open(self.filename, 'wb') + + self.call_args = CallArgs(fileobj=self.fileobj) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.fileobj, self.osutil + ) + ) + + def test_is_compatible_bytes_io(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse( + self.download_output_manager.is_compatible(object(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + self.assertIs( + self.download_output_manager.get_fileobj_for_io_writes( + self.future + ), + self.fileobj, + ) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), + CompleteDownloadNOOPTask, + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadNonSeekableOutputManager( + self.osutil, self.transfer_coordinator, io_executor=None + ) + + def test_is_compatible_with_seekable_stream(self): + with open(self.filename, 'wb') as f: + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_not_compatible_with_filename(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_compatible_with_non_seekable_stream(self): + class NonSeekable: + def write(self, data): + pass + + f = NonSeekable() + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_is_compatible_with_bytesio(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIs( + self.download_output_manager.get_download_task_tag(), + IN_MEMORY_DOWNLOAD_TAG, + ) + + def test_submit_writes_from_internal_queue(self): + class FakeQueue: + def request_writes(self, offset, data): + return [ + {'offset': 0, 'data': 'foo'}, + {'offset': 3, 'data': 'bar'}, + ] + + q = FakeQueue() + io_executor = BoundedExecutor(1000, 1) + manager = DownloadNonSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=io_executor, + defer_queue=q, + ) + fileobj = WriteCollector() + manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1) + io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=1 + ) + self.assertIsInstance(io_write_task, IOStreamingWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(0, 'foo')]) + + +class TestDownloadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # Create a stream to read from + self.content = b'my content' + self.stream = BytesIO(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'io_executor': self.io_executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_download_submission_task() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_for_get_object(self, tag_value): + submissions_to_compare = self.executor.submissions + if len(submissions_to_compare) > 1: + # If it was ranged get, make sure we do not include the join task. + submissions_to_compare = submissions_to_compare[:-1] + for submission in submissions_to_compare: + self.assertEqual(submission['tag'], tag_value) + + def add_head_object_response(self): + self.stubber.add_response( + 'head_object', {'ContentLength': len(self.content)} + ) + + def add_get_responses(self): + chunksize = self.config.multipart_chunksize + for i in range(0, len(self.content), chunksize): + if i + chunksize > len(self.content): + stream = BytesIO(self.content[i:]) + self.stubber.add_response('get_object', {'Body': stream}) + else: + stream = BytesIO(self.content[i : i + chunksize]) + self.stubber.add_response('get_object', {'Body': stream}) + + def configure_for_ranged_get(self): + self.config.multipart_threshold = 1 + self.config.multipart_chunksize = 4 + + def get_download_submission_task(self): + return self.get_task( + DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def wait_and_assert_completed_successfully(self, submission_task): + submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + def test_submits_no_tag_for_get_object_filename(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_filename(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def tests_submits_tag_for_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + +class TestGetObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + self.max_attempts = 5 + self.io_executor = BoundedExecutor(1000, 1) + self.content = b'my content' + self.stream = BytesIO(self.content) + self.fileobj = WriteCollector() + self.osutil = OSUtils() + self.io_chunksize = 64 * (1024 ** 2) + self.task_cls = GetObjectTask + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def get_download_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'fileobj': self.fileobj, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'max_attempts': self.max_attempts, + 'download_output_manager': self.download_output_manager, + 'io_chunksize': self.io_chunksize, + } + default_kwargs.update(kwargs) + self.transfer_coordinator.set_status_to_queued() + return self.get_task(self.task_cls, main_kwargs=default_kwargs) + + def assert_io_writes(self, expected_writes): + # Let the io executor process all of the writes before checking + # what writes were sent to it. + self.io_executor.shutdown() + self.assertEqual(self.fileobj.writes, expected_writes) + + def test_main(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_extra_args(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Range': 'bytes=0-', + }, + ) + self.extra_args['Range'] = 'bytes=0-' + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_control_chunk_size(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + + self.assert_io_writes(expected_contents) + + def test_start_index(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(start_index=5) + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(5, self.content)]) + + def test_uses_bandwidth_limiter(self): + bandwidth_limiter = mock.Mock(BandwidthLimiter) + + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual( + bandwidth_limiter.get_bandwith_limited_stream.call_args_list, + [mock.call(mock.ANY, self.transfer_coordinator)], + ) + + def test_retries_succeeds(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + # Retryable error should have not affected the bytes placed into + # the io queue. + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_retries_failure(self): + for _ in range(self.max_attempts): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + + # Should have failed out on a RetriesExceededError + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + self.stubber.assert_no_pending_responses() + + def test_retries_in_middle_of_streaming(self): + # After the first read a retryable error will be thrown + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError( + copy.deepcopy(self.stream), SOCKET_ERROR, 1 + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + # This is the content initially read in before the retry hit on the + # second read() + expected_contents.append((0, bytes(self.content[0:1]))) + + # The rest of the content should be the entire set of data partitioned + # out based on the one byte stream chunk size. Note the second + # element in the list should be a copy of the first element since + # a retryable exception happened in between. + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + self.assert_io_writes(expected_contents) + + def test_cancels_out_of_queueing(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': CancelledStreamWrapper( + self.stream, self.transfer_coordinator + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + # Make sure that no contents were added to the queue because the task + # should have been canceled before trying to add the contents to the + # io queue. + self.assert_io_writes([]) + + def test_handles_callback_on_initial_error(self): + # We can't use the stubber for this because we need to raise + # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows + # you to raise a ClientError. + self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR()) + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + # Should have failed out on a RetriesExceededError because + # get_object keeps raising a socket error. + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + + +class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask): + def setUp(self): + super().setUp() + self.task_cls = ImmediatelyWriteIOGetObjectTask + # When data is written out, it should not use the io executor at all + # if it does use the io executor that is a deviation from expected + # behavior as the data should be written immediately to the file + # object once downloaded. + self.io_executor = None + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def assert_io_writes(self, expected_writes): + self.assertEqual(self.fileobj.writes, expected_writes) + + +class BaseIOTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.osutil = OSUtils() + self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile') + self.final_filename = os.path.join(self.files.rootdir, 'myfile') + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + +class TestIOStreamingWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'foobar'}, + ) + task() + task2 = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'baz'}, + ) + task2() + with open(self.temp_filename, 'rb') as f: + # We should just have written to the file in the order + # the tasks were executed. + self.assertEqual(f.read(), b'foobarbaz') + + +class TestIOWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + # Write once to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0}, + ) + task() + + # Write again to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3}, + ) + task() + + with open(self.temp_filename, 'rb') as f: + self.assertEqual(f.read(), b'foobar') + + +class TestIORenameFileTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IORenameFileTask, + main_kwargs={ + 'fileobj': f, + 'final_filename': self.final_filename, + 'osutil': self.osutil, + }, + ) + task() + self.assertTrue(os.path.exists(self.final_filename)) + self.assertFalse(os.path.exists(self.temp_filename)) + + +class TestIOCloseTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'w') as f: + task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f}) + task() + self.assertTrue(f.closed) + + +class TestDownloadChunkIterator(unittest.TestCase): + def test_iter(self): + content = b'my content' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, len(content)): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'my content']) + + def test_iter_chunksize(self): + content = b'1234' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'123', b'4']) + + def test_empty_content(self): + body = BytesIO(b'') + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'']) + + +class TestDeferQueue(unittest.TestCase): + def setUp(self): + self.q = DeferQueue() + + def test_no_writes_when_not_lowest_block(self): + writes = self.q.request_writes(offset=1, data='bar') + self.assertEqual(writes, []) + + def test_writes_returned_in_order(self): + self.assertEqual(self.q.request_writes(offset=3, data='d'), []) + self.assertEqual(self.q.request_writes(offset=2, data='c'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # Everything at this point has been deferred, but as soon as we + # send offset=0, that will unlock offsets 0-3. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + {'offset': 3, 'data': 'd'}, + ], + ) + + def test_unlocks_partial_range(self): + self.assertEqual(self.q.request_writes(offset=5, data='f'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + ], + ) + + def test_data_can_be_any_size(self): + self.q.request_writes(offset=5, data='hello world') + writes = self.q.request_writes(offset=0, data='abcde') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'abcde'}, + {'offset': 5, 'data': 'hello world'}, + ], + ) + + def test_data_queued_in_order(self): + # This immediately gets returned because offset=0 is the + # next range we're waiting on. + writes = self.q.request_writes(offset=0, data='hello world') + self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}]) + # Same thing here but with offset + writes = self.q.request_writes(offset=11, data='hello again') + self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}]) + + def test_writes_below_min_offset_are_ignored(self): + self.q.request_writes(offset=0, data='a') + self.q.request_writes(offset=1, data='b') + self.q.request_writes(offset=2, data='c') + + # At this point we're expecting offset=3, so if a write + # comes in below 3, we ignore it. + self.assertEqual(self.q.request_writes(offset=0, data='a'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + self.assertEqual( + self.q.request_writes(offset=3, data='d'), + [{'offset': 3, 'data': 'd'}], + ) + + def test_duplicate_writes_are_ignored(self): + self.q.request_writes(offset=2, data='c') + self.q.request_writes(offset=1, data='b') + + # We're still waiting for offset=0, but if + # a duplicate write comes in for offset=2/offset=1 + # it's ignored. This gives "first one wins" behavior. + self.assertEqual(self.q.request_writes(offset=2, data='X'), []) + self.assertEqual(self.q.request_writes(offset=1, data='Y'), []) + + self.assertEqual( + self.q.request_writes(offset=0, data='a'), + [ + {'offset': 0, 'data': 'a'}, + # Note we're seeing 'b' 'c', and not 'X', 'Y'. + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + ], + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py new file mode 100644 index 0000000000..ca2888a654 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py @@ -0,0 +1,696 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import ( + CancelledError, + FatalError, + TransferNotDoneError, +) +from s3transfer.futures import ( + BaseExecutor, + BoundedExecutor, + ExecutorFuture, + NonThreadedExecutor, + NonThreadedExecutorFuture, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.tasks import Task +from s3transfer.utils import ( + FunctionContainer, + NoResourcesAvailable, + TaskSemaphore, +) +from __tests__ import ( + RecordingExecutor, + TransferCoordinatorWithInterrupt, + mock, + unittest, +) + + +def return_call_args(*args, **kwargs): + return args, kwargs + + +def raise_exception(exception): + raise exception + + +def get_exc_info(exception): + try: + raise_exception(exception) + except Exception: + return sys.exc_info() + + +class RecordingTransferCoordinator(TransferCoordinator): + def __init__(self): + self.all_transfer_futures_ever_associated = set() + super().__init__() + + def add_associated_future(self, future): + self.all_transfer_futures_ever_associated.add(future) + super().add_associated_future(future) + + +class ReturnFooTask(Task): + def _main(self, **kwargs): + return 'foo' + + +class SleepTask(Task): + def _main(self, sleep_time, **kwargs): + time.sleep(sleep_time) + + +class TestTransferFuture(unittest.TestCase): + def setUp(self): + self.meta = TransferMeta() + self.coordinator = TransferCoordinator() + self.future = self._get_transfer_future() + + def _get_transfer_future(self, **kwargs): + components = { + 'meta': self.meta, + 'coordinator': self.coordinator, + } + for component_name, component in kwargs.items(): + components[component_name] = component + return TransferFuture(**components) + + def test_meta(self): + self.assertIs(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.coordinator.set_result(None) + self.assertTrue(self.future.done()) + + def test_result(self): + result = 'foo' + self.coordinator.set_result(result) + self.coordinator.announce_done() + self.assertEqual(self.future.result(), result) + + def test_keyboard_interrupt_on_result_does_not_block(self): + # This should raise a KeyboardInterrupt when result is called on it. + self.coordinator = TransferCoordinatorWithInterrupt() + self.future = self._get_transfer_future() + + # result() should not block and immediately raise the keyboard + # interrupt exception. + with self.assertRaises(KeyboardInterrupt): + self.future.result() + + def test_cancel(self): + self.future.cancel() + self.assertTrue(self.future.done()) + self.assertEqual(self.coordinator.status, 'cancelled') + + def test_set_exception(self): + # Set the result such that there is no exception + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.assertEqual(self.future.result(), 'result') + + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + def test_set_exception_only_after_done(self): + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(ValueError()) + + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + +class TestTransferMeta(unittest.TestCase): + def setUp(self): + self.transfer_meta = TransferMeta() + + def test_size(self): + self.assertEqual(self.transfer_meta.size, None) + self.transfer_meta.provide_transfer_size(5) + self.assertEqual(self.transfer_meta.size, 5) + + def test_call_args(self): + call_args = object() + transfer_meta = TransferMeta(call_args) + # Assert the that call args provided is the same as is returned + self.assertIs(transfer_meta.call_args, call_args) + + def test_transfer_id(self): + transfer_meta = TransferMeta(transfer_id=1) + self.assertEqual(transfer_meta.transfer_id, 1) + + def test_user_context(self): + self.transfer_meta.user_context['foo'] = 'bar' + self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) + + +class TestTransferCoordinator(unittest.TestCase): + def setUp(self): + self.transfer_coordinator = TransferCoordinator() + + def test_transfer_id(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual(transfer_coordinator.transfer_id, 1) + + def test_repr(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual( + repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)' + ) + + def test_initial_status(self): + # A TransferCoordinator with no progress should have the status + # of not-started + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + def test_set_status_to_queued(self): + self.transfer_coordinator.set_status_to_queued() + self.assertEqual(self.transfer_coordinator.status, 'queued') + + def test_cannot_set_status_to_queued_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_queued() + + def test_status_running(self): + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + def test_cannot_set_status_to_running_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_running() + + def test_set_result(self): + success_result = 'foo' + self.transfer_coordinator.set_result(success_result) + self.transfer_coordinator.announce_done() + # Setting result should result in a success state and the return value + # that was set. + self.assertEqual(self.transfer_coordinator.status, 'success') + self.assertEqual(self.transfer_coordinator.result(), success_result) + + def test_set_exception(self): + exception_result = RuntimeError + self.transfer_coordinator.set_exception(exception_result) + self.transfer_coordinator.announce_done() + # Setting an exception should result in a failed state and the return + # value should be the raised exception + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(self.transfer_coordinator.exception, exception_result) + with self.assertRaises(exception_result): + self.transfer_coordinator.result() + + def test_exception_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError) + # It status should be success even after the exception is set because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_exception_can_override_done_state_with_override_flag(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError, override=True) + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_cancel(self): + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.cancel() + # This should set the state to cancelled and raise the CancelledError + # exception and should have also set the done event so that result() + # is no longer set. + self.assertEqual(self.transfer_coordinator.status, 'cancelled') + with self.assertRaises(CancelledError): + self.transfer_coordinator.result() + + def test_cancel_can_run_done_callbacks_that_uses_result(self): + exceptions = [] + + def capture_exception(transfer_coordinator, captured_exceptions): + try: + transfer_coordinator.result() + except Exception as e: + captured_exceptions.append(e) + + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.add_done_callback( + capture_exception, self.transfer_coordinator, exceptions + ) + self.transfer_coordinator.cancel() + + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], CancelledError) + + def test_cancel_with_message(self): + message = 'my message' + self.transfer_coordinator.cancel(message) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + self.transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my message' + self.transfer_coordinator.cancel(message, exc_type=FatalError) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + self.transfer_coordinator.result() + + def test_cancel_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.cancel() + # It status should be success even after cancel is called because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_set_result_can_override_cancel(self): + self.transfer_coordinator.cancel() + # Result setting should override any cancel or set exception as this + # is always invoked by the final task. + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.announce_done() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_submit(self): + # Submit a callable to the transfer coordinator. It should submit it + # to the executor. + executor = RecordingExecutor( + BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)}) + ) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task, tag='my-tag') + executor.shutdown() + # Make sure the future got submit and executed as well by checking its + # result value which should include the provided future tag. + self.assertEqual( + executor.submissions, + [{'block': True, 'tag': 'my-tag', 'task': task}], + ) + self.assertEqual(future.result(), 'foo') + + def test_association_and_disassociation_on_submit(self): + self.transfer_coordinator = RecordingTransferCoordinator() + + # Submit a callable to the transfer coordinator. + executor = BoundedExecutor(1, 1) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task) + executor.shutdown() + + # Make sure the future that got submitted was associated to the + # transfer future at some point. + self.assertEqual( + self.transfer_coordinator.all_transfer_futures_ever_associated, + {future}, + ) + + # Make sure the future got disassociated once the future is now done + # by looking at the currently associated futures. + self.assertEqual(self.transfer_coordinator.associated_futures, set()) + + def test_done(self): + # These should result in not done state: + # queued + self.assertFalse(self.transfer_coordinator.done()) + # running + self.transfer_coordinator.set_status_to_running() + self.assertFalse(self.transfer_coordinator.done()) + + # These should result in done state: + # failed + self.transfer_coordinator.set_exception(Exception) + self.assertTrue(self.transfer_coordinator.done()) + + # success + self.transfer_coordinator.set_result('foo') + self.assertTrue(self.transfer_coordinator.done()) + + # cancelled + self.transfer_coordinator.cancel() + self.assertTrue(self.transfer_coordinator.done()) + + def test_result_waits_until_done(self): + execution_order = [] + + def sleep_then_set_result(transfer_coordinator, execution_order): + time.sleep(0.05) + execution_order.append('setting_result') + transfer_coordinator.set_result(None) + self.transfer_coordinator.announce_done() + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit( + sleep_then_set_result, + self.transfer_coordinator, + execution_order, + ) + self.transfer_coordinator.result() + execution_order.append('after_result') + + # The result() call should have waited until the other thread set + # the result after sleeping for 0.05 seconds. + self.assertTrue(execution_order, ['setting_result', 'after_result']) + + def test_failure_cleanups(self): + args = (1, 2) + kwargs = {'foo': 'bar'} + + second_args = (2, 4) + second_kwargs = {'biz': 'baz'} + + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *args, **kwargs + ) + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *second_args, **second_kwargs + ) + + # Ensure the callbacks got added. + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2) + + result_list = [] + # Ensure they will get called in the correct order. + for cleanup in self.transfer_coordinator.failure_cleanups: + result_list.append(cleanup()) + self.assertEqual( + result_list, [(args, kwargs), (second_args, second_kwargs)] + ) + + def test_associated_futures(self): + first_future = object() + # Associate one future to the transfer + self.transfer_coordinator.add_associated_future(first_future) + associated_futures = self.transfer_coordinator.associated_futures + # The first future should be in the returned list of futures. + self.assertEqual(associated_futures, {first_future}) + + second_future = object() + # Associate another future to the transfer. + self.transfer_coordinator.add_associated_future(second_future) + # The association should not have mutated the returned list from + # before. + self.assertEqual(associated_futures, {first_future}) + + # Both futures should be in the returned list. + self.assertEqual( + self.transfer_coordinator.associated_futures, + {first_future, second_future}, + ) + + def test_done_callbacks_on_done(self): + done_callback_invocations = [] + callback = FunctionContainer( + done_callback_invocations.append, 'done callback called' + ) + + # Add the done callback to the transfer. + self.transfer_coordinator.add_done_callback(callback) + + # Announce that the transfer is done. This should invoke the done + # callback. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + # If done is announced again, we should not invoke the callback again + # because done has already been announced and thus the callback has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + def test_failure_cleanups_on_done(self): + cleanup_invocations = [] + callback = FunctionContainer( + cleanup_invocations.append, 'cleanup called' + ) + + # Add the failure cleanup to the transfer. + self.transfer_coordinator.add_failure_cleanup(callback) + + # Announce that the transfer is done. This should invoke the failure + # cleanup. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + # If done is announced again, we should not invoke the cleanup again + # because done has already been announced and thus the cleanup has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + +class TestBoundedExecutor(unittest.TestCase): + def setUp(self): + self.coordinator = TransferCoordinator() + self.tag_semaphores = {} + self.executor = self.get_executor() + + def get_executor(self, max_size=1, max_num_threads=1): + return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores) + + def get_task(self, task_cls, main_kwargs=None): + return task_cls(self.coordinator, main_kwargs=main_kwargs) + + def get_sleep_task(self, sleep_time=0.01): + return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time}) + + def add_semaphore(self, task_tag, count): + self.tag_semaphores[task_tag] = TaskSemaphore(count) + + def assert_submit_would_block(self, task, tag=None): + with self.assertRaises(NoResourcesAvailable): + self.executor.submit(task, tag=tag, block=False) + + def assert_submit_would_not_block(self, task, tag=None, **kwargs): + try: + self.executor.submit(task, tag=tag, block=False) + except NoResourcesAvailable: + self.fail( + 'Task {} should not have been blocked. Caused by:\n{}'.format( + task, traceback.format_exc() + ) + ) + + def add_done_callback_to_future(self, future, fn, *args, **kwargs): + callback_for_future = FunctionContainer(fn, *args, **kwargs) + future.add_done_callback(callback_for_future) + + def test_submit_single_task(self): + # Ensure we can submit a task to the executor + task = self.get_task(ReturnFooTask) + future = self.executor.submit(task) + + # Ensure what we get back is a Future + self.assertIsInstance(future, ExecutorFuture) + # Ensure the callable got executed. + self.assertEqual(future.result(), 'foo') + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_executor_blocks_on_full_capacity(self): + first_task = self.get_sleep_task() + second_task = self.get_sleep_task() + self.executor.submit(first_task) + # The first task should be sleeping for a substantial period of + # time such that on the submission of the second task, it will + # raise an error saying that it cannot be submitted as the max + # capacity of the semaphore is one. + self.assert_submit_would_block(second_task) + + def test_executor_clears_capacity_on_done_tasks(self): + first_task = self.get_sleep_task() + second_task = self.get_task(ReturnFooTask) + + # Submit a task. + future = self.executor.submit(first_task) + + # Submit a new task when the first task finishes. This should not get + # blocked because the first task should have finished clearing up + # capacity. + self.add_done_callback_to_future( + future, self.assert_submit_would_not_block, second_task + ) + + # Wait for it to complete. + self.executor.shutdown() + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_would_not_block_when_full_capacity_in_other_semaphore(self): + first_task = self.get_sleep_task() + + # Now let's create a new task with a tag and so it uses different + # semaphore. + task_tag = 'other' + other_task = self.get_sleep_task() + self.add_semaphore(task_tag, 1) + + # Submit the normal first task + self.executor.submit(first_task) + + # Even though The first task should be sleeping for a substantial + # period of time, the submission of the second task should not + # raise an error because it should use a different semaphore + self.assert_submit_would_not_block(other_task, task_tag) + + # Another submission of the other task though should raise + # an exception as the capacity is equal to one for that tag. + self.assert_submit_would_block(other_task, task_tag) + + def test_shutdown(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown() + # Ensure that the shutdown waits until the task is done + self.assertTrue(future.done()) + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_shutdown_no_wait(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown(False) + # Ensure that the shutdown returns immediately even if the task is + # not done, which it should not be because it it slow. + self.assertFalse(future.done()) + + def test_replace_underlying_executor(self): + mocked_executor_cls = mock.Mock(BaseExecutor) + executor = BoundedExecutor(10, 1, {}, mocked_executor_cls) + executor.submit(self.get_task(ReturnFooTask)) + self.assertTrue(mocked_executor_cls.return_value.submit.called) + + +class TestExecutorFuture(unittest.TestCase): + def test_result(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'})) + + def test_done(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertTrue(wrapped_future.done()) + + def test_add_done_callback(self): + done_callbacks = [] + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + wrapped_future.add_done_callback( + FunctionContainer(done_callbacks.append, 'called') + ) + self.assertEqual(done_callbacks, ['called']) + + +class TestNonThreadedExecutor(unittest.TestCase): + def test_submit(self): + executor = NonThreadedExecutor() + future = executor.submit(return_call_args, 1, 2, foo='bar') + self.assertIsInstance(future, NonThreadedExecutorFuture) + self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'})) + + def test_submit_with_exception(self): + executor = NonThreadedExecutor() + future = executor.submit(raise_exception, RuntimeError()) + self.assertIsInstance(future, NonThreadedExecutorFuture) + with self.assertRaises(RuntimeError): + future.result() + + def test_submit_with_exception_and_captures_info(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + future = NonThreadedExecutor().submit(raise_exception, exception) + try: + future.result() + # An exception should have been raised + self.fail('Future should have raised a ValueError') + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + +class TestNonThreadedExecutorFuture(unittest.TestCase): + def setUp(self): + self.future = NonThreadedExecutorFuture() + + def test_done_starts_false(self): + self.assertFalse(self.future.done()) + + def test_done_after_setting_result(self): + self.future.set_result('result') + self.assertTrue(self.future.done()) + + def test_done_after_setting_exception(self): + self.future.set_exception_info(Exception(), None) + self.assertTrue(self.future.done()) + + def test_result(self): + self.future.set_result('result') + self.assertEqual(self.future.result(), 'result') + + def test_exception_result(self): + exception = ValueError('message') + self.future.set_exception_info(exception, None) + with self.assertRaisesRegex(ValueError, 'message'): + self.future.result() + + def test_exception_result_doesnt_modify_last_frame(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + self.future.set_exception_info(exception, tb) + try: + self.future.result() + # An exception should have been raised + self.fail() + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + def test_done_callback(self): + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, []) + self.future.set_result('result') + self.assertEqual(done_futures, [self.future]) + + def test_done_callback_after_done(self): + self.future.set_result('result') + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, [self.future]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py new file mode 100644 index 0000000000..fc3caa843f --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py @@ -0,0 +1,143 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import time +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import TransferCoordinator +from s3transfer.manager import TransferConfig, TransferCoordinatorController +from __tests__ import TransferCoordinatorWithInterrupt, unittest + + +class FutureResultException(Exception): + pass + + +class TestTransferConfig(unittest.TestCase): + def test_exception_on_zero_attr_value(self): + with self.assertRaises(ValueError): + TransferConfig(max_request_queue_size=0) + + +class TestTransferCoordinatorController(unittest.TestCase): + def setUp(self): + self.coordinator_controller = TransferCoordinatorController() + + def sleep_then_announce_done(self, transfer_coordinator, sleep_time): + time.sleep(sleep_time) + transfer_coordinator.set_result('done') + transfer_coordinator.announce_done() + + def assert_coordinator_is_cancelled(self, transfer_coordinator): + self.assertEqual(transfer_coordinator.status, 'cancelled') + + def test_add_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Ensure that is tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, + {transfer_coordinator}, + ) + + def test_remove_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Now remove the coordinator + self.coordinator_controller.remove_transfer_coordinator( + transfer_coordinator + ) + # Make sure that it is no longer getting tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, set() + ) + + def test_cancel(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Cancel with the canceler + self.coordinator_controller.cancel() + # Check that coordinator got canceled + self.assert_coordinator_is_cancelled(transfer_coordinator) + + def test_cancel_with_message(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message, exc_type=FatalError) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + transfer_coordinator.result() + + def test_wait_for_done_transfer_coordinators(self): + # Create a coordinator and add it to the canceler + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + + sleep_time = 0.02 + with ThreadPoolExecutor(max_workers=1) as executor: + # In a separate thread sleep and then set the transfer coordinator + # to done after sleeping. + start_time = time.time() + executor.submit( + self.sleep_then_announce_done, transfer_coordinator, sleep_time + ) + # Now call wait to wait for the transfer coordinator to be done. + self.coordinator_controller.wait() + end_time = time.time() + wait_time = end_time - start_time + # The time waited should not be less than the time it took to sleep in + # the separate thread because the wait ending should be dependent on + # the sleeping thread announcing that the transfer coordinator is done. + self.assertTrue(sleep_time <= wait_time) + + def test_wait_does_not_propogate_exceptions_from_result(self): + transfer_coordinator = TransferCoordinator() + transfer_coordinator.set_exception(FutureResultException()) + transfer_coordinator.announce_done() + try: + self.coordinator_controller.wait() + except FutureResultException as e: + self.fail('%s should not have been raised.' % e) + + def test_wait_can_be_interrupted(self): + inject_interrupt_coordinator = TransferCoordinatorWithInterrupt() + self.coordinator_controller.add_transfer_coordinator( + inject_interrupt_coordinator + ) + with self.assertRaises(KeyboardInterrupt): + self.coordinator_controller.wait() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py new file mode 100644 index 0000000000..d77b5e0240 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py @@ -0,0 +1,728 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import queue +import signal +import threading +import time +from io import BytesIO + +from botocore.client import BaseClient +from botocore.config import Config +from botocore.exceptions import ClientError, ReadTimeoutError + +from s3transfer.constants import PROCESS_USER_AGENT +from s3transfer.exceptions import CancelledError, RetriesExceededError +from s3transfer.processpool import ( + SHUTDOWN_SIGNAL, + ClientFactory, + DownloadFileRequest, + GetObjectJob, + GetObjectSubmitter, + GetObjectWorker, + ProcessPoolDownloader, + ProcessPoolTransferFuture, + ProcessPoolTransferMeta, + ProcessTransferConfig, + TransferMonitor, + TransferState, + ignore_ctrl_c, +) +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + FileCreator, + StreamWithError, + StubbedClientTest, + mock, + skip_if_windows, + unittest, +) + + +class RenameFailingOSUtils(OSUtils): + def __init__(self, exception): + self.exception = exception + + def rename_file(self, current_filename, new_filename): + raise self.exception + + +class TestIgnoreCtrlC(unittest.TestCase): + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_ignore_ctrl_c(self): + with ignore_ctrl_c(): + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + self.fail( + 'The ignore_ctrl_c context manager should have ' + 'ignored the KeyboardInterrupt exception' + ) + + +class TestProcessPoolDownloader(unittest.TestCase): + def test_uses_client_kwargs(self): + with mock.patch('s3transfer.processpool.ClientFactory') as factory: + ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'}) + self.assertEqual( + factory.call_args[0][0], {'region_name': 'myregion'} + ) + + +class TestProcessPoolTransferFuture(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + self.meta = ProcessPoolTransferMeta( + transfer_id=self.transfer_id, call_args=CallArgs() + ) + self.future = ProcessPoolTransferFuture( + monitor=self.monitor, meta=self.meta + ) + + def test_meta(self): + self.assertEqual(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.future.done()) + + def test_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.future.result()) + + def test_result_with_exception(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.future.result() + + def test_result_with_keyboard_interrupt(self): + mock_monitor = mock.Mock(TransferMonitor) + mock_monitor._connect = mock.Mock() + mock_monitor.poll_for_result.side_effect = KeyboardInterrupt() + future = ProcessPoolTransferFuture( + monitor=mock_monitor, meta=self.meta + ) + with self.assertRaises(KeyboardInterrupt): + future.result() + self.assertTrue(mock_monitor._connect.called) + self.assertTrue(mock_monitor.notify_exception.called) + call_args = mock_monitor.notify_exception.call_args[0] + self.assertEqual(call_args[0], self.transfer_id) + self.assertIsInstance(call_args[1], CancelledError) + + def test_cancel(self): + self.future.cancel() + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(CancelledError): + self.future.result() + + +class TestProcessPoolTransferMeta(unittest.TestCase): + def test_transfer_id(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.transfer_id, 1) + + def test_call_args(self): + call_args = CallArgs() + meta = ProcessPoolTransferMeta(1, call_args) + self.assertEqual(meta.call_args, call_args) + + def test_user_context(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.user_context, {}) + meta.user_context['mykey'] = 'myvalue' + self.assertEqual(meta.user_context, {'mykey': 'myvalue'}) + + +class TestClientFactory(unittest.TestCase): + def test_create_client(self): + client = ClientFactory().create_client() + self.assertIsInstance(client, BaseClient) + self.assertEqual(client.meta.service_model.service_name, 's3') + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_create_client_with_client_kwargs(self): + client = ClientFactory({'region_name': 'myregion'}).create_client() + self.assertEqual(client.meta.region_name, 'myregion') + + def test_user_agent_with_config(self): + client = ClientFactory({'config': Config()}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent_extra(self): + config = Config(user_agent_extra='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent(self): + config = Config(user_agent='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + +class TestTransferMonitor(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + + def test_notify_new_transfer_creates_new_state(self): + monitor = TransferMonitor() + transfer_id = monitor.notify_new_transfer() + self.assertFalse(monitor.is_done(transfer_id)) + self.assertIsNone(monitor.get_exception(transfer_id)) + + def test_notify_new_transfer_increments_transfer_id(self): + monitor = TransferMonitor() + self.assertEqual(monitor.notify_new_transfer(), 0) + self.assertEqual(monitor.notify_new_transfer(), 1) + + def test_notify_get_exception(self): + exception = Exception() + self.monitor.notify_exception(self.transfer_id, exception) + self.assertEqual( + self.monitor.get_exception(self.transfer_id), exception + ) + + def test_get_no_exception(self): + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + def test_notify_jobs(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0) + + def test_notify_jobs_for_multiple_transfers(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + other_transfer_id = self.monitor.notify_new_transfer() + self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual( + self.monitor.notify_job_complete(other_transfer_id), 1 + ) + + def test_done(self): + self.assertFalse(self.monitor.is_done(self.transfer_id)) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.monitor.is_done(self.transfer_id)) + + def test_poll_for_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.monitor.poll_for_result(self.transfer_id)) + + def test_poll_for_result_raises_error(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.monitor.poll_for_result(self.transfer_id) + + def test_poll_for_result_waits_till_done(self): + event_order = [] + + def sleep_then_notify_done(): + time.sleep(0.05) + event_order.append('notify_done') + self.monitor.notify_done(self.transfer_id) + + t = threading.Thread(target=sleep_then_notify_done) + t.start() + + self.monitor.poll_for_result(self.transfer_id) + event_order.append('done_polling') + self.assertEqual(event_order, ['notify_done', 'done_polling']) + + def test_notify_cancel_all_in_progress(self): + monitor = TransferMonitor() + transfer_ids = [] + for _ in range(10): + transfer_ids.append(monitor.notify_new_transfer()) + monitor.notify_cancel_all_in_progress() + for transfer_id in transfer_ids: + self.assertIsInstance( + monitor.get_exception(transfer_id), CancelledError + ) + # Cancelling a transfer does not mean it is done as there may + # be cleanup work left to do. + self.assertFalse(monitor.is_done(transfer_id)) + + def test_notify_cancel_does_not_affect_done_transfers(self): + self.monitor.notify_done(self.transfer_id) + self.monitor.notify_cancel_all_in_progress() + self.assertTrue(self.monitor.is_done(self.transfer_id)) + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + +class TestTransferState(unittest.TestCase): + def setUp(self): + self.state = TransferState() + + def test_done(self): + self.assertFalse(self.state.done) + self.state.set_done() + self.assertTrue(self.state.done) + + def test_waits_till_done_is_set(self): + event_order = [] + + def sleep_then_set_done(): + time.sleep(0.05) + event_order.append('set_done') + self.state.set_done() + + t = threading.Thread(target=sleep_then_set_done) + t.start() + + self.state.wait_till_done() + event_order.append('done_waiting') + self.assertEqual(event_order, ['set_done', 'done_waiting']) + + def test_exception(self): + exception = RuntimeError() + self.state.exception = exception + self.assertEqual(self.state.exception, exception) + + def test_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.jobs_to_complete, 5) + + def test_decrement_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.decrement_jobs_to_complete(), 4) + + +class TestGetObjectSubmitter(StubbedClientTest): + def setUp(self): + super().setUp() + self.transfer_config = ProcessTransferConfig() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = mock.Mock(OSUtils) + self.download_request_queue = queue.Queue() + self.worker_queue = queue.Queue() + self.submitter = GetObjectSubmitter( + transfer_config=self.transfer_config, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + download_request_queue=self.download_request_queue, + worker_queue=self.worker_queue, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.filename = 'myfile' + self.temp_filename = 'myfile.temp' + self.osutil.get_temp_filename.return_value = self.temp_filename + self.extra_args = {} + self.expected_size = None + + def add_download_file_request(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'filename': self.filename, + 'extra_args': self.extra_args, + 'expected_size': self.expected_size, + } + kwargs.update(override_kwargs) + self.download_request_queue.put(DownloadFileRequest(**kwargs)) + + def add_shutdown(self): + self.download_request_queue.put(SHUTDOWN_SIGNAL) + + def assert_submitted_get_object_jobs(self, expected_jobs): + actual_jobs = [] + while not self.worker_queue.empty(): + actual_jobs.append(self.worker_queue.get()) + self.assertEqual(actual_jobs, expected_jobs) + + def test_run_for_non_ranged_download(self): + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_for_ranged_download(self): + self.transfer_config.multipart_chunksize = 2 + self.transfer_config.multipart_threshold = 4 + self.add_download_file_request(expected_size=4) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 4) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'Range': 'bytes=0-1'}, + filename=self.filename, + ), + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=2, + extra_args={'Range': 'bytes=2-'}, + filename=self.filename, + ), + ] + ) + + def test_run_when_expected_size_not_provided(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_with_extra_args(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + self.add_download_file_request( + extra_args={'VersionId': 'versionid'}, expected_size=None + ) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'VersionId': 'versionid'}, + filename=self.filename, + ) + ] + ) + + def test_run_with_exception(self): + self.stubber.add_client_error('head_object', 'NoSuchKey', 404) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_with_error_in_allocating_temp_file(self): + self.osutil.allocate.side_effect = OSError() + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), OSError + ) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_submitter_cannot_be_killed(self): + self.add_download_file_request(expected_size=None) + self.add_shutdown() + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.head_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.submitter.run() + except KeyboardInterrupt: + self.fail( + 'The submitter should have not been killed by the ' + 'KeyboardInterrupt' + ) + + +class TestGetObjectWorker(StubbedClientTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.queue = queue.Queue() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = OSUtils() + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.remote_contents = b'my content' + self.temp_filename = self.files.create_file('tempfile', '') + self.extra_args = {} + self.offset = 0 + self.final_filename = self.files.full_path('final_filename') + self.stream = BytesIO(self.remote_contents) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + def add_get_object_job(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'temp_filename': self.temp_filename, + 'extra_args': self.extra_args, + 'offset': self.offset, + 'filename': self.final_filename, + } + kwargs.update(override_kwargs) + self.queue.put(GetObjectJob(**kwargs)) + + def add_shutdown(self): + self.queue.put(SHUTDOWN_SIGNAL) + + def add_stubbed_get_object_response(self, body=None, expected_params=None): + if body is None: + body = self.stream + get_object_response = {'Body': body} + + if expected_params is None: + expected_params = {'Bucket': self.bucket, 'Key': self.key} + + self.stubber.add_response( + 'get_object', get_object_response, expected_params + ) + + def assert_contents(self, filename, contents): + self.assertTrue(os.path.exists(filename)) + with open(filename, 'rb') as f: + self.assertEqual(f.read(), contents) + + def assert_does_not_exist(self, filename): + self.assertFalse(os.path.exists(filename)) + + def test_run_is_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_contents(self.final_filename, self.remote_contents) + + def test_run_jobs_is_not_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + self.assert_does_not_exist(self.final_filename) + + def test_run_with_extra_args(self): + self.add_get_object_job(extra_args={'VersionId': 'versionid'}) + self.add_shutdown() + self.add_stubbed_get_object_response( + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + } + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + + def test_run_with_offset(self): + offset = 1 + self.add_get_object_job(offset=offset) + self.add_shutdown() + self.add_stubbed_get_object_response() + + self.worker.run() + with open(self.temp_filename, 'rb') as f: + f.seek(offset) + self.assertEqual(f.read(), self.remote_contents) + + def test_run_error_in_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.stubber.add_client_error('get_object', 'NoSuchKey', 404) + self.add_stubbed_get_object_response() + + self.worker.run() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_does_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + self.add_stubbed_get_object_response() + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + + def test_run_can_exhaust_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + # 5 is the current setting for max number of GetObject attempts + for _ in range(5): + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), + RetriesExceededError, + ) + + def test_run_skips_get_object_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + + self.worker.run() + # Note we did not add a stubbed response for get_object + self.stubber.assert_no_pending_responses() + + def test_run_final_job_removes_file_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + def test_run_fails_to_rename_file(self): + exception = OSError() + osutil = RenameFailingOSUtils(exception) + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=osutil, + ) + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.assertEqual( + self.transfer_monitor.get_exception(self.transfer_id), exception + ) + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_worker_cannot_be_killed(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.get_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.worker.run() + except KeyboardInterrupt: + self.fail( + 'The worker should have not been killed by the ' + 'KeyboardInterrupt' + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py new file mode 100644 index 0000000000..35cf4a22dd --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py @@ -0,0 +1,780 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import socket +import tempfile +from concurrent import futures +from contextlib import closing +from io import BytesIO, StringIO + +from s3transfer import ( + MultipartDownloader, + MultipartUploader, + OSUtils, + QueueShutdownError, + ReadFileChunk, + S3Transfer, + ShutdownQueue, + StreamReaderProgress, + TransferConfig, + disable_upload_callbacks, + enable_upload_callbacks, + random_file_extension, +) +from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError +from __tests__ import mock, unittest + + +class InMemoryOSLayer(OSUtils): + def __init__(self, filemap): + self.filemap = filemap + + def get_file_size(self, filename): + return len(self.filemap[filename]) + + def open_file_chunk_reader(self, filename, start_byte, size, callback): + return closing(BytesIO(self.filemap[filename])) + + def open(self, filename, mode): + if 'wb' in mode: + fileobj = BytesIO() + self.filemap[filename] = fileobj + return closing(fileobj) + else: + return closing(self.filemap[filename]) + + def remove_file(self, filename): + if filename in self.filemap: + del self.filemap[filename] + + def rename_file(self, current_filename, new_filename): + if current_filename in self.filemap: + self.filemap[new_filename] = self.filemap.pop(current_filename) + + +class SequentialExecutor: + def __init__(self, max_workers): + pass + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + # The real map() interface actually takes *args, but we specifically do + # _not_ use this interface. + def map(self, function, args): + results = [] + for arg in args: + results.append(function(arg)) + return results + + def submit(self, function): + future = futures.Future() + future.set_result(function()) + return future + + +class TestOSUtils(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_get_file_size(self): + with mock.patch('os.path.getsize') as m: + OSUtils().get_file_size('myfile') + m.assert_called_with('myfile') + + def test_open_file_chunk_reader(self): + with mock.patch('s3transfer.ReadFileChunk') as m: + OSUtils().open_file_chunk_reader('myfile', 0, 100, None) + m.from_filename.assert_called_with( + 'myfile', 0, 100, None, enable_callback=False + ) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + with mock.patch('os.remove') as remove: + remove.side_effect = OSError('fake error') + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_remove_file_proxies_remove_file(self): + with mock.patch('os.remove') as remove: + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_rename_file(self): + with mock.patch('s3transfer.compat.rename_file') as rename_file: + OSUtils().rename_file('foo', 'newfoo') + rename_file.assert_called_with('foo', 'newfoo') + + +class TestReadFileChunk(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + +class TestReadFileChunkWithCallback(TestReadFileChunk): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + with open(self.filename, 'wb') as f: + f.write(b'abc') + self.amounts_seen = [] + + def callback(self, amount): + self.amounts_seen.append(amount) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + +class TestStreamReaderProgress(unittest.TestCase): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + amounts_seen = [] + + def callback(amount): + amounts_seen.append(amount) + + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream, callback) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(amounts_seen, [9]) + + +class TestMultipartUploader(unittest.TestCase): + def test_multipart_upload_uses_correct_client_calls(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + # We need to check both the sequence of calls (create/upload/complete) + # as well as the params passed between the calls, including + # 1. The upload_id was plumbed through + # 2. The collected etags were added to the complete call. + client.create_multipart_upload.assert_called_with( + Bucket='bucket', Key='key' + ) + # Should be two parts. + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_injects_proper_kwargs(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + extra_args = { + 'SSECustomerKey': 'fakekey', + 'SSECustomerAlgorithm': 'AES256', + 'StorageClass': 'REDUCED_REDUNDANCY', + } + uploader.upload_file('filename', 'bucket', 'key', None, extra_args) + + client.create_multipart_upload.assert_called_with( + Bucket='bucket', + Key='key', + # The initial call should inject all the storage class params. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + StorageClass='REDUCED_REDUNDANCY', + ) + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + # We only have to forward certain **extra_args in subsequent + # UploadPart calls. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_is_aborted_on_error(self): + # If the create_multipart_upload succeeds and any upload_part + # fails, then abort_multipart_upload will be called. + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.side_effect = Exception( + "Some kind of error occurred." + ) + + with self.assertRaises(S3UploadFailedError): + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + client.abort_multipart_upload.assert_called_with( + Bucket='bucket', Key='key', UploadId='upload_id' + ) + + +class TestMultipartDownloader(unittest.TestCase): + + maxDiff = None + + def test_multipart_download_uses_correct_client_calls(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + client.get_object.assert_called_with( + Range='bytes=0-', Bucket='bucket', Key='key' + ) + + def test_multipart_download_with_multiple_parts(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + # For testing purposes, we're testing with a multipart threshold + # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz', + # this should result in 3 calls. In python slices this would be: + # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly + # different because they use inclusive ranges. + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # Note these are inclusive ranges. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_retry_on_failures_from_stream_reads(self): + # If we get an exception during a call to the response body's .read() + # method, we should retry the request. + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = [ + socket.error("fake error"), + response_body, + ] + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # The first call to range=0-3 fails because of the + # side_effect above where we make the .read() raise a + # socket.error. + # The second call to range=0-3 then succeeds. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_exception_raised_on_exceeded_retries(self): + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = socket.error("fake error") + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + with self.assertRaises(RetriesExceededError): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_io_thread_failure_triggers_shutdown(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + os_layer = mock.Mock() + mock_fileobj = mock.MagicMock() + mock_fileobj.__enter__.return_value = mock_fileobj + mock_fileobj.write.side_effect = Exception("fake IO error") + os_layer.open.return_value = mock_fileobj + + downloader = MultipartDownloader( + client, TransferConfig(), os_layer, SequentialExecutor + ) + # We're verifying that the exception raised from the IO future + # propagates back up via download_file(). + with self.assertRaisesRegex(Exception, "fake IO error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_download_futures_fail_triggers_shutdown(self): + class FailedDownloadParts(SequentialExecutor): + def __init__(self, max_workers): + self.is_first = True + + def submit(self, function): + future = futures.Future() + if self.is_first: + # This is the download_parts_thread. + future.set_exception( + Exception("fake download parts error") + ) + self.is_first = False + return future + + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts + ) + with self.assertRaisesRegex(Exception, "fake download parts error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + +class TestS3Transfer(unittest.TestCase): + def setUp(self): + self.client = mock.Mock() + self.random_file_patch = mock.patch('s3transfer.random_file_extension') + self.random_file = self.random_file_patch.start() + self.random_file.return_value = 'RANDOM' + + def tearDown(self): + self.random_file_patch.stop() + + def test_callback_handlers_register_on_put_item(self): + osutil = InMemoryOSLayer({'smallfile': b'foobar'}) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + events = self.client.meta.events + events.register_first.assert_called_with( + 'request-created.s3', + disable_upload_callbacks, + unique_id='s3upload-callback-disable', + ) + events.register_last.assert_called_with( + 'request-created.s3', + enable_upload_callbacks, + unique_id='s3upload-callback-enable', + ) + + def test_upload_below_multipart_threshold_uses_put_object(self): + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY + ) + + def test_extra_args_on_uploaded_passed_to_api_call(self): + extra_args = {'ACL': 'public-read'} + fake_files = {'smallfile': b'hello world'} + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file( + 'smallfile', 'bucket', 'key', extra_args=extra_args + ) + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read' + ) + + def test_uses_multipart_upload_when_over_threshold(self): + with mock.patch('s3transfer.MultipartUploader') as uploader: + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + config = TransferConfig( + multipart_threshold=2, multipart_chunksize=2 + ) + transfer = S3Transfer(self.client, osutil=osutil, config=config) + transfer.upload_file('smallfile', 'bucket', 'key') + + uploader.return_value.upload_file.assert_called_with( + 'smallfile', 'bucket', 'key', None, {} + ) + + def test_uses_multipart_download_when_over_threshold(self): + with mock.patch('s3transfer.MultipartDownloader') as downloader: + osutil = InMemoryOSLayer({}) + over_multipart_threshold = 100 * 1024 * 1024 + transfer = S3Transfer(self.client, osutil=osutil) + callback = mock.sentinel.CALLBACK + self.client.head_object.return_value = { + 'ContentLength': over_multipart_threshold, + } + transfer.download_file( + 'bucket', 'key', 'filename', callback=callback + ) + + downloader.return_value.download_file.assert_called_with( + # Note how we're downloading to a temporary random file. + 'bucket', + 'key', + 'filename.RANDOM', + over_multipart_threshold, + {}, + callback, + ) + + def test_download_file_with_invalid_extra_args(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + with self.assertRaises(ValueError): + transfer.download_file( + 'bucket', + 'key', + '/tmp/smallfile', + extra_args={'BadValue': 'foo'}, + ) + + def test_upload_file_with_invalid_extra_args(self): + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + bad_args = {"WebsiteRedirectLocation": "/foo"} + with self.assertRaises(ValueError): + transfer.upload_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args + ) + + def test_download_file_fowards_extra_args(self): + extra_args = { + 'SSECustomerKey': 'foo', + 'SSECustomerAlgorithm': 'AES256', + } + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args + ) + + # Note that we need to invoke the HeadObject call + # and the PutObject call with the extra_args. + # This is necessary. Trying to HeadObject an SSE object + # will return a 400 if you don't provide the required + # params. + self.client.get_object.assert_called_with( + Bucket='bucket', + Key='key', + SSECustomerAlgorithm='AES256', + SSECustomerKey='foo', + ) + + def test_get_object_stream_is_retried_and_succeeds(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.side_effect = [ + # First request fails. + socket.error("fake error"), + # Second succeeds. + {'Body': BytesIO(b'foobar')}, + ] + transfer.download_file('bucket', 'key', '/tmp/smallfile') + + self.assertEqual(self.client.get_object.call_count, 2) + + def test_get_object_stream_uses_all_retries_and_errors_out(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + # Here we're raising an exception every single time, which + # will exhaust our retry count and propagate a + # RetriesExceededError. + self.client.get_object.side_effect = socket.error("fake error") + with self.assertRaises(RetriesExceededError): + transfer.download_file('bucket', 'key', 'smallfile') + + self.assertEqual(self.client.get_object.call_count, 5) + # We should have also cleaned up the in progress file + # we were downloading to. + self.assertEqual(osutil.filemap, {}) + + def test_download_below_multipart_threshold(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file('bucket', 'key', 'smallfile') + + self.client.get_object.assert_called_with(Bucket='bucket', Key='key') + + def test_can_create_with_just_client(self): + transfer = S3Transfer(client=mock.Mock()) + self.assertIsInstance(transfer, S3Transfer) + + +class TestShutdownQueue(unittest.TestCase): + def test_handles_normal_put_get_requests(self): + q = ShutdownQueue() + q.put('foo') + self.assertEqual(q.get(), 'foo') + + def test_put_raises_error_on_shutdown(self): + q = ShutdownQueue() + q.trigger_shutdown() + with self.assertRaises(QueueShutdownError): + q.put('foo') + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestCallbackHandlers(unittest.TestCase): + def setUp(self): + self.request = mock.Mock() + + def test_disable_request_on_put_object(self): + disable_upload_callbacks(self.request, 'PutObject') + self.request.body.disable_callback.assert_called_with() + + def test_disable_request_on_upload_part(self): + disable_upload_callbacks(self.request, 'UploadPart') + self.request.body.disable_callback.assert_called_with() + + def test_enable_object_on_put_object(self): + enable_upload_callbacks(self.request, 'PutObject') + self.request.body.enable_callback.assert_called_with() + + def test_enable_object_on_upload_part(self): + enable_upload_callbacks(self.request, 'UploadPart') + self.request.body.enable_callback.assert_called_with() + + def test_dont_disable_if_missing_interface(self): + del self.request.body.disable_callback + disable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_enable_if_missing_interface(self): + del self.request.body.enable_callback + enable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_disable_if_wrong_operation(self): + disable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.disable_callback.called) + + def test_dont_enable_if_wrong_operation(self): + enable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.enable_callback.called) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py new file mode 100644 index 0000000000..a26d3a548c --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py @@ -0,0 +1,91 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.exceptions import InvalidSubscriberMethodError +from s3transfer.subscribers import BaseSubscriber +from __tests__ import unittest + + +class ExtraMethodsSubscriber(BaseSubscriber): + def extra_method(self): + return 'called extra method' + + +class NotCallableSubscriber(BaseSubscriber): + on_done = 'foo' + + +class NoKwargsSubscriber(BaseSubscriber): + def on_done(self): + pass + + +class OverrideMethodSubscriber(BaseSubscriber): + def on_queued(self, **kwargs): + return kwargs + + +class OverrideConstructorSubscriber(BaseSubscriber): + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + +class TestSubscribers(unittest.TestCase): + def test_can_instantiate_base_subscriber(self): + try: + BaseSubscriber() + except InvalidSubscriberMethodError: + self.fail('BaseSubscriber should be instantiable') + + def test_can_call_base_subscriber_method(self): + subscriber = BaseSubscriber() + try: + subscriber.on_done(future=None) + except Exception as e: + self.fail( + 'Should be able to call base class subscriber method. ' + 'instead got: %s' % e + ) + + def test_subclass_can_have_and_call_additional_methods(self): + subscriber = ExtraMethodsSubscriber() + self.assertEqual(subscriber.extra_method(), 'called extra method') + + def test_can_subclass_and_override_method_from_base_subscriber(self): + subscriber = OverrideMethodSubscriber() + # Make sure that the overridden method is called + self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'}) + + def test_can_subclass_and_override_constructor_from_base_class(self): + subscriber = OverrideConstructorSubscriber('foo', arg2='bar') + # Make sure you can create a custom constructor. + self.assertEqual(subscriber.arg1, 'foo') + self.assertEqual(subscriber.arg2, 'bar') + + def test_invalid_arguments_in_constructor_of_subclass_subscriber(self): + # The override constructor should still have validation of + # constructor args. + with self.assertRaises(TypeError): + OverrideConstructorSubscriber() + + def test_not_callable_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must be callable' + ): + NotCallableSubscriber() + + def test_no_kwargs_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must accept keyword' + ): + NoKwargsSubscriber() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py new file mode 100644 index 0000000000..4f0bc4d1cc --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py @@ -0,0 +1,833 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from concurrent import futures +from functools import partial +from threading import Event + +from s3transfer.futures import BoundedExecutor, TransferCoordinator +from s3transfer.subscribers import BaseSubscriber +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + RecordingSubscriber, + unittest, +) + + +class TaskFailureException(Exception): + pass + + +class SuccessTask(Task): + def _main( + self, return_value='success', callbacks=None, failure_cleanups=None + ): + if callbacks: + for callback in callbacks: + callback() + if failure_cleanups: + for failure_cleanup in failure_cleanups: + self._transfer_coordinator.add_failure_cleanup(failure_cleanup) + return return_value + + +class FailureTask(Task): + def _main(self, exception=TaskFailureException): + raise exception() + + +class ReturnKwargsTask(Task): + def _main(self, **kwargs): + return kwargs + + +class SubmitMoreTasksTask(Task): + def _main(self, executor, tasks_to_submit): + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + + +class NOOPSubmissionTask(SubmissionTask): + def _submit(self, transfer_future, **kwargs): + pass + + +class ExceptionSubmissionTask(SubmissionTask): + def _submit( + self, + transfer_future, + executor=None, + tasks_to_submit=None, + additional_callbacks=None, + exception=TaskFailureException, + ): + if executor and tasks_to_submit: + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + if additional_callbacks: + for callback in additional_callbacks: + callback() + raise exception() + + +class StatusRecordingTransferCoordinator(TransferCoordinator): + def __init__(self, transfer_id=None): + super().__init__(transfer_id) + self.status_changes = [self._status] + + def set_status_to_queued(self): + super().set_status_to_queued() + self._record_status_change() + + def set_status_to_running(self): + super().set_status_to_running() + self._record_status_change() + + def _record_status_change(self): + self.status_changes.append(self._status) + + +class RecordingStateSubscriber(BaseSubscriber): + def __init__(self, transfer_coordinator): + self._transfer_coordinator = transfer_coordinator + self.status_during_on_queued = None + + def on_queued(self, **kwargs): + self.status_during_on_queued = self._transfer_coordinator.status + + +class TestSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.executor = BoundedExecutor(1000, 5) + self.call_args = CallArgs(subscribers=[]) + self.transfer_future = self.get_transfer_future(self.call_args) + self.main_kwargs = {'transfer_future': self.transfer_future} + + def test_transitions_from_not_started_to_queued_to_running(self): + self.transfer_coordinator = StatusRecordingTransferCoordinator() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + # Status should be queued until submission task has been ran. + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + submission_task() + # Once submission task has been ran, the status should now be running. + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Ensure the transitions were as expected as well. + self.assertEqual( + self.transfer_coordinator.status_changes, + ['not-started', 'queued', 'running'], + ) + + def test_on_queued_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the on_queued callback of the subscriber is called. + self.assertEqual( + subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_on_queued_status_in_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingStateSubscriber(self.transfer_coordinator) + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the status was queued during on_queued callback. + self.assertEqual(subscriber.status_during_on_queued, 'queued') + + def test_sets_exception_from_submit(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Make sure the status of the future is failed + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the future propagates the exception encountered in the + # submission task. + with self.assertRaises(TaskFailureException): + self.transfer_future.result() + + def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self): + self.main_kwargs['exception'] = KeyboardInterrupt + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + self.assertEqual(self.transfer_coordinator.status, 'failed') + with self.assertRaises(KeyboardInterrupt): + self.transfer_future.result() + + def test_calls_done_callbacks_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_calls_failure_cleanups_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_cleanups_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the done callbacks to be invoke more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called only once. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_done_callbacks_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the failure cleanups to be invoked more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called only once. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_handles_cleanups_submitted_in_other_tasks(self): + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + # We want the cleanup to be added in the execution of the task and + # still be executed by the submission task when it fails. + task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [task] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called even though the callback got + # added in a completely different task. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self): + # In this test, we want to make sure that any tasks that may be + # submitted in another task complete before we start performing + # cleanups. + # + # This is tested by doing the following: + # + # ExecutionSubmissionTask + # | + # +--submits-->SubmitMoreTasksTask + # | + # +--submits-->SuccessTask + # | + # +-->sleeps-->adds failure cleanup + # + # In the end, the failure cleanup of the SuccessTask should be ran + # when the ExecutionSubmissionTask fails. If the + # ExeceptionSubmissionTask did not run the failure cleanup it is most + # likely that it did not wait for the SuccessTask to complete, which + # it needs to because the ExeceptionSubmissionTask does not know + # what failure cleanups it needs to run until all spawned tasks have + # completed. + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + + cleanup_task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + task_for_submitting_cleanup_task = self.get_task( + SubmitMoreTasksTask, + main_kwargs={ + 'executor': self.executor, + 'tasks_to_submit': [cleanup_task], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [ + task_for_submitting_cleanup_task + ] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_submission_task_announces_done_if_cancelled_before_main(self): + invocations_of_done = [] + done_callback = FunctionContainer( + invocations_of_done.append, 'done announced' + ) + self.transfer_coordinator.add_done_callback(done_callback) + + self.transfer_coordinator.cancel() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Because the submission task was cancelled before being run + # it did not submit any extra tasks so a result it is responsible + # for making sure it announces done as nothing else will. + self.assertEqual(invocations_of_done, ['done announced']) + + +class TestTask(unittest.TestCase): + def setUp(self): + self.transfer_id = 1 + self.transfer_coordinator = TransferCoordinator( + transfer_id=self.transfer_id + ) + + def test_repr(self): + main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'} + task = ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs + ) + # The repr should not include the other parameter because it is not + # a desired parameter to include. + self.assertEqual( + repr(task), + 'ReturnKwargsTask(transfer_id={}, {})'.format( + self.transfer_id, {'bucket': 'mybucket'} + ), + ) + + def test_transfer_id(self): + task = SuccessTask(self.transfer_coordinator) + # Make sure that the id is the one provided to the id associated + # to the transfer coordinator. + self.assertEqual(task.transfer_id, self.transfer_id) + + def test_context_status_transitioning_success(self): + # The status should be set to running. + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # If a task is called, the status still should be running. + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Once the final task is called, the status should be set to success. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_context_status_transitioning_failed(self): + self.transfer_coordinator.set_status_to_running() + + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # A failure task should result in the failed status + FailureTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Even if the final task comes in and succeeds, it should stay failed. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_result_setting_for_success(self): + override_return = 'foo' + SuccessTask(self.transfer_coordinator)() + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': override_return}, + is_final=True, + )() + + # The return value for the transfer future should be of the final + # task. + self.assertEqual(self.transfer_coordinator.result(), override_return) + + def test_result_setting_for_error(self): + FailureTask(self.transfer_coordinator)() + + # If another failure comes in, the result should still throw the + # original exception when result() is eventually called. + FailureTask( + self.transfer_coordinator, main_kwargs={'exception': Exception} + )() + + # Even if a success task comes along, the result of the future + # should be the original exception + SuccessTask(self.transfer_coordinator, is_final=True)() + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_done_callbacks_success(self): + callback_results = [] + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For successful tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + def test_done_callbacks_failure(self): + callback_results = [] + FailureTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For even failed tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + # Callbacks should continue to be called even after a related failure + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'third'), + partial(callback_results.append, 'fourth'), + ], + )() + self.assertEqual( + callback_results, ['first', 'second', 'third', 'fourth'] + ) + + def test_failure_cleanups_on_failure(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + FailureTask(self.transfer_coordinator)() + # The failure callbacks should have not been called yet because it + # is not the last task + self.assertEqual(callback_results, []) + + # Now the failure callbacks should get called. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(callback_results, ['first', 'second']) + + def test_no_failure_cleanups_on_success(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + SuccessTask(self.transfer_coordinator, is_final=True)() + # The failure cleanups should not have been called because no task + # failed for the transfer context. + self.assertEqual(callback_results, []) + + def test_passing_main_kwargs(self): + main_kwargs = {'foo': 'bar', 'baz': 'biz'} + ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True + )() + # The kwargs should have been passed to the main() + self.assertEqual(self.transfer_coordinator.result(), main_kwargs) + + def test_passing_pending_kwargs_single_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo']}, + ) + ) + pending_kwargs['baz'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['baz']}, + ) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_kwargs_list_of_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': ['first', 'second']} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][0]}, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][1]}, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['foo'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out in the expected order. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_and_non_pending_kwargs(self): + main_kwargs = {'nonpending_value': 'foo'} + pending_kwargs = {} + ref_main_kwargs = { + 'nonpending_value': 'foo', + 'pending_value': 'bar', + 'pending_list': ['first', 'second'], + } + + # Create the pending tasks + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['pending_value'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_value'] + }, + ) + ) + + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][0] + }, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][1] + }, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + # and just regular nonpending kwargs. + ReturnKwargsTask( + self.transfer_coordinator, + main_kwargs=main_kwargs, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have all of the kwargs (both pending and + # nonpending) + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_single_failed_pending_future(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + pending_kwargs['baz'] = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_single_failed_pending_future_in_list(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + second_future = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value in the list + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + +class BaseMultipartTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + + +class TestCreateMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'foo' + extra_args = {'Metadata': {'foo': 'bar'}} + response = {'UploadId': upload_id} + task = self.get_task( + CreateMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='create_multipart_upload', + service_response=response, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + result_id = task() + self.stubber.assert_no_pending_responses() + # Ensure the upload id returned is correct + self.assertEqual(upload_id, result_id) + + # Make sure that the abort was added as a cleanup failure + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1) + + # Make sure if it is called, it will abort correctly + self.stubber.add_response( + method='abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + }, + ) + self.transfer_coordinator.failure_cleanups[0]() + self.stubber.assert_no_pending_responses() + + +class TestCompleteMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + + def test_includes_extra_args(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {'RequestPayer': 'requester'}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + 'RequestPayer': 'requester', + }, + ) + task() + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py new file mode 100644 index 0000000000..1ac38b3616 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py @@ -0,0 +1,694 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import math +import os +import shutil +import tempfile +from io import BytesIO + +from botocore.stub import ANY + +from s3transfer.futures import IN_MEMORY_UPLOAD_TAG +from s3transfer.manager import TransferConfig +from s3transfer.upload import ( + AggregatedProgressCallback, + InterruptReader, + PutObjectTask, + UploadFilenameInputManager, + UploadNonSeekableInputManager, + UploadPartTask, + UploadSeekableInputManager, + UploadSubmissionTask, +) +from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileSizeProvider, + NonSeekableReader, + RecordingExecutor, + RecordingSubscriber, + unittest, +) + + +class InterruptionError(Exception): + pass + + +class OSUtilsExceptionOnFileSize(OSUtils): + def get_file_size(self, filename): + raise AssertionError( + "The file %s should not have been stated" % filename + ) + + +class BaseUploadTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + self.osutil = OSUtils() + + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'my content' + self.subscribers = [] + + with open(self.filename, 'wb') as f: + f.write(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + +class TestAggregatedProgressCallback(unittest.TestCase): + def setUp(self): + self.aggregated_amounts = [] + self.threshold = 3 + self.aggregated_progress_callback = AggregatedProgressCallback( + [self.callback], self.threshold + ) + + def callback(self, bytes_transferred): + self.aggregated_amounts.append(bytes_transferred) + + def test_under_threshold(self): + one_under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(one_under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback(1) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_at_threshold(self): + self.aggregated_progress_callback(self.threshold) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_over_threshold(self): + over_threshold_amount = self.threshold + 1 + self.aggregated_progress_callback(over_threshold_amount) + self.assertEqual(self.aggregated_amounts, [over_threshold_amount]) + + def test_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + def test_flush_with_nothing_to_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + # Flushing again should do nothing as it was just flushed + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + +class TestInterruptReader(BaseUploadTest): + def test_read_raises_exception(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Read some bytes to show it can be read. + self.assertEqual(reader.read(1), self.content[0:1]) + # Then set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError()) + # The next read should have the exception propograte + with self.assertRaises(InterruptionError): + reader.read() + + def test_seek(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can seek correctly + reader.seek(1) + self.assertEqual(reader.read(1), self.content[1:2]) + + def test_tell(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can tell correctly + reader.seek(1) + self.assertEqual(reader.tell(), 1) + + +class BaseUploadInputManagerTest(BaseUploadTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + self.config = TransferConfig() + self.recording_subscriber = RecordingSubscriber() + self.subscribers.append(self.recording_subscriber) + + def _get_expected_body_for_part(self, part_number): + # A helper method for retrieving the expected body for a specific + # part number of the data + total_size = len(self.content) + chunk_size = self.config.multipart_chunksize + start_index = (part_number - 1) * chunk_size + end_index = part_number * chunk_size + if end_index >= total_size: + return self.content[start_index:] + return self.content[start_index:end_index] + + +class TestUploadFilenameInputManager(BaseUploadInputManagerTest): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadFilenameInputManager( + self.osutil, self.transfer_coordinator + ) + self.call_args = CallArgs( + fileobj=self.filename, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def test_is_compatible(self): + self.assertTrue( + self.upload_input_manager.is_compatible( + self.future.meta.call_args.fileobj + ) + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # The provided file size should be equal to size of the contents of + # the file. + self.assertEqual(self.future.meta.size, len(self.content)) + + def test_requires_multipart_upload(self): + self.future.meta.provide_transfer_size(len(self.content)) + # With the default multipart threshold, the length of the content + # should be smaller than the threshold thus not requiring a multipart + # transfer. + self.assertFalse( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + # Decreasing the threshold to that of the length of the content of + # the file should trigger the need for a multipart upload. + self.config.multipart_threshold = len(self.content) + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + def test_get_put_object_body(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + read_file_chunk.enable_callback() + # The file-like object provided back should be the same as the content + # of the file. + with read_file_chunk: + self.assertEqual(read_file_chunk.read(), self.content) + # The file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_get_put_object_body_is_interruptable(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + # Ensure the returned read file chunk can be interrupted with that + # error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + def test_yield_upload_part_bodies(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + expected_part_number = 1 + for part_number, read_file_chunk in part_iterator: + # Ensure that the part number is as expected + self.assertEqual(part_number, expected_part_number) + read_file_chunk.enable_callback() + # Ensure that the body is correct for that part. + with read_file_chunk: + self.assertEqual( + read_file_chunk.read(), + self._get_expected_body_for_part(part_number), + ) + expected_part_number += 1 + + # All of the file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_yield_upload_part_bodies_are_interruptable(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + for _, read_file_chunk in part_iterator: + # Ensure that each read file chunk yielded can be interrupted + # with that error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + +class TestUploadSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = open(self.filename, 'rb') + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible_bytes_io(self): + self.assertTrue(self.upload_input_manager.is_compatible(BytesIO())) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse(self.upload_input_manager.is_compatible(object())) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_get_put_object_body(self): + start_pos = 3 + self.fileobj.seek(start_pos) + adjusted_size = len(self.content) - start_pos + self.future.meta.provide_transfer_size(adjusted_size) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + read_file_chunk.enable_callback() + # The fact that the file was seeked to start should be taken into + # account in length and content for the read file chunk. + with read_file_chunk: + self.assertEqual(len(read_file_chunk), adjusted_size) + self.assertEqual(read_file_chunk.read(), self.content[start_pos:]) + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), adjusted_size + ) + + +class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadNonSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = NonSeekableReader(self.content) + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def assert_multipart_parts(self): + """ + Asserts that the input manager will generate a multipart upload + and that each part is in order and the correct size. + """ + # Assert that a multipart upload is required. + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + # Get a list of all the parts that would be sent. + parts = list( + self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + ) + + # Assert that the actual number of parts is what we would expect + # based on the configuration. + size = self.config.multipart_chunksize + num_parts = math.ceil(len(self.content) / size) + self.assertEqual(len(parts), num_parts) + + # Run for every part but the last part. + for i, part in enumerate(parts[:-1]): + # Assert the part number is correct. + self.assertEqual(part[0], i + 1) + # Assert the part contains the right amount of data. + data = part[1].read() + self.assertEqual(len(data), size) + + # Assert that the last part is the correct size. + expected_final_size = len(self.content) - ((num_parts - 1) * size) + final_part = parts[-1] + self.assertEqual(len(final_part[1].read()), expected_final_size) + + # Assert that the last part has the correct part number. + self.assertEqual(final_part[0], len(parts)) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # There is no way to get the size without reading the entire body. + self.assertEqual(self.future.meta.size, None) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_initial_data_parts_threshold_lesser(self): + # threshold < size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 2 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_equal(self): + # threshold == size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 4 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_greater(self): + # threshold > size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 8 + self.assert_multipart_parts() + + +class TestUploadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3) + self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE + self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5 + + with open(self.filename, 'wb') as f: + f.write(self.content) + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def add_multipart_upload_stubbed_responses(self): + self.stubber.add_response( + method='create_multipart_upload', + service_response={'UploadId': 'my-id'}, + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-1'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-2'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-3'} + ) + self.stubber.add_response( + method='complete_multipart_upload', service_response={} + ) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_value_for_put_object(self, tag_value): + self.assertEqual(self.executor.submissions[0]['tag'], tag_value) + + def assert_tag_value_for_upload_parts(self, tag_value): + for submission in self.executor.submissions[1:-1]: + self.assertEqual(submission['tag'], tag_value) + + def test_provide_file_size_on_put(self): + self.call_args.subscribers.append(FileSizeProvider(len(self.content))) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + + # With this submitter, it will fail to stat the file if a transfer + # size is not provided. + self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize() + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + def test_submits_no_tag_for_put_object_filename(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_no_tag_for_multipart_filename(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit any of the upload part tasks were + # were associated when submitted to the executor + self.assert_tag_value_for_upload_parts(None) + + def test_submits_no_tag_for_put_object_fileobj(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_tag_for_multipart_fileobj(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure tags to limit all of the upload part tasks were + # were associated when submitted to the executor as these tasks will + # have chunks of data stored with them in memory. + self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG) + + +class TestPutObjectTask(BaseUploadTest): + def test_main(self): + extra_args = {'Metadata': {'foo': 'bar'}} + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + PutObjectTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + +class TestUploadPartTask(BaseUploadTest): + def test_main(self): + extra_args = {'RequestPayer': 'requester'} + upload_id = 'my-id' + part_number = 1 + etag = 'foo' + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + UploadPartTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'part_number': part_number, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='upload_part', + service_response={'ETag': etag}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'PartNumber': part_number, + 'RequestPayer': 'requester', + }, + ) + rval = task() + self.stubber.assert_no_pending_responses() + self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number}) + self.assertEqual(self.sent_bodies, [self.content]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py new file mode 100644 index 0000000000..a1ff904e7a --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py @@ -0,0 +1,1189 @@ +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import io +import os.path +import random +import re +import shutil +import tempfile +import threading +import time +from io import BytesIO, StringIO + +from s3transfer.futures import TransferFuture, TransferMeta +from s3transfer.utils import ( + MAX_PARTS, + MAX_SINGLE_UPLOAD_SIZE, + MIN_UPLOAD_CHUNKSIZE, + CallArgs, + ChunksizeAdjuster, + CountCallbackInvoker, + DeferredOpenFile, + FunctionContainer, + NoResourcesAvailable, + OSUtils, + ReadFileChunk, + SlidingWindowSemaphore, + StreamReaderProgress, + TaskSemaphore, + calculate_num_parts, + calculate_range_parameter, + get_callbacks, + get_filtered_dict, + invoke_progress_callbacks, + random_file_extension, +) +from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest + + +class TestGetCallbacks(unittest.TestCase): + def setUp(self): + self.subscriber = RecordingSubscriber() + self.second_subscriber = RecordingSubscriber() + self.call_args = CallArgs( + subscribers=[self.subscriber, self.second_subscriber] + ) + self.transfer_meta = TransferMeta(self.call_args) + self.transfer_future = TransferFuture(self.transfer_meta) + + def test_get_callbacks(self): + callbacks = get_callbacks(self.transfer_future, 'queued') + # Make sure two callbacks were added as both subscribers had + # an on_queued method. + self.assertEqual(len(callbacks), 2) + + # Ensure that the callback was injected with the future by calling + # one of them and checking that the future was used in the call. + callbacks[0]() + self.assertEqual( + self.subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_get_callbacks_for_missing_type(self): + callbacks = get_callbacks(self.transfer_future, 'fake_state') + # There should be no callbacks as the subscribers will not have the + # on_fake_state method + self.assertEqual(len(callbacks), 0) + + +class TestGetFilteredDict(unittest.TestCase): + def test_get_filtered_dict(self): + original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'} + whitelist = ['Include'] + self.assertEqual( + get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'} + ) + + +class TestCallArgs(unittest.TestCase): + def test_call_args(self): + call_args = CallArgs(foo='bar', biz='baz') + self.assertEqual(call_args.foo, 'bar') + self.assertEqual(call_args.biz, 'baz') + + +class TestFunctionContainer(unittest.TestCase): + def get_args_kwargs(self, *args, **kwargs): + return args, kwargs + + def test_call(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual(func_container(), (('foo',), {'bar': 'baz'})) + + def test_repr(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual( + str(func_container), + 'Function: {} with args {} and kwargs {}'.format( + self.get_args_kwargs, ('foo',), {'bar': 'baz'} + ), + ) + + +class TestCountCallbackInvoker(unittest.TestCase): + def invoke_callback(self): + self.ref_results.append('callback invoked') + + def assert_callback_invoked(self): + self.assertEqual(self.ref_results, ['callback invoked']) + + def assert_callback_not_invoked(self): + self.assertEqual(self.ref_results, []) + + def setUp(self): + self.ref_results = [] + self.invoker = CountCallbackInvoker(self.invoke_callback) + + def test_increment(self): + self.invoker.increment() + self.assertEqual(self.invoker.current_count, 1) + + def test_decrement(self): + self.invoker.increment() + self.invoker.increment() + self.invoker.decrement() + self.assertEqual(self.invoker.current_count, 1) + + def test_count_cannot_go_below_zero(self): + with self.assertRaises(RuntimeError): + self.invoker.decrement() + + def test_callback_invoked_only_once_finalized(self): + self.invoker.increment() + self.invoker.decrement() + self.assert_callback_not_invoked() + self.invoker.finalize() + # Callback should only be invoked once finalized + self.assert_callback_invoked() + + def test_callback_invoked_after_finalizing_and_count_reaching_zero(self): + self.invoker.increment() + self.invoker.finalize() + # Make sure that it does not get invoked immediately after + # finalizing as the count is currently one + self.assert_callback_not_invoked() + self.invoker.decrement() + self.assert_callback_invoked() + + def test_cannot_increment_after_finalization(self): + self.invoker.finalize() + with self.assertRaises(RuntimeError): + self.invoker.increment() + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestInvokeProgressCallbacks(unittest.TestCase): + def test_invoke_progress_callbacks(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 2) + self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2) + + def test_invoke_progress_callbacks_with_no_progress(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 0) + self.assertEqual(len(recording_subscriber.on_progress_calls), 0) + + +class TestCalculateNumParts(unittest.TestCase): + def test_calculate_num_parts_divisible(self): + self.assertEqual(calculate_num_parts(size=4, part_size=2), 2) + + def test_calculate_num_parts_not_divisible(self): + self.assertEqual(calculate_num_parts(size=3, part_size=2), 2) + + +class TestCalculateRangeParameter(unittest.TestCase): + def setUp(self): + self.part_size = 5 + self.part_index = 1 + self.num_parts = 3 + + def test_calculate_range_paramter(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, self.num_parts + ) + self.assertEqual(range_val, 'bytes=5-9') + + def test_last_part_with_no_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2 + ) + self.assertEqual(range_val, 'bytes=5-') + + def test_last_part_with_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2, total_size=8 + ) + self.assertEqual(range_val, 'bytes=5-7') + + +class BaseUtilsTest(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + self.content = b'abc' + with open(self.filename, 'wb') as f: + f.write(self.content) + self.amounts_seen = [] + self.num_close_callback_calls = 0 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def callback(self, bytes_transferred): + self.amounts_seen.append(bytes_transferred) + + def close_callback(self): + self.num_close_callback_calls += 1 + + +class TestOSUtils(BaseUtilsTest): + def test_get_file_size(self): + self.assertEqual( + OSUtils().get_file_size(self.filename), len(self.content) + ) + + def test_open_file_chunk_reader(self): + reader = OSUtils().open_file_chunk_reader( + self.filename, 0, 3, [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + + def test_open_file_chunk_reader_from_fileobj(self): + with open(self.filename, 'rb') as f: + reader = OSUtils().open_file_chunk_reader_from_fileobj( + f, len(self.content), len(self.content), [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + reader.close() + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + self.assertEqual(self.num_close_callback_calls, 0) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + non_existent_file = os.path.join(self.tempdir, 'no-exist') + # This should not exist to start. + self.assertFalse(os.path.exists(non_existent_file)) + try: + OSUtils().remove_file(non_existent_file) + except OSError as e: + self.fail('OSError should have been caught: %s' % e) + + def test_remove_file_proxies_remove_file(self): + OSUtils().remove_file(self.filename) + self.assertFalse(os.path.exists(self.filename)) + + def test_rename_file(self): + new_filename = os.path.join(self.tempdir, 'newfoo') + OSUtils().rename_file(self.filename, new_filename) + self.assertFalse(os.path.exists(self.filename)) + self.assertTrue(os.path.exists(new_filename)) + + def test_is_special_file_for_normal_file(self): + self.assertFalse(OSUtils().is_special_file(self.filename)) + + def test_is_special_file_for_non_existant_file(self): + non_existant_filename = os.path.join(self.tempdir, 'no-exist') + self.assertFalse(os.path.exists(non_existant_filename)) + self.assertFalse(OSUtils().is_special_file(non_existant_filename)) + + def test_get_temp_filename(self): + filename = 'myfile' + self.assertIsNotNone( + re.match( + r'%s\.[0-9A-Fa-f]{8}$' % filename, + OSUtils().get_temp_filename(filename), + ) + ) + + def test_get_temp_filename_len_255(self): + filename = 'a' * 255 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_get_temp_filename_len_gt_255(self): + filename = 'a' * 280 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_allocate(self): + truncate_size = 1 + OSUtils().allocate(self.filename, truncate_size) + with open(self.filename, 'rb') as f: + self.assertEqual(len(f.read()), truncate_size) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_io_error(self, mock_fallocate): + mock_fallocate.side_effect = IOError() + with self.assertRaises(IOError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_os_error(self, mock_fallocate): + mock_fallocate.side_effect = OSError() + with self.assertRaises(OSError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + +class TestDeferredOpenFile(BaseUtilsTest): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + self.contents = b'my contents' + with open(self.filename, 'wb') as f: + f.write(self.contents) + self.deferred_open_file = DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.open_call_args = [] + + def tearDown(self): + self.deferred_open_file.close() + super().tearDown() + + def recording_open_function(self, filename, mode): + self.open_call_args.append((filename, mode)) + return open(filename, mode) + + def open_nonseekable(self, filename, mode): + self.open_call_args.append((filename, mode)) + return NonSeekableWriter(BytesIO(self.content)) + + def test_instantiation_does_not_open_file(self): + DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.assertEqual(len(self.open_call_args), 0) + + def test_name(self): + self.assertEqual(self.deferred_open_file.name, self.filename) + + def test_read(self): + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[0:2]) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_write(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + open_function=self.recording_open_function, + ) + + write_content = b'foo' + self.deferred_open_file.write(write_content) + self.deferred_open_file.write(write_content) + self.deferred_open_file.close() + # Both of the writes should now be in the file. + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), write_content * 2) + # Open should have only been called once. + self.assertEqual(len(self.open_call_args), 1) + + def test_seek(self): + self.deferred_open_file.seek(2) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_does_not_seek_with_zero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=0, + open_function=self.open_nonseekable, + ) + + try: + # If this seeks, an UnsupportedOperation error will be raised. + self.deferred_open_file.write(b'data') + except io.UnsupportedOperation: + self.fail('DeferredOpenFile seeked upon opening') + + def test_open_seeks_with_nonzero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=5, + open_function=self.open_nonseekable, + ) + + # Since a non-seekable file is being opened, calling Seek will raise + # an UnsupportedOperation error. + with self.assertRaises(io.UnsupportedOperation): + self.deferred_open_file.write(b'data') + + def test_tell(self): + self.deferred_open_file.tell() + # tell() should not have opened the file if it has not been seeked + # or read because we know the start bytes upfront. + self.assertEqual(len(self.open_call_args), 0) + + self.deferred_open_file.seek(2) + self.assertEqual(self.deferred_open_file.tell(), 2) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_args(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='ab+', + open_function=self.recording_open_function, + ) + # Force an open + self.deferred_open_file.write(b'data') + self.assertEqual(len(self.open_call_args), 1) + self.assertEqual(self.open_call_args[0], (self.filename, 'ab+')) + + def test_context_handler(self): + with self.deferred_open_file: + self.assertEqual(len(self.open_call_args), 1) + + +class TestReadFileChunk(BaseUtilsTest): + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + chunk.seek(1, whence=1) + self.assertEqual(chunk.tell(), 1) + chunk.seek(-1, whence=1) + self.assertEqual(chunk.tell(), 0) + chunk.seek(-1, whence=2) + self.assertEqual(chunk.tell(), 2) + + def test_tell_and_seek_boundaries(self): + # Test to ensure ReadFileChunk behaves the same as the + # Python standard library around seeking and reading out + # of bounds in a file object. + data = b'abcdefghij12345678klmnopqrst' + start_pos = 10 + chunk_size = 8 + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + + # ReadFileChunk should be a substring of only numbers + file_objects = [ + ReadFileChunk.from_filename( + filename, start_byte=start_pos, chunk_size=chunk_size + ) + ] + + # Uncomment next line to validate we match Python's io.BytesIO + # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size])) + + for obj in file_objects: + self._assert_whence_start_behavior(obj) + self._assert_whence_end_behavior(obj) + self._assert_whence_relative_behavior(obj) + self._assert_boundary_behavior(obj) + + def _assert_whence_start_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(1, 0) + self.assertEqual(file_obj.tell(), 1) + + file_obj.seek(1) + self.assertEqual(file_obj.tell(), 1) + self.assertEqual(file_obj.read(), b'2345678') + + file_obj.seek(3, 0) + self.assertEqual(file_obj.tell(), 3) + + file_obj.seek(0, 0) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_relative_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(2, 1) + self.assertEqual(file_obj.tell(), 2) + + file_obj.seek(1, 1) + self.assertEqual(file_obj.tell(), 3) + self.assertEqual(file_obj.read(), b'45678') + + file_obj.seek(20, 1) + self.assertEqual(file_obj.tell(), 28) + + file_obj.seek(-30, 1) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 1) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_end_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(-1, 2) + self.assertEqual(file_obj.tell(), 7) + + file_obj.seek(1, 2) + self.assertEqual(file_obj.tell(), 9) + + file_obj.seek(3, 2) + self.assertEqual(file_obj.tell(), 11) + self.assertEqual(file_obj.read(), b'') + + file_obj.seek(-15, 2) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 2) + self.assertEqual(file_obj.tell(), 0) + + def _assert_boundary_behavior(self, file_obj): + # Verify we're at the start + self.assertEqual(file_obj.tell(), 0) + + # Verify we can't move backwards beyond start of file + file_obj.seek(-10, 1) + self.assertEqual(file_obj.tell(), 0) + + # Verify we *can* move after end of file, but return nothing + file_obj.seek(10, 2) + self.assertEqual(file_obj.tell(), 18) + self.assertEqual(file_obj.read(), b'') + self.assertEqual(file_obj.read(10), b'') + + # Verify we can partially rewind + file_obj.seek(-12, 1) + self.assertEqual(file_obj.tell(), 6) + self.assertEqual(file_obj.read(), b'78') + self.assertEqual(file_obj.tell(), 8) + + # Verify we can rewind to start + file_obj.seek(0) + self.assertEqual(file_obj.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_all_callbacks_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback, self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + # The list should be twice as long because there are two callbacks + # recording the amount read. + self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + def test_callback_triggered_by_out_of_bound_seeks(self): + data = b'abcdefghij1234567890klmnopqr' + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + chunk = ReadFileChunk.from_filename( + filename, start_byte=10, chunk_size=10, callbacks=[self.callback] + ) + + # Seek calls that generate "0" progress are skipped by + # invoke_progress_callbacks and won't appear in the list. + expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10] + + self._assert_out_of_bound_start_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_end_seek(chunk, expected_callback_prog) + + def _assert_out_of_bound_start_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20) # (20, 10) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(11) # (11, 0) + chunk.seek(20) # (20, 0) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(0) # (0, -10) + chunk.seek(0) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_relative_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20, 1) # (20, 10) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-9, 1) # (11, 0) + chunk.seek(9, 1) # (20, 0) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-20, 1) # (0, -10) + chunk.seek(-1000, 1) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_end_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(10, 2) # (20, 10) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(1, 2) # (11, 0) + chunk.seek(10, 2) # (20, 0) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-10, 2) # (0, -10) + chunk.seek(-1000, 2) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def test_close_callbacks(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 1) + + def test_close_callbacks_when_not_enabled(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + enable_callbacks=False, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 0) + + def test_close_callbacks_when_context_handler_is_used(self): + with open(self.filename) as f: + with ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) as chunk: + chunk.read(1) + self.assertEqual(self.num_close_callback_calls, 1) + + def test_signal_transferring(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.signal_not_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, []) + chunk.signal_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, [1]) + + def test_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_transferring() + self.assertTrue(underlying_stream.signal_transferring.called) + + def test_no_call_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call signal_transferring ' + 'to the underlying stream.' + ) + + def test_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_not_transferring() + self.assertTrue(underlying_stream.signal_not_transferring.called) + + def test_no_call_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_not_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call ' + 'signal_not_transferring to the underlying stream.' + ) + + +class TestStreamReaderProgress(BaseUtilsTest): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress( + original_stream, [self.callback, self.callback] + ) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(self.amounts_seen, [9, 9]) + + +class TestTaskSemaphore(unittest.TestCase): + def setUp(self): + self.semaphore = TaskSemaphore(1) + + def test_should_block_at_max_capacity(self): + self.semaphore.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + self.semaphore.acquire('a', blocking=False) + + def test_release_capacity(self): + acquire_token = self.semaphore.acquire('a', blocking=False) + self.semaphore.release('a', acquire_token) + try: + self.semaphore.acquire('a', blocking=False) + except NoResourcesAvailable: + self.fail( + 'The release of the semaphore should have allowed for ' + 'the second acquire to not be blocked' + ) + + +class TestSlidingWindowSemaphore(unittest.TestCase): + # These tests use block=False to tests will fail + # instead of hang the test runner in the case of x + # incorrect behavior. + def test_acquire_release_basic_case(self): + sem = SlidingWindowSemaphore(1) + # Count is 1 + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', 0) + # Count now back to 1. + + def test_can_acquire_release_multiple_times(self): + sem = SlidingWindowSemaphore(1) + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', num) + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 1) + sem.release('a', num) + + def test_can_acquire_a_range(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + self.assertEqual(sem.acquire('a', blocking=False), 2) + sem.release('a', 0) + sem.release('a', 1) + sem.release('a', 2) + # Now we're reset so we should be able to acquire the same + # sequence again. + self.assertEqual(sem.acquire('a', blocking=False), 3) + self.assertEqual(sem.acquire('a', blocking=False), 4) + self.assertEqual(sem.acquire('a', blocking=False), 5) + self.assertEqual(sem.current_count(), 0) + + def test_counter_release_only_on_min_element(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # The count only increases when we free the min + # element. This means if we're currently failing to + # acquire now: + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + # Then freeing a non-min element: + sem.release('a', 1) + + # doesn't change anything. We still fail to acquire. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + self.assertEqual(sem.current_count(), 0) + + def test_raises_error_when_count_is_zero(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # Count is now 0 so trying to acquire should fail. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + def test_release_counters_can_increment_counter_repeatedly(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # These two releases don't increment the counter + # because we're waiting on 0. + sem.release('a', 1) + sem.release('a', 2) + self.assertEqual(sem.current_count(), 0) + # But as soon as we release 0, we free up 0, 1, and 2. + sem.release('a', 0) + self.assertEqual(sem.current_count(), 3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def test_error_to_release_unknown_tag(self): + sem = SlidingWindowSemaphore(3) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_track_multiple_tags(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('b', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + + # We're at our max of 3 even though 2 are for A and 1 is for B. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + sem.acquire('b', blocking=False) + + def test_can_handle_multiple_tags_released(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('b', blocking=False) + sem.acquire('b', blocking=False) + + sem.release('b', 1) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 0) + + sem.release('b', 0) + self.assertEqual(sem.acquire('a', blocking=False), 2) + + sem.release('a', 0) + self.assertEqual(sem.acquire('b', blocking=False), 2) + + def test_is_error_to_release_unknown_sequence_number(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + with self.assertRaises(ValueError): + sem.release('a', 1) + + def test_is_error_to_double_release(self): + # This is different than other error tests because + # we're verifying we can reset the state after an + # acquire/release cycle. + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.release('a', 0) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 2) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_check_in_partial_range(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + sem.release('a', 1) + sem.release('a', 3) + sem.release('a', 0) + self.assertEqual(sem.current_count(), 2) + + +class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase): + # These tests focus on mutithreaded properties of the range + # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore. + def setUp(self): + self.threads = [] + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + self.threads = [] + + def start_threads(self): + for thread in self.threads: + thread.start() + + def test_acquire_blocks_until_release_is_called(self): + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def acquire(): + # This next call to acquire will block. + self.assertEqual(sem.acquire('a', blocking=True), 2) + + t = threading.Thread(target=acquire) + self.threads.append(t) + # Starting the thread will block the sem.acquire() + # in the acquire function above. + t.start() + # This still will keep the thread blocked. + sem.release('a', 1) + # Releasing the min element will unblock the thread. + sem.release('a', 0) + t.join() + sem.release('a', 2) + + def test_stress_invariants_random_order(self): + sem = SlidingWindowSemaphore(100) + for _ in range(10): + recorded = [] + for _ in range(100): + recorded.append(sem.acquire('a', blocking=False)) + # Release them in randomized order. As long as we + # eventually free all 100, we should have all the + # resources released. + random.shuffle(recorded) + for i in recorded: + sem.release('a', i) + + # Everything's freed so should be back at count == 100 + self.assertEqual(sem.current_count(), 100) + + def test_blocking_stress(self): + sem = SlidingWindowSemaphore(5) + num_threads = 10 + num_iterations = 50 + + def acquire(): + for _ in range(num_iterations): + num = sem.acquire('a', blocking=True) + time.sleep(0.001) + sem.release('a', num) + + for i in range(num_threads): + t = threading.Thread(target=acquire) + self.threads.append(t) + self.start_threads() + self.join_threads() + # Should have all the available resources freed. + self.assertEqual(sem.current_count(), 5) + # Should have acquired num_threads * num_iterations + self.assertEqual( + sem.acquire('a', blocking=False), num_threads * num_iterations + ) + + +class TestAdjustChunksize(unittest.TestCase): + def setUp(self): + self.adjuster = ChunksizeAdjuster() + + def test_valid_chunksize(self): + chunksize = 7 * (1024 ** 2) + file_size = 8 * (1024 ** 2) + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, chunksize) + + def test_chunksize_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + file_size = 3 * MIN_UPLOAD_CHUNKSIZE + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_chunksize_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + file_size = MAX_SINGLE_UPLOAD_SIZE * 2 + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) + + def test_chunksize_too_small(self): + chunksize = 7 * (1024 ** 2) + file_size = 5 * (1024 ** 4) + # If we try to upload a 5TB file, we'll need to use 896MB part + # sizes. + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, 896 * (1024 ** 2)) + num_parts = file_size / new_size + self.assertLessEqual(num_parts, MAX_PARTS) + + def test_unknown_file_size_with_valid_chunksize(self): + chunksize = 7 * (1024 ** 2) + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, chunksize) + + def test_unknown_file_size_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_unknown_file_size_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) diff --git a/contrib/python/s3transfer/py3/tests/ya.make b/contrib/python/s3transfer/py3/tests/ya.make new file mode 100644 index 0000000000..fdbf22b0c5 --- /dev/null +++ b/contrib/python/s3transfer/py3/tests/ya.make @@ -0,0 +1,44 @@ +PY3TEST() + +OWNER(g:python-contrib) + +SIZE(MEDIUM) + +FORK_TESTS() + +PEERDIR( + contrib/python/mock + contrib/python/s3transfer +) + +TEST_SRCS( + functional/__init__.py + functional/test_copy.py + functional/test_crt.py + functional/test_delete.py + functional/test_download.py + functional/test_manager.py + functional/test_processpool.py + functional/test_upload.py + functional/test_utils.py + __init__.py + unit/__init__.py + unit/test_bandwidth.py + unit/test_compat.py + unit/test_copies.py + unit/test_crt.py + unit/test_delete.py + unit/test_download.py + unit/test_futures.py + unit/test_manager.py + unit/test_processpool.py + unit/test_s3transfer.py + unit/test_subscribers.py + unit/test_tasks.py + unit/test_upload.py + unit/test_utils.py +) + +NO_LINT() + +END() diff --git a/contrib/python/s3transfer/py3/ya.make b/contrib/python/s3transfer/py3/ya.make new file mode 100644 index 0000000000..964a630639 --- /dev/null +++ b/contrib/python/s3transfer/py3/ya.make @@ -0,0 +1,51 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +OWNER(gebetix g:python-contrib) + +VERSION(0.5.1) + +LICENSE(Apache-2.0) + +PEERDIR( + contrib/python/botocore +) + +NO_LINT() + +NO_CHECK_IMPORTS( + s3transfer.crt +) + +PY_SRCS( + TOP_LEVEL + s3transfer/__init__.py + s3transfer/bandwidth.py + s3transfer/compat.py + s3transfer/constants.py + s3transfer/copies.py + s3transfer/crt.py + s3transfer/delete.py + s3transfer/download.py + s3transfer/exceptions.py + s3transfer/futures.py + s3transfer/manager.py + s3transfer/processpool.py + s3transfer/subscribers.py + s3transfer/tasks.py + s3transfer/upload.py + s3transfer/utils.py +) + +RESOURCE_FILES( + PREFIX contrib/python/s3transfer/py3/ + .dist-info/METADATA + .dist-info/top_level.txt +) + +END() + +RECURSE_FOR_TESTS( + tests +) |